Merge pull request #909 from sthulb/ssh-client

SSH Client
This commit is contained in:
Evan Hazlett 2015-04-21 07:58:39 -07:00
commit 51044b3c3f
11 changed files with 276 additions and 184 deletions

View File

@ -1,19 +1,20 @@
package commands package commands
import ( import (
"io"
"os" "os"
"os/exec" "strings"
log "github.com/Sirupsen/logrus" log "github.com/Sirupsen/logrus"
"github.com/codegangsta/cli" "github.com/codegangsta/cli"
"github.com/docker/machine/drivers" "github.com/docker/machine/drivers"
"github.com/docker/machine/ssh"
) )
func cmdSsh(c *cli.Context) { func cmdSsh(c *cli.Context) {
var ( var (
err error err error
sshCmd *exec.Cmd
) )
name := c.Args().First() name := c.Args().First()
@ -59,19 +60,18 @@ func cmdSsh(c *cli.Context) {
} }
} }
var output ssh.Output
if len(c.Args()) <= 1 { if len(c.Args()) <= 1 {
sshCmd, err = host.GetSSHCommand() err = host.CreateSSHShell()
} else { } else {
sshCmd, err = host.GetSSHCommand(c.Args()[1:]...) output, err = host.RunSSHCommand(strings.Join(c.Args()[1:], " "))
io.Copy(os.Stderr, output.Stderr)
io.Copy(os.Stdout, output.Stdout)
} }
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
sshCmd.Stdin = os.Stdin
sshCmd.Stdout = os.Stdout
sshCmd.Stderr = os.Stderr
if err := sshCmd.Run(); err != nil {
log.Fatal(err)
}
} }

View File

