Generic and Base slight cleanups

- tests for GetIP
- extract default values into consts (user & port)
- better error handling (cert permissions change)
- unexport Driver for generic (linting)
- ordering of methods and variables for better readability

Signed-off-by: Olivier Gambier <olivier@docker.com>
This commit is contained in:
Olivier Gambier 2015-10-20 18:56:34 -07:00
parent 99aacc7b79
commit 19625def22
3 changed files with 137 additions and 113 deletions

View File

@ -1,10 +1,12 @@
package generic
import (
"errors"
"fmt"
"net"
"os"
"path/filepath"
"strconv"
"time"
"github.com/docker/machine/libmachine/drivers"
@ -16,21 +18,50 @@ import (
type Driver struct {
*drivers.BaseDriver
SSHKey string
sourceSSHKey string
}
const (
defaultSSHUser = "root"
defaultSSHPort = 22
driverName = "generic"
defaultTimeout = 1 * time.Second
)
var (
defaultSSHKey = filepath.Join(mcnutils.GetHomeDir(), ".ssh", "id_rsa")
defaultSourceSSHKey = filepath.Join(mcnutils.GetHomeDir(), ".ssh", "id_rsa")
)
// GetCreateFlags registers the flags this driver adds to
// "docker hosts create"
// NewDriver creates and returns a new instance of the driver
func NewDriver(hostName, storePath string) drivers.Driver {
return &Driver{
BaseDriver: &drivers.BaseDriver{
MachineName: hostName,
StorePath: storePath,
},
sourceSSHKey: defaultSourceSSHKey,
}
}
func (d *Driver) Create() error {
log.Info("Importing SSH key...")
// TODO: validate the key is a valid key
if err := mcnutils.CopyFile(d.sourceSSHKey, d.GetSSHKeyPath()); err != nil {
return fmt.Errorf("unable to copy ssh key: %s", err)
}
if err := os.Chmod(d.GetSSHKeyPath(), 0600); err != nil {
return fmt.Errorf("unable to set permissions on the ssh key: %s", err)
}
log.Debugf("IP: %s", d.IPAddress)
return nil
}
func (d *Driver) DriverName() string {
return driverName
}
func (d *Driver) GetCreateFlags() []mcnflag.Flag {
return []mcnflag.Flag{
mcnflag.StringFlag{
@ -40,77 +71,31 @@ func (d *Driver) GetCreateFlags() []mcnflag.Flag {
mcnflag.StringFlag{
Name: "generic-ssh-user",
Usage: "SSH user",
Value: defaultSSHUser,
Value: drivers.DefaultSSHUser,
},
mcnflag.StringFlag{
Name: "generic-ssh-key",
Usage: "SSH private key path",
Value: defaultSSHKey,
Value: defaultSourceSSHKey,
},
mcnflag.IntFlag{
Name: "generic-ssh-port",
Usage: "SSH port",
Value: defaultSSHPort,
Value: drivers.DefaultSSHPort,
},
}
}
// NewDriver creates and returns a new instance of the driver
func NewDriver(hostName, storePath string) drivers.Driver {
return &Driver{
SSHKey: defaultSSHKey,
BaseDriver: &drivers.BaseDriver{
SSHUser: defaultSSHUser,
SSHPort: defaultSSHPort,
MachineName: hostName,
StorePath: storePath,
},
func (d *Driver) GetState() (state.State, error) {
address := net.JoinHostPort(d.IPAddress, strconv.Itoa(d.SSHPort))
_, err := net.DialTimeout("tcp", address, defaultTimeout)
var st state.State
if err != nil {
st = state.Stopped
} else {
st = state.Running
}
}
func (d *Driver) DriverName() string {
return "generic"
}
func (d *Driver) GetSSHUsername() string {
return d.SSHUser
}
func (d *Driver) SetConfigFromFlags(flags drivers.DriverOptions) error {
d.IPAddress = flags.String("generic-ip-address")
d.SSHUser = flags.String("generic-ssh-user")
d.SSHKey = flags.String("generic-ssh-key")
d.SSHPort = flags.Int("generic-ssh-port")
if d.IPAddress == "" {
return fmt.Errorf("generic driver requires the --generic-ip-address option")
}
if d.SSHKey == "" {
return fmt.Errorf("generic driver requires the --generic-ssh-key option")
}
return nil
}
func (d *Driver) PreCreateCheck() error {
return nil
}
func (d *Driver) Create() error {
log.Infof("Importing SSH key...")
if err := mcnutils.CopyFile(d.SSHKey, d.GetSSHKeyPath()); err != nil {
return fmt.Errorf("unable to copy ssh key: %s", err)
}
if err := os.Chmod(d.GetSSHKeyPath(), 0600); err != nil {
return err
}
log.Debugf("IP: %s", d.IPAddress)
return nil
return st, nil
}
func (d *Driver) GetURL() (string, error) {
@ -121,27 +106,10 @@ func (d *Driver) GetURL() (string, error) {
return fmt.Sprintf("tcp://%s:2376", ip), nil
}
}
}
func (d *Driver) GetState() (state.State, error) {
addr := fmt.Sprintf("%s:%d", d.IPAddress, d.SSHPort)
_, err := net.DialTimeout("tcp", addr, defaultTimeout)
var st state.State
if err != nil {
st = state.Stopped
} else {
st = state.Running
}
return st, nil
}
func (d *Driver) Start() error {
return fmt.Errorf("generic driver does not support start")
}
func (d *Driver) Stop() error {
return fmt.Errorf("generic driver does not support stop")
func (d *Driver) Kill() error {
log.Debug("Killing...")
_, err := drivers.RunSSHCommandFromDriver(d, "sudo shutdown -P now")
return err
}
func (d *Driver) Remove() error {
@ -150,24 +118,31 @@ func (d *Driver) Remove() error {
func (d *Driver) Restart() error {
log.Debug("Restarting...")
_, err := drivers.RunSSHCommandFromDriver(d, "sudo shutdown -r now")
return err
}
if _, err := drivers.RunSSHCommandFromDriver(d, "sudo shutdown -r now"); err != nil {
return err
func (d *Driver) SetConfigFromFlags(flags drivers.DriverOptions) error {
d.IPAddress = flags.String("generic-ip-address")
d.SSHUser = flags.String("generic-ssh-user")
d.sourceSSHKey = flags.String("generic-ssh-key")
d.SSHPort = flags.Int("generic-ssh-port")
if d.IPAddress == "" {
return errors.New("generic driver requires the --generic-ip-address option")
}
if d.sourceSSHKey == "" {
return errors.New("generic driver requires the --generic-ssh-key option")
}
return nil
}
func (d *Driver) Kill() error {
log.Debug("Killing...")
if _, err := drivers.RunSSHCommandFromDriver(d, "sudo shutdown -P now"); err != nil {
return err
}
return nil
func (d *Driver) Start() error {
return errors.New("generic driver does not support start")
}
func (d *Driver) publicSSHKeyPath() string {
return d.GetSSHKeyPath() + ".pub"
func (d *Driver) Stop() error {
return errors.New("generic driver does not support stop")
}

View File

@ -1,23 +1,29 @@
package drivers
import "path/filepath"
import (
"errors"
"fmt"
"net"
"path/filepath"
)
const (
DefaultSSHUser = "root"
DefaultSSHPort = 22
)
// BaseDriver - Embed this struct into drivers to provide the common set
// of fields and functions.
type BaseDriver struct {
IPAddress string
SSHUser string
SSHPort int
MachineName string
SSHKeyPath string
SSHPort int
SSHUser string
StorePath string
SwarmMaster bool
SwarmHost string
SwarmDiscovery string
StorePath string
}
// GetSSHKeyPath -
func (d *BaseDriver) GetSSHKeyPath() string {
return filepath.Join(d.StorePath, "machines", d.MachineName, "id_rsa")
}
// DriverName returns the name of the driver
@ -28,27 +34,37 @@ func (d *BaseDriver) DriverName() string {
// GetIP returns the ip
func (d *BaseDriver) GetIP() (string, error) {
if d.IPAddress == "" {
return "", fmt.Errorf("IP address is not set")
return "", errors.New("IP address is not set")
}
ip := net.ParseIP(d.IPAddress)
if ip == nil {
return "", fmt.Errorf("IP address is invalid")
return "", fmt.Errorf("IP address is invalid: %s", d.IPAddress)
}
return d.IPAddress, nil
}
// GetMachineName returns the machine name
func (d *BaseDriver) GetMachineName() string {
return d.MachineName
}
// ResolveStorePath -
func (d *BaseDriver) ResolveStorePath(file string) string {
return filepath.Join(d.StorePath, "machines", d.MachineName, file)
// GetSSHHostname returns hostname for use with ssh
func (d *BaseDriver) GetSSHHostname() (string, error) {
return d.GetIP()
}
// GetSSHKeyPath returns the ssh key path
func (d *BaseDriver) GetSSHKeyPath() string {
if d.SSHKeyPath == "" {
d.SSHKeyPath = d.ResolveStorePath("id_rsa")
}
return d.SSHKeyPath
}
// GetSSHPort returns the ssh port, 22 if not specified
func (d *BaseDriver) GetSSHPort() (int, error) {
if d.SSHPort == 0 {
d.SSHPort = 22
d.SSHPort = DefaultSSHPort
}
return d.SSHPort, nil
@ -57,7 +73,7 @@ func (d *BaseDriver) GetSSHPort() (int, error) {
// GetSSHUsername returns the ssh user name, root if not specified
func (d *BaseDriver) GetSSHUsername() string {
if d.SSHUser == "" {
d.SSHUser = "root"
d.SSHUser = DefaultSSHUser
}
return d.SSHUser
@ -67,3 +83,8 @@ func (d *BaseDriver) GetSSHUsername() string {
func (d *BaseDriver) PreCreateCheck() error {
return nil
}
// ResolveStorePath returns the store path where the machine is
func (d *BaseDriver) ResolveStorePath(file string) string {
return filepath.Join(d.StorePath, "machines", d.MachineName, file)
}

View File

@ -0,0 +1,28 @@
package drivers
import (
"errors"
"fmt"
"github.com/stretchr/testify/assert"
"testing"
)
func TestIP(t *testing.T) {
cases := []struct {
baseDriver *BaseDriver
expectedIp string
expectedErr error
}{
{&BaseDriver{}, "", errors.New("IP address is not set")},
{&BaseDriver{IPAddress: "2001:4860:0:2001::68"}, "2001:4860:0:2001::68", nil},
{&BaseDriver{IPAddress: "192.168.0.1"}, "192.168.0.1", nil},
{&BaseDriver{IPAddress: "::1"}, "::1", nil},
{&BaseDriver{IPAddress: "whatever"}, "", fmt.Errorf("IP address is invalid: %s", "whatever")},
}
for _, c := range cases {
ip, err := c.baseDriver.GetIP()
assert.Equal(t, c.expectedIp, ip)
assert.Equal(t, c.expectedErr, err)
}
}