diff --git a/pubsub/aws/snssqs/policy.go b/pubsub/aws/snssqs/policy.go index e1b6e6095..51b756aec 100644 --- a/pubsub/aws/snssqs/policy.go +++ b/pubsub/aws/snssqs/policy.go @@ -13,12 +13,47 @@ limitations under the License. package snssqs +import ( + "encoding/json" + "reflect" +) + type arnEquals struct { - AwsSourceArn string `json:"aws:SourceArn"` + AwsSourceArn awsSourceArn `json:"aws:SourceArn"` +} + +type awsSourceArn []string + +// UnmarshalJSON This is a custom unmarshaler for awsSourceArn for handling a special case +// where aws flatten awsSourceArn into a string when it only contains one element +func (a *awsSourceArn) UnmarshalJSON(data []byte) error { + var i interface{} + err := json.Unmarshal(data, &i) + if err != nil { + return err + } + + items := reflect.ValueOf(i) + switch items.Kind() { + case reflect.String: + *a = append(*a, items.String()) + case reflect.Slice: + *a = make([]string, 0, items.Len()) + for i := 0; i < items.Len(); i++ { + item := items.Index(i) + switch item.Kind() { + case reflect.String: + *a = append(*a, item.String()) + case reflect.Interface: + *a = append(*a, item.Interface().(string)) + } + } + } + return nil } type condition struct { - ArnEquals arnEquals + ForAllValuesArnEquals arnEquals `json:"ForAllValues:ArnEquals"` } type principal struct { @@ -38,20 +73,33 @@ type policy struct { Statement []statement } -func (p *policy) statementExists(other *statement) bool { - for _, s := range p.Statement { - if s.Effect == other.Effect && - s.Principal.Service == other.Principal.Service && - s.Action == other.Action && - s.Resource == other.Resource && - s.Condition.ArnEquals.AwsSourceArn == other.Condition.ArnEquals.AwsSourceArn { - return true +func (p *policy) tryInsertCondition(sqsArn string, snsArn string) bool { + for i, s := range p.Statement { + // if there is a statement for sqsArn + if s.Resource == sqsArn { + // check if the snsArn already exists + for _, a := range s.Condition.ForAllValuesArnEquals.AwsSourceArn { + if a == snsArn { + return true + } + } + // insert it if it does not exist + p.Statement[i].Condition.ForAllValuesArnEquals.AwsSourceArn = append(p.Statement[i].Condition.ForAllValuesArnEquals.AwsSourceArn, snsArn) + return false } } - + // insert a new statement if no statement for the sqsArn + newStatement := &statement{ + Effect: "Allow", + Principal: principal{Service: "sns.amazonaws.com"}, + Action: "sqs:SendMessage", + Resource: sqsArn, + Condition: condition{ + ForAllValuesArnEquals: arnEquals{ + AwsSourceArn: []string{snsArn}, + }, + }, + } + p.Statement = append(p.Statement, *newStatement) return false } - -func (p *policy) addStatement(other *statement) { - p.Statement = append(p.Statement, *other) -} diff --git a/pubsub/aws/snssqs/snssqs.go b/pubsub/aws/snssqs/snssqs.go index 608d1f0ea..b11a06ad8 100644 --- a/pubsub/aws/snssqs/snssqs.go +++ b/pubsub/aws/snssqs/snssqs.go @@ -731,31 +731,18 @@ func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(parentCtx context.Context, return fmt.Errorf("error getting queue attributes: %w", err) } - newStatement := &statement{ - Effect: "Allow", - Principal: principal{Service: "sns.amazonaws.com"}, - Action: "sqs:SendMessage", - Resource: sqsQueueInfo.arn, - Condition: condition{ - ArnEquals: arnEquals{ - AwsSourceArn: snsARN, - }, - }, - } - policy := &policy{Version: "2012-10-17"} if policyStr, ok := getQueueAttributesOutput.Attributes[sqs.QueueAttributeNamePolicy]; ok { // look for the current statement if exists, else add it and store. if err = json.Unmarshal([]byte(*policyStr), policy); err != nil { return fmt.Errorf("error unmarshalling sqs policy: %w", err) } - if policy.statementExists(newStatement) { - // nothing to do. - return nil - } + } + conditionExists := policy.tryInsertCondition(sqsQueueInfo.arn, snsARN) + if conditionExists { + return nil } - policy.addStatement(newStatement) b, uerr := json.Marshal(policy) if uerr != nil { return fmt.Errorf("failed serializing new sqs policy: %w", uerr) diff --git a/pubsub/aws/snssqs/snssqs_test.go b/pubsub/aws/snssqs/snssqs_test.go index 95104b742..617987245 100644 --- a/pubsub/aws/snssqs/snssqs_test.go +++ b/pubsub/aws/snssqs/snssqs_test.go @@ -14,6 +14,7 @@ limitations under the License. package snssqs import ( + "encoding/json" "testing" "github.com/stretchr/testify/require" @@ -402,6 +403,97 @@ func Test_replaceNameToAWSSanitizedExistingFifoName_Trimmed(t *testing.T) { r.Equal("012345678901234567890123456789012345678901234567890123456789012345678901234.fifo", v) } +func Test_UnmarshalJSON_UnmarshallsToArray(t *testing.T) { + t.Parallel() + r := require.New(t) + + s := ` +{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": { + "Service": "sns.amazonaws.com" + }, + "Action": "sqs:SendMessage", + "Resource": "sqsArn", + "Condition": { + "ForAllValues:ArnEquals": { + "aws:SourceArn": "snsArn" + } + } + } + ] +} +` + p := &policy{} + + err := json.Unmarshal([]byte(s), p) + r.Equal(err, nil) + + statement := p.Statement[0] + r.Equal(len(statement.Condition.ForAllValuesArnEquals.AwsSourceArn), 1) + r.Equal(statement.Condition.ForAllValuesArnEquals.AwsSourceArn[0], "snsArn") +} + +func Test_tryInsertCondition(t *testing.T) { + t.Parallel() + r := require.New(t) + + policy := &policy{Version: "2012-10-17"} + sqsArn := "sqsArn" + snsArns := []string{"snsArns1", "snsArns2", "snsArns3", "snsArns4"} + + for _, snsArn := range snsArns { + policy.tryInsertCondition(sqsArn, snsArn) + } + + r.Equal(len(policy.Statement), 1) + insertedStatement := policy.Statement[0] + r.Equal(insertedStatement.Resource, sqsArn) + r.Equal(len(insertedStatement.Condition.ForAllValuesArnEquals.AwsSourceArn), len(snsArns)) + r.ElementsMatch(insertedStatement.Condition.ForAllValuesArnEquals.AwsSourceArn, snsArns) +} + +func Test_policy_compatible(t *testing.T) { + t.Parallel() + r := require.New(t) + + sqsArn := "sqsArn" + snsArn := "snsArn" + oldPolicy := ` +{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": { + "Service": "sns.amazonaws.com" + }, + "Action": "sqs:SendMessage", + "Resource": "sqsArn", + "Condition": { + "ArnEquals": { + "aws:SourceArn": "snsArn" + } + } + } + ] +} +` + policy := &policy{Version: "2012-10-17"} + err := json.Unmarshal([]byte(oldPolicy), policy) + r.Equal(err, nil) + + policy.tryInsertCondition(sqsArn, snsArn) + r.Equal(len(policy.Statement), 1) + insertedStatement := policy.Statement[0] + r.Equal(insertedStatement.Resource, sqsArn) + r.Equal(len(insertedStatement.Condition.ForAllValuesArnEquals.AwsSourceArn), 1) + r.Equal(insertedStatement.Condition.ForAllValuesArnEquals.AwsSourceArn[0], snsArn) +} + func Test_buildARN_DefaultPartition(t *testing.T) { t.Parallel() r := require.New(t)