Address feedback

Signed-off-by: Joni Collinge <jonathancollinge@live.com>
This commit is contained in:
Joni Collinge 2022-12-20 09:59:20 +00:00
parent b6911b67fd
commit 5051e28b9c
No known key found for this signature in database
GPG Key ID: BF9B59005264DD95
4 changed files with 90 additions and 60 deletions

View File

@ -286,14 +286,7 @@ func (c *Client) shouldCreateSubscription(parentCtx context.Context, topic, subs
return true, nil
}
neq := func(a, b *bool) bool {
if a == nil || b == nil {
return true
}
return *a != *b
}
if neq(res.RequiresSession, &opts.RequireSessions) {
if notEqual(res.RequiresSession, &opts.RequireSessions) {
return false, fmt.Errorf("subscription %s already exists but session requirement doesn't match", subscription)
}
@ -340,3 +333,13 @@ func (c *Client) createQueue(parentCtx context.Context, queue string) error {
}
return nil
}
func notEqual(a, b *bool) bool {
if a == nil && b == nil {
return false
} else if a == nil || b == nil {
return true
} else {
return *a != *b
}
}

View File

@ -16,6 +16,8 @@ package servicebus
import (
"context"
"fmt"
"sync"
"time"
azservicebus "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus"
"go.uber.org/multierr"
@ -41,12 +43,15 @@ type SessionReceiver struct {
*azservicebus.SessionReceiver
}
func (s *SessionReceiver) RenewSessionLocks(ctx context.Context) error {
func (s *SessionReceiver) RenewSessionLocks(ctx context.Context, timeoutInSec int) error {
if s == nil {
return nil
}
return s.RenewSessionLock(ctx, nil)
lockCtx, lockCancel := context.WithTimeout(ctx, time.Second*time.Duration(timeoutInSec))
defer lockCancel()
return s.RenewSessionLock(lockCtx, nil)
}
func NewMessageReceiver(r *azservicebus.Receiver) *MessageReceiver {
@ -57,19 +62,42 @@ type MessageReceiver struct {
*azservicebus.Receiver
}
func (m *MessageReceiver) RenewMessageLocks(ctx context.Context, msgs []*azservicebus.ReceivedMessage) error {
func (m *MessageReceiver) RenewMessageLocks(ctx context.Context, msgs []*azservicebus.ReceivedMessage, timeoutInSec int) error {
if m == nil {
return nil
}
var errs []error
var wg sync.WaitGroup
errChan := make(chan error, len(msgs))
for _, msg := range msgs {
// Renew the lock for the message.
err := m.RenewMessageLock(ctx, msg, nil)
if err != nil {
errs = append(errs, fmt.Errorf("couldn't renew all active message lock(s) for message %s, %w", msg.MessageID, err))
}
wg.Add(1)
go func(rmsg *azservicebus.ReceivedMessage) {
defer wg.Done()
lockCtx, lockCancel := context.WithTimeout(ctx, time.Second*time.Duration(timeoutInSec))
defer lockCancel()
// Renew the lock for the message.
err := m.RenewMessageLock(lockCtx, rmsg, nil)
if err != nil {
errChan <- fmt.Errorf("couldn't renew active message lock for message %s, %w", rmsg.MessageID, err)
}
}(msg)
}
return multierr.Combine(errs...)
wg.Wait()
close(errChan)
var errs []error
for err := range errChan {
errs = append(errs, err)
}
if len(errs) > 0 {
return multierr.Combine(errs...)
}
return nil
}

View File

@ -28,6 +28,15 @@ import (
"github.com/dapr/kit/retry"
)
const (
RequireSessionsMetadataKey = "requireSessions"
SessionIdleTimeoutMetadataKey = "sessionIdleTimeoutInSec"
MaxConcurrentSessionsMetadataKey = "maxConcurrentSessions"
DefaultSesssionIdleTimeoutInSec = 60
DefaultMaxConcurrentSessions = 8
)
// HandlerResponseItem represents a response from the handler for each message.
type HandlerResponseItem struct {
EntryId string //nolint:stylecheck
@ -347,35 +356,31 @@ func (s *Subscription) RenewLocksBlocking(ctx context.Context, receiver Receiver
s.logger.Infof("context canceled while renewing locks for %s", s.entity)
return nil
default:
func() {
lockCtx, lockCancel := context.WithTimeout(ctx, time.Second*time.Duration(opts.TimeoutInSec))
defer lockCancel()
if s.requireSessions {
sessionReceiver := receiver.(*SessionReceiver)
if err := sessionReceiver.RenewSessionLocks(lockCtx); err != nil {
s.logger.Errorf("error renewing session locks for %s: %s", s.entity, err)
}
s.logger.Debugf("Renewed session %s locks for %s", sessionReceiver.SessionID(), s.entity)
} else {
// Snapshot the messages to try to renew locks for.
msgs := make([]*azservicebus.ReceivedMessage, 0)
s.mu.RLock()
for _, m := range s.activeMessages {
msgs = append(msgs, m)
}
s.mu.RUnlock()
if len(msgs) == 0 {
s.logger.Debugf("No active messages require lock renewal for %s", s.entity)
return
}
msgReceiver := receiver.(*MessageReceiver)
if err := msgReceiver.RenewMessageLocks(lockCtx, msgs); err != nil {
s.logger.Errorf("error renewing message locks for %s: %s", s.entity, err)
}
s.logger.Debugf("Renewed message locks for %s", s.entity)
if s.requireSessions {
sessionReceiver := receiver.(*SessionReceiver)
if err := sessionReceiver.RenewSessionLocks(ctx, opts.TimeoutInSec); err != nil {
s.logger.Errorf("error renewing session locks for %s: %s", s.entity, err)
}
}()
s.logger.Debugf("Renewed session %s locks for %s", sessionReceiver.SessionID(), s.entity)
} else {
// Snapshot the messages to try to renew locks for.
msgs := make([]*azservicebus.ReceivedMessage, len(s.activeMessages))
s.mu.RLock()
for i, m := range s.activeMessages {
msgs[i] = m
}
s.mu.RUnlock()
if len(msgs) == 0 {
s.logger.Debugf("No active messages require lock renewal for %s", s.entity)
continue
}
msgReceiver := receiver.(*MessageReceiver)
if err := msgReceiver.RenewMessageLocks(ctx, msgs, opts.TimeoutInSec); err != nil {
s.logger.Errorf("error renewing message locks for %s: %s", s.entity, err)
}
s.logger.Debugf("Renewed message locks for %s", s.entity)
}
}
}
}

View File

@ -30,14 +30,8 @@ import (
)
const (
requireSessionsMetadataKey = "requireSessions"
sessionIdleTimeoutMetadataKey = "sessionIdleTimeoutInSec"
maxConcurrentSessionsMetadataKey = "maxConcurrentSessions"
defaultMaxBulkSubCount = 100
defaultMaxBulkPubBytes uint64 = 1024 * 128 // 128 KiB
defaultSesssionIdleTimeoutInSec = 60
defaultMaxConcurrentSessions = 8
defaultMaxBulkSubCount = 100
defaultMaxBulkPubBytes uint64 = 1024 * 128 // 128 KiB
)
type azureServiceBus struct {
@ -183,11 +177,11 @@ func (a *azureServiceBus) BulkPublish(ctx context.Context, req *pubsub.BulkPubli
func (a *azureServiceBus) Subscribe(subscribeCtx context.Context, req pubsub.SubscribeRequest, handler pubsub.Handler) error {
var requireSessions bool
if val, ok := req.Metadata[requireSessionsMetadataKey]; ok && val != "" {
if val, ok := req.Metadata[impl.RequireSessionsMetadataKey]; ok && val != "" {
requireSessions = utils.IsTruthy(val)
}
sessionIdleTimeout := time.Duration(utils.GetElemOrDefaultFromMap(req.Metadata, sessionIdleTimeoutMetadataKey, defaultSesssionIdleTimeoutInSec)) * time.Second
maxConcurrentSessions := utils.GetElemOrDefaultFromMap(req.Metadata, maxConcurrentSessionsMetadataKey, defaultMaxConcurrentSessions)
sessionIdleTimeout := time.Duration(utils.GetElemOrDefaultFromMap(req.Metadata, impl.SessionIdleTimeoutMetadataKey, impl.DefaultSesssionIdleTimeoutInSec)) * time.Second
maxConcurrentSessions := utils.GetElemOrDefaultFromMap(req.Metadata, impl.MaxConcurrentSessionsMetadataKey, impl.DefaultMaxConcurrentSessions)
sub := impl.NewSubscription(
subscribeCtx, impl.SubsriptionOptions{
@ -223,11 +217,11 @@ func (a *azureServiceBus) Subscribe(subscribeCtx context.Context, req pubsub.Sub
func (a *azureServiceBus) BulkSubscribe(subscribeCtx context.Context, req pubsub.SubscribeRequest, handler pubsub.BulkHandler) error {
var requireSessions bool
if val, ok := req.Metadata[requireSessionsMetadataKey]; ok && val != "" {
if val, ok := req.Metadata[impl.RequireSessionsMetadataKey]; ok && val != "" {
requireSessions = utils.IsTruthy(val)
}
sessionIdleTimeout := time.Duration(utils.GetElemOrDefaultFromMap(req.Metadata, sessionIdleTimeoutMetadataKey, defaultSesssionIdleTimeoutInSec)) * time.Second
maxConcurrentSessions := utils.GetElemOrDefaultFromMap(req.Metadata, maxConcurrentSessionsMetadataKey, defaultMaxConcurrentSessions)
sessionIdleTimeout := time.Duration(utils.GetElemOrDefaultFromMap(req.Metadata, impl.SessionIdleTimeoutMetadataKey, impl.DefaultSesssionIdleTimeoutInSec)) * time.Second
maxConcurrentSessions := utils.GetElemOrDefaultFromMap(req.Metadata, impl.MaxConcurrentSessionsMetadataKey, impl.DefaultMaxConcurrentSessions)
maxBulkSubCount := utils.GetElemOrDefaultFromMap(req.Metadata, contribMetadata.MaxBulkSubCountKey, defaultMaxBulkSubCount)
sub := impl.NewSubscription(