FIx SQS: Error during Dapr Shutdown and Message Polling Behavior Issue (AWS) (#3174)

Signed-off-by: Amit Mor <amit.mor@hotmail.com>
Signed-off-by: Amit Mor <amitm@at-bay.com>
Signed-off-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
Amit Mor 2023-11-01 18:42:22 +02:00 committed by GitHub
parent d47d1a7aad
commit 7fd5524c58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 319 additions and 140 deletions

1
go.mod
View File

@ -90,6 +90,7 @@ require (
github.com/oracle/oci-go-sdk/v54 v54.0.0 github.com/oracle/oci-go-sdk/v54 v54.0.0
github.com/pashagolub/pgxmock/v2 v2.12.0 github.com/pashagolub/pgxmock/v2 v2.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/puzpuzpuz/xsync/v3 v3.0.0
github.com/rabbitmq/amqp091-go v1.8.1 github.com/rabbitmq/amqp091-go v1.8.1
github.com/redis/go-redis/v9 v9.2.1 github.com/redis/go-redis/v9 v9.2.1
github.com/sendgrid/sendgrid-go v3.13.0+incompatible github.com/sendgrid/sendgrid-go v3.13.0+incompatible

2
go.sum
View File

@ -1732,6 +1732,8 @@ github.com/prometheus/statsd_exporter v0.21.0/go.mod h1:rbT83sZq2V+p73lHhPZfMc3M
github.com/prometheus/statsd_exporter v0.22.7 h1:7Pji/i2GuhK6Lu7DHrtTkFmNBCudCPT1pX2CziuyQR0= github.com/prometheus/statsd_exporter v0.22.7 h1:7Pji/i2GuhK6Lu7DHrtTkFmNBCudCPT1pX2CziuyQR0=
github.com/prometheus/statsd_exporter v0.22.7/go.mod h1:N/TevpjkIh9ccs6nuzY3jQn9dFqnUakOjnEuMPJJJnI= github.com/prometheus/statsd_exporter v0.22.7/go.mod h1:N/TevpjkIh9ccs6nuzY3jQn9dFqnUakOjnEuMPJJJnI=
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
github.com/puzpuzpuz/xsync/v3 v3.0.0 h1:QwUcmah+dZZxy6va/QSU26M6O6Q422afP9jO8JlnRSA=
github.com/puzpuzpuz/xsync/v3 v3.0.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
github.com/rabbitmq/amqp091-go v1.8.1 h1:RejT1SBUim5doqcL6s7iN6SBmsQqyTgXb1xMlH0h1hA= github.com/rabbitmq/amqp091-go v1.8.1 h1:RejT1SBUim5doqcL6s7iN6SBmsQqyTgXb1xMlH0h1hA=
github.com/rabbitmq/amqp091-go v1.8.1/go.mod h1:+jPrT9iY2eLjRaMSRHUhc3z14E/l85kv/f+6luSD3pc= github.com/rabbitmq/amqp091-go v1.8.1/go.mod h1:+jPrT9iY2eLjRaMSRHUhc3z14E/l85kv/f+6luSD3pc=
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=

View File

@ -41,23 +41,14 @@ import (
"github.com/dapr/kit/logger" "github.com/dapr/kit/logger"
) )
type topicHandler struct {
topicName string
handler pubsub.Handler
ctx context.Context
}
type snsSqs struct { type snsSqs struct {
topicsLocker TopicsLocker
// key is the sanitized topic name // key is the sanitized topic name
topicArns map[string]string topicArns map[string]string
// key is the sanitized topic name
topicHandlers map[string]topicHandler
topicsLock sync.RWMutex
// key is the topic name, value holds the ARN of the queue and its url. // key is the topic name, value holds the ARN of the queue and its url.
queues sync.Map queues map[string]*sqsQueueInfo
// key is a composite key of queue ARN and topic ARN mapping to subscription ARN. // key is a composite key of queue ARN and topic ARN mapping to subscription ARN.
subscriptions sync.Map subscriptions map[string]string
snsClient *sns.SNS snsClient *sns.SNS
sqsClient *sqs.SQS sqsClient *sqs.SQS
stsClient *sts.STS stsClient *sts.STS
@ -66,11 +57,8 @@ type snsSqs struct {
id string id string
opsTimeout time.Duration opsTimeout time.Duration
backOffConfig retry.Config backOffConfig retry.Config
pollerRunning chan struct{} subscriptionManager SubscriptionManagement
closeCh chan struct{}
closed atomic.Bool closed atomic.Bool
wg sync.WaitGroup
} }
type sqsQueueInfo struct { type sqsQueueInfo struct {
@ -107,9 +95,6 @@ func NewSnsSqs(l logger.Logger) pubsub.PubSub {
return &snsSqs{ return &snsSqs{
logger: l, logger: l,
id: id, id: id,
topicsLock: sync.RWMutex{},
pollerRunning: make(chan struct{}, 1),
closeCh: make(chan struct{}),
} }
} }
@ -160,13 +145,6 @@ func (s *snsSqs) Init(ctx context.Context, metadata pubsub.Metadata) error {
s.metadata = md s.metadata = md
// both Publish and Subscribe need reference the topic ARN, queue ARN and subscription ARN between topic and queue
// track these ARNs in these maps.
s.topicArns = make(map[string]string)
s.topicHandlers = make(map[string]topicHandler)
s.queues = sync.Map{}
s.subscriptions = sync.Map{}
sess, err := awsAuth.GetClient(md.AccessKey, md.SecretKey, md.SessionToken, md.Region, md.Endpoint) sess, err := awsAuth.GetClient(md.AccessKey, md.SecretKey, md.SessionToken, md.Region, md.Endpoint)
if err != nil { if err != nil {
return fmt.Errorf("error creating an AWS client: %w", err) return fmt.Errorf("error creating an AWS client: %w", err)
@ -189,6 +167,13 @@ func (s *snsSqs) Init(ctx context.Context, metadata pubsub.Metadata) error {
if err != nil { if err != nil {
return fmt.Errorf("error decoding backOff config: %w", err) return fmt.Errorf("error decoding backOff config: %w", err)
} }
// subscription manager responsible for managing the lifecycle of subscriptions.
s.subscriptionManager = NewSubscriptionMgmt(s.logger)
s.topicsLocker = NewLockManager()
s.topicArns = make(map[string]string)
s.queues = make(map[string]*sqsQueueInfo)
s.subscriptions = make(map[string]string)
return nil return nil
} }
@ -243,7 +228,7 @@ func (s *snsSqs) getTopicArn(parentCtx context.Context, topic string) (string, e
}) })
cancelFn() cancelFn()
if err != nil { if err != nil {
return "", fmt.Errorf("error: %w while getting topic: %v with arn: %v", err, topic, arn) return "", fmt.Errorf("error: %w, while getting (sanitized) topic: %v with arn: %v", err, topic, arn)
} }
return *getTopicOutput.Attributes["TopicArn"], nil return *getTopicOutput.Attributes["TopicArn"], nil
@ -251,40 +236,45 @@ func (s *snsSqs) getTopicArn(parentCtx context.Context, topic string) (string, e
// get the topic ARN from the topics map. If it doesn't exist in the map, try to fetch it from AWS, if it doesn't exist // get the topic ARN from the topics map. If it doesn't exist in the map, try to fetch it from AWS, if it doesn't exist
// at all, issue a request to create the topic. // at all, issue a request to create the topic.
func (s *snsSqs) getOrCreateTopic(ctx context.Context, topic string) (topicArn string, sanitizedName string, err error) { func (s *snsSqs) getOrCreateTopic(ctx context.Context, topic string) (topicArn string, sanitizedTopic string, err error) {
s.topicsLock.Lock() sanitizedTopic = nameToAWSSanitizedName(topic, s.metadata.Fifo)
defer s.topicsLock.Unlock()
sanitizedName = nameToAWSSanitizedName(topic, s.metadata.Fifo) var loadOK bool
if topicArn, loadOK = s.topicArns[sanitizedTopic]; loadOK {
if len(topicArn) > 0 {
s.logger.Debugf("Found existing topic ARN for topic %s: %s", topic, topicArn)
topicArnCached, ok := s.topicArns[sanitizedName] return topicArn, sanitizedTopic, err
if ok && topicArnCached != "" { } else {
s.logger.Debugf("found existing topic ARN for topic %s: %s", topic, topicArnCached) err = fmt.Errorf("the ARN for (sanitized) topic: %s was empty", sanitizedTopic)
return topicArnCached, sanitizedName, nil
return topicArn, sanitizedTopic, err
} }
}
// creating queues is idempotent, the names serve as unique keys among a given region. // creating queues is idempotent, the names serve as unique keys among a given region.
s.logger.Debugf("No SNS topic arn found for %s\nCreating SNS topic", topic) s.logger.Debugf("No SNS topic ARN found for topic: %s. creating SNS with (sanitized) topic: %s", topic, sanitizedTopic)
if !s.metadata.DisableEntityManagement { if !s.metadata.DisableEntityManagement {
topicArn, err = s.createTopic(ctx, sanitizedName) topicArn, err = s.createTopic(ctx, sanitizedTopic)
if err != nil { if err != nil {
s.logger.Errorf("error creating new topic %s: %w", topic, err) err = fmt.Errorf("error creating new (sanitized) topic '%s': %w", topic, err)
return "", "", err return topicArn, sanitizedTopic, err
} }
} else { } else {
topicArn, err = s.getTopicArn(ctx, sanitizedName) topicArn, err = s.getTopicArn(ctx, sanitizedTopic)
if err != nil { if err != nil {
s.logger.Errorf("error fetching info for topic %s: %w", topic, err) err = fmt.Errorf("error fetching info for (sanitized) topic: %s. wrapped error is: %w", topic, err)
return "", "", err return topicArn, sanitizedTopic, err
} }
} }
// record topic ARN. // record topic ARN.
s.topicArns[sanitizedName] = topicArn s.topicArns[sanitizedTopic] = topicArn
return topicArn, sanitizedName, nil return topicArn, sanitizedTopic, err
} }
func (s *snsSqs) createQueue(parentCtx context.Context, queueName string) (*sqsQueueInfo, error) { func (s *snsSqs) createQueue(parentCtx context.Context, queueName string) (*sqsQueueInfo, error) {
@ -346,13 +336,13 @@ func (s *snsSqs) getOrCreateQueue(ctx context.Context, queueName string) (*sqsQu
queueInfo *sqsQueueInfo queueInfo *sqsQueueInfo
) )
if cachedQueueInfo, ok := s.queues.Load(queueName); ok { if cachedQueueInfo, ok := s.queues[queueName]; ok {
s.logger.Debugf("Found queue arn for %s: %s", queueName, cachedQueueInfo.(*sqsQueueInfo).arn) s.logger.Debugf("Found queue ARN for %s: %s", queueName, cachedQueueInfo.arn)
return cachedQueueInfo.(*sqsQueueInfo), nil return cachedQueueInfo, nil
} }
// creating queues is idempotent, the names serve as unique keys among a given region. // creating queues is idempotent, the names serve as unique keys among a given region.
s.logger.Debugf("No SQS queue arn found for %s\nCreating SQS queue", queueName) s.logger.Debugf("No SQS queue ARN found for %s\nCreating SQS queue", queueName)
sanitizedName := nameToAWSSanitizedName(queueName, s.metadata.Fifo) sanitizedName := nameToAWSSanitizedName(queueName, s.metadata.Fifo)
@ -372,8 +362,8 @@ func (s *snsSqs) getOrCreateQueue(ctx context.Context, queueName string) (*sqsQu
} }
} }
s.queues.Store(queueName, queueInfo) s.queues[queueName] = queueInfo
s.logger.Debugf("Created SQS queue: %s: with arn: %s", queueName, queueInfo.arn) s.logger.Debugf("created SQS queue: %s: with arn: %s", queueName, queueInfo.arn)
return queueInfo, nil return queueInfo, nil
} }
@ -429,13 +419,13 @@ func (s *snsSqs) getSnsSqsSubscriptionArn(parentCtx context.Context, topicArn st
func (s *snsSqs) getOrCreateSnsSqsSubscription(ctx context.Context, queueArn, topicArn string) (subscriptionArn string, err error) { func (s *snsSqs) getOrCreateSnsSqsSubscription(ctx context.Context, queueArn, topicArn string) (subscriptionArn string, err error) {
compositeKey := fmt.Sprintf("%s:%s", queueArn, topicArn) compositeKey := fmt.Sprintf("%s:%s", queueArn, topicArn)
if cachedSubscriptionArn, ok := s.subscriptions.Load(compositeKey); ok { if cachedSubscriptionArn, ok := s.subscriptions[compositeKey]; ok {
s.logger.Debugf("Found subscription of queue arn: %s to topic arn: %s: %s", queueArn, topicArn, cachedSubscriptionArn) s.logger.Debugf("Found subscription of queue arn: %s to topic arn: %s: %s", queueArn, topicArn, cachedSubscriptionArn)
return cachedSubscriptionArn.(string), nil return cachedSubscriptionArn, nil
} }
s.logger.Debugf("No subscription arn found of queue arn:%s to topic arn: %s\nCreating subscription", queueArn, topicArn) s.logger.Debugf("No subscription ARN found of queue arn:%s to topic arn: %s\nCreating subscription", queueArn, topicArn)
if !s.metadata.DisableEntityManagement { if !s.metadata.DisableEntityManagement {
subscriptionArn, err = s.createSnsSqsSubscription(ctx, queueArn, topicArn) subscriptionArn, err = s.createSnsSqsSubscription(ctx, queueArn, topicArn)
@ -447,13 +437,13 @@ func (s *snsSqs) getOrCreateSnsSqsSubscription(ctx context.Context, queueArn, to
} else { } else {
subscriptionArn, err = s.getSnsSqsSubscriptionArn(ctx, topicArn) subscriptionArn, err = s.getSnsSqsSubscriptionArn(ctx, topicArn)
if err != nil { if err != nil {
s.logger.Errorf("error fetching info for topic arn %s: %w", topicArn, err) s.logger.Errorf("error fetching info for topic ARN %s: %w", topicArn, err)
return "", err return "", err
} }
} }
s.subscriptions.Store(compositeKey, subscriptionArn) s.subscriptions[compositeKey] = subscriptionArn
s.logger.Debugf("Subscribed to topic %s: %s", topicArn, subscriptionArn) s.logger.Debugf("Subscribed to topic %s: %s", topicArn, subscriptionArn)
return subscriptionArn, nil return subscriptionArn, nil
@ -555,18 +545,25 @@ func (s *snsSqs) callHandler(ctx context.Context, message *sqs.Message, queueInf
// for the user to be able to understand the source of the coming message, we'd use the original, // for the user to be able to understand the source of the coming message, we'd use the original,
// dirty name to be carried over in the pubsub.NewMessage Topic field. // dirty name to be carried over in the pubsub.NewMessage Topic field.
sanitizedTopic := snsMessagePayload.parseTopicArn() sanitizedTopic := snsMessagePayload.parseTopicArn()
s.topicsLock.RLock() // get a handler by sanitized topic name and perform validations
handler, ok := s.topicHandlers[sanitizedTopic] var (
s.topicsLock.RUnlock() handler *SubscriptionTopicHandler
if !ok || handler.topicName == "" { loadOK bool
return fmt.Errorf("handler for topic (sanitized): %s not found", sanitizedTopic) )
if handler, loadOK = s.subscriptionManager.GetSubscriptionTopicHandler(sanitizedTopic); loadOK {
if len(handler.requestTopic) == 0 {
return fmt.Errorf("handler topic name is missing")
}
} else {
return fmt.Errorf("handler for (sanitized) topic: %s was not found", sanitizedTopic)
} }
s.logger.Debugf("Processing SNS message id: %s of topic: %s", *message.MessageId, sanitizedTopic) s.logger.Debugf("Processing SNS message id: %s of (sanitized) topic: %s", *message.MessageId, sanitizedTopic)
// call the handler with its own subscription context
err = handler.handler(handler.ctx, &pubsub.NewMessage{ err = handler.handler(handler.ctx, &pubsub.NewMessage{
Data: []byte(snsMessagePayload.Message), Data: []byte(snsMessagePayload.Message),
Topic: handler.topicName, Topic: handler.requestTopic,
}) })
if err != nil { if err != nil {
return fmt.Errorf("error handling message: %w", err) return fmt.Errorf("error handling message: %w", err)
@ -575,6 +572,8 @@ func (s *snsSqs) callHandler(ctx context.Context, message *sqs.Message, queueInf
return s.acknowledgeMessage(ctx, queueInfo.url, message.ReceiptHandle) return s.acknowledgeMessage(ctx, queueInfo.url, message.ReceiptHandle)
} }
// consumeSubscription is responsible for polling messages from the queue and calling the handler.
// it is being passed as a callback to the subscription manager that initializes the context of the handler.
func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLettersQueueInfo *sqsQueueInfo) { func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLettersQueueInfo *sqsQueueInfo) {
sqsPullExponentialBackoff := s.backOffConfig.NewBackOffWithContext(ctx) sqsPullExponentialBackoff := s.backOffConfig.NewBackOffWithContext(ctx)
@ -601,12 +600,13 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters
// iteration. Therefore, a global backoff (to the internal backoff) is used (sqsPullExponentialBackoff). // iteration. Therefore, a global backoff (to the internal backoff) is used (sqsPullExponentialBackoff).
messageResponse, err := s.sqsClient.ReceiveMessageWithContext(ctx, receiveMessageInput) messageResponse, err := s.sqsClient.ReceiveMessageWithContext(ctx, receiveMessageInput)
if err != nil { if err != nil {
if err == context.Canceled || err == context.DeadlineExceeded { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || ctx.Err() != nil {
s.logger.Warn("context canceled; stopping consuming from queue arn: %v", queueInfo.arn) s.logger.Warn("context canceled; stopping consuming from queue arn: %v", queueInfo.arn)
continue continue
} }
if awsErr, ok := err.(awserr.Error); ok { var awsErr awserr.Error
if errors.As(err, &awsErr) {
s.logger.Errorf("AWS operation error while consuming from queue arn: %v with error: %w. retrying...", queueInfo.arn, awsErr.Error()) s.logger.Errorf("AWS operation error while consuming from queue arn: %v with error: %w. retrying...", queueInfo.arn, awsErr.Error())
} else { } else {
s.logger.Errorf("error consuming from queue arn: %v with error: %w. retrying...", queueInfo.arn, err) s.logger.Errorf("error consuming from queue arn: %v with error: %w. retrying...", queueInfo.arn, err)
@ -619,7 +619,6 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters
sqsPullExponentialBackoff.Reset() sqsPullExponentialBackoff.Reset()
if len(messageResponse.Messages) < 1 { if len(messageResponse.Messages) < 1 {
// s.logger.Debug("No messages received, continuing")
continue continue
} }
s.logger.Debugf("%v message(s) received on queue %s", len(messageResponse.Messages), queueInfo.arn) s.logger.Debugf("%v message(s) received on queue %s", len(messageResponse.Messages), queueInfo.arn)
@ -632,11 +631,10 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters
} }
f := func(message *sqs.Message) { f := func(message *sqs.Message) {
defer wg.Done()
if err := s.callHandler(ctx, message, queueInfo); err != nil { if err := s.callHandler(ctx, message, queueInfo); err != nil {
s.logger.Errorf("error while handling received message. error is: %v", err) s.logger.Errorf("error while handling received message. error is: %v", err)
} }
wg.Done()
} }
wg.Add(1) wg.Add(1)
@ -653,9 +651,6 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters
} }
wg.Wait() wg.Wait()
} }
// Signal that the poller stopped
<-s.pollerRunning
} }
func (s *snsSqs) createDeadLettersQueueAttributes(queueInfo, deadLettersQueueInfo *sqsQueueInfo) (*sqs.SetQueueAttributesInput, error) { func (s *snsSqs) createDeadLettersQueueAttributes(queueInfo, deadLettersQueueInfo *sqsQueueInfo) (*sqs.SetQueueAttributesInput, error) {
@ -763,6 +758,9 @@ func (s *snsSqs) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, han
return errors.New("component is closed") return errors.New("component is closed")
} }
s.topicsLocker.Lock(req.Topic)
defer s.topicsLocker.Unlock(req.Topic)
// subscribers declare a topic ARN and declare a SQS queue to use // subscribers declare a topic ARN and declare a SQS queue to use
// these should be idempotent - queues should not be created if they exist. // these should be idempotent - queues should not be created if they exist.
topicArn, sanitizedName, err := s.getOrCreateTopic(ctx, req.Topic) topicArn, sanitizedName, err := s.getOrCreateTopic(ctx, req.Topic)
@ -824,63 +822,15 @@ func (s *snsSqs) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, han
return wrappedErr return wrappedErr
} }
// Store the handler for this topic // start the subscription manager
s.topicsLock.Lock() s.subscriptionManager.Init(queueInfo, deadLettersQueueInfo, s.consumeSubscription)
defer s.topicsLock.Unlock()
s.topicHandlers[sanitizedName] = topicHandler{ s.subscriptionManager.Subscribe(&SubscriptionTopicHandler{
topicName: req.Topic, topic: sanitizedName,
requestTopic: req.Topic,
handler: handler, handler: handler,
ctx: ctx, ctx: ctx,
} })
// pollerCancel is used to cancel the polling goroutine. We use a noop cancel
// func in case the poller is already running and there is no cancel to use
// from the select below.
var pollerCancel context.CancelFunc = func() {}
// Start the poller for the queue if it's not running already
select {
case s.pollerRunning <- struct{}{}:
// If inserting in the channel succeeds, then it's not running already
// Use a context that is tied to the background context
var subctx context.Context
subctx, pollerCancel = context.WithCancel(context.Background())
s.wg.Add(2)
go func() {
defer s.wg.Done()
defer pollerCancel()
select {
case <-s.closeCh:
case <-subctx.Done():
}
}()
go func() {
defer s.wg.Done()
s.consumeSubscription(subctx, queueInfo, deadLettersQueueInfo)
}()
default:
// Do nothing, it means the poller is already running
}
// Watch for subscription context cancellation to remove this subscription
s.wg.Add(1)
go func() {
defer s.wg.Done()
select {
case <-ctx.Done():
case <-s.closeCh:
}
s.topicsLock.Lock()
defer s.topicsLock.Unlock()
// Remove the handler
delete(s.topicHandlers, sanitizedName)
// If we don't have any topic left, close the poller.
if len(s.topicHandlers) == 0 {
pollerCancel()
}
}()
return nil return nil
} }
@ -920,9 +870,9 @@ func (s *snsSqs) Publish(ctx context.Context, req *pubsub.PublishRequest) error
// client. Blocks until all goroutines have returned. // client. Blocks until all goroutines have returned.
func (s *snsSqs) Close() error { func (s *snsSqs) Close() error {
if s.closed.CompareAndSwap(false, true) { if s.closed.CompareAndSwap(false, true) {
close(s.closeCh) s.subscriptionManager.Close()
} }
s.wg.Wait()
return nil return nil
} }

