diff --git a/drivers/generic/generic.go b/drivers/generic/generic.go index bc84a36915..f94e01bca8 100644 --- a/drivers/generic/generic.go +++ b/drivers/generic/generic.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "os" + "path" "strconv" "time" @@ -87,13 +88,6 @@ func (d *Driver) GetSSHUsername() string { } func (d *Driver) GetSSHKeyPath() string { - if d.SSHKey == "" { - return "" - } - - if d.SSHKeyPath == "" { - d.SSHKeyPath = d.ResolveStorePath("id_rsa") - } return d.SSHKeyPath } @@ -114,8 +108,10 @@ func (d *Driver) SetConfigFromFlags(flags drivers.DriverOptions) error { func (d *Driver) PreCreateCheck() error { if d.SSHKey != "" { if _, err := os.Stat(d.SSHKey); os.IsNotExist(err) { - return fmt.Errorf("Ssh key does not exist: %q", d.SSHKey) + return fmt.Errorf("SSH key does not exist: %q", d.SSHKey) } + + // TODO: validate the key is a valid key } return nil @@ -126,13 +122,14 @@ func (d *Driver) Create() error { log.Info("No SSH key specified. Assuming an existing key at the default location.") } else { log.Info("Importing SSH key...") - // TODO: validate the key is a valid key - if err := mcnutils.CopyFile(d.SSHKey, d.GetSSHKeyPath()); err != nil { - return fmt.Errorf("unable to copy ssh key: %s", err) + + d.SSHKeyPath = d.ResolveStorePath(path.Base(d.SSHKey)) + if err := copySSHKey(d.SSHKey, d.SSHKeyPath); err != nil { + return err } - if err := os.Chmod(d.GetSSHKeyPath(), 0600); err != nil { - return fmt.Errorf("unable to set permissions on the ssh key: %s", err) + if err := copySSHKey(d.SSHKey+".pub", d.SSHKeyPath+".pub"); err != nil { + log.Infof("Couldn't copy SSH public key : %s", err) } } @@ -185,3 +182,15 @@ func (d *Driver) Kill() error { func (d *Driver) Remove() error { return nil } + +func copySSHKey(src, dst string) error { + if err := mcnutils.CopyFile(src, dst); err != nil { + return fmt.Errorf("unable to copy ssh key: %s", err) + } + + if err := os.Chmod(dst, 0600); err != nil { + return fmt.Errorf("unable to set permissions on the ssh key: %s", err) + } + + return nil +} diff --git a/libmachine/mcnutils/utils.go b/libmachine/mcnutils/utils.go index fdaabc968d..b1965b069f 100644 --- a/libmachine/mcnutils/utils.go +++ b/libmachine/mcnutils/utils.go @@ -63,11 +63,7 @@ func CopyFile(src, dst string) error { return err } - if err := os.Chmod(dst, fi.Mode()); err != nil { - return err - } - - return nil + return os.Chmod(dst, fi.Mode()) } func WaitForSpecificOrError(f func() (bool, error), maxAttempts int, waitInterval time.Duration) error {