diff --git a/drivers/amazonec2/amazonec2.go b/drivers/amazonec2/amazonec2.go index 9a2e0af4d1..6033f486bb 100644 --- a/drivers/amazonec2/amazonec2.go +++ b/drivers/amazonec2/amazonec2.go @@ -3,6 +3,7 @@ package amazonec2 import ( "crypto/md5" "crypto/rand" + "errors" "fmt" "io" "io/ioutil" @@ -52,6 +53,7 @@ var ( type Driver struct { *drivers.BaseDriver + clientFactory func() Ec2Client Id string AccessKey string SecretKey string @@ -83,6 +85,10 @@ type Driver struct { Monitoring bool } +type clientFactory interface { + build(d *Driver) Ec2Client +} + func (d *Driver) GetCreateFlags() []mcnflag.Flag { return []mcnflag.Flag{ mcnflag.StringFlag{ @@ -201,9 +207,9 @@ func (d *Driver) GetCreateFlags() []mcnflag.Flag { } } -func NewDriver(hostName, storePath string) drivers.Driver { +func NewDriver(hostName, storePath string) *Driver { id := generateId() - return &Driver{ + driver := &Driver{ Id: id, AMI: defaultAmiId, Region: defaultRegion, @@ -218,6 +224,24 @@ func NewDriver(hostName, storePath string) drivers.Driver { StorePath: storePath, }, } + + driver.clientFactory = driver.buildClient + + return driver +} + +func (d *Driver) buildClient() Ec2Client { + config := aws.NewConfig() + alogger := AwsLogger() + config = config.WithRegion(d.Region) + config = config.WithCredentials(credentials.NewStaticCredentials(d.AccessKey, d.SecretKey, d.SessionToken)) + config = config.WithLogger(alogger) + config = config.WithLogLevel(aws.LogDebugWithHTTPBody) + return ec2.New(session.New(config)) +} + +func (d *Driver) getClient() Ec2Client { + return d.clientFactory() } func (d *Driver) SetConfigFromFlags(flags drivers.DriverOptions) error { @@ -661,16 +685,6 @@ func (d *Driver) Remove() error { return nil } -func (d *Driver) getClient() *ec2.EC2 { - config := aws.NewConfig() - alogger := AwsLogger() - config = config.WithRegion(d.Region) - config = config.WithCredentials(credentials.NewStaticCredentials(d.AccessKey, d.SecretKey, d.SessionToken)) - config = config.WithLogger(alogger) - config = config.WithLogLevel(aws.LogDebugWithHTTPBody) - return ec2.New(session.New(config)) -} - func (d *Driver) getInstance() (*ec2.Instance, error) { instances, err := d.getClient().DescribeInstances(&ec2.DescribeInstancesInput{ InstanceIds: []*string{&d.InstanceId}, diff --git a/drivers/amazonec2/amazonec2_test.go b/drivers/amazonec2/amazonec2_test.go index 70e3021c26..98170f056e 100644 --- a/drivers/amazonec2/amazonec2_test.go +++ b/drivers/amazonec2/amazonec2_test.go @@ -5,6 +5,8 @@ import ( "os" "testing" + "errors" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" "github.com/docker/machine/commands/commandstest" @@ -85,8 +87,7 @@ func getTestDriver() (*Driver, error) { d := NewDriver(machineTestName, storePath) d.SetConfigFromFlags(getDefaultTestDriverFlags()) - drv := d.(*Driver) - return drv, nil + return d, nil } func TestConfigureSecurityGroupPermissionsEmpty(t *testing.T) { diff --git a/drivers/amazonec2/ec2client.go b/drivers/amazonec2/ec2client.go new file mode 100644 index 0000000000..ecb2a50b19 --- /dev/null +++ b/drivers/amazonec2/ec2client.go @@ -0,0 +1,51 @@ +package amazonec2 + +import "github.com/aws/aws-sdk-go/service/ec2" + +type Ec2Client interface { + DescribeAccountAttributes(input *ec2.DescribeAccountAttributesInput) (*ec2.DescribeAccountAttributesOutput, error) + + DescribeSubnets(input *ec2.DescribeSubnetsInput) (*ec2.DescribeSubnetsOutput, error) + + CreateTags(input *ec2.CreateTagsInput) (*ec2.CreateTagsOutput, error) + + //SecurityGroup + + CreateSecurityGroup(input *ec2.CreateSecurityGroupInput) (*ec2.CreateSecurityGroupOutput, error) + + AuthorizeSecurityGroupIngress(input *ec2.AuthorizeSecurityGroupIngressInput) (*ec2.AuthorizeSecurityGroupIngressOutput, error) + + DescribeSecurityGroups(input *ec2.DescribeSecurityGroupsInput) (*ec2.DescribeSecurityGroupsOutput, error) + + DeleteSecurityGroup(input *ec2.DeleteSecurityGroupInput) (*ec2.DeleteSecurityGroupOutput, error) + + //KeyPair + + DeleteKeyPair(input *ec2.DeleteKeyPairInput) (*ec2.DeleteKeyPairOutput, error) + + ImportKeyPair(input *ec2.ImportKeyPairInput) (*ec2.ImportKeyPairOutput, error) + + DescribeKeyPairs(input *ec2.DescribeKeyPairsInput) (*ec2.DescribeKeyPairsOutput, error) + + //Instances + + DescribeInstances(input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error) + + StartInstances(input *ec2.StartInstancesInput) (*ec2.StartInstancesOutput, error) + + RebootInstances(input *ec2.RebootInstancesInput) (*ec2.RebootInstancesOutput, error) + + StopInstances(input *ec2.StopInstancesInput) (*ec2.StopInstancesOutput, error) + + RunInstances(input *ec2.RunInstancesInput) (*ec2.Reservation, error) + + TerminateInstances(input *ec2.TerminateInstancesInput) (*ec2.TerminateInstancesOutput, error) + + //SpotInstances + + RequestSpotInstances(input *ec2.RequestSpotInstancesInput) (*ec2.RequestSpotInstancesOutput, error) + + DescribeSpotInstanceRequests(input *ec2.DescribeSpotInstanceRequestsInput) (*ec2.DescribeSpotInstanceRequestsOutput, error) + + WaitUntilSpotInstanceRequestFulfilled(input *ec2.DescribeSpotInstanceRequestsInput) error +}