mirror of https://github.com/knative/pkg.git
				
				
				
			Hardening the websocket client. (#262)
* Cleanup websocket connection, actually test reconnects. * Some more cleanup. * Locally define connFactory to avoid races. * Move locks around, harden test. * Add logging. * Drop redundant target. * Move message encoding outside of the writerLock. * Fix assignment nit. * Remove named return value. * Add close signal to long-running loops. * Add todo for returning a messageType. * Bump header to 2019. * Add note on draining the messageChan. * Drop target from signature. * Drop target from test. * Add a more speaking example to draining the messageChan. * Fix typo. * Relax read lock, improve test. * Bump test coverage. * Add double shutdown test. * Remove code duplication in test.
This commit is contained in:
		
							parent
							
								
									bd5a391c64
								
							
						
					
					
						commit
						3df885a8fc
					
				|  | @ -1,5 +1,5 @@ | |||
| /* | ||||
| Copyright 2018 The Knative Authors | ||||
| Copyright 2019 The Knative Authors | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
|  | @ -20,11 +20,14 @@ import ( | |||
| 	"bytes" | ||||
| 	"encoding/gob" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"io/ioutil" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"go.uber.org/zap" | ||||
| 
 | ||||
| 	"k8s.io/apimachinery/pkg/util/wait" | ||||
| 
 | ||||
| 	"github.com/gorilla/websocket" | ||||
|  | @ -35,13 +38,8 @@ var ( | |||
| 	// but no connection is already created.
 | ||||
| 	ErrConnectionNotEstablished = errors.New("connection has not yet been established") | ||||
| 
 | ||||
| 	connFactory = func(target string) (rawConnection, error) { | ||||
| 		dialer := &websocket.Dialer{ | ||||
| 			HandshakeTimeout: 3 * time.Second, | ||||
| 		} | ||||
| 		conn, _, err := dialer.Dial(target, nil) | ||||
| 		return conn, err | ||||
| 	} | ||||
| 	// errShuttingDown is returned internally once the shutdown signal has been sent.
 | ||||
| 	errShuttingDown = errors.New("shutdown in progress") | ||||
| ) | ||||
| 
 | ||||
| // RawConnection is an interface defining the methods needed
 | ||||
|  | @ -54,9 +52,11 @@ type rawConnection interface { | |||
| 
 | ||||
| // ManagedConnection represents a websocket connection.
 | ||||
| type ManagedConnection struct { | ||||
| 	target     string | ||||
| 	connection rawConnection | ||||
| 	closeChan  chan struct{} | ||||
| 	connection        rawConnection | ||||
| 	connectionFactory func() (rawConnection, error) | ||||
| 
 | ||||
| 	closeChan chan struct{} | ||||
| 	closeOnce sync.Once | ||||
| 
 | ||||
| 	// If set, messages will be forwarded to this channel
 | ||||
| 	messageChan chan []byte | ||||
|  | @ -78,8 +78,8 @@ type ManagedConnection struct { | |||
| // that can only send messages to the endpoint it connects to.
 | ||||
| // The connection will continuously be kept alive and reconnected
 | ||||
| // in case of a loss of connectivity.
 | ||||
| func NewDurableSendingConnection(target string) *ManagedConnection { | ||||
| 	return NewDurableConnection(target, nil) | ||||
| func NewDurableSendingConnection(target string, logger *zap.SugaredLogger) *ManagedConnection { | ||||
| 	return NewDurableConnection(target, nil, logger) | ||||
| } | ||||
| 
 | ||||
| // NewDurableConnection creates a new websocket connection, that
 | ||||
|  | @ -87,30 +87,44 @@ func NewDurableSendingConnection(target string) *ManagedConnection { | |||
| // send messages to the endpoint it connects to.
 | ||||
| // The connection will continuously be kept alive and reconnected
 | ||||
| // in case of a loss of connectivity.
 | ||||
| func NewDurableConnection(target string, messageChan chan []byte) *ManagedConnection { | ||||
| 	c := newConnection(target, messageChan) | ||||
| //
 | ||||
| // Note: The given channel needs to be drained after calling `Shutdown`
 | ||||
| // to not cause any deadlocks. If the channel's buffer is likely to be
 | ||||
| // filled, this needs to happen in separate goroutines, i.e.
 | ||||
| //
 | ||||
| // go func() {conn.Shutdown(); close(messageChan)}
 | ||||
| // go func() {for range messageChan {}}
 | ||||
| func NewDurableConnection(target string, messageChan chan []byte, logger *zap.SugaredLogger) *ManagedConnection { | ||||
| 	websocketConnectionFactory := func() (rawConnection, error) { | ||||
| 		dialer := &websocket.Dialer{ | ||||
| 			HandshakeTimeout: 3 * time.Second, | ||||
| 		} | ||||
| 		conn, _, err := dialer.Dial(target, nil) | ||||
| 		return conn, err | ||||
| 	} | ||||
| 
 | ||||
| 	c := newConnection(websocketConnectionFactory, messageChan) | ||||
| 
 | ||||
| 	// Keep the connection alive asynchronously and reconnect on
 | ||||
| 	// connection failure.
 | ||||
| 	go func() { | ||||
| 		// If the close signal races the connection attempt, make
 | ||||
| 		// sure the connection actually closes.
 | ||||
| 		defer func() { | ||||
| 			c.connectionLock.RLock() | ||||
| 			defer c.connectionLock.RUnlock() | ||||
| 
 | ||||
| 			if conn := c.connection; conn != nil { | ||||
| 				conn.Close() | ||||
| 			} | ||||
| 		}() | ||||
| 		for { | ||||
| 			select { | ||||
| 			default: | ||||
| 				logger.Infof("Connecting to %q", target) | ||||
| 				if err := c.connect(); err != nil { | ||||
| 					logger.Errorw(fmt.Sprintf("Connecting to %q failed", target), zap.Error(err)) | ||||
| 					continue | ||||
| 				} | ||||
| 				c.keepalive() | ||||
| 				logger.Infof("Connected to %q", target) | ||||
| 				if err := c.keepalive(); err != nil { | ||||
| 					logger.Errorw(fmt.Sprintf("Connection to %q broke down, reconnecting...", target), zap.Error(err)) | ||||
| 				} | ||||
| 				if err := c.closeConnection(); err != nil { | ||||
| 					logger.Errorw("Failed to close the connection after crashing", zap.Error(err)) | ||||
| 				} | ||||
| 			case <-c.closeChan: | ||||
| 				logger.Infof("Connection to %q is being shutdown", target) | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
|  | @ -120,11 +134,11 @@ func NewDurableConnection(target string, messageChan chan []byte) *ManagedConnec | |||
| } | ||||
| 
 | ||||
| // newConnection creates a new connection primitive.
 | ||||
| func newConnection(target string, messageChan chan []byte) *ManagedConnection { | ||||
| func newConnection(connFactory func() (rawConnection, error), messageChan chan []byte) *ManagedConnection { | ||||
| 	conn := &ManagedConnection{ | ||||
| 		target:      target, | ||||
| 		closeChan:   make(chan struct{}, 1), | ||||
| 		messageChan: messageChan, | ||||
| 		connectionFactory: connFactory, | ||||
| 		closeChan:         make(chan struct{}), | ||||
| 		messageChan:       messageChan, | ||||
| 		connectionBackoff: wait.Backoff{ | ||||
| 			Duration: 100 * time.Millisecond, | ||||
| 			Factor:   1.3, | ||||
|  | @ -137,58 +151,87 @@ func newConnection(target string, messageChan chan []byte) *ManagedConnection { | |||
| } | ||||
| 
 | ||||
| // connect tries to establish a websocket connection.
 | ||||
| func (c *ManagedConnection) connect() (err error) { | ||||
| func (c *ManagedConnection) connect() error { | ||||
| 	var err error | ||||
| 	wait.ExponentialBackoff(c.connectionBackoff, func() (bool, error) { | ||||
| 		var conn rawConnection | ||||
| 		conn, err = connFactory(c.target) | ||||
| 		if err != nil { | ||||
| 			return false, nil | ||||
| 		} | ||||
| 		c.connectionLock.Lock() | ||||
| 		defer c.connectionLock.Unlock() | ||||
| 		select { | ||||
| 		default: | ||||
| 			var conn rawConnection | ||||
| 			conn, err = c.connectionFactory() | ||||
| 			if err != nil { | ||||
| 				return false, nil | ||||
| 			} | ||||
| 
 | ||||
| 		c.connection = conn | ||||
| 		return true, nil | ||||
| 			c.connectionLock.Lock() | ||||
| 			defer c.connectionLock.Unlock() | ||||
| 
 | ||||
| 			c.connection = conn | ||||
| 			return true, nil | ||||
| 		case <-c.closeChan: | ||||
| 			err = errShuttingDown | ||||
| 			return false, err | ||||
| 		} | ||||
| 	}) | ||||
| 
 | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| // keepalive keeps the connection open and reads control messages.
 | ||||
| // All messages are discarded.
 | ||||
| func (c *ManagedConnection) keepalive() (err error) { | ||||
| // keepalive keeps the connection open.
 | ||||
| func (c *ManagedConnection) keepalive() error { | ||||
| 	for { | ||||
| 		select { | ||||
| 		default: | ||||
| 			if err := c.read(); err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		case <-c.closeChan: | ||||
| 			return errShuttingDown | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // closeConnection closes the underlying websocket connection.
 | ||||
| func (c *ManagedConnection) closeConnection() error { | ||||
| 	c.connectionLock.Lock() | ||||
| 	defer c.connectionLock.Unlock() | ||||
| 
 | ||||
| 	if c.connection != nil { | ||||
| 		err := c.connection.Close() | ||||
| 		c.connection = nil | ||||
| 		return err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // read reads the next message from the connection.
 | ||||
| // If a messageChan is supplied and the current message type is not
 | ||||
| // a control message, the message is sent to that channel.
 | ||||
| func (c *ManagedConnection) read() error { | ||||
| 	c.connectionLock.RLock() | ||||
| 	defer c.connectionLock.RUnlock() | ||||
| 
 | ||||
| 	if c.connection == nil { | ||||
| 		return ErrConnectionNotEstablished | ||||
| 	} | ||||
| 
 | ||||
| 	c.readerLock.Lock() | ||||
| 	defer c.readerLock.Unlock() | ||||
| 
 | ||||
| 	for { | ||||
| 		func() { | ||||
| 			c.connectionLock.RLock() | ||||
| 			defer c.connectionLock.RUnlock() | ||||
| 	messageType, reader, err := c.connection.NextReader() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 			if conn := c.connection; conn != nil { | ||||
| 				var reader io.Reader | ||||
| 				var messageType int | ||||
| 				messageType, reader, err = conn.NextReader() | ||||
| 				if err != nil { | ||||
| 					conn.Close() | ||||
| 				} | ||||
| 
 | ||||
| 				// Send the message to the channel if its an application level message
 | ||||
| 				// and if that channel is set.
 | ||||
| 				if c.messageChan != nil && (messageType == websocket.TextMessage || messageType == websocket.BinaryMessage) { | ||||
| 					if message, _ := ioutil.ReadAll(reader); message != nil { | ||||
| 						c.messageChan <- message | ||||
| 					} | ||||
| 				} | ||||
| 			} else { | ||||
| 				err = ErrConnectionNotEstablished | ||||
| 			} | ||||
| 		}() | ||||
| 
 | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 	// Send the message to the channel if its an application level message
 | ||||
| 	// and if that channel is set.
 | ||||
| 	// TODO(markusthoemmes): Return the messageType along with the payload.
 | ||||
| 	if c.messageChan != nil && (messageType == websocket.TextMessage || messageType == websocket.BinaryMessage) { | ||||
| 		if message, _ := ioutil.ReadAll(reader); message != nil { | ||||
| 			c.messageChan <- message | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // Send sends an encodable message over the websocket connection.
 | ||||
|  | @ -196,31 +239,27 @@ func (c *ManagedConnection) Send(msg interface{}) error { | |||
| 	c.connectionLock.RLock() | ||||
| 	defer c.connectionLock.RUnlock() | ||||
| 
 | ||||
| 	conn := c.connection | ||||
| 	if conn == nil { | ||||
| 	if c.connection == nil { | ||||
| 		return ErrConnectionNotEstablished | ||||
| 	} | ||||
| 
 | ||||
| 	c.writerLock.Lock() | ||||
| 	defer c.writerLock.Unlock() | ||||
| 
 | ||||
| 	var b bytes.Buffer | ||||
| 	enc := gob.NewEncoder(&b) | ||||
| 	if err := enc.Encode(msg); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	return conn.WriteMessage(websocket.BinaryMessage, b.Bytes()) | ||||
| 	c.writerLock.Lock() | ||||
| 	defer c.writerLock.Unlock() | ||||
| 
 | ||||
| 	return c.connection.WriteMessage(websocket.BinaryMessage, b.Bytes()) | ||||
| } | ||||
| 
 | ||||
| // Close closes the websocket connection.
 | ||||
| func (c *ManagedConnection) Close() error { | ||||
| 	c.closeChan <- struct{}{} | ||||
| 	c.connectionLock.RLock() | ||||
| 	defer c.connectionLock.RUnlock() | ||||
| // Shutdown closes the websocket connection.
 | ||||
| func (c *ManagedConnection) Shutdown() error { | ||||
| 	c.closeOnce.Do(func() { | ||||
| 		close(c.closeChan) | ||||
| 	}) | ||||
| 
 | ||||
| 	if conn := c.connection; conn != nil { | ||||
| 		return conn.Close() | ||||
| 	} | ||||
| 	return nil | ||||
| 	return c.closeConnection() | ||||
| } | ||||
|  |  | |||
|  | @ -1,5 +1,5 @@ | |||
| /* | ||||
| Copyright 2018 The Knative Authors | ||||
| Copyright 2019 The Knative Authors | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
|  | @ -19,16 +19,20 @@ package websocket | |||
| import ( | ||||
| 	"errors" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	ktesting "github.com/knative/pkg/logging/testing" | ||||
| 
 | ||||
| 	"k8s.io/apimachinery/pkg/util/wait" | ||||
| 
 | ||||
| 	"github.com/gorilla/websocket" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	target = "test" | ||||
| ) | ||||
| const propagationTimeout = 5 * time.Second | ||||
| 
 | ||||
| type inspectableConnection struct { | ||||
| 	nextReaderCalls   chan struct{} | ||||
|  | @ -39,20 +43,41 @@ type inspectableConnection struct { | |||
| } | ||||
| 
 | ||||
| func (c *inspectableConnection) WriteMessage(messageType int, data []byte) error { | ||||
| 	c.writeMessageCalls <- struct{}{} | ||||
| 	if c.writeMessageCalls != nil { | ||||
| 		c.writeMessageCalls <- struct{}{} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (c *inspectableConnection) NextReader() (int, io.Reader, error) { | ||||
| 	c.nextReaderCalls <- struct{}{} | ||||
| 	if c.nextReaderCalls != nil { | ||||
| 		c.nextReaderCalls <- struct{}{} | ||||
| 	} | ||||
| 	return c.nextReaderFunc() | ||||
| } | ||||
| 
 | ||||
| func (c *inspectableConnection) Close() error { | ||||
| 	c.closeCalls <- struct{}{} | ||||
| 	if c.closeCalls != nil { | ||||
| 		c.closeCalls <- struct{}{} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // staticConnFactory returns a static connection, for example
 | ||||
| // an inspectable connection.
 | ||||
| func staticConnFactory(conn rawConnection) func() (rawConnection, error) { | ||||
| 	return func() (rawConnection, error) { | ||||
| 		return conn, nil | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // errConnFactory returns a static error.
 | ||||
| func errConnFactory(err error) func() (rawConnection, error) { | ||||
| 	return func() (rawConnection, error) { | ||||
| 		return nil, err | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestRetriesWhileConnect(t *testing.T) { | ||||
| 	want := 2 | ||||
| 	got := 0 | ||||
|  | @ -61,17 +86,17 @@ func TestRetriesWhileConnect(t *testing.T) { | |||
| 		closeCalls: make(chan struct{}, 1), | ||||
| 	} | ||||
| 
 | ||||
| 	connFactory = func(_ string) (rawConnection, error) { | ||||
| 	connFactory := func() (rawConnection, error) { | ||||
| 		got++ | ||||
| 		if got == want { | ||||
| 			return spy, nil | ||||
| 		} | ||||
| 		return nil, errors.New("not yet") | ||||
| 	} | ||||
| 	conn := newConnection(target, nil) | ||||
| 	conn := newConnection(connFactory, nil) | ||||
| 
 | ||||
| 	conn.connect() | ||||
| 	conn.Close() | ||||
| 	conn.Shutdown() | ||||
| 
 | ||||
| 	if got != want { | ||||
| 		t.Fatalf("Wanted %v retries. Got %v.", want, got) | ||||
|  | @ -96,11 +121,7 @@ func TestSendErrorOnEncode(t *testing.T) { | |||
| 	spy := &inspectableConnection{ | ||||
| 		writeMessageCalls: make(chan struct{}, 1), | ||||
| 	} | ||||
| 
 | ||||
| 	connFactory = func(_ string) (rawConnection, error) { | ||||
| 		return spy, nil | ||||
| 	} | ||||
| 	conn := newConnection(target, nil) | ||||
| 	conn := newConnection(staticConnFactory(spy), nil) | ||||
| 	conn.connect() | ||||
| 	// gob cannot encode nil values
 | ||||
| 	got := conn.Send(nil) | ||||
|  | @ -117,10 +138,7 @@ func TestSendMessage(t *testing.T) { | |||
| 	spy := &inspectableConnection{ | ||||
| 		writeMessageCalls: make(chan struct{}, 1), | ||||
| 	} | ||||
| 	connFactory = func(_ string) (rawConnection, error) { | ||||
| 		return spy, nil | ||||
| 	} | ||||
| 	conn := newConnection(target, nil) | ||||
| 	conn := newConnection(staticConnFactory(spy), nil) | ||||
| 	conn.connect() | ||||
| 	got := conn.Send("test") | ||||
| 
 | ||||
|  | @ -142,12 +160,9 @@ func TestReceiveMessage(t *testing.T) { | |||
| 			return websocket.TextMessage, strings.NewReader(testMessage), nil | ||||
| 		}, | ||||
| 	} | ||||
| 	connFactory = func(_ string) (rawConnection, error) { | ||||
| 		return spy, nil | ||||
| 	} | ||||
| 
 | ||||
| 	messageChan := make(chan []byte, 1) | ||||
| 	conn := newConnection(target, messageChan) | ||||
| 	conn := newConnection(staticConnFactory(spy), messageChan) | ||||
| 	conn.connect() | ||||
| 	go conn.keepalive() | ||||
| 
 | ||||
|  | @ -162,12 +177,9 @@ func TestCloseClosesConnection(t *testing.T) { | |||
| 	spy := &inspectableConnection{ | ||||
| 		closeCalls: make(chan struct{}, 1), | ||||
| 	} | ||||
| 	connFactory = func(_ string) (rawConnection, error) { | ||||
| 		return spy, nil | ||||
| 	} | ||||
| 	conn := newConnection(target, nil) | ||||
| 	conn := newConnection(staticConnFactory(spy), nil) | ||||
| 	conn.connect() | ||||
| 	conn.Close() | ||||
| 	conn.Shutdown() | ||||
| 
 | ||||
| 	if len(spy.closeCalls) != 1 { | ||||
| 		t.Fatalf("Expected 'Close' to be called once, got %v", len(spy.closeCalls)) | ||||
|  | @ -178,7 +190,7 @@ func TestCloseIgnoresNoConnection(t *testing.T) { | |||
| 	conn := &ManagedConnection{ | ||||
| 		closeChan: make(chan struct{}, 1), | ||||
| 	} | ||||
| 	got := conn.Close() | ||||
| 	got := conn.Shutdown() | ||||
| 
 | ||||
| 	if got != nil { | ||||
| 		t.Fatalf("Expected no error, got %v", got) | ||||
|  | @ -186,60 +198,46 @@ func TestCloseIgnoresNoConnection(t *testing.T) { | |||
| } | ||||
| 
 | ||||
| func TestDurableConnectionWhenConnectionBreaksDown(t *testing.T) { | ||||
| 	testConn := &inspectableConnection{ | ||||
| 		nextReaderCalls:   make(chan struct{}), | ||||
| 		writeMessageCalls: make(chan struct{}), | ||||
| 		closeCalls:        make(chan struct{}), | ||||
| 	testPayload := "test" | ||||
| 	reconnectChan := make(chan struct{}) | ||||
| 
 | ||||
| 		nextReaderFunc: func() (int, io.Reader, error) { | ||||
| 			return 1, nil, errors.New("next reader errored") | ||||
| 		}, | ||||
| 	} | ||||
| 	connectAttempts := make(chan struct{}) | ||||
| 	connFactory = func(_ string) (rawConnection, error) { | ||||
| 		connectAttempts <- struct{}{} | ||||
| 		return testConn, nil | ||||
| 	} | ||||
| 	conn := NewDurableSendingConnection(target) | ||||
| 	upgrader := websocket.Upgrader{} | ||||
| 	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		c, err := upgrader.Upgrade(w, r, nil) | ||||
| 		if err != nil { | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 	// the connection is constantly created, tried to read from
 | ||||
| 	// and closed because NextReader (which holds the connection
 | ||||
| 	// open) fails.
 | ||||
| 	for i := 0; i < 100; i++ { | ||||
| 		<-connectAttempts | ||||
| 		<-testConn.nextReaderCalls | ||||
| 		<-testConn.closeCalls | ||||
| 	} | ||||
| 		// Waits for a message to be sent before dropping the connection.
 | ||||
| 		<-reconnectChan | ||||
| 		c.Close() | ||||
| 	})) | ||||
| 	defer s.Close() | ||||
| 
 | ||||
| 	// Enter the reconnect loop
 | ||||
| 	<-connectAttempts | ||||
| 	logger := ktesting.TestLogger(t) | ||||
| 	target := "ws" + strings.TrimPrefix(s.URL, "http") | ||||
| 	conn := NewDurableSendingConnection(target, logger) | ||||
| 	defer conn.Shutdown() | ||||
| 
 | ||||
| 	// Call 'Close' asynchronously and wait for it to reach
 | ||||
| 	// the channel.
 | ||||
| 	go conn.Close() | ||||
| 	<-testConn.closeCalls | ||||
| 	for i := 0; i < 10; i++ { | ||||
| 		err := wait.PollImmediate(50*time.Millisecond, 5*time.Second, func() (bool, error) { | ||||
| 			if err := conn.Send(testPayload); err != nil { | ||||
| 				return false, nil | ||||
| 			} | ||||
| 			return true, nil | ||||
| 		}) | ||||
| 
 | ||||
| 	// Advance the reconnect loop until 'Close' is called.
 | ||||
| 	<-testConn.nextReaderCalls | ||||
| 	<-testConn.closeCalls | ||||
| 		if err != nil { | ||||
| 			t.Errorf("Timed out trying to send a message: %v", err) | ||||
| 		} | ||||
| 
 | ||||
| 	// Wait for the final call to 'Close' (when the loop is aborted)
 | ||||
| 	<-testConn.closeCalls | ||||
| 
 | ||||
| 	if len(connectAttempts) > 1 { | ||||
| 		t.Fatalf("Expected at most one connection attempts, got %v", len(connectAttempts)) | ||||
| 	} | ||||
| 	if len(testConn.nextReaderCalls) > 1 { | ||||
| 		t.Fatalf("Expected at most one calls to 'NextReader', got %v", len(testConn.nextReaderCalls)) | ||||
| 		// Message successfully sent, instruct the server to drop the connection.
 | ||||
| 		reconnectChan <- struct{}{} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestConnectFailureReturnsError(t *testing.T) { | ||||
| 	connFactory = func(_ string) (rawConnection, error) { | ||||
| 		return nil, ErrConnectionNotEstablished | ||||
| 	} | ||||
| 
 | ||||
| 	conn := newConnection(target, nil) | ||||
| 	conn := newConnection(errConnFactory(ErrConnectionNotEstablished), nil) | ||||
| 
 | ||||
| 	// Shorten the connection backoff duration for this test
 | ||||
| 	conn.connectionBackoff.Duration = 1 * time.Millisecond | ||||
|  | @ -252,10 +250,70 @@ func TestConnectFailureReturnsError(t *testing.T) { | |||
| } | ||||
| 
 | ||||
| func TestKeepaliveWithNoConnectionReturnsError(t *testing.T) { | ||||
| 	conn := newConnection(target, nil) | ||||
| 	conn := newConnection(nil, nil) | ||||
| 	got := conn.keepalive() | ||||
| 
 | ||||
| 	if got == nil { | ||||
| 		t.Fatal("Expected an error but got none") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestConnectLoopIsStopped(t *testing.T) { | ||||
| 	conn := newConnection(errConnFactory(errors.New("connection error")), nil) | ||||
| 
 | ||||
| 	errorChan := make(chan error) | ||||
| 	go func() { | ||||
| 		errorChan <- conn.connect() | ||||
| 	}() | ||||
| 
 | ||||
| 	conn.Shutdown() | ||||
| 
 | ||||
| 	select { | ||||
| 	case err := <-errorChan: | ||||
| 		if err != errShuttingDown { | ||||
| 			t.Errorf("Wrong 'connect' error, got %v, want %v", err, errShuttingDown) | ||||
| 		} | ||||
| 	case <-time.After(propagationTimeout): | ||||
| 		t.Error("Timed out waiting for the keepalive loop to stop.") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestKeepaliveLoopIsStopped(t *testing.T) { | ||||
| 	spy := &inspectableConnection{ | ||||
| 		nextReaderFunc: func() (int, io.Reader, error) { | ||||
| 			return websocket.TextMessage, nil, nil | ||||
| 		}, | ||||
| 	} | ||||
| 	conn := newConnection(staticConnFactory(spy), nil) | ||||
| 	conn.connect() | ||||
| 
 | ||||
| 	errorChan := make(chan error) | ||||
| 	go func() { | ||||
| 		errorChan <- conn.keepalive() | ||||
| 	}() | ||||
| 
 | ||||
| 	conn.Shutdown() | ||||
| 
 | ||||
| 	select { | ||||
| 	case err := <-errorChan: | ||||
| 		if err != errShuttingDown { | ||||
| 			t.Errorf("Wrong 'keepalive' error, got %v, want %v", err, errShuttingDown) | ||||
| 		} | ||||
| 	case <-time.After(propagationTimeout): | ||||
| 		t.Error("Timed out waiting for the keepalive loop to stop.") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestDoubleShutdown(t *testing.T) { | ||||
| 	spy := &inspectableConnection{ | ||||
| 		closeCalls: make(chan struct{}, 2), // potentially allow 2 calls
 | ||||
| 	} | ||||
| 	conn := newConnection(staticConnFactory(spy), nil) | ||||
| 	conn.connect() | ||||
| 	conn.Shutdown() | ||||
| 	conn.Shutdown() | ||||
| 
 | ||||
| 	if want, got := 1, len(spy.closeCalls); want != got { | ||||
| 		t.Errorf("Wrong 'Close' callcount, got %d, want %d", got, want) | ||||
| 	} | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue