From 3df885a8fc84be700e8a73ef2b19ffbc24411b6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Markus=20Th=C3=B6mmes?= Date: Fri, 8 Feb 2019 18:51:41 +0100 Subject: [PATCH] Hardening the websocket client. (#262) * Cleanup websocket connection, actually test reconnects. * Some more cleanup. * Locally define connFactory to avoid races. * Move locks around, harden test. * Add logging. * Drop redundant target. * Move message encoding outside of the writerLock. * Fix assignment nit. * Remove named return value. * Add close signal to long-running loops. * Add todo for returning a messageType. * Bump header to 2019. * Add note on draining the messageChan. * Drop target from signature. * Drop target from test. * Add a more speaking example to draining the messageChan. * Fix typo. * Relax read lock, improve test. * Bump test coverage. * Add double shutdown test. * Remove code duplication in test. --- websocket/connection.go | 207 +++++++++++++++++++++-------------- websocket/connection_test.go | 206 +++++++++++++++++++++------------- 2 files changed, 255 insertions(+), 158 deletions(-) diff --git a/websocket/connection.go b/websocket/connection.go index f644d9709..1e19455fb 100644 --- a/websocket/connection.go +++ b/websocket/connection.go @@ -1,5 +1,5 @@ /* -Copyright 2018 The Knative Authors +Copyright 2019 The Knative Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,11 +20,14 @@ import ( "bytes" "encoding/gob" "errors" + "fmt" "io" "io/ioutil" "sync" "time" + "go.uber.org/zap" + "k8s.io/apimachinery/pkg/util/wait" "github.com/gorilla/websocket" @@ -35,13 +38,8 @@ var ( // but no connection is already created. ErrConnectionNotEstablished = errors.New("connection has not yet been established") - connFactory = func(target string) (rawConnection, error) { - dialer := &websocket.Dialer{ - HandshakeTimeout: 3 * time.Second, - } - conn, _, err := dialer.Dial(target, nil) - return conn, err - } + // errShuttingDown is returned internally once the shutdown signal has been sent. + errShuttingDown = errors.New("shutdown in progress") ) // RawConnection is an interface defining the methods needed @@ -54,9 +52,11 @@ type rawConnection interface { // ManagedConnection represents a websocket connection. type ManagedConnection struct { - target string - connection rawConnection - closeChan chan struct{} + connection rawConnection + connectionFactory func() (rawConnection, error) + + closeChan chan struct{} + closeOnce sync.Once // If set, messages will be forwarded to this channel messageChan chan []byte @@ -78,8 +78,8 @@ type ManagedConnection struct { // that can only send messages to the endpoint it connects to. // The connection will continuously be kept alive and reconnected // in case of a loss of connectivity. -func NewDurableSendingConnection(target string) *ManagedConnection { - return NewDurableConnection(target, nil) +func NewDurableSendingConnection(target string, logger *zap.SugaredLogger) *ManagedConnection { + return NewDurableConnection(target, nil, logger) } // NewDurableConnection creates a new websocket connection, that @@ -87,30 +87,44 @@ func NewDurableSendingConnection(target string) *ManagedConnection { // send messages to the endpoint it connects to. // The connection will continuously be kept alive and reconnected // in case of a loss of connectivity. -func NewDurableConnection(target string, messageChan chan []byte) *ManagedConnection { - c := newConnection(target, messageChan) +// +// Note: The given channel needs to be drained after calling `Shutdown` +// to not cause any deadlocks. If the channel's buffer is likely to be +// filled, this needs to happen in separate goroutines, i.e. +// +// go func() {conn.Shutdown(); close(messageChan)} +// go func() {for range messageChan {}} +func NewDurableConnection(target string, messageChan chan []byte, logger *zap.SugaredLogger) *ManagedConnection { + websocketConnectionFactory := func() (rawConnection, error) { + dialer := &websocket.Dialer{ + HandshakeTimeout: 3 * time.Second, + } + conn, _, err := dialer.Dial(target, nil) + return conn, err + } + + c := newConnection(websocketConnectionFactory, messageChan) // Keep the connection alive asynchronously and reconnect on // connection failure. go func() { - // If the close signal races the connection attempt, make - // sure the connection actually closes. - defer func() { - c.connectionLock.RLock() - defer c.connectionLock.RUnlock() - - if conn := c.connection; conn != nil { - conn.Close() - } - }() for { select { default: + logger.Infof("Connecting to %q", target) if err := c.connect(); err != nil { + logger.Errorw(fmt.Sprintf("Connecting to %q failed", target), zap.Error(err)) continue } - c.keepalive() + logger.Infof("Connected to %q", target) + if err := c.keepalive(); err != nil { + logger.Errorw(fmt.Sprintf("Connection to %q broke down, reconnecting...", target), zap.Error(err)) + } + if err := c.closeConnection(); err != nil { + logger.Errorw("Failed to close the connection after crashing", zap.Error(err)) + } case <-c.closeChan: + logger.Infof("Connection to %q is being shutdown", target) return } } @@ -120,11 +134,11 @@ func NewDurableConnection(target string, messageChan chan []byte) *ManagedConnec } // newConnection creates a new connection primitive. -func newConnection(target string, messageChan chan []byte) *ManagedConnection { +func newConnection(connFactory func() (rawConnection, error), messageChan chan []byte) *ManagedConnection { conn := &ManagedConnection{ - target: target, - closeChan: make(chan struct{}, 1), - messageChan: messageChan, + connectionFactory: connFactory, + closeChan: make(chan struct{}), + messageChan: messageChan, connectionBackoff: wait.Backoff{ Duration: 100 * time.Millisecond, Factor: 1.3, @@ -137,58 +151,87 @@ func newConnection(target string, messageChan chan []byte) *ManagedConnection { } // connect tries to establish a websocket connection. -func (c *ManagedConnection) connect() (err error) { +func (c *ManagedConnection) connect() error { + var err error wait.ExponentialBackoff(c.connectionBackoff, func() (bool, error) { - var conn rawConnection - conn, err = connFactory(c.target) - if err != nil { - return false, nil - } - c.connectionLock.Lock() - defer c.connectionLock.Unlock() + select { + default: + var conn rawConnection + conn, err = c.connectionFactory() + if err != nil { + return false, nil + } - c.connection = conn - return true, nil + c.connectionLock.Lock() + defer c.connectionLock.Unlock() + + c.connection = conn + return true, nil + case <-c.closeChan: + err = errShuttingDown + return false, err + } }) return err } -// keepalive keeps the connection open and reads control messages. -// All messages are discarded. -func (c *ManagedConnection) keepalive() (err error) { +// keepalive keeps the connection open. +func (c *ManagedConnection) keepalive() error { + for { + select { + default: + if err := c.read(); err != nil { + return err + } + case <-c.closeChan: + return errShuttingDown + } + } +} + +// closeConnection closes the underlying websocket connection. +func (c *ManagedConnection) closeConnection() error { + c.connectionLock.Lock() + defer c.connectionLock.Unlock() + + if c.connection != nil { + err := c.connection.Close() + c.connection = nil + return err + } + return nil +} + +// read reads the next message from the connection. +// If a messageChan is supplied and the current message type is not +// a control message, the message is sent to that channel. +func (c *ManagedConnection) read() error { + c.connectionLock.RLock() + defer c.connectionLock.RUnlock() + + if c.connection == nil { + return ErrConnectionNotEstablished + } + c.readerLock.Lock() defer c.readerLock.Unlock() - for { - func() { - c.connectionLock.RLock() - defer c.connectionLock.RUnlock() + messageType, reader, err := c.connection.NextReader() + if err != nil { + return err + } - if conn := c.connection; conn != nil { - var reader io.Reader - var messageType int - messageType, reader, err = conn.NextReader() - if err != nil { - conn.Close() - } - - // Send the message to the channel if its an application level message - // and if that channel is set. - if c.messageChan != nil && (messageType == websocket.TextMessage || messageType == websocket.BinaryMessage) { - if message, _ := ioutil.ReadAll(reader); message != nil { - c.messageChan <- message - } - } - } else { - err = ErrConnectionNotEstablished - } - }() - - if err != nil { - return err + // Send the message to the channel if its an application level message + // and if that channel is set. + // TODO(markusthoemmes): Return the messageType along with the payload. + if c.messageChan != nil && (messageType == websocket.TextMessage || messageType == websocket.BinaryMessage) { + if message, _ := ioutil.ReadAll(reader); message != nil { + c.messageChan <- message } } + + return nil } // Send sends an encodable message over the websocket connection. @@ -196,31 +239,27 @@ func (c *ManagedConnection) Send(msg interface{}) error { c.connectionLock.RLock() defer c.connectionLock.RUnlock() - conn := c.connection - if conn == nil { + if c.connection == nil { return ErrConnectionNotEstablished } - c.writerLock.Lock() - defer c.writerLock.Unlock() - var b bytes.Buffer enc := gob.NewEncoder(&b) if err := enc.Encode(msg); err != nil { return err } - return conn.WriteMessage(websocket.BinaryMessage, b.Bytes()) + c.writerLock.Lock() + defer c.writerLock.Unlock() + + return c.connection.WriteMessage(websocket.BinaryMessage, b.Bytes()) } -// Close closes the websocket connection. -func (c *ManagedConnection) Close() error { - c.closeChan <- struct{}{} - c.connectionLock.RLock() - defer c.connectionLock.RUnlock() +// Shutdown closes the websocket connection. +func (c *ManagedConnection) Shutdown() error { + c.closeOnce.Do(func() { + close(c.closeChan) + }) - if conn := c.connection; conn != nil { - return conn.Close() - } - return nil + return c.closeConnection() } diff --git a/websocket/connection_test.go b/websocket/connection_test.go index eb080eea1..5c6abeee7 100644 --- a/websocket/connection_test.go +++ b/websocket/connection_test.go @@ -1,5 +1,5 @@ /* -Copyright 2018 The Knative Authors +Copyright 2019 The Knative Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,16 +19,20 @@ package websocket import ( "errors" "io" + "net/http" + "net/http/httptest" "strings" "testing" "time" + ktesting "github.com/knative/pkg/logging/testing" + + "k8s.io/apimachinery/pkg/util/wait" + "github.com/gorilla/websocket" ) -const ( - target = "test" -) +const propagationTimeout = 5 * time.Second type inspectableConnection struct { nextReaderCalls chan struct{} @@ -39,20 +43,41 @@ type inspectableConnection struct { } func (c *inspectableConnection) WriteMessage(messageType int, data []byte) error { - c.writeMessageCalls <- struct{}{} + if c.writeMessageCalls != nil { + c.writeMessageCalls <- struct{}{} + } return nil } func (c *inspectableConnection) NextReader() (int, io.Reader, error) { - c.nextReaderCalls <- struct{}{} + if c.nextReaderCalls != nil { + c.nextReaderCalls <- struct{}{} + } return c.nextReaderFunc() } func (c *inspectableConnection) Close() error { - c.closeCalls <- struct{}{} + if c.closeCalls != nil { + c.closeCalls <- struct{}{} + } return nil } +// staticConnFactory returns a static connection, for example +// an inspectable connection. +func staticConnFactory(conn rawConnection) func() (rawConnection, error) { + return func() (rawConnection, error) { + return conn, nil + } +} + +// errConnFactory returns a static error. +func errConnFactory(err error) func() (rawConnection, error) { + return func() (rawConnection, error) { + return nil, err + } +} + func TestRetriesWhileConnect(t *testing.T) { want := 2 got := 0 @@ -61,17 +86,17 @@ func TestRetriesWhileConnect(t *testing.T) { closeCalls: make(chan struct{}, 1), } - connFactory = func(_ string) (rawConnection, error) { + connFactory := func() (rawConnection, error) { got++ if got == want { return spy, nil } return nil, errors.New("not yet") } - conn := newConnection(target, nil) + conn := newConnection(connFactory, nil) conn.connect() - conn.Close() + conn.Shutdown() if got != want { t.Fatalf("Wanted %v retries. Got %v.", want, got) @@ -96,11 +121,7 @@ func TestSendErrorOnEncode(t *testing.T) { spy := &inspectableConnection{ writeMessageCalls: make(chan struct{}, 1), } - - connFactory = func(_ string) (rawConnection, error) { - return spy, nil - } - conn := newConnection(target, nil) + conn := newConnection(staticConnFactory(spy), nil) conn.connect() // gob cannot encode nil values got := conn.Send(nil) @@ -117,10 +138,7 @@ func TestSendMessage(t *testing.T) { spy := &inspectableConnection{ writeMessageCalls: make(chan struct{}, 1), } - connFactory = func(_ string) (rawConnection, error) { - return spy, nil - } - conn := newConnection(target, nil) + conn := newConnection(staticConnFactory(spy), nil) conn.connect() got := conn.Send("test") @@ -142,12 +160,9 @@ func TestReceiveMessage(t *testing.T) { return websocket.TextMessage, strings.NewReader(testMessage), nil }, } - connFactory = func(_ string) (rawConnection, error) { - return spy, nil - } messageChan := make(chan []byte, 1) - conn := newConnection(target, messageChan) + conn := newConnection(staticConnFactory(spy), messageChan) conn.connect() go conn.keepalive() @@ -162,12 +177,9 @@ func TestCloseClosesConnection(t *testing.T) { spy := &inspectableConnection{ closeCalls: make(chan struct{}, 1), } - connFactory = func(_ string) (rawConnection, error) { - return spy, nil - } - conn := newConnection(target, nil) + conn := newConnection(staticConnFactory(spy), nil) conn.connect() - conn.Close() + conn.Shutdown() if len(spy.closeCalls) != 1 { t.Fatalf("Expected 'Close' to be called once, got %v", len(spy.closeCalls)) @@ -178,7 +190,7 @@ func TestCloseIgnoresNoConnection(t *testing.T) { conn := &ManagedConnection{ closeChan: make(chan struct{}, 1), } - got := conn.Close() + got := conn.Shutdown() if got != nil { t.Fatalf("Expected no error, got %v", got) @@ -186,60 +198,46 @@ func TestCloseIgnoresNoConnection(t *testing.T) { } func TestDurableConnectionWhenConnectionBreaksDown(t *testing.T) { - testConn := &inspectableConnection{ - nextReaderCalls: make(chan struct{}), - writeMessageCalls: make(chan struct{}), - closeCalls: make(chan struct{}), + testPayload := "test" + reconnectChan := make(chan struct{}) - nextReaderFunc: func() (int, io.Reader, error) { - return 1, nil, errors.New("next reader errored") - }, - } - connectAttempts := make(chan struct{}) - connFactory = func(_ string) (rawConnection, error) { - connectAttempts <- struct{}{} - return testConn, nil - } - conn := NewDurableSendingConnection(target) + upgrader := websocket.Upgrader{} + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } - // the connection is constantly created, tried to read from - // and closed because NextReader (which holds the connection - // open) fails. - for i := 0; i < 100; i++ { - <-connectAttempts - <-testConn.nextReaderCalls - <-testConn.closeCalls - } + // Waits for a message to be sent before dropping the connection. + <-reconnectChan + c.Close() + })) + defer s.Close() - // Enter the reconnect loop - <-connectAttempts + logger := ktesting.TestLogger(t) + target := "ws" + strings.TrimPrefix(s.URL, "http") + conn := NewDurableSendingConnection(target, logger) + defer conn.Shutdown() - // Call 'Close' asynchronously and wait for it to reach - // the channel. - go conn.Close() - <-testConn.closeCalls + for i := 0; i < 10; i++ { + err := wait.PollImmediate(50*time.Millisecond, 5*time.Second, func() (bool, error) { + if err := conn.Send(testPayload); err != nil { + return false, nil + } + return true, nil + }) - // Advance the reconnect loop until 'Close' is called. - <-testConn.nextReaderCalls - <-testConn.closeCalls + if err != nil { + t.Errorf("Timed out trying to send a message: %v", err) + } - // Wait for the final call to 'Close' (when the loop is aborted) - <-testConn.closeCalls - - if len(connectAttempts) > 1 { - t.Fatalf("Expected at most one connection attempts, got %v", len(connectAttempts)) - } - if len(testConn.nextReaderCalls) > 1 { - t.Fatalf("Expected at most one calls to 'NextReader', got %v", len(testConn.nextReaderCalls)) + // Message successfully sent, instruct the server to drop the connection. + reconnectChan <- struct{}{} } } func TestConnectFailureReturnsError(t *testing.T) { - connFactory = func(_ string) (rawConnection, error) { - return nil, ErrConnectionNotEstablished - } - - conn := newConnection(target, nil) + conn := newConnection(errConnFactory(ErrConnectionNotEstablished), nil) // Shorten the connection backoff duration for this test conn.connectionBackoff.Duration = 1 * time.Millisecond @@ -252,10 +250,70 @@ func TestConnectFailureReturnsError(t *testing.T) { } func TestKeepaliveWithNoConnectionReturnsError(t *testing.T) { - conn := newConnection(target, nil) + conn := newConnection(nil, nil) got := conn.keepalive() if got == nil { t.Fatal("Expected an error but got none") } } + +func TestConnectLoopIsStopped(t *testing.T) { + conn := newConnection(errConnFactory(errors.New("connection error")), nil) + + errorChan := make(chan error) + go func() { + errorChan <- conn.connect() + }() + + conn.Shutdown() + + select { + case err := <-errorChan: + if err != errShuttingDown { + t.Errorf("Wrong 'connect' error, got %v, want %v", err, errShuttingDown) + } + case <-time.After(propagationTimeout): + t.Error("Timed out waiting for the keepalive loop to stop.") + } +} + +func TestKeepaliveLoopIsStopped(t *testing.T) { + spy := &inspectableConnection{ + nextReaderFunc: func() (int, io.Reader, error) { + return websocket.TextMessage, nil, nil + }, + } + conn := newConnection(staticConnFactory(spy), nil) + conn.connect() + + errorChan := make(chan error) + go func() { + errorChan <- conn.keepalive() + }() + + conn.Shutdown() + + select { + case err := <-errorChan: + if err != errShuttingDown { + t.Errorf("Wrong 'keepalive' error, got %v, want %v", err, errShuttingDown) + } + case <-time.After(propagationTimeout): + t.Error("Timed out waiting for the keepalive loop to stop.") + } +} + +func TestDoubleShutdown(t *testing.T) { + spy := &inspectableConnection{ + closeCalls: make(chan struct{}, 2), // potentially allow 2 calls + } + conn := newConnection(staticConnFactory(spy), nil) + conn.connect() + conn.Shutdown() + conn.Shutdown() + + if want, got := 1, len(spy.closeCalls); want != got { + t.Errorf("Wrong 'Close' callcount, got %d, want %d", got, want) + } +}