diff --git a/api/utils.go b/api/utils.go index 0268552ad3..c65191ef5d 100644 --- a/api/utils.go +++ b/api/utils.go @@ -10,6 +10,7 @@ import ( "net/http" "strconv" "strings" + "time" log "github.com/Sirupsen/logrus" "github.com/docker/swarm/cluster" @@ -130,6 +131,108 @@ func proxy(tlsConfig *tls.Config, addr string, w http.ResponseWriter, r *http.Re return proxyAsync(tlsConfig, addr, w, r, nil) } +type tlsClientConn struct { + *tls.Conn + rawConn net.Conn +} + +func (c *tlsClientConn) CloseWrite() error { + // Go standard tls.Conn doesn't provide the CloseWrite() method so we do it + // on its underlying connection. + if cwc, ok := c.rawConn.(interface { + CloseWrite() error + }); ok { + log.Debug("Calling CloseWrite on Hijacked TLS Conn") + return cwc.CloseWrite() + } + return nil +} + +// We need to copy Go's implementation of tls.Dial (pkg/cryptor/tls/tls.go) in +// order to return our custom tlsClientCon struct which holds both the tls.Conn +// object _and_ its underlying raw connection. The rationale for this is that +// we need to be able to close the write end of the connection when attaching, +// which tls.Conn does not provide. +func tlsDialWithDialer(dialer *net.Dialer, network, addr string, config *tls.Config) (net.Conn, error) { + // We want the Timeout and Deadline values from dialer to cover the + // whole process: TCP connection and TLS handshake. This means that we + // also need to start our own timers now. + timeout := dialer.Timeout + + if !dialer.Deadline.IsZero() { + deadlineTimeout := dialer.Deadline.Sub(time.Now()) + if timeout == 0 || deadlineTimeout < timeout { + timeout = deadlineTimeout + } + } + + var errChannel chan error + + if timeout != 0 { + errChannel = make(chan error, 2) + time.AfterFunc(timeout, func() { + errChannel <- errors.New("") + }) + } + + rawConn, err := dialer.Dial(network, addr) + if err != nil { + return nil, err + } + // When we set up a TCP connection for hijack, there could be long periods + // of inactivity (a long running command with no output) that in certain + // network setups may cause ECONNTIMEOUT, leaving the client in an unknown + // state. Setting TCP KeepAlive on the socket connection will prohibit + // ECONNTIMEOUT unless the socket connection truly is broken + if tcpConn, ok := rawConn.(*net.TCPConn); ok { + tcpConn.SetKeepAlive(true) + tcpConn.SetKeepAlivePeriod(30 * time.Second) + } + + colonPos := strings.LastIndex(addr, ":") + if colonPos == -1 { + colonPos = len(addr) + } + hostname := addr[:colonPos] + + // If no ServerName is set, infer the ServerName + // from the hostname we're connecting to. + if config.ServerName == "" { + // Make a copy to avoid polluting argument or default. + c := *config + c.ServerName = hostname + config = &c + } + + conn := tls.Client(rawConn, config) + + if timeout == 0 { + err = conn.Handshake() + } else { + go func() { + errChannel <- conn.Handshake() + }() + + err = <-errChannel + } + + if err != nil { + rawConn.Close() + return nil, err + } + + // This is Docker difference with standard's crypto/tls package: returned a + // wrapper which holds both the TLS and raw connections. + return &tlsClientConn{conn, rawConn}, nil +} + +func dialHijack(tlsConfig *tls.Config, addr string) (net.Conn, error) { + if tlsConfig == nil { + return net.Dial("tcp", addr) + } + return tlsDialWithDialer(new(net.Dialer), "tcp", addr, tlsConfig) +} + func hijack(tlsConfig *tls.Config, addr string, w http.ResponseWriter, r *http.Request) error { if parts := strings.SplitN(addr, "://", 2); len(parts) == 2 { addr = parts[1] @@ -142,11 +245,7 @@ func hijack(tlsConfig *tls.Config, addr string, w http.ResponseWriter, r *http.R err error ) - if tlsConfig != nil { - d, err = tls.Dial("tcp", addr, tlsConfig) - } else { - d, err = net.Dial("tcp", addr) - } + d, err = dialHijack(tlsConfig, addr) if err != nil { return err }