Rename StringOrSlice to StringOrSet, sort lists

This commit is contained in:
Peter Rifel 2024-02-12 21:37:27 -06:00
parent 21804bf631
commit f098401c49
No known key found for this signature in database
5 changed files with 62 additions and 60 deletions

View File

@ -60,19 +60,19 @@ func addCertManagerPermissions(b *iam.PolicyBuilder, p *iam.Policy) {
Action: stringorslice.Of("route53:ChangeResourceRecordSets", Action: stringorslice.Of("route53:ChangeResourceRecordSets",
"route53:ListResourceRecordSets", "route53:ListResourceRecordSets",
), ),
Resource: stringorslice.Slice(zoneResources), Resource: stringorslice.Set(zoneResources),
}) })
p.Statement = append(p.Statement, &iam.Statement{ p.Statement = append(p.Statement, &iam.Statement{
Effect: iam.StatementEffectAllow, Effect: iam.StatementEffectAllow,
Action: stringorslice.Slice([]string{"route53:GetChange"}), Action: stringorslice.Set([]string{"route53:GetChange"}),
Resource: stringorslice.Slice([]string{fmt.Sprintf("arn:%v:route53:::change/*", b.Partition)}), Resource: stringorslice.Set([]string{fmt.Sprintf("arn:%v:route53:::change/*", b.Partition)}),
}) })
wildcard := stringorslice.Slice([]string{"*"}) wildcard := stringorslice.Set([]string{"*"})
p.Statement = append(p.Statement, &iam.Statement{ p.Statement = append(p.Statement, &iam.Statement{
Effect: iam.StatementEffectAllow, Effect: iam.StatementEffectAllow,
Action: stringorslice.Slice([]string{"route53:ListHostedZonesByName"}), Action: stringorslice.Set([]string{"route53:ListHostedZonesByName"}),
Resource: wildcard, Resource: wildcard,
}) })
} }

View File

