Refactor wsstream library from apiserver to apimachinery
Kubernetes-commit: 8f3109da7913ef17c6656893f12f0e29ceabbde0
This commit is contained in:
parent
7d79c570c7
commit
cdd93b4685
|
@ -24,8 +24,8 @@ import (
|
|||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"k8s.io/apimachinery/pkg/util/httpstream/wsstream"
|
||||
"k8s.io/apiserver/pkg/authentication/authenticator"
|
||||
"k8s.io/apiserver/pkg/util/wsstream"
|
||||
)
|
||||
|
||||
const bearerProtocolPrefix = "base64url.bearer.authorization.k8s.io."
|
||||
|
|
|
@ -34,6 +34,7 @@ import (
|
|||
|
||||
"k8s.io/apimachinery/pkg/runtime"
|
||||
"k8s.io/apimachinery/pkg/runtime/schema"
|
||||
"k8s.io/apimachinery/pkg/util/httpstream/wsstream"
|
||||
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
|
||||
"k8s.io/apiserver/pkg/audit"
|
||||
"k8s.io/apiserver/pkg/endpoints/handlers/negotiation"
|
||||
|
@ -42,7 +43,6 @@ import (
|
|||
"k8s.io/apiserver/pkg/registry/rest"
|
||||
utilfeature "k8s.io/apiserver/pkg/util/feature"
|
||||
"k8s.io/apiserver/pkg/util/flushwriter"
|
||||
"k8s.io/apiserver/pkg/util/wsstream"
|
||||
"k8s.io/component-base/tracing"
|
||||
)
|
||||
|
||||
|
|
|
@ -30,12 +30,12 @@ import (
|
|||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
"k8s.io/apimachinery/pkg/runtime"
|
||||
"k8s.io/apimachinery/pkg/runtime/serializer/streaming"
|
||||
"k8s.io/apimachinery/pkg/util/httpstream/wsstream"
|
||||
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
|
||||
"k8s.io/apimachinery/pkg/watch"
|
||||
"k8s.io/apiserver/pkg/endpoints/handlers/negotiation"
|
||||
"k8s.io/apiserver/pkg/endpoints/metrics"
|
||||
apirequest "k8s.io/apiserver/pkg/endpoints/request"
|
||||
"k8s.io/apiserver/pkg/util/wsstream"
|
||||
)
|
||||
|
||||
// nothing will ever be sent down this channel
|
||||
|
|
|
@ -1,350 +0,0 @@
|
|||
/*
|
||||
Copyright 2015 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package wsstream
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/websocket"
|
||||
"k8s.io/klog/v2"
|
||||
|
||||
"k8s.io/apimachinery/pkg/util/runtime"
|
||||
)
|
||||
|
||||
// The Websocket subprotocol "channel.k8s.io" prepends each binary message with a byte indicating
|
||||
// the channel number (zero indexed) the message was sent on. Messages in both directions should
|
||||
// prefix their messages with this channel byte. When used for remote execution, the channel numbers
|
||||
// are by convention defined to match the POSIX file-descriptors assigned to STDIN, STDOUT, and STDERR
|
||||
// (0, 1, and 2). No other conversion is performed on the raw subprotocol - writes are sent as they
|
||||
// are received by the server.
|
||||
//
|
||||
// Example client session:
|
||||
//
|
||||
// CONNECT http://server.com with subprotocol "channel.k8s.io"
|
||||
// WRITE []byte{0, 102, 111, 111, 10} # send "foo\n" on channel 0 (STDIN)
|
||||
// READ []byte{1, 10} # receive "\n" on channel 1 (STDOUT)
|
||||
// CLOSE
|
||||
const ChannelWebSocketProtocol = "channel.k8s.io"
|
||||
|
||||
// The Websocket subprotocol "base64.channel.k8s.io" base64 encodes each message with a character
|
||||
// indicating the channel number (zero indexed) the message was sent on. Messages in both directions
|
||||
// should prefix their messages with this channel char. When used for remote execution, the channel
|
||||
// numbers are by convention defined to match the POSIX file-descriptors assigned to STDIN, STDOUT,
|
||||
// and STDERR ('0', '1', and '2'). The data received on the server is base64 decoded (and must be
|
||||
// be valid) and data written by the server to the client is base64 encoded.
|
||||
//
|
||||
// Example client session:
|
||||
//
|
||||
// CONNECT http://server.com with subprotocol "base64.channel.k8s.io"
|
||||
// WRITE []byte{48, 90, 109, 57, 118, 67, 103, 111, 61} # send "foo\n" (base64: "Zm9vCgo=") on channel '0' (STDIN)
|
||||
// READ []byte{49, 67, 103, 61, 61} # receive "\n" (base64: "Cg==") on channel '1' (STDOUT)
|
||||
// CLOSE
|
||||
const Base64ChannelWebSocketProtocol = "base64.channel.k8s.io"
|
||||
|
||||
type codecType int
|
||||
|
||||
const (
|
||||
rawCodec codecType = iota
|
||||
base64Codec
|
||||
)
|
||||
|
||||
type ChannelType int
|
||||
|
||||
const (
|
||||
IgnoreChannel ChannelType = iota
|
||||
ReadChannel
|
||||
WriteChannel
|
||||
ReadWriteChannel
|
||||
)
|
||||
|
||||
var (
|
||||
// connectionUpgradeRegex matches any Connection header value that includes upgrade
|
||||
connectionUpgradeRegex = regexp.MustCompile("(^|.*,\\s*)upgrade($|\\s*,)")
|
||||
)
|
||||
|
||||
// IsWebSocketRequest returns true if the incoming request contains connection upgrade headers
|
||||
// for WebSockets.
|
||||
func IsWebSocketRequest(req *http.Request) bool {
|
||||
if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") {
|
||||
return false
|
||||
}
|
||||
return connectionUpgradeRegex.MatchString(strings.ToLower(req.Header.Get("Connection")))
|
||||
}
|
||||
|
||||
// IgnoreReceives reads from a WebSocket until it is closed, then returns. If timeout is set, the
|
||||
// read and write deadlines are pushed every time a new message is received.
|
||||
func IgnoreReceives(ws *websocket.Conn, timeout time.Duration) {
|
||||
defer runtime.HandleCrash()
|
||||
var data []byte
|
||||
for {
|
||||
resetTimeout(ws, timeout)
|
||||
if err := websocket.Message.Receive(ws, &data); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handshake ensures the provided user protocol matches one of the allowed protocols. It returns
|
||||
// no error if no protocol is specified.
|
||||
func handshake(config *websocket.Config, req *http.Request, allowed []string) error {
|
||||
protocols := config.Protocol
|
||||
if len(protocols) == 0 {
|
||||
protocols = []string{""}
|
||||
}
|
||||
|
||||
for _, protocol := range protocols {
|
||||
for _, allow := range allowed {
|
||||
if allow == protocol {
|
||||
config.Protocol = []string{protocol}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("requested protocol(s) are not supported: %v; supports %v", config.Protocol, allowed)
|
||||
}
|
||||
|
||||
// ChannelProtocolConfig describes a websocket subprotocol with channels.
|
||||
type ChannelProtocolConfig struct {
|
||||
Binary bool
|
||||
Channels []ChannelType
|
||||
}
|
||||
|
||||
// NewDefaultChannelProtocols returns a channel protocol map with the
|
||||
// subprotocols "", "channel.k8s.io", "base64.channel.k8s.io" and the given
|
||||
// channels.
|
||||
func NewDefaultChannelProtocols(channels []ChannelType) map[string]ChannelProtocolConfig {
|
||||
return map[string]ChannelProtocolConfig{
|
||||
"": {Binary: true, Channels: channels},
|
||||
ChannelWebSocketProtocol: {Binary: true, Channels: channels},
|
||||
Base64ChannelWebSocketProtocol: {Binary: false, Channels: channels},
|
||||
}
|
||||
}
|
||||
|
||||
// Conn supports sending multiple binary channels over a websocket connection.
|
||||
type Conn struct {
|
||||
protocols map[string]ChannelProtocolConfig
|
||||
selectedProtocol string
|
||||
channels []*websocketChannel
|
||||
codec codecType
|
||||
ready chan struct{}
|
||||
ws *websocket.Conn
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// NewConn creates a WebSocket connection that supports a set of channels. Channels begin each
|
||||
// web socket message with a single byte indicating the channel number (0-N). 255 is reserved for
|
||||
// future use. The channel types for each channel are passed as an array, supporting the different
|
||||
// duplex modes. Read and Write refer to whether the channel can be used as a Reader or Writer.
|
||||
//
|
||||
// The protocols parameter maps subprotocol names to ChannelProtocols. The empty string subprotocol
|
||||
// name is used if websocket.Config.Protocol is empty.
|
||||
func NewConn(protocols map[string]ChannelProtocolConfig) *Conn {
|
||||
return &Conn{
|
||||
ready: make(chan struct{}),
|
||||
protocols: protocols,
|
||||
}
|
||||
}
|
||||
|
||||
// SetIdleTimeout sets the interval for both reads and writes before timeout. If not specified,
|
||||
// there is no timeout on the connection.
|
||||
func (conn *Conn) SetIdleTimeout(duration time.Duration) {
|
||||
conn.timeout = duration
|
||||
}
|
||||
|
||||
// Open the connection and create channels for reading and writing. It returns
|
||||
// the selected subprotocol, a slice of channels and an error.
|
||||
func (conn *Conn) Open(w http.ResponseWriter, req *http.Request) (string, []io.ReadWriteCloser, error) {
|
||||
go func() {
|
||||
defer runtime.HandleCrash()
|
||||
defer conn.Close()
|
||||
websocket.Server{Handshake: conn.handshake, Handler: conn.handle}.ServeHTTP(w, req)
|
||||
}()
|
||||
<-conn.ready
|
||||
rwc := make([]io.ReadWriteCloser, len(conn.channels))
|
||||
for i := range conn.channels {
|
||||
rwc[i] = conn.channels[i]
|
||||
}
|
||||
return conn.selectedProtocol, rwc, nil
|
||||
}
|
||||
|
||||
func (conn *Conn) initialize(ws *websocket.Conn) {
|
||||
negotiated := ws.Config().Protocol
|
||||
conn.selectedProtocol = negotiated[0]
|
||||
p := conn.protocols[conn.selectedProtocol]
|
||||
if p.Binary {
|
||||
conn.codec = rawCodec
|
||||
} else {
|
||||
conn.codec = base64Codec
|
||||
}
|
||||
conn.ws = ws
|
||||
conn.channels = make([]*websocketChannel, len(p.Channels))
|
||||
for i, t := range p.Channels {
|
||||
switch t {
|
||||
case ReadChannel:
|
||||
conn.channels[i] = newWebsocketChannel(conn, byte(i), true, false)
|
||||
case WriteChannel:
|
||||
conn.channels[i] = newWebsocketChannel(conn, byte(i), false, true)
|
||||
case ReadWriteChannel:
|
||||
conn.channels[i] = newWebsocketChannel(conn, byte(i), true, true)
|
||||
case IgnoreChannel:
|
||||
conn.channels[i] = newWebsocketChannel(conn, byte(i), false, false)
|
||||
}
|
||||
}
|
||||
|
||||
close(conn.ready)
|
||||
}
|
||||
|
||||
func (conn *Conn) handshake(config *websocket.Config, req *http.Request) error {
|
||||
supportedProtocols := make([]string, 0, len(conn.protocols))
|
||||
for p := range conn.protocols {
|
||||
supportedProtocols = append(supportedProtocols, p)
|
||||
}
|
||||
return handshake(config, req, supportedProtocols)
|
||||
}
|
||||
|
||||
func (conn *Conn) resetTimeout() {
|
||||
if conn.timeout > 0 {
|
||||
conn.ws.SetDeadline(time.Now().Add(conn.timeout))
|
||||
}
|
||||
}
|
||||
|
||||
// Close is only valid after Open has been called
|
||||
func (conn *Conn) Close() error {
|
||||
<-conn.ready
|
||||
for _, s := range conn.channels {
|
||||
s.Close()
|
||||
}
|
||||
conn.ws.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// handle implements a websocket handler.
|
||||
func (conn *Conn) handle(ws *websocket.Conn) {
|
||||
defer conn.Close()
|
||||
conn.initialize(ws)
|
||||
|
||||
for {
|
||||
conn.resetTimeout()
|
||||
var data []byte
|
||||
if err := websocket.Message.Receive(ws, &data); err != nil {
|
||||
if err != io.EOF {
|
||||
klog.Errorf("Error on socket receive: %v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
channel := data[0]
|
||||
if conn.codec == base64Codec {
|
||||
channel = channel - '0'
|
||||
}
|
||||
data = data[1:]
|
||||
if int(channel) >= len(conn.channels) {
|
||||
klog.V(6).Infof("Frame is targeted for a reader %d that is not valid, possible protocol error", channel)
|
||||
continue
|
||||
}
|
||||
if _, err := conn.channels[channel].DataFromSocket(data); err != nil {
|
||||
klog.Errorf("Unable to write frame to %d: %v\n%s", channel, err, string(data))
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// write multiplexes the specified channel onto the websocket
|
||||
func (conn *Conn) write(num byte, data []byte) (int, error) {
|
||||
conn.resetTimeout()
|
||||
switch conn.codec {
|
||||
case rawCodec:
|
||||
frame := make([]byte, len(data)+1)
|
||||
frame[0] = num
|
||||
copy(frame[1:], data)
|
||||
if err := websocket.Message.Send(conn.ws, frame); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
case base64Codec:
|
||||
frame := string('0'+num) + base64.StdEncoding.EncodeToString(data)
|
||||
if err := websocket.Message.Send(conn.ws, frame); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
// websocketChannel represents a channel in a connection
|
||||
type websocketChannel struct {
|
||||
conn *Conn
|
||||
num byte
|
||||
r io.Reader
|
||||
w io.WriteCloser
|
||||
|
||||
read, write bool
|
||||
}
|
||||
|
||||
// newWebsocketChannel creates a pipe for writing to a websocket. Do not write to this pipe
|
||||
// prior to the connection being opened. It may be no, half, or full duplex depending on
|
||||
// read and write.
|
||||
func newWebsocketChannel(conn *Conn, num byte, read, write bool) *websocketChannel {
|
||||
r, w := io.Pipe()
|
||||
return &websocketChannel{conn, num, r, w, read, write}
|
||||
}
|
||||
|
||||
func (p *websocketChannel) Write(data []byte) (int, error) {
|
||||
if !p.write {
|
||||
return len(data), nil
|
||||
}
|
||||
return p.conn.write(p.num, data)
|
||||
}
|
||||
|
||||
// DataFromSocket is invoked by the connection receiver to move data from the connection
|
||||
// into a specific channel.
|
||||
func (p *websocketChannel) DataFromSocket(data []byte) (int, error) {
|
||||
if !p.read {
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
switch p.conn.codec {
|
||||
case rawCodec:
|
||||
return p.w.Write(data)
|
||||
case base64Codec:
|
||||
dst := make([]byte, len(data))
|
||||
n, err := base64.StdEncoding.Decode(dst, data)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return p.w.Write(dst[:n])
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (p *websocketChannel) Read(data []byte) (int, error) {
|
||||
if !p.read {
|
||||
return 0, io.EOF
|
||||
}
|
||||
return p.r.Read(data)
|
||||
}
|
||||
|
||||
func (p *websocketChannel) Close() error {
|
||||
return p.w.Close()
|
||||
}
|
|
@ -1,274 +0,0 @@
|
|||
/*
|
||||
Copyright 2015 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package wsstream
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
func newServer(handler http.Handler) (*httptest.Server, string) {
|
||||
server := httptest.NewServer(handler)
|
||||
serverAddr := server.Listener.Addr().String()
|
||||
return server, serverAddr
|
||||
}
|
||||
|
||||
func TestRawConn(t *testing.T) {
|
||||
channels := []ChannelType{ReadWriteChannel, ReadWriteChannel, IgnoreChannel, ReadChannel, WriteChannel}
|
||||
conn := NewConn(NewDefaultChannelProtocols(channels))
|
||||
|
||||
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
conn.Open(w, req)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
client, err := websocket.Dial("ws://"+addr, "", "http://localhost/")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
<-conn.ready
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
// verify we can read a client write
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
data, err := ioutil.ReadAll(conn.channels[0])
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(data, []byte("client")) {
|
||||
t.Errorf("unexpected server read: %v", data)
|
||||
}
|
||||
}()
|
||||
|
||||
if n, err := client.Write(append([]byte{0}, []byte("client")...)); err != nil || n != 7 {
|
||||
t.Fatalf("%d: %v", n, err)
|
||||
}
|
||||
|
||||
// verify we can read a server write
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if n, err := conn.channels[1].Write([]byte("server")); err != nil && n != 6 {
|
||||
t.Errorf("%d: %v", n, err)
|
||||
}
|
||||
}()
|
||||
|
||||
data := make([]byte, 1024)
|
||||
if n, err := io.ReadAtLeast(client, data, 6); n != 7 || err != nil {
|
||||
t.Fatalf("%d: %v", n, err)
|
||||
}
|
||||
if !reflect.DeepEqual(data[:7], append([]byte{1}, []byte("server")...)) {
|
||||
t.Errorf("unexpected client read: %v", data[:7])
|
||||
}
|
||||
|
||||
// verify that an ignore channel is empty in both directions.
|
||||
if n, err := conn.channels[2].Write([]byte("test")); n != 4 || err != nil {
|
||||
t.Errorf("writes should be ignored")
|
||||
}
|
||||
data = make([]byte, 1024)
|
||||
if n, err := conn.channels[2].Read(data); n != 0 || err != io.EOF {
|
||||
t.Errorf("reads should be ignored")
|
||||
}
|
||||
|
||||
// verify that a write to a Read channel doesn't block
|
||||
if n, err := conn.channels[3].Write([]byte("test")); n != 4 || err != nil {
|
||||
t.Errorf("writes should be ignored")
|
||||
}
|
||||
|
||||
// verify that a read from a Write channel doesn't block
|
||||
data = make([]byte, 1024)
|
||||
if n, err := conn.channels[4].Read(data); n != 0 || err != io.EOF {
|
||||
t.Errorf("reads should be ignored")
|
||||
}
|
||||
|
||||
// verify that a client write to a Write channel doesn't block (is dropped)
|
||||
if n, err := client.Write(append([]byte{4}, []byte("ignored")...)); err != nil || n != 8 {
|
||||
t.Fatalf("%d: %v", n, err)
|
||||
}
|
||||
|
||||
client.Close()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestBase64Conn(t *testing.T) {
|
||||
conn := NewConn(NewDefaultChannelProtocols([]ChannelType{ReadWriteChannel, ReadWriteChannel}))
|
||||
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
conn.Open(w, req)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config.Protocol = []string{"base64.channel.k8s.io"}
|
||||
client, err := websocket.DialConfig(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
<-conn.ready
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
data, err := ioutil.ReadAll(conn.channels[0])
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(data, []byte("client")) {
|
||||
t.Errorf("unexpected server read: %s", string(data))
|
||||
}
|
||||
}()
|
||||
|
||||
clientData := base64.StdEncoding.EncodeToString([]byte("client"))
|
||||
if n, err := client.Write(append([]byte{'0'}, clientData...)); err != nil || n != len(clientData)+1 {
|
||||
t.Fatalf("%d: %v", n, err)
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if n, err := conn.channels[1].Write([]byte("server")); err != nil && n != 6 {
|
||||
t.Errorf("%d: %v", n, err)
|
||||
}
|
||||
}()
|
||||
|
||||
data := make([]byte, 1024)
|
||||
if n, err := io.ReadAtLeast(client, data, 9); n != 9 || err != nil {
|
||||
t.Fatalf("%d: %v", n, err)
|
||||
}
|
||||
expect := []byte(base64.StdEncoding.EncodeToString([]byte("server")))
|
||||
|
||||
if !reflect.DeepEqual(data[:9], append([]byte{'1'}, expect...)) {
|
||||
t.Errorf("unexpected client read: %v", data[:9])
|
||||
}
|
||||
|
||||
client.Close()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
type versionTest struct {
|
||||
supported map[string]bool // protocol -> binary
|
||||
requested []string
|
||||
error bool
|
||||
expected string
|
||||
}
|
||||
|
||||
func versionTests() []versionTest {
|
||||
const (
|
||||
binary = true
|
||||
base64 = false
|
||||
)
|
||||
return []versionTest{
|
||||
{
|
||||
supported: nil,
|
||||
requested: []string{"raw"},
|
||||
error: true,
|
||||
},
|
||||
{
|
||||
supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
|
||||
requested: nil,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
|
||||
requested: []string{"v1.raw"},
|
||||
error: true,
|
||||
},
|
||||
{
|
||||
supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
|
||||
requested: []string{"v1.raw", "v1.base64"},
|
||||
error: true,
|
||||
}, {
|
||||
supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
|
||||
requested: []string{"v1.raw", "raw"},
|
||||
expected: "raw",
|
||||
},
|
||||
{
|
||||
supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64},
|
||||
requested: []string{"v1.raw"},
|
||||
expected: "v1.raw",
|
||||
},
|
||||
{
|
||||
supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64},
|
||||
requested: []string{"v2.base64"},
|
||||
expected: "v2.base64",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestVersionedConn(t *testing.T) {
|
||||
for i, test := range versionTests() {
|
||||
func() {
|
||||
supportedProtocols := map[string]ChannelProtocolConfig{}
|
||||
for p, binary := range test.supported {
|
||||
supportedProtocols[p] = ChannelProtocolConfig{
|
||||
Binary: binary,
|
||||
Channels: []ChannelType{ReadWriteChannel},
|
||||
}
|
||||
}
|
||||
conn := NewConn(supportedProtocols)
|
||||
// note that it's not enough to wait for conn.ready to avoid a race here. Hence,
|
||||
// we use a channel.
|
||||
selectedProtocol := make(chan string)
|
||||
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
p, _, _ := conn.Open(w, req)
|
||||
selectedProtocol <- p
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config.Protocol = test.requested
|
||||
client, err := websocket.DialConfig(config)
|
||||
if err != nil {
|
||||
if !test.error {
|
||||
t.Fatalf("test %d: didn't expect error: %v", i, err)
|
||||
} else {
|
||||
return
|
||||
}
|
||||
}
|
||||
defer client.Close()
|
||||
if test.error && err == nil {
|
||||
t.Fatalf("test %d: expected an error", i)
|
||||
}
|
||||
|
||||
<-conn.ready
|
||||
if got, expected := <-selectedProtocol, test.expected; got != expected {
|
||||
t.Fatalf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
|
@ -1,21 +0,0 @@
|
|||
/*
|
||||
Copyright 2015 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Package wsstream contains utilities for streaming content over WebSockets.
|
||||
// The Conn type allows callers to multiplex multiple read/write channels over
|
||||
// a single websocket. The Reader type allows an io.Reader to be copied over
|
||||
// a websocket channel as binary content.
|
||||
package wsstream // import "k8s.io/apiserver/pkg/util/wsstream"
|
|
@ -1,177 +0,0 @@
|
|||
/*
|
||||
Copyright 2015 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package wsstream
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/websocket"
|
||||
|
||||
"k8s.io/apimachinery/pkg/util/runtime"
|
||||
)
|
||||
|
||||
// The WebSocket subprotocol "binary.k8s.io" will only send messages to the
|
||||
// client and ignore messages sent to the server. The received messages are
|
||||
// the exact bytes written to the stream. Zero byte messages are possible.
|
||||
const binaryWebSocketProtocol = "binary.k8s.io"
|
||||
|
||||
// The WebSocket subprotocol "base64.binary.k8s.io" will only send messages to the
|
||||
// client and ignore messages sent to the server. The received messages are
|
||||
// a base64 version of the bytes written to the stream. Zero byte messages are
|
||||
// possible.
|
||||
const base64BinaryWebSocketProtocol = "base64.binary.k8s.io"
|
||||
|
||||
// ReaderProtocolConfig describes a websocket subprotocol with one stream.
|
||||
type ReaderProtocolConfig struct {
|
||||
Binary bool
|
||||
}
|
||||
|
||||
// NewDefaultReaderProtocols returns a stream protocol map with the
|
||||
// subprotocols "", "channel.k8s.io", "base64.channel.k8s.io".
|
||||
func NewDefaultReaderProtocols() map[string]ReaderProtocolConfig {
|
||||
return map[string]ReaderProtocolConfig{
|
||||
"": {Binary: true},
|
||||
binaryWebSocketProtocol: {Binary: true},
|
||||
base64BinaryWebSocketProtocol: {Binary: false},
|
||||
}
|
||||
}
|
||||
|
||||
// Reader supports returning an arbitrary byte stream over a websocket channel.
|
||||
type Reader struct {
|
||||
err chan error
|
||||
r io.Reader
|
||||
ping bool
|
||||
timeout time.Duration
|
||||
protocols map[string]ReaderProtocolConfig
|
||||
selectedProtocol string
|
||||
|
||||
handleCrash func(additionalHandlers ...func(interface{})) // overridable for testing
|
||||
}
|
||||
|
||||
// NewReader creates a WebSocket pipe that will copy the contents of r to a provided
|
||||
// WebSocket connection. If ping is true, a zero length message will be sent to the client
|
||||
// before the stream begins reading.
|
||||
//
|
||||
// The protocols parameter maps subprotocol names to StreamProtocols. The empty string
|
||||
// subprotocol name is used if websocket.Config.Protocol is empty.
|
||||
func NewReader(r io.Reader, ping bool, protocols map[string]ReaderProtocolConfig) *Reader {
|
||||
return &Reader{
|
||||
r: r,
|
||||
err: make(chan error),
|
||||
ping: ping,
|
||||
protocols: protocols,
|
||||
handleCrash: runtime.HandleCrash,
|
||||
}
|
||||
}
|
||||
|
||||
// SetIdleTimeout sets the interval for both reads and writes before timeout. If not specified,
|
||||
// there is no timeout on the reader.
|
||||
func (r *Reader) SetIdleTimeout(duration time.Duration) {
|
||||
r.timeout = duration
|
||||
}
|
||||
|
||||
func (r *Reader) handshake(config *websocket.Config, req *http.Request) error {
|
||||
supportedProtocols := make([]string, 0, len(r.protocols))
|
||||
for p := range r.protocols {
|
||||
supportedProtocols = append(supportedProtocols, p)
|
||||
}
|
||||
return handshake(config, req, supportedProtocols)
|
||||
}
|
||||
|
||||
// Copy the reader to the response. The created WebSocket is closed after this
|
||||
// method completes.
|
||||
func (r *Reader) Copy(w http.ResponseWriter, req *http.Request) error {
|
||||
go func() {
|
||||
defer r.handleCrash()
|
||||
websocket.Server{Handshake: r.handshake, Handler: r.handle}.ServeHTTP(w, req)
|
||||
}()
|
||||
return <-r.err
|
||||
}
|
||||
|
||||
// handle implements a WebSocket handler.
|
||||
func (r *Reader) handle(ws *websocket.Conn) {
|
||||
// Close the connection when the client requests it, or when we finish streaming, whichever happens first
|
||||
closeConnOnce := &sync.Once{}
|
||||
closeConn := func() {
|
||||
closeConnOnce.Do(func() {
|
||||
ws.Close()
|
||||
})
|
||||
}
|
||||
|
||||
negotiated := ws.Config().Protocol
|
||||
r.selectedProtocol = negotiated[0]
|
||||
defer close(r.err)
|
||||
defer closeConn()
|
||||
|
||||
go func() {
|
||||
defer runtime.HandleCrash()
|
||||
// This blocks until the connection is closed.
|
||||
// Client should not send anything.
|
||||
IgnoreReceives(ws, r.timeout)
|
||||
// Once the client closes, we should also close
|
||||
closeConn()
|
||||
}()
|
||||
|
||||
r.err <- messageCopy(ws, r.r, !r.protocols[r.selectedProtocol].Binary, r.ping, r.timeout)
|
||||
}
|
||||
|
||||
func resetTimeout(ws *websocket.Conn, timeout time.Duration) {
|
||||
if timeout > 0 {
|
||||
ws.SetDeadline(time.Now().Add(timeout))
|
||||
}
|
||||
}
|
||||
|
||||
func messageCopy(ws *websocket.Conn, r io.Reader, base64Encode, ping bool, timeout time.Duration) error {
|
||||
buf := make([]byte, 2048)
|
||||
if ping {
|
||||
resetTimeout(ws, timeout)
|
||||
if base64Encode {
|
||||
if err := websocket.Message.Send(ws, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := websocket.Message.Send(ws, []byte{}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
for {
|
||||
resetTimeout(ws, timeout)
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
if n > 0 {
|
||||
if base64Encode {
|
||||
if err := websocket.Message.Send(ws, base64.StdEncoding.EncodeToString(buf[:n])); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := websocket.Message.Send(ws, buf[:n]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,294 +0,0 @@
|
|||
/*
|
||||
Copyright 2015 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package wsstream
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
func TestStream(t *testing.T) {
|
||||
input := "some random text"
|
||||
r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
|
||||
r.SetIdleTimeout(time.Second)
|
||||
data, err := readWebSocket(r, t, nil)
|
||||
if !reflect.DeepEqual(data, []byte(input)) {
|
||||
t.Errorf("unexpected server read: %v", data)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamPing(t *testing.T) {
|
||||
input := "some random text"
|
||||
r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
|
||||
r.SetIdleTimeout(time.Second)
|
||||
err := expectWebSocketFrames(r, t, nil, [][]byte{
|
||||
{},
|
||||
[]byte(input),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamBase64(t *testing.T) {
|
||||
input := "some random text"
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(input))
|
||||
r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
|
||||
data, err := readWebSocket(r, t, nil, "base64.binary.k8s.io")
|
||||
if !reflect.DeepEqual(data, []byte(encoded)) {
|
||||
t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded))
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamVersionedBase64(t *testing.T) {
|
||||
input := "some random text"
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(input))
|
||||
r := NewReader(bytes.NewBuffer([]byte(input)), true, map[string]ReaderProtocolConfig{
|
||||
"": {Binary: true},
|
||||
"binary.k8s.io": {Binary: true},
|
||||
"base64.binary.k8s.io": {Binary: false},
|
||||
"v1.binary.k8s.io": {Binary: true},
|
||||
"v1.base64.binary.k8s.io": {Binary: false},
|
||||
"v2.binary.k8s.io": {Binary: true},
|
||||
"v2.base64.binary.k8s.io": {Binary: false},
|
||||
})
|
||||
data, err := readWebSocket(r, t, nil, "v2.base64.binary.k8s.io")
|
||||
if !reflect.DeepEqual(data, []byte(encoded)) {
|
||||
t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded))
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamVersionedCopy(t *testing.T) {
|
||||
for i, test := range versionTests() {
|
||||
func() {
|
||||
supportedProtocols := map[string]ReaderProtocolConfig{}
|
||||
for p, binary := range test.supported {
|
||||
supportedProtocols[p] = ReaderProtocolConfig{
|
||||
Binary: binary,
|
||||
}
|
||||
}
|
||||
input := "some random text"
|
||||
r := NewReader(bytes.NewBuffer([]byte(input)), true, supportedProtocols)
|
||||
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
err := r.Copy(w, req)
|
||||
if err != nil {
|
||||
w.WriteHeader(503)
|
||||
}
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
config.Protocol = test.requested
|
||||
client, err := websocket.DialConfig(config)
|
||||
if err != nil {
|
||||
if !test.error {
|
||||
t.Errorf("test %d: didn't expect error: %v", i, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
defer client.Close()
|
||||
if test.error && err == nil {
|
||||
t.Errorf("test %d: expected an error", i)
|
||||
return
|
||||
}
|
||||
|
||||
<-r.err
|
||||
if got, expected := r.selectedProtocol, test.expected; got != expected {
|
||||
t.Errorf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamError(t *testing.T) {
|
||||
input := "some random text"
|
||||
errs := &errorReader{
|
||||
reads: [][]byte{
|
||||
[]byte("some random"),
|
||||
[]byte(" text"),
|
||||
},
|
||||
err: fmt.Errorf("bad read"),
|
||||
}
|
||||
r := NewReader(errs, false, NewDefaultReaderProtocols())
|
||||
|
||||
data, err := readWebSocket(r, t, nil)
|
||||
if !reflect.DeepEqual(data, []byte(input)) {
|
||||
t.Errorf("unexpected server read: %v", data)
|
||||
}
|
||||
if err == nil || err.Error() != "bad read" {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamSurvivesPanic(t *testing.T) {
|
||||
input := "some random text"
|
||||
errs := &errorReader{
|
||||
reads: [][]byte{
|
||||
[]byte("some random"),
|
||||
[]byte(" text"),
|
||||
},
|
||||
panicMessage: "bad read",
|
||||
}
|
||||
r := NewReader(errs, false, NewDefaultReaderProtocols())
|
||||
|
||||
// do not call runtime.HandleCrash() in handler. Otherwise, the tests are interrupted.
|
||||
r.handleCrash = func(additionalHandlers ...func(interface{})) { recover() }
|
||||
|
||||
data, err := readWebSocket(r, t, nil)
|
||||
if !reflect.DeepEqual(data, []byte(input)) {
|
||||
t.Errorf("unexpected server read: %v", data)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamClosedDuringRead(t *testing.T) {
|
||||
for i := 0; i < 25; i++ {
|
||||
ch := make(chan struct{})
|
||||
input := "some random text"
|
||||
errs := &errorReader{
|
||||
reads: [][]byte{
|
||||
[]byte("some random"),
|
||||
[]byte(" text"),
|
||||
},
|
||||
err: fmt.Errorf("stuff"),
|
||||
pause: ch,
|
||||
}
|
||||
r := NewReader(errs, false, NewDefaultReaderProtocols())
|
||||
|
||||
data, err := readWebSocket(r, t, func(c *websocket.Conn) {
|
||||
c.Close()
|
||||
close(ch)
|
||||
})
|
||||
// verify that the data returned by the server on an early close always has a specific error
|
||||
if err == nil || !strings.Contains(err.Error(), "use of closed network connection") {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// verify that the data returned is a strict subset of the input
|
||||
if !bytes.HasPrefix([]byte(input), data) && len(data) != 0 {
|
||||
t.Fatalf("unexpected server read: %q", string(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type errorReader struct {
|
||||
reads [][]byte
|
||||
err error
|
||||
panicMessage string
|
||||
pause chan struct{}
|
||||
}
|
||||
|
||||
func (r *errorReader) Read(p []byte) (int, error) {
|
||||
if len(r.reads) == 0 {
|
||||
if r.pause != nil {
|
||||
<-r.pause
|
||||
}
|
||||
if len(r.panicMessage) != 0 {
|
||||
panic(r.panicMessage)
|
||||
}
|
||||
return 0, r.err
|
||||
}
|
||||
next := r.reads[0]
|
||||
r.reads = r.reads[1:]
|
||||
copy(p, next)
|
||||
return len(next), nil
|
||||
}
|
||||
|
||||
func readWebSocket(r *Reader, t *testing.T, fn func(*websocket.Conn), protocols ...string) ([]byte, error) {
|
||||
errCh := make(chan error, 1)
|
||||
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
errCh <- r.Copy(w, req)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr)
|
||||
config.Protocol = protocols
|
||||
client, err := websocket.DialConfig(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
if fn != nil {
|
||||
fn(client)
|
||||
}
|
||||
|
||||
data, err := ioutil.ReadAll(client)
|
||||
if err != nil {
|
||||
return data, err
|
||||
}
|
||||
return data, <-errCh
|
||||
}
|
||||
|
||||
func expectWebSocketFrames(r *Reader, t *testing.T, fn func(*websocket.Conn), frames [][]byte, protocols ...string) error {
|
||||
errCh := make(chan error, 1)
|
||||
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
errCh <- r.Copy(w, req)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr)
|
||||
config.Protocol = protocols
|
||||
ws, err := websocket.DialConfig(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer ws.Close()
|
||||
|
||||
if fn != nil {
|
||||
fn(ws)
|
||||
}
|
||||
|
||||
for i := range frames {
|
||||
var data []byte
|
||||
if err := websocket.Message.Receive(ws, &data); err != nil {
|
||||
return err
|
||||
}
|
||||
if !reflect.DeepEqual(frames[i], data) {
|
||||
return fmt.Errorf("frame %d did not match expected: %v", data, err)
|
||||
}
|
||||
}
|
||||
var data []byte
|
||||
if err := websocket.Message.Receive(ws, &data); err != io.EOF {
|
||||
return fmt.Errorf("expected no more frames: %v (%v)", err, data)
|
||||
}
|
||||
return <-errCh
|
||||
}
|
Loading…
Reference in New Issue