Add bulk publish support to Azure Service Bus (#2106)

* Initial implementation

Signed-off-by: Shubham Sharma <shubhash@microsoft.com>

* Some improvements

Signed-off-by: Shubham Sharma <shubhash@microsoft.com>

* Refactor and add tests

Signed-off-by: Shubham Sharma <shubhash@microsoft.com>

* Lint

Signed-off-by: Shubham Sharma <shubhash@microsoft.com>

* Refactor

Signed-off-by: Shubham Sharma <shubhash@microsoft.com>

* Use request metadata instead of component metadata

Signed-off-by: Shubham Sharma <shubhash@microsoft.com>

* Remove unused method

Signed-off-by: Shubham Sharma <shubhash@microsoft.com>

* Review comments addressed

Signed-off-by: Shubham Sharma <shubhash@microsoft.com>

* Doc comments

Signed-off-by: Shubham Sharma <shubhash@microsoft.com>

Signed-off-by: Shubham Sharma <shubhash@microsoft.com>
This commit is contained in:
Shubham Sharma 2022-09-23 22:27:00 +05:30 committed by GitHub
parent 8500da577c
commit b298b65cfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 256 additions and 57 deletions

View File

@ -1,6 +1,20 @@
/*
Copyright 2022 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package utils
import (
"strconv"
"strings"
)
@ -14,3 +28,21 @@ func IsTruthy(val string) bool {
return false
}
}
// GetElemOrDefaultFromMap returns the value of a key from a map, or a default value
// if the key does not exist or the value is not of the expected type.
func GetElemOrDefaultFromMap[T int | uint64](m map[string]string, key string, def T) T {
if val, ok := m[key]; ok {
switch any(def).(type) {
case int:
if ival, err := strconv.ParseInt(val, 10, 64); err == nil {
return T(ival)
}
case uint64:
if uval, err := strconv.ParseUint(val, 10, 64); err == nil {
return T(uval)
}
}
}
return def
}

View File

@ -0,0 +1,98 @@
/*
Copyright 2022 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package utils
import "testing"
func TestGetElemOrDefaultFromMap(t *testing.T) {
t.Run("test int", func(t *testing.T) {
testcases := []struct {
name string
m map[string]string
key string
def int
expected int
}{
{
name: "Get an int value from map that exists",
m: map[string]string{"key": "1"},
key: "key",
def: 0,
expected: 1,
},
{
name: "Get an int value from map that does not exist",
m: map[string]string{"key": "1"},
key: "key2",
def: 0,
expected: 0,
},
{
name: "Get an int value from map that exists but is not an int",
m: map[string]string{"key": "a"},
key: "key",
def: 0,
expected: 0,
},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
actual := GetElemOrDefaultFromMap(tc.m, tc.key, tc.def)
if actual != tc.expected {
t.Errorf("expected %v, got %v", tc.expected, actual)
}
})
}
})
t.Run("test uint64", func(t *testing.T) {
testcases := []struct {
name string
m map[string]string
key string
def uint64
expected uint64
}{
{
name: "Get an uint64 value from map that exists",
m: map[string]string{"key": "1"},
key: "key",
def: uint64(0),
expected: uint64(1),
},
{
name: "Get an uint64 value from map that does not exist",
m: map[string]string{"key": "1"},
key: "key2",
def: uint64(0),
expected: uint64(0),
},
{
name: "Get an int value from map that exists but is not an uint64",
m: map[string]string{"key": "-1"},
key: "key",
def: 0,
expected: 0,
},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
actual := GetElemOrDefaultFromMap(tc.m, tc.key, tc.def)
if actual != tc.expected {
t.Errorf("expected %v, got %v", tc.expected, actual)
}
})
}
})
}

View File

@ -44,6 +44,9 @@ const (
// MaxBulkAwaitDurationKey is the key for the max bulk await duration in the metadata.
MaxBulkAwaitDurationMilliSecondsKey string = "maxBulkAwaitDurationMilliSeconds"
// MaxBulkPubBytes defines the maximum bytes to publish in a bulk publish request metadata.
MaxBulkPubBytes string = "maxBulkPubBytes"
)
// TryGetTTL tries to get the ttl as a time.Duration value for pubsub, binding and any other building block.

View File

@ -144,64 +144,98 @@ func NewASBMessageFromPubsubRequest(req *pubsub.PublishRequest) (*azservicebus.M
Body: req.Data,
}
err := addMetadataToMessage(asbMsg, req.Metadata)
return asbMsg, err
}
// NewASBMessageFromBulkMessageEntry builds a new Azure Service Bus message from a BulkMessageEntry.
func NewASBMessageFromBulkMessageEntry(entry pubsub.BulkMessageEntry) (*azservicebus.Message, error) {
asbMsg := &azservicebus.Message{
Body: entry.Event,
ContentType: &entry.ContentType,
}
err := addMetadataToMessage(asbMsg, entry.Metadata)
return asbMsg, err
}
func addMetadataToMessage(asbMsg *azservicebus.Message, metadata map[string]string) error {
// Common properties.
ttl, ok, _ := contribMetadata.TryGetTTL(req.Metadata)
ttl, ok, _ := contribMetadata.TryGetTTL(metadata)
if ok {
asbMsg.TimeToLive = &ttl
}
// Azure Service Bus specific properties.
// reference: https://docs.microsoft.com/en-us/rest/api/servicebus/message-headers-and-properties#message-headers
msgID, ok, _ := tryGetString(req.Metadata, MessageIDMetadataKey)
msgID, ok, _ := tryGetString(metadata, MessageIDMetadataKey)
if ok {
asbMsg.MessageID = &msgID
}
correlationID, ok, _ := tryGetString(req.Metadata, CorrelationIDMetadataKey)
correlationID, ok, _ := tryGetString(metadata, CorrelationIDMetadataKey)
if ok {
asbMsg.CorrelationID = &correlationID
}
sessionID, okSessionID, _ := tryGetString(req.Metadata, SessionIDMetadataKey)
sessionID, okSessionID, _ := tryGetString(metadata, SessionIDMetadataKey)
if okSessionID {
asbMsg.SessionID = &sessionID
}
label, ok, _ := tryGetString(req.Metadata, LabelMetadataKey)
label, ok, _ := tryGetString(metadata, LabelMetadataKey)
if ok {
asbMsg.Subject = &label
}
replyTo, ok, _ := tryGetString(req.Metadata, ReplyToMetadataKey)
replyTo, ok, _ := tryGetString(metadata, ReplyToMetadataKey)
if ok {
asbMsg.ReplyTo = &replyTo
}
to, ok, _ := tryGetString(req.Metadata, ToMetadataKey)
to, ok, _ := tryGetString(metadata, ToMetadataKey)
if ok {
asbMsg.To = &to
}
partitionKey, ok, _ := tryGetString(req.Metadata, PartitionKeyMetadataKey)
partitionKey, ok, _ := tryGetString(metadata, PartitionKeyMetadataKey)
if ok {
if okSessionID && partitionKey != sessionID {
return nil, fmt.Errorf("session id %s and partition key %s should be equal when both present", sessionID, partitionKey)
return fmt.Errorf("session id %s and partition key %s should be equal when both present", sessionID, partitionKey)
}
asbMsg.PartitionKey = &partitionKey
}
contentType, ok, _ := tryGetString(req.Metadata, ContentTypeMetadataKey)
contentType, ok, _ := tryGetString(metadata, ContentTypeMetadataKey)
if ok {
asbMsg.ContentType = &contentType
}
scheduledEnqueueTime, ok, _ := tryGetScheduledEnqueueTime(req.Metadata)
scheduledEnqueueTime, ok, _ := tryGetScheduledEnqueueTime(metadata)
if ok {
asbMsg.ScheduledEnqueueTime = scheduledEnqueueTime
}
return asbMsg, nil
return nil
}
// UpdateASBBatchMessageWithBulkPublishRequest updates the batch message with messages from the bulk publish request.
func UpdateASBBatchMessageWithBulkPublishRequest(asbMsgBatch *azservicebus.MessageBatch, req *pubsub.BulkPublishRequest) error {
// Add entries from bulk request to batch.
for _, entry := range req.Entries {
asbMsg, err := NewASBMessageFromBulkMessageEntry(entry)
if err != nil {
return err
}
err = asbMsgBatch.AddMessage(asbMsg, nil)
if err != nil {
return err
}
}
return nil
}
func tryGetString(props map[string]string, key string) (string, bool, error) {

View File

@ -21,48 +21,43 @@ import (
azservicebus "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dapr/components-contrib/pubsub"
)
func TestNewASBMessageFromPubsubRequest(t *testing.T) {
testMessageData := []byte("test message")
testMessageID := "testMessageId"
testCorrelationID := "testCorrelationId"
testSessionID := "testSessionId"
testLabel := "testLabel"
testReplyTo := "testReplyTo"
testTo := "testTo"
testPartitionKey := testSessionID
testPartitionKeyUnique := "testPartitionKey"
testContentType := "testContentType"
nowUtc := time.Now().UTC()
testScheduledEnqueueTimeUtc := nowUtc.Format(http.TimeFormat)
var (
testMessageID = "testMessageId"
testCorrelationID = "testCorrelationId"
testSessionID = "testSessionId"
testLabel = "testLabel"
testReplyTo = "testReplyTo"
testTo = "testTo"
testPartitionKey = testSessionID
testPartitionKeyUnique = "testPartitionKey"
testContentType = "testContentType"
nowUtc = time.Now().UTC()
testScheduledEnqueueTimeUtc = nowUtc.Format(http.TimeFormat)
)
func TestAddMetadataToMessage(t *testing.T) {
testCases := []struct {
name string
pubsubRequest pubsub.PublishRequest
metadata map[string]string
expectedAzServiceBusMessage azservicebus.Message
expectError bool
}{
{
name: "Maps pubsub request to azure service bus message.",
pubsubRequest: pubsub.PublishRequest{
Data: testMessageData,
Metadata: map[string]string{
MessageIDMetadataKey: testMessageID,
CorrelationIDMetadataKey: testCorrelationID,
SessionIDMetadataKey: testSessionID,
LabelMetadataKey: testLabel,
ReplyToMetadataKey: testReplyTo,
ToMetadataKey: testTo,
PartitionKeyMetadataKey: testPartitionKey,
ContentTypeMetadataKey: testContentType,
ScheduledEnqueueTimeUtcMetadataKey: testScheduledEnqueueTimeUtc,
},
metadata: map[string]string{
MessageIDMetadataKey: testMessageID,
CorrelationIDMetadataKey: testCorrelationID,
SessionIDMetadataKey: testSessionID,
LabelMetadataKey: testLabel,
ReplyToMetadataKey: testReplyTo,
ToMetadataKey: testTo,
PartitionKeyMetadataKey: testPartitionKey,
ContentTypeMetadataKey: testContentType,
ScheduledEnqueueTimeUtcMetadataKey: testScheduledEnqueueTimeUtc,
},
expectedAzServiceBusMessage: azservicebus.Message{
Body: testMessageData,
MessageID: &testMessageID,
CorrelationID: &testCorrelationID,
SessionID: &testSessionID,
@ -77,21 +72,17 @@ func TestNewASBMessageFromPubsubRequest(t *testing.T) {
},
{
name: "Errors when partition key and session id set but not equal.",
pubsubRequest: pubsub.PublishRequest{
Data: testMessageData,
Metadata: map[string]string{
MessageIDMetadataKey: testMessageID,
CorrelationIDMetadataKey: testCorrelationID,
SessionIDMetadataKey: testSessionID,
LabelMetadataKey: testLabel,
ReplyToMetadataKey: testReplyTo,
ToMetadataKey: testTo,
PartitionKeyMetadataKey: testPartitionKeyUnique,
ContentTypeMetadataKey: testContentType,
},
metadata: map[string]string{
MessageIDMetadataKey: testMessageID,
CorrelationIDMetadataKey: testCorrelationID,
SessionIDMetadataKey: testSessionID,
LabelMetadataKey: testLabel,
ReplyToMetadataKey: testReplyTo,
ToMetadataKey: testTo,
PartitionKeyMetadataKey: testPartitionKeyUnique,
ContentTypeMetadataKey: testContentType,
},
expectedAzServiceBusMessage: azservicebus.Message{
Body: testMessageData,
MessageID: &testMessageID,
CorrelationID: &testCorrelationID,
SessionID: &testSessionID,
@ -108,7 +99,8 @@ func TestNewASBMessageFromPubsubRequest(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// act.
msg, err := NewASBMessageFromPubsubRequest(&tc.pubsubRequest)
msg := &azservicebus.Message{}
err := addMetadataToMessage(msg, tc.metadata)
// assert.
if tc.expectError {

View File

@ -30,6 +30,7 @@ import (
azauth "github.com/dapr/components-contrib/internal/authentication/azure"
impl "github.com/dapr/components-contrib/internal/component/azure/servicebus"
"github.com/dapr/components-contrib/internal/utils"
contribMetadata "github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/pubsub"
"github.com/dapr/kit/logger"
@ -37,7 +38,8 @@ import (
)
const (
errorMessagePrefix = "azure service bus error:"
errorMessagePrefix = "azure service bus error:"
defaultMaxBulkPubBytes uint64 = 1024 * 128 // 128 KiB
)
var retriableSendingErrors = map[amqp.ErrorCondition]struct{}{
@ -359,6 +361,44 @@ func (a *azureServiceBus) Publish(req *pubsub.PublishRequest) error {
)
}
func (a *azureServiceBus) BulkPublish(ctx context.Context, req *pubsub.BulkPublishRequest) (pubsub.BulkPublishResponse, error) {
// If the request is empty, sender.SendMessageBatch will panic later.
// Return an empty response to avoid this.
if len(req.Entries) == 0 {
a.logger.Warnf("Empty bulk publish request, skipping")
return pubsub.NewBulkPublishResponse(req.Entries, pubsub.PublishSucceeded, nil), nil
}
sender, err := a.senderForTopic(ctx, req.Topic)
if err != nil {
return pubsub.NewBulkPublishResponse(req.Entries, pubsub.PublishFailed, err), err
}
// Create a new batch of messages with batch options.
batchOpts := &servicebus.MessageBatchOptions{
MaxBytes: utils.GetElemOrDefaultFromMap(req.Metadata, contribMetadata.MaxBulkPubBytes, defaultMaxBulkPubBytes),
}
batchMsg, err := sender.NewMessageBatch(ctx, batchOpts)
if err != nil {
return pubsub.NewBulkPublishResponse(req.Entries, pubsub.PublishFailed, err), err
}
// Add messages from the bulk publish request to the batch.
err = UpdateASBBatchMessageWithBulkPublishRequest(batchMsg, req)
if err != nil {
return pubsub.NewBulkPublishResponse(req.Entries, pubsub.PublishFailed, err), err
}
// Azure Service Bus does not return individual status for each message in the request.
err = sender.SendMessageBatch(ctx, batchMsg, nil)
if err != nil {
return pubsub.NewBulkPublishResponse(req.Entries, pubsub.PublishFailed, err), err
}
return pubsub.NewBulkPublishResponse(req.Entries, pubsub.PublishSucceeded, nil), nil
}
func (a *azureServiceBus) Subscribe(subscribeCtx context.Context, req pubsub.SubscribeRequest, handler pubsub.Handler) error {
subID := a.metadata.ConsumerID
if !a.metadata.DisableEntityManagement {