diff --git a/internal/component/azure/servicebus/client.go b/internal/component/azure/servicebus/client.go index d25f53b06..caad6c1a5 100644 --- a/internal/component/azure/servicebus/client.go +++ b/internal/component/azure/servicebus/client.go @@ -286,14 +286,7 @@ func (c *Client) shouldCreateSubscription(parentCtx context.Context, topic, subs return true, nil } - neq := func(a, b *bool) bool { - if a == nil || b == nil { - return true - } - return *a != *b - } - - if neq(res.RequiresSession, &opts.RequireSessions) { + if notEqual(res.RequiresSession, &opts.RequireSessions) { return false, fmt.Errorf("subscription %s already exists but session requirement doesn't match", subscription) } @@ -340,3 +333,13 @@ func (c *Client) createQueue(parentCtx context.Context, queue string) error { } return nil } + +func notEqual(a, b *bool) bool { + if a == nil && b == nil { + return false + } else if a == nil || b == nil { + return true + } else { + return *a != *b + } +} diff --git a/internal/component/azure/servicebus/receiver.go b/internal/component/azure/servicebus/receiver.go index 6fb77c74f..5fc14c465 100644 --- a/internal/component/azure/servicebus/receiver.go +++ b/internal/component/azure/servicebus/receiver.go @@ -16,6 +16,8 @@ package servicebus import ( "context" "fmt" + "sync" + "time" azservicebus "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus" "go.uber.org/multierr" @@ -41,12 +43,15 @@ type SessionReceiver struct { *azservicebus.SessionReceiver } -func (s *SessionReceiver) RenewSessionLocks(ctx context.Context) error { +func (s *SessionReceiver) RenewSessionLocks(ctx context.Context, timeoutInSec int) error { if s == nil { return nil } - return s.RenewSessionLock(ctx, nil) + lockCtx, lockCancel := context.WithTimeout(ctx, time.Second*time.Duration(timeoutInSec)) + defer lockCancel() + + return s.RenewSessionLock(lockCtx, nil) } func NewMessageReceiver(r *azservicebus.Receiver) *MessageReceiver { @@ -57,19 +62,42 @@ type MessageReceiver struct { *azservicebus.Receiver } -func (m *MessageReceiver) RenewMessageLocks(ctx context.Context, msgs []*azservicebus.ReceivedMessage) error { +func (m *MessageReceiver) RenewMessageLocks(ctx context.Context, msgs []*azservicebus.ReceivedMessage, timeoutInSec int) error { if m == nil { return nil } - var errs []error + var wg sync.WaitGroup + + errChan := make(chan error, len(msgs)) for _, msg := range msgs { - // Renew the lock for the message. - err := m.RenewMessageLock(ctx, msg, nil) - if err != nil { - errs = append(errs, fmt.Errorf("couldn't renew all active message lock(s) for message %s, %w", msg.MessageID, err)) - } + wg.Add(1) + + go func(rmsg *azservicebus.ReceivedMessage) { + defer wg.Done() + + lockCtx, lockCancel := context.WithTimeout(ctx, time.Second*time.Duration(timeoutInSec)) + defer lockCancel() + + // Renew the lock for the message. + err := m.RenewMessageLock(lockCtx, rmsg, nil) + if err != nil { + errChan <- fmt.Errorf("couldn't renew active message lock for message %s, %w", rmsg.MessageID, err) + } + }(msg) } - return multierr.Combine(errs...) + wg.Wait() + close(errChan) + + var errs []error + for err := range errChan { + errs = append(errs, err) + } + + if len(errs) > 0 { + return multierr.Combine(errs...) + } + + return nil } diff --git a/internal/component/azure/servicebus/subscription.go b/internal/component/azure/servicebus/subscription.go index 81beda704..0e49ab487 100644 --- a/internal/component/azure/servicebus/subscription.go +++ b/internal/component/azure/servicebus/subscription.go @@ -28,6 +28,15 @@ import ( "github.com/dapr/kit/retry" ) +const ( + RequireSessionsMetadataKey = "requireSessions" + SessionIdleTimeoutMetadataKey = "sessionIdleTimeoutInSec" + MaxConcurrentSessionsMetadataKey = "maxConcurrentSessions" + + DefaultSesssionIdleTimeoutInSec = 60 + DefaultMaxConcurrentSessions = 8 +) + // HandlerResponseItem represents a response from the handler for each message. type HandlerResponseItem struct { EntryId string //nolint:stylecheck @@ -347,35 +356,31 @@ func (s *Subscription) RenewLocksBlocking(ctx context.Context, receiver Receiver s.logger.Infof("context canceled while renewing locks for %s", s.entity) return nil default: - func() { - lockCtx, lockCancel := context.WithTimeout(ctx, time.Second*time.Duration(opts.TimeoutInSec)) - defer lockCancel() - - if s.requireSessions { - sessionReceiver := receiver.(*SessionReceiver) - if err := sessionReceiver.RenewSessionLocks(lockCtx); err != nil { - s.logger.Errorf("error renewing session locks for %s: %s", s.entity, err) - } - s.logger.Debugf("Renewed session %s locks for %s", sessionReceiver.SessionID(), s.entity) - } else { - // Snapshot the messages to try to renew locks for. - msgs := make([]*azservicebus.ReceivedMessage, 0) - s.mu.RLock() - for _, m := range s.activeMessages { - msgs = append(msgs, m) - } - s.mu.RUnlock() - if len(msgs) == 0 { - s.logger.Debugf("No active messages require lock renewal for %s", s.entity) - return - } - msgReceiver := receiver.(*MessageReceiver) - if err := msgReceiver.RenewMessageLocks(lockCtx, msgs); err != nil { - s.logger.Errorf("error renewing message locks for %s: %s", s.entity, err) - } - s.logger.Debugf("Renewed message locks for %s", s.entity) + if s.requireSessions { + sessionReceiver := receiver.(*SessionReceiver) + if err := sessionReceiver.RenewSessionLocks(ctx, opts.TimeoutInSec); err != nil { + s.logger.Errorf("error renewing session locks for %s: %s", s.entity, err) } - }() + s.logger.Debugf("Renewed session %s locks for %s", sessionReceiver.SessionID(), s.entity) + } else { + // Snapshot the messages to try to renew locks for. + msgs := make([]*azservicebus.ReceivedMessage, len(s.activeMessages)) + s.mu.RLock() + for i, m := range s.activeMessages { + msgs[i] = m + } + s.mu.RUnlock() + + if len(msgs) == 0 { + s.logger.Debugf("No active messages require lock renewal for %s", s.entity) + continue + } + msgReceiver := receiver.(*MessageReceiver) + if err := msgReceiver.RenewMessageLocks(ctx, msgs, opts.TimeoutInSec); err != nil { + s.logger.Errorf("error renewing message locks for %s: %s", s.entity, err) + } + s.logger.Debugf("Renewed message locks for %s", s.entity) + } } } } diff --git a/pubsub/azure/servicebus/topics/servicebus.go b/pubsub/azure/servicebus/topics/servicebus.go index 8fc45f427..09d74c509 100644 --- a/pubsub/azure/servicebus/topics/servicebus.go +++ b/pubsub/azure/servicebus/topics/servicebus.go @@ -30,14 +30,8 @@ import ( ) const ( - requireSessionsMetadataKey = "requireSessions" - sessionIdleTimeoutMetadataKey = "sessionIdleTimeoutInSec" - maxConcurrentSessionsMetadataKey = "maxConcurrentSessions" - - defaultMaxBulkSubCount = 100 - defaultMaxBulkPubBytes uint64 = 1024 * 128 // 128 KiB - defaultSesssionIdleTimeoutInSec = 60 - defaultMaxConcurrentSessions = 8 + defaultMaxBulkSubCount = 100 + defaultMaxBulkPubBytes uint64 = 1024 * 128 // 128 KiB ) type azureServiceBus struct { @@ -183,11 +177,11 @@ func (a *azureServiceBus) BulkPublish(ctx context.Context, req *pubsub.BulkPubli func (a *azureServiceBus) Subscribe(subscribeCtx context.Context, req pubsub.SubscribeRequest, handler pubsub.Handler) error { var requireSessions bool - if val, ok := req.Metadata[requireSessionsMetadataKey]; ok && val != "" { + if val, ok := req.Metadata[impl.RequireSessionsMetadataKey]; ok && val != "" { requireSessions = utils.IsTruthy(val) } - sessionIdleTimeout := time.Duration(utils.GetElemOrDefaultFromMap(req.Metadata, sessionIdleTimeoutMetadataKey, defaultSesssionIdleTimeoutInSec)) * time.Second - maxConcurrentSessions := utils.GetElemOrDefaultFromMap(req.Metadata, maxConcurrentSessionsMetadataKey, defaultMaxConcurrentSessions) + sessionIdleTimeout := time.Duration(utils.GetElemOrDefaultFromMap(req.Metadata, impl.SessionIdleTimeoutMetadataKey, impl.DefaultSesssionIdleTimeoutInSec)) * time.Second + maxConcurrentSessions := utils.GetElemOrDefaultFromMap(req.Metadata, impl.MaxConcurrentSessionsMetadataKey, impl.DefaultMaxConcurrentSessions) sub := impl.NewSubscription( subscribeCtx, impl.SubsriptionOptions{ @@ -223,11 +217,11 @@ func (a *azureServiceBus) Subscribe(subscribeCtx context.Context, req pubsub.Sub func (a *azureServiceBus) BulkSubscribe(subscribeCtx context.Context, req pubsub.SubscribeRequest, handler pubsub.BulkHandler) error { var requireSessions bool - if val, ok := req.Metadata[requireSessionsMetadataKey]; ok && val != "" { + if val, ok := req.Metadata[impl.RequireSessionsMetadataKey]; ok && val != "" { requireSessions = utils.IsTruthy(val) } - sessionIdleTimeout := time.Duration(utils.GetElemOrDefaultFromMap(req.Metadata, sessionIdleTimeoutMetadataKey, defaultSesssionIdleTimeoutInSec)) * time.Second - maxConcurrentSessions := utils.GetElemOrDefaultFromMap(req.Metadata, maxConcurrentSessionsMetadataKey, defaultMaxConcurrentSessions) + sessionIdleTimeout := time.Duration(utils.GetElemOrDefaultFromMap(req.Metadata, impl.SessionIdleTimeoutMetadataKey, impl.DefaultSesssionIdleTimeoutInSec)) * time.Second + maxConcurrentSessions := utils.GetElemOrDefaultFromMap(req.Metadata, impl.MaxConcurrentSessionsMetadataKey, impl.DefaultMaxConcurrentSessions) maxBulkSubCount := utils.GetElemOrDefaultFromMap(req.Metadata, contribMetadata.MaxBulkSubCountKey, defaultMaxBulkSubCount) sub := impl.NewSubscription(