@ -3,7 +3,6 @@ package drivers
import ( import (
"errors" "errors"
"fmt" "fmt"
"os/exec"
"sort" "sort"
log "github.com/Sirupsen/logrus" log "github.com/Sirupsen/logrus"
@ -174,21 +173,32 @@ type DriverOptions interface {
Bool(key string) bool Bool(key string) bool
} }
func GetSSHCommandFromDriver(d Driver, args ...string) (*exec.Cmd, error) { func RunSSHCommandFromDriver(d Driver, args string) (ssh.Output, error) {
var output ssh.Output
host, err := d.GetSSHHostname() host, err := d.GetSSHHostname()
if err != nil { if err != nil {
return nil, err return output, err
} }
port, err := d.GetSSHPort() port, err := d.GetSSHPort()
if err != nil { if err != nil {
return nil, err return output, err
} }
user := d.GetSSHUsername() user := d.GetSSHUsername()
keyPath := d.GetSSHKeyPath() keyPath := d.GetSSHKeyPath()
return ssh.GetSSHCommand(host, port, user, keyPath, args...), nil auth := &ssh.Auth{
Keys: []string{keyPath},
}
client, err := ssh.NewClient(user, host, port, auth)
if err != nil {
return output, err
}
return client.Run(args)
} }
func MachineInState(d Driver, desiredState state.State) func() bool { func MachineInState(d Driver, desiredState state.State) func() bool {

View File

@ -234,9 +234,17 @@ func (c *ComputeUtil) deleteInstance() error {
func (c *ComputeUtil) executeCommands(commands []string, ip, sshKeyPath string) error { func (c *ComputeUtil) executeCommands(commands []string, ip, sshKeyPath string) error {
for _, command := range commands { for _, command := range commands {
cmd := ssh.GetSSHCommand(ip, 22, c.userName, sshKeyPath, command) auth := &ssh.Auth{
if err := cmd.Run(); err != nil { Keys: []string{sshKeyPath},
return fmt.Errorf("error executing command: %v %v", command, err) }
client, err := ssh.NewClient(c.userName, ip, 22, auth)
if err != nil {
return err
}
if _, err := client.Run(command); err != nil {
return err
} }
} }
return nil return nil

View File

@ -452,19 +452,17 @@ func (d *Driver) GetIP() (string, error) {
if s != state.Running { if s != state.Running {
return "", drivers.ErrHostIsNotRunning return "", drivers.ErrHostIsNotRunning
} }
cmd, err := drivers.GetSSHCommandFromDriver(d, "ip addr show dev eth1") output, err := drivers.RunSSHCommandFromDriver(d, "ip addr show dev eth1")
if err != nil { if err != nil {
return "", err return "", err
} }
// reset to nil as if using from Host Stdout is already set when using DEBUG var buf bytes.Buffer
cmd.Stdout = nil if _, err := buf.ReadFrom(output.Stdout); err != nil {
b, err := cmd.Output()
if err != nil {
return "", err return "", err
} }
out := string(b)
out := buf.String()
log.Debugf("SSH returned: %s\nEND SSH\n", out) log.Debugf("SSH returned: %s\nEND SSH\n", out)
// parse to find: inet 192.168.59.103/24 brd 192.168.59.255 scope global eth1 // parse to find: inet 192.168.59.103/24 brd 192.168.59.255 scope global eth1
lines := strings.Split(out, "\n") lines := strings.Split(out, "\n")

View File

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"os/exec"
"path/filepath" "path/filepath"
"regexp" "regexp"
@ -127,23 +126,49 @@ func (h *Host) Create(name string) error {
return nil return nil
} }
func (h *Host) GetSSHCommand(args ...string) (*exec.Cmd, error) { func (h *Host) RunSSHCommand(command string) (ssh.Output, error) {
var output ssh.Output
addr, err := h.Driver.GetSSHHostname() addr, err := h.Driver.GetSSHHostname()
if err != nil { if err != nil {
return nil, err return output, err
} }
user := h.Driver.GetSSHUsername()
port, err := h.Driver.GetSSHPort() port, err := h.Driver.GetSSHPort()
if err != nil { if err != nil {
return nil, err return output, err
} }
keyPath := h.Driver.GetSSHKeyPath() auth := &ssh.Auth{
Keys: []string{h.Driver.GetSSHKeyPath()},
}
cmd := ssh.GetSSHCommand(addr, port, user, keyPath, args...) client, err := ssh.NewClient(h.Driver.GetSSHUsername(), addr, port, auth)
return cmd, nil
return client.Run(command)
}
func (h *Host) CreateSSHShell() error {
addr, err := h.Driver.GetSSHHostname()
if err != nil {
return err
}
port, err := h.Driver.GetSSHPort()
if err != nil {
return err
}
auth := &ssh.Auth{
Keys: []string{h.Driver.GetSSHKeyPath()},
}
client, err := ssh.NewClient(h.Driver.GetSSHUsername(), addr, port, auth)
if err != nil {
return err
}
return client.Shell()
} }
func (h *Host) Start() error { func (h *Host) Start() error {
@ -326,15 +351,11 @@ func sshAvailableFunc(h *Host) func() bool {
log.Debugf("Error waiting for TCP waiting for SSH: %s", err) log.Debugf("Error waiting for TCP waiting for SSH: %s", err)
return false return false
} }
cmd, err := h.GetSSHCommand("exit 0")
if err != nil { if _, err := h.RunSSHCommand("exit 0"); err != nil {
log.Debugf("Error getting ssh command 'exit 0' : %s", err) log.Debugf("Error getting ssh command 'exit 0' : %s", err)
return false return false
} }
if err := cmd.Run(); err != nil {
log.Debugf("Error running ssh command 'exit 0' : %s", err)
return false
}
return true return true
} }
} }

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"errors" "errors"
"fmt" "fmt"
"os/exec"
"path" "path"
log "github.com/Sirupsen/logrus" log "github.com/Sirupsen/logrus"
@ -12,6 +11,7 @@ import (
"github.com/docker/machine/libmachine/auth" "github.com/docker/machine/libmachine/auth"
"github.com/docker/machine/libmachine/provision/pkgaction" "github.com/docker/machine/libmachine/provision/pkgaction"
"github.com/docker/machine/libmachine/swarm" "github.com/docker/machine/libmachine/swarm"
"github.com/docker/machine/ssh"
"github.com/docker/machine/state" "github.com/docker/machine/state"
"github.com/docker/machine/utils" "github.com/docker/machine/utils"
) )
@ -36,16 +36,13 @@ type Boot2DockerProvisioner struct {
func (provisioner *Boot2DockerProvisioner) Service(name string, action pkgaction.ServiceAction) error { func (provisioner *Boot2DockerProvisioner) Service(name string, action pkgaction.ServiceAction) error {
var ( var (
cmd *exec.Cmd
err error err error
) )
cmd, err = provisioner.SSHCommand(fmt.Sprintf("sudo /etc/init.d/%s %s", name, action.String()))
if err != nil { if _, err = provisioner.SSHCommand(fmt.Sprintf("sudo /etc/init.d/%s %s", name, action.String())); err != nil {
return err
}
if err := cmd.Run(); err != nil {
return err return err
} }
return nil return nil
} }
@ -101,15 +98,13 @@ func (provisioner *Boot2DockerProvisioner) Package(name string, action pkgaction
} }
func (provisioner *Boot2DockerProvisioner) Hostname() (string, error) { func (provisioner *Boot2DockerProvisioner) Hostname() (string, error) {
cmd, err := provisioner.SSHCommand(fmt.Sprintf("hostname")) output, err := provisioner.SSHCommand(fmt.Sprintf("hostname"))
if err != nil { if err != nil {
return "", err return "", err
} }
var so bytes.Buffer var so bytes.Buffer
cmd.Stdout = &so if _, err := so.ReadFrom(output.Stdout); err != nil {
if err := cmd.Run(); err != nil {
return "", err return "", err
} }
@ -117,16 +112,15 @@ func (provisioner *Boot2DockerProvisioner) Hostname() (string, error) {
} }
func (provisioner *Boot2DockerProvisioner) SetHostname(hostname string) error { func (provisioner *Boot2DockerProvisioner) SetHostname(hostname string) error {
cmd, err := provisioner.SSHCommand(fmt.Sprintf( if _, err := provisioner.SSHCommand(fmt.Sprintf(
"sudo hostname %s && echo %q | sudo tee /var/lib/boot2docker/etc/hostname", "sudo hostname %s && echo %q | sudo tee /var/lib/boot2docker/etc/hostname",
hostname, hostname,
hostname, hostname,
)) )); err != nil {
if err != nil {
return err return err
} }
return cmd.Run() return nil
} }
func (provisioner *Boot2DockerProvisioner) GetDockerOptionsDir() string { func (provisioner *Boot2DockerProvisioner) GetDockerOptionsDir() string {
@ -188,8 +182,8 @@ func (provisioner *Boot2DockerProvisioner) Provision(swarmOptions swarm.SwarmOpt
return nil return nil
} }
func (provisioner *Boot2DockerProvisioner) SSHCommand(args ...string) (*exec.Cmd, error) { func (provisioner *Boot2DockerProvisioner) SSHCommand(args string) (ssh.Output, error) {
return drivers.GetSSHCommandFromDriver(provisioner.Driver, args...) return drivers.RunSSHCommandFromDriver(provisioner.Driver, args)
} }
func (provisioner *Boot2DockerProvisioner) GetDriver() drivers.Driver { func (provisioner *Boot2DockerProvisioner) GetDriver() drivers.Driver {

View File

@ -3,12 +3,12 @@ package provision
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"os/exec"
"github.com/docker/machine/drivers" "github.com/docker/machine/drivers"
"github.com/docker/machine/libmachine/auth" "github.com/docker/machine/libmachine/auth"
"github.com/docker/machine/libmachine/provision/pkgaction" "github.com/docker/machine/libmachine/provision/pkgaction"
"github.com/docker/machine/libmachine/swarm" "github.com/docker/machine/libmachine/swarm"
"github.com/docker/machine/ssh"
) )
var provisioners = make(map[string]*RegisteredProvisioner) var provisioners = make(map[string]*RegisteredProvisioner)
@ -48,7 +48,7 @@ type Provisioner interface {
GetDriver() drivers.Driver GetDriver() drivers.Driver
// Short-hand for accessing an SSH command from the driver. // Short-hand for accessing an SSH command from the driver.
SSHCommand(args ...string) (*exec.Cmd, error) SSHCommand(args string) (ssh.Output, error)
// Set the OS Release info depending on how it's represented // Set the OS Release info depending on how it's represented
// internally // internally
@ -68,18 +68,13 @@ func DetectProvisioner(d drivers.Driver) (Provisioner, error) {
var ( var (
osReleaseOut bytes.Buffer osReleaseOut bytes.Buffer
) )
catOsReleaseCmd, err := drivers.GetSSHCommandFromDriver(d, "cat /etc/os-release") catOsReleaseOutput, err := drivers.RunSSHCommandFromDriver(d, "cat /etc/os-release")
if err != nil { if err != nil {
return nil, fmt.Errorf("Error getting SSH command: %s", err) return nil, fmt.Errorf("Error getting SSH command: %s", err)
} }
// Normally I would just use Output() for this, but d.GetSSHCommand if _, err := osReleaseOut.ReadFrom(catOsReleaseOutput.Stdout); err != nil {
// defaults to sending the output of the command to stdout in debug return nil, err
// mode, so that will be broken if we don't set it ourselves.
catOsReleaseCmd.Stdout = &osReleaseOut
if err := catOsReleaseCmd.Run(); err != nil {
return nil, fmt.Errorf("Error running SSH command to get /etc/os-release: %s", err)
} }
osReleaseInfo, err := NewOsRelease(osReleaseOut.Bytes()) osReleaseInfo, err := NewOsRelease(osReleaseOut.Bytes())

View File

@ -3,13 +3,13 @@ package provision
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"os/exec"
log "github.com/Sirupsen/logrus" log "github.com/Sirupsen/logrus"
"github.com/docker/machine/drivers" "github.com/docker/machine/drivers"
"github.com/docker/machine/libmachine/auth" "github.com/docker/machine/libmachine/auth"
"github.com/docker/machine/libmachine/provision/pkgaction" "github.com/docker/machine/libmachine/provision/pkgaction"
"github.com/docker/machine/libmachine/swarm" "github.com/docker/machine/libmachine/swarm"
"github.com/docker/machine/ssh"
"github.com/docker/machine/utils" "github.com/docker/machine/utils"
) )
@ -38,12 +38,7 @@ type UbuntuProvisioner struct {
func (provisioner *UbuntuProvisioner) Service(name string, action pkgaction.ServiceAction) error { func (provisioner *UbuntuProvisioner) Service(name string, action pkgaction.ServiceAction) error {
command := fmt.Sprintf("sudo service %s %s", name, action.String()) command := fmt.Sprintf("sudo service %s %s", name, action.String())
cmd, err := provisioner.SSHCommand(command) if _, err := provisioner.SSHCommand(command); err != nil {
if err != nil {
return err
}
if err := cmd.Run(); err != nil {
return err return err
} }
@ -70,12 +65,7 @@ func (provisioner *UbuntuProvisioner) Package(name string, action pkgaction.Pack
command := fmt.Sprintf("DEBIAN_FRONTEND=noninteractive sudo -E apt-get %s -y %s", packageAction, name) command := fmt.Sprintf("DEBIAN_FRONTEND=noninteractive sudo -E apt-get %s -y %s", packageAction, name)
cmd, err := provisioner.SSHCommand(command) if _, err := provisioner.SSHCommand(command); err != nil {
if err != nil {
return err
}
if err := cmd.Run(); err != nil {
return err return err
} }
@ -83,15 +73,10 @@ func (provisioner *UbuntuProvisioner) Package(name string, action pkgaction.Pack
} }
func (provisioner *UbuntuProvisioner) dockerDaemonResponding() bool { func (provisioner *UbuntuProvisioner) dockerDaemonResponding() bool {
cmd, err := provisioner.SSHCommand("sudo docker version") if _, err := provisioner.SSHCommand("sudo docker version"); err != nil {
if err != nil {
log.Warn("Error getting SSH command to check if the daemon is up: %s", err) log.Warn("Error getting SSH command to check if the daemon is up: %s", err)
return false return false
} }
if err := cmd.Run(); err != nil {
log.Debug("Error checking for daemon up: %s", err)
return false
}
// The daemon is up if the command worked. Carry on. // The daemon is up if the command worked. Carry on.
return true return true
@ -128,15 +113,13 @@ func (provisioner *UbuntuProvisioner) Provision(swarmOptions swarm.SwarmOptions,
} }
func (provisioner *UbuntuProvisioner) Hostname() (string, error) { func (provisioner *UbuntuProvisioner) Hostname() (string, error) {
cmd, err := provisioner.SSHCommand("hostname") output, err := provisioner.SSHCommand("hostname")
if err != nil { if err != nil {
return "", err return "", err
} }
var so bytes.Buffer var so bytes.Buffer
cmd.Stdout = &so if _, err := so.ReadFrom(output.Stdout); err != nil {
if err := cmd.Run(); err != nil {
return "", err return "", err
} }
@ -144,26 +127,24 @@ func (provisioner *UbuntuProvisioner) Hostname() (string, error) {
} }
func (provisioner *UbuntuProvisioner) SetHostname(hostname string) error { func (provisioner *UbuntuProvisioner) SetHostname(hostname string) error {
cmd, err := provisioner.SSHCommand(fmt.Sprintf( if _, err := provisioner.SSHCommand(fmt.Sprintf(
"sudo hostname %s && echo %q | sudo tee /etc/hostname && echo \"127.0.0.1 %s\" | sudo tee -a /etc/hosts", "sudo hostname %s && echo %q | sudo tee /etc/hostname && echo \"127.0.0.1 %s\" | sudo tee -a /etc/hosts",
hostname, hostname,
hostname, hostname,
hostname, hostname,
)) )); err != nil {
if err != nil {
return err return err
} }
return cmd.Run() return nil
} }
func (provisioner *UbuntuProvisioner) GetDockerOptionsDir() string { func (provisioner *UbuntuProvisioner) GetDockerOptionsDir() string {
return "/etc/docker" return "/etc/docker"
} }
func (provisioner *UbuntuProvisioner) SSHCommand(args ...string) (*exec.Cmd, error) { func (provisioner *UbuntuProvisioner) SSHCommand(args string) (ssh.Output, error) {
return drivers.GetSSHCommandFromDriver(provisioner.Driver, args...) return drivers.RunSSHCommandFromDriver(provisioner.Driver, args)
} }
func (provisioner *UbuntuProvisioner) CompatibleWithHost() bool { func (provisioner *UbuntuProvisioner) CompatibleWithHost() bool {

View File

@ -25,18 +25,13 @@ type DockerOptions struct {
func installDockerGeneric(p Provisioner) error { func installDockerGeneric(p Provisioner) error {
// install docker - until cloudinit we use ubuntu everywhere so we // install docker - until cloudinit we use ubuntu everywhere so we
// just install it using the docker repos // just install it using the docker repos
cmd, err := p.SSHCommand("if ! type docker; then curl -sSL https://get.docker.com | sh -; fi") if output, err := p.SSHCommand("if ! type docker; then curl -sSL https://get.docker.com | sh -; fi"); err != nil {
if err != nil { var buf bytes.Buffer
if _, err := buf.ReadFrom(output.Stderr); err != nil {
return err return err
} }
// HACK: the script above will output debug to stderr; we save it and return fmt.Errorf("error installing docker: %s\n", buf.String())
// then check if the command returned an error; if so, we show the debug
var buf bytes.Buffer
cmd.Stderr = &buf
if err := cmd.Run(); err != nil {
return fmt.Errorf("error installing docker: %s\n%s\n", err, string(buf.Bytes()))
} }
return nil return nil
@ -99,11 +94,7 @@ func ConfigureAuth(p Provisioner, authOptions auth.AuthOptions) error {
dockerDir := p.GetDockerOptionsDir() dockerDir := p.GetDockerOptionsDir()
cmd, err := p.SSHCommand(fmt.Sprintf("sudo mkdir -p %s", dockerDir)) if _, err := p.SSHCommand(fmt.Sprintf("sudo mkdir -p %s", dockerDir)); err != nil {
if err != nil {
return err
}
if err := cmd.Run(); err != nil {
return err return err
} }
@ -132,27 +123,15 @@ func ConfigureAuth(p Provisioner, authOptions auth.AuthOptions) error {
machineServerKeyPath := path.Join(dockerDir, "server-key.pem") machineServerKeyPath := path.Join(dockerDir, "server-key.pem")
authOptions.ServerKeyRemotePath = machineServerKeyPath authOptions.ServerKeyRemotePath = machineServerKeyPath
cmd, err = p.SSHCommand(fmt.Sprintf("echo \"%s\" | sudo tee %s", string(caCert), machineCaCertPath)) if _, err = p.SSHCommand(fmt.Sprintf("echo \"%s\" | sudo tee %s", string(caCert), machineCaCertPath)); err != nil {
if err != nil {
return err
}
if err := cmd.Run(); err != nil {
return err return err
} }
cmd, err = p.SSHCommand(fmt.Sprintf("echo \"%s\" | sudo tee %s", string(serverKey), machineServerKeyPath)) if _, err = p.SSHCommand(fmt.Sprintf("echo \"%s\" | sudo tee %s", string(serverKey), machineServerKeyPath)); err != nil {
if err != nil {
return err
}
if err := cmd.Run(); err != nil {
return err return err
} }
cmd, err = p.SSHCommand(fmt.Sprintf("echo \"%s\" | sudo tee %s", string(serverCert), machineServerCertPath)) if _, err = p.SSHCommand(fmt.Sprintf("echo \"%s\" | sudo tee %s", string(serverCert), machineServerCertPath)); err != nil {
if err != nil {
return err
}
if err := cmd.Run(); err != nil {
return err return err
} }
@ -179,11 +158,7 @@ func ConfigureAuth(p Provisioner, authOptions auth.AuthOptions) error {
return err return err
} }
cmd, err = p.SSHCommand(fmt.Sprintf("echo \"%s\" | sudo tee -a %s", dkrcfg.EngineOptions, dkrcfg.EngineOptionsPath)) if _, err = p.SSHCommand(fmt.Sprintf("echo \"%s\" | sudo tee -a %s", dkrcfg.EngineOptions, dkrcfg.EngineOptionsPath)); err != nil {
if err != nil {
return err
}
if err := cmd.Run(); err != nil {
return err return err
} }
@ -234,11 +209,7 @@ func configureSwarm(p Provisioner, swarmOptions swarm.SwarmOptions) error {
return err return err
} }
cmd, err := p.SSHCommand(fmt.Sprintf("sudo docker pull %s", swarm.DockerImage)) if _, err := p.SSHCommand(fmt.Sprintf("sudo docker pull %s", swarm.DockerImage)); err != nil {
if err != nil {
return err
}
if err := cmd.Run(); err != nil {
return err return err
} }
@ -248,12 +219,8 @@ func configureSwarm(p Provisioner, swarmOptions swarm.SwarmOptions) error {
if swarmOptions.Master { if swarmOptions.Master {
log.Debug("launching swarm master") log.Debug("launching swarm master")
log.Debugf("master args: %s", masterArgs) log.Debugf("master args: %s", masterArgs)
cmd, err = p.SSHCommand(fmt.Sprintf("sudo docker run -d -p %s:%s --restart=always --name swarm-agent-master -v %s:%s %s manage %s", if _, err = p.SSHCommand(fmt.Sprintf("sudo docker run -d -p %s:%s --restart=always --name swarm-agent-master -v %s:%s %s manage %s",
port, port, dockerDir, dockerDir, swarm.DockerImage, masterArgs)) port, port, dockerDir, dockerDir, swarm.DockerImage, masterArgs)); err != nil {
if err != nil {
return err
}
if err := cmd.Run(); err != nil {
return err return err
} }
} }
@ -261,12 +228,8 @@ func configureSwarm(p Provisioner, swarmOptions swarm.SwarmOptions) error {
// start node agent // start node agent
log.Debug("launching swarm node") log.Debug("launching swarm node")
log.Debugf("node args: %s", nodeArgs) log.Debugf("node args: %s", nodeArgs)
cmd, err = p.SSHCommand(fmt.Sprintf("sudo docker run -d --restart=always --name swarm-agent -v %s:%s %s join %s", if _, err = p.SSHCommand(fmt.Sprintf("sudo docker run -d --restart=always --name swarm-agent -v %s:%s %s join %s",
dockerDir, dockerDir, swarm.DockerImage, nodeArgs)) dockerDir, dockerDir, swarm.DockerImage, nodeArgs)); err != nil {
if err != nil {
return err
}
if err := cmd.Run(); err != nil {
return err return err
} }

154
ssh/client.go Normal file
View File

@ -0,0 +1,154 @@
package ssh
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"os"
"github.com/docker/docker/pkg/term"
"golang.org/x/crypto/ssh"
)
type Client struct {
Config *ssh.ClientConfig
Hostname string
Port int
}
func NewClient(user string, host string, port int, auth *Auth) (*Client, error) {
config, err := NewConfig(user, auth)
if err != nil {
return nil, err
}
return &Client{
Config: config,
Hostname: host,
Port: port,
}, nil
}
func NewConfig(user string, auth *Auth) (*ssh.ClientConfig, error) {
var authMethods []ssh.AuthMethod
for _, k := range auth.Keys {
key, err := ioutil.ReadFile(k)
if err != nil {
return nil, err
}
privateKey, err := ssh.ParsePrivateKey(key)
if err != nil {
return nil, err
}
authMethods = append(authMethods, ssh.PublicKeys(privateKey))
}
for _, p := range auth.Passwords {
authMethods = append(authMethods, ssh.Password(p))
}
return &ssh.ClientConfig{
User: user,
Auth: authMethods,
}, nil
}
func (client *Client) Run(command string) (Output, error) {
var output Output
conn, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", client.Hostname, client.Port), client.Config)
if err != nil {
return output, err
}
session, err := conn.NewSession()
if err != nil {
return output, err
}
defer session.Close()
var stdout, stderr bytes.Buffer
session.Stdout = &stdout
session.Stderr = &stderr
output = Output{
Stdout: &stdout,
Stderr: &stderr,
}
return output, session.Run(command)
}
func (client *Client) Shell() error {
conn, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", client.Hostname, client.Port), client.Config)
if err != nil {
return err
}
session, err := conn.NewSession()
if err != nil {
return err
}
defer session.Close()
session.Stdout = os.Stdout
session.Stderr = os.Stderr
session.Stdin = os.Stdin
modes := ssh.TerminalModes{
ssh.ECHO: 1,
}
var termWidth, termHeight int
fd := os.Stdin.Fd()
if term.IsTerminal(fd) {
var oldState *term.State
oldState, err = term.MakeRaw(fd)
if err != nil {
return err
}
defer term.RestoreTerminal(fd, oldState)
winsize, err := term.GetWinsize(fd)
if err != nil {
termWidth = 80
termHeight = 24
} else {
termWidth = int(winsize.Width)
termHeight = int(winsize.Height)
}
}
if err := session.RequestPty("xterm", termHeight, termWidth, modes); err != nil {
return err
}
if err := session.Shell(); err != nil {
return err
}
session.Wait()
return nil
}
type Auth struct {
Passwords []string
Keys []string
}
type Output struct {
Stdout io.Reader
Stderr io.Reader
}

View File

@ -1,41 +1,9 @@
package ssh package ssh
import ( import (
"fmt"
"net" "net"
"os"
"os/exec"
"strings"
log "github.com/Sirupsen/logrus"
) )
func GetSSHCommand(host string, port int, user string, sshKey string, args ...string) *exec.Cmd {
defaultSSHArgs := []string{
"-o", "IdentitiesOnly=yes",
"-o", "StrictHostKeyChecking=no", // don't bother checking in ~/.ssh/known_hosts
"-o", "UserKnownHostsFile=/dev/null", // don't write anything to ~/.ssh/known_hosts
"-o", "ConnectionAttempts=3", // retry 3 times if SSH connection fails
"-o", "ConnectTimeout=10", // timeout after 10 seconds
"-o", "LogLevel=quiet", // suppress "Warning: Permanently added '[localhost]:2022' (ECDSA) to the list of known hosts."
"-p", fmt.Sprintf("%d", port),
"-i", sshKey,
fmt.Sprintf("%s@%s", user, host),
}
sshArgs := append(defaultSSHArgs, args...)
cmd := exec.Command("ssh", sshArgs...)
cmd.Stderr = os.Stderr
if os.Getenv("DEBUG") != "" {
cmd.Stdout = os.Stdout
}
log.Debugf("executing: %v", strings.Join(cmd.Args, " "))
return cmd
}
func WaitForTCP(addr string) error { func WaitForTCP(addr string) error {
for { for {
conn, err := net.Dial("tcp", addr) conn, err := net.Dial("tcp", addr)