components-contrib/pubsub/azure/servicebus/servicebus.go

349 lines
11 KiB
Go

// ------------------------------------------------------------
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
// ------------------------------------------------------------
package servicebus
import (
"context"
"fmt"
"strconv"
"time"
azservicebus "github.com/Azure/azure-service-bus-go"
"github.com/dapr/components-contrib/pubsub"
log "github.com/sirupsen/logrus"
)
const (
// Keys
connectionString = "connectionString"
consumerID = "consumerID"
maxDeliveryCount = "maxDeliveryCount"
timeoutInSec = "timeoutInSec"
lockDurationInSec = "lockDurationInSec"
defaultMessageTimeToLiveInSec = "defaultMessageTimeToLiveInSec"
autoDeleteOnIdleInSec = "autoDeleteOnIdleInSec"
disableEntityManagement = "disableEntityManagement"
errorMessagePrefix = "azure service bus error:"
// Defaults
defaultTimeoutInSec = 60
defaultDisableEntityManagement = false
)
type azureServiceBus struct {
metadata metadata
namespace *azservicebus.Namespace
topicManager *azservicebus.TopicManager
}
type subscription interface {
Close(ctx context.Context) error
Receive(ctx context.Context, handler azservicebus.Handler) error
}
// NewAzureServiceBus returns a new Azure ServiceBus pub-sub implementation
func NewAzureServiceBus() pubsub.PubSub {
return &azureServiceBus{}
}
func parseAzureServiceBusMetadata(meta pubsub.Metadata) (metadata, error) {
m := metadata{}
/* Required configuration settings - no defaults */
if val, ok := meta.Properties[connectionString]; ok && val != "" {
m.ConnectionString = val
} else {
return m, fmt.Errorf("%s missing connection string", errorMessagePrefix)
}
if val, ok := meta.Properties[consumerID]; ok && val != "" {
m.ConsumerID = val
} else {
return m, fmt.Errorf("%s missing consumerID", errorMessagePrefix)
}
/* Optional configuration settings - defaults will be set by the client */
m.TimeoutInSec = defaultTimeoutInSec
if val, ok := meta.Properties[timeoutInSec]; ok && val != "" {
var err error
m.TimeoutInSec, err = strconv.Atoi(val)
if err != nil {
return m, fmt.Errorf("%s invalid timeoutInSec %s, %s", errorMessagePrefix, val, err)
}
}
m.DisableEntityManagement = defaultDisableEntityManagement
if val, ok := meta.Properties[disableEntityManagement]; ok && val != "" {
var err error
m.DisableEntityManagement, err = strconv.ParseBool(val)
if err != nil {
return m, fmt.Errorf("%s invalid disableEntityManagement %s, %s", errorMessagePrefix, val, err)
}
}
/* Nullable configuration settings - defaults will be set by the server */
if val, ok := meta.Properties[maxDeliveryCount]; ok && val != "" {
valAsInt, err := strconv.Atoi(val)
if err != nil {
return m, fmt.Errorf("%s invalid maxDeliveryCount %s, %s", errorMessagePrefix, val, err)
}
m.MaxDeliveryCount = &valAsInt
}
if val, ok := meta.Properties[lockDurationInSec]; ok && val != "" {
valAsInt, err := strconv.Atoi(val)
if err != nil {
return m, fmt.Errorf("%s invalid lockDurationInSec %s, %s", errorMessagePrefix, val, err)
}
m.LockDurationInSec = &valAsInt
}
if val, ok := meta.Properties[defaultMessageTimeToLiveInSec]; ok && val != "" {
valAsInt, err := strconv.Atoi(val)
if err != nil {
return m, fmt.Errorf("%s invalid defaultMessageTimeToLiveInSec %s, %s", errorMessagePrefix, val, err)
}
m.DefaultMessageTimeToLiveInSec = &valAsInt
}
if val, ok := meta.Properties[autoDeleteOnIdleInSec]; ok && val != "" {
valAsInt, err := strconv.Atoi(val)
if err != nil {
return m, fmt.Errorf("%s invalid autoDeleteOnIdleInSecKey %s, %s", errorMessagePrefix, val, err)
}
m.AutoDeleteOnIdleInSec = &valAsInt
}
return m, nil
}
func (a *azureServiceBus) Init(metadata pubsub.Metadata) error {
m, err := parseAzureServiceBusMetadata(metadata)
if err != nil {
return err
}
a.metadata = m
a.namespace, err = azservicebus.NewNamespace(azservicebus.NamespaceWithConnectionString(a.metadata.ConnectionString))
if err != nil {
return err
}
a.topicManager = a.namespace.NewTopicManager()
return nil
}
func (a *azureServiceBus) Publish(req *pubsub.PublishRequest) error {
if !a.metadata.DisableEntityManagement {
err := a.ensureTopic(req.Topic)
if err != nil {
return err
}
}
sender, err := a.namespace.NewTopic(req.Topic)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(a.metadata.TimeoutInSec))
defer cancel()
err = sender.Send(ctx, azservicebus.NewMessage(req.Data))
if err != nil {
return err
}
return nil
}
func (a *azureServiceBus) Subscribe(req pubsub.SubscribeRequest, handler func(msg *pubsub.NewMessage) error) error {
subID := a.metadata.ConsumerID
if !a.metadata.DisableEntityManagement {
err := a.ensureSubscription(subID, req.Topic)
if err != nil {
return err
}
}
topic, err := a.namespace.NewTopic(req.Topic)
if err != nil {
return fmt.Errorf("%s could not instantiate topic %s, %s", errorMessagePrefix, req.Topic, err)
}
var sub subscription
sub, err = topic.NewSubscription(subID)
if err != nil {
return fmt.Errorf("%s could not instantiate subscription %s for topic %s", errorMessagePrefix, subID, req.Topic)
}
sbHandlerFunc := azservicebus.HandlerFunc(a.getHandlerFunc(req.Topic, handler))
ctx := context.Background()
go a.handleSubscriptionMessages(ctx, req.Topic, sub, sbHandlerFunc)
return nil
}
func (a *azureServiceBus) getHandlerFunc(topic string, handler func(msg *pubsub.NewMessage) error) func(ctx context.Context, message *azservicebus.Message) error {
return func(ctx context.Context, message *azservicebus.Message) error {
msg := &pubsub.NewMessage{
Data: message.Data,
Topic: topic,
}
err := handler(msg)
if err != nil {
return message.Abandon(ctx)
}
return message.Complete(ctx)
}
}
func (a *azureServiceBus) handleSubscriptionMessages(ctx context.Context, topic string, sub subscription, handlerFunc azservicebus.HandlerFunc) {
for {
if err := sub.Receive(ctx, handlerFunc); err != nil {
log.Errorf("%s error receiving from topic %s, %s", errorMessagePrefix, topic, err)
return
}
}
}
func (a *azureServiceBus) ensureTopic(topic string) error {
entity, err := a.getTopicEntity(topic)
if err != nil {
return err
}
if entity == nil {
err = a.createTopicEntity(topic)
if err != nil {
return err
}
}
return nil
}
func (a *azureServiceBus) ensureSubscription(name string, topic string) error {
err := a.ensureTopic(topic)
if err != nil {
return err
}
subManager, err := a.namespace.NewSubscriptionManager(topic)
if err != nil {
return err
}
entity, err := a.getSubscriptionEntity(subManager, topic, name)
if err != nil {
return err
}
if entity == nil {
err = a.createSubscriptionEntity(subManager, topic, name)
if err != nil {
return err
}
}
return nil
}
func (a *azureServiceBus) getTopicEntity(topic string) (*azservicebus.TopicEntity, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(a.metadata.TimeoutInSec))
defer cancel()
if a.topicManager == nil {
return nil, fmt.Errorf("%s init() has not been called", errorMessagePrefix)
}
topicEntity, err := a.topicManager.Get(ctx, topic)
if err != nil && !azservicebus.IsErrNotFound(err) {
return nil, fmt.Errorf("%s could not get topic %s, %s", errorMessagePrefix, topic, err)
}
return topicEntity, nil
}
func (a *azureServiceBus) createTopicEntity(topic string) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(a.metadata.TimeoutInSec))
defer cancel()
_, err := a.topicManager.Put(ctx, topic)
if err != nil {
return fmt.Errorf("%s could not put topic %s, %s", errorMessagePrefix, topic, err)
}
return nil
}
func (a *azureServiceBus) getSubscriptionEntity(mgr *azservicebus.SubscriptionManager, topic, subscription string) (*azservicebus.SubscriptionEntity, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(a.metadata.TimeoutInSec))
defer cancel()
entity, err := mgr.Get(ctx, subscription)
if err != nil && !azservicebus.IsErrNotFound(err) {
return nil, fmt.Errorf("%s could not get subscription %s, %s", errorMessagePrefix, subscription, err)
}
return entity, nil
}
func (a *azureServiceBus) createSubscriptionEntity(mgr *azservicebus.SubscriptionManager, topic, subscription string) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(a.metadata.TimeoutInSec))
defer cancel()
opts, err := a.createSubscriptionManagementOptions()
if err != nil {
return err
}
_, err = mgr.Put(ctx, subscription, opts...)
if err != nil {
return fmt.Errorf("%s could not put subscription %s, %s", errorMessagePrefix, subscription, err)
}
return nil
}
func (a *azureServiceBus) createSubscriptionManagementOptions() ([]azservicebus.SubscriptionManagementOption, error) {
var opts []azservicebus.SubscriptionManagementOption
if a.metadata.MaxDeliveryCount != nil {
opts = append(opts, subscriptionManagementOptionsWithMaxDeliveryCount(a.metadata.MaxDeliveryCount))
}
if a.metadata.LockDurationInSec != nil {
opts = append(opts, subscriptionManagementOptionsWithLockDuration(a.metadata.LockDurationInSec))
}
if a.metadata.DefaultMessageTimeToLiveInSec != nil {
opts = append(opts, subscriptionManagementOptionsWithDefaultMessageTimeToLive(a.metadata.DefaultMessageTimeToLiveInSec))
}
if a.metadata.AutoDeleteOnIdleInSec != nil {
opts = append(opts, subscriptionManagementOptionsWithAutoDeleteOnIdle(a.metadata.AutoDeleteOnIdleInSec))
}
return opts, nil
}
func subscriptionManagementOptionsWithMaxDeliveryCount(maxDeliveryCount *int) azservicebus.SubscriptionManagementOption {
return func(d *azservicebus.SubscriptionDescription) error {
mdc := int32(*maxDeliveryCount)
d.MaxDeliveryCount = &mdc
return nil
}
}
func subscriptionManagementOptionsWithAutoDeleteOnIdle(durationInSec *int) azservicebus.SubscriptionManagementOption {
return func(d *azservicebus.SubscriptionDescription) error {
duration := fmt.Sprintf("PT%dS", *durationInSec)
d.AutoDeleteOnIdle = &duration
return nil
}
}
func subscriptionManagementOptionsWithDefaultMessageTimeToLive(durationInSec *int) azservicebus.SubscriptionManagementOption {
return func(d *azservicebus.SubscriptionDescription) error {
duration := fmt.Sprintf("PT%dS", *durationInSec)
d.DefaultMessageTimeToLive = &duration
return nil
}
}
func subscriptionManagementOptionsWithLockDuration(durationInSec *int) azservicebus.SubscriptionManagementOption {
return func(d *azservicebus.SubscriptionDescription) error {
duration := fmt.Sprintf("PT%dS", *durationInSec)
d.LockDuration = &duration
return nil
}
}