View File

@ -0,0 +1,179 @@
package snssqs
import (
"context"
"sync"
"github.com/puzpuzpuz/xsync/v3"
"github.com/dapr/components-contrib/pubsub"
"github.com/dapr/kit/logger"
)
type (
SubscriptionAction int
)
const (
Subscribe SubscriptionAction = iota
Unsubscribe
)
type SubscriptionTopicHandler struct {
topic string
requestTopic string
handler pubsub.Handler
ctx context.Context
}
type changeSubscriptionTopicHandler struct {
action SubscriptionAction
handler *SubscriptionTopicHandler
}
type SubscriptionManager struct {
logger logger.Logger
consumeCancelFunc context.CancelFunc
closeCh chan struct{}
topicsChangeCh chan changeSubscriptionTopicHandler
topicsHandlers *xsync.MapOf[string, *SubscriptionTopicHandler]
lock sync.Mutex
wg sync.WaitGroup
initOnce sync.Once
}
type SubscriptionManagement interface {
Init(queueInfo *sqsQueueInfo, dlqInfo *sqsQueueInfo, cbk func(context.Context, *sqsQueueInfo, *sqsQueueInfo))
Subscribe(topicHandler *SubscriptionTopicHandler)
Close()
GetSubscriptionTopicHandler(topic string) (*SubscriptionTopicHandler, bool)
}
func NewSubscriptionMgmt(log logger.Logger) SubscriptionManagement {
return &SubscriptionManager{
logger: log,
consumeCancelFunc: func() {}, // noop until we (re)start sqs consumption
closeCh: make(chan struct{}),
topicsChangeCh: make(chan changeSubscriptionTopicHandler),
topicsHandlers: xsync.NewMapOf[string, *SubscriptionTopicHandler](),
}
}
func createQueueConsumerCbk(queueInfo *sqsQueueInfo, dlqInfo *sqsQueueInfo, cbk func(ctx context.Context, queueInfo *sqsQueueInfo, dlqInfo *sqsQueueInfo)) func(ctx context.Context) {
return func(ctx context.Context) {
cbk(ctx, queueInfo, dlqInfo)
}
}
func (sm *SubscriptionManager) Init(queueInfo *sqsQueueInfo, dlqInfo *sqsQueueInfo, cbk func(context.Context, *sqsQueueInfo, *sqsQueueInfo)) {
sm.initOnce.Do(func() {
queueConsumerCbk := createQueueConsumerCbk(queueInfo, dlqInfo, cbk)
go sm.queueConsumerController(queueConsumerCbk)
sm.logger.Debug("Subscription manager initialized")
})
}
// queueConsumerController is responsible for managing the subscription lifecycle
// and the only place where the topicsHandlers map is updated.
// it is running in a separate goroutine and is responsible for starting and stopping sqs consumption
// where its lifecycle is managed by the subscription manager,
// and it has its own context with its child contexts used for sqs consumption and aborting of the consumption.
// it is also responsible for managing the lifecycle of the subscription handlers.
func (sm *SubscriptionManager) queueConsumerController(queueConsumerCbk func(context.Context)) {
ctx := context.Background()
for {
select {
case changeEvent := <-sm.topicsChangeCh:
topic := changeEvent.handler.topic
sm.logger.Debugf("Subscription change event received with action: %v, on topic: %s", changeEvent.action, topic)
// topic change events are serialized so that no interleaving can occur
sm.lock.Lock()
// although we have a lock here, the topicsHandlers map is thread safe and can be accessed concurrently so other subscribers that are already consuming messages
// can get the handler for the topic while we're still updating the map without blocking them
current := sm.topicsHandlers.Size()
switch changeEvent.action {
case Subscribe:
sm.topicsHandlers.Store(topic, changeEvent.handler)
// if before we've added the subscription there were no subscriptions, this subscribe signals us to start consuming from sqs
if current == 0 {
var subCtx context.Context
// create a new context for sqs consumption with a cancel func to be used when we unsubscribe from all topics
subCtx, sm.consumeCancelFunc = context.WithCancel(ctx)
// start sqs consumption
sm.logger.Info("Starting SQS consumption")
go queueConsumerCbk(subCtx)
}
case Unsubscribe:
sm.topicsHandlers.Delete(topic)
// for idempotency, we check the size of the map after the delete operation, as we might have already deleted the subscription
afterDelete := sm.topicsHandlers.Size()
// if before we've removed this subscription we had one (last) subscription, this signals us to stop sqs consumption
if current == 1 && afterDelete == 0 {
sm.logger.Info("Last subscription removed. no more handlers are mapped to topics. stopping SQS consumption")
sm.consumeCancelFunc()
}
}
sm.lock.Unlock()
case <-sm.closeCh:
return
}
}
}
func (sm *SubscriptionManager) Subscribe(topicHandler *SubscriptionTopicHandler) {
sm.logger.Debug("Subscribing to topic: ", topicHandler.topic)
sm.wg.Add(1)
go func() {
defer sm.wg.Done()
sm.createSubscribeListener(topicHandler)
}()
}
func (sm *SubscriptionManager) createSubscribeListener(topicHandler *SubscriptionTopicHandler) {
sm.logger.Debug("Creating a subscribe listener for topic: ", topicHandler.topic)
sm.topicsChangeCh <- changeSubscriptionTopicHandler{Subscribe, topicHandler}
closeCh := make(chan struct{})
// the unsubscriber is expected to be terminated by the dapr runtime as it cancels the context upon unsubscribe
go sm.createUnsubscribeListener(topicHandler.ctx, topicHandler.topic, closeCh)
// if the SubscriptinoManager is being closed and somehow the dapr runtime did not call unsubscribe, we close the control
// channel here to terminate the unsubscriber and return
defer close(closeCh)
<-sm.closeCh
}
// ctx is a context provided by daprd per subscription. unrelated to the consuming sm.baseCtx
func (sm *SubscriptionManager) createUnsubscribeListener(ctx context.Context, topic string, closeCh <-chan struct{}) {
sm.logger.Debug("Creating an unsubscribe listener for topic: ", topic)
defer sm.unsubscribe(topic)
for {
select {
case <-ctx.Done():
return
case <-closeCh:
return
}
}
}
func (sm *SubscriptionManager) unsubscribe(topic string) {
sm.logger.Debug("Unsubscribing from topic: ", topic)
if value, ok := sm.GetSubscriptionTopicHandler(topic); ok {
sm.topicsChangeCh <- changeSubscriptionTopicHandler{Unsubscribe, value}
}
}
func (sm *SubscriptionManager) Close() {
close(sm.closeCh)
sm.wg.Wait()
}
func (sm *SubscriptionManager) GetSubscriptionTopicHandler(topic string) (*SubscriptionTopicHandler, bool) {
return sm.topicsHandlers.Load(topic)
}

