From 3b4b168051fdf00e06a89505bce97369749a2587 Mon Sep 17 00:00:00 2001 From: Stefan Scherer Date: Wed, 13 Jan 2016 00:46:11 +0100 Subject: [PATCH] Check grand parent if not directly called from a windows shell Signed-off-by: Stefan Scherer --- libmachine/shell/shell_windows.go | 28 +++++++++++++++++--------- libmachine/shell/shell_windows_test.go | 22 +++++++++++++++++--- 2 files changed, 37 insertions(+), 13 deletions(-) diff --git a/libmachine/shell/shell_windows.go b/libmachine/shell/shell_windows.go index e46c5681d9..89cd2c8b0a 100644 --- a/libmachine/shell/shell_windows.go +++ b/libmachine/shell/shell_windows.go @@ -37,24 +37,22 @@ func getProcessEntry(pid int) (pe *syscall.ProcessEntry32, err error) { } } -// startedBy returns the exe file name of the parent process. -func startedBy() (exefile string, err error) { - ppid := os.Getppid() - - pe, err := getProcessEntry(ppid) +// getNameAndItsPpid returns the exe file name its parent process id. +func getNameAndItsPpid(pid int) (exefile string, parentid int, err error) { + pe, err := getProcessEntry(pid) if err != nil { - return "", err + return "", 0, err } name := syscall.UTF16ToString(pe.ExeFile[:]) - return name, nil + return name, int(pe.ParentProcessID), nil } func Detect() (string, error) { shell := os.Getenv("SHELL") if shell == "" { - shell, err := startedBy() + shell, shellppid, err := getNameAndItsPpid(os.Getppid()) if err != nil { return "cmd", err // defaulting to cmd } @@ -63,8 +61,18 @@ func Detect() (string, error) { } else if strings.Contains(strings.ToLower(shell), "cmd") { return "cmd", nil } else { - fmt.Printf("You can further specify your shell with either 'cmd' or 'powershell' with the --shell flag.\n\n") - return "cmd", nil // this could be either powershell or cmd, defaulting to cmd + shell, _, err := getNameAndItsPpid(shellppid) + if err != nil { + return "cmd", err // defaulting to cmd + } + if strings.Contains(strings.ToLower(shell), "powershell") { + return "powershell", nil + } else if strings.Contains(strings.ToLower(shell), "cmd") { + return "cmd", nil + } else { + fmt.Printf("You can further specify your shell with either 'cmd' or 'powershell' with the --shell flag.\n\n") + return "cmd", nil // this could be either powershell or cmd, defaulting to cmd + } } } diff --git a/libmachine/shell/shell_windows_test.go b/libmachine/shell/shell_windows_test.go index 945ebc51b4..81c0c50705 100644 --- a/libmachine/shell/shell_windows_test.go +++ b/libmachine/shell/shell_windows_test.go @@ -13,13 +13,29 @@ func TestDetect(t *testing.T) { shell, err := Detect() - assert.Equal(t, "cmd", shell) + assert.Equal(t, "powershell", shell) assert.NoError(t, err) } -func TestStartedBy(t *testing.T) { - shell, err := startedBy() +func TestGetNameAndItsPpidOfCurrent(t *testing.T) { + shell, shellppid, err := getNameAndItsPpid(os.Getpid()) + + assert.Equal(t, "shell.test.exe", shell) + assert.Equal(t, os.Getppid(), shellppid) + assert.NoError(t, err) +} + +func TestGetNameAndItsPpidOfParent(t *testing.T) { + shell, _, err := getNameAndItsPpid(os.Getppid()) assert.Equal(t, "go.exe", shell) assert.NoError(t, err) } + +func TestGetNameAndItsPpidOfGrandParent(t *testing.T) { + shell, shellppid, err := getNameAndItsPpid(os.Getppid()) + shell, shellppid, err = getNameAndItsPpid(shellppid) + + assert.Equal(t, "powershell.exe", shell) + assert.NoError(t, err) +}