Carry on commits from #2033

A couple of small cleanup and enhancements that were dropped after the revert.

Signed-off-by: Olivier Gambier <olivier@docker.com>
This commit is contained in:
Olivier Gambier 2015-10-29 11:14:09 -07:00
parent d855c35059
commit ae2d344c2b
5 changed files with 89 additions and 68 deletions

View File

@ -230,13 +230,6 @@ func (d *Driver) GetURL() (string, error) {
return fmt.Sprintf("tcp://%s:2376", ip), nil return fmt.Sprintf("tcp://%s:2376", ip), nil
} }
func (d *Driver) GetIP() (string, error) {
if d.IPAddress == "" {
return "", fmt.Errorf("IP address is not set")
}
return d.IPAddress, nil
}
func (d *Driver) GetState() (state.State, error) { func (d *Driver) GetState() (state.State, error) {
droplet, _, err := d.getClient().Droplets.Get(d.DropletID) droplet, _, err := d.getClient().Droplets.Get(d.DropletID)
if err != nil { if err != nil {

View File

@ -150,13 +150,6 @@ func (d *Driver) GetURL() (string, error) {
return fmt.Sprintf("tcp://%s:2376", ip), nil return fmt.Sprintf("tcp://%s:2376", ip), nil
} }
func (d *Driver) GetIP() (string, error) {
if d.IPAddress == "" {
return "", fmt.Errorf("IP address is not set")
}
return d.IPAddress, nil
}
func (d *Driver) GetState() (state.State, error) { func (d *Driver) GetState() (state.State, error) {
client := egoscale.NewClient(d.URL, d.ApiKey, d.ApiSecretKey) client := egoscale.NewClient(d.URL, d.ApiKey, d.ApiSecretKey)
vm, err := client.GetVirtualMachine(d.Id) vm, err := client.GetVirtualMachine(d.Id)

View File

@ -1,10 +1,12 @@
package generic package generic
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"time" "time"
"github.com/docker/machine/libmachine/drivers" "github.com/docker/machine/libmachine/drivers"
@ -20,13 +22,11 @@ type Driver struct {
} }
const ( const (
defaultSSHUser = "root"
defaultSSHPort = 22
defaultTimeout = 1 * time.Second defaultTimeout = 1 * time.Second
) )
var ( var (
defaultSSHKey = filepath.Join(mcnutils.GetHomeDir(), ".ssh", "id_rsa") defaultSourceSSHKey = filepath.Join(mcnutils.GetHomeDir(), ".ssh", "id_rsa")
) )
// GetCreateFlags registers the flags this driver adds to // GetCreateFlags registers the flags this driver adds to
@ -40,17 +40,17 @@ func (d *Driver) GetCreateFlags() []mcnflag.Flag {
mcnflag.StringFlag{ mcnflag.StringFlag{
Name: "generic-ssh-user", Name: "generic-ssh-user",
Usage: "SSH user", Usage: "SSH user",
Value: defaultSSHUser, Value: drivers.DefaultSSHUser,
}, },
mcnflag.StringFlag{ mcnflag.StringFlag{
Name: "generic-ssh-key", Name: "generic-ssh-key",
Usage: "SSH private key path", Usage: "SSH private key path",
Value: defaultSSHKey, Value: defaultSourceSSHKey,
}, },
mcnflag.IntFlag{ mcnflag.IntFlag{
Name: "generic-ssh-port", Name: "generic-ssh-port",
Usage: "SSH port", Usage: "SSH port",
Value: defaultSSHPort, Value: drivers.DefaultSSHPort,
}, },
} }
} }
@ -58,13 +58,11 @@ func (d *Driver) GetCreateFlags() []mcnflag.Flag {
// NewDriver creates and returns a new instance of the driver // NewDriver creates and returns a new instance of the driver
func NewDriver(hostName, storePath string) drivers.Driver { func NewDriver(hostName, storePath string) drivers.Driver {
return &Driver{ return &Driver{
SSHKey: defaultSSHKey,
BaseDriver: &drivers.BaseDriver{ BaseDriver: &drivers.BaseDriver{
SSHUser: defaultSSHUser,
SSHPort: defaultSSHPort,
MachineName: hostName, MachineName: hostName,
StorePath: storePath, StorePath: storePath,
}, },
SSHKey: defaultSourceSSHKey,
} }
} }
@ -87,29 +85,26 @@ func (d *Driver) SetConfigFromFlags(flags drivers.DriverOptions) error {
d.SSHPort = flags.Int("generic-ssh-port") d.SSHPort = flags.Int("generic-ssh-port")
if d.IPAddress == "" { if d.IPAddress == "" {
return fmt.Errorf("generic driver requires the --generic-ip-address option") return errors.New("generic driver requires the --generic-ip-address option")
} }
if d.SSHKey == "" { if d.SSHKey == "" {
return fmt.Errorf("generic driver requires the --generic-ssh-key option") return errors.New("generic driver requires the --generic-ssh-key option")
} }
return nil return nil
} }
func (d *Driver) PreCreateCheck() error {
return nil
}
func (d *Driver) Create() error { func (d *Driver) Create() error {
log.Infof("Importing SSH key...") log.Info("Importing SSH key...")
// TODO: validate the key is a valid key
if err := mcnutils.CopyFile(d.SSHKey, d.GetSSHKeyPath()); err != nil { if err := mcnutils.CopyFile(d.SSHKey, d.GetSSHKeyPath()); err != nil {
return fmt.Errorf("unable to copy ssh key: %s", err) return fmt.Errorf("unable to copy ssh key: %s", err)
} }
if err := os.Chmod(d.GetSSHKeyPath(), 0600); err != nil { if err := os.Chmod(d.GetSSHKeyPath(), 0600); err != nil {
return err return fmt.Errorf("unable to set permissions on the ssh key: %s", err)
} }
log.Debugf("IP: %s", d.IPAddress) log.Debugf("IP: %s", d.IPAddress)
@ -125,16 +120,10 @@ func (d *Driver) GetURL() (string, error) {
return fmt.Sprintf("tcp://%s:2376", ip), nil return fmt.Sprintf("tcp://%s:2376", ip), nil
} }
func (d *Driver) GetIP() (string, error) {
if d.IPAddress == "" {
return "", fmt.Errorf("IP address is not set")
}
return d.IPAddress, nil
}
func (d *Driver) GetState() (state.State, error) { func (d *Driver) GetState() (state.State, error) {
addr := fmt.Sprintf("%s:%d", d.IPAddress, d.SSHPort)
_, err := net.DialTimeout("tcp", addr, defaultTimeout) address := net.JoinHostPort(d.IPAddress, strconv.Itoa(d.SSHPort))
_, err := net.DialTimeout("tcp", address, defaultTimeout)
var st state.State var st state.State
if err != nil { if err != nil {
st = state.Stopped st = state.Stopped
@ -145,11 +134,11 @@ func (d *Driver) GetState() (state.State, error) {
} }
func (d *Driver) Start() error { func (d *Driver) Start() error {
return fmt.Errorf("generic driver does not support start") return errors.New("generic driver does not support start")
} }
func (d *Driver) Stop() error { func (d *Driver) Stop() error {
return fmt.Errorf("generic driver does not support stop") return errors.New("generic driver does not support stop")
} }
func (d *Driver) Remove() error { func (d *Driver) Remove() error {
@ -158,22 +147,15 @@ func (d *Driver) Remove() error {
func (d *Driver) Restart() error { func (d *Driver) Restart() error {
log.Debug("Restarting...") log.Debug("Restarting...")
_, err := drivers.RunSSHCommandFromDriver(d, "sudo shutdown -r now")
if _, err := drivers.RunSSHCommandFromDriver(d, "sudo shutdown -r now"); err != nil { return err
return err
}
return nil
} }
func (d *Driver) Kill() error { func (d *Driver) Kill() error {
log.Debug("Killing...") log.Debug("Killing...")
if _, err := drivers.RunSSHCommandFromDriver(d, "sudo shutdown -P now"); err != nil { _, err := drivers.RunSSHCommandFromDriver(d, "sudo shutdown -P now")
return err return err
}
return nil
} }
func (d *Driver) publicSSHKeyPath() string { func (d *Driver) publicSSHKeyPath() string {

View File

@ -1,23 +1,29 @@
package drivers package drivers
import "path/filepath" import (
"errors"
"fmt"
"net"
"path/filepath"
)
const (
DefaultSSHUser = "root"
DefaultSSHPort = 22
)
// BaseDriver - Embed this struct into drivers to provide the common set // BaseDriver - Embed this struct into drivers to provide the common set
// of fields and functions. // of fields and functions.
type BaseDriver struct { type BaseDriver struct {
IPAddress string IPAddress string
MachineName string
SSHUser string SSHUser string
SSHPort int SSHPort int
MachineName string SSHKeyPath string
StorePath string
SwarmMaster bool SwarmMaster bool
SwarmHost string SwarmHost string
SwarmDiscovery string SwarmDiscovery string
StorePath string
}
// GetSSHKeyPath -
func (d *BaseDriver) GetSSHKeyPath() string {
return filepath.Join(d.StorePath, "machines", d.MachineName, "id_rsa")
} }
// DriverName returns the name of the driver // DriverName returns the name of the driver
@ -25,20 +31,35 @@ func (d *BaseDriver) DriverName() string {
return "unknown" return "unknown"
} }
// GetIP returns the ip // GetMachineName returns the machine name
func (d *BaseDriver) GetMachineName() string { func (d *BaseDriver) GetMachineName() string {
return d.MachineName return d.MachineName
} }
// ResolveStorePath - // GetIP returns the ip
func (d *BaseDriver) ResolveStorePath(file string) string { func (d *BaseDriver) GetIP() (string, error) {
return filepath.Join(d.StorePath, "machines", d.MachineName, file) if d.IPAddress == "" {
return "", errors.New("IP address is not set")
}
ip := net.ParseIP(d.IPAddress)
if ip == nil {
return "", fmt.Errorf("IP address is invalid: %s", d.IPAddress)
}
return d.IPAddress, nil
}
// GetSSHKeyPath returns the ssh key path
func (d *BaseDriver) GetSSHKeyPath() string {
if d.SSHKeyPath == "" {
d.SSHKeyPath = d.ResolveStorePath("id_rsa")
}
return d.SSHKeyPath
} }
// GetSSHPort returns the ssh port, 22 if not specified // GetSSHPort returns the ssh port, 22 if not specified
func (d *BaseDriver) GetSSHPort() (int, error) { func (d *BaseDriver) GetSSHPort() (int, error) {
if d.SSHPort == 0 { if d.SSHPort == 0 {
d.SSHPort = 22 d.SSHPort = DefaultSSHPort
} }
return d.SSHPort, nil return d.SSHPort, nil
@ -47,9 +68,8 @@ func (d *BaseDriver) GetSSHPort() (int, error) {
// GetSSHUsername returns the ssh user name, root if not specified // GetSSHUsername returns the ssh user name, root if not specified
func (d *BaseDriver) GetSSHUsername() string { func (d *BaseDriver) GetSSHUsername() string {
if d.SSHUser == "" { if d.SSHUser == "" {
d.SSHUser = "root" d.SSHUser = DefaultSSHUser
} }
return d.SSHUser return d.SSHUser
} }
@ -57,3 +77,8 @@ func (d *BaseDriver) GetSSHUsername() string {
func (d *BaseDriver) PreCreateCheck() error { func (d *BaseDriver) PreCreateCheck() error {
return nil return nil
} }
// ResolveStorePath returns the store path where the machine is
func (d *BaseDriver) ResolveStorePath(file string) string {
return filepath.Join(d.StorePath, "machines", d.MachineName, file)
}

View File

@ -0,0 +1,28 @@
package drivers
import (
"errors"
"fmt"
"github.com/stretchr/testify/assert"
"testing"
)
func TestIP(t *testing.T) {
cases := []struct {
baseDriver *BaseDriver
expectedIp string
expectedErr error
}{
{&BaseDriver{}, "", errors.New("IP address is not set")},
{&BaseDriver{IPAddress: "2001:4860:0:2001::68"}, "2001:4860:0:2001::68", nil},
{&BaseDriver{IPAddress: "192.168.0.1"}, "192.168.0.1", nil},
{&BaseDriver{IPAddress: "::1"}, "::1", nil},
{&BaseDriver{IPAddress: "whatever"}, "", fmt.Errorf("IP address is invalid: %s", "whatever")},
}
for _, c := range cases {
ip, err := c.baseDriver.GetIP()
assert.Equal(t, c.expectedIp, ip)
assert.Equal(t, c.expectedErr, err)
}
}