amazonec2: Convert EC2 API calls to official SDK

This does an almost 1-to-1 translation of API calls.

The differences are as follows:
1. Use the SDK waiter for spot instance request fulfillment
2. Uses the toplevel private/public ip fields instead of the
   networkinterface's fields
3. Recognizes the 'Terminated' state as an error explicitly instead of
   implicitly.
4. Uses filters on DescribeSecurityGroups to find the correct one more
   efficiently and to limit to a given VPC.

Other than that, it really should be identical apart from the perhaps
obvious error message differences.

Signed-off-by: Euan <euank@euank.com>
This commit is contained in:
Euan 2015-11-25 00:42:39 -08:00 committed by Jean-Laurent de Morlhon
parent 7f2e3c1d19
commit 8d98d2b7b7
2 changed files with 303 additions and 156 deletions

View File

@ -12,7 +12,11 @@ import (
"strings" "strings"
"time" "time"
"github.com/docker/machine/drivers/amazonec2/amz" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/docker/machine/libmachine/drivers" "github.com/docker/machine/libmachine/drivers"
"github.com/docker/machine/libmachine/log" "github.com/docker/machine/libmachine/log"
"github.com/docker/machine/libmachine/mcnflag" "github.com/docker/machine/libmachine/mcnflag"
@ -35,6 +39,10 @@ const (
defaultSpotPrice = "0.50" defaultSpotPrice = "0.50"
) )
const (
keypairNotFoundCode = "InvalidKeyPair.NotFound"
)
var ( var (
dockerPort = 2376 dockerPort = 2376
swarmPort = 3376 swarmPort = 3376
@ -233,19 +241,21 @@ func (d *Driver) SetConfigFromFlags(flags drivers.DriverOptions) error {
} }
if d.SubnetId != "" && d.VpcId != "" { if d.SubnetId != "" && d.VpcId != "" {
filters := []amz.Filter{ subnetFilter := []*ec2.Filter{
{ {
Name: "subnet-id", Name: aws.String("subnet-id"),
Value: d.SubnetId, Values: []*string{&d.SubnetId},
}, },
} }
subnets, err := d.getClient().GetSubnets(filters) subnets, err := d.getClient().DescribeSubnets(&ec2.DescribeSubnetsInput{
Filters: subnetFilter,
})
if err != nil { if err != nil {
return err return err
} }
if subnets[0].VpcId != d.VpcId { if *subnets.Subnets[0].VpcId != d.VpcId {
return fmt.Errorf("SubnetId: %s does not belong to VpcId: %s", d.SubnetId, d.VpcId) return fmt.Errorf("SubnetId: %s does not belong to VpcId: %s", d.SubnetId, d.VpcId)
} }
} }
@ -275,44 +285,52 @@ func (d *Driver) DriverName() string {
func (d *Driver) checkPrereqs() error { func (d *Driver) checkPrereqs() error {
// check for existing keypair // check for existing keypair
key, err := d.getClient().GetKeyPair(d.MachineName) key, err := d.getClient().DescribeKeyPairs(&ec2.DescribeKeyPairsInput{
KeyNames: []*string{&d.MachineName},
})
if err != nil { if err != nil {
return err if awsErr, ok := err.(awserr.Error); ok && awsErr.Code() == keypairNotFoundCode {
// Not a real error for 'NotFound' since we're checking existance anyways
} else {
return err
}
} }
if key != nil { if err == nil && len(key.KeyPairs) != 0 {
return fmt.Errorf("There is already a keypair with the name %s. Please either remove that keypair or use a different machine name.", d.MachineName) return fmt.Errorf("There is already a keypair with the name %s. Please either remove that keypair or use a different machine name.", d.MachineName)
} }
regionZone := d.Region + d.Zone regionZone := d.Region + d.Zone
if d.SubnetId == "" { if d.SubnetId == "" {
filters := []amz.Filter{ filters := []*ec2.Filter{
{ {
Name: "availabilityZone", Name: aws.String("availability-zone"),
Value: regionZone, Values: []*string{&regionZone},
}, },
{ {
Name: "vpc-id", Name: aws.String("vpc-id"),
Value: d.VpcId, Values: []*string{&d.VpcId},
}, },
} }
subnets, err := d.getClient().GetSubnets(filters) subnets, err := d.getClient().DescribeSubnets(&ec2.DescribeSubnetsInput{
Filters: filters,
})
if err != nil { if err != nil {
return err return err
} }
if len(subnets) == 0 { if len(subnets.Subnets) == 0 {
return fmt.Errorf("unable to find a subnet in the zone: %s", regionZone) return fmt.Errorf("unable to find a subnet in the zone: %s", regionZone)
} }
d.SubnetId = subnets[0].SubnetId d.SubnetId = *subnets.Subnets[0].SubnetId
// try to find default // try to find default
if len(subnets) > 1 { if len(subnets.Subnets) > 1 {
for _, subnet := range subnets { for _, subnet := range subnets.Subnets {
if subnet.DefaultForAz { if *subnet.DefaultForAz {
d.SubnetId = subnet.SubnetId d.SubnetId = *subnet.SubnetId
break break
} }
} }
@ -354,53 +372,123 @@ func (d *Driver) Create() error {
return err return err
} }
bdm := &amz.BlockDeviceMapping{ bdm := &ec2.BlockDeviceMapping{
DeviceName: "/dev/sda1", DeviceName: aws.String("/dev/sda1"),
VolumeSize: d.RootSize, Ebs: &ec2.EbsBlockDevice{
DeleteOnTermination: true, VolumeSize: aws.Int64(d.RootSize),
VolumeType: "gp2", VolumeType: aws.String("gp2"),
DeleteOnTermination: aws.Bool(true),
},
} }
netSpecs := []*ec2.InstanceNetworkInterfaceSpecification{{
DeviceIndex: aws.Int64(0), // eth0
Groups: []*string{&d.SecurityGroupId},
SubnetId: &d.SubnetId,
AssociatePublicIpAddress: aws.Bool(!d.PrivateIPOnly),
}}
regionZone := d.Region + d.Zone
log.Debugf("launching instance in subnet %s", d.SubnetId) log.Debugf("launching instance in subnet %s", d.SubnetId)
var instance amz.EC2Instance
var instance *ec2.Instance
if d.RequestSpotInstance { if d.RequestSpotInstance {
spotInstanceRequestId, err := d.getClient().RequestSpotInstances(d.AMI, d.InstanceType, d.Zone, 1, d.SecurityGroupId, d.KeyName, d.SubnetId, bdm, d.IamInstanceProfile, d.SpotPrice, d.Monitoring) spotInstanceRequest, err := d.getClient().RequestSpotInstances(&ec2.RequestSpotInstancesInput{
LaunchSpecification: &ec2.RequestSpotLaunchSpecification{
ImageId: &d.AMI,
Placement: &ec2.SpotPlacement{
AvailabilityZone: &regionZone,
},
KeyName: &d.KeyName,
InstanceType: &d.InstanceType,
NetworkInterfaces: netSpecs,
Monitoring: &ec2.RunInstancesMonitoringEnabled{Enabled: aws.Bool(d.Monitoring)},
IamInstanceProfile: &ec2.IamInstanceProfileSpecification{
Name: &d.IamInstanceProfile,
},
BlockDeviceMappings: []*ec2.BlockDeviceMapping{bdm},
},
InstanceCount: aws.Int64(1),
SpotPrice: &d.SpotPrice,
})
if err != nil { if err != nil {
return fmt.Errorf("Error request spot instance: %s", err) return fmt.Errorf("Error request spot instance: %s", err)
} }
var instanceId string
var spotInstanceRequestStatus string
log.Info("Waiting for spot instance...") log.Info("Waiting for spot instance...")
// check until fulfilled err = d.getClient().WaitUntilSpotInstanceRequestFulfilled(&ec2.DescribeSpotInstanceRequestsInput{
for instanceId == "" { SpotInstanceRequestIds: []*string{spotInstanceRequest.SpotInstanceRequests[0].SpotInstanceRequestId},
time.Sleep(time.Second * 5) })
spotInstanceRequestStatus, instanceId, err = d.getClient().DescribeSpotInstanceRequests(spotInstanceRequestId)
if err != nil {
return fmt.Errorf("Error describe spot instance request: %s", err)
}
log.Debugf("spot instance request status: %s", spotInstanceRequestStatus)
}
instance, err = d.getClient().GetInstance(instanceId)
if err != nil { if err != nil {
return fmt.Errorf("Error get instance: %s", err) return fmt.Errorf("Error fulfilling spot request: %v", err)
}
log.Info("Created spot instance request %v", *spotInstanceRequest.SpotInstanceRequests[0].SpotInstanceRequestId)
// resolve instance id
for i := 0; i < 3; i++ {
// Even though the waiter succeeded, eventual consistency means we could
// get a describe output that does not include this information. Try a
// few times just in case
var resolvedSpotInstance *ec2.DescribeSpotInstanceRequestsOutput
resolvedSpotInstance, err = d.getClient().DescribeSpotInstanceRequests(&ec2.DescribeSpotInstanceRequestsInput{
SpotInstanceRequestIds: []*string{spotInstanceRequest.SpotInstanceRequests[0].SpotInstanceRequestId},
})
if err != nil {
// Unexpected; no need to retry
return fmt.Errorf("Error describing previously made spot instance request: %v", err)
}
maybeInstanceId := resolvedSpotInstance.SpotInstanceRequests[0].InstanceId
if maybeInstanceId != nil {
var instances *ec2.DescribeInstancesOutput
instances, err = d.getClient().DescribeInstances(&ec2.DescribeInstancesInput{
InstanceIds: []*string{maybeInstanceId},
})
if err != nil {
// Retry if we get an id from spot instance but EC2 doesn't recognize it yet; see above, eventual consistency possible
continue
}
instance = instances.Reservations[0].Instances[0]
err = nil
break
}
time.Sleep(5 * time.Second)
}
if err != nil {
return fmt.Errorf("Error resolving spot instance to real instance: %v", err)
} }
} else { } else {
inst, err := d.getClient().RunInstance(d.AMI, d.InstanceType, d.Zone, 1, 1, d.SecurityGroupId, d.KeyName, d.SubnetId, bdm, d.IamInstanceProfile, d.PrivateIPOnly, d.Monitoring) inst, err := d.getClient().RunInstances(&ec2.RunInstancesInput{
ImageId: &d.AMI,
MinCount: aws.Int64(1),
MaxCount: aws.Int64(1),
Placement: &ec2.Placement{
AvailabilityZone: &regionZone,
},
KeyName: &d.KeyName,
InstanceType: &d.InstanceType,
NetworkInterfaces: netSpecs,
Monitoring: &ec2.RunInstancesMonitoringEnabled{Enabled: aws.Bool(d.Monitoring)},
IamInstanceProfile: &ec2.IamInstanceProfileSpecification{
Name: &d.IamInstanceProfile,
},
BlockDeviceMappings: []*ec2.BlockDeviceMapping{bdm},
})
if err != nil { if err != nil {
return fmt.Errorf("Error launching instance: %s", err) return fmt.Errorf("Error launching instance: %s", err)
} }
instance = inst instance = inst.Instances[0]
} }
d.InstanceId = instance.InstanceId d.InstanceId = *instance.InstanceId
log.Debug("waiting for ip address to become available") log.Debug("waiting for ip address to become available")
if err := mcnutils.WaitFor(d.instanceIpAvailable); err != nil { if err := mcnutils.WaitFor(d.instanceIpAvailable); err != nil {
return err return err
} }
if len(instance.NetworkInterfaceSet) > 0 { if instance.PrivateIpAddress != nil {
d.PrivateIPAddress = instance.NetworkInterfaceSet[0].PrivateIpAddress d.PrivateIPAddress = *instance.PrivateIpAddress
} }
d.waitForInstance() d.waitForInstance()
@ -412,12 +500,15 @@ func (d *Driver) Create() error {
) )
log.Debug("Settings tags for instance") log.Debug("Settings tags for instance")
tags := map[string]string{ _, err := d.getClient().CreateTags(&ec2.CreateTagsInput{
"Name": d.MachineName, Resources: []*string{&d.InstanceId},
} Tags: []*ec2.Tag{{
Key: aws.String("Name"),
if err := d.getClient().CreateTags(d.InstanceId, tags); err != nil { Value: &d.MachineName,
return err }},
})
if err != nil {
return fmt.Errorf("Unable to tag instance %s: %s", d.InstanceId, err)
} }
return nil return nil
@ -441,14 +532,23 @@ func (d *Driver) GetIP() (string, error) {
} }
if d.PrivateIPOnly { if d.PrivateIPOnly {
return inst.PrivateIpAddress, nil if inst.PrivateIpAddress == nil {
return "", fmt.Errorf("No private IP for instance %v", *inst.InstanceId)
}
return *inst.PrivateIpAddress, nil
} }
if d.UsePrivateIP { if d.UsePrivateIP {
return inst.PrivateIpAddress, nil if inst.PrivateIpAddress == nil {
return "", fmt.Errorf("No private IP for instance %v", *inst.InstanceId)
}
return *inst.PrivateIpAddress, nil
} }
return inst.IpAddress, nil if inst.PublicIpAddress == nil {
return "", fmt.Errorf("No IP for instance %v", *inst.InstanceId)
}
return *inst.PublicIpAddress, nil
} }
func (d *Driver) GetState() (state.State, error) { func (d *Driver) GetState() (state.State, error) {
@ -456,18 +556,21 @@ func (d *Driver) GetState() (state.State, error) {
if err != nil { if err != nil {
return state.Error, err return state.Error, err
} }
switch inst.InstanceState.Name { switch *inst.State.Name {
case "pending": case ec2.InstanceStateNamePending:
return state.Starting, nil return state.Starting, nil
case "running": case ec2.InstanceStateNameRunning:
return state.Running, nil return state.Running, nil
case "stopping": case ec2.InstanceStateNameStopping:
return state.Stopping, nil return state.Stopping, nil
case "shutting-down": case ec2.InstanceStateNameShuttingDown:
return state.Stopping, nil return state.Stopping, nil
case "stopped": case ec2.InstanceStateNameStopped:
return state.Stopped, nil return state.Stopped, nil
case ec2.InstanceStateNameTerminated:
return state.Error, nil
default: default:
log.Warnf("unrecognized instance state: %v", *inst.State.Name)
return state.Error, nil return state.Error, nil
} }
} }
@ -487,7 +590,10 @@ func (d *Driver) GetSSHUsername() string {
} }
func (d *Driver) Start() error { func (d *Driver) Start() error {
if err := d.getClient().StartInstance(d.InstanceId); err != nil { _, err := d.getClient().StartInstances(&ec2.StartInstancesInput{
InstanceIds: []*string{&d.InstanceId},
})
if err != nil {
return err return err
} }
@ -499,10 +605,11 @@ func (d *Driver) Start() error {
} }
func (d *Driver) Stop() error { func (d *Driver) Stop() error {
if err := d.getClient().StopInstance(d.InstanceId, false); err != nil { _, err := d.getClient().StopInstances(&ec2.StopInstancesInput{
return err InstanceIds: []*string{&d.InstanceId},
} Force: aws.Bool(false),
return nil })
return err
} }
func (d *Driver) Remove() error { func (d *Driver) Remove() error {
@ -520,31 +627,35 @@ func (d *Driver) Remove() error {
} }
func (d *Driver) Restart() error { func (d *Driver) Restart() error {
if err := d.getClient().RestartInstance(d.InstanceId); err != nil { _, err := d.getClient().RebootInstances(&ec2.RebootInstancesInput{
return fmt.Errorf("unable to restart instance: %s", err) InstanceIds: []*string{&d.InstanceId},
} })
return nil return err
} }
func (d *Driver) Kill() error { func (d *Driver) Kill() error {
if err := d.getClient().StopInstance(d.InstanceId, true); err != nil { _, err := d.getClient().StopInstances(&ec2.StopInstancesInput{
return err InstanceIds: []*string{&d.InstanceId},
} Force: aws.Bool(true),
return nil })
return err
} }
func (d *Driver) getClient() *amz.EC2 { func (d *Driver) getClient() *ec2.EC2 {
auth := amz.GetAuth(d.AccessKey, d.SecretKey, d.SessionToken) config := aws.NewConfig()
return amz.NewEC2(auth, d.Region) config = config.WithRegion(d.Region)
config = config.WithCredentials(credentials.NewStaticCredentials(d.AccessKey, d.SecretKey, d.SessionToken))
return ec2.New(session.New(config))
} }
func (d *Driver) getInstance() (*amz.EC2Instance, error) { func (d *Driver) getInstance() (*ec2.Instance, error) {
instance, err := d.getClient().GetInstance(d.InstanceId) instances, err := d.getClient().DescribeInstances(&ec2.DescribeInstancesInput{
InstanceIds: []*string{&d.InstanceId},
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
return instances.Reservations[0].Instances[0], nil
return &instance, nil
} }
func (d *Driver) instanceIsRunning() bool { func (d *Driver) instanceIsRunning() bool {
@ -567,7 +678,6 @@ func (d *Driver) waitForInstance() error {
} }
func (d *Driver) createKeyPair() error { func (d *Driver) createKeyPair() error {
if err := ssh.GenerateSSHKey(d.GetSSHKeyPath()); err != nil { if err := ssh.GenerateSSHKey(d.GetSSHKeyPath()); err != nil {
return err return err
} }
@ -580,11 +690,13 @@ func (d *Driver) createKeyPair() error {
keyName := d.MachineName keyName := d.MachineName
log.Debugf("creating key pair: %s", keyName) log.Debugf("creating key pair: %s", keyName)
_, err = d.getClient().ImportKeyPair(&ec2.ImportKeyPairInput{
if err := d.getClient().ImportKeyPair(keyName, string(publicKey)); err != nil { KeyName: &keyName,
PublicKeyMaterial: publicKey,
})
if err != nil {
return err return err
} }
d.KeyName = keyName d.KeyName = keyName
return nil return nil
} }
@ -595,10 +707,12 @@ func (d *Driver) terminate() error {
} }
log.Debugf("terminating instance: %s", d.InstanceId) log.Debugf("terminating instance: %s", d.InstanceId)
if err := d.getClient().TerminateInstance(d.InstanceId); err != nil { _, err := d.getClient().TerminateInstances(&ec2.TerminateInstancesInput{
InstanceIds: []*string{&d.InstanceId},
})
if err != nil {
return fmt.Errorf("unable to terminate instance: %s", err) return fmt.Errorf("unable to terminate instance: %s", err)
} }
return nil return nil
} }
@ -608,9 +722,15 @@ func (d *Driver) isSwarmMaster() bool {
func (d *Driver) securityGroupAvailableFunc(id string) func() bool { func (d *Driver) securityGroupAvailableFunc(id string) func() bool {
return func() bool { return func() bool {
_, err := d.getClient().GetSecurityGroupById(id)
if err == nil { securityGroup, err := d.getClient().DescribeSecurityGroups(&ec2.DescribeSecurityGroupsInput{
GroupIds: []*string{&id},
})
if err == nil && len(securityGroup.SecurityGroups) > 0 {
return true return true
} else if err == nil {
log.Debugf("No security group with id %v found", id)
return false
} }
log.Debug(err) log.Debug(err)
return false return false
@ -620,92 +740,112 @@ func (d *Driver) securityGroupAvailableFunc(id string) func() bool {
func (d *Driver) configureSecurityGroup(groupName string) error { func (d *Driver) configureSecurityGroup(groupName string) error {
log.Debugf("configuring security group in %s", d.VpcId) log.Debugf("configuring security group in %s", d.VpcId)
var securityGroup *amz.SecurityGroup var group *ec2.SecurityGroup
filters := []*ec2.Filter{
groups, err := d.getClient().GetSecurityGroups() {
Name: aws.String("group-name"),
Values: []*string{&groupName},
},
{
Name: aws.String("vpc-id"),
Values: []*string{&d.VpcId},
},
}
groups, err := d.getClient().DescribeSecurityGroups(&ec2.DescribeSecurityGroupsInput{
Filters: filters,
})
if err != nil { if err != nil {
return err return err
} }
for _, grp := range groups { if len(groups.SecurityGroups) > 0 {
if grp.GroupName == groupName { log.Debugf("found existing security group (%s) in %s", groupName, d.VpcId)
log.Debugf("found existing security group (%s) in %s", groupName, d.VpcId) group = groups.SecurityGroups[0]
securityGroup = &grp
break
}
} }
// if not found, create // if not found, create
if securityGroup == nil { if group == nil {
log.Debugf("creating security group (%s) in %s", groupName, d.VpcId) log.Debugf("creating security group (%s) in %s", groupName, d.VpcId)
group, err := d.getClient().CreateSecurityGroup(groupName, "Docker Machine", d.VpcId) groupResp, err := d.getClient().CreateSecurityGroup(&ec2.CreateSecurityGroupInput{
GroupName: &groupName,
Description: aws.String("Docker Machine"),
VpcId: &d.VpcId,
})
if err != nil { if err != nil {
return err return err
} }
securityGroup = group // Manually translate into the security group construct
group = &ec2.SecurityGroup{
GroupId: groupResp.GroupId,
VpcId: aws.String(d.VpcId),
GroupName: aws.String(groupName),
}
// wait until created (dat eventual consistency) // wait until created (dat eventual consistency)
log.Debugf("waiting for group (%s) to become available", group.GroupId) log.Debugf("waiting for group (%s) to become available", *group.GroupId)
if err := mcnutils.WaitFor(d.securityGroupAvailableFunc(group.GroupId)); err != nil { if err := mcnutils.WaitFor(d.securityGroupAvailableFunc(*group.GroupId)); err != nil {
return err return err
} }
} }
d.SecurityGroupId = securityGroup.GroupId d.SecurityGroupId = *group.GroupId
perms := d.configureSecurityGroupPermissions(securityGroup) perms := d.configureSecurityGroupPermissions(group)
if len(perms) != 0 { if len(perms) != 0 {
log.Debugf("authorizing group %s with permissions: %v", securityGroup.GroupName, perms) log.Debugf("authorizing group %s with permissions: %v", groupName, perms)
if err := d.getClient().AuthorizeSecurityGroup(d.SecurityGroupId, perms); err != nil { _, err := d.getClient().AuthorizeSecurityGroupIngress(&ec2.AuthorizeSecurityGroupIngressInput{
GroupId: group.GroupId,
IpPermissions: perms,
})
if err != nil {
return err return err
} }
} }
return nil return nil
} }
func (d *Driver) configureSecurityGroupPermissions(group *amz.SecurityGroup) []amz.IpPermission { func (d *Driver) configureSecurityGroupPermissions(group *ec2.SecurityGroup) []*ec2.IpPermission {
hasSshPort := false hasSshPort := false
hasDockerPort := false hasDockerPort := false
hasSwarmPort := false hasSwarmPort := false
for _, p := range group.IpPermissions { for _, p := range group.IpPermissions {
switch p.FromPort { switch *p.FromPort {
case 22: case 22:
hasSshPort = true hasSshPort = true
case dockerPort: case int64(dockerPort):
hasDockerPort = true hasDockerPort = true
case swarmPort: case int64(swarmPort):
hasSwarmPort = true hasSwarmPort = true
} }
} }
perms := []amz.IpPermission{} perms := []*ec2.IpPermission{}
if !hasSshPort { if !hasSshPort {
perms = append(perms, amz.IpPermission{ perms = append(perms, &ec2.IpPermission{
IpProtocol: "tcp", IpProtocol: aws.String("tcp"),
FromPort: 22, FromPort: aws.Int64(22),
ToPort: 22, ToPort: aws.Int64(22),
IpRange: ipRange, IpRanges: []*ec2.IpRange{{CidrIp: aws.String(ipRange)}},
}) })
} }
if !hasDockerPort { if !hasDockerPort {
perms = append(perms, amz.IpPermission{ perms = append(perms, &ec2.IpPermission{
IpProtocol: "tcp", IpProtocol: aws.String("tcp"),
FromPort: dockerPort, FromPort: aws.Int64(int64(dockerPort)),
ToPort: dockerPort, ToPort: aws.Int64(int64(dockerPort)),
IpRange: ipRange, IpRanges: []*ec2.IpRange{{CidrIp: aws.String(ipRange)}},
}) })
} }
if !hasSwarmPort && d.SwarmMaster { if !hasSwarmPort && d.SwarmMaster {
perms = append(perms, amz.IpPermission{ perms = append(perms, &ec2.IpPermission{
IpProtocol: "tcp", IpProtocol: aws.String("tcp"),
FromPort: swarmPort, FromPort: aws.Int64(int64(swarmPort)),
ToPort: swarmPort, ToPort: aws.Int64(int64(swarmPort)),
IpRange: ipRange, IpRanges: []*ec2.IpRange{{CidrIp: aws.String(ipRange)}},
}) })
} }
@ -717,7 +857,10 @@ func (d *Driver) configureSecurityGroupPermissions(group *amz.SecurityGroup) []a
func (d *Driver) deleteSecurityGroup() error { func (d *Driver) deleteSecurityGroup() error {
log.Debugf("deleting security group %s", d.SecurityGroupId) log.Debugf("deleting security group %s", d.SecurityGroupId)
if err := d.getClient().DeleteSecurityGroup(d.SecurityGroupId); err != nil { _, err := d.getClient().DeleteSecurityGroup(&ec2.DeleteSecurityGroupInput{
GroupId: &d.SecurityGroupId,
})
if err != nil {
return err return err
} }
@ -727,7 +870,10 @@ func (d *Driver) deleteSecurityGroup() error {
func (d *Driver) deleteKeyPair() error { func (d *Driver) deleteKeyPair() error {
log.Debugf("deleting key pair: %s", d.KeyName) log.Debugf("deleting key pair: %s", d.KeyName)
if err := d.getClient().DeleteKeyPair(d.KeyName); err != nil { _, err := d.getClient().DeleteKeyPair(&ec2.DeleteKeyPairInput{
KeyName: &d.KeyName,
})
if err != nil {
return err return err
} }

View File

@ -5,9 +5,10 @@ import (
"os" "os"
"testing" "testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/docker/machine/commands/commandstest" "github.com/docker/machine/commands/commandstest"
"github.com/docker/machine/commands/mcndirs" "github.com/docker/machine/commands/mcndirs"
"github.com/docker/machine/drivers/amazonec2/amz"
"github.com/docker/machine/libmachine/drivers" "github.com/docker/machine/libmachine/drivers"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -24,10 +25,10 @@ const (
) )
var ( var (
securityGroup = amz.SecurityGroup{ securityGroup = &ec2.SecurityGroup{
GroupName: "test-group", GroupName: aws.String("test-group"),
GroupId: "12345", GroupId: aws.String("12345"),
VpcId: "12345", VpcId: aws.String("12345"),
} }
) )
@ -96,7 +97,7 @@ func TestConfigureSecurityGroupPermissionsEmpty(t *testing.T) {
defer cleanup() defer cleanup()
group := securityGroup group := securityGroup
perms := d.configureSecurityGroupPermissions(&group) perms := d.configureSecurityGroupPermissions(group)
if len(perms) != 2 { if len(perms) != 2 {
t.Fatalf("expected 2 permissions; received %d", len(perms)) t.Fatalf("expected 2 permissions; received %d", len(perms))
} }
@ -111,20 +112,20 @@ func TestConfigureSecurityGroupPermissionsSshOnly(t *testing.T) {
group := securityGroup group := securityGroup
group.IpPermissions = []amz.IpPermission{ group.IpPermissions = []*ec2.IpPermission{
{ {
IpProtocol: "tcp", IpProtocol: aws.String("tcp"),
FromPort: testSSHPort, FromPort: aws.Int64(int64(testSSHPort)),
ToPort: testSSHPort, ToPort: aws.Int64(int64(testSSHPort)),
}, },
} }
perms := d.configureSecurityGroupPermissions(&group) perms := d.configureSecurityGroupPermissions(group)
if len(perms) != 1 { if len(perms) != 1 {
t.Fatalf("expected 1 permission; received %d", len(perms)) t.Fatalf("expected 1 permission; received %d", len(perms))
} }
receivedPort := perms[0].FromPort receivedPort := *perms[0].FromPort
if receivedPort != testDockerPort { if receivedPort != testDockerPort {
t.Fatalf("expected permission on port %d; received port %d", testDockerPort, receivedPort) t.Fatalf("expected permission on port %d; received port %d", testDockerPort, receivedPort)
} }
@ -139,20 +140,20 @@ func TestConfigureSecurityGroupPermissionsDockerOnly(t *testing.T) {
group := securityGroup group := securityGroup
group.IpPermissions = []amz.IpPermission{ group.IpPermissions = []*ec2.IpPermission{
{ {
IpProtocol: "tcp", IpProtocol: aws.String("tcp"),
FromPort: testDockerPort, FromPort: aws.Int64((testDockerPort)),
ToPort: testDockerPort, ToPort: aws.Int64((testDockerPort)),
}, },
} }
perms := d.configureSecurityGroupPermissions(&group) perms := d.configureSecurityGroupPermissions(group)
if len(perms) != 1 { if len(perms) != 1 {
t.Fatalf("expected 1 permission; received %d", len(perms)) t.Fatalf("expected 1 permission; received %d", len(perms))
} }
receivedPort := perms[0].FromPort receivedPort := *perms[0].FromPort
if receivedPort != testSSHPort { if receivedPort != testSSHPort {
t.Fatalf("expected permission on port %d; received port %d", testSSHPort, receivedPort) t.Fatalf("expected permission on port %d; received port %d", testSSHPort, receivedPort)
} }
@ -167,20 +168,20 @@ func TestConfigureSecurityGroupPermissionsDockerAndSsh(t *testing.T) {
group := securityGroup group := securityGroup
group.IpPermissions = []amz.IpPermission{ group.IpPermissions = []*ec2.IpPermission{
{ {
IpProtocol: "tcp", IpProtocol: aws.String("tcp"),
FromPort: testSSHPort, FromPort: aws.Int64(testSSHPort),
ToPort: testSSHPort, ToPort: aws.Int64(testSSHPort),
}, },
{ {
IpProtocol: "tcp", IpProtocol: aws.String("tcp"),
FromPort: testDockerPort, FromPort: aws.Int64(testDockerPort),
ToPort: testDockerPort, ToPort: aws.Int64(testDockerPort),
}, },
} }
perms := d.configureSecurityGroupPermissions(&group) perms := d.configureSecurityGroupPermissions(group)
if len(perms) != 0 { if len(perms) != 0 {
t.Fatalf("expected 0 permissions; received %d", len(perms)) t.Fatalf("expected 0 permissions; received %d", len(perms))
} }