Add message time to live in RabbitMQ, Azure Service Bus/Storage Queue bindings (#298)

* Add message ttl for Azure Service Bus/Storage Queue
and RabbitMQ

* Rename metadata key to ttlInSeconds
Move defaultQueueTTL to RabbitMQ metadata

* Move integration tests to own files

* Add +build integration_test tag

* Remove integration test skip

Co-authored-by: Aman Bhardwaj <amanbha@users.noreply.github.com>
This commit is contained in:
Francisco Beltrao 2020-04-15 19:07:16 +02:00 committed by GitHub
parent 4f9d6d97e3
commit 75fcd7accf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 859 additions and 58 deletions

View File

@ -19,6 +19,9 @@ const (
correlationID = "correlationID"
label = "label"
id = "id"
// AzureServiceBusDefaultMessageTimeToLive defines the default time to live for queues, which is 14 days. The same way Azure Portal does.
AzureServiceBusDefaultMessageTimeToLive = time.Hour * 24 * 14
)
// AzureServiceBusQueues is an input/output binding reading from and sending events to Azure Service Bus queues
@ -32,6 +35,7 @@ type AzureServiceBusQueues struct {
type serviceBusQueuesMetadata struct {
ConnectionString string `json:"connectionString"`
QueueName string `json:"queueName"`
ttl time.Duration
}
// NewAzureServiceBusQueues returns a new AzureServiceBusQueues instance
@ -52,7 +56,43 @@ func (a *AzureServiceBusQueues) Init(metadata bindings.Metadata) error {
return err
}
client, err := ns.NewQueue(a.metadata.QueueName)
qm := ns.NewQueueManager()
ctx := context.Background()
queues, err := qm.List(ctx)
if err != nil {
return err
}
var entity *servicebus.QueueEntity
for _, q := range queues {
if q.Name == a.metadata.QueueName {
entity = q
break
}
}
// Create queue if it does not exist
if entity == nil {
var ttl time.Duration
var ok bool
ttl, ok, err = bindings.TryGetTTL(metadata.Properties)
if err != nil {
return err
}
if !ok {
ttl = a.metadata.ttl
}
entity, err = qm.Put(ctx, a.metadata.QueueName, servicebus.QueueEntityWithMessageTimeToLive(&ttl))
if err != nil {
return err
}
}
client, err := ns.NewQueue(entity.Name)
if err != nil {
return err
}
@ -71,6 +111,19 @@ func (a *AzureServiceBusQueues) parseMetadata(metadata bindings.Metadata) (*serv
if err != nil {
return nil, err
}
ttl, ok, err := bindings.TryGetTTL(metadata.Properties)
if err != nil {
return nil, err
}
// set the same default message time to live as suggested in Azure Portal to 14 days (otherwise it will be 10675199 days)
if !ok {
ttl = AzureServiceBusDefaultMessageTimeToLive
}
m.ttl = ttl
return &m, nil
}
@ -85,8 +138,17 @@ func (a *AzureServiceBusQueues) Write(req *bindings.WriteRequest) error {
if val, ok := req.Metadata[correlationID]; ok && val != "" {
msg.CorrelationID = val
}
err := a.client.Send(ctx, msg)
return err
ttl, ok, err := bindings.TryGetTTL(req.Metadata)
if err != nil {
return err
}
if ok {
msg.TTL = &ttl
}
return a.client.Send(ctx, msg)
}
func (a *AzureServiceBusQueues) Read(handler func(*bindings.ReadResponse) error) error {

View File

@ -0,0 +1,180 @@
// +build integration_test
// ------------------------------------------------------------
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
// ------------------------------------------------------------
package servicebusqueues
import (
"context"
"fmt"
"os"
"testing"
"time"
servicebus "github.com/Azure/azure-service-bus-go"
"github.com/dapr/components-contrib/bindings"
"github.com/dapr/dapr/pkg/logger"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
const (
// Environment variable containing the connection string to Azure Service Bus
testServiceBusEnvKey = "DAPR_TEST_AZURE_SERVICEBUS"
)
func getTestServiceBusConnectionString() string {
return os.Getenv(testServiceBusEnvKey)
}
type testQueueHandler struct {
callback func(*servicebus.Message)
}
func (h testQueueHandler) Handle(ctx context.Context, message *servicebus.Message) error {
h.callback(message)
return message.Complete(ctx)
}
func getMessageWithRetries(queue *servicebus.Queue, maxDuration time.Duration) (*servicebus.Message, bool, error) {
var receivedMessage *servicebus.Message
queueHandler := testQueueHandler{
callback: func(msg *servicebus.Message) {
receivedMessage = msg
},
}
ctx, cancel := context.WithTimeout(context.Background(), maxDuration)
defer cancel()
err := queue.ReceiveOne(ctx, queueHandler)
if err != nil && err != context.DeadlineExceeded {
return nil, false, err
}
return receivedMessage, receivedMessage != nil, nil
}
func TestQueueWithTTL(t *testing.T) {
serviceBusConnectionString := getTestServiceBusConnectionString()
assert.NotEmpty(serviceBusConnectionString, fmt.Sprintf("Azure ServiceBus connection string must set in environment variable '%s'", testServiceBusEnvKey))
queueName := uuid.New().String()
a := NewAzureServiceBusQueues(logger.NewLogger("test"))
m := bindings.Metadata{}
m.Properties = map[string]string{"connectionString": serviceBusConnectionString, "queueName": queueName, bindings.TTLMetadataKey: "1"}
err := a.Init(m)
assert.Nil(t, err)
// Assert thet queue was created with an time to live value
ns, err := servicebus.NewNamespace(servicebus.NamespaceWithConnectionString(serviceBusConnectionString))
assert.Nil(t, err)
queue, err := ns.NewQueue(queueName)
assert.Nil(t, err)
qmr := ns.NewQueueManager()
defer qmr.Delete(context.Background(), queueName)
queueEntity, err := qmr.Get(context.Background(), queueName)
assert.Nil(t, err)
assert.Equal(t, "PT1S", *queueEntity.DefaultMessageTimeToLive)
// Assert that if waited too long, we won't see any message
const tooLateMsgContent = "too_late_msg"
err = a.Write(&bindings.WriteRequest{Data: []byte(tooLateMsgContent)})
assert.Nil(t, err)
time.Sleep(time.Second * 2)
const ttlInSeconds = 1
const maxGetDuration = ttlInSeconds * time.Second
_, ok, err := getMessageWithRetries(queue, maxGetDuration)
assert.Nil(t, err)
assert.False(t, ok)
// Getting before it is expired, should return it
const testMsgContent = "test_msg"
err = a.Write(&bindings.WriteRequest{Data: []byte(testMsgContent)})
assert.Nil(t, err)
msg, ok, err := getMessageWithRetries(queue, maxGetDuration)
assert.Nil(t, err)
assert.True(t, ok)
msgBody := string(msg.Data)
assert.Equal(t, testMsgContent, msgBody)
assert.NotNil(t, msg.TTL)
assert.Equal(t, time.Second, *msg.TTL)
}
func TestPublishingWithTTL(t *testing.T) {
serviceBusConnectionString := getTestServiceBusConnectionString()
assert.NotEmpty(serviceBusConnectionString, fmt.Sprintf("Azure ServiceBus connection string must set in environment variable '%s'", testServiceBusEnvKey))
queueName := uuid.New().String()
queueBinding1 := NewAzureServiceBusQueues(logger.NewLogger("test"))
bindingMetadata := bindings.Metadata{}
bindingMetadata.Properties = map[string]string{"connectionString": serviceBusConnectionString, "queueName": queueName}
err := queueBinding1.Init(bindingMetadata)
assert.Nil(t, err)
// Assert thet queue was created with Azure default time to live value
ns, err := servicebus.NewNamespace(servicebus.NamespaceWithConnectionString(serviceBusConnectionString))
assert.Nil(t, err)
queue, err := ns.NewQueue(queueName)
assert.Nil(t, err)
qmr := ns.NewQueueManager()
defer qmr.Delete(context.Background(), queueName)
queueEntity, err := qmr.Get(context.Background(), queueName)
assert.Nil(t, err)
const defaultAzureServiceBusMessageTimeToLive = "P14D"
assert.Equal(t, defaultAzureServiceBusMessageTimeToLive, *queueEntity.DefaultMessageTimeToLive)
const tooLateMsgContent = "too_late_msg"
writeRequest := bindings.WriteRequest{
Data: []byte(tooLateMsgContent),
Metadata: map[string]string{
bindings.TTLMetadataKey: "1",
},
}
err = queueBinding1.Write(&writeRequest)
assert.Nil(t, err)
time.Sleep(time.Second * 5)
const ttlInSeconds = 1
const maxGetDuration = ttlInSeconds * time.Second
_, ok, err := getMessageWithRetries(queue, maxGetDuration)
assert.Nil(t, err)
assert.False(t, ok)
// Getting before it is expired, should return it
queueBinding2 := NewAzureServiceBusQueues(logger.NewLogger("test"))
err = queueBinding2.Init(bindingMetadata)
assert.Nil(t, err)
const testMsgContent = "test_msg"
writeRequest = bindings.WriteRequest{
Data: []byte(testMsgContent),
Metadata: map[string]string{
bindings.TTLMetadataKey: "1",
},
}
err = queueBinding2.Write(&writeRequest)
assert.Nil(t, err)
msg, ok, err := getMessageWithRetries(queue, maxGetDuration)
assert.Nil(t, err)
assert.True(t, ok)
msgBody := string(msg.Data)
assert.Equal(t, testMsgContent, msgBody)
assert.NotNil(t, msg.TTL)
assert.Equal(t, time.Second, *msg.TTL)
}

View File

@ -7,6 +7,7 @@ package servicebusqueues
import (
"testing"
"time"
"github.com/dapr/components-contrib/bindings"
"github.com/dapr/dapr/pkg/logger"
@ -14,11 +15,79 @@ import (
)
func TestParseMetadata(t *testing.T) {
m := bindings.Metadata{}
m.Properties = map[string]string{"connectionString": "connString", "queueName": "queue1"}
a := NewAzureServiceBusQueues(logger.NewLogger("test"))
meta, err := a.parseMetadata(m)
assert.Nil(t, err)
assert.Equal(t, "connString", meta.ConnectionString)
assert.Equal(t, "queue1", meta.QueueName)
var oneSecondDuration time.Duration = time.Second
testCases := []struct {
name string
properties map[string]string
expectedConnectionString string
expectedQueueName string
expectedTTL time.Duration
}{
{
name: "ConnectionString and queue name",
properties: map[string]string{"connectionString": "connString", "queueName": "queue1"},
expectedConnectionString: "connString",
expectedQueueName: "queue1",
expectedTTL: AzureServiceBusDefaultMessageTimeToLive,
},
{
name: "Empty TTL",
properties: map[string]string{"connectionString": "connString", "queueName": "queue1", bindings.TTLMetadataKey: ""},
expectedConnectionString: "connString",
expectedQueueName: "queue1",
expectedTTL: AzureServiceBusDefaultMessageTimeToLive,
},
{
name: "With TTL",
properties: map[string]string{"connectionString": "connString", "queueName": "queue1", bindings.TTLMetadataKey: "1"},
expectedConnectionString: "connString",
expectedQueueName: "queue1",
expectedTTL: oneSecondDuration,
},
}
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
m := bindings.Metadata{}
m.Properties = tt.properties
a := NewAzureServiceBusQueues(logger.NewLogger("test"))
meta, err := a.parseMetadata(m)
assert.Nil(t, err)
assert.Equal(t, tt.expectedConnectionString, meta.ConnectionString)
assert.Equal(t, tt.expectedQueueName, meta.QueueName)
assert.Equal(t, tt.expectedTTL, meta.ttl)
})
}
}
func TestParseMetadataWithInvalidTTL(t *testing.T) {
testCases := []struct {
name string
properties map[string]string
}{
{
name: "Whitespaces TTL",
properties: map[string]string{"connectionString": "connString", "queueName": "queue1", bindings.TTLMetadataKey: " "},
},
{
name: "Negative ttl",
properties: map[string]string{"connectionString": "connString", "queueName": "queue1", bindings.TTLMetadataKey: "-1"},
},
{
name: "Non-numeric ttl",
properties: map[string]string{"connectionString": "connString", "queueName": "queue1", bindings.TTLMetadataKey: "abc"},
},
}
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
m := bindings.Metadata{}
m.Properties = tt.properties
a := NewAzureServiceBusQueues(logger.NewLogger("test"))
_, err := a.parseMetadata(m)
assert.NotNil(t, err)
})
}
}

