diff --git a/drivers/amazonec2/amazonec2.go b/drivers/amazonec2/amazonec2.go index 66d3267242..d853379f36 100644 --- a/drivers/amazonec2/amazonec2.go +++ b/drivers/amazonec2/amazonec2.go @@ -15,7 +15,6 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ec2" "github.com/docker/machine/libmachine/drivers" @@ -57,6 +56,7 @@ var ( type Driver struct { *drivers.BaseDriver clientFactory func() Ec2Client + awsCredentials awsCredentials Id string AccessKey string SecretKey string @@ -226,6 +226,7 @@ func NewDriver(hostName, storePath string) *Driver { MachineName: hostName, StorePath: storePath, }, + awsCredentials: &defaultAWSCredentials{}, } driver.clientFactory = driver.buildClient @@ -237,7 +238,7 @@ func (d *Driver) buildClient() Ec2Client { config := aws.NewConfig() alogger := AwsLogger() config = config.WithRegion(d.Region) - config = config.WithCredentials(credentials.NewStaticCredentials(d.AccessKey, d.SecretKey, d.SessionToken)) + config = config.WithCredentials(d.awsCredentials.NewStaticCredentials(d.AccessKey, d.SecretKey, d.SessionToken)) config = config.WithLogger(alogger) config = config.WithLogLevel(aws.LogDebugWithHTTPBody) return ec2.New(session.New(config)) diff --git a/drivers/amazonec2/amazonec2_test.go b/drivers/amazonec2/amazonec2_test.go index 6e9c0ee850..fb0b6f2b34 100644 --- a/drivers/amazonec2/amazonec2_test.go +++ b/drivers/amazonec2/amazonec2_test.go @@ -240,34 +240,10 @@ func TestSetConfigFromFlags(t *testing.T) { assert.Empty(t, checkFlags.InvalidFlags) } -type fakeEC2WithDescribe struct { - *ec2.EC2 - output *ec2.DescribeAccountAttributesOutput - err error -} - -func (f *fakeEC2WithDescribe) DescribeAccountAttributes(input *ec2.DescribeAccountAttributesInput) (*ec2.DescribeAccountAttributesOutput, error) { - return f.output, f.err -} - func TestFindDefaultVPC(t *testing.T) { - defaultVpc := "default-vpc" - vpcName := "vpc-9999" - driver := NewDriver("machineFoo", "path") driver.clientFactory = func() Ec2Client { - return &fakeEC2WithDescribe{ - output: &ec2.DescribeAccountAttributesOutput{ - AccountAttributes: []*ec2.AccountAttribute{ - { - AttributeName: &defaultVpc, - AttributeValues: []*ec2.AccountAttributeValue{ - {AttributeValue: &vpcName}, - }, - }, - }, - }, - } + return &fakeEC2WithLogin{} } vpc, err := driver.getDefaultVPCId() @@ -305,3 +281,12 @@ func TestDescribeAccountAttributeFails(t *testing.T) { assert.EqualError(t, err, "Not Found") assert.Empty(t, vpc) } + +} + +} + +} + +} + diff --git a/drivers/amazonec2/awscredentials.go b/drivers/amazonec2/awscredentials.go new file mode 100644 index 0000000000..15e3a326d3 --- /dev/null +++ b/drivers/amazonec2/awscredentials.go @@ -0,0 +1,19 @@ +package amazonec2 + +import "github.com/aws/aws-sdk-go/aws/credentials" + +type awsCredentials interface { + NewStaticCredentials(id, secret, token string) *credentials.Credentials + + NewSharedCredentials(filename, profile string) *credentials.Credentials +} + +type defaultAWSCredentials struct{} + +func (c *defaultAWSCredentials) NewStaticCredentials(id, secret, token string) *credentials.Credentials { + return credentials.NewStaticCredentials(id, secret, token) +} + +func (c *defaultAWSCredentials) NewSharedCredentials(filename, profile string) *credentials.Credentials { + return credentials.NewSharedCredentials(filename, profile) +} diff --git a/drivers/amazonec2/stub_test.go b/drivers/amazonec2/stub_test.go new file mode 100644 index 0000000000..680da8bd4c --- /dev/null +++ b/drivers/amazonec2/stub_test.go @@ -0,0 +1,90 @@ +package amazonec2 + +import ( + "errors" + + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/service/ec2" +) + +type fakeEC2 struct { + *ec2.EC2 +} + +type errorProvider struct{} + +func (p *errorProvider) Retrieve() (credentials.Value, error) { + return credentials.Value{}, errors.New("bad credentials") +} + +func (p *errorProvider) IsExpired() bool { + return true +} + +type okProvider struct { + accessKeyID string + secretAccessKey string + sessionToken string +} + +func (p *okProvider) Retrieve() (credentials.Value, error) { + return credentials.Value{ + AccessKeyID: p.accessKeyID, + SecretAccessKey: p.secretAccessKey, + SessionToken: p.sessionToken, + }, nil +} + +func (p *okProvider) IsExpired() bool { + return true +} + +type cliCredentials struct{} + +func (c *cliCredentials) NewStaticCredentials(id, secret, token string) *credentials.Credentials { + return credentials.NewCredentials(&okProvider{id, secret, token}) +} + +func (c *cliCredentials) NewSharedCredentials(filename, profile string) *credentials.Credentials { + return credentials.NewCredentials(&errorProvider{}) +} + +type fileCredentials struct{} + +func (c *fileCredentials) NewStaticCredentials(id, secret, token string) *credentials.Credentials { + return nil +} + +func (c *fileCredentials) NewSharedCredentials(filename, profile string) *credentials.Credentials { + return credentials.NewCredentials(&okProvider{"access", "secret", "token"}) +} + +type fakeEC2WithDescribe struct { + *fakeEC2 + output *ec2.DescribeAccountAttributesOutput + err error +} + +func (f *fakeEC2WithDescribe) DescribeAccountAttributes(input *ec2.DescribeAccountAttributesInput) (*ec2.DescribeAccountAttributesOutput, error) { + return f.output, f.err +} + +type fakeEC2WithLogin struct { + *fakeEC2 +} + +func (f *fakeEC2WithLogin) DescribeAccountAttributes(input *ec2.DescribeAccountAttributesInput) (*ec2.DescribeAccountAttributesOutput, error) { + defaultVpc := "default-vpc" + vpcName := "vpc-9999" + + return &ec2.DescribeAccountAttributesOutput{ + AccountAttributes: []*ec2.AccountAttribute{ + { + AttributeName: &defaultVpc, + AttributeValues: []*ec2.AccountAttributeValue{ + {AttributeValue: &vpcName}, + }, + }, + }, + }, nil +}