From b298b65cfeca5bcb103b571639965e01707b73fa Mon Sep 17 00:00:00 2001 From: Shubham Sharma Date: Fri, 23 Sep 2022 22:27:00 +0530 Subject: [PATCH] Add bulk publish support to Azure Service Bus (#2106) * Initial implementation Signed-off-by: Shubham Sharma * Some improvements Signed-off-by: Shubham Sharma * Refactor and add tests Signed-off-by: Shubham Sharma * Lint Signed-off-by: Shubham Sharma * Refactor Signed-off-by: Shubham Sharma * Use request metadata instead of component metadata Signed-off-by: Shubham Sharma * Remove unused method Signed-off-by: Shubham Sharma * Review comments addressed Signed-off-by: Shubham Sharma * Doc comments Signed-off-by: Shubham Sharma Signed-off-by: Shubham Sharma --- internal/utils/utils.go | 32 ++++++++ internal/utils/utils_test.go | 98 +++++++++++++++++++++++++ metadata/utils.go | 3 + pubsub/azure/servicebus/message.go | 58 ++++++++++++--- pubsub/azure/servicebus/message_test.go | 80 +++++++++----------- pubsub/azure/servicebus/servicebus.go | 42 ++++++++++- 6 files changed, 256 insertions(+), 57 deletions(-) create mode 100644 internal/utils/utils_test.go diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 4a7f02e8e..3ae9d1c82 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -1,6 +1,20 @@ +/* +Copyright 2022 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package utils import ( + "strconv" "strings" ) @@ -14,3 +28,21 @@ func IsTruthy(val string) bool { return false } } + +// GetElemOrDefaultFromMap returns the value of a key from a map, or a default value +// if the key does not exist or the value is not of the expected type. +func GetElemOrDefaultFromMap[T int | uint64](m map[string]string, key string, def T) T { + if val, ok := m[key]; ok { + switch any(def).(type) { + case int: + if ival, err := strconv.ParseInt(val, 10, 64); err == nil { + return T(ival) + } + case uint64: + if uval, err := strconv.ParseUint(val, 10, 64); err == nil { + return T(uval) + } + } + } + return def +} diff --git a/internal/utils/utils_test.go b/internal/utils/utils_test.go new file mode 100644 index 000000000..e3b0199da --- /dev/null +++ b/internal/utils/utils_test.go @@ -0,0 +1,98 @@ +/* +Copyright 2022 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package utils + +import "testing" + +func TestGetElemOrDefaultFromMap(t *testing.T) { + t.Run("test int", func(t *testing.T) { + testcases := []struct { + name string + m map[string]string + key string + def int + expected int + }{ + { + name: "Get an int value from map that exists", + m: map[string]string{"key": "1"}, + key: "key", + def: 0, + expected: 1, + }, + { + name: "Get an int value from map that does not exist", + m: map[string]string{"key": "1"}, + key: "key2", + def: 0, + expected: 0, + }, + { + name: "Get an int value from map that exists but is not an int", + m: map[string]string{"key": "a"}, + key: "key", + def: 0, + expected: 0, + }, + } + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + actual := GetElemOrDefaultFromMap(tc.m, tc.key, tc.def) + if actual != tc.expected { + t.Errorf("expected %v, got %v", tc.expected, actual) + } + }) + } + }) + t.Run("test uint64", func(t *testing.T) { + testcases := []struct { + name string + m map[string]string + key string + def uint64 + expected uint64 + }{ + { + name: "Get an uint64 value from map that exists", + m: map[string]string{"key": "1"}, + key: "key", + def: uint64(0), + expected: uint64(1), + }, + { + name: "Get an uint64 value from map that does not exist", + m: map[string]string{"key": "1"}, + key: "key2", + def: uint64(0), + expected: uint64(0), + }, + { + name: "Get an int value from map that exists but is not an uint64", + m: map[string]string{"key": "-1"}, + key: "key", + def: 0, + expected: 0, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + actual := GetElemOrDefaultFromMap(tc.m, tc.key, tc.def) + if actual != tc.expected { + t.Errorf("expected %v, got %v", tc.expected, actual) + } + }) + } + }) +} diff --git a/metadata/utils.go b/metadata/utils.go index ddbcd36bb..13e5c619b 100644 --- a/metadata/utils.go +++ b/metadata/utils.go @@ -44,6 +44,9 @@ const ( // MaxBulkAwaitDurationKey is the key for the max bulk await duration in the metadata. MaxBulkAwaitDurationMilliSecondsKey string = "maxBulkAwaitDurationMilliSeconds" + + // MaxBulkPubBytes defines the maximum bytes to publish in a bulk publish request metadata. + MaxBulkPubBytes string = "maxBulkPubBytes" ) // TryGetTTL tries to get the ttl as a time.Duration value for pubsub, binding and any other building block. diff --git a/pubsub/azure/servicebus/message.go b/pubsub/azure/servicebus/message.go index 0b40f243b..a489d564b 100644 --- a/pubsub/azure/servicebus/message.go +++ b/pubsub/azure/servicebus/message.go @@ -144,64 +144,98 @@ func NewASBMessageFromPubsubRequest(req *pubsub.PublishRequest) (*azservicebus.M Body: req.Data, } + err := addMetadataToMessage(asbMsg, req.Metadata) + return asbMsg, err +} + +// NewASBMessageFromBulkMessageEntry builds a new Azure Service Bus message from a BulkMessageEntry. +func NewASBMessageFromBulkMessageEntry(entry pubsub.BulkMessageEntry) (*azservicebus.Message, error) { + asbMsg := &azservicebus.Message{ + Body: entry.Event, + ContentType: &entry.ContentType, + } + + err := addMetadataToMessage(asbMsg, entry.Metadata) + return asbMsg, err +} + +func addMetadataToMessage(asbMsg *azservicebus.Message, metadata map[string]string) error { // Common properties. - ttl, ok, _ := contribMetadata.TryGetTTL(req.Metadata) + ttl, ok, _ := contribMetadata.TryGetTTL(metadata) if ok { asbMsg.TimeToLive = &ttl } // Azure Service Bus specific properties. // reference: https://docs.microsoft.com/en-us/rest/api/servicebus/message-headers-and-properties#message-headers - msgID, ok, _ := tryGetString(req.Metadata, MessageIDMetadataKey) + msgID, ok, _ := tryGetString(metadata, MessageIDMetadataKey) if ok { asbMsg.MessageID = &msgID } - correlationID, ok, _ := tryGetString(req.Metadata, CorrelationIDMetadataKey) + correlationID, ok, _ := tryGetString(metadata, CorrelationIDMetadataKey) if ok { asbMsg.CorrelationID = &correlationID } - sessionID, okSessionID, _ := tryGetString(req.Metadata, SessionIDMetadataKey) + sessionID, okSessionID, _ := tryGetString(metadata, SessionIDMetadataKey) if okSessionID { asbMsg.SessionID = &sessionID } - label, ok, _ := tryGetString(req.Metadata, LabelMetadataKey) + label, ok, _ := tryGetString(metadata, LabelMetadataKey) if ok { asbMsg.Subject = &label } - replyTo, ok, _ := tryGetString(req.Metadata, ReplyToMetadataKey) + replyTo, ok, _ := tryGetString(metadata, ReplyToMetadataKey) if ok { asbMsg.ReplyTo = &replyTo } - to, ok, _ := tryGetString(req.Metadata, ToMetadataKey) + to, ok, _ := tryGetString(metadata, ToMetadataKey) if ok { asbMsg.To = &to } - partitionKey, ok, _ := tryGetString(req.Metadata, PartitionKeyMetadataKey) + partitionKey, ok, _ := tryGetString(metadata, PartitionKeyMetadataKey) if ok { if okSessionID && partitionKey != sessionID { - return nil, fmt.Errorf("session id %s and partition key %s should be equal when both present", sessionID, partitionKey) + return fmt.Errorf("session id %s and partition key %s should be equal when both present", sessionID, partitionKey) } asbMsg.PartitionKey = &partitionKey } - contentType, ok, _ := tryGetString(req.Metadata, ContentTypeMetadataKey) + contentType, ok, _ := tryGetString(metadata, ContentTypeMetadataKey) if ok { asbMsg.ContentType = &contentType } - scheduledEnqueueTime, ok, _ := tryGetScheduledEnqueueTime(req.Metadata) + scheduledEnqueueTime, ok, _ := tryGetScheduledEnqueueTime(metadata) if ok { asbMsg.ScheduledEnqueueTime = scheduledEnqueueTime } - return asbMsg, nil + return nil +} + +// UpdateASBBatchMessageWithBulkPublishRequest updates the batch message with messages from the bulk publish request. +func UpdateASBBatchMessageWithBulkPublishRequest(asbMsgBatch *azservicebus.MessageBatch, req *pubsub.BulkPublishRequest) error { + // Add entries from bulk request to batch. + for _, entry := range req.Entries { + asbMsg, err := NewASBMessageFromBulkMessageEntry(entry) + if err != nil { + return err + } + + err = asbMsgBatch.AddMessage(asbMsg, nil) + if err != nil { + return err + } + } + + return nil } func tryGetString(props map[string]string, key string) (string, bool, error) { diff --git a/pubsub/azure/servicebus/message_test.go b/pubsub/azure/servicebus/message_test.go index a6827f1f3..5ce732535 100644 --- a/pubsub/azure/servicebus/message_test.go +++ b/pubsub/azure/servicebus/message_test.go @@ -21,48 +21,43 @@ import ( azservicebus "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "github.com/dapr/components-contrib/pubsub" ) -func TestNewASBMessageFromPubsubRequest(t *testing.T) { - testMessageData := []byte("test message") - testMessageID := "testMessageId" - testCorrelationID := "testCorrelationId" - testSessionID := "testSessionId" - testLabel := "testLabel" - testReplyTo := "testReplyTo" - testTo := "testTo" - testPartitionKey := testSessionID - testPartitionKeyUnique := "testPartitionKey" - testContentType := "testContentType" - nowUtc := time.Now().UTC() - testScheduledEnqueueTimeUtc := nowUtc.Format(http.TimeFormat) +var ( + testMessageID = "testMessageId" + testCorrelationID = "testCorrelationId" + testSessionID = "testSessionId" + testLabel = "testLabel" + testReplyTo = "testReplyTo" + testTo = "testTo" + testPartitionKey = testSessionID + testPartitionKeyUnique = "testPartitionKey" + testContentType = "testContentType" + nowUtc = time.Now().UTC() + testScheduledEnqueueTimeUtc = nowUtc.Format(http.TimeFormat) +) +func TestAddMetadataToMessage(t *testing.T) { testCases := []struct { name string - pubsubRequest pubsub.PublishRequest + metadata map[string]string expectedAzServiceBusMessage azservicebus.Message expectError bool }{ { name: "Maps pubsub request to azure service bus message.", - pubsubRequest: pubsub.PublishRequest{ - Data: testMessageData, - Metadata: map[string]string{ - MessageIDMetadataKey: testMessageID, - CorrelationIDMetadataKey: testCorrelationID, - SessionIDMetadataKey: testSessionID, - LabelMetadataKey: testLabel, - ReplyToMetadataKey: testReplyTo, - ToMetadataKey: testTo, - PartitionKeyMetadataKey: testPartitionKey, - ContentTypeMetadataKey: testContentType, - ScheduledEnqueueTimeUtcMetadataKey: testScheduledEnqueueTimeUtc, - }, + metadata: map[string]string{ + MessageIDMetadataKey: testMessageID, + CorrelationIDMetadataKey: testCorrelationID, + SessionIDMetadataKey: testSessionID, + LabelMetadataKey: testLabel, + ReplyToMetadataKey: testReplyTo, + ToMetadataKey: testTo, + PartitionKeyMetadataKey: testPartitionKey, + ContentTypeMetadataKey: testContentType, + ScheduledEnqueueTimeUtcMetadataKey: testScheduledEnqueueTimeUtc, }, expectedAzServiceBusMessage: azservicebus.Message{ - Body: testMessageData, MessageID: &testMessageID, CorrelationID: &testCorrelationID, SessionID: &testSessionID, @@ -77,21 +72,17 @@ func TestNewASBMessageFromPubsubRequest(t *testing.T) { }, { name: "Errors when partition key and session id set but not equal.", - pubsubRequest: pubsub.PublishRequest{ - Data: testMessageData, - Metadata: map[string]string{ - MessageIDMetadataKey: testMessageID, - CorrelationIDMetadataKey: testCorrelationID, - SessionIDMetadataKey: testSessionID, - LabelMetadataKey: testLabel, - ReplyToMetadataKey: testReplyTo, - ToMetadataKey: testTo, - PartitionKeyMetadataKey: testPartitionKeyUnique, - ContentTypeMetadataKey: testContentType, - }, + metadata: map[string]string{ + MessageIDMetadataKey: testMessageID, + CorrelationIDMetadataKey: testCorrelationID, + SessionIDMetadataKey: testSessionID, + LabelMetadataKey: testLabel, + ReplyToMetadataKey: testReplyTo, + ToMetadataKey: testTo, + PartitionKeyMetadataKey: testPartitionKeyUnique, + ContentTypeMetadataKey: testContentType, }, expectedAzServiceBusMessage: azservicebus.Message{ - Body: testMessageData, MessageID: &testMessageID, CorrelationID: &testCorrelationID, SessionID: &testSessionID, @@ -108,7 +99,8 @@ func TestNewASBMessageFromPubsubRequest(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // act. - msg, err := NewASBMessageFromPubsubRequest(&tc.pubsubRequest) + msg := &azservicebus.Message{} + err := addMetadataToMessage(msg, tc.metadata) // assert. if tc.expectError { diff --git a/pubsub/azure/servicebus/servicebus.go b/pubsub/azure/servicebus/servicebus.go index 3cfa63fc1..73e67b227 100644 --- a/pubsub/azure/servicebus/servicebus.go +++ b/pubsub/azure/servicebus/servicebus.go @@ -30,6 +30,7 @@ import ( azauth "github.com/dapr/components-contrib/internal/authentication/azure" impl "github.com/dapr/components-contrib/internal/component/azure/servicebus" + "github.com/dapr/components-contrib/internal/utils" contribMetadata "github.com/dapr/components-contrib/metadata" "github.com/dapr/components-contrib/pubsub" "github.com/dapr/kit/logger" @@ -37,7 +38,8 @@ import ( ) const ( - errorMessagePrefix = "azure service bus error:" + errorMessagePrefix = "azure service bus error:" + defaultMaxBulkPubBytes uint64 = 1024 * 128 // 128 KiB ) var retriableSendingErrors = map[amqp.ErrorCondition]struct{}{ @@ -359,6 +361,44 @@ func (a *azureServiceBus) Publish(req *pubsub.PublishRequest) error { ) } +func (a *azureServiceBus) BulkPublish(ctx context.Context, req *pubsub.BulkPublishRequest) (pubsub.BulkPublishResponse, error) { + // If the request is empty, sender.SendMessageBatch will panic later. + // Return an empty response to avoid this. + if len(req.Entries) == 0 { + a.logger.Warnf("Empty bulk publish request, skipping") + return pubsub.NewBulkPublishResponse(req.Entries, pubsub.PublishSucceeded, nil), nil + } + + sender, err := a.senderForTopic(ctx, req.Topic) + if err != nil { + return pubsub.NewBulkPublishResponse(req.Entries, pubsub.PublishFailed, err), err + } + + // Create a new batch of messages with batch options. + batchOpts := &servicebus.MessageBatchOptions{ + MaxBytes: utils.GetElemOrDefaultFromMap(req.Metadata, contribMetadata.MaxBulkPubBytes, defaultMaxBulkPubBytes), + } + + batchMsg, err := sender.NewMessageBatch(ctx, batchOpts) + if err != nil { + return pubsub.NewBulkPublishResponse(req.Entries, pubsub.PublishFailed, err), err + } + + // Add messages from the bulk publish request to the batch. + err = UpdateASBBatchMessageWithBulkPublishRequest(batchMsg, req) + if err != nil { + return pubsub.NewBulkPublishResponse(req.Entries, pubsub.PublishFailed, err), err + } + + // Azure Service Bus does not return individual status for each message in the request. + err = sender.SendMessageBatch(ctx, batchMsg, nil) + if err != nil { + return pubsub.NewBulkPublishResponse(req.Entries, pubsub.PublishFailed, err), err + } + + return pubsub.NewBulkPublishResponse(req.Entries, pubsub.PublishSucceeded, nil), nil +} + func (a *azureServiceBus) Subscribe(subscribeCtx context.Context, req pubsub.SubscribeRequest, handler pubsub.Handler) error { subID := a.metadata.ConsumerID if !a.metadata.DisableEntityManagement {