refactored aws sqs policy inserting (#1807)

* refactored aws sqs policy inserting

Signed-off-by: Xingru <x.xingruxu@student.maastrichtuniversity.nl>

* Add tests and fixed tryInsertCondition not inserting bug

Signed-off-by: Xingru <x.xingruxu@student.maastrichtuniversity.nl>

* fixed lint in snssqs_test.go

Signed-off-by: Xingru <x.xingruxu@student.maastrichtuniversity.nl>

* fixed lint

Signed-off-by: Xingru <x.xingruxu@student.maastrichtuniversity.nl>

* fixed lint error in policy.go

Signed-off-by: Xingru <x.xingruxu@student.maastrichtuniversity.nl>

Signed-off-by: Xingru <x.xingruxu@student.maastrichtuniversity.nl>
Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
Co-authored-by: Bernd Verst <4535280+berndverst@users.noreply.github.com>
Co-authored-by: Loong Dai <long.dai@intel.com>
Co-authored-by: Dapr Bot <56698301+dapr-bot@users.noreply.github.com>
Co-authored-by: Yaron Schneider <schneider.yaron@live.com>
This commit is contained in:
Xingru 2022-09-14 02:11:58 +02:00 committed by GitHub
parent 9c9df2ff76
commit ed483dc88b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 159 additions and 32 deletions

View File

@ -13,12 +13,47 @@ limitations under the License.
package snssqs
import (
"encoding/json"
"reflect"
)
type arnEquals struct {
AwsSourceArn string `json:"aws:SourceArn"`
AwsSourceArn awsSourceArn `json:"aws:SourceArn"`
}
type awsSourceArn []string
// UnmarshalJSON This is a custom unmarshaler for awsSourceArn for handling a special case
// where aws flatten awsSourceArn into a string when it only contains one element
func (a *awsSourceArn) UnmarshalJSON(data []byte) error {
var i interface{}
err := json.Unmarshal(data, &i)
if err != nil {
return err
}
items := reflect.ValueOf(i)
switch items.Kind() {
case reflect.String:
*a = append(*a, items.String())
case reflect.Slice:
*a = make([]string, 0, items.Len())
for i := 0; i < items.Len(); i++ {
item := items.Index(i)
switch item.Kind() {
case reflect.String:
*a = append(*a, item.String())
case reflect.Interface:
*a = append(*a, item.Interface().(string))
}
}
}
return nil
}
type condition struct {
ArnEquals arnEquals
ForAllValuesArnEquals arnEquals `json:"ForAllValues:ArnEquals"`
}
type principal struct {
@ -38,20 +73,33 @@ type policy struct {
Statement []statement
}
func (p *policy) statementExists(other *statement) bool {
for _, s := range p.Statement {
if s.Effect == other.Effect &&
s.Principal.Service == other.Principal.Service &&
s.Action == other.Action &&
s.Resource == other.Resource &&
s.Condition.ArnEquals.AwsSourceArn == other.Condition.ArnEquals.AwsSourceArn {
return true
func (p *policy) tryInsertCondition(sqsArn string, snsArn string) bool {
for i, s := range p.Statement {
// if there is a statement for sqsArn
if s.Resource == sqsArn {
// check if the snsArn already exists
for _, a := range s.Condition.ForAllValuesArnEquals.AwsSourceArn {
if a == snsArn {
return true
}
}
// insert it if it does not exist
p.Statement[i].Condition.ForAllValuesArnEquals.AwsSourceArn = append(p.Statement[i].Condition.ForAllValuesArnEquals.AwsSourceArn, snsArn)
return false
}
}
// insert a new statement if no statement for the sqsArn
newStatement := &statement{
Effect: "Allow",
Principal: principal{Service: "sns.amazonaws.com"},
Action: "sqs:SendMessage",
Resource: sqsArn,
Condition: condition{
ForAllValuesArnEquals: arnEquals{
AwsSourceArn: []string{snsArn},
},
},
}
p.Statement = append(p.Statement, *newStatement)
return false
}
func (p *policy) addStatement(other *statement) {
p.Statement = append(p.Statement, *other)
}

View File

@ -731,31 +731,18 @@ func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(parentCtx context.Context,
return fmt.Errorf("error getting queue attributes: %w", err)
}
newStatement := &statement{
Effect: "Allow",
Principal: principal{Service: "sns.amazonaws.com"},
Action: "sqs:SendMessage",
Resource: sqsQueueInfo.arn,
Condition: condition{
ArnEquals: arnEquals{
AwsSourceArn: snsARN,
},
},
}
policy := &policy{Version: "2012-10-17"}
if policyStr, ok := getQueueAttributesOutput.Attributes[sqs.QueueAttributeNamePolicy]; ok {
// look for the current statement if exists, else add it and store.
if err = json.Unmarshal([]byte(*policyStr), policy); err != nil {
return fmt.Errorf("error unmarshalling sqs policy: %w", err)
}
if policy.statementExists(newStatement) {
// nothing to do.
return nil
}
}
conditionExists := policy.tryInsertCondition(sqsQueueInfo.arn, snsARN)
if conditionExists {
return nil
}
policy.addStatement(newStatement)
b, uerr := json.Marshal(policy)
if uerr != nil {
return fmt.Errorf("failed serializing new sqs policy: %w", uerr)

View File

@ -14,6 +14,7 @@ limitations under the License.
package snssqs
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
@ -402,6 +403,97 @@ func Test_replaceNameToAWSSanitizedExistingFifoName_Trimmed(t *testing.T) {
r.Equal("012345678901234567890123456789012345678901234567890123456789012345678901234.fifo", v)
}
func Test_UnmarshalJSON_UnmarshallsToArray(t *testing.T) {
t.Parallel()
r := require.New(t)
s := `
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Principal": {
"Service": "sns.amazonaws.com"
},
"Action": "sqs:SendMessage",
"Resource": "sqsArn",
"Condition": {
"ForAllValues:ArnEquals": {
"aws:SourceArn": "snsArn"
}
}
}
]
}
`
p := &policy{}
err := json.Unmarshal([]byte(s), p)
r.Equal(err, nil)
statement := p.Statement[0]
r.Equal(len(statement.Condition.ForAllValuesArnEquals.AwsSourceArn), 1)
r.Equal(statement.Condition.ForAllValuesArnEquals.AwsSourceArn[0], "snsArn")
}
func Test_tryInsertCondition(t *testing.T) {
t.Parallel()
r := require.New(t)
policy := &policy{Version: "2012-10-17"}
sqsArn := "sqsArn"
snsArns := []string{"snsArns1", "snsArns2", "snsArns3", "snsArns4"}
for _, snsArn := range snsArns {
policy.tryInsertCondition(sqsArn, snsArn)
}
r.Equal(len(policy.Statement), 1)
insertedStatement := policy.Statement[0]
r.Equal(insertedStatement.Resource, sqsArn)
r.Equal(len(insertedStatement.Condition.ForAllValuesArnEquals.AwsSourceArn), len(snsArns))
r.ElementsMatch(insertedStatement.Condition.ForAllValuesArnEquals.AwsSourceArn, snsArns)
}
func Test_policy_compatible(t *testing.T) {
t.Parallel()
r := require.New(t)
sqsArn := "sqsArn"
snsArn := "snsArn"
oldPolicy := `
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Principal": {
"Service": "sns.amazonaws.com"
},
"Action": "sqs:SendMessage",
"Resource": "sqsArn",
"Condition": {
"ArnEquals": {
"aws:SourceArn": "snsArn"
}
}
}
]
}
`
policy := &policy{Version: "2012-10-17"}
err := json.Unmarshal([]byte(oldPolicy), policy)
r.Equal(err, nil)
policy.tryInsertCondition(sqsArn, snsArn)
r.Equal(len(policy.Statement), 1)
insertedStatement := policy.Statement[0]
r.Equal(insertedStatement.Resource, sqsArn)
r.Equal(len(insertedStatement.Condition.ForAllValuesArnEquals.AwsSourceArn), 1)
r.Equal(insertedStatement.Condition.ForAllValuesArnEquals.AwsSourceArn[0], snsArn)
}
func Test_buildARN_DefaultPartition(t *testing.T) {
t.Parallel()
r := require.New(t)