View File

@ -23,6 +23,10 @@ import (
"github.com/dapr/dapr/pkg/logger"
)
const (
defaultTTL = time.Minute * 10
)
type consumer struct {
callback func(*bindings.ReadResponse) error
}
@ -30,7 +34,7 @@ type consumer struct {
// QueueHelper enables injection for testnig
type QueueHelper interface {
Init(accountName string, accountKey string, queueName string, decodeBase64 bool) error
Write(data []byte) error
Write(data []byte, ttl *time.Duration) error
Read(ctx context.Context, consumer *consumer) error
}
@ -61,11 +65,16 @@ func (d *AzureQueueHelper) Init(accountName string, accountKey string, queueName
return nil
}
func (d *AzureQueueHelper) Write(data []byte) error {
func (d *AzureQueueHelper) Write(data []byte, ttl *time.Duration) error {
ctx := context.TODO()
messagesURL := d.queueURL.NewMessagesURL()
s := string(data)
_, err := messagesURL.Enqueue(ctx, s, time.Second*0, time.Minute*10)
if ttl == nil {
ttlToUse := defaultTTL
ttl = &ttlToUse
}
_, err := messagesURL.Enqueue(ctx, s, time.Second*0, *ttl)
return err
}
@ -131,6 +140,7 @@ type storageQueuesMetadata struct {
QueueName string `json:"queue"`
AccountName string `json:"storageAccount"`
DecodeBase64 string `json:"decodeBase64"`
ttl *time.Duration
}
// NewAzureStorageQueues returns a new AzureStorageQueues instance
@ -168,11 +178,31 @@ func (a *AzureStorageQueues) parseMetadata(metadata bindings.Metadata) (*storage
if err != nil {
return nil, err
}
ttl, ok, err := bindings.TryGetTTL(metadata.Properties)
if err != nil {
return nil, err
}
if ok {
m.ttl = &ttl
}
return &m, nil
}
func (a *AzureStorageQueues) Write(req *bindings.WriteRequest) error {
err := a.helper.Write(req.Data)
ttlToUse := a.metadata.ttl
ttl, ok, err := bindings.TryGetTTL(req.Metadata)
if err != nil {
return err
}
if ok {
ttlToUse = &ttl
}
err = a.helper.Write(req.Data, ttlToUse)
if err != nil {
return err
}

View File

@ -27,8 +27,9 @@ func (m *MockHelper) Init(accountName string, accountKey string, queueName strin
return retvals.Error(0)
}
func (m *MockHelper) Write(data []byte) error {
return nil
func (m *MockHelper) Write(data []byte, ttl *time.Duration) error {
retvals := m.Called(data, ttl)
return retvals.Error(0)
}
func (m *MockHelper) Read(ctx context.Context, consumer *consumer) error {
@ -38,6 +39,9 @@ func (m *MockHelper) Read(ctx context.Context, consumer *consumer) error {
func TestWriteQueue(t *testing.T) {
mm := new(MockHelper)
mm.On("Init", mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("bool")).Return(nil)
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.MatchedBy(func(in *time.Duration) bool {
return in == nil
})).Return(nil)
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
@ -54,6 +58,53 @@ func TestWriteQueue(t *testing.T) {
assert.Nil(t, err)
}
func TestWriteWithTTLInQueue(t *testing.T) {
mm := new(MockHelper)
mm.On("Init", mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("bool")).Return(nil)
mm.On("Write", mock.AnythingOfTypeArgument("[]uint8"), mock.MatchedBy(func(in *time.Duration) bool {
return in != nil && *in == time.Second
})).Return(nil)
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
m := bindings.Metadata{}
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1", bindings.TTLMetadataKey: "1"}
err := a.Init(m)
assert.Nil(t, err)
r := bindings.WriteRequest{Data: []byte("This is my message")}
err = a.Write(&r)
assert.Nil(t, err)
}
func TestWriteWithTTLInWrite(t *testing.T) {
mm := new(MockHelper)
mm.On("Init", mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("bool")).Return(nil)
mm.On("Write", mock.AnythingOfTypeArgument("[]uint8"), mock.MatchedBy(func(in *time.Duration) bool {
return in != nil && *in == time.Second
})).Return(nil)
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
m := bindings.Metadata{}
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1", bindings.TTLMetadataKey: "1"}
err := a.Init(m)
assert.Nil(t, err)
r := bindings.WriteRequest{
Data: []byte("This is my message"),
Metadata: map[string]string{bindings.TTLMetadataKey: "1"},
}
err = a.Write(&r)
assert.Nil(t, err)
}
// Uncomment this function to write a message to local storage queue
/* func TestWriteLocalQueue(t *testing.T) {
@ -75,7 +126,7 @@ func TestWriteQueue(t *testing.T) {
func TestReadQueue(t *testing.T) {
mm := new(MockHelper)
mm.On("Init", mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("bool")).Return(nil)
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
m := bindings.Metadata{}
@ -108,6 +159,7 @@ func TestReadQueue(t *testing.T) {
func TestReadQueueDecode(t *testing.T) {
mm := new(MockHelper)
mm.On("Init", mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("bool")).Return(nil)
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
@ -169,6 +221,7 @@ func TestReadQueueDecode(t *testing.T) {
func TestReadQueueNoMessage(t *testing.T) {
mm := new(MockHelper)
mm.On("Init", mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), false).Return(nil)
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
@ -194,13 +247,79 @@ func TestReadQueueNoMessage(t *testing.T) {
}
func TestParseMetadata(t *testing.T) {
m := bindings.Metadata{}
m.Properties = map[string]string{"storageAccessKey": "myKey", "queue": "queue1", "storageAccount": "devstoreaccount1"}
var oneSecondDuration time.Duration = time.Second
a := NewAzureStorageQueues(logger.NewLogger("test"))
meta, err := a.parseMetadata(m)
testCases := []struct {
name string
properties map[string]string
expectedAccountKey string
expectedQueueName string
expectedTTL *time.Duration
}{
{
name: "Account and key",
properties: map[string]string{"storageAccessKey": "myKey", "queue": "queue1", "storageAccount": "devstoreaccount1"},
expectedAccountKey: "myKey",
expectedQueueName: "queue1",
},
{
name: "Empty TTL",
properties: map[string]string{"storageAccessKey": "myKey", "queue": "queue1", "storageAccount": "devstoreaccount1", bindings.TTLMetadataKey: ""},
expectedAccountKey: "myKey",
expectedQueueName: "queue1",
},
{
name: "With TTL",
properties: map[string]string{"storageAccessKey": "myKey", "queue": "queue1", "storageAccount": "devstoreaccount1", bindings.TTLMetadataKey: "1"},
expectedAccountKey: "myKey",
expectedQueueName: "queue1",
expectedTTL: &oneSecondDuration,
},
}
assert.Nil(t, err)
assert.Equal(t, "myKey", meta.AccountKey)
assert.Equal(t, "queue1", meta.QueueName)
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
m := bindings.Metadata{}
m.Properties = tt.properties
a := NewAzureStorageQueues(logger.NewLogger("test"))
meta, err := a.parseMetadata(m)
assert.Nil(t, err)
assert.Equal(t, tt.expectedAccountKey, meta.AccountKey)
assert.Equal(t, tt.expectedQueueName, meta.QueueName)
assert.Equal(t, tt.expectedTTL, meta.ttl)
})
}
}
func TestParseMetadataWithInvalidTTL(t *testing.T) {
testCases := []struct {
name string
properties map[string]string
}{
{
name: "Whitespaces TTL",
properties: map[string]string{"storageAccessKey": "myKey", "queue": "queue1", "storageAccount": "devstoreaccount1", bindings.TTLMetadataKey: " "},
},
{
name: "Negative ttl",
properties: map[string]string{"storageAccessKey": "myKey", "queue": "queue1", "storageAccount": "devstoreaccount1", bindings.TTLMetadataKey: "-1"},
},
{
name: "Non-numeric ttl",
properties: map[string]string{"storageAccessKey": "myKey", "queue": "queue1", "storageAccount": "devstoreaccount1", bindings.TTLMetadataKey: "abc"},
},
}
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
m := bindings.Metadata{}
m.Properties = tt.properties
a := NewAzureStorageQueues(logger.NewLogger("test"))
_, err := a.parseMetadata(m)
assert.NotNil(t, err)
})
}
}

View File

@ -7,18 +7,25 @@ package rabbitmq
import (
"encoding/json"
"strconv"
"time"
"github.com/dapr/components-contrib/bindings"
"github.com/dapr/dapr/pkg/logger"
"github.com/streadway/amqp"
)
const (
rabbitMQQueueMessageTTLKey = "x-message-ttl"
)
// RabbitMQ allows sending/receiving data to/from RabbitMQ
type RabbitMQ struct {
connection *amqp.Connection
channel *amqp.Channel
metadata *rabbitMQMetadata
metadata rabbitMQMetadata
logger logger.Logger
queue amqp.Queue
}
// Metadata is the rabbitmq config
@ -27,6 +34,7 @@ type rabbitMQMetadata struct {
Host string `json:"host"`
Durable bool `json:"durable,string"`
DeleteWhenUnused bool `json:"deleteWhenUnused,string"`
defaultQueueTTL *time.Duration
}
// NewRabbitMQ returns a new rabbitmq instance
@ -36,14 +44,12 @@ func NewRabbitMQ(logger logger.Logger) *RabbitMQ {
// Init does metadata parsing and connection creation
func (r *RabbitMQ) Init(metadata bindings.Metadata) error {
meta, err := r.getRabbitMQMetadata(metadata)
err := r.parseMetadata(metadata)
if err != nil {
return err
}
r.metadata = meta
conn, err := amqp.Dial(meta.Host)
conn, err := amqp.Dial(r.metadata.Host)
if err != nil {
return err
}
@ -55,42 +61,84 @@ func (r *RabbitMQ) Init(metadata bindings.Metadata) error {
r.connection = conn
r.channel = ch
q, err := r.declareQueue()
if err != nil {
return err
}
r.queue = q
return nil
}
func (r *RabbitMQ) Write(req *bindings.WriteRequest) error {
err := r.channel.Publish("", r.metadata.QueueName, false, false, amqp.Publishing{
ContentType: "text/plain",
Body: req.Data,
})
pub := amqp.Publishing{
DeliveryMode: amqp.Persistent,
ContentType: "text/plain",
Body: req.Data,
}
ttl, ok, err := bindings.TryGetTTL(req.Metadata)
if err != nil {
return err
}
// The default time to live has been set in the queue
// We allow overriding on each call, by setting a value in request metadata
if ok {
// RabbitMQ expects the duration in ms
pub.Expiration = strconv.FormatInt(ttl.Milliseconds(), 10)
}
err = r.channel.Publish("", r.metadata.QueueName, false, false, pub)
if err != nil {
return err
}
return nil
}
func (r *RabbitMQ) getRabbitMQMetadata(metadata bindings.Metadata) (*rabbitMQMetadata, error) {
func (r *RabbitMQ) parseMetadata(metadata bindings.Metadata) error {
b, err := json.Marshal(metadata.Properties)
if err != nil {
return nil, err
}
var rabbitMQMeta rabbitMQMetadata
err = json.Unmarshal(b, &rabbitMQMeta)
if err != nil {
return nil, err
}
return &rabbitMQMeta, nil
}
func (r *RabbitMQ) Read(handler func(*bindings.ReadResponse) error) error {
q, err := r.channel.QueueDeclare(r.metadata.QueueName, r.metadata.Durable, r.metadata.DeleteWhenUnused, false, false, nil)
if err != nil {
return err
}
var m rabbitMQMetadata
err = json.Unmarshal(b, &m)
if err != nil {
return err
}
ttl, ok, err := bindings.TryGetTTL(metadata.Properties)
if err != nil {
return err
}
if ok {
m.defaultQueueTTL = &ttl
}
r.metadata = m
return nil
}
func (r *RabbitMQ) declareQueue() (amqp.Queue, error) {
args := amqp.Table{}
if r.metadata.defaultQueueTTL != nil {
// Value in ms
ttl := *r.metadata.defaultQueueTTL / time.Millisecond
args[rabbitMQQueueMessageTTLKey] = int(ttl)
}
return r.channel.QueueDeclare(r.metadata.QueueName, r.metadata.Durable, r.metadata.DeleteWhenUnused, false, false, args)
}
func (r *RabbitMQ) Read(handler func(*bindings.ReadResponse) error) error {
msgs, err := r.channel.Consume(
q.Name,
r.queue.Name,
"",
false,
false,

View File

@ -0,0 +1,179 @@
// +build integration_test
// ------------------------------------------------------------
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
// ------------------------------------------------------------
package rabbitmq
import (
"fmt"
"os"
"strconv"
"testing"
"time"
"github.com/dapr/components-contrib/bindings"
"github.com/dapr/dapr/pkg/logger"
"github.com/google/uuid"
"github.com/streadway/amqp"
"github.com/stretchr/testify/assert"
)
const (
// Environment variable containing the host name for RabbitMQ integration tests
// To run using docker: docker run -d --hostname -rabbit --name test-rabbit -p 15672:15672 -p 5672:5672 rabbitmq:3-management
// In that case the connection string will be: amqp://guest:guest@localhost:5672/
testRabbitMQHostEnvKey = "DAPR_TEST_RABBITMQ_HOST"
)
func getTestRabbitMQHost() string {
return os.Getenv(testRabbitMQHostEnvKey)
}
func getMessageWithRetries(ch *amqp.Channel, queueName string, maxDuration time.Duration) (msg amqp.Delivery, ok bool, err error) {
start := time.Now()
for time.Since(start) < maxDuration {
msg, ok, err := ch.Get(queueName, true)
if err != nil || ok {
return msg, ok, err
}
time.Sleep(100 * time.Millisecond)
}
return amqp.Delivery{}, false, nil
}
func TestQueuesWithTTL(t *testing.T) {
rabbitmqHost := getTestRabbitMQHost()
assert.NotEmpty(t, rabbitmqHost, fmt.Sprintf("RabbitMQ host configuration must be set in environment variable '%s' (example 'amqp://guest:guest@localhost:5672/')", testRabbitMQHostEnvKey))
queueName := uuid.New().String()
durable := true
exclusive := false
const ttlInSeconds = 1
const maxGetDuration = ttlInSeconds * time.Second
metadata := bindings.Metadata{
Name: "testQueue",
Properties: map[string]string{
"queueName": queueName,
"host": rabbitmqHost,
"deleteWhenUnused": strconv.FormatBool(exclusive),
"durable": strconv.FormatBool(durable),
bindings.TTLMetadataKey: strconv.FormatInt(ttlInSeconds, 10),
},
}
logger := logger.NewLogger("test")
r := NewRabbitMQ(logger)
err := r.Init(metadata)
assert.Nil(t, err)
// Assert that if waited too long, we won't see any message
conn, err := amqp.Dial(rabbitmqHost)
assert.Nil(t, err)
defer conn.Close()
ch, err := conn.Channel()
assert.Nil(t, err)
defer ch.Close()
const tooLateMsgContent = "too_late_msg"
err = r.Write(&bindings.WriteRequest{Data: []byte(tooLateMsgContent)})
assert.Nil(t, err)
time.Sleep(time.Second + (ttlInSeconds * time.Second))
_, ok, err := getMessageWithRetries(ch, queueName, maxGetDuration)
assert.Nil(t, err)
assert.False(t, ok)
// Getting before it is expired, should return it
const testMsgContent = "test_msg"
err = r.Write(&bindings.WriteRequest{Data: []byte(testMsgContent)})
assert.Nil(t, err)
msg, ok, err := getMessageWithRetries(ch, queueName, maxGetDuration)
assert.Nil(t, err)
assert.True(t, ok)
msgBody := string(msg.Body)
assert.Equal(t, testMsgContent, msgBody)
}
func TestPublishingWithTTL(t *testing.T) {
rabbitmqHost := getTestRabbitMQHost()
assert.NotEmpty(t, rabbitmqHost, fmt.Sprintf("RabbitMQ host configuration must be set in environment variable '%s' (example 'amqp://guest:guest@localhost:5672/')", testRabbitMQHostEnvKey))
queueName := uuid.New().String()
durable := true
exclusive := false
const ttlInSeconds = 1
const maxGetDuration = ttlInSeconds * time.Second
metadata := bindings.Metadata{
Name: "testQueue",
Properties: map[string]string{
"queueName": queueName,
"host": rabbitmqHost,
"deleteWhenUnused": strconv.FormatBool(exclusive),
"durable": strconv.FormatBool(durable),
},
}
logger := logger.NewLogger("test")
rabbitMQBinding1 := NewRabbitMQ(logger)
err := rabbitMQBinding1.Init(metadata)
assert.Nil(t, err)
// Assert that if waited too long, we won't see any message
conn, err := amqp.Dial(rabbitmqHost)
assert.Nil(t, err)
defer conn.Close()
ch, err := conn.Channel()
assert.Nil(t, err)
defer ch.Close()
const tooLateMsgContent = "too_late_msg"
writeRequest := bindings.WriteRequest{
Data: []byte(tooLateMsgContent),
Metadata: map[string]string{
bindings.TTLMetadataKey: strconv.Itoa(ttlInSeconds),
},
}
err = rabbitMQBinding1.Write(&writeRequest)
assert.Nil(t, err)
time.Sleep(time.Second + (ttlInSeconds * time.Second))
_, ok, err := getMessageWithRetries(ch, queueName, maxGetDuration)
assert.Nil(t, err)
assert.False(t, ok)
// Getting before it is expired, should return it
rabbitMQBinding2 := NewRabbitMQ(logger)
err = rabbitMQBinding2.Init(metadata)
assert.Nil(t, err)
const testMsgContent = "test_msg"
writeRequest = bindings.WriteRequest{
Data: []byte(testMsgContent),
Metadata: map[string]string{
bindings.TTLMetadataKey: strconv.Itoa(ttlInSeconds * 1000),
},
}
err = rabbitMQBinding2.Write(&writeRequest)
assert.Nil(t, err)
msg, ok, err := getMessageWithRetries(ch, queueName, maxGetDuration)
assert.Nil(t, err)
assert.True(t, ok)
msgBody := string(msg.Body)
assert.Equal(t, testMsgContent, msgBody)
}

View File

@ -7,6 +7,7 @@ package rabbitmq
import (
"testing"
"time"
"github.com/dapr/components-contrib/bindings"
"github.com/dapr/dapr/pkg/logger"
@ -14,13 +15,89 @@ import (
)
func TestParseMetadata(t *testing.T) {
m := bindings.Metadata{}
m.Properties = map[string]string{"QueueName": "a", "Host": "a", "DeleteWhenUnused": "true", "Durable": "true"}
r := RabbitMQ{logger: logger.NewLogger("test")}
rm, err := r.getRabbitMQMetadata(m)
assert.Nil(t, err)
assert.Equal(t, "a", rm.QueueName)
assert.Equal(t, "a", rm.Host)
assert.Equal(t, true, rm.DeleteWhenUnused)
assert.Equal(t, true, rm.Durable)
const queueName = "test-queue"
const host = "test-host"
var oneSecondTTL time.Duration = time.Second
testCases := []struct {
name string
properties map[string]string
expectedDeleteWhenUnused bool
expectedDurable bool
expectedTTL *time.Duration
}{
{
name: "Delete / Durable",
properties: map[string]string{"QueueName": queueName, "Host": host, "DeleteWhenUnused": "true", "Durable": "true"},
expectedDeleteWhenUnused: true,
expectedDurable: true,
},
{
name: "Not Delete / Not Durable",
properties: map[string]string{"QueueName": queueName, "Host": host, "DeleteWhenUnused": "false", "Durable": "false"},
expectedDeleteWhenUnused: false,
expectedDurable: false,
},
{
name: "With one second TTL",
properties: map[string]string{"QueueName": queueName, "Host": host, "DeleteWhenUnused": "false", "Durable": "false", bindings.TTLMetadataKey: "1"},
expectedDeleteWhenUnused: false,
expectedDurable: false,
expectedTTL: &oneSecondTTL,
},
{
name: "Empty TTL",
properties: map[string]string{"QueueName": queueName, "Host": host, "DeleteWhenUnused": "false", "Durable": "false", bindings.TTLMetadataKey: ""},
expectedDeleteWhenUnused: false,
expectedDurable: false,
},
}
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
m := bindings.Metadata{}
m.Properties = tt.properties
r := RabbitMQ{logger: logger.NewLogger("test")}
err := r.parseMetadata(m)
assert.Nil(t, err)
assert.Equal(t, queueName, r.metadata.QueueName)
assert.Equal(t, host, r.metadata.Host)
assert.Equal(t, tt.expectedDeleteWhenUnused, r.metadata.DeleteWhenUnused)
assert.Equal(t, tt.expectedDurable, r.metadata.Durable)
assert.Equal(t, tt.expectedTTL, r.metadata.defaultQueueTTL)
})
}
}
func TestParseMetadataWithInvalidTTL(t *testing.T) {
const queueName = "test-queue"
const host = "test-host"
testCases := []struct {
name string
properties map[string]string
}{
{
name: "Whitespaces TTL",
properties: map[string]string{"QueueName": queueName, "Host": host, bindings.TTLMetadataKey: " "},
},
{
name: "Negative ttl",
properties: map[string]string{"QueueName": queueName, "Host": host, bindings.TTLMetadataKey: "-1"},
},
{
name: "Non-numeric ttl",
properties: map[string]string{"QueueName": queueName, "Host": host, bindings.TTLMetadataKey: "abc"},
},
}
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
m := bindings.Metadata{}
m.Properties = tt.properties
r := RabbitMQ{logger: logger.NewLogger("test")}
err := r.parseMetadata(m)
assert.NotNil(t, err)
})
}
}

37
bindings/utils.go Normal file
View File

@ -0,0 +1,37 @@
// ------------------------------------------------------------
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
// ------------------------------------------------------------
package bindings
import (
"fmt"
"strconv"
"time"
"github.com/pkg/errors"
)
const (
// TTLMetadataKey defines the metadata key for setting a time to live (in seconds)
TTLMetadataKey = "ttlInSeconds"
)
// TryGetTTL tries to get the ttl (in seconds) value for a binding
func TryGetTTL(props map[string]string) (time.Duration, bool, error) {
if val, ok := props[TTLMetadataKey]; ok && val != "" {
valInt, err := strconv.Atoi(val)
if err != nil {
return 0, false, errors.Wrapf(err, "%s value must be a valid integer: actual is '%s'", TTLMetadataKey, val)
}
if valInt <= 0 {
return 0, false, fmt.Errorf("%s value must be higher than zero: actual is %d", TTLMetadataKey, valInt)
}
return time.Duration(valInt) * time.Second, true, nil
}
return 0, false, nil
}