Merge pull request #119186 from seans3/stream-translator-proxy

Stream Translator Proxy and FallbackExecutor for WebSockets

Kubernetes-commit: 87981480f33790225628824943217bd6bb7564bb
This commit is contained in:
Kubernetes Publisher 2023-10-24 17:10:34 +02:00
commit 618d6a3eb7
7 changed files with 1424 additions and 9 deletions

11
go.mod
View File

@ -18,6 +18,7 @@ require (
github.com/google/uuid v1.3.0
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f
github.com/spf13/pflag v1.0.5
github.com/stretchr/testify v1.8.3
go.etcd.io/etcd/api/v3 v3.5.9
@ -42,8 +43,8 @@ require (
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/square/go-jose.v2 v2.6.0
k8s.io/api v0.0.0-20231023194506-bfce70f1b5c8
k8s.io/apimachinery v0.0.0-20231024034334-1e138bd489ac
k8s.io/client-go v0.0.0-20231024035150-c92537416a96
k8s.io/apimachinery v0.0.0-20231024171030-c18d2bfed439
k8s.io/client-go v0.0.0-20231024171543-e2e59f3539ef
k8s.io/component-base v0.0.0-20231024040035-12d4256eb135
k8s.io/klog/v2 v2.100.1
k8s.io/kms v0.0.0-20231023195612-e039984be9c9
@ -87,9 +88,9 @@ require (
github.com/json-iterator/go v1.1.12 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
github.com/moby/spdystream v0.2.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/pquerna/cachecontrol v0.1.0 // indirect
@ -127,8 +128,8 @@ require (
replace (
k8s.io/api => k8s.io/api v0.0.0-20231023194506-bfce70f1b5c8
k8s.io/apimachinery => k8s.io/apimachinery v0.0.0-20231024034334-1e138bd489ac
k8s.io/client-go => k8s.io/client-go v0.0.0-20231024035150-c92537416a96
k8s.io/apimachinery => k8s.io/apimachinery v0.0.0-20231024171030-c18d2bfed439
k8s.io/client-go => k8s.io/client-go v0.0.0-20231024171543-e2e59f3539ef
k8s.io/component-base => k8s.io/component-base v0.0.0-20231024040035-12d4256eb135
k8s.io/kms => k8s.io/kms v0.0.0-20231023195612-e039984be9c9
)

11
go.sum
View File

@ -42,6 +42,7 @@ github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAE
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230305170008-8188dc5388df h1:7RFfzj4SSt6nnvCPbCqijJi1nWCd+TqAT3bYCStRC18=
github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230305170008-8188dc5388df/go.mod h1:pSwJ0fSY5KhvocuWSx4fz3BA8OrA1bQn+K1Eli3BRwM=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a h1:idn718Q4B6AGu/h5Sxe66HYVdqdGu2l9Iebqhi/AEoA=
github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY=
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
@ -237,6 +238,8 @@ github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
github.com/moby/spdystream v0.2.0 h1:cjW1zVyyoiM0T7b6UoySUFqzXMoqRckQtXwGPiBhOM8=
github.com/moby/spdystream v0.2.0/go.mod h1:f7i0iNDQJ059oMTcWxx8MA/zKFIuD/lY+0GqbN2Wy8c=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@ -675,10 +678,10 @@ honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9
honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
k8s.io/api v0.0.0-20231023194506-bfce70f1b5c8 h1:U7xcM/WBTkLV+TjNciuW7l+oXM2OHd5/TmVnPKyrmpA=
k8s.io/api v0.0.0-20231023194506-bfce70f1b5c8/go.mod h1:mgYOiLIgrQcsuVxrBI6Pplk91r3sl5ZJ7eUx7UBMTkY=
k8s.io/apimachinery v0.0.0-20231024034334-1e138bd489ac h1:x3g6c1u7CtRoraBlRP2JThB3aHz7vw4FZFXRZsvoIoc=
k8s.io/apimachinery v0.0.0-20231024034334-1e138bd489ac/go.mod h1:mdlGhJWO1mhVzQXm1Lx7D1BvvBIVKlRVy0vvl1LwGjg=
k8s.io/client-go v0.0.0-20231024035150-c92537416a96 h1:76J+c8hyhX3e9eteWOc08cGsJeH5ky0Jmh/naC0ll8g=
k8s.io/client-go v0.0.0-20231024035150-c92537416a96/go.mod h1:hML9Z37ARvWfQt+YEVEMZ3EVJBqM19lCsFXogGW6VX8=
k8s.io/apimachinery v0.0.0-20231024171030-c18d2bfed439 h1:/oxbLzC7mkHNdeFI8AMsTPTwudQu7sz7rnPGIxv2yqM=
k8s.io/apimachinery v0.0.0-20231024171030-c18d2bfed439/go.mod h1:mdlGhJWO1mhVzQXm1Lx7D1BvvBIVKlRVy0vvl1LwGjg=
k8s.io/client-go v0.0.0-20231024171543-e2e59f3539ef h1:dx12CsKyk2cct0NtF7fHMX8cOzb1uJhI5wn2sQDpr60=
k8s.io/client-go v0.0.0-20231024171543-e2e59f3539ef/go.mod h1:3HC2qEcjQxIt5UW1R7vC5RX2sf/wkRWovfAEPkbmPxA=
k8s.io/component-base v0.0.0-20231024040035-12d4256eb135 h1:BxZJ2rg42EI0RbeNV5gb+8tdwYZ1iwxZJW4FmUMMdtc=
k8s.io/component-base v0.0.0-20231024040035-12d4256eb135/go.mod h1:ft9o5mWD7glAMtEqdxl4CmAKA9G6DFYRajW3TPrsQhs=
k8s.io/klog/v2 v2.100.1 h1:7WCHKK6K8fNhTqfBhISHQ97KrnJNFZMcQvKp7gP/tmg=

View File

@ -0,0 +1,167 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package proxy
import (
"fmt"
"net/http"
"net/url"
"github.com/mxk/go-flowrate/flowrate"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/httpstream/spdy"
constants "k8s.io/apimachinery/pkg/util/remotecommand"
"k8s.io/client-go/tools/remotecommand"
"k8s.io/client-go/util/exec"
)
// StreamTranslatorHandler is a handler which translates WebSocket stream data
// to SPDY to proxy to kubelet (and ContainerRuntime).
type StreamTranslatorHandler struct {
// Location is the location of the upstream proxy. It is used as the location to Dial on the upstream server
// for upgrade requests.
Location *url.URL
// Transport provides an optional round tripper to use to proxy. If nil, the default proxy transport is used
Transport http.RoundTripper
// MaxBytesPerSec throttles stream Reader/Writer if necessary
MaxBytesPerSec int64
// Options define the requested streams (e.g. stdin, stdout).
Options Options
}
// NewStreamTranslatorHandler creates a new proxy handler. Responder is required for returning
// errors to the caller.
func NewStreamTranslatorHandler(location *url.URL, transport http.RoundTripper, maxBytesPerSec int64, opts Options) *StreamTranslatorHandler {
return &StreamTranslatorHandler{
Location: location,
Transport: transport,
MaxBytesPerSec: maxBytesPerSec,
Options: opts,
}
}
func (h *StreamTranslatorHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// Create WebSocket server, including particular streams requested. If this websocket
// endpoint is not able to be upgraded, the websocket library will return errors
// to the client.
websocketStreams, err := webSocketServerStreams(req, w, h.Options)
if err != nil {
return
}
defer websocketStreams.conn.Close()
// Creating SPDY executor, ensuring redirects are not followed.
spdyRoundTripper, err := spdy.NewRoundTripperWithConfig(spdy.RoundTripperConfig{UpgradeTransport: h.Transport})
if err != nil {
websocketStreams.writeStatus(apierrors.NewInternalError(err)) //nolint:errcheck
return
}
spdyExecutor, err := remotecommand.NewSPDYExecutorRejectRedirects(spdyRoundTripper, spdyRoundTripper, "POST", h.Location)
if err != nil {
websocketStreams.writeStatus(apierrors.NewInternalError(err)) //nolint:errcheck
return
}
// Wire the WebSocket server streams output to the SPDY client input. The stdin/stdout/stderr streams
// can be throttled if the transfer rate exceeds the "MaxBytesPerSec" (zero means unset). Throttling
// the streams instead of the underlying connection *may* not perform the same if two streams
// traveling the same direction (e.g. stdout, stderr) are being maxed out.
opts := remotecommand.StreamOptions{}
if h.Options.Stdin {
stdin := websocketStreams.stdinStream
if h.MaxBytesPerSec > 0 {
stdin = flowrate.NewReader(stdin, h.MaxBytesPerSec)
}
opts.Stdin = stdin
}
if h.Options.Stdout {
stdout := websocketStreams.stdoutStream
if h.MaxBytesPerSec > 0 {
stdout = flowrate.NewWriter(stdout, h.MaxBytesPerSec)
}
opts.Stdout = stdout
}
if h.Options.Stderr {
stderr := websocketStreams.stderrStream
if h.MaxBytesPerSec > 0 {
stderr = flowrate.NewWriter(stderr, h.MaxBytesPerSec)
}
opts.Stderr = stderr
}
if h.Options.Tty {
opts.Tty = true
opts.TerminalSizeQueue = &translatorSizeQueue{resizeChan: websocketStreams.resizeChan}
}
// Start the SPDY client with connected streams. Output from the WebSocket server
// streams will be forwarded into the SPDY client. Report SPDY execution errors
// through the websocket error stream.
err = spdyExecutor.StreamWithContext(req.Context(), opts)
if err != nil {
//nolint:errcheck // Ignore writeStatus returned error
if statusErr, ok := err.(*apierrors.StatusError); ok {
websocketStreams.writeStatus(statusErr)
} else if exitErr, ok := err.(exec.CodeExitError); ok && exitErr.Exited() {
websocketStreams.writeStatus(codeExitToStatusError(exitErr))
} else {
websocketStreams.writeStatus(apierrors.NewInternalError(err))
}
return
}
// Write the success status back to the WebSocket client.
//nolint:errcheck
websocketStreams.writeStatus(&apierrors.StatusError{ErrStatus: metav1.Status{
Status: metav1.StatusSuccess,
}})
}
// translatorSizeQueue feeds the size events from the WebSocket
// resizeChan into the SPDY client input. Implements TerminalSizeQueue
// interface.
type translatorSizeQueue struct {
resizeChan chan remotecommand.TerminalSize
}
func (t *translatorSizeQueue) Next() *remotecommand.TerminalSize {
size, ok := <-t.resizeChan
if !ok {
return nil
}
return &size
}
// codeExitToStatusError converts a passed CodeExitError to the type necessary
// to send through an error stream using "writeStatus".
func codeExitToStatusError(exitErr exec.CodeExitError) *apierrors.StatusError {
rc := exitErr.ExitStatus()
return &apierrors.StatusError{
ErrStatus: metav1.Status{
Status: metav1.StatusFailure,
Reason: constants.NonZeroExitCodeReason,
Details: &metav1.StatusDetails{
Causes: []metav1.StatusCause{
{
Type: constants.ExitCodeCauseType,
Message: fmt.Sprintf("%d", rc),
},
},
},
Message: fmt.Sprintf("command terminated with non-zero exit code: %v", exitErr),
},
}
}

View File

@ -0,0 +1,872 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package proxy
import (
"bytes"
"context"
"crypto/rand"
"encoding/json"
"errors"
"fmt"
"io"
"math"
mrand "math/rand"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strings"
"testing"
"time"
v1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/httpstream/spdy"
rcconstants "k8s.io/apimachinery/pkg/util/remotecommand"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/remotecommand"
"k8s.io/client-go/transport"
)
// TestStreamTranslator_LoopbackStdinToStdout returns random data sent on the client's
// STDIN channel back onto the client's STDOUT channel. There are two servers in this test: the
// upstream fake SPDY server, and the StreamTranslator server. The StreamTranslator proxys the
// data received from the websocket client upstream to the SPDY server (by translating the
// websocket data into spdy). The returned data read on the websocket client STDOUT is then
// compared the random data sent on STDIN to ensure they are the same.
func TestStreamTranslator_LoopbackStdinToStdout(t *testing.T) {
// Create upstream fake SPDY server which copies STDIN back onto STDOUT stream.
spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx, err := createSPDYServerStreams(w, req, Options{
Stdin: true,
Stdout: true,
})
if err != nil {
t.Errorf("error on createHTTPStreams: %v", err)
return
}
defer ctx.conn.Close()
// Loopback STDIN data onto STDOUT stream.
_, err = io.Copy(ctx.stdoutStream, ctx.stdinStream)
if err != nil {
t.Fatalf("error copying STDIN to STDOUT: %v", err)
}
}))
defer spdyServer.Close()
// Create StreamTranslatorHandler, which points upstream to fake SPDY server with
// streams STDIN and STDOUT. Create test server from StreamTranslatorHandler.
spdyLocation, err := url.Parse(spdyServer.URL)
if err != nil {
t.Fatalf("Unable to parse spdy server URL: %s", spdyServer.URL)
}
spdyTransport, err := fakeTransport()
if err != nil {
t.Fatalf("Unexpected error creating transport: %v", err)
}
streams := Options{Stdin: true, Stdout: true}
streamTranslator := NewStreamTranslatorHandler(spdyLocation, spdyTransport, 0, streams)
streamTranslatorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
streamTranslator.ServeHTTP(w, req)
}))
defer streamTranslatorServer.Close()
// Now create the websocket client (executor), and point it to the "streamTranslatorServer".
streamTranslatorLocation, err := url.Parse(streamTranslatorServer.URL)
if err != nil {
t.Fatalf("Unable to parse StreamTranslator server URL: %s", streamTranslatorServer.URL)
}
exec, err := remotecommand.NewWebSocketExecutor(&rest.Config{Host: streamTranslatorLocation.Host}, "GET", streamTranslatorServer.URL)
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
// Generate random data, and set it up to stream on STDIN. The data will be
// returned on the STDOUT buffer.
randomSize := 1024 * 1024
randomData := make([]byte, randomSize)
if _, err := rand.Read(randomData); err != nil {
t.Errorf("unexpected error reading random data: %v", err)
}
var stdout bytes.Buffer
options := &remotecommand.StreamOptions{
Stdin: bytes.NewReader(randomData),
Stdout: &stdout,
}
errorChan := make(chan error)
go func() {
// Start the streaming on the WebSocket "exec" client.
errorChan <- exec.StreamWithContext(context.Background(), *options)
}()
select {
case <-time.After(wait.ForeverTestTimeout):
t.Fatalf("expect stream to be closed after connection is closed.")
case err := <-errorChan:
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
data, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
if err != nil {
t.Errorf("error reading the stream: %v", err)
return
}
// Check the random data sent on STDIN was the same returned on STDOUT.
if !bytes.Equal(randomData, data) {
t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData))
}
}
// TestStreamTranslator_LoopbackStdinToStderr returns random data sent on the client's
// STDIN channel back onto the client's STDERR channel. There are two servers in this test: the
// upstream fake SPDY server, and the StreamTranslator server. The StreamTranslator proxys the
// data received from the websocket client upstream to the SPDY server (by translating the
// websocket data into spdy). The returned data read on the websocket client STDERR is then
// compared the random data sent on STDIN to ensure they are the same.
func TestStreamTranslator_LoopbackStdinToStderr(t *testing.T) {
// Create upstream fake SPDY server which copies STDIN back onto STDERR stream.
spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx, err := createSPDYServerStreams(w, req, Options{
Stdin: true,
Stderr: true,
})
if err != nil {
t.Errorf("error on createHTTPStreams: %v", err)
return
}
defer ctx.conn.Close()
// Loopback STDIN data onto STDERR stream.
_, err = io.Copy(ctx.stderrStream, ctx.stdinStream)
if err != nil {
t.Fatalf("error copying STDIN to STDERR: %v", err)
}
}))
defer spdyServer.Close()
// Create StreamTranslatorHandler, which points upstream to fake SPDY server with
// streams STDIN and STDERR. Create test server from StreamTranslatorHandler.
spdyLocation, err := url.Parse(spdyServer.URL)
if err != nil {
t.Fatalf("Unable to parse spdy server URL: %s", spdyServer.URL)
}
spdyTransport, err := fakeTransport()
if err != nil {
t.Fatalf("Unexpected error creating transport: %v", err)
}
streams := Options{Stdin: true, Stderr: true}
streamTranslator := NewStreamTranslatorHandler(spdyLocation, spdyTransport, 0, streams)
streamTranslatorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
streamTranslator.ServeHTTP(w, req)
}))
defer streamTranslatorServer.Close()
// Now create the websocket client (executor), and point it to the "streamTranslatorServer".
streamTranslatorLocation, err := url.Parse(streamTranslatorServer.URL)
if err != nil {
t.Fatalf("Unable to parse StreamTranslator server URL: %s", streamTranslatorServer.URL)
}
exec, err := remotecommand.NewWebSocketExecutor(&rest.Config{Host: streamTranslatorLocation.Host}, "GET", streamTranslatorServer.URL)
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
// Generate random data, and set it up to stream on STDIN. The data will be
// returned on the STDERR buffer.
randomSize := 1024 * 1024
randomData := make([]byte, randomSize)
if _, err := rand.Read(randomData); err != nil {
t.Errorf("unexpected error reading random data: %v", err)
}
var stderr bytes.Buffer
options := &remotecommand.StreamOptions{
Stdin: bytes.NewReader(randomData),
Stderr: &stderr,
}
errorChan := make(chan error)
go func() {
// Start the streaming on the WebSocket "exec" client.
errorChan <- exec.StreamWithContext(context.Background(), *options)
}()
select {
case <-time.After(wait.ForeverTestTimeout):
t.Fatalf("expect stream to be closed after connection is closed.")
case err := <-errorChan:
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
data, err := io.ReadAll(bytes.NewReader(stderr.Bytes()))
if err != nil {
t.Errorf("error reading the stream: %v", err)
return
}
// Check the random data sent on STDIN was the same returned on STDERR.
if !bytes.Equal(randomData, data) {
t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData))
}
}
// Returns a random exit code in the range(1-127).
func randomExitCode() int {
errorCode := mrand.Intn(127) // Range: (0 - 126)
errorCode += 1 // Range: (1 - 127)
return errorCode
}
// TestStreamTranslator_ErrorStream tests the error stream by sending an error with a random
// exit code, then validating the error arrives on the error stream.
func TestStreamTranslator_ErrorStream(t *testing.T) {
expectedExitCode := randomExitCode()
// Create upstream fake SPDY server, returning a non-zero exit code
// on error stream within the structured error.
spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx, err := createSPDYServerStreams(w, req, Options{
Stdout: true,
})
if err != nil {
t.Errorf("error on createHTTPStreams: %v", err)
return
}
defer ctx.conn.Close()
// Read/discard STDIN data before returning error on error stream.
_, err = io.Copy(io.Discard, ctx.stdinStream)
if err != nil {
t.Fatalf("error copying STDIN to DISCARD: %v", err)
}
// Force an non-zero exit code error returned on the error stream.
err = ctx.writeStatus(&apierrors.StatusError{ErrStatus: metav1.Status{
Status: metav1.StatusFailure,
Reason: rcconstants.NonZeroExitCodeReason,
Details: &metav1.StatusDetails{
Causes: []metav1.StatusCause{
{
Type: rcconstants.ExitCodeCauseType,
Message: fmt.Sprintf("%d", expectedExitCode),
},
},
},
}})
if err != nil {
t.Fatalf("error writing status: %v", err)
}
}))
defer spdyServer.Close()
// Create StreamTranslatorHandler, which points upstream to fake SPDY server, and
// create a test server using the StreamTranslatorHandler.
spdyLocation, err := url.Parse(spdyServer.URL)
if err != nil {
t.Fatalf("Unable to parse spdy server URL: %s", spdyServer.URL)
}
spdyTransport, err := fakeTransport()
if err != nil {
t.Fatalf("Unexpected error creating transport: %v", err)
}
streams := Options{Stdin: true}
streamTranslator := NewStreamTranslatorHandler(spdyLocation, spdyTransport, 0, streams)
streamTranslatorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
streamTranslator.ServeHTTP(w, req)
}))
defer streamTranslatorServer.Close()
// Now create the websocket client (executor), and point it to the "streamTranslatorServer".
streamTranslatorLocation, err := url.Parse(streamTranslatorServer.URL)
if err != nil {
t.Fatalf("Unable to parse StreamTranslator server URL: %s", streamTranslatorServer.URL)
}
exec, err := remotecommand.NewWebSocketExecutor(&rest.Config{Host: streamTranslatorLocation.Host}, "GET", streamTranslatorServer.URL)
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
// Generate random data, and set it up to stream on STDIN. The data will be discarded at
// upstream SDPY server.
randomSize := 1024 * 1024
randomData := make([]byte, randomSize)
if _, err := rand.Read(randomData); err != nil {
t.Errorf("unexpected error reading random data: %v", err)
}
options := &remotecommand.StreamOptions{
Stdin: bytes.NewReader(randomData),
}
errorChan := make(chan error)
go func() {
// Start the streaming on the WebSocket "exec" client.
errorChan <- exec.StreamWithContext(context.Background(), *options)
}()
select {
case <-time.After(wait.ForeverTestTimeout):
t.Fatalf("expect stream to be closed after connection is closed.")
case err := <-errorChan:
// Expect exit code error on error stream.
if err == nil {
t.Errorf("expected error, but received none")
}
expectedError := fmt.Sprintf("command terminated with exit code %d", expectedExitCode)
// Compare expected error with exit code to actual error.
if expectedError != err.Error() {
t.Errorf("expected error (%s), got (%s)", expectedError, err)
}
}
}
// TestStreamTranslator_MultipleReadChannels tests two streams (STDOUT, STDERR) reading from
// the connections at the same time.
func TestStreamTranslator_MultipleReadChannels(t *testing.T) {
// Create upstream fake SPDY server which copies STDIN back onto STDOUT and STDERR stream.
spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx, err := createSPDYServerStreams(w, req, Options{
Stdin: true,
Stdout: true,
Stderr: true,
})
if err != nil {
t.Errorf("error on createHTTPStreams: %v", err)
return
}
defer ctx.conn.Close()
// TeeReader copies data read on STDIN onto STDERR.
stdinReader := io.TeeReader(ctx.stdinStream, ctx.stderrStream)
// Also copy STDIN to STDOUT.
_, err = io.Copy(ctx.stdoutStream, stdinReader)
if err != nil {
t.Errorf("error copying STDIN to STDOUT: %v", err)
}
}))
defer spdyServer.Close()
// Create StreamTranslatorHandler, which points upstream to fake SPDY server with
// streams STDIN, STDOUT, and STDERR. Create test server from StreamTranslatorHandler.
spdyLocation, err := url.Parse(spdyServer.URL)
if err != nil {
t.Fatalf("Unable to parse spdy server URL: %s", spdyServer.URL)
}
spdyTransport, err := fakeTransport()
if err != nil {
t.Fatalf("Unexpected error creating transport: %v", err)
}
streams := Options{Stdin: true, Stdout: true, Stderr: true}
streamTranslator := NewStreamTranslatorHandler(spdyLocation, spdyTransport, 0, streams)
streamTranslatorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
streamTranslator.ServeHTTP(w, req)
}))
defer streamTranslatorServer.Close()
// Now create the websocket client (executor), and point it to the "streamTranslatorServer".
streamTranslatorLocation, err := url.Parse(streamTranslatorServer.URL)
if err != nil {
t.Fatalf("Unable to parse StreamTranslator server URL: %s", streamTranslatorServer.URL)
}
exec, err := remotecommand.NewWebSocketExecutor(&rest.Config{Host: streamTranslatorLocation.Host}, "GET", streamTranslatorServer.URL)
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
// Generate random data, and set it up to stream on STDIN. The data will be
// returned on the STDOUT and STDERR buffer.
randomSize := 1024 * 1024
randomData := make([]byte, randomSize)
if _, err := rand.Read(randomData); err != nil {
t.Errorf("unexpected error reading random data: %v", err)
}
var stdout, stderr bytes.Buffer
options := &remotecommand.StreamOptions{
Stdin: bytes.NewReader(randomData),
Stdout: &stdout,
Stderr: &stderr,
}
errorChan := make(chan error)
go func() {
// Start the streaming on the WebSocket "exec" client.
errorChan <- exec.StreamWithContext(context.Background(), *options)
}()
select {
case <-time.After(wait.ForeverTestTimeout):
t.Fatalf("expect stream to be closed after connection is closed.")
case err := <-errorChan:
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
stdoutBytes, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
if err != nil {
t.Errorf("error reading the stream: %v", err)
return
}
// Check the random data sent on STDIN was the same returned on STDOUT.
if !bytes.Equal(stdoutBytes, randomData) {
t.Errorf("unexpected data received: %d sent: %d", len(stdoutBytes), len(randomData))
}
stderrBytes, err := io.ReadAll(bytes.NewReader(stderr.Bytes()))
if err != nil {
t.Errorf("error reading the stream: %v", err)
return
}
// Check the random data sent on STDIN was the same returned on STDERR.
if !bytes.Equal(stderrBytes, randomData) {
t.Errorf("unexpected data received: %d sent: %d", len(stderrBytes), len(randomData))
}
}
// TestStreamTranslator_ThrottleReadChannels tests two streams (STDOUT, STDERR) using rate limited streams.
func TestStreamTranslator_ThrottleReadChannels(t *testing.T) {
// Create upstream fake SPDY server which copies STDIN back onto STDOUT and STDERR stream.
spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx, err := createSPDYServerStreams(w, req, Options{
Stdin: true,
Stdout: true,
Stderr: true,
})
if err != nil {
t.Errorf("error on createHTTPStreams: %v", err)
return
}
defer ctx.conn.Close()
// TeeReader copies data read on STDIN onto STDERR.
stdinReader := io.TeeReader(ctx.stdinStream, ctx.stderrStream)
// Also copy STDIN to STDOUT.
_, err = io.Copy(ctx.stdoutStream, stdinReader)
if err != nil {
t.Errorf("error copying STDIN to STDOUT: %v", err)
}
}))
defer spdyServer.Close()
// Create StreamTranslatorHandler, which points upstream to fake SPDY server with
// streams STDIN, STDOUT, and STDERR. Create test server from StreamTranslatorHandler.
spdyLocation, err := url.Parse(spdyServer.URL)
if err != nil {
t.Fatalf("Unable to parse spdy server URL: %s", spdyServer.URL)
}
spdyTransport, err := fakeTransport()
if err != nil {
t.Fatalf("Unexpected error creating transport: %v", err)
}
streams := Options{Stdin: true, Stdout: true, Stderr: true}
maxBytesPerSec := 900 * 1024 // slightly less than the 1MB that is being transferred to exercise throttling.
streamTranslator := NewStreamTranslatorHandler(spdyLocation, spdyTransport, int64(maxBytesPerSec), streams)
streamTranslatorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
streamTranslator.ServeHTTP(w, req)
}))
defer streamTranslatorServer.Close()
// Now create the websocket client (executor), and point it to the "streamTranslatorServer".
streamTranslatorLocation, err := url.Parse(streamTranslatorServer.URL)
if err != nil {
t.Fatalf("Unable to parse StreamTranslator server URL: %s", streamTranslatorServer.URL)
}
exec, err := remotecommand.NewWebSocketExecutor(&rest.Config{Host: streamTranslatorLocation.Host}, "GET", streamTranslatorServer.URL)
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
// Generate random data, and set it up to stream on STDIN. The data will be
// returned on the STDOUT and STDERR buffer.
randomSize := 1024 * 1024
randomData := make([]byte, randomSize)
if _, err := rand.Read(randomData); err != nil {
t.Errorf("unexpected error reading random data: %v", err)
}
var stdout, stderr bytes.Buffer
options := &remotecommand.StreamOptions{
Stdin: bytes.NewReader(randomData),
Stdout: &stdout,
Stderr: &stderr,
}
errorChan := make(chan error)
go func() {
// Start the streaming on the WebSocket "exec" client.
errorChan <- exec.StreamWithContext(context.Background(), *options)
}()
select {
case <-time.After(wait.ForeverTestTimeout):
t.Fatalf("expect stream to be closed after connection is closed.")
case err := <-errorChan:
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
stdoutBytes, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
if err != nil {
t.Errorf("error reading the stream: %v", err)
return
}
// Check the random data sent on STDIN was the same returned on STDOUT.
if !bytes.Equal(stdoutBytes, randomData) {
t.Errorf("unexpected data received: %d sent: %d", len(stdoutBytes), len(randomData))
}
stderrBytes, err := io.ReadAll(bytes.NewReader(stderr.Bytes()))
if err != nil {
t.Errorf("error reading the stream: %v", err)
return
}
// Check the random data sent on STDIN was the same returned on STDERR.
if !bytes.Equal(stderrBytes, randomData) {
t.Errorf("unexpected data received: %d sent: %d", len(stderrBytes), len(randomData))
}
}
// fakeTerminalSizeQueue implements TerminalSizeQueue, returning a random set of
// "maxSizes" number of TerminalSizes, storing the TerminalSizes in "sizes" slice.
type fakeTerminalSizeQueue struct {
maxSizes int
terminalSizes []remotecommand.TerminalSize
}
// newTerminalSizeQueue returns a pointer to a fakeTerminalSizeQueue passing
// "max" number of random TerminalSizes created.
func newTerminalSizeQueue(max int) *fakeTerminalSizeQueue {
return &fakeTerminalSizeQueue{
maxSizes: max,
terminalSizes: make([]remotecommand.TerminalSize, 0, max),
}
}
// Next returns a pointer to the next random TerminalSize, or nil if we have
// already returned "maxSizes" TerminalSizes already. Stores the randomly
// created TerminalSize in "terminalSizes" field for later validation.
func (f *fakeTerminalSizeQueue) Next() *remotecommand.TerminalSize {
if len(f.terminalSizes) >= f.maxSizes {
return nil
}
size := randomTerminalSize()
f.terminalSizes = append(f.terminalSizes, size)
return &size
}
// randomTerminalSize returns a TerminalSize with random values in the
// range (0-65535) for the fields Width and Height.
func randomTerminalSize() remotecommand.TerminalSize {
randWidth := uint16(mrand.Intn(int(math.Pow(2, 16))))
randHeight := uint16(mrand.Intn(int(math.Pow(2, 16))))
return remotecommand.TerminalSize{
Width: randWidth,
Height: randHeight,
}
}
// TestStreamTranslator_MultipleWriteChannels
func TestStreamTranslator_TTYResizeChannel(t *testing.T) {
// Create the fake terminal size queue and the actualTerminalSizes which
// will be received at the opposite websocket endpoint.
numSizeQueue := 10000
sizeQueue := newTerminalSizeQueue(numSizeQueue)
actualTerminalSizes := make([]remotecommand.TerminalSize, 0, numSizeQueue)
// Create upstream fake SPDY server which copies STDIN back onto STDERR stream.
spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx, err := createSPDYServerStreams(w, req, Options{
Tty: true,
})
if err != nil {
t.Errorf("error on createHTTPStreams: %v", err)
return
}
defer ctx.conn.Close()
// Read the terminal resize requests, storing them in actualTerminalSizes
for i := 0; i < numSizeQueue; i++ {
actualTerminalSize := <-ctx.resizeChan
actualTerminalSizes = append(actualTerminalSizes, actualTerminalSize)
}
}))
defer spdyServer.Close()
// Create StreamTranslatorHandler, which points upstream to fake SPDY server with
// resize (TTY resize) stream. Create test server from StreamTranslatorHandler.
spdyLocation, err := url.Parse(spdyServer.URL)
if err != nil {
t.Fatalf("Unable to parse spdy server URL: %s", spdyServer.URL)
}
spdyTransport, err := fakeTransport()
if err != nil {
t.Fatalf("Unexpected error creating transport: %v", err)
}
streams := Options{Tty: true}
streamTranslator := NewStreamTranslatorHandler(spdyLocation, spdyTransport, 0, streams)
streamTranslatorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
streamTranslator.ServeHTTP(w, req)
}))
defer streamTranslatorServer.Close()
// Now create the websocket client (executor), and point it to the "streamTranslatorServer".
streamTranslatorLocation, err := url.Parse(streamTranslatorServer.URL)
if err != nil {
t.Fatalf("Unable to parse StreamTranslator server URL: %s", streamTranslatorServer.URL)
}
exec, err := remotecommand.NewWebSocketExecutor(&rest.Config{Host: streamTranslatorLocation.Host}, "GET", streamTranslatorServer.URL)
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
options := &remotecommand.StreamOptions{
Tty: true,
TerminalSizeQueue: sizeQueue,
}
errorChan := make(chan error)
go func() {
// Start the streaming on the WebSocket "exec" client.
errorChan <- exec.StreamWithContext(context.Background(), *options)
}()
select {
case <-time.After(wait.ForeverTestTimeout):
t.Fatalf("expect stream to be closed after connection is closed.")
case err := <-errorChan:
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
// Validate the random TerminalSizes sent on the resize stream are the same
// as the actual TerminalSizes received at the websocket server.
if len(actualTerminalSizes) != numSizeQueue {
t.Fatalf("expected to receive num terminal resizes (%d), got (%d)",
numSizeQueue, len(actualTerminalSizes))
}
for i, actual := range actualTerminalSizes {
expected := sizeQueue.terminalSizes[i]
if !reflect.DeepEqual(expected, actual) {
t.Errorf("expected terminal resize window %v, got %v", expected, actual)
}
}
}
// TestStreamTranslator_WebSocketServerErrors validates that when there is a problem creating
// the websocket server as the first step of the StreamTranslator an error is properly returned.
func TestStreamTranslator_WebSocketServerErrors(t *testing.T) {
spdyLocation, err := url.Parse("http://127.0.0.1")
if err != nil {
t.Fatalf("Unable to parse spdy server URL")
}
spdyTransport, err := fakeTransport()
if err != nil {
t.Fatalf("Unexpected error creating transport: %v", err)
}
streamTranslator := NewStreamTranslatorHandler(spdyLocation, spdyTransport, 0, Options{})
streamTranslatorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
streamTranslator.ServeHTTP(w, req)
}))
defer streamTranslatorServer.Close()
// Now create the websocket client (executor), and point it to the "streamTranslatorServer".
streamTranslatorLocation, err := url.Parse(streamTranslatorServer.URL)
if err != nil {
t.Fatalf("Unable to parse StreamTranslator server URL: %s", streamTranslatorServer.URL)
}
exec, err := remotecommand.NewWebSocketExecutorForProtocols(
&rest.Config{Host: streamTranslatorLocation.Host},
"GET",
streamTranslatorServer.URL,
rcconstants.StreamProtocolV4Name, // RemoteCommand V4 protocol is unsupported
)
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
errorChan := make(chan error)
go func() {
// Start the streaming on the WebSocket "exec" client. The WebSocket server within the
// StreamTranslator propagates an error here because the V4 protocol is not supported.
errorChan <- exec.StreamWithContext(context.Background(), remotecommand.StreamOptions{})
}()
select {
case <-time.After(wait.ForeverTestTimeout):
t.Fatalf("expect stream to be closed after connection is closed.")
case err := <-errorChan:
// Must return "websocket unable to upgrade" (bad handshake) error.
if err == nil {
t.Fatalf("expected error, but received none")
}
if !strings.Contains(err.Error(), "unable to upgrade streaming request") {
t.Errorf("expected websocket bad handshake error, got (%s)", err)
}
}
}
// TestStreamTranslator_BlockRedirects verifies that the StreamTranslator will *not* follow
// redirects; it will thrown an error instead.
func TestStreamTranslator_BlockRedirects(t *testing.T) {
for _, statusCode := range []int{
http.StatusMovedPermanently, // 301
http.StatusFound, // 302
http.StatusSeeOther, // 303
http.StatusTemporaryRedirect, // 307
http.StatusPermanentRedirect, // 308
} {
// Create upstream fake SPDY server which returns a redirect.
spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Location", "/")
w.WriteHeader(statusCode)
}))
defer spdyServer.Close()
spdyLocation, err := url.Parse(spdyServer.URL)
if err != nil {
t.Fatalf("Unable to parse spdy server URL: %s", spdyServer.URL)
}
spdyTransport, err := fakeTransport()
if err != nil {
t.Fatalf("Unexpected error creating transport: %v", err)
}
streams := Options{Stdout: true}
streamTranslator := NewStreamTranslatorHandler(spdyLocation, spdyTransport, 0, streams)
streamTranslatorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
streamTranslator.ServeHTTP(w, req)
}))
defer streamTranslatorServer.Close()
// Now create the websocket client (executor), and point it to the "streamTranslatorServer".
streamTranslatorLocation, err := url.Parse(streamTranslatorServer.URL)
if err != nil {
t.Fatalf("Unable to parse StreamTranslator server URL: %s", streamTranslatorServer.URL)
}
exec, err := remotecommand.NewWebSocketExecutor(&rest.Config{Host: streamTranslatorLocation.Host}, "GET", streamTranslatorServer.URL)
if err != nil {
t.Errorf("unexpected error creating websocket executor: %v", err)
}
errorChan := make(chan error)
go func() {
// Start the streaming on the WebSocket "exec" client.
// Should return "redirect not allowed" error.
errorChan <- exec.StreamWithContext(context.Background(), remotecommand.StreamOptions{})
}()
select {
case <-time.After(wait.ForeverTestTimeout):
t.Fatalf("expect stream to be closed after connection is closed.")
case err := <-errorChan:
// Must return "redirect now allowed" error.
if err == nil {
t.Fatalf("expected error, but received none")
}
if !strings.Contains(err.Error(), "redirect not allowed") {
t.Errorf("expected redirect not allowed error, got (%s)", err)
}
}
}
}
// streamContext encapsulates the structures necessary to communicate through
// a SPDY connection, including the Reader/Writer streams.
type streamContext struct {
conn io.Closer
stdinStream io.ReadCloser
stdoutStream io.WriteCloser
stderrStream io.WriteCloser
resizeStream io.ReadCloser
resizeChan chan remotecommand.TerminalSize
writeStatus func(status *apierrors.StatusError) error
}
type streamAndReply struct {
httpstream.Stream
replySent <-chan struct{}
}
// CreateSPDYServerStreams upgrades the passed HTTP request to a SPDY bi-directional streaming
// connection with remote command streams defined in passed options. Returns a streamContext
// structure containing the Reader/Writer streams to communicate through the SDPY connection.
// Returns an error if unable to upgrade the HTTP connection to a SPDY connection.
func createSPDYServerStreams(w http.ResponseWriter, req *http.Request, opts Options) (*streamContext, error) {
_, err := httpstream.Handshake(req, w, []string{rcconstants.StreamProtocolV4Name})
if err != nil {
return nil, err
}
upgrader := spdy.NewResponseUpgrader()
streamCh := make(chan streamAndReply)
conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream, replySent <-chan struct{}) error {
streamCh <- streamAndReply{Stream: stream, replySent: replySent}
return nil
})
ctx := &streamContext{
conn: conn,
}
// wait for stream
replyChan := make(chan struct{}, 5)
defer close(replyChan)
receivedStreams := 0
expectedStreams := 1 // expect at least the error stream
if opts.Stdout {
expectedStreams++
}
if opts.Stdin {
expectedStreams++
}
if opts.Stderr {
expectedStreams++
}
if opts.Tty {
expectedStreams++
}
WaitForStreams:
for {
select {
case stream := <-streamCh:
streamType := stream.Headers().Get(v1.StreamType)
switch streamType {
case v1.StreamTypeError:
replyChan <- struct{}{}
ctx.writeStatus = v4WriteStatusFunc(stream)
case v1.StreamTypeStdout:
replyChan <- struct{}{}
ctx.stdoutStream = stream
case v1.StreamTypeStdin:
replyChan <- struct{}{}
ctx.stdinStream = stream
case v1.StreamTypeStderr:
replyChan <- struct{}{}
ctx.stderrStream = stream
case v1.StreamTypeResize:
replyChan <- struct{}{}
ctx.resizeStream = stream
default:
// add other stream ...
return nil, errors.New("unimplemented stream type")
}
case <-replyChan:
receivedStreams++
if receivedStreams == expectedStreams {
break WaitForStreams
}
}
}
if ctx.resizeStream != nil {
ctx.resizeChan = make(chan remotecommand.TerminalSize)
go handleResizeEvents(req.Context(), ctx.resizeStream, ctx.resizeChan)
}
return ctx, nil
}
func v4WriteStatusFunc(stream io.Writer) func(status *apierrors.StatusError) error {
return func(status *apierrors.StatusError) error {
bs, err := json.Marshal(status.Status())
if err != nil {
return err
}
_, err = stream.Write(bs)
return err
}
}
func fakeTransport() (*http.Transport, error) {
cfg := &transport.Config{
TLS: transport.TLSConfig{
Insecure: true,
CAFile: "",
},
}
rt, err := transport.New(cfg)
if err != nil {
return nil, err
}
t, ok := rt.(*http.Transport)
if !ok {
return nil, fmt.Errorf("unknown transport type: %T", rt)
}
return t, nil
}

