diff --git a/pubsub/rabbitmq/metadata.go b/pubsub/rabbitmq/metadata.go index ffb1aceb1..6228595be 100644 --- a/pubsub/rabbitmq/metadata.go +++ b/pubsub/rabbitmq/metadata.go @@ -37,6 +37,7 @@ type metadata struct { concurrency pubsub.ConcurrencyMode maxLen int64 maxLenBytes int64 + exchangeKind string } // createMetadata creates a new instance from the pubsub metadata. @@ -46,6 +47,7 @@ func createMetadata(pubSubMetadata pubsub.Metadata) (*metadata, error) { deleteWhenUnused: true, autoAck: false, reconnectWait: time.Duration(defaultReconnectWaitSeconds) * time.Second, + exchangeKind: fanoutExchangeKind, } if val, found := pubSubMetadata.Properties[metadataHostKey]; found && val != "" { @@ -121,6 +123,14 @@ func createMetadata(pubSubMetadata pubsub.Metadata) (*metadata, error) { } } + if val, found := pubSubMetadata.Properties[metadataExchangeKind]; found && val != "" { + if exchangeKindValid(val) { + result.exchangeKind = val + } else { + return &result, fmt.Errorf("%s invalid RabbitMQ exchange kind %s", errorMessagePrefix, val) + } + } + c, err := pubsub.Concurrency(pubSubMetadata.Properties) if err != nil { return &result, err @@ -143,3 +153,7 @@ func (m *metadata) formatQueueDeclareArgs(origin amqp.Table) amqp.Table { return origin } + +func exchangeKindValid(kind string) bool { + return kind == amqp.ExchangeFanout || kind == amqp.ExchangeTopic || kind == amqp.ExchangeDirect || kind == amqp.ExchangeHeaders +} diff --git a/pubsub/rabbitmq/metadata_test.go b/pubsub/rabbitmq/metadata_test.go index ce4fdef35..9399e6b19 100644 --- a/pubsub/rabbitmq/metadata_test.go +++ b/pubsub/rabbitmq/metadata_test.go @@ -17,6 +17,8 @@ import ( "fmt" "testing" + "github.com/streadway/amqp" + "github.com/stretchr/testify/assert" "github.com/dapr/components-contrib/pubsub" @@ -63,6 +65,7 @@ func TestCreateMetadata(t *testing.T) { assert.Equal(t, uint8(0), m.prefetchCount) assert.Equal(t, int64(0), m.maxLen) assert.Equal(t, int64(0), m.maxLenBytes) + assert.Equal(t, fanoutExchangeKind, m.exchangeKind) }) t.Run("host is not given", func(t *testing.T) { @@ -274,4 +277,40 @@ func TestCreateMetadata(t *testing.T) { assert.Equal(t, tt.expected, m.enableDeadLetter) }) } + validExchangeKind := []string{amqp.ExchangeDirect, amqp.ExchangeTopic, amqp.ExchangeFanout, amqp.ExchangeHeaders} + + for _, exchangeKind := range validExchangeKind { + t.Run(fmt.Sprintf("exchangeKind value=%s", exchangeKind), func(t *testing.T) { + fakeProperties := getFakeProperties() + + fakeMetaData := pubsub.Metadata{ + Properties: fakeProperties, + } + fakeMetaData.Properties[metadataExchangeKind] = exchangeKind + + // act + m, err := createMetadata(fakeMetaData) + + // assert + assert.NoError(t, err) + assert.Equal(t, fakeProperties[metadataHostKey], m.host) + assert.Equal(t, fakeProperties[metadataConsumerIDKey], m.consumerID) + assert.Equal(t, exchangeKind, m.exchangeKind) + }) + } + + t.Run("exchangeKind is invalid", func(t *testing.T) { + fakeProperties := getFakeProperties() + + fakeMetaData := pubsub.Metadata{ + Properties: fakeProperties, + } + fakeMetaData.Properties[metadataExchangeKind] = "invalid" + + // act + _, err := createMetadata(fakeMetaData) + + // assert + assert.Error(t, err) + }) } diff --git a/pubsub/rabbitmq/rabbitmq.go b/pubsub/rabbitmq/rabbitmq.go index 66328b5c3..d9ed0859a 100644 --- a/pubsub/rabbitmq/rabbitmq.go +++ b/pubsub/rabbitmq/rabbitmq.go @@ -49,6 +49,7 @@ const ( metadataEnableDeadLetter = "enableDeadLetter" metadataMaxLen = "maxLen" metadataMaxLenBytes = "maxLenBytes" + metadataExchangeKind = "exchangeKind" defaultReconnectWaitSeconds = 3 publishMaxRetries = 3 @@ -60,6 +61,7 @@ const ( argMaxLengthBytes = "x-max-length-bytes" argDeadLetterExchange = "x-dead-letter-exchange" queueModeLazy = "lazy" + reqMetadataRoutingKey = "routingKey" ) // RabbitMQ allows sending/receiving messages in pub/sub format. @@ -198,13 +200,17 @@ func (r *rabbitMQ) publishSync(req *pubsub.PublishRequest) (rabbitMQChannelBroke return r.channel, r.connectionCount, errors.New(errorChannelNotInitialized) } - if err := r.ensureExchangeDeclared(r.channel, req.Topic); err != nil { + if err := r.ensureExchangeDeclared(r.channel, req.Topic, r.metadata.exchangeKind); err != nil { r.logger.Errorf("%s publishing to %s failed in ensureExchangeDeclared: %v", logMessagePrefix, req.Topic, err) return r.channel, r.connectionCount, err } + routingKey := "" + if val, ok := req.Metadata[reqMetadataRoutingKey]; ok && val != "" { + routingKey = val + } - if err := r.channel.Publish(req.Topic, "", false, false, amqp.Publishing{ + if err := r.channel.Publish(req.Topic, routingKey, false, false, amqp.Publishing{ ContentType: "text/plain", Body: req.Data, DeliveryMode: r.metadata.deliveryMode, @@ -266,7 +272,7 @@ func (r *rabbitMQ) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler // this function call should be wrapped by channelMutex. func (r *rabbitMQ) prepareSubscription(channel rabbitMQChannelBroker, req pubsub.SubscribeRequest, queueName string) (*amqp.Queue, error) { - err := r.ensureExchangeDeclared(channel, req.Topic) + err := r.ensureExchangeDeclared(channel, req.Topic, r.metadata.exchangeKind) if err != nil { r.logger.Errorf("%s prepareSubscription for topic/queue '%s/%s' failed in ensureExchangeDeclared: %v", logMessagePrefix, req.Topic, queueName, err) @@ -279,7 +285,7 @@ func (r *rabbitMQ) prepareSubscription(channel rabbitMQChannelBroker, req pubsub // declare dead letter exchange dlxName := fmt.Sprintf(defaultDeadLetterExchangeFormat, queueName) dlqName := fmt.Sprintf(defaultDeadLetterQueueFormat, queueName) - err = r.ensureExchangeDeclared(channel, dlxName) + err = r.ensureExchangeDeclared(channel, dlxName, fanoutExchangeKind) if err != nil { r.logger.Errorf("%s prepareSubscription for topic/queue '%s/%s' failed in ensureExchangeDeclared: %v", logMessagePrefix, req.Topic, dlqName, err) @@ -322,8 +328,12 @@ func (r *rabbitMQ) prepareSubscription(channel rabbitMQChannelBroker, req pubsub } } - r.logger.Infof("%s binding queue '%s' to exchange '%s'", logMessagePrefix, q.Name, req.Topic) - err = channel.QueueBind(q.Name, "", req.Topic, false, nil) + routingKey := "" + if val, ok := req.Metadata[reqMetadataRoutingKey]; ok && val != "" { + routingKey = val + } + r.logger.Infof("%s binding queue '%s' to exchange '%s' with routing key '%s'", logMessagePrefix, q.Name, req.Topic, routingKey) + err = channel.QueueBind(q.Name, routingKey, req.Topic, false, nil) if err != nil { r.logger.Errorf("%s prepareSubscription for topic/queue '%s/%s' failed in channel.QueueBind: %v", logMessagePrefix, req.Topic, queueName, err) @@ -468,10 +478,10 @@ func (r *rabbitMQ) handleMessage(channel rabbitMQChannelBroker, d amqp.Delivery, } // this function call should be wrapped by channelMutex. -func (r *rabbitMQ) ensureExchangeDeclared(channel rabbitMQChannelBroker, exchange string) error { +func (r *rabbitMQ) ensureExchangeDeclared(channel rabbitMQChannelBroker, exchange, exchangeKind string) error { if !r.containsExchange(exchange) { - r.logger.Debugf("%s declaring exchange '%s' of kind '%s'", logMessagePrefix, exchange, fanoutExchangeKind) - err := channel.ExchangeDeclare(exchange, fanoutExchangeKind, true, false, false, false, nil) + r.logger.Debugf("%s declaring exchange '%s' of kind '%s'", logMessagePrefix, exchange, exchangeKind) + err := channel.ExchangeDeclare(exchange, exchangeKind, true, false, false, false, nil) if err != nil { r.logger.Errorf("%s ensureExchangeDeclared: channel.ExchangeDeclare failed: %v", logMessagePrefix, err)