Pubsub AWS SNS/SQS - adding context, cancellation, timeouts, retrying \w backoff & disable delete of messages on failure (#1433)

* squash

Signed-off-by: Amit Mor <amit.mor@hotmail.com>

* comment

Signed-off-by: Amit Mor <amit.mor@hotmail.com>

* gofumpted

Signed-off-by: Amit Mor <amit.mor@hotmail.com>

* breakdown of metadata loading

Signed-off-by: Amit Mor <amit.mor@hotmail.com>

* metadata further refactoring

Signed-off-by: Amit Mor <amit.mor@hotmail.com>
This commit is contained in:
Amit Mor 2022-01-13 17:27:49 +02:00 committed by GitHub
parent 3c28fee80f
commit c8844ccaed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 627 additions and 309 deletions

View File

@ -0,0 +1,351 @@
package snssqs
import (
"errors"
"fmt"
"strconv"
"github.com/dapr/components-contrib/pubsub"
)
type snsSqsMetadata struct {
// aws endpoint for the component to use.
Endpoint string
// access key to use for accessing sqs/sns.
AccessKey string
// secret key to use for accessing sqs/sns.
SecretKey string
// aws session token to use.
SessionToken string
// aws region in which SNS/SQS should create resources.
Region string
// name of the queue for this application. The is provided by the runtime as "consumerID".
sqsQueueName string
// name of the dead letter queue for this application.
sqsDeadLettersQueueName string
// flag to SNS and SQS FIFO.
fifo bool
// a namespace for SNS SQS FIFO to order messages within that group. limits consumer concurrency if set but guarantees that all
// published messages would be ordered by their arrival time to SQS.
// see: https://aws.amazon.com/blogs/compute/solving-complex-ordering-challenges-with-amazon-sqs-fifo-queues/
fifoMessageGroupID string
// amount of time in seconds that a message is hidden from receive requests after it is sent to a subscriber. Default: 10.
messageVisibilityTimeout int64
// number of times to resend a message after processing of that message fails before removing that message from the queue. Default: 10.
messageRetryLimit int64
// upon reaching the messageRetryLimit, disables the default deletion behaviour of the message from the SQS queue, and resetting the message visibilty on SQS
// so that other consumers can try consuming that message.
disableDeleteOnRetryLimit bool
// if sqsDeadLettersQueueName is set to a value, then the messageReceiveLimit defines the number of times a message is received
// before it is moved to the dead-letters queue. This value must be smaller than messageRetryLimit.
messageReceiveLimit int64
// amount of time to await receipt of a message before making another request. Default: 1.
messageWaitTimeSeconds int64
// maximum number of messages to receive from the queue at a time. Default: 10, Maximum: 10.
messageMaxNumber int64
// disable resource provisioning of SNS and SQS.
disableEntityManagement bool
// assets creation timeout.
assetsManagementTimeoutSeconds float64
// aws account ID. internally resolved if not given.
accountID string
}
func getAliasedProperty(aliases []string, metadata pubsub.Metadata) (string, bool) {
props := metadata.Properties
for _, s := range aliases {
if val, ok := props[s]; ok {
return val, true
}
}
return "", false
}
func parseInt64(input string, propertyName string) (int64, error) {
number, err := strconv.Atoi(input)
if err != nil {
return -1, fmt.Errorf("parsing %s failed with: %w", propertyName, err)
}
return int64(number), nil
}
func parseBool(input string, propertyName string) (bool, error) {
val, err := strconv.ParseBool(input)
if err != nil {
return false, fmt.Errorf("parsing %s failed with: %w", propertyName, err)
}
return val, nil
}
func parseFloat64(input string, propertyName string) (float64, error) {
val, err := strconv.ParseFloat(input, 64)
if err != nil {
return 0, fmt.Errorf("parsing %s failed with: %w", propertyName, err)
}
return val, nil
}
func maskLeft(s string) string {
rs := []rune(s)
for i := 0; i < len(rs)-4; i++ {
rs[i] = 'X'
}
return string(rs)
}
func (s *snsSqs) getSnsSqsMetatdata(metadata pubsub.Metadata) (*snsSqsMetadata, error) {
md := &snsSqsMetadata{}
if err := md.setCredsAndQueueNameConfig(metadata); err != nil {
return nil, err
}
props := metadata.Properties
if err := md.setMessageVisibilityTimeout(props); err != nil {
return nil, err
}
if err := md.setMessageRetryLimit(props); err != nil {
return nil, err
}
if err := md.setDeadlettersQueueConfig(props); err != nil {
return nil, err
}
if err := md.setDisableDeleteOnRetryLimit(props); err != nil {
return nil, err
}
if err := md.setFifoConfig(props); err != nil {
return nil, err
}
if err := md.setMessageWaitTimeSeconds(props); err != nil {
return nil, err
}
if err := md.setMessageMaxNumber(props); err != nil {
return nil, err
}
if err := md.setDisableEntityManagement(props); err != nil {
return nil, err
}
if err := md.setAssetsManagementTimeoutSeconds(props); err != nil {
return nil, err
}
s.logger.Debug(md.hideDebugPrintedCredentials())
return md, nil
}
func (md *snsSqsMetadata) hideDebugPrintedCredentials() string {
mdCopy := *md
mdCopy.AccessKey = maskLeft(md.AccessKey)
mdCopy.SecretKey = maskLeft(md.SecretKey)
mdCopy.SessionToken = maskLeft(md.SessionToken)
return fmt.Sprintf("%#v\n", mdCopy)
}
func (md *snsSqsMetadata) setCredsAndQueueNameConfig(metadata pubsub.Metadata) error {
if val, ok := getAliasedProperty([]string{"Endpoint", "endpoint"}, metadata); ok {
md.Endpoint = val
}
if val, ok := getAliasedProperty([]string{"awsAccountID", "accessKey"}, metadata); ok {
md.AccessKey = val
}
if val, ok := getAliasedProperty([]string{"awsSecret", "secretKey"}, metadata); ok {
md.SecretKey = val
}
if val, ok := metadata.Properties["sessionToken"]; ok {
md.SessionToken = val
}
if val, ok := getAliasedProperty([]string{"awsRegion", "region"}, metadata); ok {
md.Region = val
}
if val, ok := metadata.Properties["consumerID"]; ok {
md.sqsQueueName = val
} else {
return errors.New("consumerID must be set")
}
return nil
}
func (md *snsSqsMetadata) setAssetsManagementTimeoutSeconds(props map[string]string) error {
if val, ok := props["assetsManagementTimeoutSeconds"]; ok {
parsed, err := parseFloat64(val, "assetsManagementTimeoutSeconds")
if err != nil {
return err
}
md.assetsManagementTimeoutSeconds = parsed
} else {
md.assetsManagementTimeoutSeconds = assetsManagementDefaultTimeoutSeconds
}
return nil
}
func (md *snsSqsMetadata) setDisableEntityManagement(props map[string]string) error {
if val, ok := props["disableEntityManagement"]; ok {
parsed, err := parseBool(val, "disableEntityManagement")
if err != nil {
return err
}
md.disableEntityManagement = parsed
}
return nil
}
func (md *snsSqsMetadata) setMessageMaxNumber(props map[string]string) error {
if val, ok := props["messageMaxNumber"]; !ok {
md.messageMaxNumber = 10
} else {
maxNumber, err := parseInt64(val, "messageMaxNumber")
if err != nil {
return err
}
if maxNumber < 1 {
return errors.New("messageMaxNumber must be greater than 0")
} else if maxNumber > 10 {
return errors.New("messageMaxNumber must be less than or equal to 10")
}
md.messageMaxNumber = maxNumber
}
return nil
}
func (md *snsSqsMetadata) setMessageWaitTimeSeconds(props map[string]string) error {
if val, ok := props["messageWaitTimeSeconds"]; !ok {
md.messageWaitTimeSeconds = 1
} else {
waitTime, err := parseInt64(val, "messageWaitTimeSeconds")
if err != nil {
return err
}
if waitTime < 1 {
return errors.New("messageWaitTimeSeconds must be greater than 0")
}
md.messageWaitTimeSeconds = waitTime
}
return nil
}
func (md *snsSqsMetadata) setFifoConfig(props map[string]string) error {
// fifo settings: enable/disable SNS and SQS FIFO.
if val, ok := props["fifo"]; ok {
fifo, err := parseBool(val, "fifo")
if err != nil {
return err
}
md.fifo = fifo
} else {
md.fifo = false
}
// fifo settings: assign user provided Message Group ID
// for more details, see: https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/using-messagegroupid-property.html
if val, ok := props["fifoMessageGroupID"]; ok {
md.fifoMessageGroupID = val
}
return nil
}
func (md *snsSqsMetadata) setDeadlettersQueueConfig(props map[string]string) error {
if val, ok := props["sqsDeadLettersQueueName"]; ok {
md.sqsDeadLettersQueueName = val
}
if val, ok := props["messageReceiveLimit"]; ok {
messageReceiveLimit, err := parseInt64(val, "messageReceiveLimit")
if err != nil {
return err
}
// assign: used provided configuration
md.messageReceiveLimit = messageReceiveLimit
}
// XOR on having either a valid messageReceiveLimit and invalid sqsDeadLettersQueueName, and vice versa.
if (md.messageReceiveLimit > 0 || len(md.sqsDeadLettersQueueName) > 0) && !(md.messageReceiveLimit > 0 && len(md.sqsDeadLettersQueueName) > 0) {
return errors.New("to use SQS dead letters queue, messageReceiveLimit and sqsDeadLettersQueueName must both be set to a value")
}
return nil
}
func (md *snsSqsMetadata) setDisableDeleteOnRetryLimit(props map[string]string) error {
if val, ok := props["disableDeleteOnRetryLimit"]; ok {
disableDeleteOnRetryLimit, err := parseBool(val, "disableDeleteOnRetryLimit")
if err != nil {
return err
}
if len(md.sqsDeadLettersQueueName) > 0 && disableDeleteOnRetryLimit {
return errors.New("configuration conflict: 'disableDeleteOnRetryLimit' cannot be set to 'true' when 'sqsDeadLettersQueueName' is set to a value. either remove this configuration or set 'disableDeleteOnRetryLimit' to 'false'")
}
md.disableDeleteOnRetryLimit = disableDeleteOnRetryLimit
} else {
// default when not configured.
md.disableDeleteOnRetryLimit = false
}
return nil
}
func (md *snsSqsMetadata) setMessageRetryLimit(props map[string]string) error {
if val, ok := props["messageRetryLimit"]; !ok {
md.messageRetryLimit = 10
} else {
retryLimit, err := parseInt64(val, "messageRetryLimit")
if err != nil {
return err
}
if retryLimit < 2 {
return errors.New("messageRetryLimit must be greater than 1")
}
md.messageRetryLimit = retryLimit
}
return nil
}
func (md *snsSqsMetadata) setMessageVisibilityTimeout(props map[string]string) error {
if val, ok := props["messageVisibilityTimeout"]; !ok {
md.messageVisibilityTimeout = 10
} else {
timeout, err := parseInt64(val, "messageVisibilityTimeout")
if err != nil {
return err
}
if timeout < 1 {
return errors.New("messageVisibilityTimeout must be greater than 0")
}
md.messageVisibilityTimeout = timeout
}
return nil
}

View File

@ -16,17 +16,20 @@ package snssqs
import (
"context"
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"sync"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/sns"
"github.com/aws/aws-sdk-go/service/sqs"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/dapr/kit/retry"
gonanoid "github.com/matoous/go-nanoid/v2"
aws_auth "github.com/dapr/components-contrib/authentication/aws"
@ -38,17 +41,22 @@ type snsSqs struct {
// key is the topic name, value is the ARN of the topic.
topics sync.Map
// key is the sanitized topic name, value is the actual topic name.
topicsSanitized sync.Map
sanitizedTopics sync.Map
// key is the topic name, value holds the ARN of the queue and its url.
queues sync.Map
// key is a composite key of queue ARN and topic ARN mapping to subscription ARN.
subscriptions sync.Map
snsClient *sns.SNS
sqsClient *sqs.SQS
stsClient *sts.STS
metadata *snsSqsMetadata
logger logger.Logger
id string
opsTimeout time.Duration
ctx context.Context
cancelFn context.CancelFunc
backOffConfig retry.Config
}
type sqsQueueInfo struct {
@ -56,49 +64,23 @@ type sqsQueueInfo struct {
url string
}
type snsSqsMetadata struct {
// aws endpoint for the component to use.
Endpoint string
// access key to use for accessing sqs/sns.
AccessKey string
// secret key to use for accessing sqs/sns.
SecretKey string
// aws session token to use.
SessionToken string
// aws region in which SNS/SQS should create resources.
Region string
// name of the queue for this application. The is provided by the runtime as "consumerID".
sqsQueueName string
// name of the dead letter queue for this application.
sqsDeadLettersQueueName string
// flag to SNS and SQS FIFO.
fifo bool
// a namespace for SNS SQS FIFO to order messages within that group. limits consumer concurrency if set but guarantees that all
// published messages would be ordered by their arrival time to SQS.
// see: https://aws.amazon.com/blogs/compute/solving-complex-ordering-challenges-with-amazon-sqs-fifo-queues/
fifoMessageGroupID string
// amount of time in seconds that a message is hidden from receive requests after it is sent to a subscriber. Default: 10.
messageVisibilityTimeout int64
// number of times to resend a message after processing of that message fails before removing that message from the queue. Default: 10.
messageRetryLimit int64
// if sqsDeadLettersQueueName is set to a value, then the messageReceiveLimit defines the number of times a message is received
// before it is moved to the dead-letters queue. This value must be smaller than messageRetryLimit.
messageReceiveLimit int64
// amount of time to await receipt of a message before making another request. Default: 1.
messageWaitTimeSeconds int64
// maximum number of messages to receive from the queue at a time. Default: 10, Maximum: 10.
messageMaxNumber int64
// disable resource provisioning of SNS and SQS.
disableEntityManagement bool
// aws account ID.
accountID string
type snsMessage struct {
Message string
TopicArn string
}
func (sn *snsMessage) parseTopicArn() string {
arn := sn.TopicArn
return arn[strings.LastIndex(arn, ":")+1:]
}
const (
awsSqsQueueNameKey = "dapr-queue-name"
awsSnsTopicNameKey = "dapr-topic-name"
awsSqsFifoSuffix = ".fifo"
maxAWSNameLength = 80
awsSqsQueueNameKey = "dapr-queue-name"
awsSnsTopicNameKey = "dapr-topic-name"
awsSqsFifoSuffix = ".fifo"
maxAWSNameLength = 80
assetsManagementDefaultTimeoutSeconds = 5.0
awsAccountIDLength = 12
)
// NewSnsSqs - constructor for a new snssqs dapr component.
@ -114,34 +96,6 @@ func NewSnsSqs(l logger.Logger) pubsub.PubSub {
}
}
func getAliasedProperty(aliases []string, metadata pubsub.Metadata) (string, bool) {
props := metadata.Properties
for _, s := range aliases {
if val, ok := props[s]; ok {
return val, true
}
}
return "", false
}
func parseInt64(input string, propertyName string) (int64, error) {
number, err := strconv.Atoi(input)
if err != nil {
return -1, fmt.Errorf("parsing %s failed with: %w", propertyName, err)
}
return int64(number), nil
}
func parseBool(input string, propertyName string) (bool, error) {
val, err := strconv.ParseBool(input)
if err != nil {
return false, fmt.Errorf("parsing %s failed with: %w", propertyName, err)
}
return val, nil
}
// sanitize topic/queue name to conform with:
// https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/quotas-queues.html
func nameToAWSSanitizedName(name string, isFifo bool) string {
@ -181,143 +135,6 @@ func nameToAWSSanitizedName(name string, isFifo bool) string {
return string(s[:j])
}
func (s *snsSqs) getSnsSqsMetatdata(metadata pubsub.Metadata) (*snsSqsMetadata, error) {
md := snsSqsMetadata{}
props := metadata.Properties
md.sqsQueueName = metadata.Properties["consumerID"]
s.logger.Debugf("Setting queue name to %s", md.sqsQueueName)
if val, ok := getAliasedProperty([]string{"Endpoint", "endpoint"}, metadata); ok {
s.logger.Debugf("endpoint: %s", val)
md.Endpoint = val
}
if val, ok := getAliasedProperty([]string{"awsAccountID", "accessKey"}, metadata); ok {
s.logger.Debugf("accessKey: %s", val)
md.AccessKey = val
}
if val, ok := getAliasedProperty([]string{"awsSecret", "secretKey"}, metadata); ok {
s.logger.Debugf("secretKey: %s", val)
md.SecretKey = val
}
if val, ok := props["sessionToken"]; ok {
md.SessionToken = val
}
if val, ok := getAliasedProperty([]string{"awsRegion", "region"}, metadata); ok {
md.Region = val
}
if val, ok := props["messageVisibilityTimeout"]; !ok {
md.messageVisibilityTimeout = 10
} else {
timeout, err := parseInt64(val, "messageVisibilityTimeout")
if err != nil {
return nil, err
}
if timeout < 1 {
return nil, errors.New("messageVisibilityTimeout must be greater than 0")
}
md.messageVisibilityTimeout = timeout
}
if val, ok := props["messageRetryLimit"]; !ok {
md.messageRetryLimit = 10
} else {
retryLimit, err := parseInt64(val, "messageRetryLimit")
if err != nil {
return nil, err
}
if retryLimit < 2 {
return nil, errors.New("messageRetryLimit must be greater than 1")
}
md.messageRetryLimit = retryLimit
}
if val, ok := props["sqsDeadLettersQueueName"]; ok {
md.sqsDeadLettersQueueName = val
}
if val, ok := props["messageReceiveLimit"]; ok {
messageReceiveLimit, err := parseInt64(val, "messageReceiveLimit")
if err != nil {
return nil, err
}
// assign: used provided configuration
md.messageReceiveLimit = messageReceiveLimit
}
// XOR on having either a valid messageReceiveLimit and invalid sqsDeadLettersQueueName, and vice versa.
if (md.messageReceiveLimit > 0 || len(md.sqsDeadLettersQueueName) > 0) && !(md.messageReceiveLimit > 0 && len(md.sqsDeadLettersQueueName) > 0) {
return nil, errors.New("to use SQS dead letters queue, messageReceiveLimit and sqsDeadLettersQueueName must both be set to a value")
}
// fifo settings: enable/disable SNS and SQS FIFO.
if val, ok := props["fifo"]; ok {
fifo, err := parseBool(val, "fifo")
if err != nil {
return nil, err
}
md.fifo = fifo
} else {
md.fifo = false
}
// fifo settings: assign user provided Message Group ID
// for more details, see: https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/using-messagegroupid-property.html
if val, ok := props["fifoMessageGroupID"]; ok {
md.fifoMessageGroupID = val
}
if val, ok := props["messageWaitTimeSeconds"]; !ok {
md.messageWaitTimeSeconds = 1
} else {
waitTime, err := parseInt64(val, "messageWaitTimeSeconds")
if err != nil {
return nil, err
}
if waitTime < 1 {
return nil, errors.New("messageWaitTimeSeconds must be greater than 0")
}
md.messageWaitTimeSeconds = waitTime
}
if val, ok := props["messageMaxNumber"]; !ok {
md.messageMaxNumber = 10
} else {
maxNumber, err := parseInt64(val, "messageMaxNumber")
if err != nil {
return nil, err
}
if maxNumber < 1 {
return nil, errors.New("messageMaxNumber must be greater than 0")
} else if maxNumber > 10 {
return nil, errors.New("messageMaxNumber must be less than or equal to 10")
}
md.messageMaxNumber = maxNumber
}
if val, ok := props["disableEntityManagement"]; ok {
parsed, err := parseBool(val, "disableEntityManagement")
if err != nil {
return nil, err
}
md.disableEntityManagement = parsed
}
return &md, nil
}
func (s *snsSqs) Init(metadata pubsub.Metadata) error {
md, err := s.getSnsSqsMetatdata(metadata)
if err != nil {
@ -329,7 +146,7 @@ func (s *snsSqs) Init(metadata pubsub.Metadata) error {
// both Publish and Subscribe need reference the topic ARN, queue ARN and subscription ARN between topic and queue
// track these ARNs in these maps.
s.topics = sync.Map{}
s.topicsSanitized = sync.Map{}
s.sanitizedTopics = sync.Map{}
s.queues = sync.Map{}
s.subscriptions = sync.Map{}
@ -337,23 +154,48 @@ func (s *snsSqs) Init(metadata pubsub.Metadata) error {
if err != nil {
return fmt.Errorf("error creating an AWS client: %w", err)
}
// AWS sns,sqs,sts client.
s.snsClient = sns.New(sess)
s.sqsClient = sqs.New(sess)
s.stsClient = sts.New(sess)
callerIDOutput, err := s.stsClient.GetCallerIdentity(&sts.GetCallerIdentityInput{})
s.opsTimeout = time.Duration(md.assetsManagementTimeoutSeconds * float64(time.Second))
s.ctx, s.cancelFn = context.WithCancel(context.Background())
if err := s.setAwsAccountIDIfNotProvided(); err != nil {
return err
}
// Default retry configuration is used if no
// backOff properties are set.
if err := retry.DecodeConfigWithPrefix(
&s.backOffConfig,
metadata.Properties,
"backOff"); err != nil {
return fmt.Errorf("error decoding backOff config: %w", err)
}
return nil
}
func (s *snsSqs) setAwsAccountIDIfNotProvided() error {
if len(s.metadata.accountID) == awsAccountIDLength {
return nil
}
ctx, cancelFn := context.WithTimeout(s.ctx, s.opsTimeout)
defer cancelFn()
callerIDOutput, err := s.stsClient.GetCallerIdentityWithContext(ctx, &sts.GetCallerIdentityInput{})
if err != nil {
return fmt.Errorf("error fetching sts caller ID: %w", err)
}
s.metadata.accountID = *callerIDOutput.Account
s.snsClient = sns.New(sess)
s.sqsClient = sqs.New(sess)
return nil
}
func (s *snsSqs) buildARN(serviceName, entityName string) string {
// arn:aws:sns:us-east-1:302212680347:aws-controltower-SecurityNotifications
return fmt.Sprintf("arn:aws:%s:%s:%s:%s", serviceName, s.metadata.Region, s.metadata.accountID, entityName)
}
@ -369,7 +211,10 @@ func (s *snsSqs) createTopic(topic string) (string, error) {
snsCreateTopicInput.SetAttributes(attributes)
}
createTopicResponse, err := s.snsClient.CreateTopic(snsCreateTopicInput)
ctx, cancelFn := context.WithTimeout(s.ctx, s.opsTimeout)
defer cancelFn()
createTopicResponse, err := s.snsClient.CreateTopicWithContext(ctx, snsCreateTopicInput)
if err != nil {
return "", fmt.Errorf("error while creating an SNS topic: %w", err)
}
@ -378,8 +223,11 @@ func (s *snsSqs) createTopic(topic string) (string, error) {
}
func (s *snsSqs) getTopicArn(topic string) (string, error) {
ctx, cancelFn := context.WithTimeout(s.ctx, s.opsTimeout)
defer cancelFn()
arn := s.buildARN("sns", topic)
getTopicOutput, err := s.snsClient.GetTopicAttributes(&sns.GetTopicAttributesInput{TopicArn: aws.String(arn)})
getTopicOutput, err := s.snsClient.GetTopicAttributesWithContext(ctx, &sns.GetTopicAttributesInput{TopicArn: aws.String(arn)})
if err != nil {
return "", fmt.Errorf("error: %w while getting topic: %v with arn: %v", err, topic, arn)
}
@ -422,7 +270,7 @@ func (s *snsSqs) getOrCreateTopic(topic string) (string, error) {
// record topic ARN.
s.topics.Store(topic, topicArn)
s.topicsSanitized.Store(sanitizedName, topic)
s.sanitizedTopics.Store(sanitizedName, topic)
return topicArn, nil
}
@ -438,13 +286,18 @@ func (s *snsSqs) createQueue(queueName string) (*sqsQueueInfo, error) {
attributes := map[string]*string{"FifoQueue": aws.String("true"), "ContentBasedDeduplication": aws.String("true")}
sqsCreateQueueInput.SetAttributes(attributes)
}
ctx, cancelFn := context.WithTimeout(s.ctx, s.opsTimeout)
defer cancelFn()
createQueueResponse, err := s.sqsClient.CreateQueue(sqsCreateQueueInput)
createQueueResponse, err := s.sqsClient.CreateQueueWithContext(ctx, sqsCreateQueueInput)
if err != nil {
return nil, fmt.Errorf("error creaing an SQS queue: %w", err)
}
queueAttributesResponse, err := s.sqsClient.GetQueueAttributes(&sqs.GetQueueAttributesInput{
aCtx, aCancelFn := context.WithTimeout(s.ctx, s.opsTimeout)
defer aCancelFn()
queueAttributesResponse, err := s.sqsClient.GetQueueAttributesWithContext(aCtx, &sqs.GetQueueAttributesInput{
AttributeNames: []*string{aws.String("QueueArn")},
QueueUrl: createQueueResponse.QueueUrl,
})
@ -459,14 +312,20 @@ func (s *snsSqs) createQueue(queueName string) (*sqsQueueInfo, error) {
}
func (s *snsSqs) getQueueArn(queueName string) (*sqsQueueInfo, error) {
queueURLOutput, err := s.sqsClient.GetQueueUrl(&sqs.GetQueueUrlInput{QueueName: aws.String(queueName), QueueOwnerAWSAccountId: aws.String(s.metadata.accountID)})
ctx, cancelFn := context.WithTimeout(s.ctx, s.opsTimeout)
defer cancelFn()
queueURLOutput, err := s.sqsClient.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{QueueName: aws.String(queueName), QueueOwnerAWSAccountId: aws.String(s.metadata.accountID)})
if err != nil {
return nil, fmt.Errorf("error: %w while getting url of queue: %s", err, queueName)
}
url := queueURLOutput.QueueUrl
aCtx, aCancelFn := context.WithTimeout(s.ctx, s.opsTimeout)
defer aCancelFn()
var getQueueOutput *sqs.GetQueueAttributesOutput
getQueueOutput, err = s.sqsClient.GetQueueAttributes(&sqs.GetQueueAttributesInput{QueueUrl: url, AttributeNames: []*string{aws.String("QueueArn")}})
getQueueOutput, err = s.sqsClient.GetQueueAttributesWithContext(aCtx, &sqs.GetQueueAttributesInput{QueueUrl: url, AttributeNames: []*string{aws.String("QueueArn")}})
if err != nil {
return nil, fmt.Errorf("error: %w while getting information for queue: %s, with url: %s", err, queueName, *url)
}
@ -525,7 +384,10 @@ func (s *snsSqs) getMessageGroupID(req *pubsub.PublishRequest) *string {
}
func (s *snsSqs) createSnsSqsSubscription(queueArn, topicArn string) (string, error) {
subscribeOutput, err := s.snsClient.Subscribe(&sns.SubscribeInput{
ctx, cancelFn := context.WithTimeout(s.ctx, s.opsTimeout)
defer cancelFn()
subscribeOutput, err := s.snsClient.SubscribeWithContext(ctx, &sns.SubscribeInput{
Attributes: nil,
Endpoint: aws.String(queueArn), // create SQS queue per subscription.
Protocol: aws.String("sqs"),
@ -543,7 +405,10 @@ func (s *snsSqs) createSnsSqsSubscription(queueArn, topicArn string) (string, er
}
func (s *snsSqs) getSnsSqsSubscriptionArn(topicArn string) (string, error) {
listSubscriptionsOutput, err := s.snsClient.ListSubscriptionsByTopic(&sns.ListSubscriptionsByTopicInput{TopicArn: aws.String(topicArn)})
ctx, cancelFn := context.WithTimeout(s.ctx, s.opsTimeout)
defer cancelFn()
listSubscriptionsOutput, err := s.snsClient.ListSubscriptionsByTopicWithContext(ctx, &sns.ListSubscriptionsByTopicInput{TopicArn: aws.String(topicArn)})
if err != nil {
return "", fmt.Errorf("error listing subsriptions for topic arn: %v: %w", topicArn, err)
}
@ -594,146 +459,176 @@ func (s *snsSqs) getOrCreateSnsSqsSubscription(queueArn, topicArn string) (strin
return subscriptionArn, nil
}
func (s *snsSqs) Publish(req *pubsub.PublishRequest) error {
topicArn, err := s.getOrCreateTopic(req.Topic)
if err != nil {
s.logger.Errorf("error getting topic ARN for %s: %v", req.Topic, err)
}
message := string(req.Data)
snsPublishInput := &sns.PublishInput{
Message: &message,
TopicArn: &topicArn,
}
if s.metadata.fifo {
snsPublishInput.MessageGroupId = s.getMessageGroupID(req)
}
_, err = s.snsClient.Publish(snsPublishInput)
if err != nil {
wrappedErr := fmt.Errorf("error publishing to topic: %s with topic ARN %s: %w", req.Topic, topicArn, err)
s.logger.Error(wrappedErr)
return wrappedErr
}
return nil
}
type snsMessage struct {
Message string
TopicArn string
}
func parseTopicArn(arn string) string {
return arn[strings.LastIndex(arn, ":")+1:]
}
func (s *snsSqs) acknowledgeMessage(queueURL string, receiptHandle *string) error {
if _, err := s.sqsClient.DeleteMessage(&sqs.DeleteMessageInput{
QueueUrl: &queueURL,
ctx, cancelFn := context.WithCancel(s.ctx)
defer cancelFn()
deleteMessageInput := &sqs.DeleteMessageInput{
QueueUrl: aws.String(queueURL),
ReceiptHandle: receiptHandle,
}); err != nil {
return fmt.Errorf("error deleting SQS message: %w", err)
}
if _, err := s.sqsClient.DeleteMessageWithContext(ctx, deleteMessageInput); err != nil {
return fmt.Errorf("error deleting message: %w", err)
}
return nil
}
func (s *snsSqs) handleMessage(message *sqs.Message, queueInfo, deadLettersQueueInfo *sqsQueueInfo, handler pubsub.Handler) error {
func (s *snsSqs) resetMessageVisibilityTimeout(queueURL string, receiptHandle *string) error {
ctx, cancelFn := context.WithCancel(s.ctx)
defer cancelFn()
// reset the timeout to its initial value so that the remaining timeout would be overridden by the initial value for other consumer to attempt processing.
changeMessageVisibilityInput := &sqs.ChangeMessageVisibilityInput{
QueueUrl: aws.String(queueURL),
ReceiptHandle: receiptHandle,
VisibilityTimeout: aws.Int64(s.metadata.messageVisibilityTimeout),
}
if _, err := s.sqsClient.ChangeMessageVisibilityWithContext(ctx, changeMessageVisibilityInput); err != nil {
return fmt.Errorf("error changing message visibility timeout: %w", err)
}
return nil
}
func (s *snsSqs) parseReceiveCount(message *sqs.Message) (int64, error) {
// if this message has been received > x times, delete from queue, it's borked.
recvCount, ok := message.Attributes[sqs.MessageSystemAttributeNameApproximateReceiveCount]
if !ok {
return fmt.Errorf(
return 0, fmt.Errorf(
"no ApproximateReceiveCount returned with response, will not attempt further processing: %v", message)
}
recvCountInt, err := strconv.ParseInt(*recvCount, 10, 32)
if err != nil {
return fmt.Errorf("error parsing ApproximateReceiveCount from message: %v", message)
return 0, fmt.Errorf("error parsing ApproximateReceiveCount from message: %v", message)
}
return recvCountInt, nil
}
func (s *snsSqs) validateMessage(message *sqs.Message, queueInfo, deadLettersQueueInfo *sqsQueueInfo, handler pubsub.Handler) error {
recvCount, err := s.parseReceiveCount(message)
if err != nil {
return err
}
// if we are over the allowable retry limit, and there is no dead-letters queue, delete the message from the queue.
if deadLettersQueueInfo == nil && recvCountInt >= s.metadata.messageRetryLimit {
if innerErr := s.acknowledgeMessage(queueInfo.url, message.ReceiptHandle); innerErr != nil {
return fmt.Errorf("error acknowledging message after receiving the message too many times: %w", innerErr)
messageRetryLimit := s.metadata.messageRetryLimit
if deadLettersQueueInfo == nil && recvCount >= messageRetryLimit {
// if we are over the allowable retry limit, and there is no dead-letters queue, and we don't disable deletes, then delete the message from the queue.
if !s.metadata.disableDeleteOnRetryLimit {
if innerErr := s.acknowledgeMessage(queueInfo.url, message.ReceiptHandle); innerErr != nil {
return fmt.Errorf("error acknowledging message after receiving the message too many times: %w", innerErr)
}
return fmt.Errorf("message received greater than %v times, deleting this message without further processing", messageRetryLimit)
}
// if we are over the allowable retry limit, and there is no dead-letters queue, and deletes are disabled, then don't delete the message from the queue.
// reset the already "consumed" message visibility clock.
s.logger.Debugf("message received greater than %v times. deletion past the thredhold is diabled. noop", messageRetryLimit)
if err := s.resetMessageVisibilityTimeout(queueInfo.url, message.ReceiptHandle); err != nil {
return fmt.Errorf("error resetting message visibility timeout: %w", err)
}
return fmt.Errorf(
"message received greater than %v times, deleting this message without further processing", s.metadata.messageRetryLimit)
return nil
}
// ... else, there is no need to actively do something if we reached the limit defined in messageReceiveLimit as the message had
// already been moved to the dead-letters queue by SQS. meaning, the below condition should not be reached as SQS would not send
// a message if we've already surpassed the s.metadata.messageReceiveLimit value.
if deadLettersQueueInfo != nil && recvCountInt > s.metadata.messageReceiveLimit {
// a message if we've already surpassed the messageRetryLimit value.
if deadLettersQueueInfo != nil && recvCount > messageRetryLimit {
awsErr := fmt.Errorf(
"message received greater than %v times, this message should have been moved without further processing to dead-letters queue: %v", s.metadata.messageReceiveLimit, s.metadata.sqsDeadLettersQueueName)
s.logger.Error(awsErr)
"message received greater than %v times, this message should have been moved without further processing to dead-letters queue: %v", messageRetryLimit, s.metadata.sqsDeadLettersQueueName)
return awsErr
}
// otherwise try to handle the message.
var messageBody snsMessage
err = json.Unmarshal([]byte(*(message.Body)), &messageBody)
return nil
}
func (s *snsSqs) callHandler(message *sqs.Message, queueInfo *sqsQueueInfo, handler pubsub.Handler) error {
// otherwise, try to handle the message.
var snsMessagePayload snsMessage
err := json.Unmarshal([]byte(*(message.Body)), &snsMessagePayload)
if err != nil {
return fmt.Errorf("error unmarshalling message: %w", err)
}
// messageBody.TopicArn can only carry a sanitized topic name as we conform to AWS naming standards.
// snsMessagePayload.TopicArn can only carry a sanitized topic name as we conform to AWS naming standards.
// for the user to be able to understand the source of the coming message, we'd use the original,
// dirty name to be carried over in the pubsub.NewMessage Topic field.
sanitizedTopic := parseTopicArn(messageBody.TopicArn)
cachedTopic, ok := s.topicsSanitized.Load(sanitizedTopic)
sanitizedTopic := snsMessagePayload.parseTopicArn()
cachedTopic, ok := s.sanitizedTopics.Load(sanitizedTopic)
if !ok {
return fmt.Errorf("failed loading topic (sanitized): %s from internal topics cache. SNS topic might be just created", sanitizedTopic)
}
err = handler(context.Background(), &pubsub.NewMessage{
Data: []byte(messageBody.Message),
Topic: cachedTopic.(string),
})
s.logger.Debugf("Processing SNS message id: %s of topic: %s", message.MessageId, sanitizedTopic)
if err != nil {
ctx, cancelFn := context.WithCancel(s.ctx)
defer cancelFn()
if err := handler(ctx, &pubsub.NewMessage{
Data: []byte(snsMessagePayload.Message),
Topic: cachedTopic.(string),
}); err != nil {
return fmt.Errorf("error handling message: %w", err)
}
// otherwise, there was no error, acknowledge the message.
return s.acknowledgeMessage(queueInfo.url, message.ReceiptHandle)
}
func (s *snsSqs) consumeSubscription(queueInfo, deadLettersQueueInfo *sqsQueueInfo, handler pubsub.Handler) {
go func() {
ctx, cancelFn := context.WithCancel(s.ctx)
defer cancelFn()
sqsPullExponentialBackoff := s.backOffConfig.NewBackOffWithContext(ctx)
receiveMessageInput := &sqs.ReceiveMessageInput{
// use this property to decide when a message should be discarded.
AttributeNames: []*string{
aws.String(sqs.MessageSystemAttributeNameApproximateReceiveCount),
},
MaxNumberOfMessages: aws.Int64(s.metadata.messageMaxNumber),
QueueUrl: aws.String(queueInfo.url),
VisibilityTimeout: aws.Int64(s.metadata.messageVisibilityTimeout),
WaitTimeSeconds: aws.Int64(s.metadata.messageWaitTimeSeconds),
}
for {
messageResponse, err := s.sqsClient.ReceiveMessage(&sqs.ReceiveMessageInput{
// use this property to decide when a message should be discarded.
AttributeNames: []*string{
aws.String(sqs.MessageSystemAttributeNameApproximateReceiveCount),
},
MaxNumberOfMessages: aws.Int64(s.metadata.messageMaxNumber),
QueueUrl: aws.String(queueInfo.url),
VisibilityTimeout: aws.Int64(s.metadata.messageVisibilityTimeout),
WaitTimeSeconds: aws.Int64(s.metadata.messageWaitTimeSeconds),
})
// Internally, by default, aws go sdk performs 3 retires with exponential backoff to contact
// sqs and try pull messages. Since we are iteratively short polling (based on the defined
// s.metadata.messageWaitTimeSeconds) the sdk backoff is not effective as it gets reset per each polling
// iteration. Therefore, a global backoff (to the internal backoff) is used (sqsPullExponentialBackoff).
messageResponse, err := s.sqsClient.ReceiveMessageWithContext(ctx, receiveMessageInput)
if err != nil {
s.logger.Errorf("error consuming topic: %v", err)
if awsErr, ok := err.(awserr.Error); ok {
s.logger.Errorf("AWS operation error while consuming from queue url: %v with error: %w. retrying...", queueInfo.url, awsErr.Error())
} else {
s.logger.Errorf("error consuming from queue url: %v with error: %w. retrying...", queueInfo.url, err)
}
time.Sleep(sqsPullExponentialBackoff.NextBackOff())
continue
}
// error either recovered or did not happen at all. resetting the backoff counter (and duration).
sqsPullExponentialBackoff.Reset()
// retry receiving messages.
if len(messageResponse.Messages) < 1 {
s.logger.Debug("No messages received, requesting again")
s.logger.Debug("No messages received, continuing")
continue
}
s.logger.Debugf("%v message(s) received", len(messageResponse.Messages))
for _, m := range messageResponse.Messages {
if err := s.handleMessage(m, queueInfo, deadLettersQueueInfo, handler); err != nil {
s.logger.Error(err)
for _, message := range messageResponse.Messages {
if err := s.validateMessage(message, queueInfo, deadLettersQueueInfo, handler); err != nil {
s.logger.Errorf("message is not valid for further processing by the handler. error is: %w", err)
continue
}
if err := s.callHandler(message, queueInfo, handler); err != nil {
s.logger.Errorf("error handling received message with error: %w", err)
continue
}
}
}
@ -782,8 +677,11 @@ func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(sqsQueueInfo *sqsQueueInfo,
if s.metadata.disableEntityManagement {
return nil
}
ctx, cancelFn := context.WithTimeout(s.ctx, s.opsTimeout)
defer cancelFn()
// only permit SNS to send messages to SQS using the created subscription.
getQueueAttributesOutput, err := s.sqsClient.GetQueueAttributes(&sqs.GetQueueAttributesInput{QueueUrl: &sqsQueueInfo.url, AttributeNames: []*string{aws.String(sqs.QueueAttributeNamePolicy)}})
getQueueAttributesOutput, err := s.sqsClient.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{QueueUrl: &sqsQueueInfo.url, AttributeNames: []*string{aws.String(sqs.QueueAttributeNamePolicy)}})
if err != nil {
return fmt.Errorf("error getting queue attributes: %w", err)
}
@ -818,7 +716,10 @@ func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(sqsQueueInfo *sqsQueueInfo,
return fmt.Errorf("failed serializing new sqs policy: %w", uerr)
}
if _, err = s.sqsClient.SetQueueAttributes(&(sqs.SetQueueAttributesInput{
aCtx, aCancelFn := context.WithTimeout(s.ctx, s.opsTimeout)
defer aCancelFn()
if _, err = s.sqsClient.SetQueueAttributesWithContext(aCtx, &(sqs.SetQueueAttributesInput{
Attributes: map[string]*string{
"Policy": aws.String(string(b)),
},
@ -882,7 +783,10 @@ func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler)
return wrappedErr
}
_, derr = s.sqsClient.SetQueueAttributes(sqsSetQueueAttributesInput)
ctx, cancelFn := context.WithTimeout(s.ctx, s.opsTimeout)
defer cancelFn()
_, derr = s.sqsClient.SetQueueAttributesWithContext(ctx, sqsSetQueueAttributesInput)
if derr != nil {
wrappedErr := fmt.Errorf("error updating queue attributes with dead-letter queue: %w", derr)
s.logger.Error(wrappedErr)
@ -904,7 +808,38 @@ func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler)
return nil
}
func (s *snsSqs) Publish(req *pubsub.PublishRequest) error {
topicArn, err := s.getOrCreateTopic(req.Topic)
if err != nil {
s.logger.Errorf("error getting topic ARN for %s: %v", req.Topic, err)
}
message := string(req.Data)
snsPublishInput := &sns.PublishInput{
Message: aws.String(message),
TopicArn: aws.String(topicArn),
}
if s.metadata.fifo {
snsPublishInput.MessageGroupId = s.getMessageGroupID(req)
}
ctx, cancelFn := context.WithCancel(s.ctx)
defer cancelFn()
// sns client has internal exponential backoffs.
_, err = s.snsClient.PublishWithContext(ctx, snsPublishInput)
if err != nil {
wrappedErr := fmt.Errorf("error publishing to topic: %s with topic ARN %s: %w", req.Topic, topicArn, err)
s.logger.Error(wrappedErr)
return wrappedErr
}
return nil
}
func (s *snsSqs) Close() error {
s.cancelFn()
return nil
}

View File

@ -31,7 +31,8 @@ func Test_parseTopicArn(t *testing.T) {
t.Parallel()
// no further guarantees are made about this function.
r := require.New(t)
r.Equal("qqnoob", parseTopicArn("arn:aws:sqs:us-east-1:000000000000:qqnoob"))
tSnsMessage := &snsMessage{TopicArn: "arn:aws:sqs:us-east-1:000000000000:qqnoob"}
r.Equal("qqnoob", tSnsMessage.parseTopicArn())
}
// Verify that all metadata ends up in the correct spot.
@ -103,6 +104,9 @@ func Test_getSnsSqsMetatdata_defaults(t *testing.T) {
r.Equal(int64(10), md.messageRetryLimit)
r.Equal(int64(1), md.messageWaitTimeSeconds)
r.Equal(int64(10), md.messageMaxNumber)
r.Equal(false, md.disableEntityManagement)
r.Equal(float64(5), md.assetsManagementTimeoutSeconds)
r.Equal(false, md.disableDeleteOnRetryLimit)
}
func Test_getSnsSqsMetatdata_legacyaliases(t *testing.T) {
@ -188,6 +192,20 @@ func Test_getSnsSqsMetatdata_invalidMetadataSetup(t *testing.T) {
}},
name: "deadletters message queue without deadletters receive limit",
},
{
metadata: pubsub.Metadata{Properties: map[string]string{
"consumerID": "consumer",
"Endpoint": "endpoint",
"AccessKey": "acctId",
"SecretKey": "secret",
"awsToken": "token",
"Region": "region",
"sqsDeadLettersQueueName": "my-queue",
"messageReceiveLimit": "9",
"disableDeleteOnRetryLimit": "true",
}},
name: "deadletters message queue with disableDeleteOnRetryLimit",
},
{
metadata: pubsub.Metadata{Properties: map[string]string{
"consumerID": "consumer",
@ -248,6 +266,20 @@ func Test_getSnsSqsMetatdata_invalidMetadataSetup(t *testing.T) {
}},
name: "invalid message retry limit",
},
// disableEntityManagement
{
metadata: pubsub.Metadata{Properties: map[string]string{
"consumerID": "consumer",
"Endpoint": "endpoint",
"AccessKey": "acctId",
"SecretKey": "secret",
"awsToken": "token",
"Region": "region",
"messageRetryLimit": "10",
"disableEntityManagement": "y",
}},
name: "invalid message disableEntityManagement",
},
}
l := logger.NewLogger("SnsSqs unit test")