Pubsub AWS SNS/SQS - adding context, cancellation, timeouts, retrying \w backoff & disable delete of messages on failure (#1433)
* squash Signed-off-by: Amit Mor <amit.mor@hotmail.com> * comment Signed-off-by: Amit Mor <amit.mor@hotmail.com> * gofumpted Signed-off-by: Amit Mor <amit.mor@hotmail.com> * breakdown of metadata loading Signed-off-by: Amit Mor <amit.mor@hotmail.com> * metadata further refactoring Signed-off-by: Amit Mor <amit.mor@hotmail.com>
This commit is contained in:
parent
3c28fee80f
commit
c8844ccaed
|
@ -0,0 +1,351 @@
|
|||
package snssqs
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/dapr/components-contrib/pubsub"
|
||||
)
|
||||
|
||||
type snsSqsMetadata struct {
|
||||
// aws endpoint for the component to use.
|
||||
Endpoint string
|
||||
// access key to use for accessing sqs/sns.
|
||||
AccessKey string
|
||||
// secret key to use for accessing sqs/sns.
|
||||
SecretKey string
|
||||
// aws session token to use.
|
||||
SessionToken string
|
||||
// aws region in which SNS/SQS should create resources.
|
||||
Region string
|
||||
// name of the queue for this application. The is provided by the runtime as "consumerID".
|
||||
sqsQueueName string
|
||||
// name of the dead letter queue for this application.
|
||||
sqsDeadLettersQueueName string
|
||||
// flag to SNS and SQS FIFO.
|
||||
fifo bool
|
||||
// a namespace for SNS SQS FIFO to order messages within that group. limits consumer concurrency if set but guarantees that all
|
||||
// published messages would be ordered by their arrival time to SQS.
|
||||
// see: https://aws.amazon.com/blogs/compute/solving-complex-ordering-challenges-with-amazon-sqs-fifo-queues/
|
||||
fifoMessageGroupID string
|
||||
// amount of time in seconds that a message is hidden from receive requests after it is sent to a subscriber. Default: 10.
|
||||
messageVisibilityTimeout int64
|
||||
// number of times to resend a message after processing of that message fails before removing that message from the queue. Default: 10.
|
||||
messageRetryLimit int64
|
||||
// upon reaching the messageRetryLimit, disables the default deletion behaviour of the message from the SQS queue, and resetting the message visibilty on SQS
|
||||
// so that other consumers can try consuming that message.
|
||||
disableDeleteOnRetryLimit bool
|
||||
// if sqsDeadLettersQueueName is set to a value, then the messageReceiveLimit defines the number of times a message is received
|
||||
// before it is moved to the dead-letters queue. This value must be smaller than messageRetryLimit.
|
||||
messageReceiveLimit int64
|
||||
// amount of time to await receipt of a message before making another request. Default: 1.
|
||||
messageWaitTimeSeconds int64
|
||||
// maximum number of messages to receive from the queue at a time. Default: 10, Maximum: 10.
|
||||
messageMaxNumber int64
|
||||
// disable resource provisioning of SNS and SQS.
|
||||
disableEntityManagement bool
|
||||
// assets creation timeout.
|
||||
assetsManagementTimeoutSeconds float64
|
||||
// aws account ID. internally resolved if not given.
|
||||
accountID string
|
||||
}
|
||||
|
||||
func getAliasedProperty(aliases []string, metadata pubsub.Metadata) (string, bool) {
|
||||
props := metadata.Properties
|
||||
for _, s := range aliases {
|
||||
if val, ok := props[s]; ok {
|
||||
return val, true
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
func parseInt64(input string, propertyName string) (int64, error) {
|
||||
number, err := strconv.Atoi(input)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("parsing %s failed with: %w", propertyName, err)
|
||||
}
|
||||
|
||||
return int64(number), nil
|
||||
}
|
||||
|
||||
func parseBool(input string, propertyName string) (bool, error) {
|
||||
val, err := strconv.ParseBool(input)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parsing %s failed with: %w", propertyName, err)
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func parseFloat64(input string, propertyName string) (float64, error) {
|
||||
val, err := strconv.ParseFloat(input, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing %s failed with: %w", propertyName, err)
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func maskLeft(s string) string {
|
||||
rs := []rune(s)
|
||||
for i := 0; i < len(rs)-4; i++ {
|
||||
rs[i] = 'X'
|
||||
}
|
||||
return string(rs)
|
||||
}
|
||||
|
||||
func (s *snsSqs) getSnsSqsMetatdata(metadata pubsub.Metadata) (*snsSqsMetadata, error) {
|
||||
md := &snsSqsMetadata{}
|
||||
if err := md.setCredsAndQueueNameConfig(metadata); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
props := metadata.Properties
|
||||
|
||||
if err := md.setMessageVisibilityTimeout(props); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := md.setMessageRetryLimit(props); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := md.setDeadlettersQueueConfig(props); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := md.setDisableDeleteOnRetryLimit(props); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := md.setFifoConfig(props); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := md.setMessageWaitTimeSeconds(props); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := md.setMessageMaxNumber(props); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := md.setDisableEntityManagement(props); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := md.setAssetsManagementTimeoutSeconds(props); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.logger.Debug(md.hideDebugPrintedCredentials())
|
||||
|
||||
return md, nil
|
||||
}
|
||||
|
||||
func (md *snsSqsMetadata) hideDebugPrintedCredentials() string {
|
||||
mdCopy := *md
|
||||
mdCopy.AccessKey = maskLeft(md.AccessKey)
|
||||
mdCopy.SecretKey = maskLeft(md.SecretKey)
|
||||
mdCopy.SessionToken = maskLeft(md.SessionToken)
|
||||
|
||||
return fmt.Sprintf("%#v\n", mdCopy)
|
||||
}
|
||||
|
||||
func (md *snsSqsMetadata) setCredsAndQueueNameConfig(metadata pubsub.Metadata) error {
|
||||
if val, ok := getAliasedProperty([]string{"Endpoint", "endpoint"}, metadata); ok {
|
||||
md.Endpoint = val
|
||||
}
|
||||
|
||||
if val, ok := getAliasedProperty([]string{"awsAccountID", "accessKey"}, metadata); ok {
|
||||
md.AccessKey = val
|
||||
}
|
||||
|
||||
if val, ok := getAliasedProperty([]string{"awsSecret", "secretKey"}, metadata); ok {
|
||||
md.SecretKey = val
|
||||
}
|
||||
|
||||
if val, ok := metadata.Properties["sessionToken"]; ok {
|
||||
md.SessionToken = val
|
||||
}
|
||||
|
||||
if val, ok := getAliasedProperty([]string{"awsRegion", "region"}, metadata); ok {
|
||||
md.Region = val
|
||||
}
|
||||
|
||||
if val, ok := metadata.Properties["consumerID"]; ok {
|
||||
md.sqsQueueName = val
|
||||
} else {
|
||||
return errors.New("consumerID must be set")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (md *snsSqsMetadata) setAssetsManagementTimeoutSeconds(props map[string]string) error {
|
||||
if val, ok := props["assetsManagementTimeoutSeconds"]; ok {
|
||||
parsed, err := parseFloat64(val, "assetsManagementTimeoutSeconds")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
md.assetsManagementTimeoutSeconds = parsed
|
||||
} else {
|
||||
md.assetsManagementTimeoutSeconds = assetsManagementDefaultTimeoutSeconds
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (md *snsSqsMetadata) setDisableEntityManagement(props map[string]string) error {
|
||||
if val, ok := props["disableEntityManagement"]; ok {
|
||||
parsed, err := parseBool(val, "disableEntityManagement")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
md.disableEntityManagement = parsed
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (md *snsSqsMetadata) setMessageMaxNumber(props map[string]string) error {
|
||||
if val, ok := props["messageMaxNumber"]; !ok {
|
||||
md.messageMaxNumber = 10
|
||||
} else {
|
||||
maxNumber, err := parseInt64(val, "messageMaxNumber")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if maxNumber < 1 {
|
||||
return errors.New("messageMaxNumber must be greater than 0")
|
||||
} else if maxNumber > 10 {
|
||||
return errors.New("messageMaxNumber must be less than or equal to 10")
|
||||
}
|
||||
|
||||
md.messageMaxNumber = maxNumber
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (md *snsSqsMetadata) setMessageWaitTimeSeconds(props map[string]string) error {
|
||||
if val, ok := props["messageWaitTimeSeconds"]; !ok {
|
||||
md.messageWaitTimeSeconds = 1
|
||||
} else {
|
||||
waitTime, err := parseInt64(val, "messageWaitTimeSeconds")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if waitTime < 1 {
|
||||
return errors.New("messageWaitTimeSeconds must be greater than 0")
|
||||
}
|
||||
|
||||
md.messageWaitTimeSeconds = waitTime
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (md *snsSqsMetadata) setFifoConfig(props map[string]string) error {
|
||||
// fifo settings: enable/disable SNS and SQS FIFO.
|
||||
if val, ok := props["fifo"]; ok {
|
||||
fifo, err := parseBool(val, "fifo")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
md.fifo = fifo
|
||||
} else {
|
||||
md.fifo = false
|
||||
}
|
||||
|
||||
// fifo settings: assign user provided Message Group ID
|
||||
// for more details, see: https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/using-messagegroupid-property.html
|
||||
if val, ok := props["fifoMessageGroupID"]; ok {
|
||||
md.fifoMessageGroupID = val
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (md *snsSqsMetadata) setDeadlettersQueueConfig(props map[string]string) error {
|
||||
if val, ok := props["sqsDeadLettersQueueName"]; ok {
|
||||
md.sqsDeadLettersQueueName = val
|
||||
}
|
||||
|
||||
if val, ok := props["messageReceiveLimit"]; ok {
|
||||
messageReceiveLimit, err := parseInt64(val, "messageReceiveLimit")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// assign: used provided configuration
|
||||
md.messageReceiveLimit = messageReceiveLimit
|
||||
}
|
||||
|
||||
// XOR on having either a valid messageReceiveLimit and invalid sqsDeadLettersQueueName, and vice versa.
|
||||
if (md.messageReceiveLimit > 0 || len(md.sqsDeadLettersQueueName) > 0) && !(md.messageReceiveLimit > 0 && len(md.sqsDeadLettersQueueName) > 0) {
|
||||
return errors.New("to use SQS dead letters queue, messageReceiveLimit and sqsDeadLettersQueueName must both be set to a value")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (md *snsSqsMetadata) setDisableDeleteOnRetryLimit(props map[string]string) error {
|
||||
if val, ok := props["disableDeleteOnRetryLimit"]; ok {
|
||||
disableDeleteOnRetryLimit, err := parseBool(val, "disableDeleteOnRetryLimit")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(md.sqsDeadLettersQueueName) > 0 && disableDeleteOnRetryLimit {
|
||||
return errors.New("configuration conflict: 'disableDeleteOnRetryLimit' cannot be set to 'true' when 'sqsDeadLettersQueueName' is set to a value. either remove this configuration or set 'disableDeleteOnRetryLimit' to 'false'")
|
||||
}
|
||||
|
||||
md.disableDeleteOnRetryLimit = disableDeleteOnRetryLimit
|
||||
} else {
|
||||
// default when not configured.
|
||||
md.disableDeleteOnRetryLimit = false
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (md *snsSqsMetadata) setMessageRetryLimit(props map[string]string) error {
|
||||
if val, ok := props["messageRetryLimit"]; !ok {
|
||||
md.messageRetryLimit = 10
|
||||
} else {
|
||||
retryLimit, err := parseInt64(val, "messageRetryLimit")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if retryLimit < 2 {
|
||||
return errors.New("messageRetryLimit must be greater than 1")
|
||||
}
|
||||
|
||||
md.messageRetryLimit = retryLimit
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (md *snsSqsMetadata) setMessageVisibilityTimeout(props map[string]string) error {
|
||||
if val, ok := props["messageVisibilityTimeout"]; !ok {
|
||||
md.messageVisibilityTimeout = 10
|
||||
} else {
|
||||
timeout, err := parseInt64(val, "messageVisibilityTimeout")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if timeout < 1 {
|
||||
return errors.New("messageVisibilityTimeout must be greater than 0")
|
||||
}
|
||||
|
||||
md.messageVisibilityTimeout = timeout
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -16,17 +16,20 @@ package snssqs
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/service/sns"
|
||||
"github.com/aws/aws-sdk-go/service/sqs"
|
||||
"github.com/aws/aws-sdk-go/service/sts"
|
||||
|
||||
"github.com/dapr/kit/retry"
|
||||
|
||||
gonanoid "github.com/matoous/go-nanoid/v2"
|
||||
|
||||
aws_auth "github.com/dapr/components-contrib/authentication/aws"
|
||||
|
@ -38,17 +41,22 @@ type snsSqs struct {
|
|||
// key is the topic name, value is the ARN of the topic.
|
||||
topics sync.Map
|
||||
// key is the sanitized topic name, value is the actual topic name.
|
||||
topicsSanitized sync.Map
|
||||
sanitizedTopics sync.Map
|
||||
// key is the topic name, value holds the ARN of the queue and its url.
|
||||
queues sync.Map
|
||||
// key is a composite key of queue ARN and topic ARN mapping to subscription ARN.
|
||||
subscriptions sync.Map
|
||||
|
||||
snsClient *sns.SNS
|
||||
sqsClient *sqs.SQS
|
||||
stsClient *sts.STS
|
||||
metadata *snsSqsMetadata
|
||||
logger logger.Logger
|
||||
id string
|
||||
opsTimeout time.Duration
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
backOffConfig retry.Config
|
||||
}
|
||||
|
||||
type sqsQueueInfo struct {
|
||||
|
@ -56,49 +64,23 @@ type sqsQueueInfo struct {
|
|||
url string
|
||||
}
|
||||
|
||||
type snsSqsMetadata struct {
|
||||
// aws endpoint for the component to use.
|
||||
Endpoint string
|
||||
// access key to use for accessing sqs/sns.
|
||||
AccessKey string
|
||||
// secret key to use for accessing sqs/sns.
|
||||
SecretKey string
|
||||
// aws session token to use.
|
||||
SessionToken string
|
||||
// aws region in which SNS/SQS should create resources.
|
||||
Region string
|
||||
// name of the queue for this application. The is provided by the runtime as "consumerID".
|
||||
sqsQueueName string
|
||||
// name of the dead letter queue for this application.
|
||||
sqsDeadLettersQueueName string
|
||||
// flag to SNS and SQS FIFO.
|
||||
fifo bool
|
||||
// a namespace for SNS SQS FIFO to order messages within that group. limits consumer concurrency if set but guarantees that all
|
||||
// published messages would be ordered by their arrival time to SQS.
|
||||
// see: https://aws.amazon.com/blogs/compute/solving-complex-ordering-challenges-with-amazon-sqs-fifo-queues/
|
||||
fifoMessageGroupID string
|
||||
// amount of time in seconds that a message is hidden from receive requests after it is sent to a subscriber. Default: 10.
|
||||
messageVisibilityTimeout int64
|
||||
// number of times to resend a message after processing of that message fails before removing that message from the queue. Default: 10.
|
||||
messageRetryLimit int64
|
||||
// if sqsDeadLettersQueueName is set to a value, then the messageReceiveLimit defines the number of times a message is received
|
||||
// before it is moved to the dead-letters queue. This value must be smaller than messageRetryLimit.
|
||||
messageReceiveLimit int64
|
||||
// amount of time to await receipt of a message before making another request. Default: 1.
|
||||
messageWaitTimeSeconds int64
|
||||
// maximum number of messages to receive from the queue at a time. Default: 10, Maximum: 10.
|
||||
messageMaxNumber int64
|
||||
// disable resource provisioning of SNS and SQS.
|
||||
disableEntityManagement bool
|
||||
// aws account ID.
|
||||
accountID string
|
||||
type snsMessage struct {
|
||||
Message string
|
||||
TopicArn string
|
||||
}
|
||||
|
||||
func (sn *snsMessage) parseTopicArn() string {
|
||||
arn := sn.TopicArn
|
||||
return arn[strings.LastIndex(arn, ":")+1:]
|
||||
}
|
||||
|
||||
const (
|
||||
awsSqsQueueNameKey = "dapr-queue-name"
|
||||
awsSnsTopicNameKey = "dapr-topic-name"
|
||||
awsSqsFifoSuffix = ".fifo"
|
||||
maxAWSNameLength = 80
|
||||
awsSqsQueueNameKey = "dapr-queue-name"
|
||||
awsSnsTopicNameKey = "dapr-topic-name"
|
||||
awsSqsFifoSuffix = ".fifo"
|
||||
maxAWSNameLength = 80
|
||||
assetsManagementDefaultTimeoutSeconds = 5.0
|
||||
awsAccountIDLength = 12
|
||||
)
|
||||
|
||||
// NewSnsSqs - constructor for a new snssqs dapr component.
|
||||
|
@ -114,34 +96,6 @@ func NewSnsSqs(l logger.Logger) pubsub.PubSub {
|
|||
}
|
||||
}
|
||||
|
||||
func getAliasedProperty(aliases []string, metadata pubsub.Metadata) (string, bool) {
|
||||
props := metadata.Properties
|
||||
for _, s := range aliases {
|
||||
if val, ok := props[s]; ok {
|
||||
return val, true
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
func parseInt64(input string, propertyName string) (int64, error) {
|
||||
number, err := strconv.Atoi(input)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("parsing %s failed with: %w", propertyName, err)
|
||||
}
|
||||
|
||||
return int64(number), nil
|
||||
}
|
||||
|
||||
func parseBool(input string, propertyName string) (bool, error) {
|
||||
val, err := strconv.ParseBool(input)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parsing %s failed with: %w", propertyName, err)
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// sanitize topic/queue name to conform with:
|
||||
// https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/quotas-queues.html
|
||||
func nameToAWSSanitizedName(name string, isFifo bool) string {
|
||||
|
@ -181,143 +135,6 @@ func nameToAWSSanitizedName(name string, isFifo bool) string {
|
|||
return string(s[:j])
|
||||
}
|
||||
|
||||
func (s *snsSqs) getSnsSqsMetatdata(metadata pubsub.Metadata) (*snsSqsMetadata, error) {
|
||||
md := snsSqsMetadata{}
|
||||
props := metadata.Properties
|
||||
md.sqsQueueName = metadata.Properties["consumerID"]
|
||||
s.logger.Debugf("Setting queue name to %s", md.sqsQueueName)
|
||||
|
||||
if val, ok := getAliasedProperty([]string{"Endpoint", "endpoint"}, metadata); ok {
|
||||
s.logger.Debugf("endpoint: %s", val)
|
||||
md.Endpoint = val
|
||||
}
|
||||
|
||||
if val, ok := getAliasedProperty([]string{"awsAccountID", "accessKey"}, metadata); ok {
|
||||
s.logger.Debugf("accessKey: %s", val)
|
||||
md.AccessKey = val
|
||||
}
|
||||
|
||||
if val, ok := getAliasedProperty([]string{"awsSecret", "secretKey"}, metadata); ok {
|
||||
s.logger.Debugf("secretKey: %s", val)
|
||||
md.SecretKey = val
|
||||
}
|
||||
|
||||
if val, ok := props["sessionToken"]; ok {
|
||||
md.SessionToken = val
|
||||
}
|
||||
|
||||
if val, ok := getAliasedProperty([]string{"awsRegion", "region"}, metadata); ok {
|
||||
md.Region = val
|
||||
}
|
||||
|
||||
if val, ok := props["messageVisibilityTimeout"]; !ok {
|
||||
md.messageVisibilityTimeout = 10
|
||||
} else {
|
||||
timeout, err := parseInt64(val, "messageVisibilityTimeout")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if timeout < 1 {
|
||||
return nil, errors.New("messageVisibilityTimeout must be greater than 0")
|
||||
}
|
||||
|
||||
md.messageVisibilityTimeout = timeout
|
||||
}
|
||||
|
||||
if val, ok := props["messageRetryLimit"]; !ok {
|
||||
md.messageRetryLimit = 10
|
||||
} else {
|
||||
retryLimit, err := parseInt64(val, "messageRetryLimit")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if retryLimit < 2 {
|
||||
return nil, errors.New("messageRetryLimit must be greater than 1")
|
||||
}
|
||||
|
||||
md.messageRetryLimit = retryLimit
|
||||
}
|
||||
|
||||
if val, ok := props["sqsDeadLettersQueueName"]; ok {
|
||||
md.sqsDeadLettersQueueName = val
|
||||
}
|
||||
|
||||
if val, ok := props["messageReceiveLimit"]; ok {
|
||||
messageReceiveLimit, err := parseInt64(val, "messageReceiveLimit")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// assign: used provided configuration
|
||||
md.messageReceiveLimit = messageReceiveLimit
|
||||
}
|
||||
|
||||
// XOR on having either a valid messageReceiveLimit and invalid sqsDeadLettersQueueName, and vice versa.
|
||||
if (md.messageReceiveLimit > 0 || len(md.sqsDeadLettersQueueName) > 0) && !(md.messageReceiveLimit > 0 && len(md.sqsDeadLettersQueueName) > 0) {
|
||||
return nil, errors.New("to use SQS dead letters queue, messageReceiveLimit and sqsDeadLettersQueueName must both be set to a value")
|
||||
}
|
||||
|
||||
// fifo settings: enable/disable SNS and SQS FIFO.
|
||||
if val, ok := props["fifo"]; ok {
|
||||
fifo, err := parseBool(val, "fifo")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
md.fifo = fifo
|
||||
} else {
|
||||
md.fifo = false
|
||||
}
|
||||
|
||||
// fifo settings: assign user provided Message Group ID
|
||||
// for more details, see: https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/using-messagegroupid-property.html
|
||||
if val, ok := props["fifoMessageGroupID"]; ok {
|
||||
md.fifoMessageGroupID = val
|
||||
}
|
||||
|
||||
if val, ok := props["messageWaitTimeSeconds"]; !ok {
|
||||
md.messageWaitTimeSeconds = 1
|
||||
} else {
|
||||
waitTime, err := parseInt64(val, "messageWaitTimeSeconds")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if waitTime < 1 {
|
||||
return nil, errors.New("messageWaitTimeSeconds must be greater than 0")
|
||||
}
|
||||
|
||||
md.messageWaitTimeSeconds = waitTime
|
||||
}
|
||||
|
||||
if val, ok := props["messageMaxNumber"]; !ok {
|
||||
md.messageMaxNumber = 10
|
||||
} else {
|
||||
maxNumber, err := parseInt64(val, "messageMaxNumber")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if maxNumber < 1 {
|
||||
return nil, errors.New("messageMaxNumber must be greater than 0")
|
||||
} else if maxNumber > 10 {
|
||||
return nil, errors.New("messageMaxNumber must be less than or equal to 10")
|
||||
}
|
||||
|
||||
md.messageMaxNumber = maxNumber
|
||||
}
|
||||
|
||||
if val, ok := props["disableEntityManagement"]; ok {
|
||||
parsed, err := parseBool(val, "disableEntityManagement")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
md.disableEntityManagement = parsed
|
||||
}
|
||||
|
||||
return &md, nil
|
||||
}
|
||||
|
||||
func (s *snsSqs) Init(metadata pubsub.Metadata) error {
|
||||
md, err := s.getSnsSqsMetatdata(metadata)
|
||||
if err != nil {
|
||||
|
@ -329,7 +146,7 @@ func (s *snsSqs) Init(metadata pubsub.Metadata) error {
|
|||
// both Publish and Subscribe need reference the topic ARN, queue ARN and subscription ARN between topic and queue
|
||||
// track these ARNs in these maps.
|
||||
s.topics = sync.Map{}
|
||||
s.topicsSanitized = sync.Map{}
|
||||
s.sanitizedTopics = sync.Map{}
|
||||
s.queues = sync.Map{}
|
||||
s.subscriptions = sync.Map{}
|
||||
|
||||
|
@ -337,23 +154,48 @@ func (s *snsSqs) Init(metadata pubsub.Metadata) error {
|
|||
if err != nil {
|
||||
return fmt.Errorf("error creating an AWS client: %w", err)
|
||||
}
|
||||
|
||||
// AWS sns,sqs,sts client.
|
||||
s.snsClient = sns.New(sess)
|
||||
s.sqsClient = sqs.New(sess)
|
||||
s.stsClient = sts.New(sess)
|
||||
callerIDOutput, err := s.stsClient.GetCallerIdentity(&sts.GetCallerIdentityInput{})
|
||||
|
||||
s.opsTimeout = time.Duration(md.assetsManagementTimeoutSeconds * float64(time.Second))
|
||||
s.ctx, s.cancelFn = context.WithCancel(context.Background())
|
||||
|
||||
if err := s.setAwsAccountIDIfNotProvided(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Default retry configuration is used if no
|
||||
// backOff properties are set.
|
||||
if err := retry.DecodeConfigWithPrefix(
|
||||
&s.backOffConfig,
|
||||
metadata.Properties,
|
||||
"backOff"); err != nil {
|
||||
return fmt.Errorf("error decoding backOff config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *snsSqs) setAwsAccountIDIfNotProvided() error {
|
||||
if len(s.metadata.accountID) == awsAccountIDLength {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancelFn := context.WithTimeout(s.ctx, s.opsTimeout)
|
||||
defer cancelFn()
|
||||
|
||||
callerIDOutput, err := s.stsClient.GetCallerIdentityWithContext(ctx, &sts.GetCallerIdentityInput{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error fetching sts caller ID: %w", err)
|
||||
}
|
||||
|
||||
s.metadata.accountID = *callerIDOutput.Account
|
||||
|
||||
s.snsClient = sns.New(sess)
|
||||
s.sqsClient = sqs.New(sess)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *snsSqs) buildARN(serviceName, entityName string) string {
|
||||
// arn:aws:sns:us-east-1:302212680347:aws-controltower-SecurityNotifications
|
||||
return fmt.Sprintf("arn:aws:%s:%s:%s:%s", serviceName, s.metadata.Region, s.metadata.accountID, entityName)
|
||||
}
|
||||
|
||||
|
@ -369,7 +211,10 @@ func (s *snsSqs) createTopic(topic string) (string, error) {
|
|||
snsCreateTopicInput.SetAttributes(attributes)
|
||||
}
|
||||
|
||||
createTopicResponse, err := s.snsClient.CreateTopic(snsCreateTopicInput)
|
||||
ctx, cancelFn := context.WithTimeout(s.ctx, s.opsTimeout)
|
||||
defer cancelFn()
|
||||
|
||||
createTopicResponse, err := s.snsClient.CreateTopicWithContext(ctx, snsCreateTopicInput)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error while creating an SNS topic: %w", err)
|
||||
}
|
||||
|
@ -378,8 +223,11 @@ func (s *snsSqs) createTopic(topic string) (string, error) {
|
|||
}
|
||||
|
||||
func (s *snsSqs) getTopicArn(topic string) (string, error) {
|
||||
ctx, cancelFn := context.WithTimeout(s.ctx, s.opsTimeout)
|
||||
defer cancelFn()
|
||||
|
||||
arn := s.buildARN("sns", topic)
|
||||
getTopicOutput, err := s.snsClient.GetTopicAttributes(&sns.GetTopicAttributesInput{TopicArn: aws.String(arn)})
|
||||
getTopicOutput, err := s.snsClient.GetTopicAttributesWithContext(ctx, &sns.GetTopicAttributesInput{TopicArn: aws.String(arn)})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error: %w while getting topic: %v with arn: %v", err, topic, arn)
|
||||
}
|
||||
|
@ -422,7 +270,7 @@ func (s *snsSqs) getOrCreateTopic(topic string) (string, error) {
|
|||
|
||||
// record topic ARN.
|
||||
s.topics.Store(topic, topicArn)
|
||||
s.topicsSanitized.Store(sanitizedName, topic)
|
||||
s.sanitizedTopics.Store(sanitizedName, topic)
|
||||
|
||||
return topicArn, nil
|
||||
}
|
||||
|
@ -438,13 +286,18 @@ func (s *snsSqs) createQueue(queueName string) (*sqsQueueInfo, error) {
|
|||
attributes := map[string]*string{"FifoQueue": aws.String("true"), "ContentBasedDeduplication": aws.String("true")}
|
||||
sqsCreateQueueInput.SetAttributes(attributes)
|
||||
}
|
||||
ctx, cancelFn := context.WithTimeout(s.ctx, s.opsTimeout)
|
||||
defer cancelFn()
|
||||
|
||||
createQueueResponse, err := s.sqsClient.CreateQueue(sqsCreateQueueInput)
|
||||
createQueueResponse, err := s.sqsClient.CreateQueueWithContext(ctx, sqsCreateQueueInput)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creaing an SQS queue: %w", err)
|
||||
}
|
||||
|
||||
queueAttributesResponse, err := s.sqsClient.GetQueueAttributes(&sqs.GetQueueAttributesInput{
|
||||
aCtx, aCancelFn := context.WithTimeout(s.ctx, s.opsTimeout)
|
||||
defer aCancelFn()
|
||||
|
||||
queueAttributesResponse, err := s.sqsClient.GetQueueAttributesWithContext(aCtx, &sqs.GetQueueAttributesInput{
|
||||
AttributeNames: []*string{aws.String("QueueArn")},
|
||||
QueueUrl: createQueueResponse.QueueUrl,
|
||||
})
|
||||
|
@ -459,14 +312,20 @@ func (s *snsSqs) createQueue(queueName string) (*sqsQueueInfo, error) {
|
|||
}
|
||||
|
||||
func (s *snsSqs) getQueueArn(queueName string) (*sqsQueueInfo, error) {
|
||||
queueURLOutput, err := s.sqsClient.GetQueueUrl(&sqs.GetQueueUrlInput{QueueName: aws.String(queueName), QueueOwnerAWSAccountId: aws.String(s.metadata.accountID)})
|
||||
ctx, cancelFn := context.WithTimeout(s.ctx, s.opsTimeout)
|
||||
defer cancelFn()
|
||||
|
||||
queueURLOutput, err := s.sqsClient.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{QueueName: aws.String(queueName), QueueOwnerAWSAccountId: aws.String(s.metadata.accountID)})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error: %w while getting url of queue: %s", err, queueName)
|
||||
}
|
||||
url := queueURLOutput.QueueUrl
|
||||
|
||||
aCtx, aCancelFn := context.WithTimeout(s.ctx, s.opsTimeout)
|
||||
defer aCancelFn()
|
||||
|
||||
var getQueueOutput *sqs.GetQueueAttributesOutput
|
||||
getQueueOutput, err = s.sqsClient.GetQueueAttributes(&sqs.GetQueueAttributesInput{QueueUrl: url, AttributeNames: []*string{aws.String("QueueArn")}})
|
||||
getQueueOutput, err = s.sqsClient.GetQueueAttributesWithContext(aCtx, &sqs.GetQueueAttributesInput{QueueUrl: url, AttributeNames: []*string{aws.String("QueueArn")}})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error: %w while getting information for queue: %s, with url: %s", err, queueName, *url)
|
||||
}
|
||||
|
@ -525,7 +384,10 @@ func (s *snsSqs) getMessageGroupID(req *pubsub.PublishRequest) *string {
|
|||
}
|
||||
|
||||
func (s *snsSqs) createSnsSqsSubscription(queueArn, topicArn string) (string, error) {
|
||||
subscribeOutput, err := s.snsClient.Subscribe(&sns.SubscribeInput{
|
||||
ctx, cancelFn := context.WithTimeout(s.ctx, s.opsTimeout)
|
||||
defer cancelFn()
|
||||
|
||||
subscribeOutput, err := s.snsClient.SubscribeWithContext(ctx, &sns.SubscribeInput{
|
||||
Attributes: nil,
|
||||
Endpoint: aws.String(queueArn), // create SQS queue per subscription.
|
||||
Protocol: aws.String("sqs"),
|
||||
|
@ -543,7 +405,10 @@ func (s *snsSqs) createSnsSqsSubscription(queueArn, topicArn string) (string, er
|
|||
}
|
||||
|
||||
func (s *snsSqs) getSnsSqsSubscriptionArn(topicArn string) (string, error) {
|
||||
listSubscriptionsOutput, err := s.snsClient.ListSubscriptionsByTopic(&sns.ListSubscriptionsByTopicInput{TopicArn: aws.String(topicArn)})
|
||||
ctx, cancelFn := context.WithTimeout(s.ctx, s.opsTimeout)
|
||||
defer cancelFn()
|
||||
|
||||
listSubscriptionsOutput, err := s.snsClient.ListSubscriptionsByTopicWithContext(ctx, &sns.ListSubscriptionsByTopicInput{TopicArn: aws.String(topicArn)})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error listing subsriptions for topic arn: %v: %w", topicArn, err)
|
||||
}
|
||||
|
@ -594,146 +459,176 @@ func (s *snsSqs) getOrCreateSnsSqsSubscription(queueArn, topicArn string) (strin
|
|||
return subscriptionArn, nil
|
||||
}
|
||||
|
||||
func (s *snsSqs) Publish(req *pubsub.PublishRequest) error {
|
||||
topicArn, err := s.getOrCreateTopic(req.Topic)
|
||||
if err != nil {
|
||||
s.logger.Errorf("error getting topic ARN for %s: %v", req.Topic, err)
|
||||
}
|
||||
|
||||
message := string(req.Data)
|
||||
snsPublishInput := &sns.PublishInput{
|
||||
Message: &message,
|
||||
TopicArn: &topicArn,
|
||||
}
|
||||
if s.metadata.fifo {
|
||||
snsPublishInput.MessageGroupId = s.getMessageGroupID(req)
|
||||
}
|
||||
|
||||
_, err = s.snsClient.Publish(snsPublishInput)
|
||||
if err != nil {
|
||||
wrappedErr := fmt.Errorf("error publishing to topic: %s with topic ARN %s: %w", req.Topic, topicArn, err)
|
||||
s.logger.Error(wrappedErr)
|
||||
|
||||
return wrappedErr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type snsMessage struct {
|
||||
Message string
|
||||
TopicArn string
|
||||
}
|
||||
|
||||
func parseTopicArn(arn string) string {
|
||||
return arn[strings.LastIndex(arn, ":")+1:]
|
||||
}
|
||||
|
||||
func (s *snsSqs) acknowledgeMessage(queueURL string, receiptHandle *string) error {
|
||||
if _, err := s.sqsClient.DeleteMessage(&sqs.DeleteMessageInput{
|
||||
QueueUrl: &queueURL,
|
||||
ctx, cancelFn := context.WithCancel(s.ctx)
|
||||
defer cancelFn()
|
||||
|
||||
deleteMessageInput := &sqs.DeleteMessageInput{
|
||||
QueueUrl: aws.String(queueURL),
|
||||
ReceiptHandle: receiptHandle,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("error deleting SQS message: %w", err)
|
||||
}
|
||||
if _, err := s.sqsClient.DeleteMessageWithContext(ctx, deleteMessageInput); err != nil {
|
||||
return fmt.Errorf("error deleting message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *snsSqs) handleMessage(message *sqs.Message, queueInfo, deadLettersQueueInfo *sqsQueueInfo, handler pubsub.Handler) error {
|
||||
func (s *snsSqs) resetMessageVisibilityTimeout(queueURL string, receiptHandle *string) error {
|
||||
ctx, cancelFn := context.WithCancel(s.ctx)
|
||||
defer cancelFn()
|
||||
|
||||
// reset the timeout to its initial value so that the remaining timeout would be overridden by the initial value for other consumer to attempt processing.
|
||||
changeMessageVisibilityInput := &sqs.ChangeMessageVisibilityInput{
|
||||
QueueUrl: aws.String(queueURL),
|
||||
ReceiptHandle: receiptHandle,
|
||||
VisibilityTimeout: aws.Int64(s.metadata.messageVisibilityTimeout),
|
||||
}
|
||||
if _, err := s.sqsClient.ChangeMessageVisibilityWithContext(ctx, changeMessageVisibilityInput); err != nil {
|
||||
return fmt.Errorf("error changing message visibility timeout: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *snsSqs) parseReceiveCount(message *sqs.Message) (int64, error) {
|
||||
// if this message has been received > x times, delete from queue, it's borked.
|
||||
recvCount, ok := message.Attributes[sqs.MessageSystemAttributeNameApproximateReceiveCount]
|
||||
|
||||
if !ok {
|
||||
return fmt.Errorf(
|
||||
return 0, fmt.Errorf(
|
||||
"no ApproximateReceiveCount returned with response, will not attempt further processing: %v", message)
|
||||
}
|
||||
|
||||
recvCountInt, err := strconv.ParseInt(*recvCount, 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing ApproximateReceiveCount from message: %v", message)
|
||||
return 0, fmt.Errorf("error parsing ApproximateReceiveCount from message: %v", message)
|
||||
}
|
||||
return recvCountInt, nil
|
||||
}
|
||||
|
||||
func (s *snsSqs) validateMessage(message *sqs.Message, queueInfo, deadLettersQueueInfo *sqsQueueInfo, handler pubsub.Handler) error {
|
||||
recvCount, err := s.parseReceiveCount(message)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if we are over the allowable retry limit, and there is no dead-letters queue, delete the message from the queue.
|
||||
if deadLettersQueueInfo == nil && recvCountInt >= s.metadata.messageRetryLimit {
|
||||
if innerErr := s.acknowledgeMessage(queueInfo.url, message.ReceiptHandle); innerErr != nil {
|
||||
return fmt.Errorf("error acknowledging message after receiving the message too many times: %w", innerErr)
|
||||
messageRetryLimit := s.metadata.messageRetryLimit
|
||||
if deadLettersQueueInfo == nil && recvCount >= messageRetryLimit {
|
||||
// if we are over the allowable retry limit, and there is no dead-letters queue, and we don't disable deletes, then delete the message from the queue.
|
||||
if !s.metadata.disableDeleteOnRetryLimit {
|
||||
if innerErr := s.acknowledgeMessage(queueInfo.url, message.ReceiptHandle); innerErr != nil {
|
||||
return fmt.Errorf("error acknowledging message after receiving the message too many times: %w", innerErr)
|
||||
}
|
||||
return fmt.Errorf("message received greater than %v times, deleting this message without further processing", messageRetryLimit)
|
||||
}
|
||||
// if we are over the allowable retry limit, and there is no dead-letters queue, and deletes are disabled, then don't delete the message from the queue.
|
||||
// reset the already "consumed" message visibility clock.
|
||||
s.logger.Debugf("message received greater than %v times. deletion past the thredhold is diabled. noop", messageRetryLimit)
|
||||
if err := s.resetMessageVisibilityTimeout(queueInfo.url, message.ReceiptHandle); err != nil {
|
||||
return fmt.Errorf("error resetting message visibility timeout: %w", err)
|
||||
}
|
||||
|
||||
return fmt.Errorf(
|
||||
"message received greater than %v times, deleting this message without further processing", s.metadata.messageRetryLimit)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ... else, there is no need to actively do something if we reached the limit defined in messageReceiveLimit as the message had
|
||||
// already been moved to the dead-letters queue by SQS. meaning, the below condition should not be reached as SQS would not send
|
||||
// a message if we've already surpassed the s.metadata.messageReceiveLimit value.
|
||||
if deadLettersQueueInfo != nil && recvCountInt > s.metadata.messageReceiveLimit {
|
||||
// a message if we've already surpassed the messageRetryLimit value.
|
||||
if deadLettersQueueInfo != nil && recvCount > messageRetryLimit {
|
||||
awsErr := fmt.Errorf(
|
||||
"message received greater than %v times, this message should have been moved without further processing to dead-letters queue: %v", s.metadata.messageReceiveLimit, s.metadata.sqsDeadLettersQueueName)
|
||||
s.logger.Error(awsErr)
|
||||
"message received greater than %v times, this message should have been moved without further processing to dead-letters queue: %v", messageRetryLimit, s.metadata.sqsDeadLettersQueueName)
|
||||
|
||||
return awsErr
|
||||
}
|
||||
|
||||
// otherwise try to handle the message.
|
||||
var messageBody snsMessage
|
||||
err = json.Unmarshal([]byte(*(message.Body)), &messageBody)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *snsSqs) callHandler(message *sqs.Message, queueInfo *sqsQueueInfo, handler pubsub.Handler) error {
|
||||
// otherwise, try to handle the message.
|
||||
var snsMessagePayload snsMessage
|
||||
err := json.Unmarshal([]byte(*(message.Body)), &snsMessagePayload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error unmarshalling message: %w", err)
|
||||
}
|
||||
|
||||
// messageBody.TopicArn can only carry a sanitized topic name as we conform to AWS naming standards.
|
||||
// snsMessagePayload.TopicArn can only carry a sanitized topic name as we conform to AWS naming standards.
|
||||
// for the user to be able to understand the source of the coming message, we'd use the original,
|
||||
// dirty name to be carried over in the pubsub.NewMessage Topic field.
|
||||
sanitizedTopic := parseTopicArn(messageBody.TopicArn)
|
||||
cachedTopic, ok := s.topicsSanitized.Load(sanitizedTopic)
|
||||
sanitizedTopic := snsMessagePayload.parseTopicArn()
|
||||
cachedTopic, ok := s.sanitizedTopics.Load(sanitizedTopic)
|
||||
if !ok {
|
||||
return fmt.Errorf("failed loading topic (sanitized): %s from internal topics cache. SNS topic might be just created", sanitizedTopic)
|
||||
}
|
||||
|
||||
err = handler(context.Background(), &pubsub.NewMessage{
|
||||
Data: []byte(messageBody.Message),
|
||||
Topic: cachedTopic.(string),
|
||||
})
|
||||
s.logger.Debugf("Processing SNS message id: %s of topic: %s", message.MessageId, sanitizedTopic)
|
||||
|
||||
if err != nil {
|
||||
ctx, cancelFn := context.WithCancel(s.ctx)
|
||||
defer cancelFn()
|
||||
|
||||
if err := handler(ctx, &pubsub.NewMessage{
|
||||
Data: []byte(snsMessagePayload.Message),
|
||||
Topic: cachedTopic.(string),
|
||||
}); err != nil {
|
||||
return fmt.Errorf("error handling message: %w", err)
|
||||
}
|
||||
|
||||
// otherwise, there was no error, acknowledge the message.
|
||||
return s.acknowledgeMessage(queueInfo.url, message.ReceiptHandle)
|
||||
}
|
||||
|
||||
func (s *snsSqs) consumeSubscription(queueInfo, deadLettersQueueInfo *sqsQueueInfo, handler pubsub.Handler) {
|
||||
go func() {
|
||||
ctx, cancelFn := context.WithCancel(s.ctx)
|
||||
defer cancelFn()
|
||||
|
||||
sqsPullExponentialBackoff := s.backOffConfig.NewBackOffWithContext(ctx)
|
||||
|
||||
receiveMessageInput := &sqs.ReceiveMessageInput{
|
||||
// use this property to decide when a message should be discarded.
|
||||
AttributeNames: []*string{
|
||||
aws.String(sqs.MessageSystemAttributeNameApproximateReceiveCount),
|
||||
},
|
||||
MaxNumberOfMessages: aws.Int64(s.metadata.messageMaxNumber),
|
||||
QueueUrl: aws.String(queueInfo.url),
|
||||
VisibilityTimeout: aws.Int64(s.metadata.messageVisibilityTimeout),
|
||||
WaitTimeSeconds: aws.Int64(s.metadata.messageWaitTimeSeconds),
|
||||
}
|
||||
|
||||
for {
|
||||
messageResponse, err := s.sqsClient.ReceiveMessage(&sqs.ReceiveMessageInput{
|
||||
// use this property to decide when a message should be discarded.
|
||||
AttributeNames: []*string{
|
||||
aws.String(sqs.MessageSystemAttributeNameApproximateReceiveCount),
|
||||
},
|
||||
MaxNumberOfMessages: aws.Int64(s.metadata.messageMaxNumber),
|
||||
QueueUrl: aws.String(queueInfo.url),
|
||||
VisibilityTimeout: aws.Int64(s.metadata.messageVisibilityTimeout),
|
||||
WaitTimeSeconds: aws.Int64(s.metadata.messageWaitTimeSeconds),
|
||||
})
|
||||
// Internally, by default, aws go sdk performs 3 retires with exponential backoff to contact
|
||||
// sqs and try pull messages. Since we are iteratively short polling (based on the defined
|
||||
// s.metadata.messageWaitTimeSeconds) the sdk backoff is not effective as it gets reset per each polling
|
||||
// iteration. Therefore, a global backoff (to the internal backoff) is used (sqsPullExponentialBackoff).
|
||||
messageResponse, err := s.sqsClient.ReceiveMessageWithContext(ctx, receiveMessageInput)
|
||||
if err != nil {
|
||||
s.logger.Errorf("error consuming topic: %v", err)
|
||||
if awsErr, ok := err.(awserr.Error); ok {
|
||||
s.logger.Errorf("AWS operation error while consuming from queue url: %v with error: %w. retrying...", queueInfo.url, awsErr.Error())
|
||||
} else {
|
||||
s.logger.Errorf("error consuming from queue url: %v with error: %w. retrying...", queueInfo.url, err)
|
||||
}
|
||||
time.Sleep(sqsPullExponentialBackoff.NextBackOff())
|
||||
|
||||
continue
|
||||
}
|
||||
// error either recovered or did not happen at all. resetting the backoff counter (and duration).
|
||||
sqsPullExponentialBackoff.Reset()
|
||||
|
||||
// retry receiving messages.
|
||||
if len(messageResponse.Messages) < 1 {
|
||||
s.logger.Debug("No messages received, requesting again")
|
||||
s.logger.Debug("No messages received, continuing")
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
s.logger.Debugf("%v message(s) received", len(messageResponse.Messages))
|
||||
|
||||
for _, m := range messageResponse.Messages {
|
||||
if err := s.handleMessage(m, queueInfo, deadLettersQueueInfo, handler); err != nil {
|
||||
s.logger.Error(err)
|
||||
for _, message := range messageResponse.Messages {
|
||||
if err := s.validateMessage(message, queueInfo, deadLettersQueueInfo, handler); err != nil {
|
||||
s.logger.Errorf("message is not valid for further processing by the handler. error is: %w", err)
|
||||
continue
|
||||
}
|
||||
if err := s.callHandler(message, queueInfo, handler); err != nil {
|
||||
s.logger.Errorf("error handling received message with error: %w", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -782,8 +677,11 @@ func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(sqsQueueInfo *sqsQueueInfo,
|
|||
if s.metadata.disableEntityManagement {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancelFn := context.WithTimeout(s.ctx, s.opsTimeout)
|
||||
defer cancelFn()
|
||||
// only permit SNS to send messages to SQS using the created subscription.
|
||||
getQueueAttributesOutput, err := s.sqsClient.GetQueueAttributes(&sqs.GetQueueAttributesInput{QueueUrl: &sqsQueueInfo.url, AttributeNames: []*string{aws.String(sqs.QueueAttributeNamePolicy)}})
|
||||
getQueueAttributesOutput, err := s.sqsClient.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{QueueUrl: &sqsQueueInfo.url, AttributeNames: []*string{aws.String(sqs.QueueAttributeNamePolicy)}})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting queue attributes: %w", err)
|
||||
}
|
||||
|
@ -818,7 +716,10 @@ func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(sqsQueueInfo *sqsQueueInfo,
|
|||
return fmt.Errorf("failed serializing new sqs policy: %w", uerr)
|
||||
}
|
||||
|
||||
if _, err = s.sqsClient.SetQueueAttributes(&(sqs.SetQueueAttributesInput{
|
||||
aCtx, aCancelFn := context.WithTimeout(s.ctx, s.opsTimeout)
|
||||
defer aCancelFn()
|
||||
|
||||
if _, err = s.sqsClient.SetQueueAttributesWithContext(aCtx, &(sqs.SetQueueAttributesInput{
|
||||
Attributes: map[string]*string{
|
||||
"Policy": aws.String(string(b)),
|
||||
},
|
||||
|
@ -882,7 +783,10 @@ func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler)
|
|||
return wrappedErr
|
||||
}
|
||||
|
||||
_, derr = s.sqsClient.SetQueueAttributes(sqsSetQueueAttributesInput)
|
||||
ctx, cancelFn := context.WithTimeout(s.ctx, s.opsTimeout)
|
||||
defer cancelFn()
|
||||
|
||||
_, derr = s.sqsClient.SetQueueAttributesWithContext(ctx, sqsSetQueueAttributesInput)
|
||||
if derr != nil {
|
||||
wrappedErr := fmt.Errorf("error updating queue attributes with dead-letter queue: %w", derr)
|
||||
s.logger.Error(wrappedErr)
|
||||
|
@ -904,7 +808,38 @@ func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler)
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *snsSqs) Publish(req *pubsub.PublishRequest) error {
|
||||
topicArn, err := s.getOrCreateTopic(req.Topic)
|
||||
if err != nil {
|
||||
s.logger.Errorf("error getting topic ARN for %s: %v", req.Topic, err)
|
||||
}
|
||||
|
||||
message := string(req.Data)
|
||||
snsPublishInput := &sns.PublishInput{
|
||||
Message: aws.String(message),
|
||||
TopicArn: aws.String(topicArn),
|
||||
}
|
||||
if s.metadata.fifo {
|
||||
snsPublishInput.MessageGroupId = s.getMessageGroupID(req)
|
||||
}
|
||||
|
||||
ctx, cancelFn := context.WithCancel(s.ctx)
|
||||
defer cancelFn()
|
||||
// sns client has internal exponential backoffs.
|
||||
_, err = s.snsClient.PublishWithContext(ctx, snsPublishInput)
|
||||
if err != nil {
|
||||
wrappedErr := fmt.Errorf("error publishing to topic: %s with topic ARN %s: %w", req.Topic, topicArn, err)
|
||||
s.logger.Error(wrappedErr)
|
||||
|
||||
return wrappedErr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *snsSqs) Close() error {
|
||||
s.cancelFn()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -31,7 +31,8 @@ func Test_parseTopicArn(t *testing.T) {
|
|||
t.Parallel()
|
||||
// no further guarantees are made about this function.
|
||||
r := require.New(t)
|
||||
r.Equal("qqnoob", parseTopicArn("arn:aws:sqs:us-east-1:000000000000:qqnoob"))
|
||||
tSnsMessage := &snsMessage{TopicArn: "arn:aws:sqs:us-east-1:000000000000:qqnoob"}
|
||||
r.Equal("qqnoob", tSnsMessage.parseTopicArn())
|
||||
}
|
||||
|
||||
// Verify that all metadata ends up in the correct spot.
|
||||
|
@ -103,6 +104,9 @@ func Test_getSnsSqsMetatdata_defaults(t *testing.T) {
|
|||
r.Equal(int64(10), md.messageRetryLimit)
|
||||
r.Equal(int64(1), md.messageWaitTimeSeconds)
|
||||
r.Equal(int64(10), md.messageMaxNumber)
|
||||
r.Equal(false, md.disableEntityManagement)
|
||||
r.Equal(float64(5), md.assetsManagementTimeoutSeconds)
|
||||
r.Equal(false, md.disableDeleteOnRetryLimit)
|
||||
}
|
||||
|
||||
func Test_getSnsSqsMetatdata_legacyaliases(t *testing.T) {
|
||||
|
@ -188,6 +192,20 @@ func Test_getSnsSqsMetatdata_invalidMetadataSetup(t *testing.T) {
|
|||
}},
|
||||
name: "deadletters message queue without deadletters receive limit",
|
||||
},
|
||||
{
|
||||
metadata: pubsub.Metadata{Properties: map[string]string{
|
||||
"consumerID": "consumer",
|
||||
"Endpoint": "endpoint",
|
||||
"AccessKey": "acctId",
|
||||
"SecretKey": "secret",
|
||||
"awsToken": "token",
|
||||
"Region": "region",
|
||||
"sqsDeadLettersQueueName": "my-queue",
|
||||
"messageReceiveLimit": "9",
|
||||
"disableDeleteOnRetryLimit": "true",
|
||||
}},
|
||||
name: "deadletters message queue with disableDeleteOnRetryLimit",
|
||||
},
|
||||
{
|
||||
metadata: pubsub.Metadata{Properties: map[string]string{
|
||||
"consumerID": "consumer",
|
||||
|
@ -248,6 +266,20 @@ func Test_getSnsSqsMetatdata_invalidMetadataSetup(t *testing.T) {
|
|||
}},
|
||||
name: "invalid message retry limit",
|
||||
},
|
||||
// disableEntityManagement
|
||||
{
|
||||
metadata: pubsub.Metadata{Properties: map[string]string{
|
||||
"consumerID": "consumer",
|
||||
"Endpoint": "endpoint",
|
||||
"AccessKey": "acctId",
|
||||
"SecretKey": "secret",
|
||||
"awsToken": "token",
|
||||
"Region": "region",
|
||||
"messageRetryLimit": "10",
|
||||
"disableEntityManagement": "y",
|
||||
}},
|
||||
name: "invalid message disableEntityManagement",
|
||||
},
|
||||
}
|
||||
|
||||
l := logger.NewLogger("SnsSqs unit test")
|
||||
|
|
Loading…
Reference in New Issue