569 lines
14 KiB
Go
569 lines
14 KiB
Go
/*
|
|
Copyright 2021 The Dapr Authors
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package rabbitmq
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/pem"
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"net/url"
|
|
"reflect"
|
|
"strconv"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/dapr/components-contrib/internal/utils"
|
|
|
|
amqp "github.com/rabbitmq/amqp091-go"
|
|
|
|
"github.com/dapr/components-contrib/bindings"
|
|
"github.com/dapr/components-contrib/metadata"
|
|
"github.com/dapr/kit/logger"
|
|
)
|
|
|
|
const (
|
|
host = "host"
|
|
queueName = "queueName"
|
|
exclusive = "exclusive"
|
|
durable = "durable"
|
|
deleteWhenUnused = "deleteWhenUnused"
|
|
prefetchCount = "prefetchCount"
|
|
maxPriority = "maxPriority"
|
|
reconnectWaitSecondsKey = "reconnectWaitInSeconds"
|
|
rabbitMQQueueMessageTTLKey = "x-message-ttl"
|
|
rabbitMQMaxPriorityKey = "x-max-priority"
|
|
caCert = "caCert"
|
|
clientCert = "clientCert"
|
|
clientKey = "clientKey"
|
|
externalSasl = "saslExternal"
|
|
defaultBase = 10
|
|
defaultBitSize = 0
|
|
|
|
errorChannelConnection = "channel/connection is not open"
|
|
defaultReconnectWait = 5 * time.Second
|
|
)
|
|
|
|
var errClosed = errors.New("component is stopped")
|
|
|
|
// RabbitMQ allows sending/receiving data to/from RabbitMQ.
|
|
type RabbitMQ struct {
|
|
connection *amqp.Connection
|
|
channel *amqp.Channel
|
|
metadata rabbitMQMetadata
|
|
logger logger.Logger
|
|
queue amqp.Queue
|
|
closed atomic.Bool
|
|
closeCh chan struct{}
|
|
wg sync.WaitGroup
|
|
|
|
// used for reconnect
|
|
channelMutex sync.RWMutex
|
|
notifyRabbitChannelClose chan *amqp.Error
|
|
}
|
|
|
|
// Metadata is the rabbitmq config.
|
|
type rabbitMQMetadata struct {
|
|
Host string `mapstructure:"host"`
|
|
QueueName string `mapstructure:"queueName"`
|
|
Exclusive bool `mapstructure:"exclusive"`
|
|
Durable bool `mapstructure:"durable"`
|
|
DeleteWhenUnused bool `mapstructure:"deleteWhenUnused"`
|
|
PrefetchCount int `mapstructure:"prefetchCount"`
|
|
MaxPriority *uint8 `mapstructure:"maxPriority"` // Priority Queue deactivated if nil
|
|
ReconnectWait time.Duration `mapstructure:"reconnectWaitInSeconds"`
|
|
DefaultQueueTTL *time.Duration `mapstructure:"ttlInSeconds"`
|
|
CaCert string `mapstructure:"caCert"`
|
|
ClientCert string `mapstructure:"clientCert"`
|
|
ClientKey string `mapstructure:"clientKey"`
|
|
ExternalSasl bool `mapstructure:"externalSasl"`
|
|
}
|
|
|
|
// NewRabbitMQ returns a new rabbitmq instance.
|
|
func NewRabbitMQ(logger logger.Logger) bindings.InputOutputBinding {
|
|
return &RabbitMQ{
|
|
logger: logger,
|
|
closeCh: make(chan struct{}),
|
|
}
|
|
}
|
|
|
|
// Init does metadata parsing and connection creation.
|
|
func (r *RabbitMQ) Init(_ context.Context, metadata bindings.Metadata) error {
|
|
err := r.parseMetadata(metadata)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = r.connect()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
r.notifyRabbitChannelClose = make(chan *amqp.Error, 1)
|
|
r.channel.NotifyClose(r.notifyRabbitChannelClose)
|
|
r.wg.Add(1)
|
|
go func() {
|
|
defer r.wg.Done()
|
|
r.reconnectWhenNecessary()
|
|
}()
|
|
return nil
|
|
}
|
|
|
|
func (r *RabbitMQ) reconnectWhenNecessary() {
|
|
for {
|
|
select {
|
|
case <-r.closeCh:
|
|
return
|
|
case e := <-r.notifyRabbitChannelClose:
|
|
// If this error can not be recovered, first wait and then retry.
|
|
if e != nil && !e.Recover {
|
|
select {
|
|
case <-time.After(r.metadata.ReconnectWait):
|
|
case <-r.closeCh:
|
|
return
|
|
}
|
|
}
|
|
r.channelMutex.Lock()
|
|
if r.connection != nil && !r.connection.IsClosed() {
|
|
ch, err := r.connection.Channel()
|
|
if err == nil {
|
|
r.notifyRabbitChannelClose = make(chan *amqp.Error, 1)
|
|
ch.NotifyClose(r.notifyRabbitChannelClose)
|
|
r.channel = ch
|
|
r.channelMutex.Unlock()
|
|
continue
|
|
}
|
|
// if encounter err fallback to reconnect connection
|
|
}
|
|
r.channelMutex.Unlock()
|
|
// keep trying to reconnect
|
|
for {
|
|
err := r.connect()
|
|
if err == nil {
|
|
break
|
|
}
|
|
if err == errClosed {
|
|
return
|
|
}
|
|
r.logger.Warnf("Reconnect failed: %v", err)
|
|
|
|
select {
|
|
case <-time.After(r.metadata.ReconnectWait):
|
|
case <-r.closeCh:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func dial(uri string) (conn *amqp.Connection, ch *amqp.Channel, err error) {
|
|
conn, err = amqp.Dial(uri)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
ch, err = conn.Channel()
|
|
if err != nil {
|
|
conn.Close()
|
|
return nil, nil, err
|
|
}
|
|
|
|
return conn, ch, nil
|
|
}
|
|
|
|
func dialTLS(uri string, tlsConfig *tls.Config, externalAuth bool) (conn *amqp.Connection, ch *amqp.Channel, err error) {
|
|
if externalAuth {
|
|
conn, err = amqp.DialTLS_ExternalAuth(uri, tlsConfig)
|
|
} else {
|
|
conn, err = amqp.DialTLS(uri, tlsConfig)
|
|
}
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
ch, err = conn.Channel()
|
|
if err != nil {
|
|
conn.Close()
|
|
return nil, nil, err
|
|
}
|
|
|
|
return conn, ch, nil
|
|
}
|
|
|
|
func (r *RabbitMQ) Operations() []bindings.OperationKind {
|
|
return []bindings.OperationKind{bindings.CreateOperation}
|
|
}
|
|
|
|
func (r *RabbitMQ) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
|
|
// check if connection channel to rabbitmq is open
|
|
r.channelMutex.RLock()
|
|
ch := r.channel
|
|
r.channelMutex.RUnlock()
|
|
if ch == nil {
|
|
return nil, errors.New(errorChannelConnection)
|
|
}
|
|
pub := amqp.Publishing{
|
|
DeliveryMode: amqp.Persistent,
|
|
ContentType: "text/plain",
|
|
Body: req.Data,
|
|
Headers: make(amqp.Table, len(req.Metadata)),
|
|
}
|
|
|
|
for k, v := range req.Metadata {
|
|
pub.Headers[k] = v
|
|
}
|
|
|
|
contentType, ok := metadata.TryGetContentType(req.Metadata)
|
|
if ok {
|
|
pub.ContentType = contentType
|
|
}
|
|
|
|
// The default time to live has been set in the queue
|
|
// We allow overriding on each call, by setting a value in request metadata
|
|
ttl, ok, err := metadata.TryGetTTL(req.Metadata)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get TTL: %w", err)
|
|
}
|
|
if ok {
|
|
// RabbitMQ expects the duration in ms
|
|
pub.Expiration = strconv.FormatInt(ttl.Milliseconds(), 10)
|
|
}
|
|
|
|
priority, ok, err := metadata.TryGetPriority(req.Metadata)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if ok {
|
|
pub.Priority = priority
|
|
}
|
|
|
|
err = ch.PublishWithContext(ctx, "", r.metadata.QueueName, false, false, pub)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to publish message: %w", err)
|
|
}
|
|
|
|
return nil, nil
|
|
}
|
|
|
|
func (r *RabbitMQ) parseMetadata(meta bindings.Metadata) error {
|
|
m := rabbitMQMetadata{
|
|
ReconnectWait: defaultReconnectWait,
|
|
}
|
|
|
|
metadata.DecodeMetadata(meta.Properties, &m)
|
|
|
|
if m.Host == "" {
|
|
return errors.New("missing host address")
|
|
}
|
|
|
|
if m.QueueName == "" {
|
|
return errors.New("missing queue Name")
|
|
}
|
|
|
|
if val := meta.Properties[maxPriority]; val != "" {
|
|
parsedVal, err := strconv.ParseUint(val, defaultBase, defaultBitSize)
|
|
if err != nil {
|
|
return fmt.Errorf("can't parse maxPriority field: %s", err)
|
|
}
|
|
|
|
maxPriority := uint8(parsedVal)
|
|
if parsedVal > 255 {
|
|
// Overflow
|
|
maxPriority = math.MaxUint8
|
|
}
|
|
|
|
m.MaxPriority = &maxPriority
|
|
}
|
|
|
|
if val, ok := meta.Properties[caCert]; ok && val != "" {
|
|
if !isValidPEM(val) {
|
|
return errors.New("invalid ca certificate")
|
|
}
|
|
m.CaCert = val
|
|
}
|
|
if val, ok := meta.Properties[clientCert]; ok && val != "" {
|
|
if !isValidPEM(val) {
|
|
return errors.New("invalid client certificate")
|
|
}
|
|
m.ClientCert = val
|
|
}
|
|
if val, ok := meta.Properties[clientKey]; ok && val != "" {
|
|
if !isValidPEM(val) {
|
|
return errors.New("invalid client certificate key")
|
|
}
|
|
m.ClientKey = val
|
|
}
|
|
|
|
if val, ok := meta.Properties[externalSasl]; ok && val != "" {
|
|
m.ExternalSasl = utils.IsTruthy(val)
|
|
}
|
|
|
|
if val, ok := meta.Properties[caCert]; ok && val != "" {
|
|
if !isValidPEM(val) {
|
|
return errors.New("invalid ca certificate")
|
|
}
|
|
m.CaCert = val
|
|
}
|
|
if val, ok := meta.Properties[clientCert]; ok && val != "" {
|
|
if !isValidPEM(val) {
|
|
return errors.New("invalid client certificate")
|
|
}
|
|
m.ClientCert = val
|
|
}
|
|
if val, ok := meta.Properties[clientKey]; ok && val != "" {
|
|
if !isValidPEM(val) {
|
|
return errors.New("invalid client certificate key")
|
|
}
|
|
m.ClientKey = val
|
|
}
|
|
|
|
if val, ok := meta.Properties[externalSasl]; ok && val != "" {
|
|
m.ExternalSasl = utils.IsTruthy(val)
|
|
}
|
|
|
|
ttl, ok, err := metadata.TryGetTTL(meta.Properties)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to parse TTL: %w", err)
|
|
}
|
|
if ok {
|
|
m.DefaultQueueTTL = &ttl
|
|
}
|
|
|
|
r.metadata = m
|
|
return nil
|
|
}
|
|
|
|
func (r *RabbitMQ) declareQueue(channel *amqp.Channel) (amqp.Queue, error) {
|
|
args := amqp.Table{}
|
|
if r.metadata.DefaultQueueTTL != nil {
|
|
// Value in ms
|
|
ttl := *r.metadata.DefaultQueueTTL / time.Millisecond
|
|
args[rabbitMQQueueMessageTTLKey] = int(ttl)
|
|
}
|
|
|
|
if r.metadata.MaxPriority != nil {
|
|
args[rabbitMQMaxPriorityKey] = *r.metadata.MaxPriority
|
|
}
|
|
|
|
return channel.QueueDeclare(r.metadata.QueueName, r.metadata.Durable, r.metadata.DeleteWhenUnused, r.metadata.Exclusive, false, args)
|
|
}
|
|
|
|
func (r *RabbitMQ) Read(ctx context.Context, handler bindings.Handler) error {
|
|
if r.closed.Load() {
|
|
return errors.New("binding already closed")
|
|
}
|
|
|
|
readCtx, cancel := context.WithCancel(ctx)
|
|
r.wg.Add(2)
|
|
go func() {
|
|
select {
|
|
case <-r.closeCh:
|
|
// nop
|
|
case <-readCtx.Done():
|
|
// nop
|
|
}
|
|
r.wg.Done()
|
|
cancel()
|
|
}()
|
|
go func() {
|
|
// unless closed, keep trying to read and handle messages forever
|
|
defer r.wg.Done()
|
|
for {
|
|
var (
|
|
msgs <-chan amqp.Delivery
|
|
err error
|
|
declaredQueueName string
|
|
ch *amqp.Channel
|
|
)
|
|
r.channelMutex.RLock()
|
|
declaredQueueName = r.queue.Name
|
|
ch = r.channel
|
|
r.channelMutex.RUnlock()
|
|
|
|
if ch != nil {
|
|
msgs, err = ch.Consume(
|
|
declaredQueueName,
|
|
"",
|
|
false,
|
|
false,
|
|
false,
|
|
false,
|
|
nil,
|
|
)
|
|
if err == nil {
|
|
// all good, handle messages
|
|
r.handleMessage(readCtx, handler, msgs, ch)
|
|
} else {
|
|
r.logger.Errorf("Error consuming messages from queue [%s]: %v", r.queue.Name, err)
|
|
}
|
|
}
|
|
|
|
// something went wrong, wait for reconnect
|
|
select {
|
|
case <-time.After(r.metadata.ReconnectWait):
|
|
continue
|
|
case <-readCtx.Done():
|
|
r.logger.Info("Input binding closed, stop fetching message")
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
return nil
|
|
}
|
|
|
|
func (r *RabbitMQ) newTLSConfig() *tls.Config {
|
|
tlsConfig := new(tls.Config)
|
|
|
|
if r.metadata.ClientCert != "" && r.metadata.ClientKey != "" {
|
|
cert, err := tls.X509KeyPair([]byte(r.metadata.ClientCert), []byte(r.metadata.ClientKey))
|
|
if err != nil {
|
|
r.logger.Warnf("Unable to load client certificate and key pair. Err: %v", err)
|
|
return tlsConfig
|
|
}
|
|
tlsConfig.Certificates = []tls.Certificate{cert}
|
|
}
|
|
|
|
if r.metadata.CaCert != "" {
|
|
tlsConfig.RootCAs = x509.NewCertPool()
|
|
if ok := tlsConfig.RootCAs.AppendCertsFromPEM([]byte(r.metadata.CaCert)); !ok {
|
|
r.logger.Warnf("Unable to load CA certificate.")
|
|
}
|
|
}
|
|
return tlsConfig
|
|
}
|
|
|
|
// isValidPEM validates the provided input has PEM formatted block.
|
|
func isValidPEM(val string) bool {
|
|
block, _ := pem.Decode([]byte(val))
|
|
|
|
return block != nil
|
|
}
|
|
|
|
// handleMessage handles incoming messages from RabbitMQ
|
|
func (r *RabbitMQ) handleMessage(ctx context.Context, handler bindings.Handler, msgCh <-chan amqp.Delivery, ch *amqp.Channel) {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case d, ok := <-msgCh:
|
|
if !ok {
|
|
r.logger.Info("Input binding channel closed")
|
|
return
|
|
}
|
|
|
|
metadata := make(map[string]string, len(d.Headers))
|
|
// Passthrough any custom metadata to the handler.
|
|
for k, v := range d.Headers {
|
|
if s, ok := v.(string); ok {
|
|
// Escape the key and value to ensure they are valid URL query parameters.
|
|
// This is necessary for them to be sent as HTTP Metadata.
|
|
metadata[url.QueryEscape(k)] = url.QueryEscape(s)
|
|
}
|
|
}
|
|
|
|
_, err := handler(ctx, &bindings.ReadResponse{
|
|
Data: d.Body,
|
|
Metadata: metadata,
|
|
})
|
|
if err != nil {
|
|
ch.Nack(d.DeliveryTag, false, true)
|
|
} else {
|
|
ch.Ack(d.DeliveryTag, false)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (r *RabbitMQ) Close() error {
|
|
if r.closed.CompareAndSwap(false, true) {
|
|
close(r.closeCh)
|
|
}
|
|
defer r.wg.Wait()
|
|
r.channelMutex.Lock()
|
|
defer r.channelMutex.Unlock()
|
|
return r.reset()
|
|
}
|
|
|
|
func (r *RabbitMQ) connect() error {
|
|
if r.closed.Load() {
|
|
// Do not reconnect on stopped service.
|
|
return errClosed
|
|
}
|
|
var conn *amqp.Connection
|
|
var ch *amqp.Channel
|
|
var err error
|
|
if r.metadata.ClientCert != "" && r.metadata.ClientKey != "" && r.metadata.CaCert != "" {
|
|
tlsConfig := r.newTLSConfig()
|
|
conn, ch, err = dialTLS(r.metadata.Host, tlsConfig, r.metadata.ExternalSasl)
|
|
} else {
|
|
conn, ch, err = dial(r.metadata.Host)
|
|
}
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
ch.Qos(r.metadata.PrefetchCount, 0, true)
|
|
q, err := r.declareQueue(ch)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
r.notifyRabbitChannelClose = make(chan *amqp.Error, 1)
|
|
ch.NotifyClose(r.notifyRabbitChannelClose)
|
|
|
|
r.channelMutex.Lock()
|
|
// try to close the old channel and connection, ignore the error
|
|
_ = r.reset() //nolint:errcheck
|
|
r.connection, r.channel, r.queue = conn, ch, q
|
|
r.channelMutex.Unlock()
|
|
|
|
r.logger.Info("Connected to RabbitMQ")
|
|
|
|
return nil
|
|
}
|
|
|
|
// reset the channel and the connection when encountered a connection error.
|
|
// this function call should be wrapped by channelMutex.
|
|
func (r *RabbitMQ) reset() (err error) {
|
|
if r.channel != nil {
|
|
if err = r.channel.Close(); err != nil {
|
|
r.logger.Warnf("Reset: channel.Close() failed: %v", err)
|
|
}
|
|
r.channel = nil
|
|
}
|
|
|
|
if r.connection != nil {
|
|
if err2 := r.connection.Close(); err2 != nil {
|
|
r.logger.Warnf("Reset: connection.Close() failed: %v", err2)
|
|
if err == nil {
|
|
err = err2
|
|
}
|
|
}
|
|
r.connection = nil
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func (r *RabbitMQ) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
|
|
metadataStruct := rabbitMQMetadata{}
|
|
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.BindingType)
|
|
return
|
|
}
|