View File

@ -0,0 +1,51 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package proxy
import (
"net/http"
"k8s.io/klog/v2"
)
// translatingHandler wraps the delegate handler, implementing the
// http.Handler interface. The delegate handles all requests unless
// the request satisfies the passed "shouldTranslate" function
// (currently only for WebSocket/V5 request), in which case the translator
// handles the request.
type translatingHandler struct {
delegate http.Handler
translator http.Handler
shouldTranslate func(*http.Request) bool
}
func NewTranslatingHandler(delegate http.Handler, translator http.Handler, shouldTranslate func(*http.Request) bool) http.Handler {
return &translatingHandler{
delegate: delegate,
translator: translator,
shouldTranslate: shouldTranslate,
}
}
func (t *translatingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if t.shouldTranslate(req) {
klog.V(4).Infof("request handled by translator proxy")
t.translator.ServeHTTP(w, req)
return
}
t.delegate.ServeHTTP(w, req)
}

View File

@ -0,0 +1,121 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package proxy
import (
"net/http"
"testing"
"github.com/stretchr/testify/require"
"k8s.io/apimachinery/pkg/util/httpstream/wsstream"
)
// fakeHandler implements http.Handler interface
type fakeHandler struct {
served bool
}
// ServeHTTP stores the fact that this fake handler was called.
func (fh *fakeHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
fh.served = true
}
func TestTranslatingHandler(t *testing.T) {
tests := map[string]struct {
upgrade string
version string
expectTranslator bool
}{
"websocket/v5 upgrade, serves translator": {
upgrade: "websocket",
version: "v5.channel.k8s.io",
expectTranslator: true,
},
"websocket/v5 upgrade with multiple other versions, serves translator": {
upgrade: "websocket",
version: "v5.channel.k8s.io, v4.channel.k8s.io, v3.channel.k8s.io",
expectTranslator: true,
},
"websocket/v5 upgrade with multiple other versions out of order, serves translator": {
upgrade: "websocket",
version: "v4.channel.k8s.io, v3.channel.k8s.io, v5.channel.k8s.io",
expectTranslator: true,
},
"no upgrade, serves delegate": {
upgrade: "",
version: "",
expectTranslator: false,
},
"no upgrade with v5, serves delegate": {
upgrade: "",
version: "v5.channel.k8s.io",
expectTranslator: false,
},
"websocket/v5 wrong case upgrade, serves delegage": {
upgrade: "websocket",
version: "v5.CHANNEL.k8s.io",
expectTranslator: false,
},
"spdy/v5 upgrade, serves delegate": {
upgrade: "spdy",
version: "v5.channel.k8s.io",
expectTranslator: false,
},
"spdy/v4 upgrade, serves delegate": {
upgrade: "spdy",
version: "v4.channel.k8s.io",
expectTranslator: false,
},
"websocket/v4 upgrade, serves delegate": {
upgrade: "websocket",
version: "v4.channel.k8s.io",
expectTranslator: false,
},
"websocket without version upgrade, serves delegate": {
upgrade: "websocket",
version: "",
expectTranslator: false,
},
}
for name, test := range tests {
req, err := http.NewRequest("GET", "http://www.example.com/", nil)
require.NoError(t, err)
if test.upgrade != "" {
req.Header.Add("Connection", "Upgrade")
req.Header.Add("Upgrade", test.upgrade)
}
if len(test.version) > 0 {
req.Header.Add(wsstream.WebSocketProtocolHeader, test.version)
}
delegate := fakeHandler{}
translator := fakeHandler{}
translatingHandler := NewTranslatingHandler(&delegate, &translator,
wsstream.IsWebSocketRequestWithStreamCloseProtocol)
translatingHandler.ServeHTTP(nil, req)
if !delegate.served && !translator.served {
t.Errorf("unexpected neither translator nor delegate served")
continue
}
if test.expectTranslator {
if !translator.served {
t.Errorf("%s: expected translator served, got delegate served", name)
}
} else if !delegate.served {
t.Errorf("%s: expected delegate served, got translator served", name)
}
}
}

