diff --git a/bindings/rabbitmq/rabbitmq.go b/bindings/rabbitmq/rabbitmq.go index f04165557..71718f64a 100644 --- a/bindings/rabbitmq/rabbitmq.go +++ b/bindings/rabbitmq/rabbitmq.go @@ -39,12 +39,18 @@ const ( deleteWhenUnused = "deleteWhenUnused" prefetchCount = "prefetchCount" maxPriority = "maxPriority" + reconnectWaitSecondsKey = "reconnectWaitInSeconds" rabbitMQQueueMessageTTLKey = "x-message-ttl" rabbitMQMaxPriorityKey = "x-max-priority" defaultBase = 10 defaultBitSize = 0 + + errorChannelConnection = "channel/connection is not open" + defaultReconnectWait = 5 * time.Second ) +var errClosed = errors.New("component is stopped") + // RabbitMQ allows sending/receiving data to/from RabbitMQ. type RabbitMQ struct { connection *amqp.Connection @@ -55,6 +61,10 @@ type RabbitMQ struct { closed atomic.Bool closeCh chan struct{} wg sync.WaitGroup + + // used for reconnect + channelMutex sync.RWMutex + notifyRabbitChannelClose chan *amqp.Error } // Metadata is the rabbitmq config. @@ -66,6 +76,7 @@ type rabbitMQMetadata struct { DeleteWhenUnused bool `json:"deleteWhenUnused,string"` PrefetchCount int `json:"prefetchCount"` MaxPriority *uint8 `json:"maxPriority"` // Priority Queue deactivated if nil + reconnectWait time.Duration defaultQueueTTL *time.Duration } @@ -84,34 +95,95 @@ func (r *RabbitMQ) Init(_ context.Context, metadata bindings.Metadata) error { return err } - conn, err := amqp.Dial(r.metadata.Host) + err = r.connect() if err != nil { return err } - - ch, err := conn.Channel() - if err != nil { - return err - } - ch.Qos(r.metadata.PrefetchCount, 0, true) - r.connection = conn - r.channel = ch - - q, err := r.declareQueue() - if err != nil { - return err - } - - r.queue = q - + r.notifyRabbitChannelClose = make(chan *amqp.Error, 1) + r.channel.NotifyClose(r.notifyRabbitChannelClose) + r.wg.Add(1) + go func() { + defer r.wg.Done() + r.reconnectWhenNecessary() + }() return nil } +func (r *RabbitMQ) reconnectWhenNecessary() { + for { + select { + case <-r.closeCh: + return + case e := <-r.notifyRabbitChannelClose: + // If this error can not be recovered, first wait and then retry. + if e != nil && !e.Recover { + select { + case <-time.After(r.metadata.reconnectWait): + case <-r.closeCh: + return + } + } + r.channelMutex.Lock() + if r.connection != nil && !r.connection.IsClosed() { + ch, err := r.connection.Channel() + if err == nil { + r.notifyRabbitChannelClose = make(chan *amqp.Error, 1) + ch.NotifyClose(r.notifyRabbitChannelClose) + r.channel = ch + r.channelMutex.Unlock() + continue + } + // if encounter err fallback to reconnect connection + } + r.channelMutex.Unlock() + // keep trying to reconnect + for { + err := r.connect() + if err == nil { + break + } + if err == errClosed { + return + } + r.logger.Warnf("Reconnect failed: %v", err) + + select { + case <-time.After(r.metadata.reconnectWait): + case <-r.closeCh: + return + } + } + } + } +} + +func dial(uri string) (conn *amqp.Connection, ch *amqp.Channel, err error) { + conn, err = amqp.Dial(uri) + if err != nil { + return nil, nil, err + } + + ch, err = conn.Channel() + if err != nil { + conn.Close() + return nil, nil, err + } + + return conn, ch, nil +} + func (r *RabbitMQ) Operations() []bindings.OperationKind { return []bindings.OperationKind{bindings.CreateOperation} } func (r *RabbitMQ) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) { + // check if connection channel to rabbitmq is open + r.channelMutex.RLock() + ch := r.channel + r.channelMutex.RUnlock() + if ch == nil { + return nil, errors.New(errorChannelConnection) + } pub := amqp.Publishing{ DeliveryMode: amqp.Persistent, ContentType: "text/plain", @@ -119,18 +191,16 @@ func (r *RabbitMQ) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bi } contentType, ok := contribMetadata.TryGetContentType(req.Metadata) - if ok { pub.ContentType = contentType } - ttl, ok, err := contribMetadata.TryGetTTL(req.Metadata) - if err != nil { - return nil, err - } - // The default time to live has been set in the queue // We allow overriding on each call, by setting a value in request metadata + ttl, ok, err := contribMetadata.TryGetTTL(req.Metadata) + if err != nil { + return nil, fmt.Errorf("failed to get TTL: %w", err) + } if ok { // RabbitMQ expects the duration in ms pub.Expiration = strconv.FormatInt(ttl.Milliseconds(), 10) @@ -140,59 +210,57 @@ func (r *RabbitMQ) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bi if err != nil { return nil, err } - if ok { pub.Priority = priority } - err = r.channel.PublishWithContext(ctx, "", r.metadata.QueueName, false, false, pub) - + err = ch.PublishWithContext(ctx, "", r.metadata.QueueName, false, false, pub) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to publish message: %w", err) } return nil, nil } func (r *RabbitMQ) parseMetadata(metadata bindings.Metadata) error { - m := rabbitMQMetadata{} + m := rabbitMQMetadata{reconnectWait: defaultReconnectWait} - if val, ok := metadata.Properties[host]; ok && val != "" { + if val := metadata.Properties[host]; val != "" { m.Host = val } else { - return errors.New("rabbitMQ binding error: missing host address") + return errors.New("missing host address") } - if val, ok := metadata.Properties[queueName]; ok && val != "" { + if val := metadata.Properties[queueName]; val != "" { m.QueueName = val } else { - return errors.New("rabbitMQ binding error: missing queue Name") + return errors.New("missing queue Name") } - if val, ok := metadata.Properties[durable]; ok && val != "" { + if val := metadata.Properties[durable]; val != "" { m.Durable = utils.IsTruthy(val) } - if val, ok := metadata.Properties[deleteWhenUnused]; ok && val != "" { + if val := metadata.Properties[deleteWhenUnused]; val != "" { m.DeleteWhenUnused = utils.IsTruthy(val) } - if val, ok := metadata.Properties[prefetchCount]; ok && val != "" { + if val := metadata.Properties[prefetchCount]; val != "" { parsedVal, err := strconv.ParseInt(val, defaultBase, defaultBitSize) if err != nil { - return fmt.Errorf("rabbitMQ binding error: can't parse prefetchCount field: %s", err) + return fmt.Errorf("can't parse prefetchCount field: %s", err) } m.PrefetchCount = int(parsedVal) } - if val, ok := metadata.Properties[exclusive]; ok && val != "" { + if val := metadata.Properties[exclusive]; val != "" { m.Exclusive = utils.IsTruthy(val) } - if val, ok := metadata.Properties[maxPriority]; ok && val != "" { + if val := metadata.Properties[maxPriority]; val != "" { parsedVal, err := strconv.ParseUint(val, defaultBase, defaultBitSize) if err != nil { - return fmt.Errorf("rabbitMQ binding error: can't parse maxPriority field: %s", err) + return fmt.Errorf("can't parse maxPriority field: %s", err) } maxPriority := uint8(parsedVal) @@ -204,11 +272,16 @@ func (r *RabbitMQ) parseMetadata(metadata bindings.Metadata) error { m.MaxPriority = &maxPriority } - ttl, ok, err := contribMetadata.TryGetTTL(metadata.Properties) - if err != nil { - return err + if val := metadata.Properties[reconnectWaitSecondsKey]; val != "" { + if intVal, err := strconv.Atoi(val); err == nil && intVal > 0 { + m.reconnectWait = time.Duration(intVal) * time.Second + } } + ttl, ok, err := contribMetadata.TryGetTTL(metadata.Properties) + if err != nil { + return fmt.Errorf("failed to parse TTL: %w", err) + } if ok { m.defaultQueueTTL = &ttl } @@ -218,7 +291,7 @@ func (r *RabbitMQ) parseMetadata(metadata bindings.Metadata) error { return nil } -func (r *RabbitMQ) declareQueue() (amqp.Queue, error) { +func (r *RabbitMQ) declareQueue(channel *amqp.Channel) (amqp.Queue, error) { args := amqp.Table{} if r.metadata.defaultQueueTTL != nil { // Value in ms @@ -230,7 +303,7 @@ func (r *RabbitMQ) declareQueue() (amqp.Queue, error) { args[rabbitMQMaxPriorityKey] = *r.metadata.MaxPriority } - return r.channel.QueueDeclare(r.metadata.QueueName, r.metadata.Durable, r.metadata.DeleteWhenUnused, r.metadata.Exclusive, false, args) + return channel.QueueDeclare(r.metadata.QueueName, r.metadata.Durable, r.metadata.DeleteWhenUnused, r.metadata.Exclusive, false, args) } func (r *RabbitMQ) Read(ctx context.Context, handler bindings.Handler) error { @@ -238,58 +311,147 @@ func (r *RabbitMQ) Read(ctx context.Context, handler bindings.Handler) error { return errors.New("binding already closed") } - msgs, err := r.channel.Consume( - r.queue.Name, - "", - false, - false, - false, - false, - nil, - ) - if err != nil { - return err - } - readCtx, cancel := context.WithCancel(ctx) - r.wg.Add(2) go func() { - defer r.wg.Done() - defer cancel() select { case <-r.closeCh: + // nop case <-readCtx.Done(): + // nop } + r.wg.Done() + cancel() }() - go func() { + // unless closed, keep trying to read and handle messages forever defer r.wg.Done() - var err error for { - select { - case <-readCtx.Done(): - return - case d := <-msgs: - _, err = handler(readCtx, &bindings.ReadResponse{ - Data: d.Body, - }) - if err != nil { - r.channel.Nack(d.DeliveryTag, false, true) + var ( + msgs <-chan amqp.Delivery + err error + declaredQueueName string + ch *amqp.Channel + ) + r.channelMutex.RLock() + declaredQueueName = r.queue.Name + ch = r.channel + r.channelMutex.RUnlock() + + if ch != nil { + msgs, err = ch.Consume( + declaredQueueName, + "", + false, + false, + false, + false, + nil, + ) + if err == nil { + // all good, handle messages + r.handleMessage(readCtx, handler, msgs, ch) } else { - r.channel.Ack(d.DeliveryTag, false) + r.logger.Errorf("Error consuming messages from queue [%s]: %v", r.queue.Name, err) } } + + // something went wrong, wait for reconnect + select { + case <-time.After(r.metadata.reconnectWait): + continue + case <-readCtx.Done(): + r.logger.Info("Input binding closed, stop fetching message") + return + } } }() - return nil } +// handleMessage handles incoming messages from RabbitMQ +func (r *RabbitMQ) handleMessage(ctx context.Context, handler bindings.Handler, msgCh <-chan amqp.Delivery, ch *amqp.Channel) { + for { + select { + case <-ctx.Done(): + return + case d, ok := <-msgCh: + if !ok { + r.logger.Info("Input binding channel closed") + return + } + _, err := handler(ctx, &bindings.ReadResponse{ + Data: d.Body, + }) + if err != nil { + ch.Nack(d.DeliveryTag, false, true) + } else { + ch.Ack(d.DeliveryTag, false) + } + } + } +} + func (r *RabbitMQ) Close() error { if r.closed.CompareAndSwap(false, true) { close(r.closeCh) } defer r.wg.Wait() - return r.channel.Close() + r.channelMutex.Lock() + defer r.channelMutex.Unlock() + return r.reset() +} + +func (r *RabbitMQ) connect() error { + if r.closed.Load() { + // Do not reconnect on stopped service. + return errClosed + } + + conn, ch, err := dial(r.metadata.Host) + if err != nil { + return err + } + + ch.Qos(r.metadata.PrefetchCount, 0, true) + q, err := r.declareQueue(ch) + if err != nil { + return err + } + + r.notifyRabbitChannelClose = make(chan *amqp.Error, 1) + ch.NotifyClose(r.notifyRabbitChannelClose) + + r.channelMutex.Lock() + // try to close the old channel and connection, ignore the error + _ = r.reset() //nolint:errcheck + r.connection, r.channel, r.queue = conn, ch, q + r.channelMutex.Unlock() + + r.logger.Info("Connected to RabbitMQ") + + return nil +} + +// reset the channel and the connection when encountered a connection error. +// this function call should be wrapped by channelMutex. +func (r *RabbitMQ) reset() (err error) { + if r.channel != nil { + if err = r.channel.Close(); err != nil { + r.logger.Warnf("Reset: channel.Close() failed: %v", err) + } + r.channel = nil + } + + if r.connection != nil { + if err2 := r.connection.Close(); err2 != nil { + r.logger.Warnf("Reset: connection.Close() failed: %v", err2) + if err == nil { + err = err2 + } + } + r.connection = nil + } + + return err } diff --git a/bindings/rabbitmq/rabbitmq_integration_test.go b/bindings/rabbitmq/rabbitmq_integration_test.go index 79bdc33a4..6428b8a62 100644 --- a/bindings/rabbitmq/rabbitmq_integration_test.go +++ b/bindings/rabbitmq/rabbitmq_integration_test.go @@ -120,6 +120,59 @@ func TestQueuesWithTTL(t *testing.T) { assert.NoError(t, r.Close()) } +func TestQueuesReconnect(t *testing.T) { + rabbitmqHost := getTestRabbitMQHost() + assert.NotEmpty(t, rabbitmqHost, fmt.Sprintf("RabbitMQ host configuration must be set in environment variable '%s' (example 'amqp://guest:guest@localhost:5672/')", testRabbitMQHostEnvKey)) + + queueName := uuid.New().String() + durable := true + exclusive := false + + metadata := bindings.Metadata{ + Base: contribMetadata.Base{ + Name: "testQueue", + Properties: map[string]string{ + "queueName": queueName, + "host": rabbitmqHost, + "deleteWhenUnused": strconv.FormatBool(exclusive), + "durable": strconv.FormatBool(durable), + }, + }, + } + + var messageReceivedCount int + var handler bindings.Handler = func(ctx context.Context, in *bindings.ReadResponse) ([]byte, error) { + messageReceivedCount++ + return nil, nil + } + + logger := logger.NewLogger("test") + + r := NewRabbitMQ(logger).(*RabbitMQ) + err := r.Init(context.Background(), metadata) + assert.Nil(t, err) + + err = r.Read(context.Background(), handler) + assert.Nil(t, err) + + const tooLateMsgContent = "success_msg1" + _, err = r.Invoke(context.Background(), &bindings.InvokeRequest{Data: []byte(tooLateMsgContent)}) + assert.Nil(t, err) + + // perform a close connection with the rabbitmq server + r.channel.Close() + time.Sleep(3 * defaultReconnectWait) + + const testMsgContent = "reconnect_msg" + _, err = r.Invoke(context.Background(), &bindings.InvokeRequest{Data: []byte(testMsgContent)}) + assert.Nil(t, err) + + time.Sleep(defaultReconnectWait) + // sending 2 messages, one before the reconnect and one after + assert.Equal(t, 2, messageReceivedCount) + assert.NoError(t, r.Close()) +} + func TestPublishingWithTTL(t *testing.T) { rabbitmqHost := getTestRabbitMQHost() assert.NotEmpty(t, rabbitmqHost, fmt.Sprintf("RabbitMQ host configuration must be set in environment variable '%s' (example 'amqp://guest:guest@localhost:5672/')", testRabbitMQHostEnvKey)) @@ -196,7 +249,7 @@ func TestPublishingWithTTL(t *testing.T) { assert.Equal(t, testMsgContent, msgBody) assert.NoError(t, rabbitMQBinding1.Close()) - assert.NoError(t, rabbitMQBinding1.Close()) + assert.NoError(t, rabbitMQBinding2.Close()) } func TestExclusiveQueue(t *testing.T) { diff --git a/bindings/rabbitmq/rabbitmq_test.go b/bindings/rabbitmq/rabbitmq_test.go index 0e2a58d24..d43e28074 100644 --- a/bindings/rabbitmq/rabbitmq_test.go +++ b/bindings/rabbitmq/rabbitmq_test.go @@ -30,20 +30,24 @@ func TestParseMetadata(t *testing.T) { oneSecondTTL := time.Second testCases := []struct { - name string - properties map[string]string - expectedDeleteWhenUnused bool - expectedDurable bool - expectedExclusive bool - expectedTTL *time.Duration - expectedPrefetchCount int - expectedMaxPriority *uint8 + name string + properties map[string]string + expectedDeleteWhenUnused bool + expectedDurable bool + expectedExclusive bool + expectedTTL *time.Duration + expectedPrefetchCount int + expectedMaxPriority *uint8 + expectedReconnectWaitCheck func(expect time.Duration) bool }{ { name: "Delete / Durable", properties: map[string]string{"queueName": queueName, "host": host, "deleteWhenUnused": "true", "durable": "true"}, expectedDeleteWhenUnused: true, expectedDurable: true, + expectedReconnectWaitCheck: func(expect time.Duration) bool { + return expect == defaultReconnectWait + }, }, { name: "Not Delete / Not durable", @@ -100,6 +104,15 @@ func TestParseMetadata(t *testing.T) { return &v }(), }, + { + name: "With reconnectWait 10 second", + properties: map[string]string{"queueName": queueName, "host": host, "deleteWhenUnused": "false", "durable": "false", "reconnectWaitInSeconds": "10"}, + expectedDeleteWhenUnused: false, + expectedDurable: false, + expectedReconnectWaitCheck: func(expect time.Duration) bool { + return expect == 10*time.Second + }, + }, } for _, tt := range testCases { @@ -108,7 +121,7 @@ func TestParseMetadata(t *testing.T) { m.Properties = tt.properties r := RabbitMQ{logger: logger.NewLogger("test")} err := r.parseMetadata(m) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, queueName, r.metadata.QueueName) assert.Equal(t, host, r.metadata.Host) assert.Equal(t, tt.expectedDeleteWhenUnused, r.metadata.DeleteWhenUnused) @@ -117,6 +130,9 @@ func TestParseMetadata(t *testing.T) { assert.Equal(t, tt.expectedPrefetchCount, r.metadata.PrefetchCount) assert.Equal(t, tt.expectedExclusive, r.metadata.Exclusive) assert.Equal(t, tt.expectedMaxPriority, r.metadata.MaxPriority) + if tt.expectedReconnectWaitCheck != nil { + assert.True(t, tt.expectedReconnectWaitCheck(r.metadata.reconnectWait)) + } }) } }