snssqs: fix consumer starvation (#3478)

Signed-off-by: Gustavo Chain <me@qustavo.cc>
Signed-off-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
Co-authored-by: Bernd Verst <github@bernd.dev>
Co-authored-by: Yaron Schneider <schneider.yaron@live.com>
This commit is contained in:
Gustavo Chaín 2024-11-26 00:33:22 -03:00 committed by GitHub
parent 2aea31969f
commit 1137759a9b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 49 additions and 6 deletions

View File

@ -57,6 +57,8 @@ type snsSqsMetadata struct {
AccountID string `mapstructure:"accountID"` AccountID string `mapstructure:"accountID"`
// processing concurrency mode // processing concurrency mode
ConcurrencyMode pubsub.ConcurrencyMode `mapstructure:"concurrencyMode"` ConcurrencyMode pubsub.ConcurrencyMode `mapstructure:"concurrencyMode"`
// limits the number of concurrent goroutines
ConcurrencyLimit int `mapstructure:"concurrencyLimit"`
} }
func maskLeft(s string) string { func maskLeft(s string) string {
@ -130,6 +132,10 @@ func (s *snsSqs) getSnsSqsMetadata(meta pubsub.Metadata) (*snsSqsMetadata, error
return nil, err return nil, err
} }
if md.ConcurrencyLimit < 0 {
return nil, errors.New("concurrencyLimit must be greater than or equal to 0")
}
s.logger.Debug(md.hideDebugPrintedCredentials()) s.logger.Debug(md.hideDebugPrintedCredentials())
return md, nil return md, nil

View File

@ -128,6 +128,15 @@ metadata:
default: '"parallel"' default: '"parallel"'
example: '"single", "parallel"' example: '"single", "parallel"'
type: string type: string
- name: concurrencyLimit
required: false
description: |
Defines the maximum number of concurrent workers handling messages.
This value is ignored when "concurrencyMode" is set to “single“.
To avoid limiting the number of concurrent workers set this to “0“.
type: number
default: '0'
example: '100'
- name: accountId - name: accountId
required: false required: false
description: | description: |

View File

@ -595,6 +595,13 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters
WaitTimeSeconds: aws.Int64(s.metadata.MessageWaitTimeSeconds), WaitTimeSeconds: aws.Int64(s.metadata.MessageWaitTimeSeconds),
} }
// sem is a semaphore used to control the concurrencyLimit.
// It is set only when we are in parallel mode and limit is > 0.
var sem chan (struct{}) = nil
if (s.metadata.ConcurrencyMode == pubsub.Parallel) && s.metadata.ConcurrencyLimit > 0 {
sem = make(chan struct{}, s.metadata.ConcurrencyLimit)
}
for { for {
// If the context is canceled, stop requesting messages // If the context is canceled, stop requesting messages
if ctx.Err() != nil { if ctx.Err() != nil {
@ -629,7 +636,6 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters
} }
s.logger.Debugf("%v message(s) received on queue %s", len(messageResponse.Messages), queueInfo.arn) s.logger.Debugf("%v message(s) received on queue %s", len(messageResponse.Messages), queueInfo.arn)
var wg sync.WaitGroup
for _, message := range messageResponse.Messages { for _, message := range messageResponse.Messages {
if err := s.validateMessage(ctx, message, queueInfo, deadLettersQueueInfo); err != nil { if err := s.validateMessage(ctx, message, queueInfo, deadLettersQueueInfo); err != nil {
s.logger.Errorf("message is not valid for further processing by the handler. error is: %v", err) s.logger.Errorf("message is not valid for further processing by the handler. error is: %v", err)
@ -637,25 +643,30 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters
} }
f := func(message *sqs.Message) { f := func(message *sqs.Message) {
defer wg.Done()
if err := s.callHandler(ctx, message, queueInfo); err != nil { if err := s.callHandler(ctx, message, queueInfo); err != nil {
s.logger.Errorf("error while handling received message. error is: %v", err) s.logger.Errorf("error while handling received message. error is: %v", err)
} }
} }
wg.Add(1)
switch s.metadata.ConcurrencyMode { switch s.metadata.ConcurrencyMode {
case pubsub.Single: case pubsub.Single:
f(message) f(message)
case pubsub.Parallel: case pubsub.Parallel:
wg.Add(1) // This is the back pressure mechanism.
// It will block until another goroutine frees a slot.
if sem != nil {
sem <- struct{}{}
}
go func(message *sqs.Message) { go func(message *sqs.Message) {
defer wg.Done() if sem != nil {
defer func() { <-sem }()
}
f(message) f(message)
}(message) }(message)
} }
} }
wg.Wait()
} }
} }

View File

@ -51,6 +51,7 @@ func Test_getSnsSqsMetadata_AllConfiguration(t *testing.T) {
"consumerID": "consumer", "consumerID": "consumer",
"Endpoint": "endpoint", "Endpoint": "endpoint",
"concurrencyMode": string(pubsub.Single), "concurrencyMode": string(pubsub.Single),
"concurrencyLimit": "42",
"accessKey": "a", "accessKey": "a",
"secretKey": "s", "secretKey": "s",
"sessionToken": "t", "sessionToken": "t",
@ -68,6 +69,7 @@ func Test_getSnsSqsMetadata_AllConfiguration(t *testing.T) {
r.Equal("consumer", md.SqsQueueName) r.Equal("consumer", md.SqsQueueName)
r.Equal("endpoint", md.Endpoint) r.Equal("endpoint", md.Endpoint)
r.Equal(pubsub.Single, md.ConcurrencyMode) r.Equal(pubsub.Single, md.ConcurrencyMode)
r.Equal(42, md.ConcurrencyLimit)
r.Equal("a", md.AccessKey) r.Equal("a", md.AccessKey)
r.Equal("s", md.SecretKey) r.Equal("s", md.SecretKey)
r.Equal("t", md.SessionToken) r.Equal("t", md.SessionToken)
@ -105,6 +107,7 @@ func Test_getSnsSqsMetadata_defaults(t *testing.T) {
r.Equal("", md.SessionToken) r.Equal("", md.SessionToken)
r.Equal("r", md.Region) r.Equal("r", md.Region)
r.Equal(pubsub.Parallel, md.ConcurrencyMode) r.Equal(pubsub.Parallel, md.ConcurrencyMode)
r.Equal(0, md.ConcurrencyLimit)
r.Equal(int64(10), md.MessageVisibilityTimeout) r.Equal(int64(10), md.MessageVisibilityTimeout)
r.Equal(int64(10), md.MessageRetryLimit) r.Equal(int64(10), md.MessageRetryLimit)
r.Equal(int64(2), md.MessageWaitTimeSeconds) r.Equal(int64(2), md.MessageWaitTimeSeconds)
@ -273,6 +276,20 @@ func Test_getSnsSqsMetadata_invalidMetadataSetup(t *testing.T) {
}}}, }}},
name: "invalid message concurrencyMode", name: "invalid message concurrencyMode",
}, },
// invalid concurrencyLimit
{
metadata: pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{
"consumerID": "consumer",
"Endpoint": "endpoint",
"AccessKey": "acctId",
"SecretKey": "secret",
"awsToken": "token",
"Region": "region",
"messageRetryLimit": "10",
"concurrencyLimit": "-1",
}}},
name: "invalid message concurrencyLimit",
},
} }
l := logger.NewLogger("SnsSqs unit test") l := logger.NewLogger("SnsSqs unit test")