diff --git a/go.mod b/go.mod index bdba1afa7..9902f6408 100644 --- a/go.mod +++ b/go.mod @@ -90,6 +90,7 @@ require ( github.com/oracle/oci-go-sdk/v54 v54.0.0 github.com/pashagolub/pgxmock/v2 v2.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible + github.com/puzpuzpuz/xsync/v3 v3.0.0 github.com/rabbitmq/amqp091-go v1.8.1 github.com/redis/go-redis/v9 v9.2.1 github.com/sendgrid/sendgrid-go v3.13.0+incompatible diff --git a/go.sum b/go.sum index 302f26ca3..e76d19003 100644 --- a/go.sum +++ b/go.sum @@ -1732,6 +1732,8 @@ github.com/prometheus/statsd_exporter v0.21.0/go.mod h1:rbT83sZq2V+p73lHhPZfMc3M github.com/prometheus/statsd_exporter v0.22.7 h1:7Pji/i2GuhK6Lu7DHrtTkFmNBCudCPT1pX2CziuyQR0= github.com/prometheus/statsd_exporter v0.22.7/go.mod h1:N/TevpjkIh9ccs6nuzY3jQn9dFqnUakOjnEuMPJJJnI= github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= +github.com/puzpuzpuz/xsync/v3 v3.0.0 h1:QwUcmah+dZZxy6va/QSU26M6O6Q422afP9jO8JlnRSA= +github.com/puzpuzpuz/xsync/v3 v3.0.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA= github.com/rabbitmq/amqp091-go v1.8.1 h1:RejT1SBUim5doqcL6s7iN6SBmsQqyTgXb1xMlH0h1hA= github.com/rabbitmq/amqp091-go v1.8.1/go.mod h1:+jPrT9iY2eLjRaMSRHUhc3z14E/l85kv/f+6luSD3pc= github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= diff --git a/pubsub/aws/snssqs/metadata.go b/pubsub/aws/snssqs/metadata.go index 2f3016cce..b2d58e96d 100644 --- a/pubsub/aws/snssqs/metadata.go +++ b/pubsub/aws/snssqs/metadata.go @@ -26,7 +26,7 @@ type snsSqsMetadata struct { // aws partition in which SNS/SQS should create resources. internalPartition string `mapstructure:"-"` // name of the queue for this application. The is provided by the runtime as "consumerID". - SqsQueueName string `mapstructure:"consumerID" mdignore:"true"` + SqsQueueName string `mapstructure:"consumerID" mdignore:"true"` // name of the dead letter queue for this application. SqsDeadLettersQueueName string `mapstructure:"sqsDeadLettersQueueName"` // flag to SNS and SQS FIFO. diff --git a/pubsub/aws/snssqs/snssqs.go b/pubsub/aws/snssqs/snssqs.go index 1d71d21ae..37fdd7849 100644 --- a/pubsub/aws/snssqs/snssqs.go +++ b/pubsub/aws/snssqs/snssqs.go @@ -41,36 +41,24 @@ import ( "github.com/dapr/kit/logger" ) -type topicHandler struct { - topicName string - handler pubsub.Handler - ctx context.Context -} - type snsSqs struct { + topicsLocker TopicsLocker // key is the sanitized topic name topicArns map[string]string - // key is the sanitized topic name - topicHandlers map[string]topicHandler - topicsLock sync.RWMutex // key is the topic name, value holds the ARN of the queue and its url. - queues sync.Map + queues map[string]*sqsQueueInfo // key is a composite key of queue ARN and topic ARN mapping to subscription ARN. - subscriptions sync.Map - - snsClient *sns.SNS - sqsClient *sqs.SQS - stsClient *sts.STS - metadata *snsSqsMetadata - logger logger.Logger - id string - opsTimeout time.Duration - backOffConfig retry.Config - pollerRunning chan struct{} - - closeCh chan struct{} - closed atomic.Bool - wg sync.WaitGroup + subscriptions map[string]string + snsClient *sns.SNS + sqsClient *sqs.SQS + stsClient *sts.STS + metadata *snsSqsMetadata + logger logger.Logger + id string + opsTimeout time.Duration + backOffConfig retry.Config + subscriptionManager SubscriptionManagement + closed atomic.Bool } type sqsQueueInfo struct { @@ -105,11 +93,8 @@ func NewSnsSqs(l logger.Logger) pubsub.PubSub { } return &snsSqs{ - logger: l, - id: id, - topicsLock: sync.RWMutex{}, - pollerRunning: make(chan struct{}, 1), - closeCh: make(chan struct{}), + logger: l, + id: id, } } @@ -160,13 +145,6 @@ func (s *snsSqs) Init(ctx context.Context, metadata pubsub.Metadata) error { s.metadata = md - // both Publish and Subscribe need reference the topic ARN, queue ARN and subscription ARN between topic and queue - // track these ARNs in these maps. - s.topicArns = make(map[string]string) - s.topicHandlers = make(map[string]topicHandler) - s.queues = sync.Map{} - s.subscriptions = sync.Map{} - sess, err := awsAuth.GetClient(md.AccessKey, md.SecretKey, md.SessionToken, md.Region, md.Endpoint) if err != nil { return fmt.Errorf("error creating an AWS client: %w", err) @@ -189,6 +167,13 @@ func (s *snsSqs) Init(ctx context.Context, metadata pubsub.Metadata) error { if err != nil { return fmt.Errorf("error decoding backOff config: %w", err) } + // subscription manager responsible for managing the lifecycle of subscriptions. + s.subscriptionManager = NewSubscriptionMgmt(s.logger) + s.topicsLocker = NewLockManager() + + s.topicArns = make(map[string]string) + s.queues = make(map[string]*sqsQueueInfo) + s.subscriptions = make(map[string]string) return nil } @@ -243,7 +228,7 @@ func (s *snsSqs) getTopicArn(parentCtx context.Context, topic string) (string, e }) cancelFn() if err != nil { - return "", fmt.Errorf("error: %w while getting topic: %v with arn: %v", err, topic, arn) + return "", fmt.Errorf("error: %w, while getting (sanitized) topic: %v with arn: %v", err, topic, arn) } return *getTopicOutput.Attributes["TopicArn"], nil @@ -251,40 +236,45 @@ func (s *snsSqs) getTopicArn(parentCtx context.Context, topic string) (string, e // get the topic ARN from the topics map. If it doesn't exist in the map, try to fetch it from AWS, if it doesn't exist // at all, issue a request to create the topic. -func (s *snsSqs) getOrCreateTopic(ctx context.Context, topic string) (topicArn string, sanitizedName string, err error) { - s.topicsLock.Lock() - defer s.topicsLock.Unlock() +func (s *snsSqs) getOrCreateTopic(ctx context.Context, topic string) (topicArn string, sanitizedTopic string, err error) { + sanitizedTopic = nameToAWSSanitizedName(topic, s.metadata.Fifo) - sanitizedName = nameToAWSSanitizedName(topic, s.metadata.Fifo) + var loadOK bool + if topicArn, loadOK = s.topicArns[sanitizedTopic]; loadOK { + if len(topicArn) > 0 { + s.logger.Debugf("Found existing topic ARN for topic %s: %s", topic, topicArn) - topicArnCached, ok := s.topicArns[sanitizedName] - if ok && topicArnCached != "" { - s.logger.Debugf("found existing topic ARN for topic %s: %s", topic, topicArnCached) - return topicArnCached, sanitizedName, nil + return topicArn, sanitizedTopic, err + } else { + err = fmt.Errorf("the ARN for (sanitized) topic: %s was empty", sanitizedTopic) + + return topicArn, sanitizedTopic, err + } } + // creating queues is idempotent, the names serve as unique keys among a given region. - s.logger.Debugf("No SNS topic arn found for %s\nCreating SNS topic", topic) + s.logger.Debugf("No SNS topic ARN found for topic: %s. creating SNS with (sanitized) topic: %s", topic, sanitizedTopic) if !s.metadata.DisableEntityManagement { - topicArn, err = s.createTopic(ctx, sanitizedName) + topicArn, err = s.createTopic(ctx, sanitizedTopic) if err != nil { - s.logger.Errorf("error creating new topic %s: %w", topic, err) + err = fmt.Errorf("error creating new (sanitized) topic '%s': %w", topic, err) - return "", "", err + return topicArn, sanitizedTopic, err } } else { - topicArn, err = s.getTopicArn(ctx, sanitizedName) + topicArn, err = s.getTopicArn(ctx, sanitizedTopic) if err != nil { - s.logger.Errorf("error fetching info for topic %s: %w", topic, err) + err = fmt.Errorf("error fetching info for (sanitized) topic: %s. wrapped error is: %w", topic, err) - return "", "", err + return topicArn, sanitizedTopic, err } } // record topic ARN. - s.topicArns[sanitizedName] = topicArn + s.topicArns[sanitizedTopic] = topicArn - return topicArn, sanitizedName, nil + return topicArn, sanitizedTopic, err } func (s *snsSqs) createQueue(parentCtx context.Context, queueName string) (*sqsQueueInfo, error) { @@ -346,13 +336,13 @@ func (s *snsSqs) getOrCreateQueue(ctx context.Context, queueName string) (*sqsQu queueInfo *sqsQueueInfo ) - if cachedQueueInfo, ok := s.queues.Load(queueName); ok { - s.logger.Debugf("Found queue arn for %s: %s", queueName, cachedQueueInfo.(*sqsQueueInfo).arn) + if cachedQueueInfo, ok := s.queues[queueName]; ok { + s.logger.Debugf("Found queue ARN for %s: %s", queueName, cachedQueueInfo.arn) - return cachedQueueInfo.(*sqsQueueInfo), nil + return cachedQueueInfo, nil } // creating queues is idempotent, the names serve as unique keys among a given region. - s.logger.Debugf("No SQS queue arn found for %s\nCreating SQS queue", queueName) + s.logger.Debugf("No SQS queue ARN found for %s\nCreating SQS queue", queueName) sanitizedName := nameToAWSSanitizedName(queueName, s.metadata.Fifo) @@ -372,8 +362,8 @@ func (s *snsSqs) getOrCreateQueue(ctx context.Context, queueName string) (*sqsQu } } - s.queues.Store(queueName, queueInfo) - s.logger.Debugf("Created SQS queue: %s: with arn: %s", queueName, queueInfo.arn) + s.queues[queueName] = queueInfo + s.logger.Debugf("created SQS queue: %s: with arn: %s", queueName, queueInfo.arn) return queueInfo, nil } @@ -429,13 +419,13 @@ func (s *snsSqs) getSnsSqsSubscriptionArn(parentCtx context.Context, topicArn st func (s *snsSqs) getOrCreateSnsSqsSubscription(ctx context.Context, queueArn, topicArn string) (subscriptionArn string, err error) { compositeKey := fmt.Sprintf("%s:%s", queueArn, topicArn) - if cachedSubscriptionArn, ok := s.subscriptions.Load(compositeKey); ok { + if cachedSubscriptionArn, ok := s.subscriptions[compositeKey]; ok { s.logger.Debugf("Found subscription of queue arn: %s to topic arn: %s: %s", queueArn, topicArn, cachedSubscriptionArn) - return cachedSubscriptionArn.(string), nil + return cachedSubscriptionArn, nil } - s.logger.Debugf("No subscription arn found of queue arn:%s to topic arn: %s\nCreating subscription", queueArn, topicArn) + s.logger.Debugf("No subscription ARN found of queue arn:%s to topic arn: %s\nCreating subscription", queueArn, topicArn) if !s.metadata.DisableEntityManagement { subscriptionArn, err = s.createSnsSqsSubscription(ctx, queueArn, topicArn) @@ -447,13 +437,13 @@ func (s *snsSqs) getOrCreateSnsSqsSubscription(ctx context.Context, queueArn, to } else { subscriptionArn, err = s.getSnsSqsSubscriptionArn(ctx, topicArn) if err != nil { - s.logger.Errorf("error fetching info for topic arn %s: %w", topicArn, err) + s.logger.Errorf("error fetching info for topic ARN %s: %w", topicArn, err) return "", err } } - s.subscriptions.Store(compositeKey, subscriptionArn) + s.subscriptions[compositeKey] = subscriptionArn s.logger.Debugf("Subscribed to topic %s: %s", topicArn, subscriptionArn) return subscriptionArn, nil @@ -555,18 +545,25 @@ func (s *snsSqs) callHandler(ctx context.Context, message *sqs.Message, queueInf // for the user to be able to understand the source of the coming message, we'd use the original, // dirty name to be carried over in the pubsub.NewMessage Topic field. sanitizedTopic := snsMessagePayload.parseTopicArn() - s.topicsLock.RLock() - handler, ok := s.topicHandlers[sanitizedTopic] - s.topicsLock.RUnlock() - if !ok || handler.topicName == "" { - return fmt.Errorf("handler for topic (sanitized): %s not found", sanitizedTopic) + // get a handler by sanitized topic name and perform validations + var ( + handler *SubscriptionTopicHandler + loadOK bool + ) + if handler, loadOK = s.subscriptionManager.GetSubscriptionTopicHandler(sanitizedTopic); loadOK { + if len(handler.requestTopic) == 0 { + return fmt.Errorf("handler topic name is missing") + } + } else { + return fmt.Errorf("handler for (sanitized) topic: %s was not found", sanitizedTopic) } - s.logger.Debugf("Processing SNS message id: %s of topic: %s", *message.MessageId, sanitizedTopic) + s.logger.Debugf("Processing SNS message id: %s of (sanitized) topic: %s", *message.MessageId, sanitizedTopic) + // call the handler with its own subscription context err = handler.handler(handler.ctx, &pubsub.NewMessage{ Data: []byte(snsMessagePayload.Message), - Topic: handler.topicName, + Topic: handler.requestTopic, }) if err != nil { return fmt.Errorf("error handling message: %w", err) @@ -575,6 +572,8 @@ func (s *snsSqs) callHandler(ctx context.Context, message *sqs.Message, queueInf return s.acknowledgeMessage(ctx, queueInfo.url, message.ReceiptHandle) } +// consumeSubscription is responsible for polling messages from the queue and calling the handler. +// it is being passed as a callback to the subscription manager that initializes the context of the handler. func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLettersQueueInfo *sqsQueueInfo) { sqsPullExponentialBackoff := s.backOffConfig.NewBackOffWithContext(ctx) @@ -601,12 +600,13 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters // iteration. Therefore, a global backoff (to the internal backoff) is used (sqsPullExponentialBackoff). messageResponse, err := s.sqsClient.ReceiveMessageWithContext(ctx, receiveMessageInput) if err != nil { - if err == context.Canceled || err == context.DeadlineExceeded { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || ctx.Err() != nil { s.logger.Warn("context canceled; stopping consuming from queue arn: %v", queueInfo.arn) continue } - if awsErr, ok := err.(awserr.Error); ok { + var awsErr awserr.Error + if errors.As(err, &awsErr) { s.logger.Errorf("AWS operation error while consuming from queue arn: %v with error: %w. retrying...", queueInfo.arn, awsErr.Error()) } else { s.logger.Errorf("error consuming from queue arn: %v with error: %w. retrying...", queueInfo.arn, err) @@ -619,7 +619,6 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters sqsPullExponentialBackoff.Reset() if len(messageResponse.Messages) < 1 { - // s.logger.Debug("No messages received, continuing") continue } s.logger.Debugf("%v message(s) received on queue %s", len(messageResponse.Messages), queueInfo.arn) @@ -632,11 +631,10 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters } f := func(message *sqs.Message) { + defer wg.Done() if err := s.callHandler(ctx, message, queueInfo); err != nil { s.logger.Errorf("error while handling received message. error is: %v", err) } - - wg.Done() } wg.Add(1) @@ -653,9 +651,6 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters } wg.Wait() } - - // Signal that the poller stopped - <-s.pollerRunning } func (s *snsSqs) createDeadLettersQueueAttributes(queueInfo, deadLettersQueueInfo *sqsQueueInfo) (*sqs.SetQueueAttributesInput, error) { @@ -763,6 +758,9 @@ func (s *snsSqs) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, han return errors.New("component is closed") } + s.topicsLocker.Lock(req.Topic) + defer s.topicsLocker.Unlock(req.Topic) + // subscribers declare a topic ARN and declare a SQS queue to use // these should be idempotent - queues should not be created if they exist. topicArn, sanitizedName, err := s.getOrCreateTopic(ctx, req.Topic) @@ -824,63 +822,15 @@ func (s *snsSqs) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, han return wrappedErr } - // Store the handler for this topic - s.topicsLock.Lock() - defer s.topicsLock.Unlock() - s.topicHandlers[sanitizedName] = topicHandler{ - topicName: req.Topic, - handler: handler, - ctx: ctx, - } + // start the subscription manager + s.subscriptionManager.Init(queueInfo, deadLettersQueueInfo, s.consumeSubscription) - // pollerCancel is used to cancel the polling goroutine. We use a noop cancel - // func in case the poller is already running and there is no cancel to use - // from the select below. - var pollerCancel context.CancelFunc = func() {} - // Start the poller for the queue if it's not running already - select { - case s.pollerRunning <- struct{}{}: - // If inserting in the channel succeeds, then it's not running already - // Use a context that is tied to the background context - var subctx context.Context - subctx, pollerCancel = context.WithCancel(context.Background()) - s.wg.Add(2) - go func() { - defer s.wg.Done() - defer pollerCancel() - select { - case <-s.closeCh: - case <-subctx.Done(): - } - }() - go func() { - defer s.wg.Done() - s.consumeSubscription(subctx, queueInfo, deadLettersQueueInfo) - }() - default: - // Do nothing, it means the poller is already running - } - - // Watch for subscription context cancellation to remove this subscription - s.wg.Add(1) - go func() { - defer s.wg.Done() - select { - case <-ctx.Done(): - case <-s.closeCh: - } - - s.topicsLock.Lock() - defer s.topicsLock.Unlock() - - // Remove the handler - delete(s.topicHandlers, sanitizedName) - - // If we don't have any topic left, close the poller. - if len(s.topicHandlers) == 0 { - pollerCancel() - } - }() + s.subscriptionManager.Subscribe(&SubscriptionTopicHandler{ + topic: sanitizedName, + requestTopic: req.Topic, + handler: handler, + ctx: ctx, + }) return nil } @@ -920,9 +870,9 @@ func (s *snsSqs) Publish(ctx context.Context, req *pubsub.PublishRequest) error // client. Blocks until all goroutines have returned. func (s *snsSqs) Close() error { if s.closed.CompareAndSwap(false, true) { - close(s.closeCh) + s.subscriptionManager.Close() } - s.wg.Wait() + return nil } diff --git a/pubsub/aws/snssqs/subscription_mgmt.go b/pubsub/aws/snssqs/subscription_mgmt.go new file mode 100644 index 000000000..c868d9dcc --- /dev/null +++ b/pubsub/aws/snssqs/subscription_mgmt.go @@ -0,0 +1,179 @@ +package snssqs + +import ( + "context" + "sync" + + "github.com/puzpuzpuz/xsync/v3" + + "github.com/dapr/components-contrib/pubsub" + "github.com/dapr/kit/logger" +) + +type ( + SubscriptionAction int +) + +const ( + Subscribe SubscriptionAction = iota + Unsubscribe +) + +type SubscriptionTopicHandler struct { + topic string + requestTopic string + handler pubsub.Handler + ctx context.Context +} + +type changeSubscriptionTopicHandler struct { + action SubscriptionAction + handler *SubscriptionTopicHandler +} + +type SubscriptionManager struct { + logger logger.Logger + consumeCancelFunc context.CancelFunc + closeCh chan struct{} + topicsChangeCh chan changeSubscriptionTopicHandler + topicsHandlers *xsync.MapOf[string, *SubscriptionTopicHandler] + lock sync.Mutex + wg sync.WaitGroup + initOnce sync.Once +} + +type SubscriptionManagement interface { + Init(queueInfo *sqsQueueInfo, dlqInfo *sqsQueueInfo, cbk func(context.Context, *sqsQueueInfo, *sqsQueueInfo)) + Subscribe(topicHandler *SubscriptionTopicHandler) + Close() + GetSubscriptionTopicHandler(topic string) (*SubscriptionTopicHandler, bool) +} + +func NewSubscriptionMgmt(log logger.Logger) SubscriptionManagement { + return &SubscriptionManager{ + logger: log, + consumeCancelFunc: func() {}, // noop until we (re)start sqs consumption + closeCh: make(chan struct{}), + topicsChangeCh: make(chan changeSubscriptionTopicHandler), + topicsHandlers: xsync.NewMapOf[string, *SubscriptionTopicHandler](), + } +} + +func createQueueConsumerCbk(queueInfo *sqsQueueInfo, dlqInfo *sqsQueueInfo, cbk func(ctx context.Context, queueInfo *sqsQueueInfo, dlqInfo *sqsQueueInfo)) func(ctx context.Context) { + return func(ctx context.Context) { + cbk(ctx, queueInfo, dlqInfo) + } +} + +func (sm *SubscriptionManager) Init(queueInfo *sqsQueueInfo, dlqInfo *sqsQueueInfo, cbk func(context.Context, *sqsQueueInfo, *sqsQueueInfo)) { + sm.initOnce.Do(func() { + queueConsumerCbk := createQueueConsumerCbk(queueInfo, dlqInfo, cbk) + go sm.queueConsumerController(queueConsumerCbk) + sm.logger.Debug("Subscription manager initialized") + }) +} + +// queueConsumerController is responsible for managing the subscription lifecycle +// and the only place where the topicsHandlers map is updated. +// it is running in a separate goroutine and is responsible for starting and stopping sqs consumption +// where its lifecycle is managed by the subscription manager, +// and it has its own context with its child contexts used for sqs consumption and aborting of the consumption. +// it is also responsible for managing the lifecycle of the subscription handlers. +func (sm *SubscriptionManager) queueConsumerController(queueConsumerCbk func(context.Context)) { + ctx := context.Background() + + for { + select { + case changeEvent := <-sm.topicsChangeCh: + topic := changeEvent.handler.topic + sm.logger.Debugf("Subscription change event received with action: %v, on topic: %s", changeEvent.action, topic) + // topic change events are serialized so that no interleaving can occur + sm.lock.Lock() + // although we have a lock here, the topicsHandlers map is thread safe and can be accessed concurrently so other subscribers that are already consuming messages + // can get the handler for the topic while we're still updating the map without blocking them + current := sm.topicsHandlers.Size() + + switch changeEvent.action { + case Subscribe: + sm.topicsHandlers.Store(topic, changeEvent.handler) + // if before we've added the subscription there were no subscriptions, this subscribe signals us to start consuming from sqs + if current == 0 { + var subCtx context.Context + // create a new context for sqs consumption with a cancel func to be used when we unsubscribe from all topics + subCtx, sm.consumeCancelFunc = context.WithCancel(ctx) + // start sqs consumption + sm.logger.Info("Starting SQS consumption") + go queueConsumerCbk(subCtx) + } + case Unsubscribe: + sm.topicsHandlers.Delete(topic) + // for idempotency, we check the size of the map after the delete operation, as we might have already deleted the subscription + afterDelete := sm.topicsHandlers.Size() + // if before we've removed this subscription we had one (last) subscription, this signals us to stop sqs consumption + if current == 1 && afterDelete == 0 { + sm.logger.Info("Last subscription removed. no more handlers are mapped to topics. stopping SQS consumption") + sm.consumeCancelFunc() + } + } + + sm.lock.Unlock() + case <-sm.closeCh: + return + } + } +} + +func (sm *SubscriptionManager) Subscribe(topicHandler *SubscriptionTopicHandler) { + sm.logger.Debug("Subscribing to topic: ", topicHandler.topic) + + sm.wg.Add(1) + go func() { + defer sm.wg.Done() + sm.createSubscribeListener(topicHandler) + }() +} + +func (sm *SubscriptionManager) createSubscribeListener(topicHandler *SubscriptionTopicHandler) { + sm.logger.Debug("Creating a subscribe listener for topic: ", topicHandler.topic) + + sm.topicsChangeCh <- changeSubscriptionTopicHandler{Subscribe, topicHandler} + closeCh := make(chan struct{}) + // the unsubscriber is expected to be terminated by the dapr runtime as it cancels the context upon unsubscribe + go sm.createUnsubscribeListener(topicHandler.ctx, topicHandler.topic, closeCh) + // if the SubscriptinoManager is being closed and somehow the dapr runtime did not call unsubscribe, we close the control + // channel here to terminate the unsubscriber and return + defer close(closeCh) + <-sm.closeCh +} + +// ctx is a context provided by daprd per subscription. unrelated to the consuming sm.baseCtx +func (sm *SubscriptionManager) createUnsubscribeListener(ctx context.Context, topic string, closeCh <-chan struct{}) { + sm.logger.Debug("Creating an unsubscribe listener for topic: ", topic) + + defer sm.unsubscribe(topic) + for { + select { + case <-ctx.Done(): + return + case <-closeCh: + return + } + } +} + +func (sm *SubscriptionManager) unsubscribe(topic string) { + sm.logger.Debug("Unsubscribing from topic: ", topic) + + if value, ok := sm.GetSubscriptionTopicHandler(topic); ok { + sm.topicsChangeCh <- changeSubscriptionTopicHandler{Unsubscribe, value} + } +} + +func (sm *SubscriptionManager) Close() { + close(sm.closeCh) + sm.wg.Wait() +} + +func (sm *SubscriptionManager) GetSubscriptionTopicHandler(topic string) (*SubscriptionTopicHandler, bool) { + return sm.topicsHandlers.Load(topic) +} diff --git a/pubsub/aws/snssqs/topics_locker.go b/pubsub/aws/snssqs/topics_locker.go new file mode 100644 index 000000000..bbb934c02 --- /dev/null +++ b/pubsub/aws/snssqs/topics_locker.go @@ -0,0 +1,44 @@ +package snssqs + +import ( + "sync" + + "github.com/puzpuzpuz/xsync/v3" +) + +// TopicsLockManager is a singleton for fine-grained locking, to prevent the component r/w operations +// from locking the entire component out when performing operations on different topics. +type TopicsLockManager struct { + xLockMap *xsync.MapOf[string, *sync.Mutex] +} + +type TopicsLocker interface { + Lock(topic string) *sync.Mutex + Unlock(topic string) +} + +func NewLockManager() *TopicsLockManager { + return &TopicsLockManager{xLockMap: xsync.NewMapOf[string, *sync.Mutex]()} +} + +func (lm *TopicsLockManager) Lock(key string) *sync.Mutex { + lock, _ := lm.xLockMap.LoadOrCompute(key, func() *sync.Mutex { + l := &sync.Mutex{} + l.Lock() + + return l + }) + + return lock +} + +func (lm *TopicsLockManager) Unlock(key string) { + lm.xLockMap.Compute(key, func(oldValue *sync.Mutex, exists bool) (newValue *sync.Mutex, delete bool) { + // if exists then the mutex must be already locked, and we unlock it + if exists { + oldValue.Unlock() + } + // we return to comply with the Compute signature, but not using the returned values + return oldValue, false + }) +} diff --git a/tests/certification/go.mod b/tests/certification/go.mod index e7b845873..dcb2d0275 100644 --- a/tests/certification/go.mod +++ b/tests/certification/go.mod @@ -227,6 +227,7 @@ require ( github.com/prometheus/common v0.44.0 // indirect github.com/prometheus/procfs v0.11.0 // indirect github.com/prometheus/statsd_exporter v0.22.7 // indirect + github.com/puzpuzpuz/xsync/v3 v3.0.0 // indirect github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 // indirect github.com/redis/go-redis/v9 v9.2.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect diff --git a/tests/certification/go.sum b/tests/certification/go.sum index 03f96ce1a..341205ff1 100644 --- a/tests/certification/go.sum +++ b/tests/certification/go.sum @@ -1108,6 +1108,8 @@ github.com/prometheus/statsd_exporter v0.21.0/go.mod h1:rbT83sZq2V+p73lHhPZfMc3M github.com/prometheus/statsd_exporter v0.22.7 h1:7Pji/i2GuhK6Lu7DHrtTkFmNBCudCPT1pX2CziuyQR0= github.com/prometheus/statsd_exporter v0.22.7/go.mod h1:N/TevpjkIh9ccs6nuzY3jQn9dFqnUakOjnEuMPJJJnI= github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= +github.com/puzpuzpuz/xsync/v3 v3.0.0 h1:QwUcmah+dZZxy6va/QSU26M6O6Q422afP9jO8JlnRSA= +github.com/puzpuzpuz/xsync/v3 v3.0.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA= github.com/rabbitmq/amqp091-go v1.8.1 h1:RejT1SBUim5doqcL6s7iN6SBmsQqyTgXb1xMlH0h1hA= github.com/rabbitmq/amqp091-go v1.8.1/go.mod h1:+jPrT9iY2eLjRaMSRHUhc3z14E/l85kv/f+6luSD3pc= github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=