diff --git a/ssh/client.go b/ssh/client.go index 175a0c391e..6a99dca7bb 100644 --- a/ssh/client.go +++ b/ssh/client.go @@ -2,12 +2,14 @@ package ssh import ( "bytes" + "errors" "fmt" "io" "io/ioutil" "os" "github.com/docker/docker/pkg/term" + "github.com/docker/machine/log" "golang.org/x/crypto/ssh" ) @@ -17,6 +19,10 @@ type Client struct { Port int } +const ( + maxDialAttempts = 10 +) + func NewClient(user string, host string, port int, auth *Auth) (*Client, error) { config, err := NewConfig(user, auth) if err != nil { @@ -58,16 +64,26 @@ func NewConfig(user string, auth *Auth) (*ssh.ClientConfig, error) { } func (client *Client) Run(command string) (Output, error) { - var output Output + var ( + output Output + conn *ssh.Client + err error + ) - conn, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", client.Hostname, client.Port), client.Config) - if err != nil { - return output, err + for i := 0; ; i++ { + conn, err = ssh.Dial("tcp", fmt.Sprintf("%s:%d", client.Hostname, client.Port), client.Config) + if err != nil { + log.Errorf("Error dialing TCP: %s", err) + if i == maxDialAttempts { + return output, errors.New("Max SSH/TCP dial attempts exceeded") + } + } + break } session, err := conn.NewSession() if err != nil { - return output, err + return output, fmt.Errorf("Error getting new session: %s", err) } defer session.Close()