@ -78,7 +78,7 @@ func (p *Policy) AddEC2CreateAction(actions, resources []string) {
&Statement{ &Statement{
Effect: StatementEffectAllow, Effect: StatementEffectAllow,
Action: stringorslice.String("ec2:CreateTags"), Action: stringorslice.String("ec2:CreateTags"),
Resource: stringorslice.Slice(actualResources), Resource: stringorslice.Set(actualResources),
Condition: Condition{ Condition: Condition{
"StringEquals": map[string]interface{}{ "StringEquals": map[string]interface{}{
"aws:RequestTag/KubernetesCluster": p.clusterName, "aws:RequestTag/KubernetesCluster": p.clusterName,
@ -89,11 +89,11 @@ func (p *Policy) AddEC2CreateAction(actions, resources []string) {
&Statement{ &Statement{
Effect: StatementEffectAllow, Effect: StatementEffectAllow,
Action: stringorslice.Slice([]string{ Action: stringorslice.Set([]string{
"ec2:CreateTags", "ec2:CreateTags",
"ec2:DeleteTags", // aws.go, tag.go "ec2:DeleteTags", // aws.go, tag.go
}), }),
Resource: stringorslice.Slice(actualResources), Resource: stringorslice.Set(actualResources),
Condition: Condition{ Condition: Condition{
"Null": map[string]string{ "Null": map[string]string{
"aws:RequestTag/KubernetesCluster": "true", "aws:RequestTag/KubernetesCluster": "true",
@ -176,8 +176,8 @@ type Condition map[string]interface{}
type Statement struct { type Statement struct {
Effect StatementEffect Effect StatementEffect
Principal Principal Principal Principal
Action stringorslice.StringOrSlice Action stringorslice.StringOrSet
Resource stringorslice.StringOrSlice Resource stringorslice.StringOrSet
Condition Condition Condition Condition
} }
@ -600,7 +600,7 @@ func (b *PolicyBuilder) AddS3Permissions(p *Policy) (*Policy, error) {
"s3:ListBucket", "s3:ListBucket",
"s3:ListBucketVersions", "s3:ListBucketVersions",
), ),
Resource: stringorslice.Slice([]string{ Resource: stringorslice.Set([]string{
fmt.Sprintf("arn:%v:s3:::%v", p.partition, s3Bucket), fmt.Sprintf("arn:%v:s3:::%v", p.partition, s3Bucket),
}), }),
}) })
@ -612,7 +612,7 @@ func (b *PolicyBuilder) AddS3Permissions(p *Policy) (*Policy, error) {
func (b *PolicyBuilder) buildS3WriteStatements(p *Policy, iamS3Path string) { func (b *PolicyBuilder) buildS3WriteStatements(p *Policy, iamS3Path string) {
p.Statement = append(p.Statement, &Statement{ p.Statement = append(p.Statement, &Statement{
Effect: StatementEffectAllow, Effect: StatementEffectAllow,
Action: stringorslice.Slice([]string{ Action: stringorslice.Set([]string{
"s3:GetObject", "s3:GetObject",
"s3:DeleteObject", "s3:DeleteObject",
"s3:DeleteObjectVersion", "s3:DeleteObjectVersion",
@ -640,7 +640,7 @@ func (b *PolicyBuilder) buildS3GetStatements(p *Policy, iamS3Path string) error
p.Statement = append(p.Statement, &Statement{ p.Statement = append(p.Statement, &Statement{
Effect: StatementEffectAllow, Effect: StatementEffectAllow,
Action: stringorslice.Slice([]string{"s3:Get*"}), Action: stringorslice.Set([]string{"s3:Get*"}),
Resource: stringorslice.Of(resources...), Resource: stringorslice.Of(resources...),
}) })
} }
@ -803,7 +803,7 @@ func addEtcdManagerPermissions(p *Policy) {
Action: stringorslice.Of( Action: stringorslice.Of(
"ec2:AttachVolume", "ec2:AttachVolume",
), ),
Resource: stringorslice.Slice([]string{"*"}), Resource: stringorslice.Set([]string{"*"}),
Condition: Condition{ Condition: Condition{
"StringEquals": map[string]string{ "StringEquals": map[string]string{
"aws:ResourceTag/k8s.io/role/master": "1", "aws:ResourceTag/k8s.io/role/master": "1",
@ -1064,19 +1064,19 @@ func AddDNSControllerPermissions(b *PolicyBuilder, p *Policy) {
Action: stringorslice.Of("route53:ChangeResourceRecordSets", Action: stringorslice.Of("route53:ChangeResourceRecordSets",
"route53:ListResourceRecordSets", "route53:ListResourceRecordSets",
"route53:GetHostedZone"), "route53:GetHostedZone"),
Resource: stringorslice.Slice([]string{fmt.Sprintf("arn:%v:route53:::hostedzone/%v", b.Partition, hostedZoneID)}), Resource: stringorslice.Set([]string{fmt.Sprintf("arn:%v:route53:::hostedzone/%v", b.Partition, hostedZoneID)}),
}) })
p.Statement = append(p.Statement, &Statement{ p.Statement = append(p.Statement, &Statement{
Effect: StatementEffectAllow, Effect: StatementEffectAllow,
Action: stringorslice.Slice([]string{"route53:GetChange"}), Action: stringorslice.Set([]string{"route53:GetChange"}),
Resource: stringorslice.Slice([]string{fmt.Sprintf("arn:%v:route53:::change/*", b.Partition)}), Resource: stringorslice.Set([]string{fmt.Sprintf("arn:%v:route53:::change/*", b.Partition)}),
}) })
wildcard := stringorslice.Slice([]string{"*"}) wildcard := stringorslice.Set([]string{"*"})
p.Statement = append(p.Statement, &Statement{ p.Statement = append(p.Statement, &Statement{
Effect: StatementEffectAllow, Effect: StatementEffectAllow,
Action: stringorslice.Slice([]string{"route53:ListHostedZones", "route53:ListTagsForResource"}), Action: stringorslice.Set([]string{"route53:ListHostedZones", "route53:ListTagsForResource"}),
Resource: wildcard, Resource: wildcard,
}) })
} }
@ -1169,10 +1169,10 @@ func addAmazonVPCCNIPermissions(p *Policy) {
p.Statement = append(p.Statement, p.Statement = append(p.Statement,
&Statement{ &Statement{
Effect: StatementEffectAllow, Effect: StatementEffectAllow,
Action: stringorslice.Slice([]string{ Action: stringorslice.Set([]string{
"ec2:CreateTags", "ec2:CreateTags",
}), }),
Resource: stringorslice.Slice([]string{ Resource: stringorslice.Set([]string{
strings.Join([]string{"arn:", p.partition, ":ec2:*:*:network-interface/*"}, ""), strings.Join([]string{"arn:", p.partition, ":ec2:*:*:network-interface/*"}, ""),
}), }),
}, },

View File

@ -49,7 +49,7 @@ func TestRoundTrip(t *testing.T) {
Action: stringorslice.Of("ec2:DescribeRegions", "ec2:DescribeInstances"), Action: stringorslice.Of("ec2:DescribeRegions", "ec2:DescribeInstances"),
Resource: stringorslice.Of("a", "b"), Resource: stringorslice.Of("a", "b"),
}, },
JSON: "{\"Action\":[\"ec2:DescribeRegions\",\"ec2:DescribeInstances\"],\"Effect\":\"Deny\",\"Resource\":[\"a\",\"b\"]}", JSON: "{\"Action\":[\"ec2:DescribeInstances\",\"ec2:DescribeRegions\"],\"Effect\":\"Deny\",\"Resource\":[\"a\",\"b\"]}",
}, },
{ {
IAM: &Statement{ IAM: &Statement{

View File

@ -18,45 +18,54 @@ package stringorslice
import ( import (
"encoding/json" "encoding/json"
"sort"
"strings" "strings"
"k8s.io/apimachinery/pkg/util/sets"
) )
// StringOrSlice is a type that holds a []string, but marshals to a []string or a string. // StringOrSet is a type that holds a []string, but marshals to a []string or a string.
type StringOrSlice struct { type StringOrSet struct {
values []string values sets.Set[string]
forceEncodeAsArray bool forceEncodeAsArray bool
} }
func (s *StringOrSlice) IsEmpty() bool { func (s *StringOrSet) IsEmpty() bool {
return len(s.values) == 0 return len(s.values) == 0
} }
// Slice will build a value that marshals to a JSON array // Set will build a value that marshals to a JSON array
func Slice(v []string) StringOrSlice { func Set(v []string) StringOrSet {
return StringOrSlice{values: v, forceEncodeAsArray: true} values := sets.Set[string]{}
values.Insert(v...)
return StringOrSet{values: values, forceEncodeAsArray: true}
} }
// Of will build a value that marshals to a JSON array if len(v) > 1, // Of will build a value that marshals to a JSON array if len(v) > 1,
// otherwise to a single string // otherwise to a single string
func Of(v ...string) StringOrSlice { func Of(v ...string) StringOrSet {
if v == nil { if v == nil {
v = []string{} v = []string{}
} }
return StringOrSlice{values: v} values := sets.Set[string]{}
values.Insert(v...)
return StringOrSet{values: values}
} }
// String will build a value that marshals to a single string // String will build a value that marshals to a single string
func String(v string) StringOrSlice { func String(v string) StringOrSet {
return StringOrSlice{values: []string{v}, forceEncodeAsArray: false} return StringOrSet{values: sets.New(v), forceEncodeAsArray: false}
} }
// UnmarshalJSON implements the json.Unmarshaller interface. // UnmarshalJSON implements the json.Unmarshaller interface.
func (s *StringOrSlice) UnmarshalJSON(value []byte) error { func (s *StringOrSet) UnmarshalJSON(value []byte) error {
if value[0] == '[' { if value[0] == '[' {
s.forceEncodeAsArray = true s.forceEncodeAsArray = true
if err := json.Unmarshal(value, &s.values); err != nil { var vals []string
if err := json.Unmarshal(value, &vals); err != nil {
return nil return nil
} }
s.values = sets.New(vals...)
return nil return nil
} }
s.forceEncodeAsArray = false s.forceEncodeAsArray = false
@ -64,46 +73,39 @@ func (s *StringOrSlice) UnmarshalJSON(value []byte) error {
if err := json.Unmarshal(value, &stringValue); err != nil { if err := json.Unmarshal(value, &stringValue); err != nil {
return err return err
} }
s.values = []string{stringValue} s.values = sets.New(stringValue)
return nil return nil
} }
// String returns the string value, or the Itoa of the int value. // String returns the string value, or the Itoa of the int value.
func (s StringOrSlice) String() string { func (s StringOrSet) String() string {
return strings.Join(s.values, ",") return strings.Join(sets.List[string](s.values), ",")
} }
func (v *StringOrSlice) Value() []string { func (v *StringOrSet) Value() []string {
return v.values vals := sets.List[string](v.values)
sort.Strings(vals)
return vals
} }
func (l StringOrSlice) Equal(r StringOrSlice) bool { func (l StringOrSet) Equal(r StringOrSet) bool {
if len(l.values) != len(r.values) { return l.values.Equal(r.values)
return false
}
for i := 0; i < len(l.values); i++ {
if l.values[i] != r.values[i] {
return false
}
}
return true
} }
// MarshalJSON implements the json.Marshaller interface. // MarshalJSON implements the json.Marshaller interface.
func (v StringOrSlice) MarshalJSON() ([]byte, error) { func (v StringOrSet) MarshalJSON() ([]byte, error) {
encodeAsJSONArray := v.forceEncodeAsArray encodeAsJSONArray := v.forceEncodeAsArray
if len(v.values) > 1 { if len(v.values) > 1 {
encodeAsJSONArray = true encodeAsJSONArray = true
} }
values := v.values values := v.Value()
if values == nil { if values == nil {
values = []string{} values = []string{}
} }
if encodeAsJSONArray { if encodeAsJSONArray {
return json.Marshal(values) return json.Marshal(values)
} else if len(v.values) == 1 { } else if len(values) == 1 {
s := v.values[0] return json.Marshal(&values[0])
return json.Marshal(&s)
} else { } else {
return json.Marshal(values) return json.Marshal(values)
} }

View File

@ -25,7 +25,7 @@ import (
func TestRoundTrip(t *testing.T) { func TestRoundTrip(t *testing.T) {
grid := []struct { grid := []struct {
Value StringOrSlice Value StringOrSet
JSON string JSON string
}{ }{
{ {
@ -37,7 +37,7 @@ func TestRoundTrip(t *testing.T) {
JSON: "\"a\"", JSON: "\"a\"",
}, },
{ {
Value: Slice([]string{"a"}), Value: Set([]string{"a"}),
JSON: "[\"a\"]", JSON: "[\"a\"]",
}, },
{ {
@ -45,7 +45,7 @@ func TestRoundTrip(t *testing.T) {
JSON: "[\"a\",\"b\"]", JSON: "[\"a\",\"b\"]",
}, },
{ {
Value: Slice([]string{"a", "b"}), Value: Set([]string{"a", "b"}),
JSON: "[\"a\",\"b\"]", JSON: "[\"a\",\"b\"]",
}, },
{ {
@ -53,7 +53,7 @@ func TestRoundTrip(t *testing.T) {
JSON: "[]", JSON: "[]",
}, },
{ {
Value: Slice(nil), Value: Set(nil),
JSON: "[]", JSON: "[]",
}, },
} }
@ -69,7 +69,7 @@ func TestRoundTrip(t *testing.T) {
t.Errorf("Unexpected JSON encoding. Actual=%q, Expected=%q", string(actualJSON), g.JSON) t.Errorf("Unexpected JSON encoding. Actual=%q, Expected=%q", string(actualJSON), g.JSON)
} }
parsed := &StringOrSlice{} parsed := &StringOrSet{}
err = json.Unmarshal([]byte(g.JSON), parsed) err = json.Unmarshal([]byte(g.JSON), parsed)
if err != nil { if err != nil {
t.Errorf("error decoding StringOrSlice %s to json: %v", g.JSON, err) t.Errorf("error decoding StringOrSlice %s to json: %v", g.JSON, err)