Improve removal of orphaned security group rules

This commit is contained in:
John Gardiner Myers 2023-10-06 22:21:33 -07:00
parent 3f1ee1e820
commit 9f40d59545
3 changed files with 54 additions and 27 deletions

View File

@ -303,8 +303,9 @@ type SecurityGroupInfo struct {
func (b *AWSModelContext) GetSecurityGroups(role kops.InstanceGroupRole) ([]SecurityGroupInfo, error) { func (b *AWSModelContext) GetSecurityGroups(role kops.InstanceGroupRole) ([]SecurityGroupInfo, error) {
var baseGroup *awstasks.SecurityGroup var baseGroup *awstasks.SecurityGroup
if role == kops.InstanceGroupRoleControlPlane { name := b.SecurityGroupName(role)
name := b.SecurityGroupName(role) switch role {
case kops.InstanceGroupRoleControlPlane:
baseGroup = &awstasks.SecurityGroup{ baseGroup = &awstasks.SecurityGroup{
Name: fi.PtrTo(name), Name: fi.PtrTo(name),
VPC: b.LinkToVPC(), VPC: b.LinkToVPC(),
@ -319,13 +320,15 @@ func (b *AWSModelContext) GetSecurityGroups(role kops.InstanceGroupRole) ([]Secu
"port=4789", // VXLAN "port=4789", // VXLAN
"port=179", // Calico "port=179", // Calico
"port=8443", // k8s api secondary listener "port=8443", // k8s api secondary listener
"port=3:4", // ICMP
"port=-1", // ICMPv6
// TODO: UDP vs TCP // TODO: UDP vs TCP vs ICMP vs ICMPv6
// TODO: Protocol 4 for calico // TODO: Protocol 4 for calico
}, },
} }
baseGroup.Tags = b.CloudTags(name, false) baseGroup.Tags = b.CloudTags(name, false)
} else if role == kops.InstanceGroupRoleNode { case kops.InstanceGroupRoleNode:
name := b.SecurityGroupName(role) name := b.SecurityGroupName(role)
baseGroup = &awstasks.SecurityGroup{ baseGroup = &awstasks.SecurityGroup{
Name: fi.PtrTo(name), Name: fi.PtrTo(name),
@ -334,19 +337,22 @@ func (b *AWSModelContext) GetSecurityGroups(role kops.InstanceGroupRole) ([]Secu
RemoveExtraRules: []string{"port=22"}, RemoveExtraRules: []string{"port=22"},
} }
baseGroup.Tags = b.CloudTags(name, false) baseGroup.Tags = b.CloudTags(name, false)
} else if role == kops.InstanceGroupRoleBastion { case kops.InstanceGroupRoleBastion:
name := b.SecurityGroupName(role) name := b.SecurityGroupName(role)
baseGroup = &awstasks.SecurityGroup{ baseGroup = &awstasks.SecurityGroup{
Name: fi.PtrTo(name), Name: fi.PtrTo(name),
VPC: b.LinkToVPC(), VPC: b.LinkToVPC(),
Description: fi.PtrTo("Security group for bastion"), Description: fi.PtrTo("Security group for bastion"),
RemoveExtraRules: []string{"port=22"}, RemoveExtraRules: []string{
"port=22", // SSH
"port=3:4", // ICMP
"port=-1", // ICMPv6
},
} }
baseGroup.Tags = b.CloudTags(name, false) baseGroup.Tags = b.CloudTags(name, false)
} else { default:
return nil, fmt.Errorf("not a supported security group type") return nil, fmt.Errorf("not a supported security group type")
} }
var groups []SecurityGroupInfo var groups []SecurityGroupInfo
done := make(map[string]bool) done := make(map[string]bool)

View File

@ -386,12 +386,23 @@ func ParseRemovalRule(rule string) (RemovalRule, error) {
if len(tokens) == 2 { if len(tokens) == 2 {
if tokens[0] == "port" { if tokens[0] == "port" {
port, err := strconv.Atoi(tokens[1]) ports := strings.SplitN(tokens[1], ":", 2)
fromPort, err := strconv.Atoi(ports[0])
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot parse rule %q", rule) return nil, fmt.Errorf("cannot parse rule %q", rule)
} }
toPort := fromPort
if len(ports) > 1 {
toPort, err = strconv.Atoi(ports[1])
if err != nil {
return nil, fmt.Errorf("cannot parse rule %q", rule)
}
}
return &PortRemovalRule{Port: port}, nil return &PortRemovalRule{
FromPort: fromPort,
ToPort: toPort,
}, nil
} else { } else {
return nil, fmt.Errorf("cannot parse rule %q", rule) return nil, fmt.Errorf("cannot parse rule %q", rule)
} }
@ -400,7 +411,8 @@ func ParseRemovalRule(rule string) (RemovalRule, error) {
} }
type PortRemovalRule struct { type PortRemovalRule struct {
Port int FromPort int
ToPort int
} }
var _ RemovalRule = &PortRemovalRule{} var _ RemovalRule = &PortRemovalRule{}
@ -411,10 +423,10 @@ func (r *PortRemovalRule) String() string {
func (r *PortRemovalRule) Matches(permission *ec2.SecurityGroupRule) bool { func (r *PortRemovalRule) Matches(permission *ec2.SecurityGroupRule) bool {
// Check if port matches // Check if port matches
if permission.FromPort == nil || *permission.FromPort != int64(r.Port) { if permission.FromPort == nil || *permission.FromPort != int64(r.FromPort) {
return false return false
} }
if permission.ToPort == nil || *permission.ToPort != int64(r.Port) { if permission.ToPort == nil || *permission.ToPort != int64(r.ToPort) {
return false return false
} }
return true return true

View File

@ -34,8 +34,10 @@ func TestParseRemovalRule(t *testing.T) {
testNotParse(t, "port=a") testNotParse(t, "port=a")
testNotParse(t, "port=22-23") testNotParse(t, "port=22-23")
testParsesAsPort(t, "port=22", 22) testParsesAsPort(t, "port=22", 22, 22)
testParsesAsPort(t, "port=443", 443) testParsesAsPort(t, "port=443", 443, 443)
testParsesAsPort(t, "port=22:23", 22, 23)
testParsesAsPort(t, "port=-1", -1, -1)
} }
func testNotParse(t *testing.T, rule string) { func testNotParse(t *testing.T, rule string) {
@ -45,7 +47,7 @@ func testNotParse(t *testing.T, rule string) {
} }
} }
func testParsesAsPort(t *testing.T, rule string, port int) { func testParsesAsPort(t *testing.T, rule string, fromPort int, toPort int) {
r, err := ParseRemovalRule(rule) r, err := ParseRemovalRule(rule)
if err != nil { if err != nil {
t.Fatalf("unexpected failure to parse rule %q: %v", rule, err) t.Fatalf("unexpected failure to parse rule %q: %v", rule, err)
@ -54,26 +56,33 @@ func testParsesAsPort(t *testing.T, rule string, port int) {
if !ok { if !ok {
t.Fatalf("unexpected rule type for rule %q: %T", r, err) t.Fatalf("unexpected rule type for rule %q: %T", r, err)
} }
if portRemovalRule.Port != port { if portRemovalRule.FromPort != fromPort {
t.Fatalf("unexpected port for %q, expecting %d, got %q", rule, port, r) t.Fatalf("unexpected fromPort for %q, expecting %d, got %q", rule, fromPort, r)
}
if portRemovalRule.ToPort != toPort {
t.Fatalf("unexpected toPort for %q, expecting %d, got %q", rule, toPort, r)
} }
} }
func TestPortRemovalRule(t *testing.T) { func TestPortRemovalRule(t *testing.T) {
r := &PortRemovalRule{Port: 22} r := &PortRemovalRule{FromPort: 22, ToPort: 23}
testMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(22), ToPort: aws.Int64(22)}) testMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(22), ToPort: aws.Int64(23)})
testNotMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(0), ToPort: aws.Int64(0)}) testNotMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(0), ToPort: aws.Int64(0)})
testNotMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(22), ToPort: aws.Int64(22)})
testNotMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(23), ToPort: aws.Int64(23)}) testNotMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(23), ToPort: aws.Int64(23)})
testNotMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(20), ToPort: aws.Int64(22)}) testNotMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(20), ToPort: aws.Int64(23)})
testNotMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(22), ToPort: aws.Int64(23)}) testNotMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(22), ToPort: aws.Int64(24)})
testNotMatches(t, r, &ec2.SecurityGroupRule{ToPort: aws.Int64(22)}) testNotMatches(t, r, &ec2.SecurityGroupRule{ToPort: aws.Int64(23)})
testNotMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(22)}) testNotMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(22)})
testNotMatches(t, r, &ec2.SecurityGroupRule{}) testNotMatches(t, r, &ec2.SecurityGroupRule{})
r = &PortRemovalRule{FromPort: -1, ToPort: -1}
testMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(-1), ToPort: aws.Int64(-1)})
} }
func TestPortRemovalRule_Zero(t *testing.T) { func TestPortRemovalRule_Zero(t *testing.T) {
r := &PortRemovalRule{Port: 0} r := &PortRemovalRule{FromPort: 0, ToPort: 0}
testMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(0), ToPort: aws.Int64(0)}) testMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(0), ToPort: aws.Int64(0)})
testNotMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(0), ToPort: aws.Int64(20)}) testNotMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(0), ToPort: aws.Int64(20)})