889 lines
30 KiB
Go
889 lines
30 KiB
Go
/*
|
|
Copyright 2021 The Dapr Authors
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package snssqs
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/aws/aws-sdk-go/aws"
|
|
"github.com/aws/aws-sdk-go/aws/awserr"
|
|
"github.com/aws/aws-sdk-go/service/sns"
|
|
"github.com/aws/aws-sdk-go/service/sqs"
|
|
"github.com/aws/aws-sdk-go/service/sts"
|
|
|
|
"github.com/dapr/kit/retry"
|
|
|
|
gonanoid "github.com/matoous/go-nanoid/v2"
|
|
|
|
awsAuth "github.com/dapr/components-contrib/common/authentication/aws"
|
|
"github.com/dapr/components-contrib/metadata"
|
|
"github.com/dapr/components-contrib/pubsub"
|
|
"github.com/dapr/kit/logger"
|
|
)
|
|
|
|
type snsSqs struct {
|
|
topicsLocker TopicsLocker
|
|
// key is the sanitized topic name
|
|
topicArns map[string]string
|
|
// key is the topic name, value holds the ARN of the queue and its url.
|
|
queues map[string]*sqsQueueInfo
|
|
// key is a composite key of queue ARN and topic ARN mapping to subscription ARN.
|
|
subscriptions map[string]string
|
|
snsClient *sns.SNS
|
|
sqsClient *sqs.SQS
|
|
stsClient *sts.STS
|
|
metadata *snsSqsMetadata
|
|
logger logger.Logger
|
|
id string
|
|
opsTimeout time.Duration
|
|
backOffConfig retry.Config
|
|
subscriptionManager SubscriptionManagement
|
|
closed atomic.Bool
|
|
}
|
|
|
|
type sqsQueueInfo struct {
|
|
arn string
|
|
url string
|
|
}
|
|
|
|
type snsMessage struct {
|
|
Message string
|
|
TopicArn string
|
|
}
|
|
|
|
func (sn *snsMessage) parseTopicArn() string {
|
|
arn := sn.TopicArn
|
|
return arn[strings.LastIndex(arn, ":")+1:]
|
|
}
|
|
|
|
const (
|
|
awsSqsQueueNameKey = "dapr-queue-name"
|
|
awsSnsTopicNameKey = "dapr-topic-name"
|
|
awsSqsFifoSuffix = ".fifo"
|
|
maxAWSNameLength = 80
|
|
assetsManagementDefaultTimeoutSeconds = 5.0
|
|
awsAccountIDLength = 12
|
|
)
|
|
|
|
// NewSnsSqs - constructor for a new snssqs dapr component.
|
|
func NewSnsSqs(l logger.Logger) pubsub.PubSub {
|
|
id, err := gonanoid.New()
|
|
if err != nil {
|
|
l.Fatalf("failed generating unique nano id: %s", err)
|
|
}
|
|
|
|
return &snsSqs{
|
|
logger: l,
|
|
id: id,
|
|
}
|
|
}
|
|
|
|
// sanitize topic/queue name to conform with:
|
|
// https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/quotas-queues.html
|
|
func nameToAWSSanitizedName(name string, isFifo bool) string {
|
|
// first remove suffix if exists, and user requested a FIFO name, then sanitize the passed in name.
|
|
hasFifoSuffix := false
|
|
if strings.HasSuffix(name, awsSqsFifoSuffix) && isFifo {
|
|
hasFifoSuffix = true
|
|
name = name[:len(name)-len(awsSqsFifoSuffix)]
|
|
}
|
|
|
|
s := []byte(name)
|
|
j := 0
|
|
for _, b := range s {
|
|
if ('a' <= b && b <= 'z') ||
|
|
('A' <= b && b <= 'Z') ||
|
|
('0' <= b && b <= '9') ||
|
|
(b == '-') ||
|
|
(b == '_') {
|
|
s[j] = b
|
|
j++
|
|
|
|
if j == maxAWSNameLength {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// reattach/add the suffix to the sanitized name, trim more if adding the suffix would exceed the maxLength.
|
|
if hasFifoSuffix || isFifo {
|
|
delta := j + len(awsSqsFifoSuffix) - maxAWSNameLength
|
|
if delta > 0 {
|
|
j -= delta
|
|
}
|
|
return string(s[:j]) + awsSqsFifoSuffix
|
|
}
|
|
|
|
return string(s[:j])
|
|
}
|
|
|
|
func (s *snsSqs) Init(ctx context.Context, metadata pubsub.Metadata) error {
|
|
md, err := s.getSnsSqsMetatdata(metadata)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
s.metadata = md
|
|
|
|
sess, err := awsAuth.GetClient(md.AccessKey, md.SecretKey, md.SessionToken, md.Region, md.Endpoint)
|
|
if err != nil {
|
|
return fmt.Errorf("error creating an AWS client: %w", err)
|
|
}
|
|
// AWS sns,sqs,sts client.
|
|
s.snsClient = sns.New(sess)
|
|
s.sqsClient = sqs.New(sess)
|
|
s.stsClient = sts.New(sess)
|
|
|
|
s.opsTimeout = time.Duration(md.AssetsManagementTimeoutSeconds * float64(time.Second))
|
|
|
|
err = s.setAwsAccountIDIfNotProvided(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Default retry configuration is used if no
|
|
// backOff properties are set.
|
|
err = retry.DecodeConfigWithPrefix(&s.backOffConfig, metadata.Properties, "backOff")
|
|
if err != nil {
|
|
return fmt.Errorf("error decoding backOff config: %w", err)
|
|
}
|
|
// subscription manager responsible for managing the lifecycle of subscriptions.
|
|
s.subscriptionManager = NewSubscriptionMgmt(s.logger)
|
|
s.topicsLocker = NewLockManager()
|
|
|
|
s.topicArns = make(map[string]string)
|
|
s.queues = make(map[string]*sqsQueueInfo)
|
|
s.subscriptions = make(map[string]string)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *snsSqs) setAwsAccountIDIfNotProvided(parentCtx context.Context) error {
|
|
if len(s.metadata.AccountID) == awsAccountIDLength {
|
|
return nil
|
|
}
|
|
|
|
ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout)
|
|
callerIDOutput, err := s.stsClient.GetCallerIdentityWithContext(ctx, &sts.GetCallerIdentityInput{})
|
|
cancelFn()
|
|
if err != nil {
|
|
return fmt.Errorf("error fetching sts caller ID: %w", err)
|
|
}
|
|
|
|
s.metadata.AccountID = *callerIDOutput.Account
|
|
return nil
|
|
}
|
|
|
|
func (s *snsSqs) buildARN(serviceName, entityName string) string {
|
|
return fmt.Sprintf("arn:%s:%s:%s:%s:%s", s.metadata.internalPartition, serviceName, s.metadata.Region, s.metadata.AccountID, entityName)
|
|
}
|
|
|
|
func (s *snsSqs) createTopic(parentCtx context.Context, topic string) (string, error) {
|
|
sanitizedName := nameToAWSSanitizedName(topic, s.metadata.Fifo)
|
|
snsCreateTopicInput := &sns.CreateTopicInput{
|
|
Name: aws.String(sanitizedName),
|
|
Tags: []*sns.Tag{{Key: aws.String(awsSnsTopicNameKey), Value: aws.String(topic)}},
|
|
}
|
|
|
|
if s.metadata.Fifo {
|
|
attributes := map[string]*string{"FifoTopic": aws.String("true"), "ContentBasedDeduplication": aws.String("true")}
|
|
snsCreateTopicInput.SetAttributes(attributes)
|
|
}
|
|
|
|
ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout)
|
|
createTopicResponse, err := s.snsClient.CreateTopicWithContext(ctx, snsCreateTopicInput)
|
|
cancelFn()
|
|
if err != nil {
|
|
return "", fmt.Errorf("error while creating an SNS topic: %w", err)
|
|
}
|
|
|
|
return *(createTopicResponse.TopicArn), nil
|
|
}
|
|
|
|
func (s *snsSqs) getTopicArn(parentCtx context.Context, topic string) (string, error) {
|
|
ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout)
|
|
arn := s.buildARN("sns", topic)
|
|
getTopicOutput, err := s.snsClient.GetTopicAttributesWithContext(ctx, &sns.GetTopicAttributesInput{
|
|
TopicArn: &arn,
|
|
})
|
|
cancelFn()
|
|
if err != nil {
|
|
return "", fmt.Errorf("error: %w, while getting (sanitized) topic: %v with arn: %v", err, topic, arn)
|
|
}
|
|
|
|
return *getTopicOutput.Attributes["TopicArn"], nil
|
|
}
|
|
|
|
// get the topic ARN from the topics map. If it doesn't exist in the map, try to fetch it from AWS, if it doesn't exist
|
|
// at all, issue a request to create the topic.
|
|
func (s *snsSqs) getOrCreateTopic(ctx context.Context, topic string) (topicArn string, sanitizedTopic string, err error) {
|
|
sanitizedTopic = nameToAWSSanitizedName(topic, s.metadata.Fifo)
|
|
|
|
var loadOK bool
|
|
if topicArn, loadOK = s.topicArns[sanitizedTopic]; loadOK {
|
|
if len(topicArn) > 0 {
|
|
s.logger.Debugf("Found existing topic ARN for topic %s: %s", topic, topicArn)
|
|
|
|
return topicArn, sanitizedTopic, err
|
|
} else {
|
|
err = fmt.Errorf("the ARN for (sanitized) topic: %s was empty", sanitizedTopic)
|
|
|
|
return topicArn, sanitizedTopic, err
|
|
}
|
|
}
|
|
|
|
// creating queues is idempotent, the names serve as unique keys among a given region.
|
|
s.logger.Debugf("No SNS topic ARN found for topic: %s. creating SNS with (sanitized) topic: %s", topic, sanitizedTopic)
|
|
|
|
if !s.metadata.DisableEntityManagement {
|
|
topicArn, err = s.createTopic(ctx, sanitizedTopic)
|
|
if err != nil {
|
|
err = fmt.Errorf("error creating new (sanitized) topic '%s': %w", topic, err)
|
|
|
|
return topicArn, sanitizedTopic, err
|
|
}
|
|
} else {
|
|
topicArn, err = s.getTopicArn(ctx, sanitizedTopic)
|
|
if err != nil {
|
|
err = fmt.Errorf("error fetching info for (sanitized) topic: %s. wrapped error is: %w", topic, err)
|
|
|
|
return topicArn, sanitizedTopic, err
|
|
}
|
|
}
|
|
|
|
// record topic ARN.
|
|
s.topicArns[sanitizedTopic] = topicArn
|
|
|
|
return topicArn, sanitizedTopic, err
|
|
}
|
|
|
|
func (s *snsSqs) createQueue(parentCtx context.Context, queueName string) (*sqsQueueInfo, error) {
|
|
sanitizedName := nameToAWSSanitizedName(queueName, s.metadata.Fifo)
|
|
sqsCreateQueueInput := &sqs.CreateQueueInput{
|
|
QueueName: aws.String(sanitizedName),
|
|
Tags: map[string]*string{awsSqsQueueNameKey: aws.String(queueName)},
|
|
}
|
|
|
|
if s.metadata.Fifo {
|
|
attributes := map[string]*string{"FifoQueue": aws.String("true"), "ContentBasedDeduplication": aws.String("true")}
|
|
sqsCreateQueueInput.SetAttributes(attributes)
|
|
}
|
|
ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout)
|
|
createQueueResponse, err := s.sqsClient.CreateQueueWithContext(ctx, sqsCreateQueueInput)
|
|
cancel()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error creaing an SQS queue: %w", err)
|
|
}
|
|
|
|
ctx, cancel = context.WithTimeout(parentCtx, s.opsTimeout)
|
|
queueAttributesResponse, err := s.sqsClient.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{
|
|
AttributeNames: []*string{aws.String("QueueArn")},
|
|
QueueUrl: createQueueResponse.QueueUrl,
|
|
})
|
|
cancel()
|
|
if err != nil {
|
|
s.logger.Errorf("error fetching queue attributes for %s: %v", queueName, err)
|
|
}
|
|
|
|
return &sqsQueueInfo{
|
|
arn: *(queueAttributesResponse.Attributes["QueueArn"]),
|
|
url: *(createQueueResponse.QueueUrl),
|
|
}, nil
|
|
}
|
|
|
|
func (s *snsSqs) getQueueArn(parentCtx context.Context, queueName string) (*sqsQueueInfo, error) {
|
|
ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout)
|
|
queueURLOutput, err := s.sqsClient.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{QueueName: aws.String(queueName), QueueOwnerAWSAccountId: aws.String(s.metadata.AccountID)})
|
|
cancel()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error: %w while getting url of queue: %s", err, queueName)
|
|
}
|
|
url := queueURLOutput.QueueUrl
|
|
|
|
ctx, cancel = context.WithTimeout(parentCtx, s.opsTimeout)
|
|
getQueueOutput, err := s.sqsClient.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{QueueUrl: url, AttributeNames: []*string{aws.String("QueueArn")}})
|
|
cancel()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error: %w while getting information for queue: %s, with url: %s", err, queueName, *url)
|
|
}
|
|
|
|
return &sqsQueueInfo{arn: *getQueueOutput.Attributes["QueueArn"], url: *url}, nil
|
|
}
|
|
|
|
func (s *snsSqs) getOrCreateQueue(ctx context.Context, queueName string) (*sqsQueueInfo, error) {
|
|
var (
|
|
err error
|
|
queueInfo *sqsQueueInfo
|
|
)
|
|
|
|
if cachedQueueInfo, ok := s.queues[queueName]; ok {
|
|
s.logger.Debugf("Found queue ARN for %s: %s", queueName, cachedQueueInfo.arn)
|
|
|
|
return cachedQueueInfo, nil
|
|
}
|
|
// creating queues is idempotent, the names serve as unique keys among a given region.
|
|
s.logger.Debugf("No SQS queue ARN found for %s\nCreating SQS queue", queueName)
|
|
|
|
sanitizedName := nameToAWSSanitizedName(queueName, s.metadata.Fifo)
|
|
|
|
if !s.metadata.DisableEntityManagement {
|
|
queueInfo, err = s.createQueue(ctx, sanitizedName)
|
|
if err != nil {
|
|
s.logger.Errorf("Error creating queue %s: %v", queueName, err)
|
|
|
|
return nil, err
|
|
}
|
|
} else {
|
|
queueInfo, err = s.getQueueArn(ctx, sanitizedName)
|
|
if err != nil {
|
|
s.logger.Errorf("error fetching info for queue %s: %w", queueName, err)
|
|
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
s.queues[queueName] = queueInfo
|
|
s.logger.Debugf("created SQS queue: %s: with arn: %s", queueName, queueInfo.arn)
|
|
|
|
return queueInfo, nil
|
|
}
|
|
|
|
func (s *snsSqs) getMessageGroupID(req *pubsub.PublishRequest) *string {
|
|
if len(s.metadata.FifoMessageGroupID) > 0 {
|
|
return &s.metadata.FifoMessageGroupID
|
|
}
|
|
// each daprd, of a given PubSub, of a given publisher application publishes to a message group ID of its own.
|
|
// for example: for a daprd serving the SNS/SQS Pubsub component we generate a unique id -> A; that component serves on behalf
|
|
// of a given PubSub deployment name B, and component A publishes to SNS on behalf of a dapr application named C (effectively to topic C).
|
|
// therefore the created message group ID for publishing messages in the aforementioned setup is "A:B:C".
|
|
fifoMessageGroupID := fmt.Sprintf("%s:%s:%s", s.id, req.PubsubName, req.Topic)
|
|
return &fifoMessageGroupID
|
|
}
|
|
|
|
func (s *snsSqs) createSnsSqsSubscription(parentCtx context.Context, queueArn, topicArn string) (string, error) {
|
|
ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout)
|
|
subscribeOutput, err := s.snsClient.SubscribeWithContext(ctx, &sns.SubscribeInput{
|
|
Attributes: nil,
|
|
Endpoint: aws.String(queueArn), // create SQS queue per subscription.
|
|
Protocol: aws.String("sqs"),
|
|
ReturnSubscriptionArn: nil,
|
|
TopicArn: aws.String(topicArn),
|
|
})
|
|
cancel()
|
|
if err != nil {
|
|
wrappedErr := fmt.Errorf("error subscribing to sns topic arn: %s, to queue arn: %s %w", topicArn, queueArn, err)
|
|
s.logger.Error(wrappedErr)
|
|
|
|
return "", wrappedErr
|
|
}
|
|
|
|
return *subscribeOutput.SubscriptionArn, nil
|
|
}
|
|
|
|
func (s *snsSqs) getSnsSqsSubscriptionArn(parentCtx context.Context, topicArn string) (string, error) {
|
|
ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout)
|
|
listSubscriptionsOutput, err := s.snsClient.ListSubscriptionsByTopicWithContext(ctx, &sns.ListSubscriptionsByTopicInput{TopicArn: aws.String(topicArn)})
|
|
cancel()
|
|
if err != nil {
|
|
return "", fmt.Errorf("error listing subsriptions for topic arn: %v: %w", topicArn, err)
|
|
}
|
|
|
|
for _, subscription := range listSubscriptionsOutput.Subscriptions {
|
|
if *subscription.TopicArn == topicArn {
|
|
return *subscription.SubscriptionArn, nil
|
|
}
|
|
}
|
|
|
|
return "", fmt.Errorf("sns sqs subscription not found for topic arn")
|
|
}
|
|
|
|
func (s *snsSqs) getOrCreateSnsSqsSubscription(ctx context.Context, queueArn, topicArn string) (subscriptionArn string, err error) {
|
|
compositeKey := fmt.Sprintf("%s:%s", queueArn, topicArn)
|
|
if cachedSubscriptionArn, ok := s.subscriptions[compositeKey]; ok {
|
|
s.logger.Debugf("Found subscription of queue arn: %s to topic arn: %s: %s", queueArn, topicArn, cachedSubscriptionArn)
|
|
|
|
return cachedSubscriptionArn, nil
|
|
}
|
|
|
|
s.logger.Debugf("No subscription ARN found of queue arn:%s to topic arn: %s\nCreating subscription", queueArn, topicArn)
|
|
|
|
if !s.metadata.DisableEntityManagement {
|
|
subscriptionArn, err = s.createSnsSqsSubscription(ctx, queueArn, topicArn)
|
|
if err != nil {
|
|
s.logger.Errorf("Error creating subscription %s: %v", subscriptionArn, err)
|
|
|
|
return "", err
|
|
}
|
|
} else {
|
|
subscriptionArn, err = s.getSnsSqsSubscriptionArn(ctx, topicArn)
|
|
if err != nil {
|
|
s.logger.Errorf("error fetching info for topic ARN %s: %w", topicArn, err)
|
|
|
|
return "", err
|
|
}
|
|
}
|
|
|
|
s.subscriptions[compositeKey] = subscriptionArn
|
|
s.logger.Debugf("Subscribed to topic %s: %s", topicArn, subscriptionArn)
|
|
|
|
return subscriptionArn, nil
|
|
}
|
|
|
|
func (s *snsSqs) acknowledgeMessage(parentCtx context.Context, queueURL string, receiptHandle *string) error {
|
|
ctx, cancelFn := context.WithCancel(parentCtx)
|
|
_, err := s.sqsClient.DeleteMessageWithContext(ctx, &sqs.DeleteMessageInput{
|
|
QueueUrl: aws.String(queueURL),
|
|
ReceiptHandle: receiptHandle,
|
|
})
|
|
cancelFn()
|
|
if err != nil {
|
|
return fmt.Errorf("error deleting message: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *snsSqs) resetMessageVisibilityTimeout(parentCtx context.Context, queueURL string, receiptHandle *string) error {
|
|
ctx, cancelFn := context.WithCancel(parentCtx)
|
|
// reset the timeout to its initial value so that the remaining timeout would be overridden by the initial value for other consumer to attempt processing.
|
|
_, err := s.sqsClient.ChangeMessageVisibilityWithContext(ctx, &sqs.ChangeMessageVisibilityInput{
|
|
QueueUrl: aws.String(queueURL),
|
|
ReceiptHandle: receiptHandle,
|
|
VisibilityTimeout: aws.Int64(0),
|
|
})
|
|
cancelFn()
|
|
if err != nil {
|
|
return fmt.Errorf("error changing message visibility timeout: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *snsSqs) parseReceiveCount(message *sqs.Message) (int64, error) {
|
|
// if this message has been received > x times, delete from queue, it's borked.
|
|
recvCount, ok := message.Attributes[sqs.MessageSystemAttributeNameApproximateReceiveCount]
|
|
|
|
if !ok {
|
|
return 0, fmt.Errorf(
|
|
"no ApproximateReceiveCount returned with response, will not attempt further processing: %v", message)
|
|
}
|
|
|
|
recvCountInt, err := strconv.ParseInt(*recvCount, 10, 32)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("error parsing ApproximateReceiveCount from message: %v", message)
|
|
}
|
|
return recvCountInt, nil
|
|
}
|
|
|
|
func (s *snsSqs) validateMessage(ctx context.Context, message *sqs.Message, queueInfo, deadLettersQueueInfo *sqsQueueInfo) error {
|
|
recvCount, err := s.parseReceiveCount(message)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
messageRetryLimit := s.metadata.MessageRetryLimit
|
|
if deadLettersQueueInfo == nil && recvCount >= messageRetryLimit {
|
|
// if we are over the allowable retry limit, and there is no dead-letters queue, and we don't disable deletes, then delete the message from the queue.
|
|
if !s.metadata.DisableDeleteOnRetryLimit {
|
|
if innerErr := s.acknowledgeMessage(ctx, queueInfo.url, message.ReceiptHandle); innerErr != nil {
|
|
return fmt.Errorf("error acknowledging message after receiving the message too many times: %w", innerErr)
|
|
}
|
|
return fmt.Errorf("message received greater than %v times, deleting this message without further processing", messageRetryLimit)
|
|
}
|
|
// if we are over the allowable retry limit, and there is no dead-letters queue, and deletes are disabled, then don't delete the message from the queue.
|
|
// reset the already "consumed" message visibility clock.
|
|
s.logger.Debugf("message received greater than %v times. deletion past the thredhold is diabled. noop", messageRetryLimit)
|
|
if err := s.resetMessageVisibilityTimeout(ctx, queueInfo.url, message.ReceiptHandle); err != nil {
|
|
return fmt.Errorf("error resetting message visibility timeout: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ... else, there is no need to actively do something if we reached the limit defined in messageReceiveLimit as the message had
|
|
// already been moved to the dead-letters queue by SQS. meaning, the below condition should not be reached as SQS would not send
|
|
// a message if we've already surpassed the messageRetryLimit value.
|
|
if deadLettersQueueInfo != nil && recvCount > messageRetryLimit {
|
|
awsErr := fmt.Errorf(
|
|
"message received greater than %v times, this message should have been moved without further processing to dead-letters queue: %v", messageRetryLimit, s.metadata.SqsDeadLettersQueueName)
|
|
|
|
return awsErr
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *snsSqs) callHandler(ctx context.Context, message *sqs.Message, queueInfo *sqsQueueInfo) error {
|
|
// otherwise, try to handle the message.
|
|
var snsMessagePayload snsMessage
|
|
err := json.Unmarshal([]byte(*(message.Body)), &snsMessagePayload)
|
|
if err != nil {
|
|
return fmt.Errorf("error unmarshalling message: %w", err)
|
|
}
|
|
|
|
// snsMessagePayload.TopicArn can only carry a sanitized topic name as we conform to AWS naming standards.
|
|
// for the user to be able to understand the source of the coming message, we'd use the original,
|
|
// dirty name to be carried over in the pubsub.NewMessage Topic field.
|
|
sanitizedTopic := snsMessagePayload.parseTopicArn()
|
|
// get a handler by sanitized topic name and perform validations
|
|
var (
|
|
handler *SubscriptionTopicHandler
|
|
loadOK bool
|
|
)
|
|
if handler, loadOK = s.subscriptionManager.GetSubscriptionTopicHandler(sanitizedTopic); loadOK {
|
|
if len(handler.requestTopic) == 0 {
|
|
return fmt.Errorf("handler topic name is missing")
|
|
}
|
|
} else {
|
|
return fmt.Errorf("handler for (sanitized) topic: %s was not found", sanitizedTopic)
|
|
}
|
|
|
|
s.logger.Debugf("Processing SNS message id: %s of (sanitized) topic: %s", *message.MessageId, sanitizedTopic)
|
|
|
|
// call the handler with its own subscription context
|
|
err = handler.handler(handler.ctx, &pubsub.NewMessage{
|
|
Data: []byte(snsMessagePayload.Message),
|
|
Topic: handler.requestTopic,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("error handling message: %w", err)
|
|
}
|
|
// otherwise, there was no error, acknowledge the message.
|
|
return s.acknowledgeMessage(ctx, queueInfo.url, message.ReceiptHandle)
|
|
}
|
|
|
|
// consumeSubscription is responsible for polling messages from the queue and calling the handler.
|
|
// it is being passed as a callback to the subscription manager that initializes the context of the handler.
|
|
func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLettersQueueInfo *sqsQueueInfo) {
|
|
sqsPullExponentialBackoff := s.backOffConfig.NewBackOffWithContext(ctx)
|
|
|
|
receiveMessageInput := &sqs.ReceiveMessageInput{
|
|
// use this property to decide when a message should be discarded.
|
|
AttributeNames: []*string{
|
|
aws.String(sqs.MessageSystemAttributeNameApproximateReceiveCount),
|
|
},
|
|
MaxNumberOfMessages: aws.Int64(s.metadata.MessageMaxNumber),
|
|
QueueUrl: aws.String(queueInfo.url),
|
|
VisibilityTimeout: aws.Int64(s.metadata.MessageVisibilityTimeout),
|
|
WaitTimeSeconds: aws.Int64(s.metadata.MessageWaitTimeSeconds),
|
|
}
|
|
|
|
for {
|
|
// If the context is canceled, stop requesting messages
|
|
if ctx.Err() != nil {
|
|
break
|
|
}
|
|
|
|
// Internally, by default, aws go sdk performs 3 retires with exponential backoff to contact
|
|
// sqs and try pull messages. Since we are iteratively short polling (based on the defined
|
|
// s.metadata.messageWaitTimeSeconds) the sdk backoff is not effective as it gets reset per each polling
|
|
// iteration. Therefore, a global backoff (to the internal backoff) is used (sqsPullExponentialBackoff).
|
|
messageResponse, err := s.sqsClient.ReceiveMessageWithContext(ctx, receiveMessageInput)
|
|
if err != nil {
|
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || ctx.Err() != nil {
|
|
s.logger.Warn("context canceled; stopping consuming from queue arn: %v", queueInfo.arn)
|
|
continue
|
|
}
|
|
|
|
var awsErr awserr.Error
|
|
if errors.As(err, &awsErr) {
|
|
s.logger.Errorf("AWS operation error while consuming from queue arn: %v with error: %w. retrying...", queueInfo.arn, awsErr.Error())
|
|
} else {
|
|
s.logger.Errorf("error consuming from queue arn: %v with error: %w. retrying...", queueInfo.arn, err)
|
|
}
|
|
time.Sleep(sqsPullExponentialBackoff.NextBackOff())
|
|
|
|
continue
|
|
}
|
|
// error either recovered or did not happen at all. resetting the backoff counter (and duration).
|
|
sqsPullExponentialBackoff.Reset()
|
|
|
|
if len(messageResponse.Messages) < 1 {
|
|
continue
|
|
}
|
|
s.logger.Debugf("%v message(s) received on queue %s", len(messageResponse.Messages), queueInfo.arn)
|
|
|
|
var wg sync.WaitGroup
|
|
for _, message := range messageResponse.Messages {
|
|
if err := s.validateMessage(ctx, message, queueInfo, deadLettersQueueInfo); err != nil {
|
|
s.logger.Errorf("message is not valid for further processing by the handler. error is: %v", err)
|
|
continue
|
|
}
|
|
|
|
f := func(message *sqs.Message) {
|
|
defer wg.Done()
|
|
if err := s.callHandler(ctx, message, queueInfo); err != nil {
|
|
s.logger.Errorf("error while handling received message. error is: %v", err)
|
|
}
|
|
}
|
|
|
|
wg.Add(1)
|
|
switch s.metadata.ConcurrencyMode {
|
|
case pubsub.Single:
|
|
f(message)
|
|
case pubsub.Parallel:
|
|
wg.Add(1)
|
|
go func(message *sqs.Message) {
|
|
defer wg.Done()
|
|
f(message)
|
|
}(message)
|
|
}
|
|
}
|
|
wg.Wait()
|
|
}
|
|
}
|
|
|
|
func (s *snsSqs) createDeadLettersQueueAttributes(queueInfo, deadLettersQueueInfo *sqsQueueInfo) (*sqs.SetQueueAttributesInput, error) {
|
|
policy := map[string]string{
|
|
"deadLetterTargetArn": deadLettersQueueInfo.arn,
|
|
"maxReceiveCount": strconv.FormatInt(s.metadata.MessageReceiveLimit, 10),
|
|
}
|
|
|
|
b, err := json.Marshal(policy)
|
|
if err != nil {
|
|
wrappedErr := fmt.Errorf("error marshalling dead-letters queue policy: %w", err)
|
|
s.logger.Error(wrappedErr)
|
|
|
|
return nil, wrappedErr
|
|
}
|
|
|
|
sqsSetQueueAttributesInput := &sqs.SetQueueAttributesInput{
|
|
QueueUrl: &queueInfo.url,
|
|
Attributes: map[string]*string{
|
|
sqs.QueueAttributeNameRedrivePolicy: aws.String(string(b)),
|
|
},
|
|
}
|
|
|
|
return sqsSetQueueAttributesInput, nil
|
|
}
|
|
|
|
func (s *snsSqs) setDeadLettersQueueAttributes(parentCtx context.Context, queueInfo, deadLettersQueueInfo *sqsQueueInfo) error {
|
|
if s.metadata.DisableEntityManagement {
|
|
return nil
|
|
}
|
|
|
|
var sqsSetQueueAttributesInput *sqs.SetQueueAttributesInput
|
|
sqsSetQueueAttributesInput, derr := s.createDeadLettersQueueAttributes(queueInfo, deadLettersQueueInfo)
|
|
if derr != nil {
|
|
wrappedErr := fmt.Errorf("error creating queue attributes for dead-letter queue: %w", derr)
|
|
s.logger.Error(wrappedErr)
|
|
|
|
return wrappedErr
|
|
}
|
|
|
|
ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout)
|
|
_, derr = s.sqsClient.SetQueueAttributesWithContext(ctx, sqsSetQueueAttributesInput)
|
|
cancelFn()
|
|
if derr != nil {
|
|
wrappedErr := fmt.Errorf("error updating queue attributes with dead-letter queue: %w", derr)
|
|
s.logger.Error(wrappedErr)
|
|
|
|
return wrappedErr
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(parentCtx context.Context, sqsQueueInfo *sqsQueueInfo, snsARN string) error {
|
|
// not creating any policies of disableEntityManagement is true.
|
|
if s.metadata.DisableEntityManagement {
|
|
return nil
|
|
}
|
|
|
|
ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout)
|
|
// only permit SNS to send messages to SQS using the created subscription.
|
|
getQueueAttributesOutput, err := s.sqsClient.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{
|
|
QueueUrl: &sqsQueueInfo.url,
|
|
AttributeNames: []*string{aws.String(sqs.QueueAttributeNamePolicy)},
|
|
})
|
|
cancelFn()
|
|
if err != nil {
|
|
return fmt.Errorf("error getting queue attributes: %w", err)
|
|
}
|
|
|
|
policy := &policy{Version: "2012-10-17"}
|
|
if policyStr, ok := getQueueAttributesOutput.Attributes[sqs.QueueAttributeNamePolicy]; ok {
|
|
// look for the current statement if exists, else add it and store.
|
|
if err = json.Unmarshal([]byte(*policyStr), policy); err != nil {
|
|
return fmt.Errorf("error unmarshalling sqs policy: %w", err)
|
|
}
|
|
}
|
|
conditionExists := policy.tryInsertCondition(sqsQueueInfo.arn, snsARN)
|
|
if conditionExists {
|
|
return nil
|
|
}
|
|
|
|
b, uerr := json.Marshal(policy)
|
|
if uerr != nil {
|
|
return fmt.Errorf("failed serializing new sqs policy: %w", uerr)
|
|
}
|
|
|
|
ctx, cancelFn = context.WithTimeout(parentCtx, s.opsTimeout)
|
|
_, err = s.sqsClient.SetQueueAttributesWithContext(ctx, &(sqs.SetQueueAttributesInput{
|
|
Attributes: map[string]*string{
|
|
"Policy": aws.String(string(b)),
|
|
},
|
|
QueueUrl: &sqsQueueInfo.url,
|
|
}))
|
|
cancelFn()
|
|
if err != nil {
|
|
return fmt.Errorf("error setting queue subscription policy: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *snsSqs) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, handler pubsub.Handler) error {
|
|
if s.closed.Load() {
|
|
return errors.New("component is closed")
|
|
}
|
|
|
|
s.topicsLocker.Lock(req.Topic)
|
|
defer s.topicsLocker.Unlock(req.Topic)
|
|
|
|
// subscribers declare a topic ARN and declare a SQS queue to use
|
|
// these should be idempotent - queues should not be created if they exist.
|
|
topicArn, sanitizedName, err := s.getOrCreateTopic(ctx, req.Topic)
|
|
if err != nil {
|
|
wrappedErr := fmt.Errorf("error getting topic ARN for %s: %w", req.Topic, err)
|
|
s.logger.Error(wrappedErr)
|
|
|
|
return wrappedErr
|
|
}
|
|
|
|
// this is the ID of the application, it is supplied via runtime as "consumerID".
|
|
var queueInfo *sqsQueueInfo
|
|
queueInfo, err = s.getOrCreateQueue(ctx, s.metadata.SqsQueueName)
|
|
if err != nil {
|
|
wrappedErr := fmt.Errorf("error retrieving SQS queue: %w", err)
|
|
s.logger.Error(wrappedErr)
|
|
|
|
return wrappedErr
|
|
}
|
|
|
|
// only after a SQS queue and SNS topic had been setup, we restrict the SendMessage action to SNS as sole source
|
|
// to prevent anyone but SNS to publish message to SQS.
|
|
err = s.restrictQueuePublishPolicyToOnlySNS(ctx, queueInfo, topicArn)
|
|
if err != nil {
|
|
wrappedErr := fmt.Errorf("error setting sns-sqs subscription policy: %w", err)
|
|
s.logger.Error(wrappedErr)
|
|
|
|
return wrappedErr
|
|
}
|
|
|
|
// apply the dead letters queue attributes to the current queue.
|
|
var deadLettersQueueInfo *sqsQueueInfo
|
|
var derr error
|
|
|
|
if len(s.metadata.SqsDeadLettersQueueName) > 0 {
|
|
deadLettersQueueInfo, derr = s.getOrCreateQueue(ctx, s.metadata.SqsDeadLettersQueueName)
|
|
if derr != nil {
|
|
wrappedErr := fmt.Errorf("error retrieving SQS dead-letter queue: %w", err)
|
|
s.logger.Error(wrappedErr)
|
|
|
|
return wrappedErr
|
|
}
|
|
|
|
err = s.setDeadLettersQueueAttributes(ctx, queueInfo, deadLettersQueueInfo)
|
|
if err != nil {
|
|
wrappedErr := fmt.Errorf("error creating dead-letter queue: %w", err)
|
|
s.logger.Error(wrappedErr)
|
|
|
|
return wrappedErr
|
|
}
|
|
}
|
|
|
|
// subscription creation is idempotent. Subscriptions are unique by topic/queue.
|
|
_, err = s.getOrCreateSnsSqsSubscription(ctx, queueInfo.arn, topicArn)
|
|
if err != nil {
|
|
wrappedErr := fmt.Errorf("error subscribing topic: %s, to queue: %s, with error: %w", topicArn, queueInfo.arn, err)
|
|
s.logger.Error(wrappedErr)
|
|
|
|
return wrappedErr
|
|
}
|
|
|
|
// start the subscription manager
|
|
s.subscriptionManager.Init(queueInfo, deadLettersQueueInfo, s.consumeSubscription)
|
|
|
|
s.subscriptionManager.Subscribe(&SubscriptionTopicHandler{
|
|
topic: sanitizedName,
|
|
requestTopic: req.Topic,
|
|
handler: handler,
|
|
ctx: ctx,
|
|
})
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *snsSqs) Publish(ctx context.Context, req *pubsub.PublishRequest) error {
|
|
if s.closed.Load() {
|
|
return errors.New("component is closed")
|
|
}
|
|
|
|
topicArn, _, err := s.getOrCreateTopic(ctx, req.Topic)
|
|
if err != nil {
|
|
s.logger.Errorf("error getting topic ARN for %s: %v", req.Topic, err)
|
|
}
|
|
|
|
message := string(req.Data)
|
|
snsPublishInput := &sns.PublishInput{
|
|
Message: aws.String(message),
|
|
TopicArn: aws.String(topicArn),
|
|
}
|
|
if s.metadata.Fifo {
|
|
snsPublishInput.MessageGroupId = s.getMessageGroupID(req)
|
|
}
|
|
|
|
// sns client has internal exponential backoffs.
|
|
_, err = s.snsClient.PublishWithContext(ctx, snsPublishInput)
|
|
if err != nil {
|
|
wrappedErr := fmt.Errorf("error publishing to topic: %s with topic ARN %s: %w", req.Topic, topicArn, err)
|
|
s.logger.Error(wrappedErr)
|
|
|
|
return wrappedErr
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Close should always be called to release the resources used by the SNS/SQS
|
|
// client. Blocks until all goroutines have returned.
|
|
func (s *snsSqs) Close() error {
|
|
if s.closed.CompareAndSwap(false, true) {
|
|
s.subscriptionManager.Close()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *snsSqs) Features() []pubsub.Feature {
|
|
return nil
|
|
}
|
|
|
|
// GetComponentMetadata returns the metadata of the component.
|
|
func (s *snsSqs) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
|
|
metadataStruct := snsSqsMetadata{}
|
|
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.PubSubType)
|
|
return
|
|
}
|