Snssqs subscription policy (#1259)

* bugfix for sns topic deletion upon termination

* removed upstream github workflow files

* gitignore

* restrict SQS send message policy

* linting mostly of unwrapped errors

* refactoring

* pr changes

* Update .gitignore

* Update dapr-bot-schedule.yml

* Update dapr-bot-schedule.yml

Co-authored-by: Yaron Schneider <yaronsc@microsoft.com>
Co-authored-by: Artur Souza <artursouza.ms@outlook.com>
Co-authored-by: Dapr Bot <56698301+dapr-bot@users.noreply.github.com>
This commit is contained in:
Amit Mor 2021-11-02 19:38:10 +02:00 committed by GitHub
parent f6a64f73fe
commit e9deaf3781
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 86 additions and 68 deletions

View File

@ -18,11 +18,11 @@ import (
)
type snsSqs struct {
// key is the topic name, value is the ARN of the topic
// key is the topic name, value is the ARN of the topic.
topics map[string]string
// key is the sanitized topic name, value is the actual topic name
// key is the sanitized topic name, value is the actual topic name.
topicSanitized map[string]string
// key is the topic name, value holds the ARN of the queue and its url
// key is the topic name, value holds the ARN of the queue and its url.
queues map[string]*sqsQueueInfo
snsClient *sns.SNS
sqsClient *sqs.SQS
@ -37,31 +37,31 @@ type sqsQueueInfo struct {
}
type snsSqsMetadata struct {
// name of the queue for this application. The is provided by the runtime as "consumerID"
// name of the queue for this application. The is provided by the runtime as "consumerID".
sqsQueueName string
// name of the dead letter queue for this application
// name of the dead letter queue for this application.
sqsDeadLettersQueueName string
// aws endpoint for the component to use.
Endpoint string
// access key to use for accessing sqs/sns
// access key to use for accessing sqs/sns.
AccessKey string
// secret key to use for accessing sqs/sns
// secret key to use for accessing sqs/sns.
SecretKey string
// aws session token to use.
SessionToken string
// aws region in which SNS/SQS should create resources
// aws region in which SNS/SQS should create resources.
Region string
// amount of time in seconds that a message is hidden from receive requests after it is sent to a subscriber. Default: 10
// amount of time in seconds that a message is hidden from receive requests after it is sent to a subscriber. Default: 10.
messageVisibilityTimeout int64
// number of times to resend a message after processing of that message fails before removing that message from the queue. Default: 10
// number of times to resend a message after processing of that message fails before removing that message from the queue. Default: 10.
messageRetryLimit int64
// if sqsDeadLettersQueueName is set to a value, then the messageReceiveLimit defines the number of times a message is received
// before it is moved to the dead-letters queue. This value must be smaller than messageRetryLimit
// before it is moved to the dead-letters queue. This value must be smaller than messageRetryLimit.
messageReceiveLimit int64
// amount of time to await receipt of a message before making another request. Default: 1
// amount of time to await receipt of a message before making another request. Default: 1.
messageWaitTimeSeconds int64
// maximum number of messages to receive from the queue at a time. Default: 10, Maximum: 10
// maximum number of messages to receive from the queue at a time. Default: 10, Maximum: 10.
messageMaxNumber int64
}
@ -194,7 +194,7 @@ func (s *snsSqs) getSnsSqsMetatdata(metadata pubsub.Metadata) (*snsSqsMetadata,
md.messageReceiveLimit = messageReceiveLimit
}
// XOR on having either a valid messageReceiveLimit and invalid sqsDeadLettersQueueName, and vice versa
// XOR on having either a valid messageReceiveLimit and invalid sqsDeadLettersQueueName, and vice versa.
if (md.messageReceiveLimit > 0 || len(md.sqsDeadLettersQueueName) > 0) && !(md.messageReceiveLimit > 0 && len(md.sqsDeadLettersQueueName) > 0) {
return nil, errors.New("to use SQS dead letters queue, messageReceiveLimit and sqsDeadLettersQueueName must both be set to a value")
}
@ -243,13 +243,13 @@ func (s *snsSqs) Init(metadata pubsub.Metadata) error {
s.metadata = md
// both Publish and Subscribe need reference the topic ARN
// track these ARNs in this map
// track these ARNs in this map.
s.topics = make(map[string]string)
s.topicSanitized = make(map[string]string)
s.queues = make(map[string]*sqsQueueInfo)
sess, err := aws_auth.GetClient(md.AccessKey, md.SecretKey, md.SessionToken, md.Region, md.Endpoint)
if err != nil {
return err
return fmt.Errorf("error creating an AWS client: %w", err)
}
s.snsClient = sns.New(sess)
s.sqsClient = sqs.New(sess)
@ -264,7 +264,7 @@ func (s *snsSqs) createTopic(topic string) (string, string, error) {
Tags: []*sns.Tag{{Key: aws.String(awsSnsTopicNameKey), Value: aws.String(topic)}},
})
if err != nil {
return "", "", err
return "", "", fmt.Errorf("error while creating an SNS topic: %w", err)
}
return *(createTopicResponse.TopicArn), sanitizedName, nil
@ -276,12 +276,12 @@ func (s *snsSqs) getOrCreateTopic(topic string) (string, error) {
topicArn, ok := s.topics[topic]
if ok {
s.logger.Debugf("Found existing topic ARN for topic %s: %s", topic, topicArn)
s.logger.Debugf("found existing topic ARN for topic %s: %s", topic, topicArn)
return topicArn, nil
}
s.logger.Debugf("No topic ARN found for %s\n Creating topic instead.", topic)
s.logger.Debugf("no topic ARN found for %s\n Creating topic instead.", topic)
topicArn, sanitizedName, err := s.createTopic(topic)
if err != nil {
@ -290,7 +290,7 @@ func (s *snsSqs) getOrCreateTopic(topic string) (string, error) {
return "", err
}
// record topic ARN
// record topic ARN.
s.topics[topic] = topicArn
s.topicSanitized[sanitizedName] = topic
@ -303,7 +303,7 @@ func (s *snsSqs) createQueue(queueName string) (*sqsQueueInfo, error) {
Tags: map[string]*string{awsSqsQueueNameKey: aws.String(queueName)},
})
if err != nil {
return nil, err
return nil, fmt.Errorf("error creaing an SQS queue: %w", err)
}
queueAttributesResponse, err := s.sqsClient.GetQueueAttributes(&sqs.GetQueueAttributesInput{
@ -314,25 +314,6 @@ func (s *snsSqs) createQueue(queueName string) (*sqsQueueInfo, error) {
s.logger.Errorf("error fetching queue attributes for %s: %v", queueName, err)
}
// add permissions to allow SNS to send messages to this queue
_, err = s.sqsClient.SetQueueAttributes(&(sqs.SetQueueAttributesInput{
Attributes: map[string]*string{
"Policy": aws.String(fmt.Sprintf(`{
"Statement": [{
"Effect":"Allow",
"Principal":"*",
"Action":"sqs:SendMessage",
"Resource":"%s"
}]
}`, *(queueAttributesResponse.Attributes["QueueArn"]))),
},
QueueUrl: createQueueResponse.QueueUrl,
}))
if err != nil {
return nil, err
}
return &sqsQueueInfo{
arn: *(queueAttributesResponse.Attributes["QueueArn"]),
url: *(createQueueResponse.QueueUrl),
@ -347,7 +328,7 @@ func (s *snsSqs) getOrCreateQueue(queueName string) (*sqsQueueInfo, error) {
return queueArn, nil
}
// creating queues is idempotent, the names serve as unique keys among a given region
// creating queues is idempotent, the names serve as unique keys among a given region.
s.logger.Debugf("No queue arn found for %s\nCreating queue", queueName)
queueInfo, err := s.createQueue(queueName)
@ -375,9 +356,10 @@ func (s *snsSqs) Publish(req *pubsub.PublishRequest) error {
})
if err != nil {
s.logger.Errorf("error publishing topic %s with topic ARN %s: %v", req.Topic, topicArn, err)
wrappedErr := fmt.Errorf("error publishing to topic: %s with topic ARN %s: %v", req.Topic, topicArn, err)
s.logger.Error(wrappedErr)
return err
return wrappedErr
}
return nil
@ -398,11 +380,11 @@ func (s *snsSqs) acknowledgeMessage(queueURL string, receiptHandle *string) erro
ReceiptHandle: receiptHandle,
})
return err
return fmt.Errorf("error deleting SQS message: %w", err)
}
func (s *snsSqs) handleMessage(message *sqs.Message, queueInfo, deadLettersQueueInfo *sqsQueueInfo, handler pubsub.Handler) error {
// if this message has been received > x times, delete from queue, it's borked
// if this message has been received > x times, delete from queue, it's borked.
recvCount, ok := message.Attributes[sqs.MessageSystemAttributeNameApproximateReceiveCount]
if !ok {
@ -425,7 +407,7 @@ func (s *snsSqs) handleMessage(message *sqs.Message, queueInfo, deadLettersQueue
"message received greater than %v times, deleting this message without further processing", s.metadata.messageRetryLimit)
}
// ... else, there is no need to actively do something if we reached the limit defined in messageReceiveLimit as the message had
// already been moved to the dead-letters queue by SQS
// already been moved to the dead-letters queue by SQS.
if deadLettersQueueInfo != nil && recvCountInt >= s.metadata.messageReceiveLimit {
s.logger.Warnf(
"message received greater than %v times, moving this message without further processing to dead-letters queue: %v", s.metadata.messageReceiveLimit, s.metadata.sqsDeadLettersQueueName)
@ -450,7 +432,7 @@ func (s *snsSqs) handleMessage(message *sqs.Message, queueInfo, deadLettersQueue
return fmt.Errorf("error handling message: %w", err)
}
// otherwise, there was no error, acknowledge the message
// otherwise, there was no error, acknowledge the message.
return s.acknowledgeMessage(queueInfo.url, message.ReceiptHandle)
}
@ -458,7 +440,7 @@ func (s *snsSqs) consumeSubscription(queueInfo, deadLettersQueueInfo *sqsQueueIn
go func() {
for {
messageResponse, err := s.sqsClient.ReceiveMessage(&sqs.ReceiveMessageInput{
// use this property to decide when a message should be discarded
// use this property to decide when a message should be discarded.
AttributeNames: []*string{
aws.String(sqs.MessageSystemAttributeNameApproximateReceiveCount),
},
@ -473,7 +455,7 @@ func (s *snsSqs) consumeSubscription(queueInfo, deadLettersQueueInfo *sqsQueueIn
continue
}
// retry receiving messages
// retry receiving messages.
if len(messageResponse.Messages) < 1 {
s.logger.Debug("No messages received, requesting again")
@ -495,7 +477,7 @@ func (s *snsSqs) createDeadLettersQueue() (*sqsQueueInfo, error) {
var deadLettersQueueInfo *sqsQueueInfo
deadLettersQueueInfo, err := s.getOrCreateQueue(s.metadata.sqsDeadLettersQueueName)
if err != nil {
s.logger.Errorf("error retrieving SQS dead-letter queue: %v", err)
s.logger.Errorf("error retrieving SQS dead-letter queue: %w", err)
return nil, err
}
@ -511,9 +493,10 @@ func (s *snsSqs) createQueueAttributesWithDeadLetters(queueInfo, deadLettersQueu
b, err := json.Marshal(policy)
if err != nil {
s.logger.Errorf("error marshalling dead-letters queue policy: %v", err)
wrappedErr := fmt.Errorf("error marshalling dead-letters queue policy: %w", err)
s.logger.Error(wrappedErr)
return nil, err
return nil, wrappedErr
}
sqsSetQueueAttributesInput := &sqs.SetQueueAttributesInput{
@ -526,33 +509,68 @@ func (s *snsSqs) createQueueAttributesWithDeadLetters(queueInfo, deadLettersQueu
return sqsSetQueueAttributesInput, nil
}
func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(sqsQueueInfo *sqsQueueInfo, snsARN string) error {
// only permit SNS to send messages to SQS using the created subscription.
if _, err := s.sqsClient.SetQueueAttributes(&(sqs.SetQueueAttributesInput{
Attributes: map[string]*string{
"Policy": aws.String(fmt.Sprintf(`{
"Version": "2012-10-17",
"Statement": [{
"Effect":"Allow",
"Principal":{"Service": "sns.amazonaws.com"},
"Action":"sqs:SendMessage",
"Resource":"%s",
"Condition": {
"ArnEquals":{
"aws:SourceArn":"%s"
}
}
}]
}`, sqsQueueInfo.arn, snsARN)),
},
QueueUrl: &sqsQueueInfo.url,
})); err != nil {
return fmt.Errorf("error setting queue subscription policy: %w", err)
}
return nil
}
func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler) error {
// subscribers declare a topic ARN
// and declare a SQS queue to use
// these should be idempotent
// queues should not be created if they exist
// subscribers declare a topic ARN and declare a SQS queue to use
// these should be idempotent - queues should not be created if they exist.
topicArn, err := s.getOrCreateTopic(req.Topic)
if err != nil {
s.logger.Errorf("error getting topic ARN for %s: %v", req.Topic, err)
s.logger.Errorf("error getting topic ARN for %s: %w", req.Topic, err)
return err
}
// this is the ID of the application, it is supplied via runtime as "consumerID"
// this is the ID of the application, it is supplied via runtime as "consumerID".
var queueInfo *sqsQueueInfo
queueInfo, err = s.getOrCreateQueue(s.metadata.sqsQueueName)
if err != nil {
s.logger.Errorf("error retrieving SQS queue: %v", err)
s.logger.Errorf("error retrieving SQS queue: %w", err)
return err
}
// only after a SQS queue and SNS topic had been setup, we restrict the SendMessage action to SNS as sole source
// to prevent anyone but SNS to publish message to SQS.
err = s.restrictQueuePublishPolicyToOnlySNS(queueInfo, topicArn)
if err != nil {
s.logger.Errorf("error setting sns-sqs subscription policy: %w", err)
return err
}
// apply the dead letters queue attributes to the current queue.
var deadLettersQueueInfo *sqsQueueInfo
if len(s.metadata.sqsDeadLettersQueueName) > 0 {
var derr error
deadLettersQueueInfo, derr = s.createDeadLettersQueue()
if derr != nil {
s.logger.Errorf("error creating dead-letter queue: %v", derr)
s.logger.Errorf("error creating dead-letter queue: %w", derr)
return derr
}
@ -560,21 +578,20 @@ func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler)
var sqsSetQueueAttributesInput *sqs.SetQueueAttributesInput
sqsSetQueueAttributesInput, derr = s.createQueueAttributesWithDeadLetters(queueInfo, deadLettersQueueInfo)
if derr != nil {
s.logger.Errorf("error creatubg queue attributes for dead-letter queue: %v", derr)
s.logger.Errorf("error creatubg queue attributes for dead-letter queue: %w", derr)
return derr
}
_, derr = s.sqsClient.SetQueueAttributes(sqsSetQueueAttributesInput)
if derr != nil {
s.logger.Errorf("error updating queue attributes with dead-letter queue: %v", derr)
wrappedErr := fmt.Errorf("error updating queue attributes with dead-letter queue: %w", derr)
s.logger.Error(wrappedErr)
return derr
return wrappedErr
}
}
// apply the dead letters queue attributes to the current queue
// subscription creation is idempotent. Subscriptions are unique by topic/queue
// subscription creation is idempotent. Subscriptions are unique by topic/queue.
subscribeOutput, err := s.snsClient.Subscribe(&sns.SubscribeInput{
Attributes: nil,
Endpoint: &queueInfo.arn, // create SQS queue per subscription
@ -583,9 +600,10 @@ func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler)
TopicArn: &topicArn,
})
if err != nil {
s.logger.Errorf("error subscribing to topic %s: %v", req.Topic, err)
wrappedErr := fmt.Errorf("error subscribing to topic %s: %w", req.Topic, err)
s.logger.Error(wrappedErr)
return err
return wrappedErr
}
s.subscriptions = append(s.subscriptions, subscribeOutput.SubscriptionArn)