From c57c8a329eaacc905bdb49ff3bca9f2c5fe08bd7 Mon Sep 17 00:00:00 2001
From: Brian Goff <cpuguy83@gmail.com>
Date: Sun, 29 Nov 2015 11:07:29 -0500
Subject: [PATCH] Ensure CloseWrite is called for hijacked TLS conns

Golang's `*tls.Conn` does not support `CloseWrite`, this means that
connections using TLS will not be able to properly close on hijacked
connections.

This copies Go's tls.Dial and instead returns an internal
`tlsClientConn` type that does store the raw net.Conn and implements
`CloseWrite`.
Implementation is mostly copied from
`github.com/docker/docker/api/client/hijack.go`

Signed-off-by: Brian Goff <cpuguy83@gmail.com>
---
 api/utils.go | 109 ++++++++++++++++++++++++++++++++++++++++++++++++---
 1 file changed, 104 insertions(+), 5 deletions(-)

diff --git a/api/utils.go b/api/utils.go
index 0268552ad3..c65191ef5d 100644
--- a/api/utils.go
+++ b/api/utils.go
@@ -10,6 +10,7 @@ import (
 	"net/http"
 	"strconv"
 	"strings"
+	"time"
 
 	log "github.com/Sirupsen/logrus"
 	"github.com/docker/swarm/cluster"
@@ -130,6 +131,108 @@ func proxy(tlsConfig *tls.Config, addr string, w http.ResponseWriter, r *http.Re
 	return proxyAsync(tlsConfig, addr, w, r, nil)
 }
 
+type tlsClientConn struct {
+	*tls.Conn
+	rawConn net.Conn
+}
+
+func (c *tlsClientConn) CloseWrite() error {
+	// Go standard tls.Conn doesn't provide the CloseWrite() method so we do it
+	// on its underlying connection.
+	if cwc, ok := c.rawConn.(interface {
+		CloseWrite() error
+	}); ok {
+		log.Debug("Calling CloseWrite on Hijacked TLS Conn")
+		return cwc.CloseWrite()
+	}
+	return nil
+}
+
+// We need to copy Go's implementation of tls.Dial (pkg/cryptor/tls/tls.go) in
+// order to return our custom tlsClientCon struct which holds both the tls.Conn
+// object _and_ its underlying raw connection. The rationale for this is that
+// we need to be able to close the write end of the connection when attaching,
+// which tls.Conn does not provide.
+func tlsDialWithDialer(dialer *net.Dialer, network, addr string, config *tls.Config) (net.Conn, error) {
+	// We want the Timeout and Deadline values from dialer to cover the
+	// whole process: TCP connection and TLS handshake. This means that we
+	// also need to start our own timers now.
+	timeout := dialer.Timeout
+
+	if !dialer.Deadline.IsZero() {
+		deadlineTimeout := dialer.Deadline.Sub(time.Now())
+		if timeout == 0 || deadlineTimeout < timeout {
+			timeout = deadlineTimeout
+		}
+	}
+
+	var errChannel chan error
+
+	if timeout != 0 {
+		errChannel = make(chan error, 2)
+		time.AfterFunc(timeout, func() {
+			errChannel <- errors.New("")
+		})
+	}
+
+	rawConn, err := dialer.Dial(network, addr)
+	if err != nil {
+		return nil, err
+	}
+	// When we set up a TCP connection for hijack, there could be long periods
+	// of inactivity (a long running command with no output) that in certain
+	// network setups may cause ECONNTIMEOUT, leaving the client in an unknown
+	// state. Setting TCP KeepAlive on the socket connection will prohibit
+	// ECONNTIMEOUT unless the socket connection truly is broken
+	if tcpConn, ok := rawConn.(*net.TCPConn); ok {
+		tcpConn.SetKeepAlive(true)
+		tcpConn.SetKeepAlivePeriod(30 * time.Second)
+	}
+
+	colonPos := strings.LastIndex(addr, ":")
+	if colonPos == -1 {
+		colonPos = len(addr)
+	}
+	hostname := addr[:colonPos]
+
+	// If no ServerName is set, infer the ServerName
+	// from the hostname we're connecting to.
+	if config.ServerName == "" {
+		// Make a copy to avoid polluting argument or default.
+		c := *config
+		c.ServerName = hostname
+		config = &c
+	}
+
+	conn := tls.Client(rawConn, config)
+
+	if timeout == 0 {
+		err = conn.Handshake()
+	} else {
+		go func() {
+			errChannel <- conn.Handshake()
+		}()
+
+		err = <-errChannel
+	}
+
+	if err != nil {
+		rawConn.Close()
+		return nil, err
+	}
+
+	// This is Docker difference with standard's crypto/tls package: returned a
+	// wrapper which holds both the TLS and raw connections.
+	return &tlsClientConn{conn, rawConn}, nil
+}
+
+func dialHijack(tlsConfig *tls.Config, addr string) (net.Conn, error) {
+	if tlsConfig == nil {
+		return net.Dial("tcp", addr)
+	}
+	return tlsDialWithDialer(new(net.Dialer), "tcp", addr, tlsConfig)
+}
+
 func hijack(tlsConfig *tls.Config, addr string, w http.ResponseWriter, r *http.Request) error {
 	if parts := strings.SplitN(addr, "://", 2); len(parts) == 2 {
 		addr = parts[1]
@@ -142,11 +245,7 @@ func hijack(tlsConfig *tls.Config, addr string, w http.ResponseWriter, r *http.R
 		err error
 	)
 
-	if tlsConfig != nil {
-		d, err = tls.Dial("tcp", addr, tlsConfig)
-	} else {
-		d, err = net.Dial("tcp", addr)
-	}
+	d, err = dialHijack(tlsConfig, addr)
 	if err != nil {
 		return err
 	}