View File

@ -0,0 +1,44 @@
package snssqs
import (
"sync"
"github.com/puzpuzpuz/xsync/v3"
)
// TopicsLockManager is a singleton for fine-grained locking, to prevent the component r/w operations
// from locking the entire component out when performing operations on different topics.
type TopicsLockManager struct {
xLockMap *xsync.MapOf[string, *sync.Mutex]
}
type TopicsLocker interface {
Lock(topic string) *sync.Mutex
Unlock(topic string)
}
func NewLockManager() *TopicsLockManager {
return &TopicsLockManager{xLockMap: xsync.NewMapOf[string, *sync.Mutex]()}
}
func (lm *TopicsLockManager) Lock(key string) *sync.Mutex {
lock, _ := lm.xLockMap.LoadOrCompute(key, func() *sync.Mutex {
l := &sync.Mutex{}
l.Lock()
return l
})
return lock
}
func (lm *TopicsLockManager) Unlock(key string) {
lm.xLockMap.Compute(key, func(oldValue *sync.Mutex, exists bool) (newValue *sync.Mutex, delete bool) {
// if exists then the mutex must be already locked, and we unlock it
if exists {
oldValue.Unlock()
}
// we return to comply with the Compute signature, but not using the returned values
return oldValue, false
})
}

View File

