snssqs: fix consumer starvation (#3478)
Signed-off-by: Gustavo Chain <me@qustavo.cc> Signed-off-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com> Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com> Co-authored-by: Bernd Verst <github@bernd.dev> Co-authored-by: Yaron Schneider <schneider.yaron@live.com>
This commit is contained in:
parent
2aea31969f
commit
1137759a9b
|
|
@ -57,6 +57,8 @@ type snsSqsMetadata struct {
|
||||||
AccountID string `mapstructure:"accountID"`
|
AccountID string `mapstructure:"accountID"`
|
||||||
// processing concurrency mode
|
// processing concurrency mode
|
||||||
ConcurrencyMode pubsub.ConcurrencyMode `mapstructure:"concurrencyMode"`
|
ConcurrencyMode pubsub.ConcurrencyMode `mapstructure:"concurrencyMode"`
|
||||||
|
// limits the number of concurrent goroutines
|
||||||
|
ConcurrencyLimit int `mapstructure:"concurrencyLimit"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func maskLeft(s string) string {
|
func maskLeft(s string) string {
|
||||||
|
|
@ -130,6 +132,10 @@ func (s *snsSqs) getSnsSqsMetadata(meta pubsub.Metadata) (*snsSqsMetadata, error
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if md.ConcurrencyLimit < 0 {
|
||||||
|
return nil, errors.New("concurrencyLimit must be greater than or equal to 0")
|
||||||
|
}
|
||||||
|
|
||||||
s.logger.Debug(md.hideDebugPrintedCredentials())
|
s.logger.Debug(md.hideDebugPrintedCredentials())
|
||||||
|
|
||||||
return md, nil
|
return md, nil
|
||||||
|
|
|
||||||
|
|
@ -128,6 +128,15 @@ metadata:
|
||||||
default: '"parallel"'
|
default: '"parallel"'
|
||||||
example: '"single", "parallel"'
|
example: '"single", "parallel"'
|
||||||
type: string
|
type: string
|
||||||
|
- name: concurrencyLimit
|
||||||
|
required: false
|
||||||
|
description: |
|
||||||
|
Defines the maximum number of concurrent workers handling messages.
|
||||||
|
This value is ignored when "concurrencyMode" is set to “single“.
|
||||||
|
To avoid limiting the number of concurrent workers set this to “0“.
|
||||||
|
type: number
|
||||||
|
default: '0'
|
||||||
|
example: '100'
|
||||||
- name: accountId
|
- name: accountId
|
||||||
required: false
|
required: false
|
||||||
description: |
|
description: |
|
||||||
|
|
|
||||||
|
|
@ -595,6 +595,13 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters
|
||||||
WaitTimeSeconds: aws.Int64(s.metadata.MessageWaitTimeSeconds),
|
WaitTimeSeconds: aws.Int64(s.metadata.MessageWaitTimeSeconds),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sem is a semaphore used to control the concurrencyLimit.
|
||||||
|
// It is set only when we are in parallel mode and limit is > 0.
|
||||||
|
var sem chan (struct{}) = nil
|
||||||
|
if (s.metadata.ConcurrencyMode == pubsub.Parallel) && s.metadata.ConcurrencyLimit > 0 {
|
||||||
|
sem = make(chan struct{}, s.metadata.ConcurrencyLimit)
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// If the context is canceled, stop requesting messages
|
// If the context is canceled, stop requesting messages
|
||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
|
|
@ -629,7 +636,6 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters
|
||||||
}
|
}
|
||||||
s.logger.Debugf("%v message(s) received on queue %s", len(messageResponse.Messages), queueInfo.arn)
|
s.logger.Debugf("%v message(s) received on queue %s", len(messageResponse.Messages), queueInfo.arn)
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for _, message := range messageResponse.Messages {
|
for _, message := range messageResponse.Messages {
|
||||||
if err := s.validateMessage(ctx, message, queueInfo, deadLettersQueueInfo); err != nil {
|
if err := s.validateMessage(ctx, message, queueInfo, deadLettersQueueInfo); err != nil {
|
||||||
s.logger.Errorf("message is not valid for further processing by the handler. error is: %v", err)
|
s.logger.Errorf("message is not valid for further processing by the handler. error is: %v", err)
|
||||||
|
|
@ -637,25 +643,30 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters
|
||||||
}
|
}
|
||||||
|
|
||||||
f := func(message *sqs.Message) {
|
f := func(message *sqs.Message) {
|
||||||
defer wg.Done()
|
|
||||||
if err := s.callHandler(ctx, message, queueInfo); err != nil {
|
if err := s.callHandler(ctx, message, queueInfo); err != nil {
|
||||||
s.logger.Errorf("error while handling received message. error is: %v", err)
|
s.logger.Errorf("error while handling received message. error is: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Add(1)
|
|
||||||
switch s.metadata.ConcurrencyMode {
|
switch s.metadata.ConcurrencyMode {
|
||||||
case pubsub.Single:
|
case pubsub.Single:
|
||||||
f(message)
|
f(message)
|
||||||
case pubsub.Parallel:
|
case pubsub.Parallel:
|
||||||
wg.Add(1)
|
// This is the back pressure mechanism.
|
||||||
|
// It will block until another goroutine frees a slot.
|
||||||
|
if sem != nil {
|
||||||
|
sem <- struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
go func(message *sqs.Message) {
|
go func(message *sqs.Message) {
|
||||||
defer wg.Done()
|
if sem != nil {
|
||||||
|
defer func() { <-sem }()
|
||||||
|
}
|
||||||
|
|
||||||
f(message)
|
f(message)
|
||||||
}(message)
|
}(message)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
wg.Wait()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,7 @@ func Test_getSnsSqsMetadata_AllConfiguration(t *testing.T) {
|
||||||
"consumerID": "consumer",
|
"consumerID": "consumer",
|
||||||
"Endpoint": "endpoint",
|
"Endpoint": "endpoint",
|
||||||
"concurrencyMode": string(pubsub.Single),
|
"concurrencyMode": string(pubsub.Single),
|
||||||
|
"concurrencyLimit": "42",
|
||||||
"accessKey": "a",
|
"accessKey": "a",
|
||||||
"secretKey": "s",
|
"secretKey": "s",
|
||||||
"sessionToken": "t",
|
"sessionToken": "t",
|
||||||
|
|
@ -68,6 +69,7 @@ func Test_getSnsSqsMetadata_AllConfiguration(t *testing.T) {
|
||||||
r.Equal("consumer", md.SqsQueueName)
|
r.Equal("consumer", md.SqsQueueName)
|
||||||
r.Equal("endpoint", md.Endpoint)
|
r.Equal("endpoint", md.Endpoint)
|
||||||
r.Equal(pubsub.Single, md.ConcurrencyMode)
|
r.Equal(pubsub.Single, md.ConcurrencyMode)
|
||||||
|
r.Equal(42, md.ConcurrencyLimit)
|
||||||
r.Equal("a", md.AccessKey)
|
r.Equal("a", md.AccessKey)
|
||||||
r.Equal("s", md.SecretKey)
|
r.Equal("s", md.SecretKey)
|
||||||
r.Equal("t", md.SessionToken)
|
r.Equal("t", md.SessionToken)
|
||||||
|
|
@ -105,6 +107,7 @@ func Test_getSnsSqsMetadata_defaults(t *testing.T) {
|
||||||
r.Equal("", md.SessionToken)
|
r.Equal("", md.SessionToken)
|
||||||
r.Equal("r", md.Region)
|
r.Equal("r", md.Region)
|
||||||
r.Equal(pubsub.Parallel, md.ConcurrencyMode)
|
r.Equal(pubsub.Parallel, md.ConcurrencyMode)
|
||||||
|
r.Equal(0, md.ConcurrencyLimit)
|
||||||
r.Equal(int64(10), md.MessageVisibilityTimeout)
|
r.Equal(int64(10), md.MessageVisibilityTimeout)
|
||||||
r.Equal(int64(10), md.MessageRetryLimit)
|
r.Equal(int64(10), md.MessageRetryLimit)
|
||||||
r.Equal(int64(2), md.MessageWaitTimeSeconds)
|
r.Equal(int64(2), md.MessageWaitTimeSeconds)
|
||||||
|
|
@ -273,6 +276,20 @@ func Test_getSnsSqsMetadata_invalidMetadataSetup(t *testing.T) {
|
||||||
}}},
|
}}},
|
||||||
name: "invalid message concurrencyMode",
|
name: "invalid message concurrencyMode",
|
||||||
},
|
},
|
||||||
|
// invalid concurrencyLimit
|
||||||
|
{
|
||||||
|
metadata: pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{
|
||||||
|
"consumerID": "consumer",
|
||||||
|
"Endpoint": "endpoint",
|
||||||
|
"AccessKey": "acctId",
|
||||||
|
"SecretKey": "secret",
|
||||||
|
"awsToken": "token",
|
||||||
|
"Region": "region",
|
||||||
|
"messageRetryLimit": "10",
|
||||||
|
"concurrencyLimit": "-1",
|
||||||
|
}}},
|
||||||
|
name: "invalid message concurrencyLimit",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
l := logger.NewLogger("SnsSqs unit test")
|
l := logger.NewLogger("SnsSqs unit test")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue