diff --git a/commands/detectshell.go b/commands/detectshell.go new file mode 100644 index 0000000000..12be8d4190 --- /dev/null +++ b/commands/detectshell.go @@ -0,0 +1,24 @@ +// +build !windows + +package commands + +import ( + "fmt" + "os" + "path/filepath" +) + +func detectShell() (string, error) { + shell := os.Getenv("SHELL") + + if shell == "" { + fmt.Printf("The default lines below are for a sh/bash shell, you can specify the shell you're using, with the --shell flag.\n\n") + return "", ErrUnknownShell + } + + if os.Getenv("__fish_bin_dir") != "" { + return "fish", nil + } + + return filepath.Base(shell), nil +} diff --git a/commands/detectshell_test.go b/commands/detectshell_test.go new file mode 100644 index 0000000000..8c44b1caef --- /dev/null +++ b/commands/detectshell_test.go @@ -0,0 +1,29 @@ +package commands + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDetectBash(t *testing.T) { + originalShell := os.Getenv("SHELL") + os.Setenv("SHELL", "/bin/bash") + defer os.Setenv("SHELL", originalShell) + shell, err := detectShell() + assert.Nil(t, err) + assert.Equal(t, "bash", shell) +} + +func TestDetectFish(t *testing.T) { + originalShell := os.Getenv("SHELL") + os.Setenv("SHELL", "/bin/bash") + defer os.Setenv("SHELL", originalShell) + originalFishdir := os.Getenv("__fish_bin_dir") + os.Setenv("__fish_bin_dir", "/usr/local/Cellar/fish/2.2.0/bin") + defer os.Setenv("__fish_bin_dir", originalFishdir) + shell, err := detectShell() + assert.Nil(t, err) + assert.Equal(t, "fish", shell) +} diff --git a/commands/detectshell_unix_test.go b/commands/detectshell_unix_test.go new file mode 100644 index 0000000000..ed495c8e74 --- /dev/null +++ b/commands/detectshell_unix_test.go @@ -0,0 +1,21 @@ +// +build !windows + +package commands + +import ( + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestUnknowShell(t *testing.T) { + originalShell := os.Getenv("SHELL") + os.Setenv("SHELL", "") + defer os.Setenv("SHELL", originalShell) + shell, err := detectShell() + fmt.Println(shell) + assert.Equal(t, err, ErrUnknownShell) + assert.Equal(t, "", shell) +} diff --git a/commands/detectshell_windows.go b/commands/detectshell_windows.go new file mode 100644 index 0000000000..faf7c90a45 --- /dev/null +++ b/commands/detectshell_windows.go @@ -0,0 +1,76 @@ +package commands + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "syscall" + "unsafe" +) + +// re-implementation of private function in https://github.com/golang/go/blob/master/src/syscall/syscall_windows.go#L945 +func getProcessEntry(pid int) (pe *syscall.ProcessEntry32, err error) { + snapshot, err := syscall.CreateToolhelp32Snapshot(syscall.TH32CS_SNAPPROCESS, 0) + if err != nil { + return nil, err + } + defer syscall.CloseHandle(syscall.Handle(snapshot)) + + var processEntry syscall.ProcessEntry32 + processEntry.Size = uint32(unsafe.Sizeof(processEntry)) + err = syscall.Process32First(snapshot, &processEntry) + if err != nil { + return nil, err + } + + for { + if processEntry.ProcessID == uint32(pid) { + pe = &processEntry + return + } + + err = syscall.Process32Next(snapshot, &processEntry) + if err != nil { + return nil, err + } + } +} + +// startedBy returns the exe file name of the parent process. +func startedBy() (exefile string, err error) { + ppid := os.Getppid() + + pe, err := getProcessEntry(ppid) + if err != nil { + return "", err + } + + name := syscall.UTF16ToString(pe.ExeFile[:]) + return name, nil +} + +func detectShell() (string, error) { + shell := os.Getenv("SHELL") + + if shell == "" { + shell, err := startedBy() + if err != nil { + return "cmd", err // defaulting to cmd + } + if strings.Contains(strings.ToLower(shell), "powershell") { + return "powershell", nil + } else if strings.Contains(strings.ToLower(shell), "cmd") { + return "cmd", nil + } else { + fmt.Printf("You can further specify your shell with either 'cmd' or 'powershell' with the --shell flag.\n\n") + return "cmd", nil // this could be either powershell or cmd, defaulting to cmd + } + } + + if os.Getenv("__fish_bin_dir") != "" { + return "fish", nil + } + + return filepath.Base(shell), nil +} diff --git a/commands/env_windows_test.go b/commands/detectshell_windows_test.go similarity index 67% rename from commands/env_windows_test.go rename to commands/detectshell_windows_test.go index 77df1e16cf..a65d95c74f 100644 --- a/commands/env_windows_test.go +++ b/commands/detectshell_windows_test.go @@ -15,3 +15,10 @@ func TestDetect(t *testing.T) { assert.Nil(t, err) assert.Equal(t, "cmd", shell) } + +func TestStartedBy(t *testing.T) { + shell, err := startedBy() + assert.Nil(t, err) + assert.NotNil(t, shell) + assert.Equal(t, "go.exe", shell) +} diff --git a/commands/env.go b/commands/env.go index 55ec74c100..e4fda1fe20 100644 --- a/commands/env.go +++ b/commands/env.go @@ -5,7 +5,6 @@ import ( "fmt" "os" "path/filepath" - "runtime" "strings" "text/template" @@ -251,22 +250,3 @@ func (g *EnvUsageHintGenerator) GenerateUsageHint(userShell string, args []strin return fmt.Sprintf("%s Run this command to configure your shell: \n%s %s\n", comment, comment, cmd) } - -func detectShell() (string, error) { - shell := os.Getenv("SHELL") - - if shell == "" { - if runtime.GOOS == "windows" { - fmt.Printf("You can further specify your shell with either 'cmd' or 'powershell' with the --shell flag.\n\n") - return "cmd", nil // this could be either powershell or cmd, defaulting to cmd - } - fmt.Printf("The default lines below are for a sh/bash shell, you can specify the shell you're using, with the --shell flag.\n\n") - return "", ErrUnknownShell - } - - if os.Getenv("__fish_bin_dir") != "" { - return "fish", nil - } - - return filepath.Base(shell), nil -} diff --git a/commands/env_test.go b/commands/env_test.go index a7f9477955..8e6570f58d 100644 --- a/commands/env_test.go +++ b/commands/env_test.go @@ -5,8 +5,6 @@ import ( "path/filepath" "testing" - "fmt" - "github.com/docker/machine/commands/commandstest" "github.com/docker/machine/commands/mcndirs" "github.com/docker/machine/drivers/fakedriver" @@ -552,34 +550,3 @@ func TestShellCfgUnset(t *testing.T) { os.Setenv(test.noProxyVar, "") } } - -func TestDetectBash(t *testing.T) { - originalShell := os.Getenv("SHELL") - os.Setenv("SHELL", "/bin/bash") - defer os.Setenv("SHELL", originalShell) - shell, err := detectShell() - assert.Nil(t, err) - assert.Equal(t, "bash", shell) -} - -func TestDetectFish(t *testing.T) { - originalShell := os.Getenv("SHELL") - os.Setenv("SHELL", "/bin/bash") - defer os.Setenv("SHELL", originalShell) - originalFishdir := os.Getenv("__fish_bin_dir") - os.Setenv("__fish_bin_dir", "/usr/local/Cellar/fish/2.2.0/bin") - defer os.Setenv("__fish_bin_dir", originalFishdir) - shell, err := detectShell() - assert.Nil(t, err) - assert.Equal(t, "fish", shell) -} - -func TestUnknowShell(t *testing.T) { - originalShell := os.Getenv("SHELL") - os.Setenv("SHELL", "") - defer os.Setenv("SHELL", originalShell) - shell, err := detectShell() - fmt.Println(shell) - assert.Equal(t, err, ErrUnknownShell) - assert.Equal(t, "", shell) -}