diff --git a/pubsub/aws/snssqs/snssqs.go b/pubsub/aws/snssqs/snssqs.go index 58849aaeb..46cbc6dfe 100644 --- a/pubsub/aws/snssqs/snssqs.go +++ b/pubsub/aws/snssqs/snssqs.go @@ -65,6 +65,27 @@ type snsSqsMetadata struct { messageMaxNumber int64 } +type arnEquals struct { + AwsSourceArn string `json:"aws:SourceArn"` +} + +type condition struct { + ArnEquals arnEquals +} + +type statement struct { + Effect string + Principal string + Action string + Resource string + Condition condition +} + +type policy struct { + Version string + Statement []statement +} + const ( awsSqsQueueNameKey = "dapr-queue-name" awsSnsTopicNameKey = "dapr-topic-name" @@ -122,6 +143,23 @@ func nameToAWSSanitizedName(name string) string { return string(s[:j]) } +func (p *policy) statementExists(other *statement) bool { + for _, s := range p.Statement { + if s.Effect == other.Effect && + s.Principal == other.Principal && + s.Action == other.Action && + s.Resource == other.Resource && + s.Condition.ArnEquals.AwsSourceArn == other.Condition.ArnEquals.AwsSourceArn { + return true + } + } + return false +} + +func (p *policy) addStatement(other *statement) { + p.Statement = append(p.Statement, *other) +} + func (s *snsSqs) getSnsSqsMetatdata(metadata pubsub.Metadata) (*snsSqsMetadata, error) { md := snsSqsMetadata{} props := metadata.Properties @@ -356,7 +394,7 @@ func (s *snsSqs) Publish(req *pubsub.PublishRequest) error { }) if err != nil { - wrappedErr := fmt.Errorf("error publishing to topic: %s with topic ARN %s: %v", req.Topic, topicArn, err) + wrappedErr := fmt.Errorf("error publishing to topic: %s with topic ARN %s: %w", req.Topic, topicArn, err) s.logger.Error(wrappedErr) return wrappedErr @@ -375,12 +413,14 @@ func parseTopicArn(arn string) string { } func (s *snsSqs) acknowledgeMessage(queueURL string, receiptHandle *string) error { - _, err := s.sqsClient.DeleteMessage(&sqs.DeleteMessageInput{ + if _, err := s.sqsClient.DeleteMessage(&sqs.DeleteMessageInput{ QueueUrl: &queueURL, ReceiptHandle: receiptHandle, - }) + }); err != nil { + return fmt.Errorf("error deleting SQS message: %w", err) + } - return fmt.Errorf("error deleting SQS message: %w", err) + return nil } func (s *snsSqs) handleMessage(message *sqs.Message, queueInfo, deadLettersQueueInfo *sqsQueueInfo, handler pubsub.Handler) error { @@ -413,7 +453,7 @@ func (s *snsSqs) handleMessage(message *sqs.Message, queueInfo, deadLettersQueue "message received greater than %v times, moving this message without further processing to dead-letters queue: %v", s.metadata.messageReceiveLimit, s.metadata.sqsDeadLettersQueueName) } - // otherwise try to handle the message + // otherwise try to handle the message. var messageBody snsMessage err = json.Unmarshal([]byte(*(message.Body)), &messageBody) @@ -511,22 +551,44 @@ func (s *snsSqs) createQueueAttributesWithDeadLetters(queueInfo, deadLettersQueu func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(sqsQueueInfo *sqsQueueInfo, snsARN string) error { // only permit SNS to send messages to SQS using the created subscription. - if _, err := s.sqsClient.SetQueueAttributes(&(sqs.SetQueueAttributesInput{ + getQueueAttributesOutput, err := s.sqsClient.GetQueueAttributes(&sqs.GetQueueAttributesInput{QueueUrl: &sqsQueueInfo.url, AttributeNames: []*string{aws.String(sqs.QueueAttributeNamePolicy)}}) + if err != nil { + return fmt.Errorf("error getting queue attributes: %w", err) + } + + newStatement := &statement{ + Effect: "Allow", + Principal: `{"Service": "sns.amazonaws.com"}`, + Action: "sqs:SendMessage", + Resource: sqsQueueInfo.arn, + Condition: condition{ + ArnEquals: arnEquals{ + AwsSourceArn: snsARN, + }, + }, + } + + policy := &policy{Version: "2012-11-05"} + if policyStr, ok := getQueueAttributesOutput.Attributes[sqs.QueueAttributeNamePolicy]; ok { + // look for the current statement if exists, else add it and store. + if err = json.Unmarshal([]byte(*policyStr), policy); err != nil { + return fmt.Errorf("error unmarshalling sqs policy: %w", err) + } + if policy.statementExists(newStatement) { + // nothing to do. + return nil + } + } + + policy.addStatement(newStatement) + b, uerr := json.Marshal(policy) + if uerr != nil { + return fmt.Errorf("failed serializing new sqs policy: %w", uerr) + } + + if _, err = s.sqsClient.SetQueueAttributes(&(sqs.SetQueueAttributesInput{ Attributes: map[string]*string{ - "Policy": aws.String(fmt.Sprintf(`{ - "Version": "2012-10-17", - "Statement": [{ - "Effect":"Allow", - "Principal":{"Service": "sns.amazonaws.com"}, - "Action":"sqs:SendMessage", - "Resource":"%s", - "Condition": { - "ArnEquals":{ - "aws:SourceArn":"%s" - } - } - }] - }`, sqsQueueInfo.arn, snsARN)), + "Policy": aws.String(string(b)), }, QueueUrl: &sqsQueueInfo.url, })); err != nil { @@ -594,7 +656,7 @@ func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler) // subscription creation is idempotent. Subscriptions are unique by topic/queue. subscribeOutput, err := s.snsClient.Subscribe(&sns.SubscribeInput{ Attributes: nil, - Endpoint: &queueInfo.arn, // create SQS queue per subscription + Endpoint: &queueInfo.arn, // create SQS queue per subscription. Protocol: aws.String("sqs"), ReturnSubscriptionArn: nil, TopicArn: &topicArn,