@ -227,6 +227,7 @@ require (
github.com/prometheus/common v0.44.0 // indirect github.com/prometheus/common v0.44.0 // indirect
github.com/prometheus/procfs v0.11.0 // indirect github.com/prometheus/procfs v0.11.0 // indirect
github.com/prometheus/statsd_exporter v0.22.7 // indirect github.com/prometheus/statsd_exporter v0.22.7 // indirect
github.com/puzpuzpuz/xsync/v3 v3.0.0 // indirect
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 // indirect github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 // indirect
github.com/redis/go-redis/v9 v9.2.1 // indirect github.com/redis/go-redis/v9 v9.2.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect

View File

@ -1108,6 +1108,8 @@ github.com/prometheus/statsd_exporter v0.21.0/go.mod h1:rbT83sZq2V+p73lHhPZfMc3M
github.com/prometheus/statsd_exporter v0.22.7 h1:7Pji/i2GuhK6Lu7DHrtTkFmNBCudCPT1pX2CziuyQR0= github.com/prometheus/statsd_exporter v0.22.7 h1:7Pji/i2GuhK6Lu7DHrtTkFmNBCudCPT1pX2CziuyQR0=
github.com/prometheus/statsd_exporter v0.22.7/go.mod h1:N/TevpjkIh9ccs6nuzY3jQn9dFqnUakOjnEuMPJJJnI= github.com/prometheus/statsd_exporter v0.22.7/go.mod h1:N/TevpjkIh9ccs6nuzY3jQn9dFqnUakOjnEuMPJJJnI=
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
github.com/puzpuzpuz/xsync/v3 v3.0.0 h1:QwUcmah+dZZxy6va/QSU26M6O6Q422afP9jO8JlnRSA=
github.com/puzpuzpuz/xsync/v3 v3.0.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
github.com/rabbitmq/amqp091-go v1.8.1 h1:RejT1SBUim5doqcL6s7iN6SBmsQqyTgXb1xMlH0h1hA= github.com/rabbitmq/amqp091-go v1.8.1 h1:RejT1SBUim5doqcL6s7iN6SBmsQqyTgXb1xMlH0h1hA=
github.com/rabbitmq/amqp091-go v1.8.1/go.mod h1:+jPrT9iY2eLjRaMSRHUhc3z14E/l85kv/f+6luSD3pc= github.com/rabbitmq/amqp091-go v1.8.1/go.mod h1:+jPrT9iY2eLjRaMSRHUhc3z14E/l85kv/f+6luSD3pc=
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=