From ae2d344c2b60f3e06b5e7d20640559afa23f44a4 Mon Sep 17 00:00:00 2001 From: Olivier Gambier Date: Thu, 29 Oct 2015 11:14:09 -0700 Subject: [PATCH] Carry on commits from #2033 A couple of small cleanup and enhancements that were dropped after the revert. Signed-off-by: Olivier Gambier --- drivers/digitalocean/digitalocean.go | 7 ---- drivers/exoscale/exoscale.go | 7 ---- drivers/generic/generic.go | 60 ++++++++++------------------ libmachine/drivers/base.go | 55 ++++++++++++++++++------- libmachine/drivers/base_test.go | 28 +++++++++++++ 5 files changed, 89 insertions(+), 68 deletions(-) create mode 100644 libmachine/drivers/base_test.go diff --git a/drivers/digitalocean/digitalocean.go b/drivers/digitalocean/digitalocean.go index 366802d1f5..bd45560715 100644 --- a/drivers/digitalocean/digitalocean.go +++ b/drivers/digitalocean/digitalocean.go @@ -230,13 +230,6 @@ func (d *Driver) GetURL() (string, error) { return fmt.Sprintf("tcp://%s:2376", ip), nil } -func (d *Driver) GetIP() (string, error) { - if d.IPAddress == "" { - return "", fmt.Errorf("IP address is not set") - } - return d.IPAddress, nil -} - func (d *Driver) GetState() (state.State, error) { droplet, _, err := d.getClient().Droplets.Get(d.DropletID) if err != nil { diff --git a/drivers/exoscale/exoscale.go b/drivers/exoscale/exoscale.go index e083e3f0d5..7c65ea978c 100644 --- a/drivers/exoscale/exoscale.go +++ b/drivers/exoscale/exoscale.go @@ -150,13 +150,6 @@ func (d *Driver) GetURL() (string, error) { return fmt.Sprintf("tcp://%s:2376", ip), nil } -func (d *Driver) GetIP() (string, error) { - if d.IPAddress == "" { - return "", fmt.Errorf("IP address is not set") - } - return d.IPAddress, nil -} - func (d *Driver) GetState() (state.State, error) { client := egoscale.NewClient(d.URL, d.ApiKey, d.ApiSecretKey) vm, err := client.GetVirtualMachine(d.Id) diff --git a/drivers/generic/generic.go b/drivers/generic/generic.go index 57a47f1702..99e35b0910 100644 --- a/drivers/generic/generic.go +++ b/drivers/generic/generic.go @@ -1,10 +1,12 @@ package generic import ( + "errors" "fmt" "net" "os" "path/filepath" + "strconv" "time" "github.com/docker/machine/libmachine/drivers" @@ -20,13 +22,11 @@ type Driver struct { } const ( - defaultSSHUser = "root" - defaultSSHPort = 22 defaultTimeout = 1 * time.Second ) var ( - defaultSSHKey = filepath.Join(mcnutils.GetHomeDir(), ".ssh", "id_rsa") + defaultSourceSSHKey = filepath.Join(mcnutils.GetHomeDir(), ".ssh", "id_rsa") ) // GetCreateFlags registers the flags this driver adds to @@ -40,17 +40,17 @@ func (d *Driver) GetCreateFlags() []mcnflag.Flag { mcnflag.StringFlag{ Name: "generic-ssh-user", Usage: "SSH user", - Value: defaultSSHUser, + Value: drivers.DefaultSSHUser, }, mcnflag.StringFlag{ Name: "generic-ssh-key", Usage: "SSH private key path", - Value: defaultSSHKey, + Value: defaultSourceSSHKey, }, mcnflag.IntFlag{ Name: "generic-ssh-port", Usage: "SSH port", - Value: defaultSSHPort, + Value: drivers.DefaultSSHPort, }, } } @@ -58,13 +58,11 @@ func (d *Driver) GetCreateFlags() []mcnflag.Flag { // NewDriver creates and returns a new instance of the driver func NewDriver(hostName, storePath string) drivers.Driver { return &Driver{ - SSHKey: defaultSSHKey, BaseDriver: &drivers.BaseDriver{ - SSHUser: defaultSSHUser, - SSHPort: defaultSSHPort, MachineName: hostName, StorePath: storePath, }, + SSHKey: defaultSourceSSHKey, } } @@ -87,29 +85,26 @@ func (d *Driver) SetConfigFromFlags(flags drivers.DriverOptions) error { d.SSHPort = flags.Int("generic-ssh-port") if d.IPAddress == "" { - return fmt.Errorf("generic driver requires the --generic-ip-address option") + return errors.New("generic driver requires the --generic-ip-address option") } if d.SSHKey == "" { - return fmt.Errorf("generic driver requires the --generic-ssh-key option") + return errors.New("generic driver requires the --generic-ssh-key option") } return nil } -func (d *Driver) PreCreateCheck() error { - return nil -} - func (d *Driver) Create() error { - log.Infof("Importing SSH key...") + log.Info("Importing SSH key...") + // TODO: validate the key is a valid key if err := mcnutils.CopyFile(d.SSHKey, d.GetSSHKeyPath()); err != nil { return fmt.Errorf("unable to copy ssh key: %s", err) } if err := os.Chmod(d.GetSSHKeyPath(), 0600); err != nil { - return err + return fmt.Errorf("unable to set permissions on the ssh key: %s", err) } log.Debugf("IP: %s", d.IPAddress) @@ -125,16 +120,10 @@ func (d *Driver) GetURL() (string, error) { return fmt.Sprintf("tcp://%s:2376", ip), nil } -func (d *Driver) GetIP() (string, error) { - if d.IPAddress == "" { - return "", fmt.Errorf("IP address is not set") - } - return d.IPAddress, nil -} - func (d *Driver) GetState() (state.State, error) { - addr := fmt.Sprintf("%s:%d", d.IPAddress, d.SSHPort) - _, err := net.DialTimeout("tcp", addr, defaultTimeout) + + address := net.JoinHostPort(d.IPAddress, strconv.Itoa(d.SSHPort)) + _, err := net.DialTimeout("tcp", address, defaultTimeout) var st state.State if err != nil { st = state.Stopped @@ -145,11 +134,11 @@ func (d *Driver) GetState() (state.State, error) { } func (d *Driver) Start() error { - return fmt.Errorf("generic driver does not support start") + return errors.New("generic driver does not support start") } func (d *Driver) Stop() error { - return fmt.Errorf("generic driver does not support stop") + return errors.New("generic driver does not support stop") } func (d *Driver) Remove() error { @@ -158,22 +147,15 @@ func (d *Driver) Remove() error { func (d *Driver) Restart() error { log.Debug("Restarting...") - - if _, err := drivers.RunSSHCommandFromDriver(d, "sudo shutdown -r now"); err != nil { - return err - } - - return nil + _, err := drivers.RunSSHCommandFromDriver(d, "sudo shutdown -r now") + return err } func (d *Driver) Kill() error { log.Debug("Killing...") - if _, err := drivers.RunSSHCommandFromDriver(d, "sudo shutdown -P now"); err != nil { - return err - } - - return nil + _, err := drivers.RunSSHCommandFromDriver(d, "sudo shutdown -P now") + return err } func (d *Driver) publicSSHKeyPath() string { diff --git a/libmachine/drivers/base.go b/libmachine/drivers/base.go index 17bd097dd8..38dd106581 100644 --- a/libmachine/drivers/base.go +++ b/libmachine/drivers/base.go @@ -1,23 +1,29 @@ package drivers -import "path/filepath" +import ( + "errors" + "fmt" + "net" + "path/filepath" +) + +const ( + DefaultSSHUser = "root" + DefaultSSHPort = 22 +) // BaseDriver - Embed this struct into drivers to provide the common set // of fields and functions. type BaseDriver struct { IPAddress string + MachineName string SSHUser string SSHPort int - MachineName string + SSHKeyPath string + StorePath string SwarmMaster bool SwarmHost string SwarmDiscovery string - StorePath string -} - -// GetSSHKeyPath - -func (d *BaseDriver) GetSSHKeyPath() string { - return filepath.Join(d.StorePath, "machines", d.MachineName, "id_rsa") } // DriverName returns the name of the driver @@ -25,20 +31,35 @@ func (d *BaseDriver) DriverName() string { return "unknown" } -// GetIP returns the ip +// GetMachineName returns the machine name func (d *BaseDriver) GetMachineName() string { return d.MachineName } -// ResolveStorePath - -func (d *BaseDriver) ResolveStorePath(file string) string { - return filepath.Join(d.StorePath, "machines", d.MachineName, file) +// GetIP returns the ip +func (d *BaseDriver) GetIP() (string, error) { + if d.IPAddress == "" { + return "", errors.New("IP address is not set") + } + ip := net.ParseIP(d.IPAddress) + if ip == nil { + return "", fmt.Errorf("IP address is invalid: %s", d.IPAddress) + } + return d.IPAddress, nil +} + +// GetSSHKeyPath returns the ssh key path +func (d *BaseDriver) GetSSHKeyPath() string { + if d.SSHKeyPath == "" { + d.SSHKeyPath = d.ResolveStorePath("id_rsa") + } + return d.SSHKeyPath } // GetSSHPort returns the ssh port, 22 if not specified func (d *BaseDriver) GetSSHPort() (int, error) { if d.SSHPort == 0 { - d.SSHPort = 22 + d.SSHPort = DefaultSSHPort } return d.SSHPort, nil @@ -47,9 +68,8 @@ func (d *BaseDriver) GetSSHPort() (int, error) { // GetSSHUsername returns the ssh user name, root if not specified func (d *BaseDriver) GetSSHUsername() string { if d.SSHUser == "" { - d.SSHUser = "root" + d.SSHUser = DefaultSSHUser } - return d.SSHUser } @@ -57,3 +77,8 @@ func (d *BaseDriver) GetSSHUsername() string { func (d *BaseDriver) PreCreateCheck() error { return nil } + +// ResolveStorePath returns the store path where the machine is +func (d *BaseDriver) ResolveStorePath(file string) string { + return filepath.Join(d.StorePath, "machines", d.MachineName, file) +} diff --git a/libmachine/drivers/base_test.go b/libmachine/drivers/base_test.go new file mode 100644 index 0000000000..3d62f8bb2e --- /dev/null +++ b/libmachine/drivers/base_test.go @@ -0,0 +1,28 @@ +package drivers + +import ( + "errors" + "fmt" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestIP(t *testing.T) { + cases := []struct { + baseDriver *BaseDriver + expectedIp string + expectedErr error + }{ + {&BaseDriver{}, "", errors.New("IP address is not set")}, + {&BaseDriver{IPAddress: "2001:4860:0:2001::68"}, "2001:4860:0:2001::68", nil}, + {&BaseDriver{IPAddress: "192.168.0.1"}, "192.168.0.1", nil}, + {&BaseDriver{IPAddress: "::1"}, "::1", nil}, + {&BaseDriver{IPAddress: "whatever"}, "", fmt.Errorf("IP address is invalid: %s", "whatever")}, + } + + for _, c := range cases { + ip, err := c.baseDriver.GetIP() + assert.Equal(t, c.expectedIp, ip) + assert.Equal(t, c.expectedErr, err) + } +}