allow specifying sg; re-use existing sg; fix race conditions with ip assigning and eventual consistency with sgs

Signed-off-by: Evan Hazlett <ejhazlett@gmail.com>
This commit is contained in:
Evan Hazlett 2015-01-22 16:36:51 -05:00
parent 85f1cb9725
commit 834fa414b4
5 changed files with 108 additions and 59 deletions

View File

@ -32,28 +32,29 @@ const (
) )
type Driver struct { type Driver struct {
Id string Id string
AccessKey string AccessKey string
SecretKey string SecretKey string
SessionToken string SessionToken string
Region string Region string
AMI string AMI string
SSHKeyID int SSHKeyID int
KeyName string KeyName string
InstanceId string InstanceId string
InstanceType string InstanceType string
IPAddress string IPAddress string
MachineName string MachineName string
SecurityGroupId string SecurityGroupName string
ReservationId string SecurityGroupId string
RootSize int64 ReservationId string
VpcId string RootSize int64
SubnetId string VpcId string
Zone string SubnetId string
CaCertPath string Zone string
PrivateKeyPath string CaCertPath string
storePath string PrivateKeyPath string
keyPath string storePath string
keyPath string
} }
type CreateFlags struct { type CreateFlags struct {
@ -123,6 +124,12 @@ func GetCreateFlags() []cli.Flag {
Value: "", Value: "",
EnvVar: "AWS_SUBNET_ID", EnvVar: "AWS_SUBNET_ID",
}, },
cli.StringFlag{
Name: "amazonec2-security-group-name",
Usage: "AWS VPC security group name",
Value: "docker-machine",
EnvVar: "AWS_SECURITY_GROUP_NAME",
},
cli.StringFlag{ cli.StringFlag{
Name: "amazonec2-instance-type", Name: "amazonec2-instance-type",
Usage: "AWS instance type", Usage: "AWS instance type",
@ -152,6 +159,7 @@ func (d *Driver) SetConfigFromFlags(flags drivers.DriverOptions) error {
d.InstanceType = flags.String("amazonec2-instance-type") d.InstanceType = flags.String("amazonec2-instance-type")
d.VpcId = flags.String("amazonec2-vpc-id") d.VpcId = flags.String("amazonec2-vpc-id")
d.SubnetId = flags.String("amazonec2-subnet-id") d.SubnetId = flags.String("amazonec2-subnet-id")
d.SecurityGroupName = flags.String("amazonec2-security-group-name")
zone := flags.String("amazonec2-zone") zone := flags.String("amazonec2-zone")
d.Zone = zone[:] d.Zone = zone[:]
d.RootSize = int64(flags.Int("amazonec2-root-size")) d.RootSize = int64(flags.Int("amazonec2-root-size"))
@ -182,7 +190,7 @@ func (d *Driver) Create() error {
return fmt.Errorf("unable to create key pair: %s", err) return fmt.Errorf("unable to create key pair: %s", err)
} }
if err := d.configureSecurityGroup(); err != nil { if err := d.configureSecurityGroup(d.SecurityGroupName); err != nil {
return err return err
} }
@ -427,8 +435,21 @@ func (d *Driver) updateDriver() error {
if err != nil { if err != nil {
return err return err
} }
d.InstanceId = inst.InstanceId // wait for ipaddress
d.IPAddress = inst.IpAddress for {
i, err := d.getInstance()
if err != nil {
return err
}
if i.IpAddress == "" {
time.Sleep(1 * time.Second)
continue
}
d.InstanceId = inst.InstanceId
d.IPAddress = inst.IpAddress
break
}
return nil return nil
} }
@ -500,18 +521,19 @@ func (d *Driver) terminate() error {
return nil return nil
} }
func (d *Driver) configureSecurityGroup() error { func (d *Driver) configureSecurityGroup(groupName string) error {
log.Debugf("configuring security group in %s", d.VpcId) log.Debugf("configuring security group in %s", d.VpcId)
var securityGroup *amz.SecurityGroup
groups, err := d.getClient().GetSecurityGroups() groups, err := d.getClient().GetSecurityGroups()
if err != nil { if err != nil {
return err return err
} }
var securityGroup *amz.SecurityGroup
for _, grp := range groups { for _, grp := range groups {
if grp.GroupName == machineSecurityGroupName { if grp.GroupName == groupName {
log.Debugf("found existing security group (%s) in %s", machineSecurityGroupName, d.VpcId) log.Debugf("found existing security group (%s) in %s", groupName, d.VpcId)
securityGroup = &grp securityGroup = &grp
break break
} }
@ -519,21 +541,32 @@ func (d *Driver) configureSecurityGroup() error {
// if not found, create // if not found, create
if securityGroup == nil { if securityGroup == nil {
log.Debugf("creating security group (%s) in %s", machineSecurityGroupName, d.VpcId) log.Debugf("creating security group (%s) in %s", groupName, d.VpcId)
group, err := d.getClient().CreateSecurityGroup(machineSecurityGroupName, "Docker Machine", d.VpcId) group, err := d.getClient().CreateSecurityGroup(groupName, "Docker Machine", d.VpcId)
if err != nil { if err != nil {
return err return err
} }
securityGroup = group securityGroup = group
// wait until created (dat eventual consistency)
log.Debugf("waiting for group (%s) to become available", group.GroupId)
for {
_, err := d.getClient().GetSecurityGroupById(group.GroupId)
if err == nil {
break
}
log.Debug(err)
time.Sleep(1 * time.Second)
}
} }
d.SecurityGroupId = securityGroup.GroupId d.SecurityGroupId = securityGroup.GroupId
log.Debugf("configuring authorization %s", ipRange) log.Debugf("configuring security group authorization for %s", ipRange)
perms := configureSecurityGroupPermissions(securityGroup) perms := configureSecurityGroupPermissions(securityGroup)
if len(perms) != 0 { if len(perms) != 0 {
log.Debugf("authorizing group %s with permissions: %v", securityGroup.GroupName, perms)
if err := d.getClient().AuthorizeSecurityGroup(d.SecurityGroupId, perms); err != nil { if err := d.getClient().AuthorizeSecurityGroup(d.SecurityGroupId, perms); err != nil {
return err return err
} }
@ -559,10 +592,10 @@ func configureSecurityGroupPermissions(group *amz.SecurityGroup) []amz.IpPermiss
if !hasSshPort { if !hasSshPort {
perm := amz.IpPermission{ perm := amz.IpPermission{
Protocol: "tcp", IpProtocol: "tcp",
FromPort: 22, FromPort: 22,
ToPort: 22, ToPort: 22,
IpRange: ipRange, IpRange: ipRange,
} }
perms = append(perms, perm) perms = append(perms, perm)
@ -570,10 +603,10 @@ func configureSecurityGroupPermissions(group *amz.SecurityGroup) []amz.IpPermiss
if !hasDockerPort { if !hasDockerPort {
perm := amz.IpPermission{ perm := amz.IpPermission{
Protocol: "tcp", IpProtocol: "tcp",
FromPort: dockerPort, FromPort: dockerPort,
ToPort: dockerPort, ToPort: dockerPort,
IpRange: ipRange, IpRange: ipRange,
} }
perms = append(perms, perm) perms = append(perms, perm)

View File

@ -32,9 +32,9 @@ func TestConfigureSecurityGroupPermissionsSshOnly(t *testing.T) {
group.IpPermissions = []amz.IpPermission{ group.IpPermissions = []amz.IpPermission{
{ {
Protocol: "tcp", IpProtocol: "tcp",
FromPort: testSshPort, FromPort: testSshPort,
ToPort: testSshPort, ToPort: testSshPort,
}, },
} }
@ -54,9 +54,9 @@ func TestConfigureSecurityGroupPermissionsDockerOnly(t *testing.T) {
group.IpPermissions = []amz.IpPermission{ group.IpPermissions = []amz.IpPermission{
{ {
Protocol: "tcp", IpProtocol: "tcp",
FromPort: testDockerPort, FromPort: testDockerPort,
ToPort: testDockerPort, ToPort: testDockerPort,
}, },
} }
@ -76,14 +76,14 @@ func TestConfigureSecurityGroupPermissionsDockerAndSsh(t *testing.T) {
group.IpPermissions = []amz.IpPermission{ group.IpPermissions = []amz.IpPermission{
{ {
Protocol: "tcp", IpProtocol: "tcp",
FromPort: testSshPort, FromPort: testSshPort,
ToPort: testSshPort, ToPort: testSshPort,
}, },
{ {
Protocol: "tcp", IpProtocol: "tcp",
FromPort: testDockerPort, FromPort: testDockerPort,
ToPort: testDockerPort, ToPort: testDockerPort,
}, },
} }

View File

@ -347,7 +347,7 @@ func (e *EC2) AuthorizeSecurityGroup(groupId string, permissions []IpPermission)
for index, perm := range permissions { for index, perm := range permissions {
n := index + 1 // amazon starts counting from 1 not 0 n := index + 1 // amazon starts counting from 1 not 0
v.Set(fmt.Sprintf("IpPermissions.%d.IpProtocol", n), perm.Protocol) v.Set(fmt.Sprintf("IpPermissions.%d.IpProtocol", n), perm.IpProtocol)
v.Set(fmt.Sprintf("IpPermissions.%d.FromPort", n), strconv.Itoa(perm.FromPort)) v.Set(fmt.Sprintf("IpPermissions.%d.FromPort", n), strconv.Itoa(perm.FromPort))
v.Set(fmt.Sprintf("IpPermissions.%d.ToPort", n), strconv.Itoa(perm.ToPort)) v.Set(fmt.Sprintf("IpPermissions.%d.ToPort", n), strconv.Itoa(perm.ToPort))
v.Set(fmt.Sprintf("IpPermissions.%d.IpRanges.1.CidrIp", n), perm.IpRange) v.Set(fmt.Sprintf("IpPermissions.%d.IpRanges.1.CidrIp", n), perm.IpRange)
@ -401,6 +401,21 @@ func (e *EC2) GetSecurityGroups() ([]SecurityGroup, error) {
return sgs, nil return sgs, nil
} }
func (e *EC2) GetSecurityGroupById(id string) (*SecurityGroup, error) {
groups, err := e.GetSecurityGroups()
if err != nil {
return nil, err
}
for _, g := range groups {
if g.GroupId == id {
return &g, nil
}
}
return nil, nil
}
func (e *EC2) GetSubnets() ([]Subnet, error) { func (e *EC2) GetSubnets() ([]Subnet, error) {
subnets := []Subnet{} subnets := []Subnet{}
resp, err := e.performStandardAction("DescribeSubnets") resp, err := e.performStandardAction("DescribeSubnets")

View File

@ -1,8 +1,8 @@
package amz package amz
type IpPermission struct { type IpPermission struct {
Protocol string IpProtocol string `xml:"ipProtocol"`
FromPort int FromPort int `xml:"fromPort"`
ToPort int ToPort int `xml:"toPort"`
IpRange string IpRange string `xml:"ipRanges"`
} }

View File

@ -15,6 +15,7 @@ type SecurityGroup struct {
GroupName string `xml:"groupName"` GroupName string `xml:"groupName"`
GroupId string `xml:"groupId"` GroupId string `xml:"groupId"`
VpcId string `xml:"vpcId"` VpcId string `xml:"vpcId"`
IpPermissions []IpPermission `xml:"ipPermissions,omitempty"` OwnerId string `xml:"ownerId"`
IpPermissionsEgress []IpPermission `xml:"ipPermissionsEgress,omitempty"` IpPermissions []IpPermission `xml:"ipPermissions>item,omitempty"`
IpPermissionsEgress []IpPermission `xml:"ipPermissionsEgress>item,omitempty"`
} }