FIx SQS: Error during Dapr Shutdown and Message Polling Behavior Issue (AWS) (#3174)
Signed-off-by: Amit Mor <amit.mor@hotmail.com> Signed-off-by: Amit Mor <amitm@at-bay.com> Signed-off-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com> Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
parent
d47d1a7aad
commit
7fd5524c58
1
go.mod
1
go.mod
|
|
@ -90,6 +90,7 @@ require (
|
|||
github.com/oracle/oci-go-sdk/v54 v54.0.0
|
||||
github.com/pashagolub/pgxmock/v2 v2.12.0
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
github.com/puzpuzpuz/xsync/v3 v3.0.0
|
||||
github.com/rabbitmq/amqp091-go v1.8.1
|
||||
github.com/redis/go-redis/v9 v9.2.1
|
||||
github.com/sendgrid/sendgrid-go v3.13.0+incompatible
|
||||
|
|
|
|||
2
go.sum
2
go.sum
|
|
@ -1732,6 +1732,8 @@ github.com/prometheus/statsd_exporter v0.21.0/go.mod h1:rbT83sZq2V+p73lHhPZfMc3M
|
|||
github.com/prometheus/statsd_exporter v0.22.7 h1:7Pji/i2GuhK6Lu7DHrtTkFmNBCudCPT1pX2CziuyQR0=
|
||||
github.com/prometheus/statsd_exporter v0.22.7/go.mod h1:N/TevpjkIh9ccs6nuzY3jQn9dFqnUakOjnEuMPJJJnI=
|
||||
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
|
||||
github.com/puzpuzpuz/xsync/v3 v3.0.0 h1:QwUcmah+dZZxy6va/QSU26M6O6Q422afP9jO8JlnRSA=
|
||||
github.com/puzpuzpuz/xsync/v3 v3.0.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
|
||||
github.com/rabbitmq/amqp091-go v1.8.1 h1:RejT1SBUim5doqcL6s7iN6SBmsQqyTgXb1xMlH0h1hA=
|
||||
github.com/rabbitmq/amqp091-go v1.8.1/go.mod h1:+jPrT9iY2eLjRaMSRHUhc3z14E/l85kv/f+6luSD3pc=
|
||||
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
|
||||
|
|
|
|||
|
|
@ -41,23 +41,14 @@ import (
|
|||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
||||
type topicHandler struct {
|
||||
topicName string
|
||||
handler pubsub.Handler
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
type snsSqs struct {
|
||||
topicsLocker TopicsLocker
|
||||
// key is the sanitized topic name
|
||||
topicArns map[string]string
|
||||
// key is the sanitized topic name
|
||||
topicHandlers map[string]topicHandler
|
||||
topicsLock sync.RWMutex
|
||||
// key is the topic name, value holds the ARN of the queue and its url.
|
||||
queues sync.Map
|
||||
queues map[string]*sqsQueueInfo
|
||||
// key is a composite key of queue ARN and topic ARN mapping to subscription ARN.
|
||||
subscriptions sync.Map
|
||||
|
||||
subscriptions map[string]string
|
||||
snsClient *sns.SNS
|
||||
sqsClient *sqs.SQS
|
||||
stsClient *sts.STS
|
||||
|
|
@ -66,11 +57,8 @@ type snsSqs struct {
|
|||
id string
|
||||
opsTimeout time.Duration
|
||||
backOffConfig retry.Config
|
||||
pollerRunning chan struct{}
|
||||
|
||||
closeCh chan struct{}
|
||||
subscriptionManager SubscriptionManagement
|
||||
closed atomic.Bool
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
type sqsQueueInfo struct {
|
||||
|
|
@ -107,9 +95,6 @@ func NewSnsSqs(l logger.Logger) pubsub.PubSub {
|
|||
return &snsSqs{
|
||||
logger: l,
|
||||
id: id,
|
||||
topicsLock: sync.RWMutex{},
|
||||
pollerRunning: make(chan struct{}, 1),
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -160,13 +145,6 @@ func (s *snsSqs) Init(ctx context.Context, metadata pubsub.Metadata) error {
|
|||
|
||||
s.metadata = md
|
||||
|
||||
// both Publish and Subscribe need reference the topic ARN, queue ARN and subscription ARN between topic and queue
|
||||
// track these ARNs in these maps.
|
||||
s.topicArns = make(map[string]string)
|
||||
s.topicHandlers = make(map[string]topicHandler)
|
||||
s.queues = sync.Map{}
|
||||
s.subscriptions = sync.Map{}
|
||||
|
||||
sess, err := awsAuth.GetClient(md.AccessKey, md.SecretKey, md.SessionToken, md.Region, md.Endpoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating an AWS client: %w", err)
|
||||
|
|
@ -189,6 +167,13 @@ func (s *snsSqs) Init(ctx context.Context, metadata pubsub.Metadata) error {
|
|||
if err != nil {
|
||||
return fmt.Errorf("error decoding backOff config: %w", err)
|
||||
}
|
||||
// subscription manager responsible for managing the lifecycle of subscriptions.
|
||||
s.subscriptionManager = NewSubscriptionMgmt(s.logger)
|
||||
s.topicsLocker = NewLockManager()
|
||||
|
||||
s.topicArns = make(map[string]string)
|
||||
s.queues = make(map[string]*sqsQueueInfo)
|
||||
s.subscriptions = make(map[string]string)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -243,7 +228,7 @@ func (s *snsSqs) getTopicArn(parentCtx context.Context, topic string) (string, e
|
|||
})
|
||||
cancelFn()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error: %w while getting topic: %v with arn: %v", err, topic, arn)
|
||||
return "", fmt.Errorf("error: %w, while getting (sanitized) topic: %v with arn: %v", err, topic, arn)
|
||||
}
|
||||
|
||||
return *getTopicOutput.Attributes["TopicArn"], nil
|
||||
|
|
@ -251,40 +236,45 @@ func (s *snsSqs) getTopicArn(parentCtx context.Context, topic string) (string, e
|
|||
|
||||
// get the topic ARN from the topics map. If it doesn't exist in the map, try to fetch it from AWS, if it doesn't exist
|
||||
// at all, issue a request to create the topic.
|
||||
func (s *snsSqs) getOrCreateTopic(ctx context.Context, topic string) (topicArn string, sanitizedName string, err error) {
|
||||
s.topicsLock.Lock()
|
||||
defer s.topicsLock.Unlock()
|
||||
func (s *snsSqs) getOrCreateTopic(ctx context.Context, topic string) (topicArn string, sanitizedTopic string, err error) {
|
||||
sanitizedTopic = nameToAWSSanitizedName(topic, s.metadata.Fifo)
|
||||
|
||||
sanitizedName = nameToAWSSanitizedName(topic, s.metadata.Fifo)
|
||||
var loadOK bool
|
||||
if topicArn, loadOK = s.topicArns[sanitizedTopic]; loadOK {
|
||||
if len(topicArn) > 0 {
|
||||
s.logger.Debugf("Found existing topic ARN for topic %s: %s", topic, topicArn)
|
||||
|
||||
topicArnCached, ok := s.topicArns[sanitizedName]
|
||||
if ok && topicArnCached != "" {
|
||||
s.logger.Debugf("found existing topic ARN for topic %s: %s", topic, topicArnCached)
|
||||
return topicArnCached, sanitizedName, nil
|
||||
return topicArn, sanitizedTopic, err
|
||||
} else {
|
||||
err = fmt.Errorf("the ARN for (sanitized) topic: %s was empty", sanitizedTopic)
|
||||
|
||||
return topicArn, sanitizedTopic, err
|
||||
}
|
||||
}
|
||||
|
||||
// creating queues is idempotent, the names serve as unique keys among a given region.
|
||||
s.logger.Debugf("No SNS topic arn found for %s\nCreating SNS topic", topic)
|
||||
s.logger.Debugf("No SNS topic ARN found for topic: %s. creating SNS with (sanitized) topic: %s", topic, sanitizedTopic)
|
||||
|
||||
if !s.metadata.DisableEntityManagement {
|
||||
topicArn, err = s.createTopic(ctx, sanitizedName)
|
||||
topicArn, err = s.createTopic(ctx, sanitizedTopic)
|
||||
if err != nil {
|
||||
s.logger.Errorf("error creating new topic %s: %w", topic, err)
|
||||
err = fmt.Errorf("error creating new (sanitized) topic '%s': %w", topic, err)
|
||||
|
||||
return "", "", err
|
||||
return topicArn, sanitizedTopic, err
|
||||
}
|
||||
} else {
|
||||
topicArn, err = s.getTopicArn(ctx, sanitizedName)
|
||||
topicArn, err = s.getTopicArn(ctx, sanitizedTopic)
|
||||
if err != nil {
|
||||
s.logger.Errorf("error fetching info for topic %s: %w", topic, err)
|
||||
err = fmt.Errorf("error fetching info for (sanitized) topic: %s. wrapped error is: %w", topic, err)
|
||||
|
||||
return "", "", err
|
||||
return topicArn, sanitizedTopic, err
|
||||
}
|
||||
}
|
||||
|
||||
// record topic ARN.
|
||||
s.topicArns[sanitizedName] = topicArn
|
||||
s.topicArns[sanitizedTopic] = topicArn
|
||||
|
||||
return topicArn, sanitizedName, nil
|
||||
return topicArn, sanitizedTopic, err
|
||||
}
|
||||
|
||||
func (s *snsSqs) createQueue(parentCtx context.Context, queueName string) (*sqsQueueInfo, error) {
|
||||
|
|
@ -346,13 +336,13 @@ func (s *snsSqs) getOrCreateQueue(ctx context.Context, queueName string) (*sqsQu
|
|||
queueInfo *sqsQueueInfo
|
||||
)
|
||||
|
||||
if cachedQueueInfo, ok := s.queues.Load(queueName); ok {
|
||||
s.logger.Debugf("Found queue arn for %s: %s", queueName, cachedQueueInfo.(*sqsQueueInfo).arn)
|
||||
if cachedQueueInfo, ok := s.queues[queueName]; ok {
|
||||
s.logger.Debugf("Found queue ARN for %s: %s", queueName, cachedQueueInfo.arn)
|
||||
|
||||
return cachedQueueInfo.(*sqsQueueInfo), nil
|
||||
return cachedQueueInfo, nil
|
||||
}
|
||||
// creating queues is idempotent, the names serve as unique keys among a given region.
|
||||
s.logger.Debugf("No SQS queue arn found for %s\nCreating SQS queue", queueName)
|
||||
s.logger.Debugf("No SQS queue ARN found for %s\nCreating SQS queue", queueName)
|
||||
|
||||
sanitizedName := nameToAWSSanitizedName(queueName, s.metadata.Fifo)
|
||||
|
||||
|
|
@ -372,8 +362,8 @@ func (s *snsSqs) getOrCreateQueue(ctx context.Context, queueName string) (*sqsQu
|
|||
}
|
||||
}
|
||||
|
||||
s.queues.Store(queueName, queueInfo)
|
||||
s.logger.Debugf("Created SQS queue: %s: with arn: %s", queueName, queueInfo.arn)
|
||||
s.queues[queueName] = queueInfo
|
||||
s.logger.Debugf("created SQS queue: %s: with arn: %s", queueName, queueInfo.arn)
|
||||
|
||||
return queueInfo, nil
|
||||
}
|
||||
|
|
@ -429,13 +419,13 @@ func (s *snsSqs) getSnsSqsSubscriptionArn(parentCtx context.Context, topicArn st
|
|||
|
||||
func (s *snsSqs) getOrCreateSnsSqsSubscription(ctx context.Context, queueArn, topicArn string) (subscriptionArn string, err error) {
|
||||
compositeKey := fmt.Sprintf("%s:%s", queueArn, topicArn)
|
||||
if cachedSubscriptionArn, ok := s.subscriptions.Load(compositeKey); ok {
|
||||
if cachedSubscriptionArn, ok := s.subscriptions[compositeKey]; ok {
|
||||
s.logger.Debugf("Found subscription of queue arn: %s to topic arn: %s: %s", queueArn, topicArn, cachedSubscriptionArn)
|
||||
|
||||
return cachedSubscriptionArn.(string), nil
|
||||
return cachedSubscriptionArn, nil
|
||||
}
|
||||
|
||||
s.logger.Debugf("No subscription arn found of queue arn:%s to topic arn: %s\nCreating subscription", queueArn, topicArn)
|
||||
s.logger.Debugf("No subscription ARN found of queue arn:%s to topic arn: %s\nCreating subscription", queueArn, topicArn)
|
||||
|
||||
if !s.metadata.DisableEntityManagement {
|
||||
subscriptionArn, err = s.createSnsSqsSubscription(ctx, queueArn, topicArn)
|
||||
|
|
@ -447,13 +437,13 @@ func (s *snsSqs) getOrCreateSnsSqsSubscription(ctx context.Context, queueArn, to
|
|||
} else {
|
||||
subscriptionArn, err = s.getSnsSqsSubscriptionArn(ctx, topicArn)
|
||||
if err != nil {
|
||||
s.logger.Errorf("error fetching info for topic arn %s: %w", topicArn, err)
|
||||
s.logger.Errorf("error fetching info for topic ARN %s: %w", topicArn, err)
|
||||
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
s.subscriptions.Store(compositeKey, subscriptionArn)
|
||||
s.subscriptions[compositeKey] = subscriptionArn
|
||||
s.logger.Debugf("Subscribed to topic %s: %s", topicArn, subscriptionArn)
|
||||
|
||||
return subscriptionArn, nil
|
||||
|
|
@ -555,18 +545,25 @@ func (s *snsSqs) callHandler(ctx context.Context, message *sqs.Message, queueInf
|
|||
// for the user to be able to understand the source of the coming message, we'd use the original,
|
||||
// dirty name to be carried over in the pubsub.NewMessage Topic field.
|
||||
sanitizedTopic := snsMessagePayload.parseTopicArn()
|
||||
s.topicsLock.RLock()
|
||||
handler, ok := s.topicHandlers[sanitizedTopic]
|
||||
s.topicsLock.RUnlock()
|
||||
if !ok || handler.topicName == "" {
|
||||
return fmt.Errorf("handler for topic (sanitized): %s not found", sanitizedTopic)
|
||||
// get a handler by sanitized topic name and perform validations
|
||||
var (
|
||||
handler *SubscriptionTopicHandler
|
||||
loadOK bool
|
||||
)
|
||||
if handler, loadOK = s.subscriptionManager.GetSubscriptionTopicHandler(sanitizedTopic); loadOK {
|
||||
if len(handler.requestTopic) == 0 {
|
||||
return fmt.Errorf("handler topic name is missing")
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("handler for (sanitized) topic: %s was not found", sanitizedTopic)
|
||||
}
|
||||
|
||||
s.logger.Debugf("Processing SNS message id: %s of topic: %s", *message.MessageId, sanitizedTopic)
|
||||
s.logger.Debugf("Processing SNS message id: %s of (sanitized) topic: %s", *message.MessageId, sanitizedTopic)
|
||||
|
||||
// call the handler with its own subscription context
|
||||
err = handler.handler(handler.ctx, &pubsub.NewMessage{
|
||||
Data: []byte(snsMessagePayload.Message),
|
||||
Topic: handler.topicName,
|
||||
Topic: handler.requestTopic,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error handling message: %w", err)
|
||||
|
|
@ -575,6 +572,8 @@ func (s *snsSqs) callHandler(ctx context.Context, message *sqs.Message, queueInf
|
|||
return s.acknowledgeMessage(ctx, queueInfo.url, message.ReceiptHandle)
|
||||
}
|
||||
|
||||
// consumeSubscription is responsible for polling messages from the queue and calling the handler.
|
||||
// it is being passed as a callback to the subscription manager that initializes the context of the handler.
|
||||
func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLettersQueueInfo *sqsQueueInfo) {
|
||||
sqsPullExponentialBackoff := s.backOffConfig.NewBackOffWithContext(ctx)
|
||||
|
||||
|
|
@ -601,12 +600,13 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters
|
|||
// iteration. Therefore, a global backoff (to the internal backoff) is used (sqsPullExponentialBackoff).
|
||||
messageResponse, err := s.sqsClient.ReceiveMessageWithContext(ctx, receiveMessageInput)
|
||||
if err != nil {
|
||||
if err == context.Canceled || err == context.DeadlineExceeded {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || ctx.Err() != nil {
|
||||
s.logger.Warn("context canceled; stopping consuming from queue arn: %v", queueInfo.arn)
|
||||
continue
|
||||
}
|
||||
|
||||
if awsErr, ok := err.(awserr.Error); ok {
|
||||
var awsErr awserr.Error
|
||||
if errors.As(err, &awsErr) {
|
||||
s.logger.Errorf("AWS operation error while consuming from queue arn: %v with error: %w. retrying...", queueInfo.arn, awsErr.Error())
|
||||
} else {
|
||||
s.logger.Errorf("error consuming from queue arn: %v with error: %w. retrying...", queueInfo.arn, err)
|
||||
|
|
@ -619,7 +619,6 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters
|
|||
sqsPullExponentialBackoff.Reset()
|
||||
|
||||
if len(messageResponse.Messages) < 1 {
|
||||
// s.logger.Debug("No messages received, continuing")
|
||||
continue
|
||||
}
|
||||
s.logger.Debugf("%v message(s) received on queue %s", len(messageResponse.Messages), queueInfo.arn)
|
||||
|
|
@ -632,11 +631,10 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters
|
|||
}
|
||||
|
||||
f := func(message *sqs.Message) {
|
||||
defer wg.Done()
|
||||
if err := s.callHandler(ctx, message, queueInfo); err != nil {
|
||||
s.logger.Errorf("error while handling received message. error is: %v", err)
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
|
|
@ -653,9 +651,6 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters
|
|||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Signal that the poller stopped
|
||||
<-s.pollerRunning
|
||||
}
|
||||
|
||||
func (s *snsSqs) createDeadLettersQueueAttributes(queueInfo, deadLettersQueueInfo *sqsQueueInfo) (*sqs.SetQueueAttributesInput, error) {
|
||||
|
|
@ -763,6 +758,9 @@ func (s *snsSqs) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, han
|
|||
return errors.New("component is closed")
|
||||
}
|
||||
|
||||
s.topicsLocker.Lock(req.Topic)
|
||||
defer s.topicsLocker.Unlock(req.Topic)
|
||||
|
||||
// subscribers declare a topic ARN and declare a SQS queue to use
|
||||
// these should be idempotent - queues should not be created if they exist.
|
||||
topicArn, sanitizedName, err := s.getOrCreateTopic(ctx, req.Topic)
|
||||
|
|
@ -824,63 +822,15 @@ func (s *snsSqs) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, han
|
|||
return wrappedErr
|
||||
}
|
||||
|
||||
// Store the handler for this topic
|
||||
s.topicsLock.Lock()
|
||||
defer s.topicsLock.Unlock()
|
||||
s.topicHandlers[sanitizedName] = topicHandler{
|
||||
topicName: req.Topic,
|
||||
// start the subscription manager
|
||||
s.subscriptionManager.Init(queueInfo, deadLettersQueueInfo, s.consumeSubscription)
|
||||
|
||||
s.subscriptionManager.Subscribe(&SubscriptionTopicHandler{
|
||||
topic: sanitizedName,
|
||||
requestTopic: req.Topic,
|
||||
handler: handler,
|
||||
ctx: ctx,
|
||||
}
|
||||
|
||||
// pollerCancel is used to cancel the polling goroutine. We use a noop cancel
|
||||
// func in case the poller is already running and there is no cancel to use
|
||||
// from the select below.
|
||||
var pollerCancel context.CancelFunc = func() {}
|
||||
// Start the poller for the queue if it's not running already
|
||||
select {
|
||||
case s.pollerRunning <- struct{}{}:
|
||||
// If inserting in the channel succeeds, then it's not running already
|
||||
// Use a context that is tied to the background context
|
||||
var subctx context.Context
|
||||
subctx, pollerCancel = context.WithCancel(context.Background())
|
||||
s.wg.Add(2)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
defer pollerCancel()
|
||||
select {
|
||||
case <-s.closeCh:
|
||||
case <-subctx.Done():
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.consumeSubscription(subctx, queueInfo, deadLettersQueueInfo)
|
||||
}()
|
||||
default:
|
||||
// Do nothing, it means the poller is already running
|
||||
}
|
||||
|
||||
// Watch for subscription context cancellation to remove this subscription
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-s.closeCh:
|
||||
}
|
||||
|
||||
s.topicsLock.Lock()
|
||||
defer s.topicsLock.Unlock()
|
||||
|
||||
// Remove the handler
|
||||
delete(s.topicHandlers, sanitizedName)
|
||||
|
||||
// If we don't have any topic left, close the poller.
|
||||
if len(s.topicHandlers) == 0 {
|
||||
pollerCancel()
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -920,9 +870,9 @@ func (s *snsSqs) Publish(ctx context.Context, req *pubsub.PublishRequest) error
|
|||
// client. Blocks until all goroutines have returned.
|
||||
func (s *snsSqs) Close() error {
|
||||
if s.closed.CompareAndSwap(false, true) {
|
||||
close(s.closeCh)
|
||||
s.subscriptionManager.Close()
|
||||
}
|
||||
s.wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,179 @@
|
|||
package snssqs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/puzpuzpuz/xsync/v3"
|
||||
|
||||
"github.com/dapr/components-contrib/pubsub"
|
||||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
||||
type (
|
||||
SubscriptionAction int
|
||||
)
|
||||
|
||||
const (
|
||||
Subscribe SubscriptionAction = iota
|
||||
Unsubscribe
|
||||
)
|
||||
|
||||
type SubscriptionTopicHandler struct {
|
||||
topic string
|
||||
requestTopic string
|
||||
handler pubsub.Handler
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
type changeSubscriptionTopicHandler struct {
|
||||
action SubscriptionAction
|
||||
handler *SubscriptionTopicHandler
|
||||
}
|
||||
|
||||
type SubscriptionManager struct {
|
||||
logger logger.Logger
|
||||
consumeCancelFunc context.CancelFunc
|
||||
closeCh chan struct{}
|
||||
topicsChangeCh chan changeSubscriptionTopicHandler
|
||||
topicsHandlers *xsync.MapOf[string, *SubscriptionTopicHandler]
|
||||
lock sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
initOnce sync.Once
|
||||
}
|
||||
|
||||
type SubscriptionManagement interface {
|
||||
Init(queueInfo *sqsQueueInfo, dlqInfo *sqsQueueInfo, cbk func(context.Context, *sqsQueueInfo, *sqsQueueInfo))
|
||||
Subscribe(topicHandler *SubscriptionTopicHandler)
|
||||
Close()
|
||||
GetSubscriptionTopicHandler(topic string) (*SubscriptionTopicHandler, bool)
|
||||
}
|
||||
|
||||
func NewSubscriptionMgmt(log logger.Logger) SubscriptionManagement {
|
||||
return &SubscriptionManager{
|
||||
logger: log,
|
||||
consumeCancelFunc: func() {}, // noop until we (re)start sqs consumption
|
||||
closeCh: make(chan struct{}),
|
||||
topicsChangeCh: make(chan changeSubscriptionTopicHandler),
|
||||
topicsHandlers: xsync.NewMapOf[string, *SubscriptionTopicHandler](),
|
||||
}
|
||||
}
|
||||
|
||||
func createQueueConsumerCbk(queueInfo *sqsQueueInfo, dlqInfo *sqsQueueInfo, cbk func(ctx context.Context, queueInfo *sqsQueueInfo, dlqInfo *sqsQueueInfo)) func(ctx context.Context) {
|
||||
return func(ctx context.Context) {
|
||||
cbk(ctx, queueInfo, dlqInfo)
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *SubscriptionManager) Init(queueInfo *sqsQueueInfo, dlqInfo *sqsQueueInfo, cbk func(context.Context, *sqsQueueInfo, *sqsQueueInfo)) {
|
||||
sm.initOnce.Do(func() {
|
||||
queueConsumerCbk := createQueueConsumerCbk(queueInfo, dlqInfo, cbk)
|
||||
go sm.queueConsumerController(queueConsumerCbk)
|
||||
sm.logger.Debug("Subscription manager initialized")
|
||||
})
|
||||
}
|
||||
|
||||
// queueConsumerController is responsible for managing the subscription lifecycle
|
||||
// and the only place where the topicsHandlers map is updated.
|
||||
// it is running in a separate goroutine and is responsible for starting and stopping sqs consumption
|
||||
// where its lifecycle is managed by the subscription manager,
|
||||
// and it has its own context with its child contexts used for sqs consumption and aborting of the consumption.
|
||||
// it is also responsible for managing the lifecycle of the subscription handlers.
|
||||
func (sm *SubscriptionManager) queueConsumerController(queueConsumerCbk func(context.Context)) {
|
||||
ctx := context.Background()
|
||||
|
||||
for {
|
||||
select {
|
||||
case changeEvent := <-sm.topicsChangeCh:
|
||||
topic := changeEvent.handler.topic
|
||||
sm.logger.Debugf("Subscription change event received with action: %v, on topic: %s", changeEvent.action, topic)
|
||||
// topic change events are serialized so that no interleaving can occur
|
||||
sm.lock.Lock()
|
||||
// although we have a lock here, the topicsHandlers map is thread safe and can be accessed concurrently so other subscribers that are already consuming messages
|
||||
// can get the handler for the topic while we're still updating the map without blocking them
|
||||
current := sm.topicsHandlers.Size()
|
||||
|
||||
switch changeEvent.action {
|
||||
case Subscribe:
|
||||
sm.topicsHandlers.Store(topic, changeEvent.handler)
|
||||
// if before we've added the subscription there were no subscriptions, this subscribe signals us to start consuming from sqs
|
||||
if current == 0 {
|
||||
var subCtx context.Context
|
||||
// create a new context for sqs consumption with a cancel func to be used when we unsubscribe from all topics
|
||||
subCtx, sm.consumeCancelFunc = context.WithCancel(ctx)
|
||||
// start sqs consumption
|
||||
sm.logger.Info("Starting SQS consumption")
|
||||
go queueConsumerCbk(subCtx)
|
||||
}
|
||||
case Unsubscribe:
|
||||
sm.topicsHandlers.Delete(topic)
|
||||
// for idempotency, we check the size of the map after the delete operation, as we might have already deleted the subscription
|
||||
afterDelete := sm.topicsHandlers.Size()
|
||||
// if before we've removed this subscription we had one (last) subscription, this signals us to stop sqs consumption
|
||||
if current == 1 && afterDelete == 0 {
|
||||
sm.logger.Info("Last subscription removed. no more handlers are mapped to topics. stopping SQS consumption")
|
||||
sm.consumeCancelFunc()
|
||||
}
|
||||
}
|
||||
|
||||
sm.lock.Unlock()
|
||||
case <-sm.closeCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *SubscriptionManager) Subscribe(topicHandler *SubscriptionTopicHandler) {
|
||||
sm.logger.Debug("Subscribing to topic: ", topicHandler.topic)
|
||||
|
||||
sm.wg.Add(1)
|
||||
go func() {
|
||||
defer sm.wg.Done()
|
||||
sm.createSubscribeListener(topicHandler)
|
||||
}()
|
||||
}
|
||||
|
||||
func (sm *SubscriptionManager) createSubscribeListener(topicHandler *SubscriptionTopicHandler) {
|
||||
sm.logger.Debug("Creating a subscribe listener for topic: ", topicHandler.topic)
|
||||
|
||||
sm.topicsChangeCh <- changeSubscriptionTopicHandler{Subscribe, topicHandler}
|
||||
closeCh := make(chan struct{})
|
||||
// the unsubscriber is expected to be terminated by the dapr runtime as it cancels the context upon unsubscribe
|
||||
go sm.createUnsubscribeListener(topicHandler.ctx, topicHandler.topic, closeCh)
|
||||
// if the SubscriptinoManager is being closed and somehow the dapr runtime did not call unsubscribe, we close the control
|
||||
// channel here to terminate the unsubscriber and return
|
||||
defer close(closeCh)
|
||||
<-sm.closeCh
|
||||
}
|
||||
|
||||
// ctx is a context provided by daprd per subscription. unrelated to the consuming sm.baseCtx
|
||||
func (sm *SubscriptionManager) createUnsubscribeListener(ctx context.Context, topic string, closeCh <-chan struct{}) {
|
||||
sm.logger.Debug("Creating an unsubscribe listener for topic: ", topic)
|
||||
|
||||
defer sm.unsubscribe(topic)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-closeCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *SubscriptionManager) unsubscribe(topic string) {
|
||||
sm.logger.Debug("Unsubscribing from topic: ", topic)
|
||||
|
||||
if value, ok := sm.GetSubscriptionTopicHandler(topic); ok {
|
||||
sm.topicsChangeCh <- changeSubscriptionTopicHandler{Unsubscribe, value}
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *SubscriptionManager) Close() {
|
||||
close(sm.closeCh)
|
||||
sm.wg.Wait()
|
||||
}
|
||||
|
||||
func (sm *SubscriptionManager) GetSubscriptionTopicHandler(topic string) (*SubscriptionTopicHandler, bool) {
|
||||
return sm.topicsHandlers.Load(topic)
|
||||
}
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
package snssqs
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/puzpuzpuz/xsync/v3"
|
||||
)
|
||||
|
||||
// TopicsLockManager is a singleton for fine-grained locking, to prevent the component r/w operations
|
||||
// from locking the entire component out when performing operations on different topics.
|
||||
type TopicsLockManager struct {
|
||||
xLockMap *xsync.MapOf[string, *sync.Mutex]
|
||||
}
|
||||
|
||||
type TopicsLocker interface {
|
||||
Lock(topic string) *sync.Mutex
|
||||
Unlock(topic string)
|
||||
}
|
||||
|
||||
func NewLockManager() *TopicsLockManager {
|
||||
return &TopicsLockManager{xLockMap: xsync.NewMapOf[string, *sync.Mutex]()}
|
||||
}
|
||||
|
||||
func (lm *TopicsLockManager) Lock(key string) *sync.Mutex {
|
||||
lock, _ := lm.xLockMap.LoadOrCompute(key, func() *sync.Mutex {
|
||||
l := &sync.Mutex{}
|
||||
l.Lock()
|
||||
|
||||
return l
|
||||
})
|
||||
|
||||
return lock
|
||||
}
|
||||
|
||||
func (lm *TopicsLockManager) Unlock(key string) {
|
||||
lm.xLockMap.Compute(key, func(oldValue *sync.Mutex, exists bool) (newValue *sync.Mutex, delete bool) {
|
||||
// if exists then the mutex must be already locked, and we unlock it
|
||||
if exists {
|
||||
oldValue.Unlock()
|
||||
}
|
||||
// we return to comply with the Compute signature, but not using the returned values
|
||||
return oldValue, false
|
||||
})
|
||||
}
|
||||
|
|
@ -227,6 +227,7 @@ require (
|
|||
github.com/prometheus/common v0.44.0 // indirect
|
||||
github.com/prometheus/procfs v0.11.0 // indirect
|
||||
github.com/prometheus/statsd_exporter v0.22.7 // indirect
|
||||
github.com/puzpuzpuz/xsync/v3 v3.0.0 // indirect
|
||||
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 // indirect
|
||||
github.com/redis/go-redis/v9 v9.2.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
|
|
|
|||
|
|
@ -1108,6 +1108,8 @@ github.com/prometheus/statsd_exporter v0.21.0/go.mod h1:rbT83sZq2V+p73lHhPZfMc3M
|
|||
github.com/prometheus/statsd_exporter v0.22.7 h1:7Pji/i2GuhK6Lu7DHrtTkFmNBCudCPT1pX2CziuyQR0=
|
||||
github.com/prometheus/statsd_exporter v0.22.7/go.mod h1:N/TevpjkIh9ccs6nuzY3jQn9dFqnUakOjnEuMPJJJnI=
|
||||
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
|
||||
github.com/puzpuzpuz/xsync/v3 v3.0.0 h1:QwUcmah+dZZxy6va/QSU26M6O6Q422afP9jO8JlnRSA=
|
||||
github.com/puzpuzpuz/xsync/v3 v3.0.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
|
||||
github.com/rabbitmq/amqp091-go v1.8.1 h1:RejT1SBUim5doqcL6s7iN6SBmsQqyTgXb1xMlH0h1hA=
|
||||
github.com/rabbitmq/amqp091-go v1.8.1/go.mod h1:+jPrT9iY2eLjRaMSRHUhc3z14E/l85kv/f+6luSD3pc=
|
||||
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
|
||||
|
|
|
|||
Loading…
Reference in New Issue