mirror of https://github.com/docker/docs.git
amazonec2: Convert EC2 API calls to official SDK
This does an almost 1-to-1 translation of API calls. The differences are as follows: 1. Use the SDK waiter for spot instance request fulfillment 2. Uses the toplevel private/public ip fields instead of the networkinterface's fields 3. Recognizes the 'Terminated' state as an error explicitly instead of implicitly. 4. Uses filters on DescribeSecurityGroups to find the correct one more efficiently and to limit to a given VPC. Other than that, it really should be identical apart from the perhaps obvious error message differences. Signed-off-by: Euan <euank@euank.com>
This commit is contained in:
parent
7f2e3c1d19
commit
8d98d2b7b7
|
@ -12,7 +12,11 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/docker/machine/drivers/amazonec2/amz"
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/ec2"
|
||||
"github.com/docker/machine/libmachine/drivers"
|
||||
"github.com/docker/machine/libmachine/log"
|
||||
"github.com/docker/machine/libmachine/mcnflag"
|
||||
|
@ -35,6 +39,10 @@ const (
|
|||
defaultSpotPrice = "0.50"
|
||||
)
|
||||
|
||||
const (
|
||||
keypairNotFoundCode = "InvalidKeyPair.NotFound"
|
||||
)
|
||||
|
||||
var (
|
||||
dockerPort = 2376
|
||||
swarmPort = 3376
|
||||
|
@ -233,19 +241,21 @@ func (d *Driver) SetConfigFromFlags(flags drivers.DriverOptions) error {
|
|||
}
|
||||
|
||||
if d.SubnetId != "" && d.VpcId != "" {
|
||||
filters := []amz.Filter{
|
||||
subnetFilter := []*ec2.Filter{
|
||||
{
|
||||
Name: "subnet-id",
|
||||
Value: d.SubnetId,
|
||||
Name: aws.String("subnet-id"),
|
||||
Values: []*string{&d.SubnetId},
|
||||
},
|
||||
}
|
||||
|
||||
subnets, err := d.getClient().GetSubnets(filters)
|
||||
subnets, err := d.getClient().DescribeSubnets(&ec2.DescribeSubnetsInput{
|
||||
Filters: subnetFilter,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if subnets[0].VpcId != d.VpcId {
|
||||
if *subnets.Subnets[0].VpcId != d.VpcId {
|
||||
return fmt.Errorf("SubnetId: %s does not belong to VpcId: %s", d.SubnetId, d.VpcId)
|
||||
}
|
||||
}
|
||||
|
@ -275,44 +285,52 @@ func (d *Driver) DriverName() string {
|
|||
|
||||
func (d *Driver) checkPrereqs() error {
|
||||
// check for existing keypair
|
||||
key, err := d.getClient().GetKeyPair(d.MachineName)
|
||||
key, err := d.getClient().DescribeKeyPairs(&ec2.DescribeKeyPairsInput{
|
||||
KeyNames: []*string{&d.MachineName},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
if awsErr, ok := err.(awserr.Error); ok && awsErr.Code() == keypairNotFoundCode {
|
||||
// Not a real error for 'NotFound' since we're checking existance anyways
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if key != nil {
|
||||
if err == nil && len(key.KeyPairs) != 0 {
|
||||
return fmt.Errorf("There is already a keypair with the name %s. Please either remove that keypair or use a different machine name.", d.MachineName)
|
||||
}
|
||||
|
||||
regionZone := d.Region + d.Zone
|
||||
if d.SubnetId == "" {
|
||||
filters := []amz.Filter{
|
||||
filters := []*ec2.Filter{
|
||||
{
|
||||
Name: "availabilityZone",
|
||||
Value: regionZone,
|
||||
Name: aws.String("availability-zone"),
|
||||
Values: []*string{®ionZone},
|
||||
},
|
||||
{
|
||||
Name: "vpc-id",
|
||||
Value: d.VpcId,
|
||||
Name: aws.String("vpc-id"),
|
||||
Values: []*string{&d.VpcId},
|
||||
},
|
||||
}
|
||||
|
||||
subnets, err := d.getClient().GetSubnets(filters)
|
||||
subnets, err := d.getClient().DescribeSubnets(&ec2.DescribeSubnetsInput{
|
||||
Filters: filters,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(subnets) == 0 {
|
||||
if len(subnets.Subnets) == 0 {
|
||||
return fmt.Errorf("unable to find a subnet in the zone: %s", regionZone)
|
||||
}
|
||||
|
||||
d.SubnetId = subnets[0].SubnetId
|
||||
d.SubnetId = *subnets.Subnets[0].SubnetId
|
||||
|
||||
// try to find default
|
||||
if len(subnets) > 1 {
|
||||
for _, subnet := range subnets {
|
||||
if subnet.DefaultForAz {
|
||||
d.SubnetId = subnet.SubnetId
|
||||
if len(subnets.Subnets) > 1 {
|
||||
for _, subnet := range subnets.Subnets {
|
||||
if *subnet.DefaultForAz {
|
||||
d.SubnetId = *subnet.SubnetId
|
||||
break
|
||||
}
|
||||
}
|
||||
|
@ -354,53 +372,123 @@ func (d *Driver) Create() error {
|
|||
return err
|
||||
}
|
||||
|
||||
bdm := &amz.BlockDeviceMapping{
|
||||
DeviceName: "/dev/sda1",
|
||||
VolumeSize: d.RootSize,
|
||||
DeleteOnTermination: true,
|
||||
VolumeType: "gp2",
|
||||
bdm := &ec2.BlockDeviceMapping{
|
||||
DeviceName: aws.String("/dev/sda1"),
|
||||
Ebs: &ec2.EbsBlockDevice{
|
||||
VolumeSize: aws.Int64(d.RootSize),
|
||||
VolumeType: aws.String("gp2"),
|
||||
DeleteOnTermination: aws.Bool(true),
|
||||
},
|
||||
}
|
||||
netSpecs := []*ec2.InstanceNetworkInterfaceSpecification{{
|
||||
DeviceIndex: aws.Int64(0), // eth0
|
||||
Groups: []*string{&d.SecurityGroupId},
|
||||
SubnetId: &d.SubnetId,
|
||||
AssociatePublicIpAddress: aws.Bool(!d.PrivateIPOnly),
|
||||
}}
|
||||
|
||||
regionZone := d.Region + d.Zone
|
||||
log.Debugf("launching instance in subnet %s", d.SubnetId)
|
||||
var instance amz.EC2Instance
|
||||
|
||||
var instance *ec2.Instance
|
||||
|
||||
if d.RequestSpotInstance {
|
||||
spotInstanceRequestId, err := d.getClient().RequestSpotInstances(d.AMI, d.InstanceType, d.Zone, 1, d.SecurityGroupId, d.KeyName, d.SubnetId, bdm, d.IamInstanceProfile, d.SpotPrice, d.Monitoring)
|
||||
spotInstanceRequest, err := d.getClient().RequestSpotInstances(&ec2.RequestSpotInstancesInput{
|
||||
LaunchSpecification: &ec2.RequestSpotLaunchSpecification{
|
||||
ImageId: &d.AMI,
|
||||
Placement: &ec2.SpotPlacement{
|
||||
AvailabilityZone: ®ionZone,
|
||||
},
|
||||
KeyName: &d.KeyName,
|
||||
InstanceType: &d.InstanceType,
|
||||
NetworkInterfaces: netSpecs,
|
||||
Monitoring: &ec2.RunInstancesMonitoringEnabled{Enabled: aws.Bool(d.Monitoring)},
|
||||
IamInstanceProfile: &ec2.IamInstanceProfileSpecification{
|
||||
Name: &d.IamInstanceProfile,
|
||||
},
|
||||
BlockDeviceMappings: []*ec2.BlockDeviceMapping{bdm},
|
||||
},
|
||||
InstanceCount: aws.Int64(1),
|
||||
SpotPrice: &d.SpotPrice,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error request spot instance: %s", err)
|
||||
}
|
||||
var instanceId string
|
||||
var spotInstanceRequestStatus string
|
||||
|
||||
log.Info("Waiting for spot instance...")
|
||||
// check until fulfilled
|
||||
for instanceId == "" {
|
||||
time.Sleep(time.Second * 5)
|
||||
spotInstanceRequestStatus, instanceId, err = d.getClient().DescribeSpotInstanceRequests(spotInstanceRequestId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error describe spot instance request: %s", err)
|
||||
}
|
||||
log.Debugf("spot instance request status: %s", spotInstanceRequestStatus)
|
||||
}
|
||||
instance, err = d.getClient().GetInstance(instanceId)
|
||||
err = d.getClient().WaitUntilSpotInstanceRequestFulfilled(&ec2.DescribeSpotInstanceRequestsInput{
|
||||
SpotInstanceRequestIds: []*string{spotInstanceRequest.SpotInstanceRequests[0].SpotInstanceRequestId},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error get instance: %s", err)
|
||||
return fmt.Errorf("Error fulfilling spot request: %v", err)
|
||||
}
|
||||
log.Info("Created spot instance request %v", *spotInstanceRequest.SpotInstanceRequests[0].SpotInstanceRequestId)
|
||||
// resolve instance id
|
||||
for i := 0; i < 3; i++ {
|
||||
// Even though the waiter succeeded, eventual consistency means we could
|
||||
// get a describe output that does not include this information. Try a
|
||||
// few times just in case
|
||||
var resolvedSpotInstance *ec2.DescribeSpotInstanceRequestsOutput
|
||||
resolvedSpotInstance, err = d.getClient().DescribeSpotInstanceRequests(&ec2.DescribeSpotInstanceRequestsInput{
|
||||
SpotInstanceRequestIds: []*string{spotInstanceRequest.SpotInstanceRequests[0].SpotInstanceRequestId},
|
||||
})
|
||||
if err != nil {
|
||||
// Unexpected; no need to retry
|
||||
return fmt.Errorf("Error describing previously made spot instance request: %v", err)
|
||||
}
|
||||
maybeInstanceId := resolvedSpotInstance.SpotInstanceRequests[0].InstanceId
|
||||
if maybeInstanceId != nil {
|
||||
var instances *ec2.DescribeInstancesOutput
|
||||
instances, err = d.getClient().DescribeInstances(&ec2.DescribeInstancesInput{
|
||||
InstanceIds: []*string{maybeInstanceId},
|
||||
})
|
||||
if err != nil {
|
||||
// Retry if we get an id from spot instance but EC2 doesn't recognize it yet; see above, eventual consistency possible
|
||||
continue
|
||||
}
|
||||
instance = instances.Reservations[0].Instances[0]
|
||||
err = nil
|
||||
break
|
||||
}
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error resolving spot instance to real instance: %v", err)
|
||||
}
|
||||
} else {
|
||||
inst, err := d.getClient().RunInstance(d.AMI, d.InstanceType, d.Zone, 1, 1, d.SecurityGroupId, d.KeyName, d.SubnetId, bdm, d.IamInstanceProfile, d.PrivateIPOnly, d.Monitoring)
|
||||
inst, err := d.getClient().RunInstances(&ec2.RunInstancesInput{
|
||||
ImageId: &d.AMI,
|
||||
MinCount: aws.Int64(1),
|
||||
MaxCount: aws.Int64(1),
|
||||
Placement: &ec2.Placement{
|
||||
AvailabilityZone: ®ionZone,
|
||||
},
|
||||
KeyName: &d.KeyName,
|
||||
InstanceType: &d.InstanceType,
|
||||
NetworkInterfaces: netSpecs,
|
||||
Monitoring: &ec2.RunInstancesMonitoringEnabled{Enabled: aws.Bool(d.Monitoring)},
|
||||
IamInstanceProfile: &ec2.IamInstanceProfileSpecification{
|
||||
Name: &d.IamInstanceProfile,
|
||||
},
|
||||
BlockDeviceMappings: []*ec2.BlockDeviceMapping{bdm},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error launching instance: %s", err)
|
||||
}
|
||||
instance = inst
|
||||
instance = inst.Instances[0]
|
||||
}
|
||||
|
||||
d.InstanceId = instance.InstanceId
|
||||
d.InstanceId = *instance.InstanceId
|
||||
|
||||
log.Debug("waiting for ip address to become available")
|
||||
if err := mcnutils.WaitFor(d.instanceIpAvailable); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(instance.NetworkInterfaceSet) > 0 {
|
||||
d.PrivateIPAddress = instance.NetworkInterfaceSet[0].PrivateIpAddress
|
||||
if instance.PrivateIpAddress != nil {
|
||||
d.PrivateIPAddress = *instance.PrivateIpAddress
|
||||
}
|
||||
|
||||
d.waitForInstance()
|
||||
|
@ -412,12 +500,15 @@ func (d *Driver) Create() error {
|
|||
)
|
||||
|
||||
log.Debug("Settings tags for instance")
|
||||
tags := map[string]string{
|
||||
"Name": d.MachineName,
|
||||
}
|
||||
|
||||
if err := d.getClient().CreateTags(d.InstanceId, tags); err != nil {
|
||||
return err
|
||||
_, err := d.getClient().CreateTags(&ec2.CreateTagsInput{
|
||||
Resources: []*string{&d.InstanceId},
|
||||
Tags: []*ec2.Tag{{
|
||||
Key: aws.String("Name"),
|
||||
Value: &d.MachineName,
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to tag instance %s: %s", d.InstanceId, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -441,14 +532,23 @@ func (d *Driver) GetIP() (string, error) {
|
|||
}
|
||||
|
||||
if d.PrivateIPOnly {
|
||||
return inst.PrivateIpAddress, nil
|
||||
if inst.PrivateIpAddress == nil {
|
||||
return "", fmt.Errorf("No private IP for instance %v", *inst.InstanceId)
|
||||
}
|
||||
return *inst.PrivateIpAddress, nil
|
||||
}
|
||||
|
||||
if d.UsePrivateIP {
|
||||
return inst.PrivateIpAddress, nil
|
||||
if inst.PrivateIpAddress == nil {
|
||||
return "", fmt.Errorf("No private IP for instance %v", *inst.InstanceId)
|
||||
}
|
||||
return *inst.PrivateIpAddress, nil
|
||||
}
|
||||
|
||||
return inst.IpAddress, nil
|
||||
if inst.PublicIpAddress == nil {
|
||||
return "", fmt.Errorf("No IP for instance %v", *inst.InstanceId)
|
||||
}
|
||||
return *inst.PublicIpAddress, nil
|
||||
}
|
||||
|
||||
func (d *Driver) GetState() (state.State, error) {
|
||||
|
@ -456,18 +556,21 @@ func (d *Driver) GetState() (state.State, error) {
|
|||
if err != nil {
|
||||
return state.Error, err
|
||||
}
|
||||
switch inst.InstanceState.Name {
|
||||
case "pending":
|
||||
switch *inst.State.Name {
|
||||
case ec2.InstanceStateNamePending:
|
||||
return state.Starting, nil
|
||||
case "running":
|
||||
case ec2.InstanceStateNameRunning:
|
||||
return state.Running, nil
|
||||
case "stopping":
|
||||
case ec2.InstanceStateNameStopping:
|
||||
return state.Stopping, nil
|
||||
case "shutting-down":
|
||||
case ec2.InstanceStateNameShuttingDown:
|
||||
return state.Stopping, nil
|
||||
case "stopped":
|
||||
case ec2.InstanceStateNameStopped:
|
||||
return state.Stopped, nil
|
||||
case ec2.InstanceStateNameTerminated:
|
||||
return state.Error, nil
|
||||
default:
|
||||
log.Warnf("unrecognized instance state: %v", *inst.State.Name)
|
||||
return state.Error, nil
|
||||
}
|
||||
}
|
||||
|
@ -487,7 +590,10 @@ func (d *Driver) GetSSHUsername() string {
|
|||
}
|
||||
|
||||
func (d *Driver) Start() error {
|
||||
if err := d.getClient().StartInstance(d.InstanceId); err != nil {
|
||||
_, err := d.getClient().StartInstances(&ec2.StartInstancesInput{
|
||||
InstanceIds: []*string{&d.InstanceId},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -499,10 +605,11 @@ func (d *Driver) Start() error {
|
|||
}
|
||||
|
||||
func (d *Driver) Stop() error {
|
||||
if err := d.getClient().StopInstance(d.InstanceId, false); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
_, err := d.getClient().StopInstances(&ec2.StopInstancesInput{
|
||||
InstanceIds: []*string{&d.InstanceId},
|
||||
Force: aws.Bool(false),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *Driver) Remove() error {
|
||||
|
@ -520,31 +627,35 @@ func (d *Driver) Remove() error {
|
|||
}
|
||||
|
||||
func (d *Driver) Restart() error {
|
||||
if err := d.getClient().RestartInstance(d.InstanceId); err != nil {
|
||||
return fmt.Errorf("unable to restart instance: %s", err)
|
||||
}
|
||||
return nil
|
||||
_, err := d.getClient().RebootInstances(&ec2.RebootInstancesInput{
|
||||
InstanceIds: []*string{&d.InstanceId},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *Driver) Kill() error {
|
||||
if err := d.getClient().StopInstance(d.InstanceId, true); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
_, err := d.getClient().StopInstances(&ec2.StopInstancesInput{
|
||||
InstanceIds: []*string{&d.InstanceId},
|
||||
Force: aws.Bool(true),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *Driver) getClient() *amz.EC2 {
|
||||
auth := amz.GetAuth(d.AccessKey, d.SecretKey, d.SessionToken)
|
||||
return amz.NewEC2(auth, d.Region)
|
||||
func (d *Driver) getClient() *ec2.EC2 {
|
||||
config := aws.NewConfig()
|
||||
config = config.WithRegion(d.Region)
|
||||
config = config.WithCredentials(credentials.NewStaticCredentials(d.AccessKey, d.SecretKey, d.SessionToken))
|
||||
return ec2.New(session.New(config))
|
||||
}
|
||||
|
||||
func (d *Driver) getInstance() (*amz.EC2Instance, error) {
|
||||
instance, err := d.getClient().GetInstance(d.InstanceId)
|
||||
func (d *Driver) getInstance() (*ec2.Instance, error) {
|
||||
instances, err := d.getClient().DescribeInstances(&ec2.DescribeInstancesInput{
|
||||
InstanceIds: []*string{&d.InstanceId},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &instance, nil
|
||||
return instances.Reservations[0].Instances[0], nil
|
||||
}
|
||||
|
||||
func (d *Driver) instanceIsRunning() bool {
|
||||
|
@ -567,7 +678,6 @@ func (d *Driver) waitForInstance() error {
|
|||
}
|
||||
|
||||
func (d *Driver) createKeyPair() error {
|
||||
|
||||
if err := ssh.GenerateSSHKey(d.GetSSHKeyPath()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -580,11 +690,13 @@ func (d *Driver) createKeyPair() error {
|
|||
keyName := d.MachineName
|
||||
|
||||
log.Debugf("creating key pair: %s", keyName)
|
||||
|
||||
if err := d.getClient().ImportKeyPair(keyName, string(publicKey)); err != nil {
|
||||
_, err = d.getClient().ImportKeyPair(&ec2.ImportKeyPairInput{
|
||||
KeyName: &keyName,
|
||||
PublicKeyMaterial: publicKey,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.KeyName = keyName
|
||||
return nil
|
||||
}
|
||||
|
@ -595,10 +707,12 @@ func (d *Driver) terminate() error {
|
|||
}
|
||||
|
||||
log.Debugf("terminating instance: %s", d.InstanceId)
|
||||
if err := d.getClient().TerminateInstance(d.InstanceId); err != nil {
|
||||
_, err := d.getClient().TerminateInstances(&ec2.TerminateInstancesInput{
|
||||
InstanceIds: []*string{&d.InstanceId},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to terminate instance: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -608,9 +722,15 @@ func (d *Driver) isSwarmMaster() bool {
|
|||
|
||||
func (d *Driver) securityGroupAvailableFunc(id string) func() bool {
|
||||
return func() bool {
|
||||
_, err := d.getClient().GetSecurityGroupById(id)
|
||||
if err == nil {
|
||||
|
||||
securityGroup, err := d.getClient().DescribeSecurityGroups(&ec2.DescribeSecurityGroupsInput{
|
||||
GroupIds: []*string{&id},
|
||||
})
|
||||
if err == nil && len(securityGroup.SecurityGroups) > 0 {
|
||||
return true
|
||||
} else if err == nil {
|
||||
log.Debugf("No security group with id %v found", id)
|
||||
return false
|
||||
}
|
||||
log.Debug(err)
|
||||
return false
|
||||
|
@ -620,92 +740,112 @@ func (d *Driver) securityGroupAvailableFunc(id string) func() bool {
|
|||
func (d *Driver) configureSecurityGroup(groupName string) error {
|
||||
log.Debugf("configuring security group in %s", d.VpcId)
|
||||
|
||||
var securityGroup *amz.SecurityGroup
|
||||
|
||||
groups, err := d.getClient().GetSecurityGroups()
|
||||
var group *ec2.SecurityGroup
|
||||
filters := []*ec2.Filter{
|
||||
{
|
||||
Name: aws.String("group-name"),
|
||||
Values: []*string{&groupName},
|
||||
},
|
||||
{
|
||||
Name: aws.String("vpc-id"),
|
||||
Values: []*string{&d.VpcId},
|
||||
},
|
||||
}
|
||||
groups, err := d.getClient().DescribeSecurityGroups(&ec2.DescribeSecurityGroupsInput{
|
||||
Filters: filters,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, grp := range groups {
|
||||
if grp.GroupName == groupName {
|
||||
log.Debugf("found existing security group (%s) in %s", groupName, d.VpcId)
|
||||
securityGroup = &grp
|
||||
break
|
||||
}
|
||||
if len(groups.SecurityGroups) > 0 {
|
||||
log.Debugf("found existing security group (%s) in %s", groupName, d.VpcId)
|
||||
group = groups.SecurityGroups[0]
|
||||
}
|
||||
|
||||
// if not found, create
|
||||
if securityGroup == nil {
|
||||
if group == nil {
|
||||
log.Debugf("creating security group (%s) in %s", groupName, d.VpcId)
|
||||
group, err := d.getClient().CreateSecurityGroup(groupName, "Docker Machine", d.VpcId)
|
||||
groupResp, err := d.getClient().CreateSecurityGroup(&ec2.CreateSecurityGroupInput{
|
||||
GroupName: &groupName,
|
||||
Description: aws.String("Docker Machine"),
|
||||
VpcId: &d.VpcId,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
securityGroup = group
|
||||
// Manually translate into the security group construct
|
||||
group = &ec2.SecurityGroup{
|
||||
GroupId: groupResp.GroupId,
|
||||
VpcId: aws.String(d.VpcId),
|
||||
GroupName: aws.String(groupName),
|
||||
}
|
||||
// wait until created (dat eventual consistency)
|
||||
log.Debugf("waiting for group (%s) to become available", group.GroupId)
|
||||
if err := mcnutils.WaitFor(d.securityGroupAvailableFunc(group.GroupId)); err != nil {
|
||||
log.Debugf("waiting for group (%s) to become available", *group.GroupId)
|
||||
if err := mcnutils.WaitFor(d.securityGroupAvailableFunc(*group.GroupId)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
d.SecurityGroupId = securityGroup.GroupId
|
||||
d.SecurityGroupId = *group.GroupId
|
||||
|
||||
perms := d.configureSecurityGroupPermissions(securityGroup)
|
||||
perms := d.configureSecurityGroupPermissions(group)
|
||||
|
||||
if len(perms) != 0 {
|
||||
log.Debugf("authorizing group %s with permissions: %v", securityGroup.GroupName, perms)
|
||||
if err := d.getClient().AuthorizeSecurityGroup(d.SecurityGroupId, perms); err != nil {
|
||||
log.Debugf("authorizing group %s with permissions: %v", groupName, perms)
|
||||
_, err := d.getClient().AuthorizeSecurityGroupIngress(&ec2.AuthorizeSecurityGroupIngressInput{
|
||||
GroupId: group.GroupId,
|
||||
IpPermissions: perms,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Driver) configureSecurityGroupPermissions(group *amz.SecurityGroup) []amz.IpPermission {
|
||||
func (d *Driver) configureSecurityGroupPermissions(group *ec2.SecurityGroup) []*ec2.IpPermission {
|
||||
hasSshPort := false
|
||||
hasDockerPort := false
|
||||
hasSwarmPort := false
|
||||
for _, p := range group.IpPermissions {
|
||||
switch p.FromPort {
|
||||
switch *p.FromPort {
|
||||
case 22:
|
||||
hasSshPort = true
|
||||
case dockerPort:
|
||||
case int64(dockerPort):
|
||||
hasDockerPort = true
|
||||
case swarmPort:
|
||||
case int64(swarmPort):
|
||||
hasSwarmPort = true
|
||||
}
|
||||
}
|
||||
|
||||
perms := []amz.IpPermission{}
|
||||
perms := []*ec2.IpPermission{}
|
||||
|
||||
if !hasSshPort {
|
||||
perms = append(perms, amz.IpPermission{
|
||||
IpProtocol: "tcp",
|
||||
FromPort: 22,
|
||||
ToPort: 22,
|
||||
IpRange: ipRange,
|
||||
perms = append(perms, &ec2.IpPermission{
|
||||
IpProtocol: aws.String("tcp"),
|
||||
FromPort: aws.Int64(22),
|
||||
ToPort: aws.Int64(22),
|
||||
IpRanges: []*ec2.IpRange{{CidrIp: aws.String(ipRange)}},
|
||||
})
|
||||
}
|
||||
|
||||
if !hasDockerPort {
|
||||
perms = append(perms, amz.IpPermission{
|
||||
IpProtocol: "tcp",
|
||||
FromPort: dockerPort,
|
||||
ToPort: dockerPort,
|
||||
IpRange: ipRange,
|
||||
perms = append(perms, &ec2.IpPermission{
|
||||
IpProtocol: aws.String("tcp"),
|
||||
FromPort: aws.Int64(int64(dockerPort)),
|
||||
ToPort: aws.Int64(int64(dockerPort)),
|
||||
IpRanges: []*ec2.IpRange{{CidrIp: aws.String(ipRange)}},
|
||||
})
|
||||
}
|
||||
|
||||
if !hasSwarmPort && d.SwarmMaster {
|
||||
perms = append(perms, amz.IpPermission{
|
||||
IpProtocol: "tcp",
|
||||
FromPort: swarmPort,
|
||||
ToPort: swarmPort,
|
||||
IpRange: ipRange,
|
||||
perms = append(perms, &ec2.IpPermission{
|
||||
IpProtocol: aws.String("tcp"),
|
||||
FromPort: aws.Int64(int64(swarmPort)),
|
||||
ToPort: aws.Int64(int64(swarmPort)),
|
||||
IpRanges: []*ec2.IpRange{{CidrIp: aws.String(ipRange)}},
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -717,7 +857,10 @@ func (d *Driver) configureSecurityGroupPermissions(group *amz.SecurityGroup) []a
|
|||
func (d *Driver) deleteSecurityGroup() error {
|
||||
log.Debugf("deleting security group %s", d.SecurityGroupId)
|
||||
|
||||
if err := d.getClient().DeleteSecurityGroup(d.SecurityGroupId); err != nil {
|
||||
_, err := d.getClient().DeleteSecurityGroup(&ec2.DeleteSecurityGroupInput{
|
||||
GroupId: &d.SecurityGroupId,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -727,7 +870,10 @@ func (d *Driver) deleteSecurityGroup() error {
|
|||
func (d *Driver) deleteKeyPair() error {
|
||||
log.Debugf("deleting key pair: %s", d.KeyName)
|
||||
|
||||
if err := d.getClient().DeleteKeyPair(d.KeyName); err != nil {
|
||||
_, err := d.getClient().DeleteKeyPair(&ec2.DeleteKeyPairInput{
|
||||
KeyName: &d.KeyName,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -5,9 +5,10 @@ import (
|
|||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/service/ec2"
|
||||
"github.com/docker/machine/commands/commandstest"
|
||||
"github.com/docker/machine/commands/mcndirs"
|
||||
"github.com/docker/machine/drivers/amazonec2/amz"
|
||||
"github.com/docker/machine/libmachine/drivers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
@ -24,10 +25,10 @@ const (
|
|||
)
|
||||
|
||||
var (
|
||||
securityGroup = amz.SecurityGroup{
|
||||
GroupName: "test-group",
|
||||
GroupId: "12345",
|
||||
VpcId: "12345",
|
||||
securityGroup = &ec2.SecurityGroup{
|
||||
GroupName: aws.String("test-group"),
|
||||
GroupId: aws.String("12345"),
|
||||
VpcId: aws.String("12345"),
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -96,7 +97,7 @@ func TestConfigureSecurityGroupPermissionsEmpty(t *testing.T) {
|
|||
defer cleanup()
|
||||
|
||||
group := securityGroup
|
||||
perms := d.configureSecurityGroupPermissions(&group)
|
||||
perms := d.configureSecurityGroupPermissions(group)
|
||||
if len(perms) != 2 {
|
||||
t.Fatalf("expected 2 permissions; received %d", len(perms))
|
||||
}
|
||||
|
@ -111,20 +112,20 @@ func TestConfigureSecurityGroupPermissionsSshOnly(t *testing.T) {
|
|||
|
||||
group := securityGroup
|
||||
|
||||
group.IpPermissions = []amz.IpPermission{
|
||||
group.IpPermissions = []*ec2.IpPermission{
|
||||
{
|
||||
IpProtocol: "tcp",
|
||||
FromPort: testSSHPort,
|
||||
ToPort: testSSHPort,
|
||||
IpProtocol: aws.String("tcp"),
|
||||
FromPort: aws.Int64(int64(testSSHPort)),
|
||||
ToPort: aws.Int64(int64(testSSHPort)),
|
||||
},
|
||||
}
|
||||
|
||||
perms := d.configureSecurityGroupPermissions(&group)
|
||||
perms := d.configureSecurityGroupPermissions(group)
|
||||
if len(perms) != 1 {
|
||||
t.Fatalf("expected 1 permission; received %d", len(perms))
|
||||
}
|
||||
|
||||
receivedPort := perms[0].FromPort
|
||||
receivedPort := *perms[0].FromPort
|
||||
if receivedPort != testDockerPort {
|
||||
t.Fatalf("expected permission on port %d; received port %d", testDockerPort, receivedPort)
|
||||
}
|
||||
|
@ -139,20 +140,20 @@ func TestConfigureSecurityGroupPermissionsDockerOnly(t *testing.T) {
|
|||
|
||||
group := securityGroup
|
||||
|
||||
group.IpPermissions = []amz.IpPermission{
|
||||
group.IpPermissions = []*ec2.IpPermission{
|
||||
{
|
||||
IpProtocol: "tcp",
|
||||
FromPort: testDockerPort,
|
||||
ToPort: testDockerPort,
|
||||
IpProtocol: aws.String("tcp"),
|
||||
FromPort: aws.Int64((testDockerPort)),
|
||||
ToPort: aws.Int64((testDockerPort)),
|
||||
},
|
||||
}
|
||||
|
||||
perms := d.configureSecurityGroupPermissions(&group)
|
||||
perms := d.configureSecurityGroupPermissions(group)
|
||||
if len(perms) != 1 {
|
||||
t.Fatalf("expected 1 permission; received %d", len(perms))
|
||||
}
|
||||
|
||||
receivedPort := perms[0].FromPort
|
||||
receivedPort := *perms[0].FromPort
|
||||
if receivedPort != testSSHPort {
|
||||
t.Fatalf("expected permission on port %d; received port %d", testSSHPort, receivedPort)
|
||||
}
|
||||
|
@ -167,20 +168,20 @@ func TestConfigureSecurityGroupPermissionsDockerAndSsh(t *testing.T) {
|
|||
|
||||
group := securityGroup
|
||||
|
||||
group.IpPermissions = []amz.IpPermission{
|
||||
group.IpPermissions = []*ec2.IpPermission{
|
||||
{
|
||||
IpProtocol: "tcp",
|
||||
FromPort: testSSHPort,
|
||||
ToPort: testSSHPort,
|
||||
IpProtocol: aws.String("tcp"),
|
||||
FromPort: aws.Int64(testSSHPort),
|
||||
ToPort: aws.Int64(testSSHPort),
|
||||
},
|
||||
{
|
||||
IpProtocol: "tcp",
|
||||
FromPort: testDockerPort,
|
||||
ToPort: testDockerPort,
|
||||
IpProtocol: aws.String("tcp"),
|
||||
FromPort: aws.Int64(testDockerPort),
|
||||
ToPort: aws.Int64(testDockerPort),
|
||||
},
|
||||
}
|
||||
|
||||
perms := d.configureSecurityGroupPermissions(&group)
|
||||
perms := d.configureSecurityGroupPermissions(group)
|
||||
if len(perms) != 0 {
|
||||
t.Fatalf("expected 0 permissions; received %d", len(perms))
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue