mirror of https://github.com/knative/pkg.git
Hardening the websocket client. (#262)
* Cleanup websocket connection, actually test reconnects. * Some more cleanup. * Locally define connFactory to avoid races. * Move locks around, harden test. * Add logging. * Drop redundant target. * Move message encoding outside of the writerLock. * Fix assignment nit. * Remove named return value. * Add close signal to long-running loops. * Add todo for returning a messageType. * Bump header to 2019. * Add note on draining the messageChan. * Drop target from signature. * Drop target from test. * Add a more speaking example to draining the messageChan. * Fix typo. * Relax read lock, improve test. * Bump test coverage. * Add double shutdown test. * Remove code duplication in test.
This commit is contained in:
parent
bd5a391c64
commit
3df885a8fc
|
|
@ -1,5 +1,5 @@
|
||||||
/*
|
/*
|
||||||
Copyright 2018 The Knative Authors
|
Copyright 2019 The Knative Authors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|
@ -20,11 +20,14 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/gob"
|
"encoding/gob"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
|
||||||
"k8s.io/apimachinery/pkg/util/wait"
|
"k8s.io/apimachinery/pkg/util/wait"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
|
|
@ -35,13 +38,8 @@ var (
|
||||||
// but no connection is already created.
|
// but no connection is already created.
|
||||||
ErrConnectionNotEstablished = errors.New("connection has not yet been established")
|
ErrConnectionNotEstablished = errors.New("connection has not yet been established")
|
||||||
|
|
||||||
connFactory = func(target string) (rawConnection, error) {
|
// errShuttingDown is returned internally once the shutdown signal has been sent.
|
||||||
dialer := &websocket.Dialer{
|
errShuttingDown = errors.New("shutdown in progress")
|
||||||
HandshakeTimeout: 3 * time.Second,
|
|
||||||
}
|
|
||||||
conn, _, err := dialer.Dial(target, nil)
|
|
||||||
return conn, err
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// RawConnection is an interface defining the methods needed
|
// RawConnection is an interface defining the methods needed
|
||||||
|
|
@ -54,9 +52,11 @@ type rawConnection interface {
|
||||||
|
|
||||||
// ManagedConnection represents a websocket connection.
|
// ManagedConnection represents a websocket connection.
|
||||||
type ManagedConnection struct {
|
type ManagedConnection struct {
|
||||||
target string
|
connection rawConnection
|
||||||
connection rawConnection
|
connectionFactory func() (rawConnection, error)
|
||||||
closeChan chan struct{}
|
|
||||||
|
closeChan chan struct{}
|
||||||
|
closeOnce sync.Once
|
||||||
|
|
||||||
// If set, messages will be forwarded to this channel
|
// If set, messages will be forwarded to this channel
|
||||||
messageChan chan []byte
|
messageChan chan []byte
|
||||||
|
|
@ -78,8 +78,8 @@ type ManagedConnection struct {
|
||||||
// that can only send messages to the endpoint it connects to.
|
// that can only send messages to the endpoint it connects to.
|
||||||
// The connection will continuously be kept alive and reconnected
|
// The connection will continuously be kept alive and reconnected
|
||||||
// in case of a loss of connectivity.
|
// in case of a loss of connectivity.
|
||||||
func NewDurableSendingConnection(target string) *ManagedConnection {
|
func NewDurableSendingConnection(target string, logger *zap.SugaredLogger) *ManagedConnection {
|
||||||
return NewDurableConnection(target, nil)
|
return NewDurableConnection(target, nil, logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDurableConnection creates a new websocket connection, that
|
// NewDurableConnection creates a new websocket connection, that
|
||||||
|
|
@ -87,30 +87,44 @@ func NewDurableSendingConnection(target string) *ManagedConnection {
|
||||||
// send messages to the endpoint it connects to.
|
// send messages to the endpoint it connects to.
|
||||||
// The connection will continuously be kept alive and reconnected
|
// The connection will continuously be kept alive and reconnected
|
||||||
// in case of a loss of connectivity.
|
// in case of a loss of connectivity.
|
||||||
func NewDurableConnection(target string, messageChan chan []byte) *ManagedConnection {
|
//
|
||||||
c := newConnection(target, messageChan)
|
// Note: The given channel needs to be drained after calling `Shutdown`
|
||||||
|
// to not cause any deadlocks. If the channel's buffer is likely to be
|
||||||
|
// filled, this needs to happen in separate goroutines, i.e.
|
||||||
|
//
|
||||||
|
// go func() {conn.Shutdown(); close(messageChan)}
|
||||||
|
// go func() {for range messageChan {}}
|
||||||
|
func NewDurableConnection(target string, messageChan chan []byte, logger *zap.SugaredLogger) *ManagedConnection {
|
||||||
|
websocketConnectionFactory := func() (rawConnection, error) {
|
||||||
|
dialer := &websocket.Dialer{
|
||||||
|
HandshakeTimeout: 3 * time.Second,
|
||||||
|
}
|
||||||
|
conn, _, err := dialer.Dial(target, nil)
|
||||||
|
return conn, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c := newConnection(websocketConnectionFactory, messageChan)
|
||||||
|
|
||||||
// Keep the connection alive asynchronously and reconnect on
|
// Keep the connection alive asynchronously and reconnect on
|
||||||
// connection failure.
|
// connection failure.
|
||||||
go func() {
|
go func() {
|
||||||
// If the close signal races the connection attempt, make
|
|
||||||
// sure the connection actually closes.
|
|
||||||
defer func() {
|
|
||||||
c.connectionLock.RLock()
|
|
||||||
defer c.connectionLock.RUnlock()
|
|
||||||
|
|
||||||
if conn := c.connection; conn != nil {
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
default:
|
default:
|
||||||
|
logger.Infof("Connecting to %q", target)
|
||||||
if err := c.connect(); err != nil {
|
if err := c.connect(); err != nil {
|
||||||
|
logger.Errorw(fmt.Sprintf("Connecting to %q failed", target), zap.Error(err))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
c.keepalive()
|
logger.Infof("Connected to %q", target)
|
||||||
|
if err := c.keepalive(); err != nil {
|
||||||
|
logger.Errorw(fmt.Sprintf("Connection to %q broke down, reconnecting...", target), zap.Error(err))
|
||||||
|
}
|
||||||
|
if err := c.closeConnection(); err != nil {
|
||||||
|
logger.Errorw("Failed to close the connection after crashing", zap.Error(err))
|
||||||
|
}
|
||||||
case <-c.closeChan:
|
case <-c.closeChan:
|
||||||
|
logger.Infof("Connection to %q is being shutdown", target)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -120,11 +134,11 @@ func NewDurableConnection(target string, messageChan chan []byte) *ManagedConnec
|
||||||
}
|
}
|
||||||
|
|
||||||
// newConnection creates a new connection primitive.
|
// newConnection creates a new connection primitive.
|
||||||
func newConnection(target string, messageChan chan []byte) *ManagedConnection {
|
func newConnection(connFactory func() (rawConnection, error), messageChan chan []byte) *ManagedConnection {
|
||||||
conn := &ManagedConnection{
|
conn := &ManagedConnection{
|
||||||
target: target,
|
connectionFactory: connFactory,
|
||||||
closeChan: make(chan struct{}, 1),
|
closeChan: make(chan struct{}),
|
||||||
messageChan: messageChan,
|
messageChan: messageChan,
|
||||||
connectionBackoff: wait.Backoff{
|
connectionBackoff: wait.Backoff{
|
||||||
Duration: 100 * time.Millisecond,
|
Duration: 100 * time.Millisecond,
|
||||||
Factor: 1.3,
|
Factor: 1.3,
|
||||||
|
|
@ -137,58 +151,87 @@ func newConnection(target string, messageChan chan []byte) *ManagedConnection {
|
||||||
}
|
}
|
||||||
|
|
||||||
// connect tries to establish a websocket connection.
|
// connect tries to establish a websocket connection.
|
||||||
func (c *ManagedConnection) connect() (err error) {
|
func (c *ManagedConnection) connect() error {
|
||||||
|
var err error
|
||||||
wait.ExponentialBackoff(c.connectionBackoff, func() (bool, error) {
|
wait.ExponentialBackoff(c.connectionBackoff, func() (bool, error) {
|
||||||
var conn rawConnection
|
select {
|
||||||
conn, err = connFactory(c.target)
|
default:
|
||||||
if err != nil {
|
var conn rawConnection
|
||||||
return false, nil
|
conn, err = c.connectionFactory()
|
||||||
}
|
if err != nil {
|
||||||
c.connectionLock.Lock()
|
return false, nil
|
||||||
defer c.connectionLock.Unlock()
|
}
|
||||||
|
|
||||||
c.connection = conn
|
c.connectionLock.Lock()
|
||||||
return true, nil
|
defer c.connectionLock.Unlock()
|
||||||
|
|
||||||
|
c.connection = conn
|
||||||
|
return true, nil
|
||||||
|
case <-c.closeChan:
|
||||||
|
err = errShuttingDown
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// keepalive keeps the connection open and reads control messages.
|
// keepalive keeps the connection open.
|
||||||
// All messages are discarded.
|
func (c *ManagedConnection) keepalive() error {
|
||||||
func (c *ManagedConnection) keepalive() (err error) {
|
for {
|
||||||
|
select {
|
||||||
|
default:
|
||||||
|
if err := c.read(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case <-c.closeChan:
|
||||||
|
return errShuttingDown
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// closeConnection closes the underlying websocket connection.
|
||||||
|
func (c *ManagedConnection) closeConnection() error {
|
||||||
|
c.connectionLock.Lock()
|
||||||
|
defer c.connectionLock.Unlock()
|
||||||
|
|
||||||
|
if c.connection != nil {
|
||||||
|
err := c.connection.Close()
|
||||||
|
c.connection = nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// read reads the next message from the connection.
|
||||||
|
// If a messageChan is supplied and the current message type is not
|
||||||
|
// a control message, the message is sent to that channel.
|
||||||
|
func (c *ManagedConnection) read() error {
|
||||||
|
c.connectionLock.RLock()
|
||||||
|
defer c.connectionLock.RUnlock()
|
||||||
|
|
||||||
|
if c.connection == nil {
|
||||||
|
return ErrConnectionNotEstablished
|
||||||
|
}
|
||||||
|
|
||||||
c.readerLock.Lock()
|
c.readerLock.Lock()
|
||||||
defer c.readerLock.Unlock()
|
defer c.readerLock.Unlock()
|
||||||
|
|
||||||
for {
|
messageType, reader, err := c.connection.NextReader()
|
||||||
func() {
|
if err != nil {
|
||||||
c.connectionLock.RLock()
|
return err
|
||||||
defer c.connectionLock.RUnlock()
|
}
|
||||||
|
|
||||||
if conn := c.connection; conn != nil {
|
// Send the message to the channel if its an application level message
|
||||||
var reader io.Reader
|
// and if that channel is set.
|
||||||
var messageType int
|
// TODO(markusthoemmes): Return the messageType along with the payload.
|
||||||
messageType, reader, err = conn.NextReader()
|
if c.messageChan != nil && (messageType == websocket.TextMessage || messageType == websocket.BinaryMessage) {
|
||||||
if err != nil {
|
if message, _ := ioutil.ReadAll(reader); message != nil {
|
||||||
conn.Close()
|
c.messageChan <- message
|
||||||
}
|
|
||||||
|
|
||||||
// Send the message to the channel if its an application level message
|
|
||||||
// and if that channel is set.
|
|
||||||
if c.messageChan != nil && (messageType == websocket.TextMessage || messageType == websocket.BinaryMessage) {
|
|
||||||
if message, _ := ioutil.ReadAll(reader); message != nil {
|
|
||||||
c.messageChan <- message
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
err = ErrConnectionNotEstablished
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send sends an encodable message over the websocket connection.
|
// Send sends an encodable message over the websocket connection.
|
||||||
|
|
@ -196,31 +239,27 @@ func (c *ManagedConnection) Send(msg interface{}) error {
|
||||||
c.connectionLock.RLock()
|
c.connectionLock.RLock()
|
||||||
defer c.connectionLock.RUnlock()
|
defer c.connectionLock.RUnlock()
|
||||||
|
|
||||||
conn := c.connection
|
if c.connection == nil {
|
||||||
if conn == nil {
|
|
||||||
return ErrConnectionNotEstablished
|
return ErrConnectionNotEstablished
|
||||||
}
|
}
|
||||||
|
|
||||||
c.writerLock.Lock()
|
|
||||||
defer c.writerLock.Unlock()
|
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
enc := gob.NewEncoder(&b)
|
enc := gob.NewEncoder(&b)
|
||||||
if err := enc.Encode(msg); err != nil {
|
if err := enc.Encode(msg); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return conn.WriteMessage(websocket.BinaryMessage, b.Bytes())
|
c.writerLock.Lock()
|
||||||
|
defer c.writerLock.Unlock()
|
||||||
|
|
||||||
|
return c.connection.WriteMessage(websocket.BinaryMessage, b.Bytes())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the websocket connection.
|
// Shutdown closes the websocket connection.
|
||||||
func (c *ManagedConnection) Close() error {
|
func (c *ManagedConnection) Shutdown() error {
|
||||||
c.closeChan <- struct{}{}
|
c.closeOnce.Do(func() {
|
||||||
c.connectionLock.RLock()
|
close(c.closeChan)
|
||||||
defer c.connectionLock.RUnlock()
|
})
|
||||||
|
|
||||||
if conn := c.connection; conn != nil {
|
return c.closeConnection()
|
||||||
return conn.Close()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
/*
|
/*
|
||||||
Copyright 2018 The Knative Authors
|
Copyright 2019 The Knative Authors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|
@ -19,16 +19,20 @@ package websocket
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
ktesting "github.com/knative/pkg/logging/testing"
|
||||||
|
|
||||||
|
"k8s.io/apimachinery/pkg/util/wait"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const propagationTimeout = 5 * time.Second
|
||||||
target = "test"
|
|
||||||
)
|
|
||||||
|
|
||||||
type inspectableConnection struct {
|
type inspectableConnection struct {
|
||||||
nextReaderCalls chan struct{}
|
nextReaderCalls chan struct{}
|
||||||
|
|
@ -39,20 +43,41 @@ type inspectableConnection struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *inspectableConnection) WriteMessage(messageType int, data []byte) error {
|
func (c *inspectableConnection) WriteMessage(messageType int, data []byte) error {
|
||||||
c.writeMessageCalls <- struct{}{}
|
if c.writeMessageCalls != nil {
|
||||||
|
c.writeMessageCalls <- struct{}{}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *inspectableConnection) NextReader() (int, io.Reader, error) {
|
func (c *inspectableConnection) NextReader() (int, io.Reader, error) {
|
||||||
c.nextReaderCalls <- struct{}{}
|
if c.nextReaderCalls != nil {
|
||||||
|
c.nextReaderCalls <- struct{}{}
|
||||||
|
}
|
||||||
return c.nextReaderFunc()
|
return c.nextReaderFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *inspectableConnection) Close() error {
|
func (c *inspectableConnection) Close() error {
|
||||||
c.closeCalls <- struct{}{}
|
if c.closeCalls != nil {
|
||||||
|
c.closeCalls <- struct{}{}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// staticConnFactory returns a static connection, for example
|
||||||
|
// an inspectable connection.
|
||||||
|
func staticConnFactory(conn rawConnection) func() (rawConnection, error) {
|
||||||
|
return func() (rawConnection, error) {
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// errConnFactory returns a static error.
|
||||||
|
func errConnFactory(err error) func() (rawConnection, error) {
|
||||||
|
return func() (rawConnection, error) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRetriesWhileConnect(t *testing.T) {
|
func TestRetriesWhileConnect(t *testing.T) {
|
||||||
want := 2
|
want := 2
|
||||||
got := 0
|
got := 0
|
||||||
|
|
@ -61,17 +86,17 @@ func TestRetriesWhileConnect(t *testing.T) {
|
||||||
closeCalls: make(chan struct{}, 1),
|
closeCalls: make(chan struct{}, 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
connFactory = func(_ string) (rawConnection, error) {
|
connFactory := func() (rawConnection, error) {
|
||||||
got++
|
got++
|
||||||
if got == want {
|
if got == want {
|
||||||
return spy, nil
|
return spy, nil
|
||||||
}
|
}
|
||||||
return nil, errors.New("not yet")
|
return nil, errors.New("not yet")
|
||||||
}
|
}
|
||||||
conn := newConnection(target, nil)
|
conn := newConnection(connFactory, nil)
|
||||||
|
|
||||||
conn.connect()
|
conn.connect()
|
||||||
conn.Close()
|
conn.Shutdown()
|
||||||
|
|
||||||
if got != want {
|
if got != want {
|
||||||
t.Fatalf("Wanted %v retries. Got %v.", want, got)
|
t.Fatalf("Wanted %v retries. Got %v.", want, got)
|
||||||
|
|
@ -96,11 +121,7 @@ func TestSendErrorOnEncode(t *testing.T) {
|
||||||
spy := &inspectableConnection{
|
spy := &inspectableConnection{
|
||||||
writeMessageCalls: make(chan struct{}, 1),
|
writeMessageCalls: make(chan struct{}, 1),
|
||||||
}
|
}
|
||||||
|
conn := newConnection(staticConnFactory(spy), nil)
|
||||||
connFactory = func(_ string) (rawConnection, error) {
|
|
||||||
return spy, nil
|
|
||||||
}
|
|
||||||
conn := newConnection(target, nil)
|
|
||||||
conn.connect()
|
conn.connect()
|
||||||
// gob cannot encode nil values
|
// gob cannot encode nil values
|
||||||
got := conn.Send(nil)
|
got := conn.Send(nil)
|
||||||
|
|
@ -117,10 +138,7 @@ func TestSendMessage(t *testing.T) {
|
||||||
spy := &inspectableConnection{
|
spy := &inspectableConnection{
|
||||||
writeMessageCalls: make(chan struct{}, 1),
|
writeMessageCalls: make(chan struct{}, 1),
|
||||||
}
|
}
|
||||||
connFactory = func(_ string) (rawConnection, error) {
|
conn := newConnection(staticConnFactory(spy), nil)
|
||||||
return spy, nil
|
|
||||||
}
|
|
||||||
conn := newConnection(target, nil)
|
|
||||||
conn.connect()
|
conn.connect()
|
||||||
got := conn.Send("test")
|
got := conn.Send("test")
|
||||||
|
|
||||||
|
|
@ -142,12 +160,9 @@ func TestReceiveMessage(t *testing.T) {
|
||||||
return websocket.TextMessage, strings.NewReader(testMessage), nil
|
return websocket.TextMessage, strings.NewReader(testMessage), nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
connFactory = func(_ string) (rawConnection, error) {
|
|
||||||
return spy, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
messageChan := make(chan []byte, 1)
|
messageChan := make(chan []byte, 1)
|
||||||
conn := newConnection(target, messageChan)
|
conn := newConnection(staticConnFactory(spy), messageChan)
|
||||||
conn.connect()
|
conn.connect()
|
||||||
go conn.keepalive()
|
go conn.keepalive()
|
||||||
|
|
||||||
|
|
@ -162,12 +177,9 @@ func TestCloseClosesConnection(t *testing.T) {
|
||||||
spy := &inspectableConnection{
|
spy := &inspectableConnection{
|
||||||
closeCalls: make(chan struct{}, 1),
|
closeCalls: make(chan struct{}, 1),
|
||||||
}
|
}
|
||||||
connFactory = func(_ string) (rawConnection, error) {
|
conn := newConnection(staticConnFactory(spy), nil)
|
||||||
return spy, nil
|
|
||||||
}
|
|
||||||
conn := newConnection(target, nil)
|
|
||||||
conn.connect()
|
conn.connect()
|
||||||
conn.Close()
|
conn.Shutdown()
|
||||||
|
|
||||||
if len(spy.closeCalls) != 1 {
|
if len(spy.closeCalls) != 1 {
|
||||||
t.Fatalf("Expected 'Close' to be called once, got %v", len(spy.closeCalls))
|
t.Fatalf("Expected 'Close' to be called once, got %v", len(spy.closeCalls))
|
||||||
|
|
@ -178,7 +190,7 @@ func TestCloseIgnoresNoConnection(t *testing.T) {
|
||||||
conn := &ManagedConnection{
|
conn := &ManagedConnection{
|
||||||
closeChan: make(chan struct{}, 1),
|
closeChan: make(chan struct{}, 1),
|
||||||
}
|
}
|
||||||
got := conn.Close()
|
got := conn.Shutdown()
|
||||||
|
|
||||||
if got != nil {
|
if got != nil {
|
||||||
t.Fatalf("Expected no error, got %v", got)
|
t.Fatalf("Expected no error, got %v", got)
|
||||||
|
|
@ -186,60 +198,46 @@ func TestCloseIgnoresNoConnection(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDurableConnectionWhenConnectionBreaksDown(t *testing.T) {
|
func TestDurableConnectionWhenConnectionBreaksDown(t *testing.T) {
|
||||||
testConn := &inspectableConnection{
|
testPayload := "test"
|
||||||
nextReaderCalls: make(chan struct{}),
|
reconnectChan := make(chan struct{})
|
||||||
writeMessageCalls: make(chan struct{}),
|
|
||||||
closeCalls: make(chan struct{}),
|
|
||||||
|
|
||||||
nextReaderFunc: func() (int, io.Reader, error) {
|
upgrader := websocket.Upgrader{}
|
||||||
return 1, nil, errors.New("next reader errored")
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
},
|
c, err := upgrader.Upgrade(w, r, nil)
|
||||||
}
|
if err != nil {
|
||||||
connectAttempts := make(chan struct{})
|
return
|
||||||
connFactory = func(_ string) (rawConnection, error) {
|
}
|
||||||
connectAttempts <- struct{}{}
|
|
||||||
return testConn, nil
|
|
||||||
}
|
|
||||||
conn := NewDurableSendingConnection(target)
|
|
||||||
|
|
||||||
// the connection is constantly created, tried to read from
|
// Waits for a message to be sent before dropping the connection.
|
||||||
// and closed because NextReader (which holds the connection
|
<-reconnectChan
|
||||||
// open) fails.
|
c.Close()
|
||||||
for i := 0; i < 100; i++ {
|
}))
|
||||||
<-connectAttempts
|
defer s.Close()
|
||||||
<-testConn.nextReaderCalls
|
|
||||||
<-testConn.closeCalls
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enter the reconnect loop
|
logger := ktesting.TestLogger(t)
|
||||||
<-connectAttempts
|
target := "ws" + strings.TrimPrefix(s.URL, "http")
|
||||||
|
conn := NewDurableSendingConnection(target, logger)
|
||||||
|
defer conn.Shutdown()
|
||||||
|
|
||||||
// Call 'Close' asynchronously and wait for it to reach
|
for i := 0; i < 10; i++ {
|
||||||
// the channel.
|
err := wait.PollImmediate(50*time.Millisecond, 5*time.Second, func() (bool, error) {
|
||||||
go conn.Close()
|
if err := conn.Send(testPayload); err != nil {
|
||||||
<-testConn.closeCalls
|
return false, nil
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
})
|
||||||
|
|
||||||
// Advance the reconnect loop until 'Close' is called.
|
if err != nil {
|
||||||
<-testConn.nextReaderCalls
|
t.Errorf("Timed out trying to send a message: %v", err)
|
||||||
<-testConn.closeCalls
|
}
|
||||||
|
|
||||||
// Wait for the final call to 'Close' (when the loop is aborted)
|
// Message successfully sent, instruct the server to drop the connection.
|
||||||
<-testConn.closeCalls
|
reconnectChan <- struct{}{}
|
||||||
|
|
||||||
if len(connectAttempts) > 1 {
|
|
||||||
t.Fatalf("Expected at most one connection attempts, got %v", len(connectAttempts))
|
|
||||||
}
|
|
||||||
if len(testConn.nextReaderCalls) > 1 {
|
|
||||||
t.Fatalf("Expected at most one calls to 'NextReader', got %v", len(testConn.nextReaderCalls))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnectFailureReturnsError(t *testing.T) {
|
func TestConnectFailureReturnsError(t *testing.T) {
|
||||||
connFactory = func(_ string) (rawConnection, error) {
|
conn := newConnection(errConnFactory(ErrConnectionNotEstablished), nil)
|
||||||
return nil, ErrConnectionNotEstablished
|
|
||||||
}
|
|
||||||
|
|
||||||
conn := newConnection(target, nil)
|
|
||||||
|
|
||||||
// Shorten the connection backoff duration for this test
|
// Shorten the connection backoff duration for this test
|
||||||
conn.connectionBackoff.Duration = 1 * time.Millisecond
|
conn.connectionBackoff.Duration = 1 * time.Millisecond
|
||||||
|
|
@ -252,10 +250,70 @@ func TestConnectFailureReturnsError(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestKeepaliveWithNoConnectionReturnsError(t *testing.T) {
|
func TestKeepaliveWithNoConnectionReturnsError(t *testing.T) {
|
||||||
conn := newConnection(target, nil)
|
conn := newConnection(nil, nil)
|
||||||
got := conn.keepalive()
|
got := conn.keepalive()
|
||||||
|
|
||||||
if got == nil {
|
if got == nil {
|
||||||
t.Fatal("Expected an error but got none")
|
t.Fatal("Expected an error but got none")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnectLoopIsStopped(t *testing.T) {
|
||||||
|
conn := newConnection(errConnFactory(errors.New("connection error")), nil)
|
||||||
|
|
||||||
|
errorChan := make(chan error)
|
||||||
|
go func() {
|
||||||
|
errorChan <- conn.connect()
|
||||||
|
}()
|
||||||
|
|
||||||
|
conn.Shutdown()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errorChan:
|
||||||
|
if err != errShuttingDown {
|
||||||
|
t.Errorf("Wrong 'connect' error, got %v, want %v", err, errShuttingDown)
|
||||||
|
}
|
||||||
|
case <-time.After(propagationTimeout):
|
||||||
|
t.Error("Timed out waiting for the keepalive loop to stop.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeepaliveLoopIsStopped(t *testing.T) {
|
||||||
|
spy := &inspectableConnection{
|
||||||
|
nextReaderFunc: func() (int, io.Reader, error) {
|
||||||
|
return websocket.TextMessage, nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn := newConnection(staticConnFactory(spy), nil)
|
||||||
|
conn.connect()
|
||||||
|
|
||||||
|
errorChan := make(chan error)
|
||||||
|
go func() {
|
||||||
|
errorChan <- conn.keepalive()
|
||||||
|
}()
|
||||||
|
|
||||||
|
conn.Shutdown()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errorChan:
|
||||||
|
if err != errShuttingDown {
|
||||||
|
t.Errorf("Wrong 'keepalive' error, got %v, want %v", err, errShuttingDown)
|
||||||
|
}
|
||||||
|
case <-time.After(propagationTimeout):
|
||||||
|
t.Error("Timed out waiting for the keepalive loop to stop.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDoubleShutdown(t *testing.T) {
|
||||||
|
spy := &inspectableConnection{
|
||||||
|
closeCalls: make(chan struct{}, 2), // potentially allow 2 calls
|
||||||
|
}
|
||||||
|
conn := newConnection(staticConnFactory(spy), nil)
|
||||||
|
conn.connect()
|
||||||
|
conn.Shutdown()
|
||||||
|
conn.Shutdown()
|
||||||
|
|
||||||
|
if want, got := 1, len(spy.closeCalls); want != got {
|
||||||
|
t.Errorf("Wrong 'Close' callcount, got %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue