365 lines
16 KiB
Go
365 lines
16 KiB
Go
/*
|
|
Copyright 2024 The Kubernetes Authors.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package proxy
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/rand"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"k8s.io/apimachinery/pkg/runtime"
|
|
"k8s.io/apimachinery/pkg/util/httpstream"
|
|
"k8s.io/apimachinery/pkg/util/httpstream/spdy"
|
|
constants "k8s.io/apimachinery/pkg/util/portforward"
|
|
"k8s.io/apimachinery/pkg/util/proxy"
|
|
"k8s.io/apimachinery/pkg/util/wait"
|
|
"k8s.io/apiserver/pkg/registry/rest"
|
|
restconfig "k8s.io/client-go/rest"
|
|
"k8s.io/client-go/tools/portforward"
|
|
)
|
|
|
|
func TestTunnelingHandler_UpgradeStreamingAndTunneling(t *testing.T) {
|
|
// Create fake upstream SPDY server, with channel receiving SPDY streams.
|
|
streamChan := make(chan httpstream.Stream)
|
|
defer close(streamChan)
|
|
stopServerChan := make(chan struct{})
|
|
defer close(stopServerChan)
|
|
// Create fake upstream SPDY server.
|
|
spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
_, err := httpstream.Handshake(req, w, []string{constants.PortForwardV1Name})
|
|
require.NoError(t, err)
|
|
upgrader := spdy.NewResponseUpgrader()
|
|
conn := upgrader.UpgradeResponse(w, req, justQueueStream(streamChan))
|
|
require.NotNil(t, conn)
|
|
defer conn.Close() //nolint:errcheck
|
|
<-stopServerChan
|
|
}))
|
|
defer spdyServer.Close()
|
|
// Create UpgradeAwareProxy handler, with url/transport pointing to upstream SPDY. Then
|
|
// create TunnelingHandler by injecting upgrade handler. Create TunnelingServer.
|
|
url, err := url.Parse(spdyServer.URL)
|
|
require.NoError(t, err)
|
|
transport, err := fakeTransport()
|
|
require.NoError(t, err)
|
|
upgradeHandler := proxy.NewUpgradeAwareHandler(url, transport, false, true, proxy.NewErrorResponder(&fakeResponder{}))
|
|
tunnelingHandler := NewTunnelingHandler(upgradeHandler)
|
|
tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
tunnelingHandler.ServeHTTP(w, req)
|
|
}))
|
|
defer tunnelingServer.Close()
|
|
// Create SPDY client connection containing a TunnelingConnection by upgrading
|
|
// a request to TunnelingHandler using new portforward version 2.
|
|
tunnelingURL, err := url.Parse(tunnelingServer.URL)
|
|
require.NoError(t, err)
|
|
dialer, err := portforward.NewSPDYOverWebsocketDialer(tunnelingURL, &restconfig.Config{Host: tunnelingURL.Host})
|
|
require.NoError(t, err)
|
|
spdyClient, protocol, err := dialer.Dial(constants.PortForwardV1Name)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, constants.PortForwardV1Name, protocol)
|
|
defer spdyClient.Close() //nolint:errcheck
|
|
// Create a SPDY client stream, which will queue a SPDY server stream
|
|
// on the stream creation channel. Send random data on the client stream
|
|
// reading off the SPDY server stream, and validating it was tunneled.
|
|
randomSize := 1024 * 1024
|
|
randomData := make([]byte, randomSize)
|
|
_, err = rand.Read(randomData)
|
|
require.NoError(t, err)
|
|
var actual []byte
|
|
go func() {
|
|
clientStream, err := spdyClient.CreateStream(http.Header{})
|
|
require.NoError(t, err)
|
|
_, err = io.Copy(clientStream, bytes.NewReader(randomData))
|
|
require.NoError(t, err)
|
|
clientStream.Close() //nolint:errcheck
|
|
}()
|
|
select {
|
|
case serverStream := <-streamChan:
|
|
actual, err = io.ReadAll(serverStream)
|
|
require.NoError(t, err)
|
|
defer serverStream.Close() //nolint:errcheck
|
|
case <-time.After(wait.ForeverTestTimeout):
|
|
t.Fatalf("timeout waiting for spdy stream to arrive on channel.")
|
|
}
|
|
assert.Equal(t, randomData, actual, "error validating tunneled random data")
|
|
}
|
|
|
|
var expectedContentLengthHeaders = http.Header{
|
|
"Content-Length": []string{"25"},
|
|
"Date": []string{"Sun, 25 Feb 2024 08:09:25 GMT"},
|
|
"Split-Point": []string{"split"},
|
|
"X-App-Protocol": []string{"portforward.k8s.io"},
|
|
}
|
|
|
|
const contentLengthHeaders = "HTTP/1.1 400 Error\r\n" +
|
|
"Content-Length: 25\r\n" +
|
|
"Date: Sun, 25 Feb 2024 08:09:25 GMT\r\n" +
|
|
"Split-Point: split\r\n" +
|
|
"X-App-Protocol: portforward.k8s.io\r\n" +
|
|
"\r\n"
|
|
|
|
const contentLengthBody = "0123456789split0123456789"
|
|
|
|
var contentLengthHeadersAndBody = contentLengthHeaders + contentLengthBody
|
|
|
|
var expectedResponseHeaders = http.Header{
|
|
"Date": []string{"Sun, 25 Feb 2024 08:09:25 GMT"},
|
|
"Split-Point": []string{"split"},
|
|
"X-App-Protocol": []string{"portforward.k8s.io"},
|
|
}
|
|
|
|
const responseHeaders = "HTTP/1.1 101 Switching Protocols\r\n" +
|
|
"Date: Sun, 25 Feb 2024 08:09:25 GMT\r\n" +
|
|
"Split-Point: split\r\n" +
|
|
"X-App-Protocol: portforward.k8s.io\r\n" +
|
|
"\r\n"
|
|
|
|
const responseBody = "This is extra split data.\n"
|
|
|
|
var responseHeadersAndBody = responseHeaders + responseBody
|
|
|
|
const invalidResponseData = "INVALID/1.1 101 Switching Protocols\r\n" +
|
|
"Date: Sun, 25 Feb 2024 08:09:25 GMT\r\n" +
|
|
"Split-Point: split\r\n" +
|
|
"X-App-Protocol: portforward.k8s.io\r\n" +
|
|
"\r\n"
|
|
|
|
func TestTunnelingHandler_HeaderInterceptingConn(t *testing.T) {
|
|
// Basic http response is intercepted correctly; no extra data sent to net.Conn.
|
|
t.Run("simple-no-body", func(t *testing.T) {
|
|
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
|
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
|
_, err := hic.Write([]byte(responseHeaders))
|
|
require.NoError(t, err)
|
|
assert.True(t, hic.initialized, "successfully parsed http response headers")
|
|
assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
|
|
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
|
|
assert.Equal(t, "portforward.k8s.io", testConnConstructor.resp.Header.Get("X-App-Protocol"))
|
|
assert.Equal(t, responseHeaders, string(testConnConstructor.initializeWriteConn.written), "only headers are written in initializeWrite")
|
|
assert.Equal(t, "", string(testConnConstructor.mockConn.written))
|
|
})
|
|
|
|
// Extra data after response headers should be sent to net.Conn.
|
|
t.Run("simple-single-write", func(t *testing.T) {
|
|
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
|
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
|
_, err := hic.Write([]byte(responseHeadersAndBody))
|
|
require.NoError(t, err)
|
|
assert.True(t, hic.initialized)
|
|
assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
|
|
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
|
|
assert.Equal(t, responseHeaders, string(testConnConstructor.initializeWriteConn.written), "only headers are written in initializeWrite")
|
|
assert.Equal(t, responseBody, string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
|
|
})
|
|
|
|
// Partially written headers are buffered and decoded
|
|
t.Run("simple-byte-by-byte", func(t *testing.T) {
|
|
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
|
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
|
// write one byte at a time
|
|
for _, b := range []byte(responseHeadersAndBody) {
|
|
_, err := hic.Write([]byte{b})
|
|
require.NoError(t, err)
|
|
}
|
|
assert.True(t, hic.initialized)
|
|
assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
|
|
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
|
|
assert.Equal(t, responseHeaders, string(testConnConstructor.initializeWriteConn.written), "only headers are written in initializeWrite")
|
|
assert.Equal(t, responseBody, string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
|
|
})
|
|
|
|
// Writes spanning the header/body breakpoint are buffered and decoded
|
|
t.Run("simple-span-headerbody", func(t *testing.T) {
|
|
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
|
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
|
// write one chunk at a time
|
|
for i, chunk := range strings.Split(responseHeadersAndBody, "split") {
|
|
if i > 0 {
|
|
n, err := hic.Write([]byte("split"))
|
|
require.Equal(t, n, len("split"))
|
|
require.NoError(t, err)
|
|
}
|
|
n, err := hic.Write([]byte(chunk))
|
|
require.Equal(t, n, len(chunk))
|
|
require.NoError(t, err)
|
|
}
|
|
assert.True(t, hic.initialized)
|
|
assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
|
|
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
|
|
assert.Equal(t, responseHeaders, string(testConnConstructor.initializeWriteConn.written), "only headers are written in initializeWrite")
|
|
assert.Equal(t, responseBody, string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
|
|
})
|
|
|
|
// Tolerate header separators of \n instead of \r\n, and extra data after response headers should be sent to net.Conn.
|
|
t.Run("simple-tolerate-lf", func(t *testing.T) {
|
|
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
|
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
|
_, err := hic.Write([]byte(strings.ReplaceAll(responseHeadersAndBody, "\r", "")))
|
|
require.NoError(t, err)
|
|
assert.True(t, hic.initialized)
|
|
assert.Equal(t, expectedResponseHeaders, testConnConstructor.resp.Header)
|
|
assert.Equal(t, "101 Switching Protocols", testConnConstructor.resp.Status)
|
|
assert.Equal(t, strings.ReplaceAll(responseHeaders, "\r", ""), string(testConnConstructor.initializeWriteConn.written), "only normalized headers are written in initializeWrite")
|
|
assert.Equal(t, responseBody, string(testConnConstructor.mockConn.written), "extra data written to net.Conn")
|
|
})
|
|
|
|
// Content-Length handling
|
|
t.Run("content-length-body", func(t *testing.T) {
|
|
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
|
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
|
_, err := hic.Write([]byte(contentLengthHeadersAndBody))
|
|
require.NoError(t, err)
|
|
assert.True(t, hic.initialized, "successfully parsed http response headers")
|
|
assert.Equal(t, expectedContentLengthHeaders, testConnConstructor.resp.Header)
|
|
assert.Equal(t, "400 Error", testConnConstructor.resp.Status)
|
|
assert.Equal(t, contentLengthHeaders, string(testConnConstructor.initializeWriteConn.written), "headers and content are written in initializeWrite")
|
|
assert.Equal(t, contentLengthBody, string(testConnConstructor.mockConn.written))
|
|
})
|
|
|
|
// Content-Length separately written headers and body
|
|
t.Run("content-length-headers-body", func(t *testing.T) {
|
|
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
|
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
|
_, err := hic.Write([]byte(contentLengthHeaders))
|
|
require.NoError(t, err)
|
|
_, err = hic.Write([]byte(contentLengthBody))
|
|
require.NoError(t, err)
|
|
assert.True(t, hic.initialized, "successfully parsed http response headers")
|
|
assert.Equal(t, expectedContentLengthHeaders, testConnConstructor.resp.Header)
|
|
assert.Equal(t, "400 Error", testConnConstructor.resp.Status)
|
|
assert.Equal(t, contentLengthHeaders, string(testConnConstructor.initializeWriteConn.written), "headers and content are written in initializeWrite")
|
|
assert.Equal(t, contentLengthBody, string(testConnConstructor.mockConn.written))
|
|
})
|
|
|
|
// Content-Length accumulating byte-by-byte
|
|
t.Run("content-length-byte-by-byte", func(t *testing.T) {
|
|
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
|
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
|
for _, b := range []byte(contentLengthHeadersAndBody) {
|
|
_, err := hic.Write([]byte{b})
|
|
require.NoError(t, err)
|
|
}
|
|
assert.True(t, hic.initialized, "successfully parsed http response headers")
|
|
assert.Equal(t, expectedContentLengthHeaders, testConnConstructor.resp.Header)
|
|
assert.Equal(t, "400 Error", testConnConstructor.resp.Status)
|
|
assert.Equal(t, contentLengthHeaders, string(testConnConstructor.initializeWriteConn.written), "headers and content are written in initializeWrite")
|
|
assert.Equal(t, contentLengthBody, string(testConnConstructor.mockConn.written))
|
|
})
|
|
|
|
// Content-Length writes spanning headers / body
|
|
t.Run("content-length-span-headerbody", func(t *testing.T) {
|
|
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
|
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
|
// write one chunk at a time
|
|
for i, chunk := range strings.Split(contentLengthHeadersAndBody, "split") {
|
|
if i > 0 {
|
|
n, err := hic.Write([]byte("split"))
|
|
require.Equal(t, n, len("split"))
|
|
require.NoError(t, err)
|
|
}
|
|
n, err := hic.Write([]byte(chunk))
|
|
require.Equal(t, n, len(chunk))
|
|
require.NoError(t, err)
|
|
}
|
|
assert.True(t, hic.initialized, "successfully parsed http response headers")
|
|
assert.Equal(t, expectedContentLengthHeaders, testConnConstructor.resp.Header)
|
|
assert.Equal(t, "400 Error", testConnConstructor.resp.Status)
|
|
assert.Equal(t, contentLengthHeaders, string(testConnConstructor.initializeWriteConn.written), "headers and content are written in initializeWrite")
|
|
assert.Equal(t, contentLengthBody, string(testConnConstructor.mockConn.written))
|
|
})
|
|
|
|
// Invalid response returns error.
|
|
t.Run("invalid-single-write", func(t *testing.T) {
|
|
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
|
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
|
_, err := hic.Write([]byte(invalidResponseData))
|
|
assert.Error(t, err, "expected error from invalid http response")
|
|
})
|
|
|
|
// Invalid response written byte by byte returns error.
|
|
t.Run("invalid-byte-by-byte", func(t *testing.T) {
|
|
testConnConstructor := &mockConnInitializer{mockConn: &mockConn{}, initializeWriteConn: &mockConn{}}
|
|
hic := &headerInterceptingConn{initializableConn: testConnConstructor}
|
|
var err error
|
|
for _, b := range []byte(invalidResponseData) {
|
|
_, err = hic.Write([]byte{b})
|
|
if err != nil {
|
|
break
|
|
}
|
|
}
|
|
assert.Error(t, err, "expected error from invalid http response")
|
|
})
|
|
}
|
|
|
|
type mockConnInitializer struct {
|
|
resp *http.Response
|
|
initializeWriteConn *mockConn
|
|
*mockConn
|
|
}
|
|
|
|
func (m *mockConnInitializer) InitializeWrite(backendResponse *http.Response, backendResponseBytes []byte) error {
|
|
m.resp = backendResponse
|
|
_, err := m.initializeWriteConn.Write(backendResponseBytes)
|
|
return err
|
|
}
|
|
|
|
// mockConn implements "net.Conn" interface.
|
|
var _ net.Conn = &mockConn{}
|
|
|
|
type mockConn struct {
|
|
written []byte
|
|
}
|
|
|
|
func (mc *mockConn) Write(p []byte) (int, error) {
|
|
mc.written = append(mc.written, p...)
|
|
return len(p), nil
|
|
}
|
|
|
|
func (mc *mockConn) Read(p []byte) (int, error) { return 0, nil }
|
|
func (mc *mockConn) Close() error { return nil }
|
|
func (mc *mockConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
|
|
func (mc *mockConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
|
|
func (mc *mockConn) SetDeadline(t time.Time) error { return nil }
|
|
func (mc *mockConn) SetReadDeadline(t time.Time) error { return nil }
|
|
func (mc *mockConn) SetWriteDeadline(t time.Time) error { return nil }
|
|
|
|
// fakeResponder implements "rest.Responder" interface.
|
|
var _ rest.Responder = &fakeResponder{}
|
|
|
|
type fakeResponder struct{}
|
|
|
|
func (fr *fakeResponder) Object(statusCode int, obj runtime.Object) {}
|
|
func (fr *fakeResponder) Error(err error) {}
|
|
|
|
// justQueueStream skips the usual stream validation before
|
|
// queueing the stream on the stream channel.
|
|
func justQueueStream(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error {
|
|
return func(stream httpstream.Stream, replySent <-chan struct{}) error {
|
|
streams <- stream
|
|
return nil
|
|
}
|
|
}
|