diff --git a/pkg/registry/generic/rest/proxy.go b/pkg/registry/generic/rest/proxy.go index b38e1468e..8286597c6 100644 --- a/pkg/registry/generic/rest/proxy.go +++ b/pkg/registry/generic/rest/proxy.go @@ -17,8 +17,6 @@ limitations under the License. package rest import ( - "bufio" - "bytes" "fmt" "io" "net" @@ -146,10 +144,13 @@ func (h *UpgradeAwareProxyHandler) tryUpgrade(w http.ResponseWriter, req *http.R rawResponse []byte err error ) + if h.InterceptRedirects && utilfeature.DefaultFeatureGate.Enabled(genericfeatures.StreamingProxyRedirects) { - backendConn, rawResponse, err = h.connectBackendWithRedirects(req) + backendConn, rawResponse, err = utilnet.ConnectWithRedirects(req.Method, h.Location, req.Header, req.Body, h) } else { - backendConn, err = h.connectBackend(req.Method, h.Location, req.Header, req.Body) + clone := utilnet.CloneRequest(req) + clone.URL = h.Location + backendConn, err = h.Dial(clone) } if err != nil { h.Responder.Error(err) @@ -214,112 +215,22 @@ func (h *UpgradeAwareProxyHandler) tryUpgrade(w http.ResponseWriter, req *http.R return true } -// connectBackend dials the backend at location and forwards a copy of the client request. -func (h *UpgradeAwareProxyHandler) connectBackend(method string, location *url.URL, header http.Header, body io.Reader) (conn net.Conn, err error) { - defer func() { - if err != nil && conn != nil { - conn.Close() - conn = nil - } - }() - - beReq, err := http.NewRequest(method, location.String(), body) +// Dial dials the backend at req.URL and writes req to it. +func (h *UpgradeAwareProxyHandler) Dial(req *http.Request) (net.Conn, error) { + conn, err := proxy.DialURL(req.URL, h.Transport) if err != nil { - return nil, err - } - beReq.Header = header - - conn, err = proxy.DialURL(location, h.Transport) - if err != nil { - return conn, fmt.Errorf("error dialing backend: %v", err) + return nil, fmt.Errorf("error dialing backend: %v", err) } - if err = beReq.Write(conn); err != nil { - return conn, fmt.Errorf("error sending request: %v", err) + if err = req.Write(conn); err != nil { + conn.Close() + return nil, fmt.Errorf("error sending request: %v", err) } return conn, err } -// connectBackendWithRedirects dials the backend and forwards a copy of the client request. If the -// client responds with a redirect, it is followed. The raw response bytes are returned, and should -// be forwarded back to the client. -func (h *UpgradeAwareProxyHandler) connectBackendWithRedirects(req *http.Request) (net.Conn, []byte, error) { - const ( - maxRedirects = 10 - maxResponseSize = 4096 - ) - var ( - initialReq = req - rawResponse = bytes.NewBuffer(make([]byte, 0, 256)) - location = h.Location - intermediateConn net.Conn - err error - ) - defer func() { - if intermediateConn != nil { - intermediateConn.Close() - } - }() - -redirectLoop: - for redirects := 0; ; redirects++ { - if redirects > maxRedirects { - return nil, nil, fmt.Errorf("too many redirects (%d)", redirects) - } - - if redirects == 0 { - intermediateConn, err = h.connectBackend(req.Method, location, req.Header, req.Body) - } else { - // Redirected requests switch to "GET" according to the HTTP spec: - // https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3 - intermediateConn, err = h.connectBackend("GET", location, initialReq.Header, nil) - } - - if err != nil { - return nil, nil, err - } - - // Peek at the backend response. - rawResponse.Reset() - respReader := bufio.NewReader(io.TeeReader( - io.LimitReader(intermediateConn, maxResponseSize), // Don't read more than maxResponseSize bytes. - rawResponse)) // Save the raw response. - resp, err := http.ReadResponse(respReader, req) - if err != nil { - // Unable to read the backend response; let the client handle it. - glog.Warningf("Error reading backend response: %v", err) - break redirectLoop - } - resp.Body.Close() // Unused. - - switch resp.StatusCode { - case http.StatusFound: - // Redirect, continue. - default: - // Don't redirect. - break redirectLoop - } - - // Reset the connection. - intermediateConn.Close() - intermediateConn = nil - - // Prepare to follow the redirect. - redirectStr := resp.Header.Get("Location") - if redirectStr == "" { - return nil, nil, fmt.Errorf("%d response missing Location header", resp.StatusCode) - } - location, err = h.Location.Parse(redirectStr) - if err != nil { - return nil, nil, fmt.Errorf("malformed Location header: %v", err) - } - } - - backendConn := intermediateConn - intermediateConn = nil // Don't close the connection when we return it. - return backendConn, rawResponse.Bytes(), nil -} +var _ utilnet.Dialer = &UpgradeAwareProxyHandler{} func (h *UpgradeAwareProxyHandler) defaultProxyTransport(url *url.URL, internalTransport http.RoundTripper) http.RoundTripper { scheme := url.Scheme diff --git a/pkg/registry/generic/rest/proxy_test.go b/pkg/registry/generic/rest/proxy_test.go index a43279fb9..96ebed4d0 100644 --- a/pkg/registry/generic/rest/proxy_test.go +++ b/pkg/registry/generic/rest/proxy_test.go @@ -432,6 +432,7 @@ func TestProxyUpgrade(t *testing.T) { Location: serverURL, Transport: tc.ProxyTransport, InterceptRedirects: redirect, + Responder: &noErrorsAllowed{t: t}, } proxy := httptest.NewServer(proxyHandler) defer proxy.Close() @@ -459,6 +460,14 @@ func TestProxyUpgrade(t *testing.T) { } } +type noErrorsAllowed struct { + t *testing.T +} + +func (r *noErrorsAllowed) Error(err error) { + r.t.Error(err) +} + func TestProxyUpgradeErrorResponse(t *testing.T) { var ( responder *fakeResponder diff --git a/pkg/util/proxy/dial.go b/pkg/util/proxy/dial.go index 55ca0e32d..3cb890dd0 100644 --- a/pkg/util/proxy/dial.go +++ b/pkg/util/proxy/dial.go @@ -32,7 +32,7 @@ import ( func DialURL(url *url.URL, transport http.RoundTripper) (net.Conn, error) { dialAddr := netutil.CanonicalAddr(url) - dialer, _ := utilnet.Dialer(transport) + dialer, _ := utilnet.DialerFor(transport) switch url.Scheme { case "http": diff --git a/pkg/util/proxy/dial_test.go b/pkg/util/proxy/dial_test.go index ee143b1e2..f268ecd80 100644 --- a/pkg/util/proxy/dial_test.go +++ b/pkg/util/proxy/dial_test.go @@ -102,7 +102,7 @@ func TestDialURL(t *testing.T) { TLSClientConfig: tlsConfigCopy, } - extractedDial, err := utilnet.Dialer(transport) + extractedDial, err := utilnet.DialerFor(transport) if err != nil { t.Fatal(err) }