200
pkg/util/proxy/websocket.go Normal file
View File

@ -0,0 +1,200 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package proxy
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/util/httpstream/wsstream"
constants "k8s.io/apimachinery/pkg/util/remotecommand"
"k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/client-go/tools/remotecommand"
)
const (
// idleTimeout is the read/write deadline set for websocket server connection. Reading
// or writing the connection will return an i/o timeout if this deadline is exceeded.
// Currently, we use the same value as the kubelet websocket server.
defaultIdleConnectionTimeout = 4 * time.Hour
// Deadline for writing errors to the websocket connection before io/timeout.
writeErrorDeadline = 10 * time.Second
)
// Options contains details about which streams are required for
// remote command execution.
type Options struct {
Stdin bool
Stdout bool
Stderr bool
Tty bool
}
// conns contains the connection and streams used when
// forwarding an attach or execute session into a container.
type conns struct {
conn io.Closer
stdinStream io.ReadCloser
stdoutStream io.WriteCloser
stderrStream io.WriteCloser
writeStatus func(status *apierrors.StatusError) error
resizeStream io.ReadCloser
resizeChan chan remotecommand.TerminalSize
tty bool
}
// Create WebSocket server streams to respond to a WebSocket client. Creates the streams passed
// in the stream options.
func webSocketServerStreams(req *http.Request, w http.ResponseWriter, opts Options) (*conns, error) {
ctx, err := createWebSocketStreams(req, w, opts)
if err != nil {
return nil, err
}
if ctx.resizeStream != nil {
ctx.resizeChan = make(chan remotecommand.TerminalSize)
go func() {
// Resize channel closes in panic case, and panic does not take down caller.
defer func() {
if p := recover(); p != nil {
// Standard panic logging.
for _, fn := range runtime.PanicHandlers {
fn(p)
}
}
}()
handleResizeEvents(req.Context(), ctx.resizeStream, ctx.resizeChan)
}()
}
return ctx, nil
}
// Read terminal resize events off of passed stream and queue into passed channel.
func handleResizeEvents(ctx context.Context, stream io.Reader, channel chan<- remotecommand.TerminalSize) {
defer close(channel)
decoder := json.NewDecoder(stream)
for {
size := remotecommand.TerminalSize{}
if err := decoder.Decode(&size); err != nil {
break
}
select {
case channel <- size:
case <-ctx.Done():
// To avoid leaking this routine, exit if the http request finishes. This path
// would generally be hit if starting the process fails and nothing is started to
// ingest these resize events.
return
}
}
}
// createChannels returns the standard channel types for a shell connection (STDIN 0, STDOUT 1, STDERR 2)
// along with the approximate duplex value. It also creates the error (3) and resize (4) channels.
func createChannels(opts Options) []wsstream.ChannelType {
// open the requested channels, and always open the error channel
channels := make([]wsstream.ChannelType, 5)
channels[constants.StreamStdIn] = readChannel(opts.Stdin)
channels[constants.StreamStdOut] = writeChannel(opts.Stdout)
channels[constants.StreamStdErr] = writeChannel(opts.Stderr)
channels[constants.StreamErr] = wsstream.WriteChannel
channels[constants.StreamResize] = wsstream.ReadChannel
return channels
}
// readChannel returns wsstream.ReadChannel if real is true, or wsstream.IgnoreChannel.
func readChannel(real bool) wsstream.ChannelType {
if real {
return wsstream.ReadChannel
}
return wsstream.IgnoreChannel
}
// writeChannel returns wsstream.WriteChannel if real is true, or wsstream.IgnoreChannel.
func writeChannel(real bool) wsstream.ChannelType {
if real {
return wsstream.WriteChannel
}
return wsstream.IgnoreChannel
}
// createWebSocketStreams returns a "conns" struct containing the websocket connection and
// streams needed to perform an exec or an attach.
func createWebSocketStreams(req *http.Request, w http.ResponseWriter, opts Options) (*conns, error) {
channels := createChannels(opts)
conn := wsstream.NewConn(map[string]wsstream.ChannelProtocolConfig{
// WebSocket server only supports remote command version 5.
constants.StreamProtocolV5Name: {
Binary: true,
Channels: channels,
},
})
conn.SetIdleTimeout(defaultIdleConnectionTimeout)
// Opening the connection responds to WebSocket client, negotiating
// the WebSocket upgrade connection and the subprotocol.
_, streams, err := conn.Open(w, req)
if err != nil {
return nil, err
}
// Send an empty message to the lowest writable channel to notify the client the connection is established
switch {
case opts.Stdout:
_, err = streams[constants.StreamStdOut].Write([]byte{})
case opts.Stderr:
_, err = streams[constants.StreamStdErr].Write([]byte{})
default:
_, err = streams[constants.StreamErr].Write([]byte{})
}
if err != nil {
conn.Close()
return nil, fmt.Errorf("write error during websocket server creation: %v", err)
}
ctx := &conns{
conn: conn,
stdinStream: streams[constants.StreamStdIn],
stdoutStream: streams[constants.StreamStdOut],
stderrStream: streams[constants.StreamStdErr],
tty: opts.Tty,
resizeStream: streams[constants.StreamResize],
}
// writeStatus returns a WriteStatusFunc that marshals a given api Status
// as json in the error channel.
ctx.writeStatus = func(status *apierrors.StatusError) error {
bs, err := json.Marshal(status.Status())
if err != nil {
return err
}
// Write status error to error stream with deadline.
conn.SetWriteDeadline(writeErrorDeadline)
_, err = streams[constants.StreamErr].Write(bs)
return err
}
return ctx, nil
}