Add message time to live in RabbitMQ, Azure Service Bus/Storage Queue bindings (#298)
* Add message ttl for Azure Service Bus/Storage Queue and RabbitMQ * Rename metadata key to ttlInSeconds Move defaultQueueTTL to RabbitMQ metadata * Move integration tests to own files * Add +build integration_test tag * Remove integration test skip Co-authored-by: Aman Bhardwaj <amanbha@users.noreply.github.com>
This commit is contained in:
parent
4f9d6d97e3
commit
75fcd7accf
|
|
@ -19,6 +19,9 @@ const (
|
|||
correlationID = "correlationID"
|
||||
label = "label"
|
||||
id = "id"
|
||||
|
||||
// AzureServiceBusDefaultMessageTimeToLive defines the default time to live for queues, which is 14 days. The same way Azure Portal does.
|
||||
AzureServiceBusDefaultMessageTimeToLive = time.Hour * 24 * 14
|
||||
)
|
||||
|
||||
// AzureServiceBusQueues is an input/output binding reading from and sending events to Azure Service Bus queues
|
||||
|
|
@ -32,6 +35,7 @@ type AzureServiceBusQueues struct {
|
|||
type serviceBusQueuesMetadata struct {
|
||||
ConnectionString string `json:"connectionString"`
|
||||
QueueName string `json:"queueName"`
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// NewAzureServiceBusQueues returns a new AzureServiceBusQueues instance
|
||||
|
|
@ -52,7 +56,43 @@ func (a *AzureServiceBusQueues) Init(metadata bindings.Metadata) error {
|
|||
return err
|
||||
}
|
||||
|
||||
client, err := ns.NewQueue(a.metadata.QueueName)
|
||||
qm := ns.NewQueueManager()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
queues, err := qm.List(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var entity *servicebus.QueueEntity
|
||||
for _, q := range queues {
|
||||
if q.Name == a.metadata.QueueName {
|
||||
entity = q
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Create queue if it does not exist
|
||||
if entity == nil {
|
||||
var ttl time.Duration
|
||||
var ok bool
|
||||
ttl, ok, err = bindings.TryGetTTL(metadata.Properties)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !ok {
|
||||
ttl = a.metadata.ttl
|
||||
}
|
||||
|
||||
entity, err = qm.Put(ctx, a.metadata.QueueName, servicebus.QueueEntityWithMessageTimeToLive(&ttl))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
client, err := ns.NewQueue(entity.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -71,6 +111,19 @@ func (a *AzureServiceBusQueues) parseMetadata(metadata bindings.Metadata) (*serv
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ttl, ok, err := bindings.TryGetTTL(metadata.Properties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// set the same default message time to live as suggested in Azure Portal to 14 days (otherwise it will be 10675199 days)
|
||||
if !ok {
|
||||
ttl = AzureServiceBusDefaultMessageTimeToLive
|
||||
}
|
||||
|
||||
m.ttl = ttl
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
|
|
@ -85,8 +138,17 @@ func (a *AzureServiceBusQueues) Write(req *bindings.WriteRequest) error {
|
|||
if val, ok := req.Metadata[correlationID]; ok && val != "" {
|
||||
msg.CorrelationID = val
|
||||
}
|
||||
err := a.client.Send(ctx, msg)
|
||||
return err
|
||||
|
||||
ttl, ok, err := bindings.TryGetTTL(req.Metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ok {
|
||||
msg.TTL = &ttl
|
||||
}
|
||||
|
||||
return a.client.Send(ctx, msg)
|
||||
}
|
||||
|
||||
func (a *AzureServiceBusQueues) Read(handler func(*bindings.ReadResponse) error) error {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,180 @@
|
|||
// +build integration_test
|
||||
|
||||
// ------------------------------------------------------------
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
// ------------------------------------------------------------
|
||||
|
||||
package servicebusqueues
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
servicebus "github.com/Azure/azure-service-bus-go"
|
||||
"github.com/dapr/components-contrib/bindings"
|
||||
"github.com/dapr/dapr/pkg/logger"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
// Environment variable containing the connection string to Azure Service Bus
|
||||
testServiceBusEnvKey = "DAPR_TEST_AZURE_SERVICEBUS"
|
||||
)
|
||||
|
||||
func getTestServiceBusConnectionString() string {
|
||||
return os.Getenv(testServiceBusEnvKey)
|
||||
}
|
||||
|
||||
type testQueueHandler struct {
|
||||
callback func(*servicebus.Message)
|
||||
}
|
||||
|
||||
func (h testQueueHandler) Handle(ctx context.Context, message *servicebus.Message) error {
|
||||
h.callback(message)
|
||||
return message.Complete(ctx)
|
||||
}
|
||||
|
||||
func getMessageWithRetries(queue *servicebus.Queue, maxDuration time.Duration) (*servicebus.Message, bool, error) {
|
||||
var receivedMessage *servicebus.Message
|
||||
|
||||
queueHandler := testQueueHandler{
|
||||
callback: func(msg *servicebus.Message) {
|
||||
receivedMessage = msg
|
||||
},
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), maxDuration)
|
||||
defer cancel()
|
||||
err := queue.ReceiveOne(ctx, queueHandler)
|
||||
if err != nil && err != context.DeadlineExceeded {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
return receivedMessage, receivedMessage != nil, nil
|
||||
}
|
||||
|
||||
func TestQueueWithTTL(t *testing.T) {
|
||||
serviceBusConnectionString := getTestServiceBusConnectionString()
|
||||
assert.NotEmpty(serviceBusConnectionString, fmt.Sprintf("Azure ServiceBus connection string must set in environment variable '%s'", testServiceBusEnvKey))
|
||||
|
||||
queueName := uuid.New().String()
|
||||
a := NewAzureServiceBusQueues(logger.NewLogger("test"))
|
||||
m := bindings.Metadata{}
|
||||
m.Properties = map[string]string{"connectionString": serviceBusConnectionString, "queueName": queueName, bindings.TTLMetadataKey: "1"}
|
||||
err := a.Init(m)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Assert thet queue was created with an time to live value
|
||||
ns, err := servicebus.NewNamespace(servicebus.NamespaceWithConnectionString(serviceBusConnectionString))
|
||||
assert.Nil(t, err)
|
||||
queue, err := ns.NewQueue(queueName)
|
||||
assert.Nil(t, err)
|
||||
|
||||
qmr := ns.NewQueueManager()
|
||||
defer qmr.Delete(context.Background(), queueName)
|
||||
|
||||
queueEntity, err := qmr.Get(context.Background(), queueName)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "PT1S", *queueEntity.DefaultMessageTimeToLive)
|
||||
|
||||
// Assert that if waited too long, we won't see any message
|
||||
const tooLateMsgContent = "too_late_msg"
|
||||
err = a.Write(&bindings.WriteRequest{Data: []byte(tooLateMsgContent)})
|
||||
assert.Nil(t, err)
|
||||
|
||||
time.Sleep(time.Second * 2)
|
||||
|
||||
const ttlInSeconds = 1
|
||||
const maxGetDuration = ttlInSeconds * time.Second
|
||||
|
||||
_, ok, err := getMessageWithRetries(queue, maxGetDuration)
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
|
||||
// Getting before it is expired, should return it
|
||||
const testMsgContent = "test_msg"
|
||||
err = a.Write(&bindings.WriteRequest{Data: []byte(testMsgContent)})
|
||||
assert.Nil(t, err)
|
||||
|
||||
msg, ok, err := getMessageWithRetries(queue, maxGetDuration)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
msgBody := string(msg.Data)
|
||||
assert.Equal(t, testMsgContent, msgBody)
|
||||
assert.NotNil(t, msg.TTL)
|
||||
assert.Equal(t, time.Second, *msg.TTL)
|
||||
}
|
||||
|
||||
func TestPublishingWithTTL(t *testing.T) {
|
||||
serviceBusConnectionString := getTestServiceBusConnectionString()
|
||||
assert.NotEmpty(serviceBusConnectionString, fmt.Sprintf("Azure ServiceBus connection string must set in environment variable '%s'", testServiceBusEnvKey))
|
||||
|
||||
queueName := uuid.New().String()
|
||||
queueBinding1 := NewAzureServiceBusQueues(logger.NewLogger("test"))
|
||||
bindingMetadata := bindings.Metadata{}
|
||||
bindingMetadata.Properties = map[string]string{"connectionString": serviceBusConnectionString, "queueName": queueName}
|
||||
err := queueBinding1.Init(bindingMetadata)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Assert thet queue was created with Azure default time to live value
|
||||
ns, err := servicebus.NewNamespace(servicebus.NamespaceWithConnectionString(serviceBusConnectionString))
|
||||
assert.Nil(t, err)
|
||||
|
||||
queue, err := ns.NewQueue(queueName)
|
||||
assert.Nil(t, err)
|
||||
|
||||
qmr := ns.NewQueueManager()
|
||||
defer qmr.Delete(context.Background(), queueName)
|
||||
|
||||
queueEntity, err := qmr.Get(context.Background(), queueName)
|
||||
assert.Nil(t, err)
|
||||
const defaultAzureServiceBusMessageTimeToLive = "P14D"
|
||||
assert.Equal(t, defaultAzureServiceBusMessageTimeToLive, *queueEntity.DefaultMessageTimeToLive)
|
||||
|
||||
const tooLateMsgContent = "too_late_msg"
|
||||
writeRequest := bindings.WriteRequest{
|
||||
Data: []byte(tooLateMsgContent),
|
||||
Metadata: map[string]string{
|
||||
bindings.TTLMetadataKey: "1",
|
||||
},
|
||||
}
|
||||
err = queueBinding1.Write(&writeRequest)
|
||||
assert.Nil(t, err)
|
||||
|
||||
time.Sleep(time.Second * 5)
|
||||
|
||||
const ttlInSeconds = 1
|
||||
const maxGetDuration = ttlInSeconds * time.Second
|
||||
|
||||
_, ok, err := getMessageWithRetries(queue, maxGetDuration)
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
|
||||
// Getting before it is expired, should return it
|
||||
queueBinding2 := NewAzureServiceBusQueues(logger.NewLogger("test"))
|
||||
err = queueBinding2.Init(bindingMetadata)
|
||||
assert.Nil(t, err)
|
||||
|
||||
const testMsgContent = "test_msg"
|
||||
writeRequest = bindings.WriteRequest{
|
||||
Data: []byte(testMsgContent),
|
||||
Metadata: map[string]string{
|
||||
bindings.TTLMetadataKey: "1",
|
||||
},
|
||||
}
|
||||
err = queueBinding2.Write(&writeRequest)
|
||||
assert.Nil(t, err)
|
||||
|
||||
msg, ok, err := getMessageWithRetries(queue, maxGetDuration)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
msgBody := string(msg.Data)
|
||||
assert.Equal(t, testMsgContent, msgBody)
|
||||
assert.NotNil(t, msg.TTL)
|
||||
|
||||
assert.Equal(t, time.Second, *msg.TTL)
|
||||
}
|
||||
|
|
@ -7,6 +7,7 @@ package servicebusqueues
|
|||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dapr/components-contrib/bindings"
|
||||
"github.com/dapr/dapr/pkg/logger"
|
||||
|
|
@ -14,11 +15,79 @@ import (
|
|||
)
|
||||
|
||||
func TestParseMetadata(t *testing.T) {
|
||||
m := bindings.Metadata{}
|
||||
m.Properties = map[string]string{"connectionString": "connString", "queueName": "queue1"}
|
||||
a := NewAzureServiceBusQueues(logger.NewLogger("test"))
|
||||
meta, err := a.parseMetadata(m)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "connString", meta.ConnectionString)
|
||||
assert.Equal(t, "queue1", meta.QueueName)
|
||||
var oneSecondDuration time.Duration = time.Second
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
properties map[string]string
|
||||
expectedConnectionString string
|
||||
expectedQueueName string
|
||||
expectedTTL time.Duration
|
||||
}{
|
||||
{
|
||||
name: "ConnectionString and queue name",
|
||||
properties: map[string]string{"connectionString": "connString", "queueName": "queue1"},
|
||||
expectedConnectionString: "connString",
|
||||
expectedQueueName: "queue1",
|
||||
expectedTTL: AzureServiceBusDefaultMessageTimeToLive,
|
||||
},
|
||||
{
|
||||
name: "Empty TTL",
|
||||
properties: map[string]string{"connectionString": "connString", "queueName": "queue1", bindings.TTLMetadataKey: ""},
|
||||
expectedConnectionString: "connString",
|
||||
expectedQueueName: "queue1",
|
||||
expectedTTL: AzureServiceBusDefaultMessageTimeToLive,
|
||||
},
|
||||
{
|
||||
name: "With TTL",
|
||||
properties: map[string]string{"connectionString": "connString", "queueName": "queue1", bindings.TTLMetadataKey: "1"},
|
||||
expectedConnectionString: "connString",
|
||||
expectedQueueName: "queue1",
|
||||
expectedTTL: oneSecondDuration,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range testCases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := bindings.Metadata{}
|
||||
m.Properties = tt.properties
|
||||
a := NewAzureServiceBusQueues(logger.NewLogger("test"))
|
||||
meta, err := a.parseMetadata(m)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, tt.expectedConnectionString, meta.ConnectionString)
|
||||
assert.Equal(t, tt.expectedQueueName, meta.QueueName)
|
||||
assert.Equal(t, tt.expectedTTL, meta.ttl)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMetadataWithInvalidTTL(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
properties map[string]string
|
||||
}{
|
||||
{
|
||||
name: "Whitespaces TTL",
|
||||
properties: map[string]string{"connectionString": "connString", "queueName": "queue1", bindings.TTLMetadataKey: " "},
|
||||
},
|
||||
{
|
||||
name: "Negative ttl",
|
||||
properties: map[string]string{"connectionString": "connString", "queueName": "queue1", bindings.TTLMetadataKey: "-1"},
|
||||
},
|
||||
{
|
||||
name: "Non-numeric ttl",
|
||||
properties: map[string]string{"connectionString": "connString", "queueName": "queue1", bindings.TTLMetadataKey: "abc"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range testCases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := bindings.Metadata{}
|
||||
m.Properties = tt.properties
|
||||
|
||||
a := NewAzureServiceBusQueues(logger.NewLogger("test"))
|
||||
_, err := a.parseMetadata(m)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,6 +23,10 @@ import (
|
|||
"github.com/dapr/dapr/pkg/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTTL = time.Minute * 10
|
||||
)
|
||||
|
||||
type consumer struct {
|
||||
callback func(*bindings.ReadResponse) error
|
||||
}
|
||||
|
|
@ -30,7 +34,7 @@ type consumer struct {
|
|||
// QueueHelper enables injection for testnig
|
||||
type QueueHelper interface {
|
||||
Init(accountName string, accountKey string, queueName string, decodeBase64 bool) error
|
||||
Write(data []byte) error
|
||||
Write(data []byte, ttl *time.Duration) error
|
||||
Read(ctx context.Context, consumer *consumer) error
|
||||
}
|
||||
|
||||
|
|
@ -61,11 +65,16 @@ func (d *AzureQueueHelper) Init(accountName string, accountKey string, queueName
|
|||
return nil
|
||||
}
|
||||
|
||||
func (d *AzureQueueHelper) Write(data []byte) error {
|
||||
func (d *AzureQueueHelper) Write(data []byte, ttl *time.Duration) error {
|
||||
ctx := context.TODO()
|
||||
messagesURL := d.queueURL.NewMessagesURL()
|
||||
s := string(data)
|
||||
_, err := messagesURL.Enqueue(ctx, s, time.Second*0, time.Minute*10)
|
||||
|
||||
if ttl == nil {
|
||||
ttlToUse := defaultTTL
|
||||
ttl = &ttlToUse
|
||||
}
|
||||
_, err := messagesURL.Enqueue(ctx, s, time.Second*0, *ttl)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -131,6 +140,7 @@ type storageQueuesMetadata struct {
|
|||
QueueName string `json:"queue"`
|
||||
AccountName string `json:"storageAccount"`
|
||||
DecodeBase64 string `json:"decodeBase64"`
|
||||
ttl *time.Duration
|
||||
}
|
||||
|
||||
// NewAzureStorageQueues returns a new AzureStorageQueues instance
|
||||
|
|
@ -168,11 +178,31 @@ func (a *AzureStorageQueues) parseMetadata(metadata bindings.Metadata) (*storage
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ttl, ok, err := bindings.TryGetTTL(metadata.Properties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ok {
|
||||
m.ttl = &ttl
|
||||
}
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func (a *AzureStorageQueues) Write(req *bindings.WriteRequest) error {
|
||||
err := a.helper.Write(req.Data)
|
||||
ttlToUse := a.metadata.ttl
|
||||
ttl, ok, err := bindings.TryGetTTL(req.Metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ok {
|
||||
ttlToUse = &ttl
|
||||
}
|
||||
|
||||
err = a.helper.Write(req.Data, ttlToUse)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,8 +27,9 @@ func (m *MockHelper) Init(accountName string, accountKey string, queueName strin
|
|||
return retvals.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockHelper) Write(data []byte) error {
|
||||
return nil
|
||||
func (m *MockHelper) Write(data []byte, ttl *time.Duration) error {
|
||||
retvals := m.Called(data, ttl)
|
||||
return retvals.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockHelper) Read(ctx context.Context, consumer *consumer) error {
|
||||
|
|
@ -38,6 +39,9 @@ func (m *MockHelper) Read(ctx context.Context, consumer *consumer) error {
|
|||
func TestWriteQueue(t *testing.T) {
|
||||
mm := new(MockHelper)
|
||||
mm.On("Init", mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("bool")).Return(nil)
|
||||
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.MatchedBy(func(in *time.Duration) bool {
|
||||
return in == nil
|
||||
})).Return(nil)
|
||||
|
||||
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
|
||||
|
||||
|
|
@ -54,6 +58,53 @@ func TestWriteQueue(t *testing.T) {
|
|||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestWriteWithTTLInQueue(t *testing.T) {
|
||||
mm := new(MockHelper)
|
||||
mm.On("Init", mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("bool")).Return(nil)
|
||||
mm.On("Write", mock.AnythingOfTypeArgument("[]uint8"), mock.MatchedBy(func(in *time.Duration) bool {
|
||||
return in != nil && *in == time.Second
|
||||
})).Return(nil)
|
||||
|
||||
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
|
||||
|
||||
m := bindings.Metadata{}
|
||||
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1", bindings.TTLMetadataKey: "1"}
|
||||
|
||||
err := a.Init(m)
|
||||
assert.Nil(t, err)
|
||||
|
||||
r := bindings.WriteRequest{Data: []byte("This is my message")}
|
||||
|
||||
err = a.Write(&r)
|
||||
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestWriteWithTTLInWrite(t *testing.T) {
|
||||
mm := new(MockHelper)
|
||||
mm.On("Init", mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("bool")).Return(nil)
|
||||
mm.On("Write", mock.AnythingOfTypeArgument("[]uint8"), mock.MatchedBy(func(in *time.Duration) bool {
|
||||
return in != nil && *in == time.Second
|
||||
})).Return(nil)
|
||||
|
||||
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
|
||||
|
||||
m := bindings.Metadata{}
|
||||
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1", bindings.TTLMetadataKey: "1"}
|
||||
|
||||
err := a.Init(m)
|
||||
assert.Nil(t, err)
|
||||
|
||||
r := bindings.WriteRequest{
|
||||
Data: []byte("This is my message"),
|
||||
Metadata: map[string]string{bindings.TTLMetadataKey: "1"},
|
||||
}
|
||||
|
||||
err = a.Write(&r)
|
||||
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
// Uncomment this function to write a message to local storage queue
|
||||
/* func TestWriteLocalQueue(t *testing.T) {
|
||||
|
||||
|
|
@ -75,7 +126,7 @@ func TestWriteQueue(t *testing.T) {
|
|||
func TestReadQueue(t *testing.T) {
|
||||
mm := new(MockHelper)
|
||||
mm.On("Init", mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("bool")).Return(nil)
|
||||
|
||||
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)
|
||||
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
|
||||
|
||||
m := bindings.Metadata{}
|
||||
|
|
@ -108,6 +159,7 @@ func TestReadQueue(t *testing.T) {
|
|||
func TestReadQueueDecode(t *testing.T) {
|
||||
mm := new(MockHelper)
|
||||
mm.On("Init", mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("bool")).Return(nil)
|
||||
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)
|
||||
|
||||
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
|
||||
|
||||
|
|
@ -169,6 +221,7 @@ func TestReadQueueDecode(t *testing.T) {
|
|||
func TestReadQueueNoMessage(t *testing.T) {
|
||||
mm := new(MockHelper)
|
||||
mm.On("Init", mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), false).Return(nil)
|
||||
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)
|
||||
|
||||
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
|
||||
|
||||
|
|
@ -194,13 +247,79 @@ func TestReadQueueNoMessage(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestParseMetadata(t *testing.T) {
|
||||
m := bindings.Metadata{}
|
||||
m.Properties = map[string]string{"storageAccessKey": "myKey", "queue": "queue1", "storageAccount": "devstoreaccount1"}
|
||||
var oneSecondDuration time.Duration = time.Second
|
||||
|
||||
a := NewAzureStorageQueues(logger.NewLogger("test"))
|
||||
meta, err := a.parseMetadata(m)
|
||||
testCases := []struct {
|
||||
name string
|
||||
properties map[string]string
|
||||
expectedAccountKey string
|
||||
expectedQueueName string
|
||||
expectedTTL *time.Duration
|
||||
}{
|
||||
{
|
||||
name: "Account and key",
|
||||
properties: map[string]string{"storageAccessKey": "myKey", "queue": "queue1", "storageAccount": "devstoreaccount1"},
|
||||
expectedAccountKey: "myKey",
|
||||
expectedQueueName: "queue1",
|
||||
},
|
||||
{
|
||||
name: "Empty TTL",
|
||||
properties: map[string]string{"storageAccessKey": "myKey", "queue": "queue1", "storageAccount": "devstoreaccount1", bindings.TTLMetadataKey: ""},
|
||||
expectedAccountKey: "myKey",
|
||||
expectedQueueName: "queue1",
|
||||
},
|
||||
{
|
||||
name: "With TTL",
|
||||
properties: map[string]string{"storageAccessKey": "myKey", "queue": "queue1", "storageAccount": "devstoreaccount1", bindings.TTLMetadataKey: "1"},
|
||||
expectedAccountKey: "myKey",
|
||||
expectedQueueName: "queue1",
|
||||
expectedTTL: &oneSecondDuration,
|
||||
},
|
||||
}
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "myKey", meta.AccountKey)
|
||||
assert.Equal(t, "queue1", meta.QueueName)
|
||||
for _, tt := range testCases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := bindings.Metadata{}
|
||||
m.Properties = tt.properties
|
||||
|
||||
a := NewAzureStorageQueues(logger.NewLogger("test"))
|
||||
meta, err := a.parseMetadata(m)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, tt.expectedAccountKey, meta.AccountKey)
|
||||
assert.Equal(t, tt.expectedQueueName, meta.QueueName)
|
||||
assert.Equal(t, tt.expectedTTL, meta.ttl)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMetadataWithInvalidTTL(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
properties map[string]string
|
||||
}{
|
||||
{
|
||||
name: "Whitespaces TTL",
|
||||
properties: map[string]string{"storageAccessKey": "myKey", "queue": "queue1", "storageAccount": "devstoreaccount1", bindings.TTLMetadataKey: " "},
|
||||
},
|
||||
{
|
||||
name: "Negative ttl",
|
||||
properties: map[string]string{"storageAccessKey": "myKey", "queue": "queue1", "storageAccount": "devstoreaccount1", bindings.TTLMetadataKey: "-1"},
|
||||
},
|
||||
{
|
||||
name: "Non-numeric ttl",
|
||||
properties: map[string]string{"storageAccessKey": "myKey", "queue": "queue1", "storageAccount": "devstoreaccount1", bindings.TTLMetadataKey: "abc"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range testCases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := bindings.Metadata{}
|
||||
m.Properties = tt.properties
|
||||
|
||||
a := NewAzureStorageQueues(logger.NewLogger("test"))
|
||||
_, err := a.parseMetadata(m)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,18 +7,25 @@ package rabbitmq
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/dapr/components-contrib/bindings"
|
||||
"github.com/dapr/dapr/pkg/logger"
|
||||
"github.com/streadway/amqp"
|
||||
)
|
||||
|
||||
const (
|
||||
rabbitMQQueueMessageTTLKey = "x-message-ttl"
|
||||
)
|
||||
|
||||
// RabbitMQ allows sending/receiving data to/from RabbitMQ
|
||||
type RabbitMQ struct {
|
||||
connection *amqp.Connection
|
||||
channel *amqp.Channel
|
||||
metadata *rabbitMQMetadata
|
||||
metadata rabbitMQMetadata
|
||||
logger logger.Logger
|
||||
queue amqp.Queue
|
||||
}
|
||||
|
||||
// Metadata is the rabbitmq config
|
||||
|
|
@ -27,6 +34,7 @@ type rabbitMQMetadata struct {
|
|||
Host string `json:"host"`
|
||||
Durable bool `json:"durable,string"`
|
||||
DeleteWhenUnused bool `json:"deleteWhenUnused,string"`
|
||||
defaultQueueTTL *time.Duration
|
||||
}
|
||||
|
||||
// NewRabbitMQ returns a new rabbitmq instance
|
||||
|
|
@ -36,14 +44,12 @@ func NewRabbitMQ(logger logger.Logger) *RabbitMQ {
|
|||
|
||||
// Init does metadata parsing and connection creation
|
||||
func (r *RabbitMQ) Init(metadata bindings.Metadata) error {
|
||||
meta, err := r.getRabbitMQMetadata(metadata)
|
||||
err := r.parseMetadata(metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.metadata = meta
|
||||
|
||||
conn, err := amqp.Dial(meta.Host)
|
||||
conn, err := amqp.Dial(r.metadata.Host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -55,42 +61,84 @@ func (r *RabbitMQ) Init(metadata bindings.Metadata) error {
|
|||
|
||||
r.connection = conn
|
||||
r.channel = ch
|
||||
|
||||
q, err := r.declareQueue()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.queue = q
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RabbitMQ) Write(req *bindings.WriteRequest) error {
|
||||
err := r.channel.Publish("", r.metadata.QueueName, false, false, amqp.Publishing{
|
||||
ContentType: "text/plain",
|
||||
Body: req.Data,
|
||||
})
|
||||
pub := amqp.Publishing{
|
||||
DeliveryMode: amqp.Persistent,
|
||||
ContentType: "text/plain",
|
||||
Body: req.Data,
|
||||
}
|
||||
|
||||
ttl, ok, err := bindings.TryGetTTL(req.Metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// The default time to live has been set in the queue
|
||||
// We allow overriding on each call, by setting a value in request metadata
|
||||
if ok {
|
||||
// RabbitMQ expects the duration in ms
|
||||
pub.Expiration = strconv.FormatInt(ttl.Milliseconds(), 10)
|
||||
}
|
||||
|
||||
err = r.channel.Publish("", r.metadata.QueueName, false, false, pub)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RabbitMQ) getRabbitMQMetadata(metadata bindings.Metadata) (*rabbitMQMetadata, error) {
|
||||
func (r *RabbitMQ) parseMetadata(metadata bindings.Metadata) error {
|
||||
b, err := json.Marshal(metadata.Properties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var rabbitMQMeta rabbitMQMetadata
|
||||
err = json.Unmarshal(b, &rabbitMQMeta)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &rabbitMQMeta, nil
|
||||
}
|
||||
|
||||
func (r *RabbitMQ) Read(handler func(*bindings.ReadResponse) error) error {
|
||||
q, err := r.channel.QueueDeclare(r.metadata.QueueName, r.metadata.Durable, r.metadata.DeleteWhenUnused, false, false, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var m rabbitMQMetadata
|
||||
err = json.Unmarshal(b, &m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ttl, ok, err := bindings.TryGetTTL(metadata.Properties)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ok {
|
||||
m.defaultQueueTTL = &ttl
|
||||
}
|
||||
|
||||
r.metadata = m
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RabbitMQ) declareQueue() (amqp.Queue, error) {
|
||||
args := amqp.Table{}
|
||||
if r.metadata.defaultQueueTTL != nil {
|
||||
// Value in ms
|
||||
ttl := *r.metadata.defaultQueueTTL / time.Millisecond
|
||||
args[rabbitMQQueueMessageTTLKey] = int(ttl)
|
||||
}
|
||||
|
||||
return r.channel.QueueDeclare(r.metadata.QueueName, r.metadata.Durable, r.metadata.DeleteWhenUnused, false, false, args)
|
||||
}
|
||||
|
||||
func (r *RabbitMQ) Read(handler func(*bindings.ReadResponse) error) error {
|
||||
msgs, err := r.channel.Consume(
|
||||
q.Name,
|
||||
r.queue.Name,
|
||||
"",
|
||||
false,
|
||||
false,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,179 @@
|
|||
// +build integration_test
|
||||
|
||||
// ------------------------------------------------------------
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
// ------------------------------------------------------------
|
||||
|
||||
package rabbitmq
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dapr/components-contrib/bindings"
|
||||
"github.com/dapr/dapr/pkg/logger"
|
||||
"github.com/google/uuid"
|
||||
"github.com/streadway/amqp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
// Environment variable containing the host name for RabbitMQ integration tests
|
||||
// To run using docker: docker run -d --hostname -rabbit --name test-rabbit -p 15672:15672 -p 5672:5672 rabbitmq:3-management
|
||||
// In that case the connection string will be: amqp://guest:guest@localhost:5672/
|
||||
testRabbitMQHostEnvKey = "DAPR_TEST_RABBITMQ_HOST"
|
||||
)
|
||||
|
||||
func getTestRabbitMQHost() string {
|
||||
return os.Getenv(testRabbitMQHostEnvKey)
|
||||
}
|
||||
|
||||
func getMessageWithRetries(ch *amqp.Channel, queueName string, maxDuration time.Duration) (msg amqp.Delivery, ok bool, err error) {
|
||||
start := time.Now()
|
||||
for time.Since(start) < maxDuration {
|
||||
msg, ok, err := ch.Get(queueName, true)
|
||||
if err != nil || ok {
|
||||
return msg, ok, err
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
return amqp.Delivery{}, false, nil
|
||||
}
|
||||
|
||||
func TestQueuesWithTTL(t *testing.T) {
|
||||
rabbitmqHost := getTestRabbitMQHost()
|
||||
assert.NotEmpty(t, rabbitmqHost, fmt.Sprintf("RabbitMQ host configuration must be set in environment variable '%s' (example 'amqp://guest:guest@localhost:5672/')", testRabbitMQHostEnvKey))
|
||||
|
||||
queueName := uuid.New().String()
|
||||
durable := true
|
||||
exclusive := false
|
||||
const ttlInSeconds = 1
|
||||
const maxGetDuration = ttlInSeconds * time.Second
|
||||
|
||||
metadata := bindings.Metadata{
|
||||
Name: "testQueue",
|
||||
Properties: map[string]string{
|
||||
"queueName": queueName,
|
||||
"host": rabbitmqHost,
|
||||
"deleteWhenUnused": strconv.FormatBool(exclusive),
|
||||
"durable": strconv.FormatBool(durable),
|
||||
bindings.TTLMetadataKey: strconv.FormatInt(ttlInSeconds, 10),
|
||||
},
|
||||
}
|
||||
|
||||
logger := logger.NewLogger("test")
|
||||
|
||||
r := NewRabbitMQ(logger)
|
||||
err := r.Init(metadata)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Assert that if waited too long, we won't see any message
|
||||
conn, err := amqp.Dial(rabbitmqHost)
|
||||
assert.Nil(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
ch, err := conn.Channel()
|
||||
assert.Nil(t, err)
|
||||
defer ch.Close()
|
||||
|
||||
const tooLateMsgContent = "too_late_msg"
|
||||
err = r.Write(&bindings.WriteRequest{Data: []byte(tooLateMsgContent)})
|
||||
assert.Nil(t, err)
|
||||
|
||||
time.Sleep(time.Second + (ttlInSeconds * time.Second))
|
||||
|
||||
_, ok, err := getMessageWithRetries(ch, queueName, maxGetDuration)
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
|
||||
// Getting before it is expired, should return it
|
||||
const testMsgContent = "test_msg"
|
||||
err = r.Write(&bindings.WriteRequest{Data: []byte(testMsgContent)})
|
||||
assert.Nil(t, err)
|
||||
|
||||
msg, ok, err := getMessageWithRetries(ch, queueName, maxGetDuration)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
msgBody := string(msg.Body)
|
||||
assert.Equal(t, testMsgContent, msgBody)
|
||||
}
|
||||
|
||||
func TestPublishingWithTTL(t *testing.T) {
|
||||
rabbitmqHost := getTestRabbitMQHost()
|
||||
assert.NotEmpty(t, rabbitmqHost, fmt.Sprintf("RabbitMQ host configuration must be set in environment variable '%s' (example 'amqp://guest:guest@localhost:5672/')", testRabbitMQHostEnvKey))
|
||||
|
||||
queueName := uuid.New().String()
|
||||
durable := true
|
||||
exclusive := false
|
||||
const ttlInSeconds = 1
|
||||
const maxGetDuration = ttlInSeconds * time.Second
|
||||
|
||||
metadata := bindings.Metadata{
|
||||
Name: "testQueue",
|
||||
Properties: map[string]string{
|
||||
"queueName": queueName,
|
||||
"host": rabbitmqHost,
|
||||
"deleteWhenUnused": strconv.FormatBool(exclusive),
|
||||
"durable": strconv.FormatBool(durable),
|
||||
},
|
||||
}
|
||||
|
||||
logger := logger.NewLogger("test")
|
||||
|
||||
rabbitMQBinding1 := NewRabbitMQ(logger)
|
||||
err := rabbitMQBinding1.Init(metadata)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Assert that if waited too long, we won't see any message
|
||||
conn, err := amqp.Dial(rabbitmqHost)
|
||||
assert.Nil(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
ch, err := conn.Channel()
|
||||
assert.Nil(t, err)
|
||||
defer ch.Close()
|
||||
|
||||
const tooLateMsgContent = "too_late_msg"
|
||||
writeRequest := bindings.WriteRequest{
|
||||
Data: []byte(tooLateMsgContent),
|
||||
Metadata: map[string]string{
|
||||
bindings.TTLMetadataKey: strconv.Itoa(ttlInSeconds),
|
||||
},
|
||||
}
|
||||
|
||||
err = rabbitMQBinding1.Write(&writeRequest)
|
||||
assert.Nil(t, err)
|
||||
|
||||
time.Sleep(time.Second + (ttlInSeconds * time.Second))
|
||||
|
||||
_, ok, err := getMessageWithRetries(ch, queueName, maxGetDuration)
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
|
||||
// Getting before it is expired, should return it
|
||||
rabbitMQBinding2 := NewRabbitMQ(logger)
|
||||
err = rabbitMQBinding2.Init(metadata)
|
||||
assert.Nil(t, err)
|
||||
|
||||
const testMsgContent = "test_msg"
|
||||
writeRequest = bindings.WriteRequest{
|
||||
Data: []byte(testMsgContent),
|
||||
Metadata: map[string]string{
|
||||
bindings.TTLMetadataKey: strconv.Itoa(ttlInSeconds * 1000),
|
||||
},
|
||||
}
|
||||
err = rabbitMQBinding2.Write(&writeRequest)
|
||||
assert.Nil(t, err)
|
||||
|
||||
msg, ok, err := getMessageWithRetries(ch, queueName, maxGetDuration)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
msgBody := string(msg.Body)
|
||||
assert.Equal(t, testMsgContent, msgBody)
|
||||
}
|
||||
|
|
@ -7,6 +7,7 @@ package rabbitmq
|
|||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dapr/components-contrib/bindings"
|
||||
"github.com/dapr/dapr/pkg/logger"
|
||||
|
|
@ -14,13 +15,89 @@ import (
|
|||
)
|
||||
|
||||
func TestParseMetadata(t *testing.T) {
|
||||
m := bindings.Metadata{}
|
||||
m.Properties = map[string]string{"QueueName": "a", "Host": "a", "DeleteWhenUnused": "true", "Durable": "true"}
|
||||
r := RabbitMQ{logger: logger.NewLogger("test")}
|
||||
rm, err := r.getRabbitMQMetadata(m)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "a", rm.QueueName)
|
||||
assert.Equal(t, "a", rm.Host)
|
||||
assert.Equal(t, true, rm.DeleteWhenUnused)
|
||||
assert.Equal(t, true, rm.Durable)
|
||||
const queueName = "test-queue"
|
||||
const host = "test-host"
|
||||
var oneSecondTTL time.Duration = time.Second
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
properties map[string]string
|
||||
expectedDeleteWhenUnused bool
|
||||
expectedDurable bool
|
||||
expectedTTL *time.Duration
|
||||
}{
|
||||
{
|
||||
name: "Delete / Durable",
|
||||
properties: map[string]string{"QueueName": queueName, "Host": host, "DeleteWhenUnused": "true", "Durable": "true"},
|
||||
expectedDeleteWhenUnused: true,
|
||||
expectedDurable: true,
|
||||
},
|
||||
{
|
||||
name: "Not Delete / Not Durable",
|
||||
properties: map[string]string{"QueueName": queueName, "Host": host, "DeleteWhenUnused": "false", "Durable": "false"},
|
||||
expectedDeleteWhenUnused: false,
|
||||
expectedDurable: false,
|
||||
},
|
||||
{
|
||||
name: "With one second TTL",
|
||||
properties: map[string]string{"QueueName": queueName, "Host": host, "DeleteWhenUnused": "false", "Durable": "false", bindings.TTLMetadataKey: "1"},
|
||||
expectedDeleteWhenUnused: false,
|
||||
expectedDurable: false,
|
||||
expectedTTL: &oneSecondTTL,
|
||||
},
|
||||
{
|
||||
name: "Empty TTL",
|
||||
properties: map[string]string{"QueueName": queueName, "Host": host, "DeleteWhenUnused": "false", "Durable": "false", bindings.TTLMetadataKey: ""},
|
||||
expectedDeleteWhenUnused: false,
|
||||
expectedDurable: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range testCases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := bindings.Metadata{}
|
||||
m.Properties = tt.properties
|
||||
r := RabbitMQ{logger: logger.NewLogger("test")}
|
||||
err := r.parseMetadata(m)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, queueName, r.metadata.QueueName)
|
||||
assert.Equal(t, host, r.metadata.Host)
|
||||
assert.Equal(t, tt.expectedDeleteWhenUnused, r.metadata.DeleteWhenUnused)
|
||||
assert.Equal(t, tt.expectedDurable, r.metadata.Durable)
|
||||
assert.Equal(t, tt.expectedTTL, r.metadata.defaultQueueTTL)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMetadataWithInvalidTTL(t *testing.T) {
|
||||
const queueName = "test-queue"
|
||||
const host = "test-host"
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
properties map[string]string
|
||||
}{
|
||||
{
|
||||
name: "Whitespaces TTL",
|
||||
properties: map[string]string{"QueueName": queueName, "Host": host, bindings.TTLMetadataKey: " "},
|
||||
},
|
||||
{
|
||||
name: "Negative ttl",
|
||||
properties: map[string]string{"QueueName": queueName, "Host": host, bindings.TTLMetadataKey: "-1"},
|
||||
},
|
||||
{
|
||||
name: "Non-numeric ttl",
|
||||
properties: map[string]string{"QueueName": queueName, "Host": host, bindings.TTLMetadataKey: "abc"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range testCases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := bindings.Metadata{}
|
||||
m.Properties = tt.properties
|
||||
r := RabbitMQ{logger: logger.NewLogger("test")}
|
||||
err := r.parseMetadata(m)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,37 @@
|
|||
// ------------------------------------------------------------
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
// ------------------------------------------------------------
|
||||
|
||||
package bindings
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
// TTLMetadataKey defines the metadata key for setting a time to live (in seconds)
|
||||
TTLMetadataKey = "ttlInSeconds"
|
||||
)
|
||||
|
||||
// TryGetTTL tries to get the ttl (in seconds) value for a binding
|
||||
func TryGetTTL(props map[string]string) (time.Duration, bool, error) {
|
||||
if val, ok := props[TTLMetadataKey]; ok && val != "" {
|
||||
valInt, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
return 0, false, errors.Wrapf(err, "%s value must be a valid integer: actual is '%s'", TTLMetadataKey, val)
|
||||
}
|
||||
|
||||
if valInt <= 0 {
|
||||
return 0, false, fmt.Errorf("%s value must be higher than zero: actual is %d", TTLMetadataKey, valInt)
|
||||
}
|
||||
|
||||
return time.Duration(valInt) * time.Second, true, nil
|
||||
}
|
||||
|
||||
return 0, false, nil
|
||||
}
|
||||
Loading…
Reference in New Issue