Merge pull request #2821 from hypriot/check-grand-parent-for-windows-shell

Detect shell from grand parent process name if not directly called from a windows shell
This commit is contained in:
Nathan LeClaire 2016-01-20 15:55:23 -08:00
commit 474363d27a
2 changed files with 37 additions and 13 deletions

View File

@ -37,24 +37,22 @@ func getProcessEntry(pid int) (pe *syscall.ProcessEntry32, err error) {
}
}
// startedBy returns the exe file name of the parent process.
func startedBy() (exefile string, err error) {
ppid := os.Getppid()
pe, err := getProcessEntry(ppid)
// getNameAndItsPpid returns the exe file name its parent process id.
func getNameAndItsPpid(pid int) (exefile string, parentid int, err error) {
pe, err := getProcessEntry(pid)
if err != nil {
return "", err
return "", 0, err
}
name := syscall.UTF16ToString(pe.ExeFile[:])
return name, nil
return name, int(pe.ParentProcessID), nil
}
func Detect() (string, error) {
shell := os.Getenv("SHELL")
if shell == "" {
shell, err := startedBy()
shell, shellppid, err := getNameAndItsPpid(os.Getppid())
if err != nil {
return "cmd", err // defaulting to cmd
}
@ -63,8 +61,18 @@ func Detect() (string, error) {
} else if strings.Contains(strings.ToLower(shell), "cmd") {
return "cmd", nil
} else {
fmt.Printf("You can further specify your shell with either 'cmd' or 'powershell' with the --shell flag.\n\n")
return "cmd", nil // this could be either powershell or cmd, defaulting to cmd
shell, _, err := getNameAndItsPpid(shellppid)
if err != nil {
return "cmd", err // defaulting to cmd
}
if strings.Contains(strings.ToLower(shell), "powershell") {
return "powershell", nil
} else if strings.Contains(strings.ToLower(shell), "cmd") {
return "cmd", nil
} else {
fmt.Printf("You can further specify your shell with either 'cmd' or 'powershell' with the --shell flag.\n\n")
return "cmd", nil // this could be either powershell or cmd, defaulting to cmd
}
}
}

View File

@ -13,13 +13,29 @@ func TestDetect(t *testing.T) {
shell, err := Detect()
assert.Equal(t, "cmd", shell)
assert.Equal(t, "powershell", shell)
assert.NoError(t, err)
}
func TestStartedBy(t *testing.T) {
shell, err := startedBy()
func TestGetNameAndItsPpidOfCurrent(t *testing.T) {
shell, shellppid, err := getNameAndItsPpid(os.Getpid())
assert.Equal(t, "shell.test.exe", shell)
assert.Equal(t, os.Getppid(), shellppid)
assert.NoError(t, err)
}
func TestGetNameAndItsPpidOfParent(t *testing.T) {
shell, _, err := getNameAndItsPpid(os.Getppid())
assert.Equal(t, "go.exe", shell)
assert.NoError(t, err)
}
func TestGetNameAndItsPpidOfGrandParent(t *testing.T) {
shell, shellppid, err := getNameAndItsPpid(os.Getppid())
shell, shellppid, err = getNameAndItsPpid(shellppid)
assert.Equal(t, "powershell.exe", shell)
assert.NoError(t, err)
}