Snssqs subscription policy (#1259)
* bugfix for sns topic deletion upon termination * removed upstream github workflow files * gitignore * restrict SQS send message policy * linting mostly of unwrapped errors * refactoring * pr changes * Update .gitignore * Update dapr-bot-schedule.yml * Update dapr-bot-schedule.yml Co-authored-by: Yaron Schneider <yaronsc@microsoft.com> Co-authored-by: Artur Souza <artursouza.ms@outlook.com> Co-authored-by: Dapr Bot <56698301+dapr-bot@users.noreply.github.com>
This commit is contained in:
parent
f6a64f73fe
commit
e9deaf3781
|
|
@ -18,11 +18,11 @@ import (
|
|||
)
|
||||
|
||||
type snsSqs struct {
|
||||
// key is the topic name, value is the ARN of the topic
|
||||
// key is the topic name, value is the ARN of the topic.
|
||||
topics map[string]string
|
||||
// key is the sanitized topic name, value is the actual topic name
|
||||
// key is the sanitized topic name, value is the actual topic name.
|
||||
topicSanitized map[string]string
|
||||
// key is the topic name, value holds the ARN of the queue and its url
|
||||
// key is the topic name, value holds the ARN of the queue and its url.
|
||||
queues map[string]*sqsQueueInfo
|
||||
snsClient *sns.SNS
|
||||
sqsClient *sqs.SQS
|
||||
|
|
@ -37,31 +37,31 @@ type sqsQueueInfo struct {
|
|||
}
|
||||
|
||||
type snsSqsMetadata struct {
|
||||
// name of the queue for this application. The is provided by the runtime as "consumerID"
|
||||
// name of the queue for this application. The is provided by the runtime as "consumerID".
|
||||
sqsQueueName string
|
||||
// name of the dead letter queue for this application
|
||||
// name of the dead letter queue for this application.
|
||||
sqsDeadLettersQueueName string
|
||||
// aws endpoint for the component to use.
|
||||
Endpoint string
|
||||
// access key to use for accessing sqs/sns
|
||||
// access key to use for accessing sqs/sns.
|
||||
AccessKey string
|
||||
// secret key to use for accessing sqs/sns
|
||||
// secret key to use for accessing sqs/sns.
|
||||
SecretKey string
|
||||
// aws session token to use.
|
||||
SessionToken string
|
||||
// aws region in which SNS/SQS should create resources
|
||||
// aws region in which SNS/SQS should create resources.
|
||||
Region string
|
||||
|
||||
// amount of time in seconds that a message is hidden from receive requests after it is sent to a subscriber. Default: 10
|
||||
// amount of time in seconds that a message is hidden from receive requests after it is sent to a subscriber. Default: 10.
|
||||
messageVisibilityTimeout int64
|
||||
// number of times to resend a message after processing of that message fails before removing that message from the queue. Default: 10
|
||||
// number of times to resend a message after processing of that message fails before removing that message from the queue. Default: 10.
|
||||
messageRetryLimit int64
|
||||
// if sqsDeadLettersQueueName is set to a value, then the messageReceiveLimit defines the number of times a message is received
|
||||
// before it is moved to the dead-letters queue. This value must be smaller than messageRetryLimit
|
||||
// before it is moved to the dead-letters queue. This value must be smaller than messageRetryLimit.
|
||||
messageReceiveLimit int64
|
||||
// amount of time to await receipt of a message before making another request. Default: 1
|
||||
// amount of time to await receipt of a message before making another request. Default: 1.
|
||||
messageWaitTimeSeconds int64
|
||||
// maximum number of messages to receive from the queue at a time. Default: 10, Maximum: 10
|
||||
// maximum number of messages to receive from the queue at a time. Default: 10, Maximum: 10.
|
||||
messageMaxNumber int64
|
||||
}
|
||||
|
||||
|
|
@ -194,7 +194,7 @@ func (s *snsSqs) getSnsSqsMetatdata(metadata pubsub.Metadata) (*snsSqsMetadata,
|
|||
md.messageReceiveLimit = messageReceiveLimit
|
||||
}
|
||||
|
||||
// XOR on having either a valid messageReceiveLimit and invalid sqsDeadLettersQueueName, and vice versa
|
||||
// XOR on having either a valid messageReceiveLimit and invalid sqsDeadLettersQueueName, and vice versa.
|
||||
if (md.messageReceiveLimit > 0 || len(md.sqsDeadLettersQueueName) > 0) && !(md.messageReceiveLimit > 0 && len(md.sqsDeadLettersQueueName) > 0) {
|
||||
return nil, errors.New("to use SQS dead letters queue, messageReceiveLimit and sqsDeadLettersQueueName must both be set to a value")
|
||||
}
|
||||
|
|
@ -243,13 +243,13 @@ func (s *snsSqs) Init(metadata pubsub.Metadata) error {
|
|||
s.metadata = md
|
||||
|
||||
// both Publish and Subscribe need reference the topic ARN
|
||||
// track these ARNs in this map
|
||||
// track these ARNs in this map.
|
||||
s.topics = make(map[string]string)
|
||||
s.topicSanitized = make(map[string]string)
|
||||
s.queues = make(map[string]*sqsQueueInfo)
|
||||
sess, err := aws_auth.GetClient(md.AccessKey, md.SecretKey, md.SessionToken, md.Region, md.Endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("error creating an AWS client: %w", err)
|
||||
}
|
||||
s.snsClient = sns.New(sess)
|
||||
s.sqsClient = sqs.New(sess)
|
||||
|
|
@ -264,7 +264,7 @@ func (s *snsSqs) createTopic(topic string) (string, string, error) {
|
|||
Tags: []*sns.Tag{{Key: aws.String(awsSnsTopicNameKey), Value: aws.String(topic)}},
|
||||
})
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
return "", "", fmt.Errorf("error while creating an SNS topic: %w", err)
|
||||
}
|
||||
|
||||
return *(createTopicResponse.TopicArn), sanitizedName, nil
|
||||
|
|
@ -276,12 +276,12 @@ func (s *snsSqs) getOrCreateTopic(topic string) (string, error) {
|
|||
topicArn, ok := s.topics[topic]
|
||||
|
||||
if ok {
|
||||
s.logger.Debugf("Found existing topic ARN for topic %s: %s", topic, topicArn)
|
||||
s.logger.Debugf("found existing topic ARN for topic %s: %s", topic, topicArn)
|
||||
|
||||
return topicArn, nil
|
||||
}
|
||||
|
||||
s.logger.Debugf("No topic ARN found for %s\n Creating topic instead.", topic)
|
||||
s.logger.Debugf("no topic ARN found for %s\n Creating topic instead.", topic)
|
||||
|
||||
topicArn, sanitizedName, err := s.createTopic(topic)
|
||||
if err != nil {
|
||||
|
|
@ -290,7 +290,7 @@ func (s *snsSqs) getOrCreateTopic(topic string) (string, error) {
|
|||
return "", err
|
||||
}
|
||||
|
||||
// record topic ARN
|
||||
// record topic ARN.
|
||||
s.topics[topic] = topicArn
|
||||
s.topicSanitized[sanitizedName] = topic
|
||||
|
||||
|
|
@ -303,7 +303,7 @@ func (s *snsSqs) createQueue(queueName string) (*sqsQueueInfo, error) {
|
|||
Tags: map[string]*string{awsSqsQueueNameKey: aws.String(queueName)},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("error creaing an SQS queue: %w", err)
|
||||
}
|
||||
|
||||
queueAttributesResponse, err := s.sqsClient.GetQueueAttributes(&sqs.GetQueueAttributesInput{
|
||||
|
|
@ -314,25 +314,6 @@ func (s *snsSqs) createQueue(queueName string) (*sqsQueueInfo, error) {
|
|||
s.logger.Errorf("error fetching queue attributes for %s: %v", queueName, err)
|
||||
}
|
||||
|
||||
// add permissions to allow SNS to send messages to this queue
|
||||
_, err = s.sqsClient.SetQueueAttributes(&(sqs.SetQueueAttributesInput{
|
||||
Attributes: map[string]*string{
|
||||
"Policy": aws.String(fmt.Sprintf(`{
|
||||
"Statement": [{
|
||||
"Effect":"Allow",
|
||||
"Principal":"*",
|
||||
"Action":"sqs:SendMessage",
|
||||
"Resource":"%s"
|
||||
}]
|
||||
}`, *(queueAttributesResponse.Attributes["QueueArn"]))),
|
||||
},
|
||||
QueueUrl: createQueueResponse.QueueUrl,
|
||||
}))
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &sqsQueueInfo{
|
||||
arn: *(queueAttributesResponse.Attributes["QueueArn"]),
|
||||
url: *(createQueueResponse.QueueUrl),
|
||||
|
|
@ -347,7 +328,7 @@ func (s *snsSqs) getOrCreateQueue(queueName string) (*sqsQueueInfo, error) {
|
|||
|
||||
return queueArn, nil
|
||||
}
|
||||
// creating queues is idempotent, the names serve as unique keys among a given region
|
||||
// creating queues is idempotent, the names serve as unique keys among a given region.
|
||||
s.logger.Debugf("No queue arn found for %s\nCreating queue", queueName)
|
||||
|
||||
queueInfo, err := s.createQueue(queueName)
|
||||
|
|
@ -375,9 +356,10 @@ func (s *snsSqs) Publish(req *pubsub.PublishRequest) error {
|
|||
})
|
||||
|
||||
if err != nil {
|
||||
s.logger.Errorf("error publishing topic %s with topic ARN %s: %v", req.Topic, topicArn, err)
|
||||
wrappedErr := fmt.Errorf("error publishing to topic: %s with topic ARN %s: %v", req.Topic, topicArn, err)
|
||||
s.logger.Error(wrappedErr)
|
||||
|
||||
return err
|
||||
return wrappedErr
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
@ -398,11 +380,11 @@ func (s *snsSqs) acknowledgeMessage(queueURL string, receiptHandle *string) erro
|
|||
ReceiptHandle: receiptHandle,
|
||||
})
|
||||
|
||||
return err
|
||||
return fmt.Errorf("error deleting SQS message: %w", err)
|
||||
}
|
||||
|
||||
func (s *snsSqs) handleMessage(message *sqs.Message, queueInfo, deadLettersQueueInfo *sqsQueueInfo, handler pubsub.Handler) error {
|
||||
// if this message has been received > x times, delete from queue, it's borked
|
||||
// if this message has been received > x times, delete from queue, it's borked.
|
||||
recvCount, ok := message.Attributes[sqs.MessageSystemAttributeNameApproximateReceiveCount]
|
||||
|
||||
if !ok {
|
||||
|
|
@ -425,7 +407,7 @@ func (s *snsSqs) handleMessage(message *sqs.Message, queueInfo, deadLettersQueue
|
|||
"message received greater than %v times, deleting this message without further processing", s.metadata.messageRetryLimit)
|
||||
}
|
||||
// ... else, there is no need to actively do something if we reached the limit defined in messageReceiveLimit as the message had
|
||||
// already been moved to the dead-letters queue by SQS
|
||||
// already been moved to the dead-letters queue by SQS.
|
||||
if deadLettersQueueInfo != nil && recvCountInt >= s.metadata.messageReceiveLimit {
|
||||
s.logger.Warnf(
|
||||
"message received greater than %v times, moving this message without further processing to dead-letters queue: %v", s.metadata.messageReceiveLimit, s.metadata.sqsDeadLettersQueueName)
|
||||
|
|
@ -450,7 +432,7 @@ func (s *snsSqs) handleMessage(message *sqs.Message, queueInfo, deadLettersQueue
|
|||
return fmt.Errorf("error handling message: %w", err)
|
||||
}
|
||||
|
||||
// otherwise, there was no error, acknowledge the message
|
||||
// otherwise, there was no error, acknowledge the message.
|
||||
return s.acknowledgeMessage(queueInfo.url, message.ReceiptHandle)
|
||||
}
|
||||
|
||||
|
|
@ -458,7 +440,7 @@ func (s *snsSqs) consumeSubscription(queueInfo, deadLettersQueueInfo *sqsQueueIn
|
|||
go func() {
|
||||
for {
|
||||
messageResponse, err := s.sqsClient.ReceiveMessage(&sqs.ReceiveMessageInput{
|
||||
// use this property to decide when a message should be discarded
|
||||
// use this property to decide when a message should be discarded.
|
||||
AttributeNames: []*string{
|
||||
aws.String(sqs.MessageSystemAttributeNameApproximateReceiveCount),
|
||||
},
|
||||
|
|
@ -473,7 +455,7 @@ func (s *snsSqs) consumeSubscription(queueInfo, deadLettersQueueInfo *sqsQueueIn
|
|||
continue
|
||||
}
|
||||
|
||||
// retry receiving messages
|
||||
// retry receiving messages.
|
||||
if len(messageResponse.Messages) < 1 {
|
||||
s.logger.Debug("No messages received, requesting again")
|
||||
|
||||
|
|
@ -495,7 +477,7 @@ func (s *snsSqs) createDeadLettersQueue() (*sqsQueueInfo, error) {
|
|||
var deadLettersQueueInfo *sqsQueueInfo
|
||||
deadLettersQueueInfo, err := s.getOrCreateQueue(s.metadata.sqsDeadLettersQueueName)
|
||||
if err != nil {
|
||||
s.logger.Errorf("error retrieving SQS dead-letter queue: %v", err)
|
||||
s.logger.Errorf("error retrieving SQS dead-letter queue: %w", err)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -511,9 +493,10 @@ func (s *snsSqs) createQueueAttributesWithDeadLetters(queueInfo, deadLettersQueu
|
|||
|
||||
b, err := json.Marshal(policy)
|
||||
if err != nil {
|
||||
s.logger.Errorf("error marshalling dead-letters queue policy: %v", err)
|
||||
wrappedErr := fmt.Errorf("error marshalling dead-letters queue policy: %w", err)
|
||||
s.logger.Error(wrappedErr)
|
||||
|
||||
return nil, err
|
||||
return nil, wrappedErr
|
||||
}
|
||||
|
||||
sqsSetQueueAttributesInput := &sqs.SetQueueAttributesInput{
|
||||
|
|
@ -526,33 +509,68 @@ func (s *snsSqs) createQueueAttributesWithDeadLetters(queueInfo, deadLettersQueu
|
|||
return sqsSetQueueAttributesInput, nil
|
||||
}
|
||||
|
||||
func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(sqsQueueInfo *sqsQueueInfo, snsARN string) error {
|
||||
// only permit SNS to send messages to SQS using the created subscription.
|
||||
if _, err := s.sqsClient.SetQueueAttributes(&(sqs.SetQueueAttributesInput{
|
||||
Attributes: map[string]*string{
|
||||
"Policy": aws.String(fmt.Sprintf(`{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [{
|
||||
"Effect":"Allow",
|
||||
"Principal":{"Service": "sns.amazonaws.com"},
|
||||
"Action":"sqs:SendMessage",
|
||||
"Resource":"%s",
|
||||
"Condition": {
|
||||
"ArnEquals":{
|
||||
"aws:SourceArn":"%s"
|
||||
}
|
||||
}
|
||||
}]
|
||||
}`, sqsQueueInfo.arn, snsARN)),
|
||||
},
|
||||
QueueUrl: &sqsQueueInfo.url,
|
||||
})); err != nil {
|
||||
return fmt.Errorf("error setting queue subscription policy: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler) error {
|
||||
// subscribers declare a topic ARN
|
||||
// and declare a SQS queue to use
|
||||
// these should be idempotent
|
||||
// queues should not be created if they exist
|
||||
// subscribers declare a topic ARN and declare a SQS queue to use
|
||||
// these should be idempotent - queues should not be created if they exist.
|
||||
topicArn, err := s.getOrCreateTopic(req.Topic)
|
||||
if err != nil {
|
||||
s.logger.Errorf("error getting topic ARN for %s: %v", req.Topic, err)
|
||||
s.logger.Errorf("error getting topic ARN for %s: %w", req.Topic, err)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// this is the ID of the application, it is supplied via runtime as "consumerID"
|
||||
// this is the ID of the application, it is supplied via runtime as "consumerID".
|
||||
var queueInfo *sqsQueueInfo
|
||||
queueInfo, err = s.getOrCreateQueue(s.metadata.sqsQueueName)
|
||||
if err != nil {
|
||||
s.logger.Errorf("error retrieving SQS queue: %v", err)
|
||||
s.logger.Errorf("error retrieving SQS queue: %w", err)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// only after a SQS queue and SNS topic had been setup, we restrict the SendMessage action to SNS as sole source
|
||||
// to prevent anyone but SNS to publish message to SQS.
|
||||
err = s.restrictQueuePublishPolicyToOnlySNS(queueInfo, topicArn)
|
||||
if err != nil {
|
||||
s.logger.Errorf("error setting sns-sqs subscription policy: %w", err)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// apply the dead letters queue attributes to the current queue.
|
||||
var deadLettersQueueInfo *sqsQueueInfo
|
||||
if len(s.metadata.sqsDeadLettersQueueName) > 0 {
|
||||
var derr error
|
||||
deadLettersQueueInfo, derr = s.createDeadLettersQueue()
|
||||
if derr != nil {
|
||||
s.logger.Errorf("error creating dead-letter queue: %v", derr)
|
||||
s.logger.Errorf("error creating dead-letter queue: %w", derr)
|
||||
|
||||
return derr
|
||||
}
|
||||
|
|
@ -560,21 +578,20 @@ func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler)
|
|||
var sqsSetQueueAttributesInput *sqs.SetQueueAttributesInput
|
||||
sqsSetQueueAttributesInput, derr = s.createQueueAttributesWithDeadLetters(queueInfo, deadLettersQueueInfo)
|
||||
if derr != nil {
|
||||
s.logger.Errorf("error creatubg queue attributes for dead-letter queue: %v", derr)
|
||||
s.logger.Errorf("error creatubg queue attributes for dead-letter queue: %w", derr)
|
||||
|
||||
return derr
|
||||
}
|
||||
_, derr = s.sqsClient.SetQueueAttributes(sqsSetQueueAttributesInput)
|
||||
if derr != nil {
|
||||
s.logger.Errorf("error updating queue attributes with dead-letter queue: %v", derr)
|
||||
wrappedErr := fmt.Errorf("error updating queue attributes with dead-letter queue: %w", derr)
|
||||
s.logger.Error(wrappedErr)
|
||||
|
||||
return derr
|
||||
return wrappedErr
|
||||
}
|
||||
}
|
||||
|
||||
// apply the dead letters queue attributes to the current queue
|
||||
|
||||
// subscription creation is idempotent. Subscriptions are unique by topic/queue
|
||||
// subscription creation is idempotent. Subscriptions are unique by topic/queue.
|
||||
subscribeOutput, err := s.snsClient.Subscribe(&sns.SubscribeInput{
|
||||
Attributes: nil,
|
||||
Endpoint: &queueInfo.arn, // create SQS queue per subscription
|
||||
|
|
@ -583,9 +600,10 @@ func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler)
|
|||
TopicArn: &topicArn,
|
||||
})
|
||||
if err != nil {
|
||||
s.logger.Errorf("error subscribing to topic %s: %v", req.Topic, err)
|
||||
wrappedErr := fmt.Errorf("error subscribing to topic %s: %w", req.Topic, err)
|
||||
s.logger.Error(wrappedErr)
|
||||
|
||||
return err
|
||||
return wrappedErr
|
||||
}
|
||||
|
||||
s.subscriptions = append(s.subscriptions, subscribeOutput.SubscriptionArn)
|
||||
|
|
|
|||
Loading…
Reference in New Issue