refactored aws sqs policy inserting (#1807)
* refactored aws sqs policy inserting Signed-off-by: Xingru <x.xingruxu@student.maastrichtuniversity.nl> * Add tests and fixed tryInsertCondition not inserting bug Signed-off-by: Xingru <x.xingruxu@student.maastrichtuniversity.nl> * fixed lint in snssqs_test.go Signed-off-by: Xingru <x.xingruxu@student.maastrichtuniversity.nl> * fixed lint Signed-off-by: Xingru <x.xingruxu@student.maastrichtuniversity.nl> * fixed lint error in policy.go Signed-off-by: Xingru <x.xingruxu@student.maastrichtuniversity.nl> Signed-off-by: Xingru <x.xingruxu@student.maastrichtuniversity.nl> Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com> Co-authored-by: Bernd Verst <4535280+berndverst@users.noreply.github.com> Co-authored-by: Loong Dai <long.dai@intel.com> Co-authored-by: Dapr Bot <56698301+dapr-bot@users.noreply.github.com> Co-authored-by: Yaron Schneider <schneider.yaron@live.com>
This commit is contained in:
parent
9c9df2ff76
commit
ed483dc88b
|
@ -13,12 +13,47 @@ limitations under the License.
|
|||
|
||||
package snssqs
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type arnEquals struct {
|
||||
AwsSourceArn string `json:"aws:SourceArn"`
|
||||
AwsSourceArn awsSourceArn `json:"aws:SourceArn"`
|
||||
}
|
||||
|
||||
type awsSourceArn []string
|
||||
|
||||
// UnmarshalJSON This is a custom unmarshaler for awsSourceArn for handling a special case
|
||||
// where aws flatten awsSourceArn into a string when it only contains one element
|
||||
func (a *awsSourceArn) UnmarshalJSON(data []byte) error {
|
||||
var i interface{}
|
||||
err := json.Unmarshal(data, &i)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
items := reflect.ValueOf(i)
|
||||
switch items.Kind() {
|
||||
case reflect.String:
|
||||
*a = append(*a, items.String())
|
||||
case reflect.Slice:
|
||||
*a = make([]string, 0, items.Len())
|
||||
for i := 0; i < items.Len(); i++ {
|
||||
item := items.Index(i)
|
||||
switch item.Kind() {
|
||||
case reflect.String:
|
||||
*a = append(*a, item.String())
|
||||
case reflect.Interface:
|
||||
*a = append(*a, item.Interface().(string))
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type condition struct {
|
||||
ArnEquals arnEquals
|
||||
ForAllValuesArnEquals arnEquals `json:"ForAllValues:ArnEquals"`
|
||||
}
|
||||
|
||||
type principal struct {
|
||||
|
@ -38,20 +73,33 @@ type policy struct {
|
|||
Statement []statement
|
||||
}
|
||||
|
||||
func (p *policy) statementExists(other *statement) bool {
|
||||
for _, s := range p.Statement {
|
||||
if s.Effect == other.Effect &&
|
||||
s.Principal.Service == other.Principal.Service &&
|
||||
s.Action == other.Action &&
|
||||
s.Resource == other.Resource &&
|
||||
s.Condition.ArnEquals.AwsSourceArn == other.Condition.ArnEquals.AwsSourceArn {
|
||||
return true
|
||||
func (p *policy) tryInsertCondition(sqsArn string, snsArn string) bool {
|
||||
for i, s := range p.Statement {
|
||||
// if there is a statement for sqsArn
|
||||
if s.Resource == sqsArn {
|
||||
// check if the snsArn already exists
|
||||
for _, a := range s.Condition.ForAllValuesArnEquals.AwsSourceArn {
|
||||
if a == snsArn {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// insert it if it does not exist
|
||||
p.Statement[i].Condition.ForAllValuesArnEquals.AwsSourceArn = append(p.Statement[i].Condition.ForAllValuesArnEquals.AwsSourceArn, snsArn)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// insert a new statement if no statement for the sqsArn
|
||||
newStatement := &statement{
|
||||
Effect: "Allow",
|
||||
Principal: principal{Service: "sns.amazonaws.com"},
|
||||
Action: "sqs:SendMessage",
|
||||
Resource: sqsArn,
|
||||
Condition: condition{
|
||||
ForAllValuesArnEquals: arnEquals{
|
||||
AwsSourceArn: []string{snsArn},
|
||||
},
|
||||
},
|
||||
}
|
||||
p.Statement = append(p.Statement, *newStatement)
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *policy) addStatement(other *statement) {
|
||||
p.Statement = append(p.Statement, *other)
|
||||
}
|
||||
|
|
|
@ -731,31 +731,18 @@ func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(parentCtx context.Context,
|
|||
return fmt.Errorf("error getting queue attributes: %w", err)
|
||||
}
|
||||
|
||||
newStatement := &statement{
|
||||
Effect: "Allow",
|
||||
Principal: principal{Service: "sns.amazonaws.com"},
|
||||
Action: "sqs:SendMessage",
|
||||
Resource: sqsQueueInfo.arn,
|
||||
Condition: condition{
|
||||
ArnEquals: arnEquals{
|
||||
AwsSourceArn: snsARN,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
policy := &policy{Version: "2012-10-17"}
|
||||
if policyStr, ok := getQueueAttributesOutput.Attributes[sqs.QueueAttributeNamePolicy]; ok {
|
||||
// look for the current statement if exists, else add it and store.
|
||||
if err = json.Unmarshal([]byte(*policyStr), policy); err != nil {
|
||||
return fmt.Errorf("error unmarshalling sqs policy: %w", err)
|
||||
}
|
||||
if policy.statementExists(newStatement) {
|
||||
// nothing to do.
|
||||
return nil
|
||||
}
|
||||
}
|
||||
conditionExists := policy.tryInsertCondition(sqsQueueInfo.arn, snsARN)
|
||||
if conditionExists {
|
||||
return nil
|
||||
}
|
||||
|
||||
policy.addStatement(newStatement)
|
||||
b, uerr := json.Marshal(policy)
|
||||
if uerr != nil {
|
||||
return fmt.Errorf("failed serializing new sqs policy: %w", uerr)
|
||||
|
|
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||
package snssqs
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -402,6 +403,97 @@ func Test_replaceNameToAWSSanitizedExistingFifoName_Trimmed(t *testing.T) {
|
|||
r.Equal("012345678901234567890123456789012345678901234567890123456789012345678901234.fifo", v)
|
||||
}
|
||||
|
||||
func Test_UnmarshalJSON_UnmarshallsToArray(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := require.New(t)
|
||||
|
||||
s := `
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Principal": {
|
||||
"Service": "sns.amazonaws.com"
|
||||
},
|
||||
"Action": "sqs:SendMessage",
|
||||
"Resource": "sqsArn",
|
||||
"Condition": {
|
||||
"ForAllValues:ArnEquals": {
|
||||
"aws:SourceArn": "snsArn"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
`
|
||||
p := &policy{}
|
||||
|
||||
err := json.Unmarshal([]byte(s), p)
|
||||
r.Equal(err, nil)
|
||||
|
||||
statement := p.Statement[0]
|
||||
r.Equal(len(statement.Condition.ForAllValuesArnEquals.AwsSourceArn), 1)
|
||||
r.Equal(statement.Condition.ForAllValuesArnEquals.AwsSourceArn[0], "snsArn")
|
||||
}
|
||||
|
||||
func Test_tryInsertCondition(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := require.New(t)
|
||||
|
||||
policy := &policy{Version: "2012-10-17"}
|
||||
sqsArn := "sqsArn"
|
||||
snsArns := []string{"snsArns1", "snsArns2", "snsArns3", "snsArns4"}
|
||||
|
||||
for _, snsArn := range snsArns {
|
||||
policy.tryInsertCondition(sqsArn, snsArn)
|
||||
}
|
||||
|
||||
r.Equal(len(policy.Statement), 1)
|
||||
insertedStatement := policy.Statement[0]
|
||||
r.Equal(insertedStatement.Resource, sqsArn)
|
||||
r.Equal(len(insertedStatement.Condition.ForAllValuesArnEquals.AwsSourceArn), len(snsArns))
|
||||
r.ElementsMatch(insertedStatement.Condition.ForAllValuesArnEquals.AwsSourceArn, snsArns)
|
||||
}
|
||||
|
||||
func Test_policy_compatible(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := require.New(t)
|
||||
|
||||
sqsArn := "sqsArn"
|
||||
snsArn := "snsArn"
|
||||
oldPolicy := `
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Principal": {
|
||||
"Service": "sns.amazonaws.com"
|
||||
},
|
||||
"Action": "sqs:SendMessage",
|
||||
"Resource": "sqsArn",
|
||||
"Condition": {
|
||||
"ArnEquals": {
|
||||
"aws:SourceArn": "snsArn"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
`
|
||||
policy := &policy{Version: "2012-10-17"}
|
||||
err := json.Unmarshal([]byte(oldPolicy), policy)
|
||||
r.Equal(err, nil)
|
||||
|
||||
policy.tryInsertCondition(sqsArn, snsArn)
|
||||
r.Equal(len(policy.Statement), 1)
|
||||
insertedStatement := policy.Statement[0]
|
||||
r.Equal(insertedStatement.Resource, sqsArn)
|
||||
r.Equal(len(insertedStatement.Condition.ForAllValuesArnEquals.AwsSourceArn), 1)
|
||||
r.Equal(insertedStatement.Condition.ForAllValuesArnEquals.AwsSourceArn[0], snsArn)
|
||||
}
|
||||
|
||||
func Test_buildARN_DefaultPartition(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := require.New(t)
|
||||
|
|
Loading…
Reference in New Issue