diff --git a/bindings/azure/servicebusqueues/servicebusqueues.go b/bindings/azure/servicebusqueues/servicebusqueues.go index e50fcd0c8..a3a736cf7 100644 --- a/bindings/azure/servicebusqueues/servicebusqueues.go +++ b/bindings/azure/servicebusqueues/servicebusqueues.go @@ -17,7 +17,6 @@ import ( "context" "errors" "fmt" - "sync" "time" servicebus "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus" @@ -36,19 +35,16 @@ const ( // AzureServiceBusQueues is an input/output binding reading from and sending events to Azure Service Bus queues. type AzureServiceBusQueues struct { - metadata *impl.Metadata - client *impl.Client - timeout time.Duration - sender *servicebus.Sender - senderLock sync.RWMutex - logger logger.Logger + metadata *impl.Metadata + client *impl.Client + timeout time.Duration + logger logger.Logger } // NewAzureServiceBusQueues returns a new AzureServiceBusQueues instance. func NewAzureServiceBusQueues(logger logger.Logger) bindings.InputOutputBinding { return &AzureServiceBusQueues{ - senderLock: sync.RWMutex{}, - logger: logger, + logger: logger, } } @@ -79,7 +75,7 @@ func (a *AzureServiceBusQueues) Operations() []bindings.OperationKind { } func (a *AzureServiceBusQueues) Invoke(invokeCtx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) { - sender, err := a.getSender() + sender, err := a.client.GetSender(invokeCtx, a.metadata.QueueName) if err != nil { return nil, fmt.Errorf("failed to create a sender for the Service Bus queue: %w", err) } @@ -96,7 +92,7 @@ func (a *AzureServiceBusQueues) Invoke(invokeCtx context.Context, req *bindings. if err != nil { if impl.IsNetworkError(err) { // Force reconnection on next call - a.deleteSender() + a.client.CloseSender(a.metadata.QueueName) } return nil, err } @@ -173,46 +169,6 @@ func (a *AzureServiceBusQueues) Read(subscribeCtx context.Context, handler bindi return nil } -// getSender returns the Sender object, creating a new connection if needed -func (a *AzureServiceBusQueues) getSender() (*servicebus.Sender, error) { - // Check if the sender already exists - a.senderLock.RLock() - if a.sender != nil { - a.senderLock.RUnlock() - return a.sender, nil - } - a.senderLock.RUnlock() - - // Acquire a write lock then try checking a.sender again in case another goroutine modified that in the meanwhile - a.senderLock.Lock() - defer a.senderLock.Unlock() - - if a.sender != nil { - return a.sender, nil - } - - // Create a new sender - sender, err := a.client.GetClient().NewSender(a.metadata.QueueName, nil) - if err != nil { - return nil, err - } - a.sender = sender - - return sender, nil -} - -// deleteSender deletes the sender, closing the connection -func (a *AzureServiceBusQueues) deleteSender() { - a.senderLock.Lock() - if a.sender != nil { - closeCtx, closeCancel := context.WithTimeout(context.Background(), time.Second) - _ = a.sender.Close(closeCtx) - closeCancel() - a.sender = nil - } - a.senderLock.Unlock() -} - func (a *AzureServiceBusQueues) getHandlerFunc(handler bindings.Handler) impl.HandlerFunc { return func(ctx context.Context, asbMsgs []*servicebus.ReceivedMessage) ([]impl.HandlerResponseItem, error) { if len(asbMsgs) != 1 { @@ -245,19 +201,7 @@ func (a *AzureServiceBusQueues) getHandlerFunc(handler bindings.Handler) impl.Ha } func (a *AzureServiceBusQueues) Close() (err error) { - a.senderLock.Lock() - defer a.senderLock.Unlock() - a.logger.Debug("Closing component") - - if a.sender != nil { - ctx, cancel := context.WithTimeout(context.Background(), a.timeout) - err = a.sender.Close(ctx) - cancel() - a.sender = nil - if err != nil { - return err - } - } + a.client.CloseSender(a.metadata.QueueName) return nil } diff --git a/internal/component/azure/servicebus/client.go b/internal/component/azure/servicebus/client.go index ba41869ff..1ee4b5ab5 100644 --- a/internal/component/azure/servicebus/client.go +++ b/internal/component/azure/servicebus/client.go @@ -16,9 +16,9 @@ package servicebus import ( "context" "fmt" + "sync" "time" - "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus" servicebus "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus" sbadmin "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/admin" @@ -26,17 +26,21 @@ import ( "github.com/dapr/kit/logger" ) -// Client contains the clients for Service Bus and methods to create topics, subscriptions, queues. +// Client contains the clients for Service Bus and methods to get senders and to create topics, subscriptions, queues. type Client struct { client *servicebus.Client adminClient *sbadmin.Client metadata *Metadata + lock *sync.RWMutex + senders map[string]*servicebus.Sender } // NewClient creates a new Client object. func NewClient(metadata *Metadata, rawMetadata map[string]string) (*Client, error) { client := &Client{ metadata: metadata, + lock: &sync.RWMutex{}, + senders: make(map[string]*servicebus.Sender), } clientOpts := &servicebus.ClientOptions{ @@ -89,10 +93,84 @@ func NewClient(metadata *Metadata, rawMetadata map[string]string) (*Client, erro } // GetClient returns the azservicebus.Client object. -func (c *Client) GetClient() *azservicebus.Client { +func (c *Client) GetClient() *servicebus.Client { return c.client } +// GetSenderForTopic returns the sender for a topic, or creates a new one if it doesn't exist +func (c *Client) GetSender(ctx context.Context, queueOrTopic string) (*servicebus.Sender, error) { + c.lock.RLock() + sender, ok := c.senders[queueOrTopic] + c.lock.RUnlock() + if ok && sender != nil { + return sender, nil + } + + c.lock.Lock() + defer c.lock.Unlock() + + // Check again after acquiring a write lock in case another goroutine created the sender + sender, ok = c.senders[queueOrTopic] + if ok && sender != nil { + return sender, nil + } + + // Create the sender + sender, err := c.client.NewSender(queueOrTopic, nil) + if err != nil { + return nil, err + } + c.senders[queueOrTopic] = sender + + return sender, nil +} + +// CloseSender closes a sender for a queue or topic. +func (c *Client) CloseSender(queueOrTopic string) { + c.lock.Lock() + defer c.lock.Unlock() + + sender, ok := c.senders[queueOrTopic] + if ok && sender != nil { + closeCtx, closeCancel := context.WithTimeout(context.Background(), time.Second) + _ = sender.Close(closeCtx) + closeCancel() + } + delete(c.senders, queueOrTopic) +} + +// CloseAllSenders closes all sender connections. +func (c *Client) CloseAllSenders(log logger.Logger) { + c.lock.Lock() + defer c.lock.Unlock() + + // Close all senders, up to 3 in parallel + workersCh := make(chan bool, 3) + for k, t := range c.senders { + // Blocks if we have too many goroutines + workersCh <- true + go func(k string, t *servicebus.Sender) { + log.Debugf("Closing sender %s", k) + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(c.metadata.TimeoutInSec)*time.Second) + err := t.Close(ctx) + cancel() + if err != nil { + // Log only + log.Warnf("Error closing sender %s: %v", k, err) + } + <-workersCh + }(k, t) + } + for i := 0; i < cap(workersCh); i++ { + // Wait for all workers to be done + workersCh <- true + } + close(workersCh) + + // Clear the map + c.senders = make(map[string]*servicebus.Sender) +} + // EnsureTopic creates the topic if it doesn't exist. // Returns with nil error if the admin client doesn't exist. func (c *Client) EnsureTopic(ctx context.Context, topic string) error { diff --git a/internal/component/azure/servicebus/message_test.go b/internal/component/azure/servicebus/message_test.go index 63186078c..44e53e431 100644 --- a/internal/component/azure/servicebus/message_test.go +++ b/internal/component/azure/servicebus/message_test.go @@ -36,12 +36,6 @@ var ( testContentType = "testContentType" nowUtc = time.Now().UTC() testScheduledEnqueueTimeUtc = nowUtc.Format(http.TimeFormat) - testLockTokenString = "bG9ja3Rva2VuAAAAAAAAAA==" //nolint:gosec - testLockTokenBytes = [16]byte{108, 111, 99, 107, 116, 111, 107, 101, 110} - testDeliveryCount = uint32(1) - testSampleTime = time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC) - testSampleTimeHTTPFormat = "Thu, 01 Jan 1970 00:00:00 GMT" - testSequenceNumber = int64(1) ) func TestAddMetadataToMessage(t *testing.T) { diff --git a/pubsub/azure/servicebus/message_test.go b/pubsub/azure/servicebus/message_test.go index 6ba336440..0c9615ea8 100644 --- a/pubsub/azure/servicebus/message_test.go +++ b/pubsub/azure/servicebus/message_test.go @@ -15,34 +15,30 @@ package servicebus import ( "fmt" - "net/http" "testing" "time" + azservicebus "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus" "github.com/stretchr/testify/assert" - azservicebus "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus" impl "github.com/dapr/components-contrib/internal/component/azure/servicebus" ) var ( - testMessageID = "testMessageId" - testCorrelationID = "testCorrelationId" - testSessionID = "testSessionId" - testLabel = "testLabel" - testReplyTo = "testReplyTo" - testTo = "testTo" - testPartitionKey = testSessionID - testPartitionKeyUnique = "testPartitionKey" - testContentType = "testContentType" - nowUtc = time.Now().UTC() - testScheduledEnqueueTimeUtc = nowUtc.Format(http.TimeFormat) - testLockTokenString = "bG9ja3Rva2VuAAAAAAAAAA==" //nolint:gosec - testLockTokenBytes = [16]byte{108, 111, 99, 107, 116, 111, 107, 101, 110} - testDeliveryCount = uint32(1) - testSampleTime = time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC) - testSampleTimeHTTPFormat = "Thu, 01 Jan 1970 00:00:00 GMT" - testSequenceNumber = int64(1) + testMessageID = "testMessageId" + testCorrelationID = "testCorrelationId" + testSessionID = "testSessionId" + testLabel = "testLabel" + testReplyTo = "testReplyTo" + testTo = "testTo" + testPartitionKey = testSessionID + testContentType = "testContentType" + testLockTokenString = "bG9ja3Rva2VuAAAAAAAAAA==" //nolint:gosec + testLockTokenBytes = [16]byte{108, 111, 99, 107, 116, 111, 107, 101, 110} + testDeliveryCount = uint32(1) + testSampleTime = time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC) + testSampleTimeHTTPFormat = "Thu, 01 Jan 1970 00:00:00 GMT" + testSequenceNumber = int64(1) ) func TestAddMessageAttributesToMetadata(t *testing.T) { diff --git a/pubsub/azure/servicebus/servicebus.go b/pubsub/azure/servicebus/servicebus.go index af11fef9e..775374590 100644 --- a/pubsub/azure/servicebus/servicebus.go +++ b/pubsub/azure/servicebus/servicebus.go @@ -17,7 +17,6 @@ import ( "context" "errors" "fmt" - "sync" "time" servicebus "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus" @@ -38,13 +37,10 @@ const ( ) type azureServiceBus struct { - metadata *impl.Metadata - client *impl.Client - logger logger.Logger - features []pubsub.Feature - topics map[string]*servicebus.Sender - topicsLock *sync.RWMutex - + metadata *impl.Metadata + client *impl.Client + logger logger.Logger + features []pubsub.Feature publishCtx context.Context publishCancel context.CancelFunc } @@ -52,10 +48,8 @@ type azureServiceBus struct { // NewAzureServiceBus returns a new Azure ServiceBus pub-sub implementation. func NewAzureServiceBus(logger logger.Logger) pubsub.PubSub { return &azureServiceBus{ - logger: logger, - features: []pubsub.Feature{pubsub.FeatureMessageTTL}, - topics: map[string]*servicebus.Sender{}, - topicsLock: &sync.RWMutex{}, + logger: logger, + features: []pubsub.Feature{pubsub.FeatureMessageTTL}, } } @@ -92,9 +86,16 @@ func (a *azureServiceBus) Publish(req *pubsub.PublishRequest) error { } return retry.NotifyRecover( func() (err error) { + // Ensure the queue or topic exists the first time it is referenced + // This does nothing if DisableEntityManagement is true + err = a.client.EnsureTopic(a.publishCtx, req.Topic) + if err != nil { + return err + } + // Get the sender var sender *servicebus.Sender - sender, err = a.senderForTopic(a.publishCtx, req.Topic) + sender, err = a.client.GetSender(a.publishCtx, req.Topic) if err != nil { return err } @@ -106,7 +107,7 @@ func (a *azureServiceBus) Publish(req *pubsub.PublishRequest) error { if err != nil { if impl.IsNetworkError(err) { // Retry after reconnecting - a.deleteSenderForTopic(req.Topic) + a.client.CloseSender(req.Topic) return err } @@ -138,7 +139,15 @@ func (a *azureServiceBus) BulkPublish(ctx context.Context, req *pubsub.BulkPubli return pubsub.NewBulkPublishResponse(req.Entries, pubsub.PublishSucceeded, nil), nil } - sender, err := a.senderForTopic(ctx, req.Topic) + // Ensure the queue or topic exists the first time it is referenced + // This does nothing if DisableEntityManagement is true + err := a.client.EnsureTopic(a.publishCtx, req.Topic) + if err != nil { + return pubsub.NewBulkPublishResponse(req.Entries, pubsub.PublishFailed, err), err + } + + // Get the sender + sender, err := a.client.GetSender(ctx, req.Topic) if err != nil { return pubsub.NewBulkPublishResponse(req.Entries, pubsub.PublishFailed, err), err } @@ -338,84 +347,9 @@ func (a *azureServiceBus) getBulkHandlerFunc(topic string, handler pubsub.BulkHa } } -// senderForTopic returns the sender for a topic, or creates a new one if it doesn't exist -func (a *azureServiceBus) senderForTopic(ctx context.Context, topic string) (*servicebus.Sender, error) { - a.topicsLock.RLock() - sender, ok := a.topics[topic] - a.topicsLock.RUnlock() - if ok && sender != nil { - return sender, nil - } - - a.topicsLock.Lock() - defer a.topicsLock.Unlock() - - // Check again after acquiring a write lock in case another goroutine created the sender - sender, ok = a.topics[topic] - if ok && sender != nil { - return sender, nil - } - - // Ensure the topic exists the first time it is referenced - // This does nothing if DisableEntityManagement is true - err := a.client.EnsureTopic(ctx, topic) - if err != nil { - return nil, err - } - - // Create the sender - sender, err = a.client.GetClient().NewSender(topic, nil) - if err != nil { - return nil, err - } - a.topics[topic] = sender - - return sender, nil -} - -// deleteSenderForTopic deletes a sender for a topic, closing the connection -func (a *azureServiceBus) deleteSenderForTopic(topic string) { - a.topicsLock.Lock() - defer a.topicsLock.Unlock() - - sender, ok := a.topics[topic] - if ok && sender != nil { - closeCtx, closeCancel := context.WithTimeout(context.Background(), time.Second) - _ = sender.Close(closeCtx) - closeCancel() - } - delete(a.topics, topic) -} - func (a *azureServiceBus) Close() (err error) { - a.topicsLock.Lock() - defer a.topicsLock.Unlock() - a.publishCancel() - - // Close all topics, up to 3 in parallel - workersCh := make(chan bool, 3) - for k, t := range a.topics { - // Blocks if we have too many goroutines - workersCh <- true - go func(k string, t *servicebus.Sender) { - a.logger.Debugf("Closing topic %s", k) - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(a.metadata.TimeoutInSec)*time.Second) - err = t.Close(ctx) - cancel() - if err != nil { - // Log only - a.logger.Warnf("%s closing topic %s: %+v", errorMessagePrefix, k, err) - } - <-workersCh - }(k, t) - } - for i := 0; i < cap(workersCh); i++ { - // Wait for all workers to be done - workersCh <- true - } - close(workersCh) - + a.client.CloseAllSenders(a.logger) return nil }