From 478b3779f209b7b2c241239c5617a31af45dd66c Mon Sep 17 00:00:00 2001 From: Andy Goldstein Date: Thu, 13 Apr 2017 11:31:47 -0400 Subject: [PATCH] Add redirect support to SpdyRoundTripper Add support for following redirects to the SpdyRoundTripper. This is necessary for clients using it directly (e.g. the apiserver talking directly to the kubelet) because the CRI streaming server issues a redirect for streaming requests. Also extract common logic for following redirects. Kubernetes-commit: 715d5d9c91c669cf33c0bf9a9c9d352c6c4228a6 --- pkg/registry/generic/rest/proxy.go | 115 +++--------------------- pkg/registry/generic/rest/proxy_test.go | 9 ++ pkg/util/proxy/dial.go | 2 +- pkg/util/proxy/dial_test.go | 2 +- 4 files changed, 24 insertions(+), 104 deletions(-) diff --git a/pkg/registry/generic/rest/proxy.go b/pkg/registry/generic/rest/proxy.go index b38e1468e..8286597c6 100644 --- a/pkg/registry/generic/rest/proxy.go +++ b/pkg/registry/generic/rest/proxy.go @@ -17,8 +17,6 @@ limitations under the License. package rest import ( - "bufio" - "bytes" "fmt" "io" "net" @@ -146,10 +144,13 @@ func (h *UpgradeAwareProxyHandler) tryUpgrade(w http.ResponseWriter, req *http.R rawResponse []byte err error ) + if h.InterceptRedirects && utilfeature.DefaultFeatureGate.Enabled(genericfeatures.StreamingProxyRedirects) { - backendConn, rawResponse, err = h.connectBackendWithRedirects(req) + backendConn, rawResponse, err = utilnet.ConnectWithRedirects(req.Method, h.Location, req.Header, req.Body, h) } else { - backendConn, err = h.connectBackend(req.Method, h.Location, req.Header, req.Body) + clone := utilnet.CloneRequest(req) + clone.URL = h.Location + backendConn, err = h.Dial(clone) } if err != nil { h.Responder.Error(err) @@ -214,112 +215,22 @@ func (h *UpgradeAwareProxyHandler) tryUpgrade(w http.ResponseWriter, req *http.R return true } -// connectBackend dials the backend at location and forwards a copy of the client request. -func (h *UpgradeAwareProxyHandler) connectBackend(method string, location *url.URL, header http.Header, body io.Reader) (conn net.Conn, err error) { - defer func() { - if err != nil && conn != nil { - conn.Close() - conn = nil - } - }() - - beReq, err := http.NewRequest(method, location.String(), body) +// Dial dials the backend at req.URL and writes req to it. +func (h *UpgradeAwareProxyHandler) Dial(req *http.Request) (net.Conn, error) { + conn, err := proxy.DialURL(req.URL, h.Transport) if err != nil { - return nil, err - } - beReq.Header = header - - conn, err = proxy.DialURL(location, h.Transport) - if err != nil { - return conn, fmt.Errorf("error dialing backend: %v", err) + return nil, fmt.Errorf("error dialing backend: %v", err) } - if err = beReq.Write(conn); err != nil { - return conn, fmt.Errorf("error sending request: %v", err) + if err = req.Write(conn); err != nil { + conn.Close() + return nil, fmt.Errorf("error sending request: %v", err) } return conn, err } -// connectBackendWithRedirects dials the backend and forwards a copy of the client request. If the -// client responds with a redirect, it is followed. The raw response bytes are returned, and should -// be forwarded back to the client. -func (h *UpgradeAwareProxyHandler) connectBackendWithRedirects(req *http.Request) (net.Conn, []byte, error) { - const ( - maxRedirects = 10 - maxResponseSize = 4096 - ) - var ( - initialReq = req - rawResponse = bytes.NewBuffer(make([]byte, 0, 256)) - location = h.Location - intermediateConn net.Conn - err error - ) - defer func() { - if intermediateConn != nil { - intermediateConn.Close() - } - }() - -redirectLoop: - for redirects := 0; ; redirects++ { - if redirects > maxRedirects { - return nil, nil, fmt.Errorf("too many redirects (%d)", redirects) - } - - if redirects == 0 { - intermediateConn, err = h.connectBackend(req.Method, location, req.Header, req.Body) - } else { - // Redirected requests switch to "GET" according to the HTTP spec: - // https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3 - intermediateConn, err = h.connectBackend("GET", location, initialReq.Header, nil) - } - - if err != nil { - return nil, nil, err - } - - // Peek at the backend response. - rawResponse.Reset() - respReader := bufio.NewReader(io.TeeReader( - io.LimitReader(intermediateConn, maxResponseSize), // Don't read more than maxResponseSize bytes. - rawResponse)) // Save the raw response. - resp, err := http.ReadResponse(respReader, req) - if err != nil { - // Unable to read the backend response; let the client handle it. - glog.Warningf("Error reading backend response: %v", err) - break redirectLoop - } - resp.Body.Close() // Unused. - - switch resp.StatusCode { - case http.StatusFound: - // Redirect, continue. - default: - // Don't redirect. - break redirectLoop - } - - // Reset the connection. - intermediateConn.Close() - intermediateConn = nil - - // Prepare to follow the redirect. - redirectStr := resp.Header.Get("Location") - if redirectStr == "" { - return nil, nil, fmt.Errorf("%d response missing Location header", resp.StatusCode) - } - location, err = h.Location.Parse(redirectStr) - if err != nil { - return nil, nil, fmt.Errorf("malformed Location header: %v", err) - } - } - - backendConn := intermediateConn - intermediateConn = nil // Don't close the connection when we return it. - return backendConn, rawResponse.Bytes(), nil -} +var _ utilnet.Dialer = &UpgradeAwareProxyHandler{} func (h *UpgradeAwareProxyHandler) defaultProxyTransport(url *url.URL, internalTransport http.RoundTripper) http.RoundTripper { scheme := url.Scheme diff --git a/pkg/registry/generic/rest/proxy_test.go b/pkg/registry/generic/rest/proxy_test.go index a43279fb9..96ebed4d0 100644 --- a/pkg/registry/generic/rest/proxy_test.go +++ b/pkg/registry/generic/rest/proxy_test.go @@ -432,6 +432,7 @@ func TestProxyUpgrade(t *testing.T) { Location: serverURL, Transport: tc.ProxyTransport, InterceptRedirects: redirect, + Responder: &noErrorsAllowed{t: t}, } proxy := httptest.NewServer(proxyHandler) defer proxy.Close() @@ -459,6 +460,14 @@ func TestProxyUpgrade(t *testing.T) { } } +type noErrorsAllowed struct { + t *testing.T +} + +func (r *noErrorsAllowed) Error(err error) { + r.t.Error(err) +} + func TestProxyUpgradeErrorResponse(t *testing.T) { var ( responder *fakeResponder diff --git a/pkg/util/proxy/dial.go b/pkg/util/proxy/dial.go index 55ca0e32d..3cb890dd0 100644 --- a/pkg/util/proxy/dial.go +++ b/pkg/util/proxy/dial.go @@ -32,7 +32,7 @@ import ( func DialURL(url *url.URL, transport http.RoundTripper) (net.Conn, error) { dialAddr := netutil.CanonicalAddr(url) - dialer, _ := utilnet.Dialer(transport) + dialer, _ := utilnet.DialerFor(transport) switch url.Scheme { case "http": diff --git a/pkg/util/proxy/dial_test.go b/pkg/util/proxy/dial_test.go index ee143b1e2..f268ecd80 100644 --- a/pkg/util/proxy/dial_test.go +++ b/pkg/util/proxy/dial_test.go @@ -102,7 +102,7 @@ func TestDialURL(t *testing.T) { TLSClientConfig: tlsConfigCopy, } - extractedDial, err := utilnet.Dialer(transport) + extractedDial, err := utilnet.DialerFor(transport) if err != nil { t.Fatal(err) }