components-contrib/common/component/azure/servicebus/client.go

398 lines
11 KiB
Go

/*
Copyright 2021 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package servicebus
import (
"context"
"fmt"
"sync"
"time"
servicebus "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus"
sbadmin "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/admin"
"github.com/cenkalti/backoff/v4"
"golang.org/x/exp/maps"
azauth "github.com/dapr/components-contrib/common/authentication/azure"
"github.com/dapr/kit/logger"
)
// Type that matches Client.EnsureTopic and Client.EnsureSubscription
type ensureFn func(context.Context, string) error
// Client contains the clients for Service Bus and methods to get senders and to create topics, subscriptions, queues.
type Client struct {
client *servicebus.Client
adminClient *sbadmin.Client
metadata *Metadata
lock *sync.RWMutex
senders map[string]*servicebus.Sender
}
// NewClient creates a new Client object.
func NewClient(metadata *Metadata, rawMetadata map[string]string) (*Client, error) {
client := &Client{
metadata: metadata,
lock: &sync.RWMutex{},
senders: make(map[string]*servicebus.Sender),
}
clientOpts := &servicebus.ClientOptions{
ApplicationID: "dapr-" + logger.DaprVersion,
// TODO: Use the built-in retry in the SDK rather than our own on top of that
/*RetryOptions: servicebus.RetryOptions{
MaxRetries: int32(metadata.PublishMaxRetries),
RetryDelay: time.Duration(metadata.PublishInitialRetryIntervalInMs) * time.Millisecond,
},*/
}
if metadata.ConnectionString != "" {
var err error
client.client, err = servicebus.NewClientFromConnectionString(metadata.ConnectionString, clientOpts)
if err != nil {
return nil, err
}
if !metadata.DisableEntityManagement {
client.adminClient, err = sbadmin.NewClientFromConnectionString(metadata.ConnectionString, nil)
if err != nil {
return nil, err
}
}
} else {
settings, err := azauth.NewEnvironmentSettings(rawMetadata)
if err != nil {
return nil, err
}
token, err := settings.GetTokenCredential()
if err != nil {
return nil, err
}
client.client, err = servicebus.NewClient(metadata.NamespaceName, token, clientOpts)
if err != nil {
return nil, err
}
if !metadata.DisableEntityManagement {
client.adminClient, err = sbadmin.NewClient(metadata.NamespaceName, token, nil)
if err != nil {
return nil, err
}
}
}
return client, nil
}
// GetClient returns the azservicebus.Client object.
func (c *Client) GetClient() *servicebus.Client {
return c.client
}
// GetSenderForTopic returns the sender for a queue or topic, or creates a new one if it doesn't exist
func (c *Client) GetSender(ctx context.Context, queueOrTopic string, ensureFn ensureFn) (*servicebus.Sender, error) {
c.lock.RLock()
sender, ok := c.senders[queueOrTopic]
c.lock.RUnlock()
if ok && sender != nil {
return sender, nil
}
c.lock.Lock()
defer c.lock.Unlock()
// Check again after acquiring a write lock in case another goroutine created the sender
sender, ok = c.senders[queueOrTopic]
if ok && sender != nil {
return sender, nil
}
// Ensure the queue or topic exists, if needed
if ensureFn != nil {
// Ensure the queue or topic exists the first time it is referenced
// This does nothing if DisableEntityManagement is true
err := ensureFn(ctx, queueOrTopic)
if err != nil {
return nil, err
}
}
// Create the sender
sender, err := c.client.NewSender(queueOrTopic, nil)
if err != nil {
return nil, err
}
c.senders[queueOrTopic] = sender
return sender, nil
}
// CloseSender closes a sender for a queue or topic.
func (c *Client) CloseSender(queueOrTopic string, log logger.Logger) {
c.lock.Lock()
sender, ok := c.senders[queueOrTopic]
if ok {
delete(c.senders, queueOrTopic)
}
c.lock.Unlock()
if sender != nil {
log.Info("Closing sender: " + queueOrTopic)
closeCtx, closeCancel := context.WithTimeout(context.Background(), time.Second)
// Log only
err := sender.Close(closeCtx)
closeCancel()
if err != nil {
// Log only
log.Warnf("Error closing sender %s: %v", queueOrTopic, err)
}
log.Debug("Closed sender: " + queueOrTopic)
}
}
// CloseAllSenders closes all sender connections.
func (c *Client) CloseAllSenders(log logger.Logger) {
c.lock.Lock()
defer c.lock.Unlock()
// Close all senders, up to 3 in parallel
workersCh := make(chan bool, 3)
for k, t := range c.senders {
// Blocks if we have too many goroutines
workersCh <- true
go func(k string, t *servicebus.Sender) {
log.Debug("Closing sender: " + k)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(c.metadata.TimeoutInSec)*time.Second)
err := t.Close(ctx)
cancel()
if err != nil {
// Log only
log.Warnf("Error closing sender %s: %v", k, err)
}
<-workersCh
}(k, t)
}
for i := 0; i < cap(workersCh); i++ {
// Wait for all workers to be done
workersCh <- true
}
close(workersCh)
// Clear the map
maps.Clear(c.senders)
}
// Close the client and every sender or consumer created by the connnection.
func (c *Client) Close(log logger.Logger) {
c.lock.Lock()
defer c.lock.Unlock()
if c.client != nil {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(c.metadata.TimeoutInSec)*time.Second)
err := c.client.Close(ctx)
cancel()
if err != nil {
// log only
log.Warnf("Error closing client: %v", err)
}
c.client = nil
}
// Clear the map of senders
maps.Clear(c.senders)
}
// EnsureTopic creates the topic if it doesn't exist.
// Returns with nil error if the admin client doesn't exist.
func (c *Client) EnsureTopic(ctx context.Context, topic string) error {
if c.adminClient == nil {
return nil
}
shouldCreate, err := c.shouldCreateTopic(ctx, topic)
if err != nil {
return err
}
if shouldCreate {
err = c.createTopic(ctx, topic)
if err != nil {
return err
}
}
return nil
}
type SubscribeOptions struct {
RequireSessions bool
MaxConcurrentSesions int
}
// EnsureSubscription creates the topic subscription if it doesn't exist.
// Returns with nil error if the admin client doesn't exist.
func (c *Client) EnsureSubscription(ctx context.Context, name string, topic string, opts SubscribeOptions) error {
if c.adminClient == nil {
return nil
}
err := c.EnsureTopic(ctx, topic)
if err != nil {
return err
}
shouldCreate, err := c.shouldCreateSubscription(ctx, topic, name, opts)
if err != nil {
return err
}
if shouldCreate {
err = c.createSubscription(ctx, topic, name, opts)
if err != nil {
return err
}
}
return nil
}
// EnsureTopic creates the queue if it doesn't exist.
// Returns with nil error if the admin client doesn't exist.
func (c *Client) EnsureQueue(ctx context.Context, queue string) error {
if c.adminClient == nil {
return nil
}
shouldCreate, err := c.shouldCreateQueue(ctx, queue)
if err != nil {
return err
}
if shouldCreate {
err = c.createQueue(ctx, queue)
if err != nil {
return err
}
}
return nil
}
func (c *Client) shouldCreateTopic(parentCtx context.Context, topic string) (bool, error) {
ctx, cancel := context.WithTimeout(parentCtx, time.Second*time.Duration(c.metadata.TimeoutInSec))
defer cancel()
res, err := c.adminClient.GetTopic(ctx, topic, nil)
if err != nil {
return false, fmt.Errorf("could not get topic %s: %w", topic, err)
}
if res == nil {
// If res is nil, the topic does not exist
return true, nil
}
return false, nil
}
func (c *Client) createTopic(parentCtx context.Context, topic string) error {
ctx, cancel := context.WithTimeout(parentCtx, time.Second*time.Duration(c.metadata.TimeoutInSec))
defer cancel()
_, err := c.adminClient.CreateTopic(ctx, topic, nil)
if err != nil {
return fmt.Errorf("could not create topic %s: %w", topic, err)
}
return nil
}
func (c *Client) shouldCreateSubscription(parentCtx context.Context, topic, subscription string, opts SubscribeOptions) (bool, error) {
ctx, cancel := context.WithTimeout(parentCtx, time.Second*time.Duration(c.metadata.TimeoutInSec))
defer cancel()
res, err := c.adminClient.GetSubscription(ctx, topic, subscription, nil)
if err != nil {
return false, fmt.Errorf("could not get subscription %s: %w", subscription, err)
}
if res == nil {
// If res is nil, the subscription does not exist
return true, nil
}
if notEqual(res.RequiresSession, &opts.RequireSessions) {
return false, fmt.Errorf("subscription %s already exists but session requirement doesn't match", subscription)
}
return false, nil
}
func (c *Client) createSubscription(parentCtx context.Context, topic, subscription string, opts SubscribeOptions) error {
ctx, cancel := context.WithTimeout(parentCtx, time.Second*time.Duration(c.metadata.TimeoutInSec))
defer cancel()
_, err := c.adminClient.CreateSubscription(ctx, topic, subscription, &sbadmin.CreateSubscriptionOptions{
Properties: c.metadata.CreateSubscriptionProperties(opts),
})
if err != nil {
return fmt.Errorf("could not create subscription %s: %w", subscription, err)
}
return nil
}
func (c *Client) shouldCreateQueue(parentCtx context.Context, queue string) (bool, error) {
ctx, cancel := context.WithTimeout(parentCtx, time.Second*time.Duration(c.metadata.TimeoutInSec))
defer cancel()
res, err := c.adminClient.GetQueue(ctx, queue, nil)
if err != nil {
return false, fmt.Errorf("could not get queue %s: %w", queue, err)
}
if res == nil {
// If res nil, the queue does not exist
return true, nil
}
return false, nil
}
func (c *Client) createQueue(parentCtx context.Context, queue string) error {
ctx, cancel := context.WithTimeout(parentCtx, time.Second*time.Duration(c.metadata.TimeoutInSec))
defer cancel()
_, err := c.adminClient.CreateQueue(ctx, queue, &sbadmin.CreateQueueOptions{
Properties: c.metadata.CreateQueueProperties(),
})
if err != nil {
return fmt.Errorf("could not create queue %s: %w", queue, err)
}
return nil
}
// ReconnectionBackoff returns the backoff for reconnecting in a subscription.
func (c *Client) ReconnectionBackoff() backoff.BackOff {
bo := backoff.NewExponentialBackOff()
bo.MaxElapsedTime = 0
bo.InitialInterval = time.Duration(c.metadata.MinConnectionRecoveryInSec) * time.Second
bo.MaxInterval = time.Duration(c.metadata.MaxConnectionRecoveryInSec) * time.Second
return bo
}
func notEqual(a, b *bool) bool {
if a == nil && b == nil {
return false
} else if a == nil || b == nil {
return true
}
return *a != *b
}