Improve removal of orphaned security group rules

This commit is contained in:
John Gardiner Myers 2023-10-06 22:21:33 -07:00
parent 3f1ee1e820
commit 9f40d59545
3 changed files with 54 additions and 27 deletions

View File

@ -303,8 +303,9 @@ type SecurityGroupInfo struct {
func (b *AWSModelContext) GetSecurityGroups(role kops.InstanceGroupRole) ([]SecurityGroupInfo, error) {
var baseGroup *awstasks.SecurityGroup
if role == kops.InstanceGroupRoleControlPlane {
name := b.SecurityGroupName(role)
name := b.SecurityGroupName(role)
switch role {
case kops.InstanceGroupRoleControlPlane:
baseGroup = &awstasks.SecurityGroup{
Name: fi.PtrTo(name),
VPC: b.LinkToVPC(),
@ -319,13 +320,15 @@ func (b *AWSModelContext) GetSecurityGroups(role kops.InstanceGroupRole) ([]Secu
"port=4789", // VXLAN
"port=179", // Calico
"port=8443", // k8s api secondary listener
"port=3:4", // ICMP
"port=-1", // ICMPv6
// TODO: UDP vs TCP
// TODO: UDP vs TCP vs ICMP vs ICMPv6
// TODO: Protocol 4 for calico
},
}
baseGroup.Tags = b.CloudTags(name, false)
} else if role == kops.InstanceGroupRoleNode {
case kops.InstanceGroupRoleNode:
name := b.SecurityGroupName(role)
baseGroup = &awstasks.SecurityGroup{
Name: fi.PtrTo(name),
@ -334,19 +337,22 @@ func (b *AWSModelContext) GetSecurityGroups(role kops.InstanceGroupRole) ([]Secu
RemoveExtraRules: []string{"port=22"},
}
baseGroup.Tags = b.CloudTags(name, false)
} else if role == kops.InstanceGroupRoleBastion {
case kops.InstanceGroupRoleBastion:
name := b.SecurityGroupName(role)
baseGroup = &awstasks.SecurityGroup{
Name: fi.PtrTo(name),
VPC: b.LinkToVPC(),
Description: fi.PtrTo("Security group for bastion"),
RemoveExtraRules: []string{"port=22"},
Name: fi.PtrTo(name),
VPC: b.LinkToVPC(),
Description: fi.PtrTo("Security group for bastion"),
RemoveExtraRules: []string{
"port=22", // SSH
"port=3:4", // ICMP
"port=-1", // ICMPv6
},
}
baseGroup.Tags = b.CloudTags(name, false)
} else {
default:
return nil, fmt.Errorf("not a supported security group type")
}
var groups []SecurityGroupInfo
done := make(map[string]bool)

View File

@ -386,12 +386,23 @@ func ParseRemovalRule(rule string) (RemovalRule, error) {
if len(tokens) == 2 {
if tokens[0] == "port" {
port, err := strconv.Atoi(tokens[1])
ports := strings.SplitN(tokens[1], ":", 2)
fromPort, err := strconv.Atoi(ports[0])
if err != nil {
return nil, fmt.Errorf("cannot parse rule %q", rule)
}
toPort := fromPort
if len(ports) > 1 {
toPort, err = strconv.Atoi(ports[1])
if err != nil {
return nil, fmt.Errorf("cannot parse rule %q", rule)
}
}
return &PortRemovalRule{Port: port}, nil
return &PortRemovalRule{
FromPort: fromPort,
ToPort: toPort,
}, nil
} else {
return nil, fmt.Errorf("cannot parse rule %q", rule)
}
@ -400,7 +411,8 @@ func ParseRemovalRule(rule string) (RemovalRule, error) {
}
type PortRemovalRule struct {
Port int
FromPort int
ToPort int
}
var _ RemovalRule = &PortRemovalRule{}
@ -411,10 +423,10 @@ func (r *PortRemovalRule) String() string {
func (r *PortRemovalRule) Matches(permission *ec2.SecurityGroupRule) bool {
// Check if port matches
if permission.FromPort == nil || *permission.FromPort != int64(r.Port) {
if permission.FromPort == nil || *permission.FromPort != int64(r.FromPort) {
return false
}
if permission.ToPort == nil || *permission.ToPort != int64(r.Port) {
if permission.ToPort == nil || *permission.ToPort != int64(r.ToPort) {
return false
}
return true

View File

@ -34,8 +34,10 @@ func TestParseRemovalRule(t *testing.T) {
testNotParse(t, "port=a")
testNotParse(t, "port=22-23")
testParsesAsPort(t, "port=22", 22)
testParsesAsPort(t, "port=443", 443)
testParsesAsPort(t, "port=22", 22, 22)
testParsesAsPort(t, "port=443", 443, 443)
testParsesAsPort(t, "port=22:23", 22, 23)
testParsesAsPort(t, "port=-1", -1, -1)
}
func testNotParse(t *testing.T, rule string) {
@ -45,7 +47,7 @@ func testNotParse(t *testing.T, rule string) {
}
}
func testParsesAsPort(t *testing.T, rule string, port int) {
func testParsesAsPort(t *testing.T, rule string, fromPort int, toPort int) {
r, err := ParseRemovalRule(rule)
if err != nil {
t.Fatalf("unexpected failure to parse rule %q: %v", rule, err)
@ -54,26 +56,33 @@ func testParsesAsPort(t *testing.T, rule string, port int) {
if !ok {
t.Fatalf("unexpected rule type for rule %q: %T", r, err)
}
if portRemovalRule.Port != port {
t.Fatalf("unexpected port for %q, expecting %d, got %q", rule, port, r)
if portRemovalRule.FromPort != fromPort {
t.Fatalf("unexpected fromPort for %q, expecting %d, got %q", rule, fromPort, r)
}
if portRemovalRule.ToPort != toPort {
t.Fatalf("unexpected toPort for %q, expecting %d, got %q", rule, toPort, r)
}
}
func TestPortRemovalRule(t *testing.T) {
r := &PortRemovalRule{Port: 22}
testMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(22), ToPort: aws.Int64(22)})
r := &PortRemovalRule{FromPort: 22, ToPort: 23}
testMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(22), ToPort: aws.Int64(23)})
testNotMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(0), ToPort: aws.Int64(0)})
testNotMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(22), ToPort: aws.Int64(22)})
testNotMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(23), ToPort: aws.Int64(23)})
testNotMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(20), ToPort: aws.Int64(22)})
testNotMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(22), ToPort: aws.Int64(23)})
testNotMatches(t, r, &ec2.SecurityGroupRule{ToPort: aws.Int64(22)})
testNotMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(20), ToPort: aws.Int64(23)})
testNotMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(22), ToPort: aws.Int64(24)})
testNotMatches(t, r, &ec2.SecurityGroupRule{ToPort: aws.Int64(23)})
testNotMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(22)})
testNotMatches(t, r, &ec2.SecurityGroupRule{})
r = &PortRemovalRule{FromPort: -1, ToPort: -1}
testMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(-1), ToPort: aws.Int64(-1)})
}
func TestPortRemovalRule_Zero(t *testing.T) {
r := &PortRemovalRule{Port: 0}
r := &PortRemovalRule{FromPort: 0, ToPort: 0}
testMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(0), ToPort: aws.Int64(0)})
testNotMatches(t, r, &ec2.SecurityGroupRule{FromPort: aws.Int64(0), ToPort: aws.Int64(20)})