Support reconnection between sidecar and broker for RabbitMQ bindings (#2565)

Signed-off-by: zhangchao <zchao9100@gmail.com>
Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
Signed-off-by: Taction <zchao9100@gmail.com>
Co-authored-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
Co-authored-by: Josh van Leeuwen <me@joshvanl.dev>
This commit is contained in:
Taction 2023-03-24 01:21:45 +08:00 committed by GitHub
parent e7c51333f8
commit d9e1cc4e86
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 315 additions and 84 deletions

View File

@ -39,12 +39,18 @@ const (
deleteWhenUnused = "deleteWhenUnused"
prefetchCount = "prefetchCount"
maxPriority = "maxPriority"
reconnectWaitSecondsKey = "reconnectWaitInSeconds"
rabbitMQQueueMessageTTLKey = "x-message-ttl"
rabbitMQMaxPriorityKey = "x-max-priority"
defaultBase = 10
defaultBitSize = 0
errorChannelConnection = "channel/connection is not open"
defaultReconnectWait = 5 * time.Second
)
var errClosed = errors.New("component is stopped")
// RabbitMQ allows sending/receiving data to/from RabbitMQ.
type RabbitMQ struct {
connection *amqp.Connection
@ -55,6 +61,10 @@ type RabbitMQ struct {
closed atomic.Bool
closeCh chan struct{}
wg sync.WaitGroup
// used for reconnect
channelMutex sync.RWMutex
notifyRabbitChannelClose chan *amqp.Error
}
// Metadata is the rabbitmq config.
@ -66,6 +76,7 @@ type rabbitMQMetadata struct {
DeleteWhenUnused bool `json:"deleteWhenUnused,string"`
PrefetchCount int `json:"prefetchCount"`
MaxPriority *uint8 `json:"maxPriority"` // Priority Queue deactivated if nil
reconnectWait time.Duration
defaultQueueTTL *time.Duration
}
@ -84,34 +95,95 @@ func (r *RabbitMQ) Init(_ context.Context, metadata bindings.Metadata) error {
return err
}
conn, err := amqp.Dial(r.metadata.Host)
err = r.connect()
if err != nil {
return err
}
ch, err := conn.Channel()
if err != nil {
return err
}
ch.Qos(r.metadata.PrefetchCount, 0, true)
r.connection = conn
r.channel = ch
q, err := r.declareQueue()
if err != nil {
return err
}
r.queue = q
r.notifyRabbitChannelClose = make(chan *amqp.Error, 1)
r.channel.NotifyClose(r.notifyRabbitChannelClose)
r.wg.Add(1)
go func() {
defer r.wg.Done()
r.reconnectWhenNecessary()
}()
return nil
}
func (r *RabbitMQ) reconnectWhenNecessary() {
for {
select {
case <-r.closeCh:
return
case e := <-r.notifyRabbitChannelClose:
// If this error can not be recovered, first wait and then retry.
if e != nil && !e.Recover {
select {
case <-time.After(r.metadata.reconnectWait):
case <-r.closeCh:
return
}
}
r.channelMutex.Lock()
if r.connection != nil && !r.connection.IsClosed() {
ch, err := r.connection.Channel()
if err == nil {
r.notifyRabbitChannelClose = make(chan *amqp.Error, 1)
ch.NotifyClose(r.notifyRabbitChannelClose)
r.channel = ch
r.channelMutex.Unlock()
continue
}
// if encounter err fallback to reconnect connection
}
r.channelMutex.Unlock()
// keep trying to reconnect
for {
err := r.connect()
if err == nil {
break
}
if err == errClosed {
return
}
r.logger.Warnf("Reconnect failed: %v", err)
select {
case <-time.After(r.metadata.reconnectWait):
case <-r.closeCh:
return
}
}
}
}
}
func dial(uri string) (conn *amqp.Connection, ch *amqp.Channel, err error) {
conn, err = amqp.Dial(uri)
if err != nil {
return nil, nil, err
}
ch, err = conn.Channel()
if err != nil {
conn.Close()
return nil, nil, err
}
return conn, ch, nil
}
func (r *RabbitMQ) Operations() []bindings.OperationKind {
return []bindings.OperationKind{bindings.CreateOperation}
}
func (r *RabbitMQ) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
// check if connection channel to rabbitmq is open
r.channelMutex.RLock()
ch := r.channel
r.channelMutex.RUnlock()
if ch == nil {
return nil, errors.New(errorChannelConnection)
}
pub := amqp.Publishing{
DeliveryMode: amqp.Persistent,
ContentType: "text/plain",
@ -119,18 +191,16 @@ func (r *RabbitMQ) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bi
}
contentType, ok := contribMetadata.TryGetContentType(req.Metadata)
if ok {
pub.ContentType = contentType
}
ttl, ok, err := contribMetadata.TryGetTTL(req.Metadata)
if err != nil {
return nil, err
}
// The default time to live has been set in the queue
// We allow overriding on each call, by setting a value in request metadata
ttl, ok, err := contribMetadata.TryGetTTL(req.Metadata)
if err != nil {
return nil, fmt.Errorf("failed to get TTL: %w", err)
}
if ok {
// RabbitMQ expects the duration in ms
pub.Expiration = strconv.FormatInt(ttl.Milliseconds(), 10)
@ -140,59 +210,57 @@ func (r *RabbitMQ) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bi
if err != nil {
return nil, err
}
if ok {
pub.Priority = priority
}
err = r.channel.PublishWithContext(ctx, "", r.metadata.QueueName, false, false, pub)
err = ch.PublishWithContext(ctx, "", r.metadata.QueueName, false, false, pub)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to publish message: %w", err)
}
return nil, nil
}
func (r *RabbitMQ) parseMetadata(metadata bindings.Metadata) error {
m := rabbitMQMetadata{}
m := rabbitMQMetadata{reconnectWait: defaultReconnectWait}
if val, ok := metadata.Properties[host]; ok && val != "" {
if val := metadata.Properties[host]; val != "" {
m.Host = val
} else {
return errors.New("rabbitMQ binding error: missing host address")
return errors.New("missing host address")
}
if val, ok := metadata.Properties[queueName]; ok && val != "" {
if val := metadata.Properties[queueName]; val != "" {
m.QueueName = val
} else {
return errors.New("rabbitMQ binding error: missing queue Name")
return errors.New("missing queue Name")
}
if val, ok := metadata.Properties[durable]; ok && val != "" {
if val := metadata.Properties[durable]; val != "" {
m.Durable = utils.IsTruthy(val)
}
if val, ok := metadata.Properties[deleteWhenUnused]; ok && val != "" {
if val := metadata.Properties[deleteWhenUnused]; val != "" {
m.DeleteWhenUnused = utils.IsTruthy(val)
}
if val, ok := metadata.Properties[prefetchCount]; ok && val != "" {
if val := metadata.Properties[prefetchCount]; val != "" {
parsedVal, err := strconv.ParseInt(val, defaultBase, defaultBitSize)
if err != nil {
return fmt.Errorf("rabbitMQ binding error: can't parse prefetchCount field: %s", err)
return fmt.Errorf("can't parse prefetchCount field: %s", err)
}
m.PrefetchCount = int(parsedVal)
}
if val, ok := metadata.Properties[exclusive]; ok && val != "" {
if val := metadata.Properties[exclusive]; val != "" {
m.Exclusive = utils.IsTruthy(val)
}
if val, ok := metadata.Properties[maxPriority]; ok && val != "" {
if val := metadata.Properties[maxPriority]; val != "" {
parsedVal, err := strconv.ParseUint(val, defaultBase, defaultBitSize)
if err != nil {
return fmt.Errorf("rabbitMQ binding error: can't parse maxPriority field: %s", err)
return fmt.Errorf("can't parse maxPriority field: %s", err)
}
maxPriority := uint8(parsedVal)
@ -204,11 +272,16 @@ func (r *RabbitMQ) parseMetadata(metadata bindings.Metadata) error {
m.MaxPriority = &maxPriority
}
ttl, ok, err := contribMetadata.TryGetTTL(metadata.Properties)
if err != nil {
return err
if val := metadata.Properties[reconnectWaitSecondsKey]; val != "" {
if intVal, err := strconv.Atoi(val); err == nil && intVal > 0 {
m.reconnectWait = time.Duration(intVal) * time.Second
}
}
ttl, ok, err := contribMetadata.TryGetTTL(metadata.Properties)
if err != nil {
return fmt.Errorf("failed to parse TTL: %w", err)
}
if ok {
m.defaultQueueTTL = &ttl
}
@ -218,7 +291,7 @@ func (r *RabbitMQ) parseMetadata(metadata bindings.Metadata) error {
return nil
}
func (r *RabbitMQ) declareQueue() (amqp.Queue, error) {
func (r *RabbitMQ) declareQueue(channel *amqp.Channel) (amqp.Queue, error) {
args := amqp.Table{}
if r.metadata.defaultQueueTTL != nil {
// Value in ms
@ -230,7 +303,7 @@ func (r *RabbitMQ) declareQueue() (amqp.Queue, error) {
args[rabbitMQMaxPriorityKey] = *r.metadata.MaxPriority
}
return r.channel.QueueDeclare(r.metadata.QueueName, r.metadata.Durable, r.metadata.DeleteWhenUnused, r.metadata.Exclusive, false, args)
return channel.QueueDeclare(r.metadata.QueueName, r.metadata.Durable, r.metadata.DeleteWhenUnused, r.metadata.Exclusive, false, args)
}
func (r *RabbitMQ) Read(ctx context.Context, handler bindings.Handler) error {
@ -238,58 +311,147 @@ func (r *RabbitMQ) Read(ctx context.Context, handler bindings.Handler) error {
return errors.New("binding already closed")
}
msgs, err := r.channel.Consume(
r.queue.Name,
"",
false,
false,
false,
false,
nil,
)
if err != nil {
return err
}
readCtx, cancel := context.WithCancel(ctx)
r.wg.Add(2)
go func() {
defer r.wg.Done()
defer cancel()
select {
case <-r.closeCh:
// nop
case <-readCtx.Done():
// nop
}
r.wg.Done()
cancel()
}()
go func() {
// unless closed, keep trying to read and handle messages forever
defer r.wg.Done()
var err error
for {
select {
case <-readCtx.Done():
return
case d := <-msgs:
_, err = handler(readCtx, &bindings.ReadResponse{
Data: d.Body,
})
if err != nil {
r.channel.Nack(d.DeliveryTag, false, true)
var (
msgs <-chan amqp.Delivery
err error
declaredQueueName string
ch *amqp.Channel
)
r.channelMutex.RLock()
declaredQueueName = r.queue.Name
ch = r.channel
r.channelMutex.RUnlock()
if ch != nil {
msgs, err = ch.Consume(
declaredQueueName,
"",
false,
false,
false,
false,
nil,
)
if err == nil {
// all good, handle messages
r.handleMessage(readCtx, handler, msgs, ch)
} else {
r.channel.Ack(d.DeliveryTag, false)
r.logger.Errorf("Error consuming messages from queue [%s]: %v", r.queue.Name, err)
}
}
// something went wrong, wait for reconnect
select {
case <-time.After(r.metadata.reconnectWait):
continue
case <-readCtx.Done():
r.logger.Info("Input binding closed, stop fetching message")
return
}
}
}()
return nil
}
// handleMessage handles incoming messages from RabbitMQ
func (r *RabbitMQ) handleMessage(ctx context.Context, handler bindings.Handler, msgCh <-chan amqp.Delivery, ch *amqp.Channel) {
for {
select {
case <-ctx.Done():
return
case d, ok := <-msgCh:
if !ok {
r.logger.Info("Input binding channel closed")
return
}
_, err := handler(ctx, &bindings.ReadResponse{
Data: d.Body,
})
if err != nil {
ch.Nack(d.DeliveryTag, false, true)
} else {
ch.Ack(d.DeliveryTag, false)
}
}
}
}
func (r *RabbitMQ) Close() error {
if r.closed.CompareAndSwap(false, true) {
close(r.closeCh)
}
defer r.wg.Wait()
return r.channel.Close()
r.channelMutex.Lock()
defer r.channelMutex.Unlock()
return r.reset()
}
func (r *RabbitMQ) connect() error {
if r.closed.Load() {
// Do not reconnect on stopped service.
return errClosed
}
conn, ch, err := dial(r.metadata.Host)
if err != nil {
return err
}
ch.Qos(r.metadata.PrefetchCount, 0, true)
q, err := r.declareQueue(ch)
if err != nil {
return err
}
r.notifyRabbitChannelClose = make(chan *amqp.Error, 1)
ch.NotifyClose(r.notifyRabbitChannelClose)
r.channelMutex.Lock()
// try to close the old channel and connection, ignore the error
_ = r.reset() //nolint:errcheck
r.connection, r.channel, r.queue = conn, ch, q
r.channelMutex.Unlock()
r.logger.Info("Connected to RabbitMQ")
return nil
}
// reset the channel and the connection when encountered a connection error.
// this function call should be wrapped by channelMutex.
func (r *RabbitMQ) reset() (err error) {
if r.channel != nil {
if err = r.channel.Close(); err != nil {
r.logger.Warnf("Reset: channel.Close() failed: %v", err)
}
r.channel = nil
}
if r.connection != nil {
if err2 := r.connection.Close(); err2 != nil {
r.logger.Warnf("Reset: connection.Close() failed: %v", err2)
if err == nil {
err = err2
}
}
r.connection = nil
}
return err
}

View File

@ -120,6 +120,59 @@ func TestQueuesWithTTL(t *testing.T) {
assert.NoError(t, r.Close())
}
func TestQueuesReconnect(t *testing.T) {
rabbitmqHost := getTestRabbitMQHost()
assert.NotEmpty(t, rabbitmqHost, fmt.Sprintf("RabbitMQ host configuration must be set in environment variable '%s' (example 'amqp://guest:guest@localhost:5672/')", testRabbitMQHostEnvKey))
queueName := uuid.New().String()
durable := true
exclusive := false
metadata := bindings.Metadata{
Base: contribMetadata.Base{
Name: "testQueue",
Properties: map[string]string{
"queueName": queueName,
"host": rabbitmqHost,
"deleteWhenUnused": strconv.FormatBool(exclusive),
"durable": strconv.FormatBool(durable),
},
},
}
var messageReceivedCount int
var handler bindings.Handler = func(ctx context.Context, in *bindings.ReadResponse) ([]byte, error) {
messageReceivedCount++
return nil, nil
}
logger := logger.NewLogger("test")
r := NewRabbitMQ(logger).(*RabbitMQ)
err := r.Init(context.Background(), metadata)
assert.Nil(t, err)
err = r.Read(context.Background(), handler)
assert.Nil(t, err)
const tooLateMsgContent = "success_msg1"
_, err = r.Invoke(context.Background(), &bindings.InvokeRequest{Data: []byte(tooLateMsgContent)})
assert.Nil(t, err)
// perform a close connection with the rabbitmq server
r.channel.Close()
time.Sleep(3 * defaultReconnectWait)
const testMsgContent = "reconnect_msg"
_, err = r.Invoke(context.Background(), &bindings.InvokeRequest{Data: []byte(testMsgContent)})
assert.Nil(t, err)
time.Sleep(defaultReconnectWait)
// sending 2 messages, one before the reconnect and one after
assert.Equal(t, 2, messageReceivedCount)
assert.NoError(t, r.Close())
}
func TestPublishingWithTTL(t *testing.T) {
rabbitmqHost := getTestRabbitMQHost()
assert.NotEmpty(t, rabbitmqHost, fmt.Sprintf("RabbitMQ host configuration must be set in environment variable '%s' (example 'amqp://guest:guest@localhost:5672/')", testRabbitMQHostEnvKey))
@ -196,7 +249,7 @@ func TestPublishingWithTTL(t *testing.T) {
assert.Equal(t, testMsgContent, msgBody)
assert.NoError(t, rabbitMQBinding1.Close())
assert.NoError(t, rabbitMQBinding1.Close())
assert.NoError(t, rabbitMQBinding2.Close())
}
func TestExclusiveQueue(t *testing.T) {

View File

@ -30,20 +30,24 @@ func TestParseMetadata(t *testing.T) {
oneSecondTTL := time.Second
testCases := []struct {
name string
properties map[string]string
expectedDeleteWhenUnused bool
expectedDurable bool
expectedExclusive bool
expectedTTL *time.Duration
expectedPrefetchCount int
expectedMaxPriority *uint8
name string
properties map[string]string
expectedDeleteWhenUnused bool
expectedDurable bool
expectedExclusive bool
expectedTTL *time.Duration
expectedPrefetchCount int
expectedMaxPriority *uint8
expectedReconnectWaitCheck func(expect time.Duration) bool
}{
{
name: "Delete / Durable",
properties: map[string]string{"queueName": queueName, "host": host, "deleteWhenUnused": "true", "durable": "true"},
expectedDeleteWhenUnused: true,
expectedDurable: true,
expectedReconnectWaitCheck: func(expect time.Duration) bool {
return expect == defaultReconnectWait
},
},
{
name: "Not Delete / Not durable",
@ -100,6 +104,15 @@ func TestParseMetadata(t *testing.T) {
return &v
}(),
},
{
name: "With reconnectWait 10 second",
properties: map[string]string{"queueName": queueName, "host": host, "deleteWhenUnused": "false", "durable": "false", "reconnectWaitInSeconds": "10"},
expectedDeleteWhenUnused: false,
expectedDurable: false,
expectedReconnectWaitCheck: func(expect time.Duration) bool {
return expect == 10*time.Second
},
},
}
for _, tt := range testCases {
@ -108,7 +121,7 @@ func TestParseMetadata(t *testing.T) {
m.Properties = tt.properties
r := RabbitMQ{logger: logger.NewLogger("test")}
err := r.parseMetadata(m)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, queueName, r.metadata.QueueName)
assert.Equal(t, host, r.metadata.Host)
assert.Equal(t, tt.expectedDeleteWhenUnused, r.metadata.DeleteWhenUnused)
@ -117,6 +130,9 @@ func TestParseMetadata(t *testing.T) {
assert.Equal(t, tt.expectedPrefetchCount, r.metadata.PrefetchCount)
assert.Equal(t, tt.expectedExclusive, r.metadata.Exclusive)
assert.Equal(t, tt.expectedMaxPriority, r.metadata.MaxPriority)
if tt.expectedReconnectWaitCheck != nil {
assert.True(t, tt.expectedReconnectWaitCheck(r.metadata.reconnectWait))
}
})
}
}