Hardening the websocket client. (#262)

* Cleanup websocket connection, actually test reconnects.

* Some more cleanup.

* Locally define connFactory to avoid races.

* Move locks around, harden test.

* Add logging.

* Drop redundant target.

* Move message encoding outside of the writerLock.

* Fix assignment nit.

* Remove named return value.

* Add close signal to long-running loops.

* Add todo for returning a messageType.

* Bump header to 2019.

* Add note on draining the messageChan.

* Drop target from signature.

* Drop target from test.

* Add a more speaking example to draining the messageChan.

* Fix typo.

* Relax read lock, improve test.

* Bump test coverage.

* Add double shutdown test.

* Remove code duplication in test.
This commit is contained in:
Markus Thömmes 2019-02-08 18:51:41 +01:00 committed by Knative Prow Robot
parent bd5a391c64
commit 3df885a8fc
2 changed files with 255 additions and 158 deletions

View File

@ -1,5 +1,5 @@
/* /*
Copyright 2018 The Knative Authors Copyright 2019 The Knative Authors
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -20,11 +20,14 @@ import (
"bytes" "bytes"
"encoding/gob" "encoding/gob"
"errors" "errors"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"sync" "sync"
"time" "time"
"go.uber.org/zap"
"k8s.io/apimachinery/pkg/util/wait" "k8s.io/apimachinery/pkg/util/wait"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@ -35,13 +38,8 @@ var (
// but no connection is already created. // but no connection is already created.
ErrConnectionNotEstablished = errors.New("connection has not yet been established") ErrConnectionNotEstablished = errors.New("connection has not yet been established")
connFactory = func(target string) (rawConnection, error) { // errShuttingDown is returned internally once the shutdown signal has been sent.
dialer := &websocket.Dialer{ errShuttingDown = errors.New("shutdown in progress")
HandshakeTimeout: 3 * time.Second,
}
conn, _, err := dialer.Dial(target, nil)
return conn, err
}
) )
// RawConnection is an interface defining the methods needed // RawConnection is an interface defining the methods needed
@ -54,9 +52,11 @@ type rawConnection interface {
// ManagedConnection represents a websocket connection. // ManagedConnection represents a websocket connection.
type ManagedConnection struct { type ManagedConnection struct {
target string connection rawConnection
connection rawConnection connectionFactory func() (rawConnection, error)
closeChan chan struct{}
closeChan chan struct{}
closeOnce sync.Once
// If set, messages will be forwarded to this channel // If set, messages will be forwarded to this channel
messageChan chan []byte messageChan chan []byte
@ -78,8 +78,8 @@ type ManagedConnection struct {
// that can only send messages to the endpoint it connects to. // that can only send messages to the endpoint it connects to.
// The connection will continuously be kept alive and reconnected // The connection will continuously be kept alive and reconnected
// in case of a loss of connectivity. // in case of a loss of connectivity.
func NewDurableSendingConnection(target string) *ManagedConnection { func NewDurableSendingConnection(target string, logger *zap.SugaredLogger) *ManagedConnection {
return NewDurableConnection(target, nil) return NewDurableConnection(target, nil, logger)
} }
// NewDurableConnection creates a new websocket connection, that // NewDurableConnection creates a new websocket connection, that
@ -87,30 +87,44 @@ func NewDurableSendingConnection(target string) *ManagedConnection {
// send messages to the endpoint it connects to. // send messages to the endpoint it connects to.
// The connection will continuously be kept alive and reconnected // The connection will continuously be kept alive and reconnected
// in case of a loss of connectivity. // in case of a loss of connectivity.
func NewDurableConnection(target string, messageChan chan []byte) *ManagedConnection { //
c := newConnection(target, messageChan) // Note: The given channel needs to be drained after calling `Shutdown`
// to not cause any deadlocks. If the channel's buffer is likely to be
// filled, this needs to happen in separate goroutines, i.e.
//
// go func() {conn.Shutdown(); close(messageChan)}
// go func() {for range messageChan {}}
func NewDurableConnection(target string, messageChan chan []byte, logger *zap.SugaredLogger) *ManagedConnection {
websocketConnectionFactory := func() (rawConnection, error) {
dialer := &websocket.Dialer{
HandshakeTimeout: 3 * time.Second,
}
conn, _, err := dialer.Dial(target, nil)
return conn, err
}
c := newConnection(websocketConnectionFactory, messageChan)
// Keep the connection alive asynchronously and reconnect on // Keep the connection alive asynchronously and reconnect on
// connection failure. // connection failure.
go func() { go func() {
// If the close signal races the connection attempt, make
// sure the connection actually closes.
defer func() {
c.connectionLock.RLock()
defer c.connectionLock.RUnlock()
if conn := c.connection; conn != nil {
conn.Close()
}
}()
for { for {
select { select {
default: default:
logger.Infof("Connecting to %q", target)
if err := c.connect(); err != nil { if err := c.connect(); err != nil {
logger.Errorw(fmt.Sprintf("Connecting to %q failed", target), zap.Error(err))
continue continue
} }
c.keepalive() logger.Infof("Connected to %q", target)
if err := c.keepalive(); err != nil {
logger.Errorw(fmt.Sprintf("Connection to %q broke down, reconnecting...", target), zap.Error(err))
}
if err := c.closeConnection(); err != nil {
logger.Errorw("Failed to close the connection after crashing", zap.Error(err))
}
case <-c.closeChan: case <-c.closeChan:
logger.Infof("Connection to %q is being shutdown", target)
return return
} }
} }
@ -120,11 +134,11 @@ func NewDurableConnection(target string, messageChan chan []byte) *ManagedConnec
} }
// newConnection creates a new connection primitive. // newConnection creates a new connection primitive.
func newConnection(target string, messageChan chan []byte) *ManagedConnection { func newConnection(connFactory func() (rawConnection, error), messageChan chan []byte) *ManagedConnection {
conn := &ManagedConnection{ conn := &ManagedConnection{
target: target, connectionFactory: connFactory,
closeChan: make(chan struct{}, 1), closeChan: make(chan struct{}),
messageChan: messageChan, messageChan: messageChan,
connectionBackoff: wait.Backoff{ connectionBackoff: wait.Backoff{
Duration: 100 * time.Millisecond, Duration: 100 * time.Millisecond,
Factor: 1.3, Factor: 1.3,
@ -137,58 +151,87 @@ func newConnection(target string, messageChan chan []byte) *ManagedConnection {
} }
// connect tries to establish a websocket connection. // connect tries to establish a websocket connection.
func (c *ManagedConnection) connect() (err error) { func (c *ManagedConnection) connect() error {
var err error
wait.ExponentialBackoff(c.connectionBackoff, func() (bool, error) { wait.ExponentialBackoff(c.connectionBackoff, func() (bool, error) {
var conn rawConnection select {
conn, err = connFactory(c.target) default:
if err != nil { var conn rawConnection
return false, nil conn, err = c.connectionFactory()
} if err != nil {
c.connectionLock.Lock() return false, nil
defer c.connectionLock.Unlock() }
c.connection = conn c.connectionLock.Lock()
return true, nil defer c.connectionLock.Unlock()
c.connection = conn
return true, nil
case <-c.closeChan:
err = errShuttingDown
return false, err
}
}) })
return err return err
} }
// keepalive keeps the connection open and reads control messages. // keepalive keeps the connection open.
// All messages are discarded. func (c *ManagedConnection) keepalive() error {
func (c *ManagedConnection) keepalive() (err error) { for {
select {
default:
if err := c.read(); err != nil {
return err
}
case <-c.closeChan:
return errShuttingDown
}
}
}
// closeConnection closes the underlying websocket connection.
func (c *ManagedConnection) closeConnection() error {
c.connectionLock.Lock()
defer c.connectionLock.Unlock()
if c.connection != nil {
err := c.connection.Close()
c.connection = nil
return err
}
return nil
}
// read reads the next message from the connection.
// If a messageChan is supplied and the current message type is not
// a control message, the message is sent to that channel.
func (c *ManagedConnection) read() error {
c.connectionLock.RLock()
defer c.connectionLock.RUnlock()
if c.connection == nil {
return ErrConnectionNotEstablished
}
c.readerLock.Lock() c.readerLock.Lock()
defer c.readerLock.Unlock() defer c.readerLock.Unlock()
for { messageType, reader, err := c.connection.NextReader()
func() { if err != nil {
c.connectionLock.RLock() return err
defer c.connectionLock.RUnlock() }
if conn := c.connection; conn != nil { // Send the message to the channel if its an application level message
var reader io.Reader // and if that channel is set.
var messageType int // TODO(markusthoemmes): Return the messageType along with the payload.
messageType, reader, err = conn.NextReader() if c.messageChan != nil && (messageType == websocket.TextMessage || messageType == websocket.BinaryMessage) {
if err != nil { if message, _ := ioutil.ReadAll(reader); message != nil {
conn.Close() c.messageChan <- message
}
// Send the message to the channel if its an application level message
// and if that channel is set.
if c.messageChan != nil && (messageType == websocket.TextMessage || messageType == websocket.BinaryMessage) {
if message, _ := ioutil.ReadAll(reader); message != nil {
c.messageChan <- message
}
}
} else {
err = ErrConnectionNotEstablished
}
}()
if err != nil {
return err
} }
} }
return nil
} }
// Send sends an encodable message over the websocket connection. // Send sends an encodable message over the websocket connection.
@ -196,31 +239,27 @@ func (c *ManagedConnection) Send(msg interface{}) error {
c.connectionLock.RLock() c.connectionLock.RLock()
defer c.connectionLock.RUnlock() defer c.connectionLock.RUnlock()
conn := c.connection if c.connection == nil {
if conn == nil {
return ErrConnectionNotEstablished return ErrConnectionNotEstablished
} }
c.writerLock.Lock()
defer c.writerLock.Unlock()
var b bytes.Buffer var b bytes.Buffer
enc := gob.NewEncoder(&b) enc := gob.NewEncoder(&b)
if err := enc.Encode(msg); err != nil { if err := enc.Encode(msg); err != nil {
return err return err
} }
return conn.WriteMessage(websocket.BinaryMessage, b.Bytes()) c.writerLock.Lock()
defer c.writerLock.Unlock()
return c.connection.WriteMessage(websocket.BinaryMessage, b.Bytes())
} }
// Close closes the websocket connection. // Shutdown closes the websocket connection.
func (c *ManagedConnection) Close() error { func (c *ManagedConnection) Shutdown() error {
c.closeChan <- struct{}{} c.closeOnce.Do(func() {
c.connectionLock.RLock() close(c.closeChan)
defer c.connectionLock.RUnlock() })
if conn := c.connection; conn != nil { return c.closeConnection()
return conn.Close()
}
return nil
} }

View File

@ -1,5 +1,5 @@
/* /*
Copyright 2018 The Knative Authors Copyright 2019 The Knative Authors
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -19,16 +19,20 @@ package websocket
import ( import (
"errors" "errors"
"io" "io"
"net/http"
"net/http/httptest"
"strings" "strings"
"testing" "testing"
"time" "time"
ktesting "github.com/knative/pkg/logging/testing"
"k8s.io/apimachinery/pkg/util/wait"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
const ( const propagationTimeout = 5 * time.Second
target = "test"
)
type inspectableConnection struct { type inspectableConnection struct {
nextReaderCalls chan struct{} nextReaderCalls chan struct{}
@ -39,20 +43,41 @@ type inspectableConnection struct {
} }
func (c *inspectableConnection) WriteMessage(messageType int, data []byte) error { func (c *inspectableConnection) WriteMessage(messageType int, data []byte) error {
c.writeMessageCalls <- struct{}{} if c.writeMessageCalls != nil {
c.writeMessageCalls <- struct{}{}
}
return nil return nil
} }
func (c *inspectableConnection) NextReader() (int, io.Reader, error) { func (c *inspectableConnection) NextReader() (int, io.Reader, error) {
c.nextReaderCalls <- struct{}{} if c.nextReaderCalls != nil {
c.nextReaderCalls <- struct{}{}
}
return c.nextReaderFunc() return c.nextReaderFunc()
} }
func (c *inspectableConnection) Close() error { func (c *inspectableConnection) Close() error {
c.closeCalls <- struct{}{} if c.closeCalls != nil {
c.closeCalls <- struct{}{}
}
return nil return nil
} }
// staticConnFactory returns a static connection, for example
// an inspectable connection.
func staticConnFactory(conn rawConnection) func() (rawConnection, error) {
return func() (rawConnection, error) {
return conn, nil
}
}
// errConnFactory returns a static error.
func errConnFactory(err error) func() (rawConnection, error) {
return func() (rawConnection, error) {
return nil, err
}
}
func TestRetriesWhileConnect(t *testing.T) { func TestRetriesWhileConnect(t *testing.T) {
want := 2 want := 2
got := 0 got := 0
@ -61,17 +86,17 @@ func TestRetriesWhileConnect(t *testing.T) {
closeCalls: make(chan struct{}, 1), closeCalls: make(chan struct{}, 1),
} }
connFactory = func(_ string) (rawConnection, error) { connFactory := func() (rawConnection, error) {
got++ got++
if got == want { if got == want {
return spy, nil return spy, nil
} }
return nil, errors.New("not yet") return nil, errors.New("not yet")
} }
conn := newConnection(target, nil) conn := newConnection(connFactory, nil)
conn.connect() conn.connect()
conn.Close() conn.Shutdown()
if got != want { if got != want {
t.Fatalf("Wanted %v retries. Got %v.", want, got) t.Fatalf("Wanted %v retries. Got %v.", want, got)
@ -96,11 +121,7 @@ func TestSendErrorOnEncode(t *testing.T) {
spy := &inspectableConnection{ spy := &inspectableConnection{
writeMessageCalls: make(chan struct{}, 1), writeMessageCalls: make(chan struct{}, 1),
} }
conn := newConnection(staticConnFactory(spy), nil)
connFactory = func(_ string) (rawConnection, error) {
return spy, nil
}
conn := newConnection(target, nil)
conn.connect() conn.connect()
// gob cannot encode nil values // gob cannot encode nil values
got := conn.Send(nil) got := conn.Send(nil)
@ -117,10 +138,7 @@ func TestSendMessage(t *testing.T) {
spy := &inspectableConnection{ spy := &inspectableConnection{
writeMessageCalls: make(chan struct{}, 1), writeMessageCalls: make(chan struct{}, 1),
} }
connFactory = func(_ string) (rawConnection, error) { conn := newConnection(staticConnFactory(spy), nil)
return spy, nil
}
conn := newConnection(target, nil)
conn.connect() conn.connect()
got := conn.Send("test") got := conn.Send("test")
@ -142,12 +160,9 @@ func TestReceiveMessage(t *testing.T) {
return websocket.TextMessage, strings.NewReader(testMessage), nil return websocket.TextMessage, strings.NewReader(testMessage), nil
}, },
} }
connFactory = func(_ string) (rawConnection, error) {
return spy, nil
}
messageChan := make(chan []byte, 1) messageChan := make(chan []byte, 1)
conn := newConnection(target, messageChan) conn := newConnection(staticConnFactory(spy), messageChan)
conn.connect() conn.connect()
go conn.keepalive() go conn.keepalive()
@ -162,12 +177,9 @@ func TestCloseClosesConnection(t *testing.T) {
spy := &inspectableConnection{ spy := &inspectableConnection{
closeCalls: make(chan struct{}, 1), closeCalls: make(chan struct{}, 1),
} }
connFactory = func(_ string) (rawConnection, error) { conn := newConnection(staticConnFactory(spy), nil)
return spy, nil
}
conn := newConnection(target, nil)
conn.connect() conn.connect()
conn.Close() conn.Shutdown()
if len(spy.closeCalls) != 1 { if len(spy.closeCalls) != 1 {
t.Fatalf("Expected 'Close' to be called once, got %v", len(spy.closeCalls)) t.Fatalf("Expected 'Close' to be called once, got %v", len(spy.closeCalls))
@ -178,7 +190,7 @@ func TestCloseIgnoresNoConnection(t *testing.T) {
conn := &ManagedConnection{ conn := &ManagedConnection{
closeChan: make(chan struct{}, 1), closeChan: make(chan struct{}, 1),
} }
got := conn.Close() got := conn.Shutdown()
if got != nil { if got != nil {
t.Fatalf("Expected no error, got %v", got) t.Fatalf("Expected no error, got %v", got)
@ -186,60 +198,46 @@ func TestCloseIgnoresNoConnection(t *testing.T) {
} }
func TestDurableConnectionWhenConnectionBreaksDown(t *testing.T) { func TestDurableConnectionWhenConnectionBreaksDown(t *testing.T) {
testConn := &inspectableConnection{ testPayload := "test"
nextReaderCalls: make(chan struct{}), reconnectChan := make(chan struct{})
writeMessageCalls: make(chan struct{}),
closeCalls: make(chan struct{}),
nextReaderFunc: func() (int, io.Reader, error) { upgrader := websocket.Upgrader{}
return 1, nil, errors.New("next reader errored") s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}, c, err := upgrader.Upgrade(w, r, nil)
} if err != nil {
connectAttempts := make(chan struct{}) return
connFactory = func(_ string) (rawConnection, error) { }
connectAttempts <- struct{}{}
return testConn, nil
}
conn := NewDurableSendingConnection(target)
// the connection is constantly created, tried to read from // Waits for a message to be sent before dropping the connection.
// and closed because NextReader (which holds the connection <-reconnectChan
// open) fails. c.Close()
for i := 0; i < 100; i++ { }))
<-connectAttempts defer s.Close()
<-testConn.nextReaderCalls
<-testConn.closeCalls
}
// Enter the reconnect loop logger := ktesting.TestLogger(t)
<-connectAttempts target := "ws" + strings.TrimPrefix(s.URL, "http")
conn := NewDurableSendingConnection(target, logger)
defer conn.Shutdown()
// Call 'Close' asynchronously and wait for it to reach for i := 0; i < 10; i++ {
// the channel. err := wait.PollImmediate(50*time.Millisecond, 5*time.Second, func() (bool, error) {
go conn.Close() if err := conn.Send(testPayload); err != nil {
<-testConn.closeCalls return false, nil
}
return true, nil
})
// Advance the reconnect loop until 'Close' is called. if err != nil {
<-testConn.nextReaderCalls t.Errorf("Timed out trying to send a message: %v", err)
<-testConn.closeCalls }
// Wait for the final call to 'Close' (when the loop is aborted) // Message successfully sent, instruct the server to drop the connection.
<-testConn.closeCalls reconnectChan <- struct{}{}
if len(connectAttempts) > 1 {
t.Fatalf("Expected at most one connection attempts, got %v", len(connectAttempts))
}
if len(testConn.nextReaderCalls) > 1 {
t.Fatalf("Expected at most one calls to 'NextReader', got %v", len(testConn.nextReaderCalls))
} }
} }
func TestConnectFailureReturnsError(t *testing.T) { func TestConnectFailureReturnsError(t *testing.T) {
connFactory = func(_ string) (rawConnection, error) { conn := newConnection(errConnFactory(ErrConnectionNotEstablished), nil)
return nil, ErrConnectionNotEstablished
}
conn := newConnection(target, nil)
// Shorten the connection backoff duration for this test // Shorten the connection backoff duration for this test
conn.connectionBackoff.Duration = 1 * time.Millisecond conn.connectionBackoff.Duration = 1 * time.Millisecond
@ -252,10 +250,70 @@ func TestConnectFailureReturnsError(t *testing.T) {
} }
func TestKeepaliveWithNoConnectionReturnsError(t *testing.T) { func TestKeepaliveWithNoConnectionReturnsError(t *testing.T) {
conn := newConnection(target, nil) conn := newConnection(nil, nil)
got := conn.keepalive() got := conn.keepalive()
if got == nil { if got == nil {
t.Fatal("Expected an error but got none") t.Fatal("Expected an error but got none")
} }
} }
func TestConnectLoopIsStopped(t *testing.T) {
conn := newConnection(errConnFactory(errors.New("connection error")), nil)
errorChan := make(chan error)
go func() {
errorChan <- conn.connect()
}()
conn.Shutdown()
select {
case err := <-errorChan:
if err != errShuttingDown {
t.Errorf("Wrong 'connect' error, got %v, want %v", err, errShuttingDown)
}
case <-time.After(propagationTimeout):
t.Error("Timed out waiting for the keepalive loop to stop.")
}
}
func TestKeepaliveLoopIsStopped(t *testing.T) {
spy := &inspectableConnection{
nextReaderFunc: func() (int, io.Reader, error) {
return websocket.TextMessage, nil, nil
},
}
conn := newConnection(staticConnFactory(spy), nil)
conn.connect()
errorChan := make(chan error)
go func() {
errorChan <- conn.keepalive()
}()
conn.Shutdown()
select {
case err := <-errorChan:
if err != errShuttingDown {
t.Errorf("Wrong 'keepalive' error, got %v, want %v", err, errShuttingDown)
}
case <-time.After(propagationTimeout):
t.Error("Timed out waiting for the keepalive loop to stop.")
}
}
func TestDoubleShutdown(t *testing.T) {
spy := &inspectableConnection{
closeCalls: make(chan struct{}, 2), // potentially allow 2 calls
}
conn := newConnection(staticConnFactory(spy), nil)
conn.connect()
conn.Shutdown()
conn.Shutdown()
if want, got := 1, len(spy.closeCalls); want != got {
t.Errorf("Wrong 'Close' callcount, got %d, want %d", got, want)
}
}