From 8a452a9629969b5368aace64df8e296a506c5156 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Fri, 23 Oct 2015 13:17:12 +0200 Subject: [PATCH] FIX #2064 Support local windows path + Use assert in tests + Improve parsing + Simplify code and tests Signed-off-by: David Gageot --- commands/scp.go | 172 ++++++++++++++------------ commands/scp_test.go | 282 +++++++++++-------------------------------- 2 files changed, 167 insertions(+), 287 deletions(-) diff --git a/commands/scp.go b/commands/scp.go index 93b3259971..78d44fdd2d 100644 --- a/commands/scp.go +++ b/commands/scp.go @@ -8,17 +8,13 @@ import ( "strings" "github.com/docker/machine/cli" - "github.com/docker/machine/libmachine/host" "github.com/docker/machine/libmachine/log" "github.com/docker/machine/libmachine/persist" ) var ( - errMalformedInput = errors.New("The input was malformed") errWrongNumberArguments = errors.New("Improper number of arguments") -) -var ( // TODO: possibly move this to ssh package baseSSHArgs = []string{ "-o", "IdentitiesOnly=yes", @@ -26,76 +22,86 @@ var ( "-o", "UserKnownHostsFile=/dev/null", "-o", "LogLevel=quiet", // suppress "Warning: Permanently added '[localhost]:2022' (ECDSA) to the list of known hosts." } - - hostLoader HostLoader ) -// TODO: Remove this hack in favor of better strategy. Currently the -// HostLoader interface wraps the loadHost() function for easier testing. -type HostLoader interface { - LoadHost(persist.Store, string) (*host.Host, error) +// HostInfo gives the mandatory information to connect to a host. +type HostInfo interface { + GetMachineName() string + + GetIP() (string, error) + + GetSSHUsername() string + + GetSSHKeyPath() string } -type ScpHostLoader struct{} - -func (s *ScpHostLoader) LoadHost(store persist.Store, name string) (*host.Host, error) { - return loadHost(store, name) +// HostInfoLoader loads host information. +type HostInfoLoader interface { + load(name string) (HostInfo, error) } -func getInfoForScpArg(hostAndPath string, store persist.Store) (*host.Host, string, []string, error) { - // TODO: What to do about colon in filepath? - splitInfo := strings.Split(hostAndPath, ":") +type storeHostInfoLoader struct { + store persist.Store +} - // Host path. e.g. "/tmp/foo" - if len(splitInfo) == 1 { - return nil, splitInfo[0], nil, nil +func (s *storeHostInfoLoader) load(name string) (HostInfo, error) { + host, err := loadHost(s.store, name) + if err != nil { + return nil, fmt.Errorf("Error loading host: %s", err) } - // Remote path. e.g. "machinename:/usr/bin/cmatrix" - if len(splitInfo) == 2 { - path := splitInfo[1] - host, err := hostLoader.LoadHost(store, splitInfo[0]) - if err != nil { - return nil, "", nil, fmt.Errorf("Error loading host: %s", err) - } - args := []string{ - "-i", - host.Driver.GetSSHKeyPath(), - } - return host, path, args, nil - } - - return nil, "", nil, errMalformedInput + return host.Driver, nil } -func generateLocationArg(host *host.Host, path string) (string, error) { - locationPrefix := "" - if host != nil { - ip, err := host.Driver.GetIP() - if err != nil { - return "", err - } - locationPrefix = fmt.Sprintf("%s@%s:", host.Driver.GetSSHUsername(), ip) +func cmdScp(c *cli.Context) error { + args := c.Args() + if len(args) != 2 { + cli.ShowCommandHelp(c, "scp") + return errWrongNumberArguments } - return locationPrefix + path, nil + + src := args[0] + dest := args[1] + + store := getStore(c) + hostInfoLoader := &storeHostInfoLoader{store} + + cmd, err := getScpCmd(src, dest, c.Bool("recursive"), hostInfoLoader) + if err != nil { + return err + } + + if err := runCmdWithStdIo(*cmd); err != nil { + return err + } + + return runCmdWithStdIo(*cmd) } -func getScpCmd(src, dest string, sshArgs []string, store persist.Store) (*exec.Cmd, error) { +func getScpCmd(src, dest string, recursive bool, hostInfoLoader HostInfoLoader) (*exec.Cmd, error) { cmdPath, err := exec.LookPath("scp") if err != nil { return nil, errors.New("Error: You must have a copy of the scp binary locally to use the scp feature.") } - srcHost, srcPath, srcOpts, err := getInfoForScpArg(src, store) + srcHost, srcPath, srcOpts, err := getInfoForScpArg(src, hostInfoLoader) if err != nil { return nil, err } - destHost, destPath, destOpts, err := getInfoForScpArg(dest, store) + destHost, destPath, destOpts, err := getInfoForScpArg(dest, hostInfoLoader) if err != nil { return nil, err } + // TODO: Check that "-3" flag is available in user's version of scp. + // It is on every system I've checked, but the manual mentioned it's "newer" + sshArgs := baseSSHArgs + sshArgs = append(sshArgs, "-3") + if recursive { + sshArgs = append(sshArgs, "-r") + } + // Append needed -i / private key flags to command. sshArgs = append(sshArgs, srcOpts...) sshArgs = append(sshArgs, destOpts...) @@ -105,6 +111,7 @@ func getScpCmd(src, dest string, sshArgs []string, store persist.Store) (*exec.C if err != nil { return nil, err } + sshArgs = append(sshArgs, locationArg) locationArg, err = generateLocationArg(destHost, destPath) if err != nil { @@ -117,6 +124,47 @@ func getScpCmd(src, dest string, sshArgs []string, store persist.Store) (*exec.C return cmd, nil } +func getInfoForScpArg(hostAndPath string, hostInfoLoader HostInfoLoader) (HostInfo, string, []string, error) { + // Local path. e.g. "/tmp/foo" + if !strings.Contains(hostAndPath, ":") { + return nil, hostAndPath, nil, nil + } + + // Path with hostname. e.g. "hostname:/usr/bin/cmatrix" + parts := strings.SplitN(hostAndPath, ":", 2) + hostName := parts[0] + path := parts[1] + if hostName == "localhost" { + return nil, path, nil, nil + } + + // Remote path + hostInfo, err := hostInfoLoader.load(hostName) + if err != nil { + return nil, "", nil, fmt.Errorf("Error loading host: %s", err) + } + + args := []string{ + "-i", + hostInfo.GetSSHKeyPath(), + } + return hostInfo, path, args, nil +} + +func generateLocationArg(hostInfo HostInfo, path string) (string, error) { + if hostInfo == nil { + return path, nil + } + + ip, err := hostInfo.GetIP() + if err != nil { + return "", err + } + + location := fmt.Sprintf("%s@%s:%s", hostInfo.GetSSHUsername(), ip, path) + return location, nil +} + func runCmdWithStdIo(cmd exec.Cmd) error { cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout @@ -124,33 +172,3 @@ func runCmdWithStdIo(cmd exec.Cmd) error { return cmd.Run() } - -func cmdScp(c *cli.Context) error { - hostLoader = &ScpHostLoader{} - - args := c.Args() - if len(args) != 2 { - cli.ShowCommandHelp(c, "scp") - return errWrongNumberArguments - } - - // TODO: Check that "-3" flag is available in user's version of scp. - // It is on every system I've checked, but the manual mentioned it's "newer" - sshArgs := append(baseSSHArgs, "-3") - - if c.Bool("recursive") { - sshArgs = append(sshArgs, "-r") - } - - src := args[0] - dest := args[1] - - store := getStore(c) - - cmd, err := getScpCmd(src, dest, sshArgs, store) - if err != nil { - return err - } - - return runCmdWithStdIo(*cmd) -} diff --git a/commands/scp_test.go b/commands/scp_test.go index 540642323e..221e5e0fea 100644 --- a/commands/scp_test.go +++ b/commands/scp_test.go @@ -1,253 +1,115 @@ package commands import ( - "errors" - "fmt" "os/exec" - "reflect" "testing" - "github.com/docker/machine/libmachine/drivers" - "github.com/docker/machine/libmachine/host" - "github.com/docker/machine/libmachine/mcnflag" - "github.com/docker/machine/libmachine/persist" - "github.com/docker/machine/libmachine/state" + "github.com/stretchr/testify/assert" ) -type ScpFakeDriver struct { - MockState state.State +type MockHostInfo struct { + name string + ip string + sshUsername string + sshKeyPath string } -type ScpFakeStore struct{} - -type ScpFakeHostLoader struct{} - -func (d ScpFakeDriver) GetCreateFlags() []mcnflag.Flag { - return []mcnflag.Flag{} +func (h *MockHostInfo) GetMachineName() string { + return h.name } -func (d ScpFakeDriver) DriverName() string { - return "fake" +func (h *MockHostInfo) GetIP() (string, error) { + return h.ip, nil } -func (d ScpFakeDriver) SetConfigFromFlags(flags drivers.DriverOptions) error { - return nil +func (h *MockHostInfo) GetSSHUsername() string { + return h.sshUsername } -func (d ScpFakeDriver) GetURL() (string, error) { - return "", nil +func (h *MockHostInfo) GetSSHKeyPath() string { + return h.sshKeyPath } -func (d ScpFakeDriver) GetIP() (string, error) { - return "12.34.56.78", nil +type MockHostInfoLoader struct { + hostInfo MockHostInfo } -func (d ScpFakeDriver) GetState() (state.State, error) { - return d.MockState, nil +func (l *MockHostInfoLoader) load(name string) (HostInfo, error) { + info := l.hostInfo + info.name = name + return &info, nil } -func (d ScpFakeDriver) GetMachineName() string { - return "myfunhost" +func TestGetInfoForLocalScpArg(t *testing.T) { + host, path, opts, err := getInfoForScpArg("/tmp/foo", nil) + assert.Nil(t, host) + assert.Equal(t, "/tmp/foo", path) + assert.Nil(t, opts) + assert.NoError(t, err) + + host, path, opts, err = getInfoForScpArg("localhost:C:\\path", nil) + assert.Nil(t, host) + assert.Equal(t, "C:\\path", path) + assert.Nil(t, opts) + assert.NoError(t, err) } -func (d ScpFakeDriver) GetSSHHostname() (string, error) { - return "12.34.56.76", nil +func TestGetInfoForRemoteScpArg(t *testing.T) { + hostInfoLoader := MockHostInfoLoader{MockHostInfo{ + sshKeyPath: "/fake/keypath/id_rsa", + }} + + host, path, opts, err := getInfoForScpArg("myfunhost:/home/docker/foo", &hostInfoLoader) + assert.Equal(t, "myfunhost", host.GetMachineName()) + assert.Equal(t, "/home/docker/foo", path) + assert.Equal(t, []string{"-i", "/fake/keypath/id_rsa"}, opts) + assert.NoError(t, err) + + host, path, opts, err = getInfoForScpArg("myfunhost:C:\\path", &hostInfoLoader) + assert.Equal(t, "myfunhost", host.GetMachineName()) + assert.Equal(t, "C:\\path", path) + assert.NoError(t, err) } -func (d ScpFakeDriver) GetSSHPort() (int, error) { - return 22, nil -} - -func (d ScpFakeDriver) PreCreateCheck() error { - return nil -} - -func (d ScpFakeDriver) Create() error { - return nil -} - -func (d ScpFakeDriver) Remove() error { - return nil -} - -func (d ScpFakeDriver) Start() error { - return nil -} - -func (d ScpFakeDriver) Stop() error { - return nil -} - -func (d ScpFakeDriver) Restart() error { - return nil -} - -func (d ScpFakeDriver) Kill() error { - return nil -} - -func (d ScpFakeDriver) Upgrade() error { - return nil -} - -func (d ScpFakeDriver) StartDocker() error { - return nil -} - -func (d ScpFakeDriver) StopDocker() error { - return nil -} - -func (d ScpFakeDriver) GetDockerConfigDir() string { - return "" -} - -func (d ScpFakeDriver) GetSSHCommand(args ...string) (*exec.Cmd, error) { - return &exec.Cmd{}, nil -} - -func (d ScpFakeDriver) GetSSHUsername() string { - return "root" -} - -func (d ScpFakeDriver) GetSSHKeyPath() string { - return "/fake/keypath/id_rsa" -} - -func (d ScpFakeDriver) ResolveStorePath(file string) string { - return "/tmp/store/machines/fake" -} - -func (s ScpFakeStore) Exists(name string) (bool, error) { - return true, nil -} - -func (s ScpFakeStore) GetActive() (*host.Host, error) { - return nil, nil -} - -func (s ScpFakeStore) List() ([]*host.Host, error) { - return nil, nil -} - -func (s ScpFakeStore) Load(name string) (*host.Host, error) { - return nil, nil -} - -func (s ScpFakeStore) Remove(name string) error { - return nil -} - -func (s ScpFakeStore) Save(host *host.Host) error { - return nil -} - -func (s ScpFakeStore) NewHost(driver drivers.Driver) (*host.Host, error) { - return nil, nil -} - -func (fshl *ScpFakeHostLoader) LoadHost(store persist.Store, name string) (*host.Host, error) { - if name == "myfunhost" { - return &host.Host{ - Name: "myfunhost", - Driver: ScpFakeDriver{}, - }, nil - } - return nil, errors.New("Host not found") -} - -func TestGetInfoForScpArg(t *testing.T) { - store := ScpFakeStore{} - hostLoader = &ScpFakeHostLoader{} - - expectedPath := "/tmp/foo" - host, path, opts, err := getInfoForScpArg("/tmp/foo", store) - if err != nil { - t.Fatalf("Unexpected error in local getInfoForScpArg call: %s", err) - } - if path != expectedPath { - t.Fatalf("Path %s not equal to expected path %s", path, expectedPath) - } - if host != nil { - t.Fatal("host should be nil") - } - if opts != nil { - t.Fatal("opts should be nil") - } - - host, path, opts, err = getInfoForScpArg("myfunhost:/home/docker/foo", store) - if err != nil { - t.Fatalf("Unexpected error in machine-based getInfoForScpArg call: %s", err) - } - expectedOpts := []string{ - "-i", - "/fake/keypath/id_rsa", - } - for i := range opts { - if expectedOpts[i] != opts[i] { - t.Fatalf("Mismatch in returned opts: %s != %s", expectedOpts[i], opts[i]) - } - } - if host.Name != "myfunhost" { - t.Fatalf("Expected host.Name to be myfunhost, got %s", host.Name) - } - if path != "/home/docker/foo" { - t.Fatalf("Expected path to be /home/docker/foo, got %s", path) - } - - host, path, opts, err = getInfoForScpArg("foo:bar:widget", store) - if err != errMalformedInput { - t.Fatalf("Didn't get back an error when we were expecting it for malformed args") - } -} - -func TestGenerateLocationArg(t *testing.T) { - host := host.Host{ - Driver: ScpFakeDriver{}, - } - - // local arg +func TestHostLocation(t *testing.T) { arg, err := generateLocationArg(nil, "/home/docker/foo") - if err != nil { - t.Fatalf("Unexpected error generating location arg for local: %s", err) - } - if arg != "/home/docker/foo" { - t.Fatalf("Expected arg to be /home/docker/foo, was %s", arg) + + assert.Equal(t, "/home/docker/foo", arg) + assert.NoError(t, err) +} + +func TestRemoteLocation(t *testing.T) { + hostInfo := MockHostInfo{ + ip: "12.34.56.78", + sshUsername: "root", } - arg, err = generateLocationArg(&host, "/home/docker/foo") - if err != nil { - t.Fatalf("Unexpected error generating location arg for remote: %s", err) - } - if arg != "root@12.34.56.78:/home/docker/foo" { - t.Fatalf("Expected arg to be root@12.34.56.78, instead it was %s", arg) - } + arg, err := generateLocationArg(&hostInfo, "/home/docker/foo") + + assert.Equal(t, "root@12.34.56.78:/home/docker/foo", arg) + assert.NoError(t, err) } func TestGetScpCmd(t *testing.T) { - // TODO: This is a little "integration-ey". Perhaps - // make an ScpDispatcher (name?) interface so that the reliant - // methods can be mocked. + hostInfoLoader := MockHostInfoLoader{MockHostInfo{ + ip: "12.34.56.78", + sshUsername: "root", + sshKeyPath: "/fake/keypath/id_rsa", + }} + + cmd, err := getScpCmd("/tmp/foo", "myfunhost:/home/docker/foo", true, &hostInfoLoader) + expectedArgs := append( baseSSHArgs, "-3", + "-r", "-i", "/fake/keypath/id_rsa", "/tmp/foo", "root@12.34.56.78:/home/docker/foo", ) expectedCmd := exec.Command("/usr/bin/scp", expectedArgs...) - store := ScpFakeStore{} - cmd, err := getScpCmd("/tmp/foo", "myfunhost:/home/docker/foo", append(baseSSHArgs, "-3"), store) - if err != nil { - t.Fatalf("Unexpected err getting scp command: %s", err) - } - - correct := reflect.DeepEqual(expectedCmd, cmd) - if !correct { - fmt.Println(expectedCmd) - fmt.Println(cmd) - t.Fatal("Expected scp cmd structs to be equal but there was mismatch") - } + assert.Equal(t, expectedCmd, cmd) + assert.NoError(t, err) }