Address feedback
Signed-off-by: Joni Collinge <jonathancollinge@live.com>
This commit is contained in:
parent
b6911b67fd
commit
5051e28b9c
|
@ -286,14 +286,7 @@ func (c *Client) shouldCreateSubscription(parentCtx context.Context, topic, subs
|
|||
return true, nil
|
||||
}
|
||||
|
||||
neq := func(a, b *bool) bool {
|
||||
if a == nil || b == nil {
|
||||
return true
|
||||
}
|
||||
return *a != *b
|
||||
}
|
||||
|
||||
if neq(res.RequiresSession, &opts.RequireSessions) {
|
||||
if notEqual(res.RequiresSession, &opts.RequireSessions) {
|
||||
return false, fmt.Errorf("subscription %s already exists but session requirement doesn't match", subscription)
|
||||
}
|
||||
|
||||
|
@ -340,3 +333,13 @@ func (c *Client) createQueue(parentCtx context.Context, queue string) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func notEqual(a, b *bool) bool {
|
||||
if a == nil && b == nil {
|
||||
return false
|
||||
} else if a == nil || b == nil {
|
||||
return true
|
||||
} else {
|
||||
return *a != *b
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,6 +16,8 @@ package servicebus
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
azservicebus "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus"
|
||||
"go.uber.org/multierr"
|
||||
|
@ -41,12 +43,15 @@ type SessionReceiver struct {
|
|||
*azservicebus.SessionReceiver
|
||||
}
|
||||
|
||||
func (s *SessionReceiver) RenewSessionLocks(ctx context.Context) error {
|
||||
func (s *SessionReceiver) RenewSessionLocks(ctx context.Context, timeoutInSec int) error {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.RenewSessionLock(ctx, nil)
|
||||
lockCtx, lockCancel := context.WithTimeout(ctx, time.Second*time.Duration(timeoutInSec))
|
||||
defer lockCancel()
|
||||
|
||||
return s.RenewSessionLock(lockCtx, nil)
|
||||
}
|
||||
|
||||
func NewMessageReceiver(r *azservicebus.Receiver) *MessageReceiver {
|
||||
|
@ -57,19 +62,42 @@ type MessageReceiver struct {
|
|||
*azservicebus.Receiver
|
||||
}
|
||||
|
||||
func (m *MessageReceiver) RenewMessageLocks(ctx context.Context, msgs []*azservicebus.ReceivedMessage) error {
|
||||
func (m *MessageReceiver) RenewMessageLocks(ctx context.Context, msgs []*azservicebus.ReceivedMessage, timeoutInSec int) error {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var errs []error
|
||||
var wg sync.WaitGroup
|
||||
|
||||
errChan := make(chan error, len(msgs))
|
||||
for _, msg := range msgs {
|
||||
// Renew the lock for the message.
|
||||
err := m.RenewMessageLock(ctx, msg, nil)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("couldn't renew all active message lock(s) for message %s, %w", msg.MessageID, err))
|
||||
}
|
||||
wg.Add(1)
|
||||
|
||||
go func(rmsg *azservicebus.ReceivedMessage) {
|
||||
defer wg.Done()
|
||||
|
||||
lockCtx, lockCancel := context.WithTimeout(ctx, time.Second*time.Duration(timeoutInSec))
|
||||
defer lockCancel()
|
||||
|
||||
// Renew the lock for the message.
|
||||
err := m.RenewMessageLock(lockCtx, rmsg, nil)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("couldn't renew active message lock for message %s, %w", rmsg.MessageID, err)
|
||||
}
|
||||
}(msg)
|
||||
}
|
||||
|
||||
return multierr.Combine(errs...)
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
|
||||
var errs []error
|
||||
for err := range errChan {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return multierr.Combine(errs...)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -28,6 +28,15 @@ import (
|
|||
"github.com/dapr/kit/retry"
|
||||
)
|
||||
|
||||
const (
|
||||
RequireSessionsMetadataKey = "requireSessions"
|
||||
SessionIdleTimeoutMetadataKey = "sessionIdleTimeoutInSec"
|
||||
MaxConcurrentSessionsMetadataKey = "maxConcurrentSessions"
|
||||
|
||||
DefaultSesssionIdleTimeoutInSec = 60
|
||||
DefaultMaxConcurrentSessions = 8
|
||||
)
|
||||
|
||||
// HandlerResponseItem represents a response from the handler for each message.
|
||||
type HandlerResponseItem struct {
|
||||
EntryId string //nolint:stylecheck
|
||||
|
@ -347,35 +356,31 @@ func (s *Subscription) RenewLocksBlocking(ctx context.Context, receiver Receiver
|
|||
s.logger.Infof("context canceled while renewing locks for %s", s.entity)
|
||||
return nil
|
||||
default:
|
||||
func() {
|
||||
lockCtx, lockCancel := context.WithTimeout(ctx, time.Second*time.Duration(opts.TimeoutInSec))
|
||||
defer lockCancel()
|
||||
|
||||
if s.requireSessions {
|
||||
sessionReceiver := receiver.(*SessionReceiver)
|
||||
if err := sessionReceiver.RenewSessionLocks(lockCtx); err != nil {
|
||||
s.logger.Errorf("error renewing session locks for %s: %s", s.entity, err)
|
||||
}
|
||||
s.logger.Debugf("Renewed session %s locks for %s", sessionReceiver.SessionID(), s.entity)
|
||||
} else {
|
||||
// Snapshot the messages to try to renew locks for.
|
||||
msgs := make([]*azservicebus.ReceivedMessage, 0)
|
||||
s.mu.RLock()
|
||||
for _, m := range s.activeMessages {
|
||||
msgs = append(msgs, m)
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
if len(msgs) == 0 {
|
||||
s.logger.Debugf("No active messages require lock renewal for %s", s.entity)
|
||||
return
|
||||
}
|
||||
msgReceiver := receiver.(*MessageReceiver)
|
||||
if err := msgReceiver.RenewMessageLocks(lockCtx, msgs); err != nil {
|
||||
s.logger.Errorf("error renewing message locks for %s: %s", s.entity, err)
|
||||
}
|
||||
s.logger.Debugf("Renewed message locks for %s", s.entity)
|
||||
if s.requireSessions {
|
||||
sessionReceiver := receiver.(*SessionReceiver)
|
||||
if err := sessionReceiver.RenewSessionLocks(ctx, opts.TimeoutInSec); err != nil {
|
||||
s.logger.Errorf("error renewing session locks for %s: %s", s.entity, err)
|
||||
}
|
||||
}()
|
||||
s.logger.Debugf("Renewed session %s locks for %s", sessionReceiver.SessionID(), s.entity)
|
||||
} else {
|
||||
// Snapshot the messages to try to renew locks for.
|
||||
msgs := make([]*azservicebus.ReceivedMessage, len(s.activeMessages))
|
||||
s.mu.RLock()
|
||||
for i, m := range s.activeMessages {
|
||||
msgs[i] = m
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
if len(msgs) == 0 {
|
||||
s.logger.Debugf("No active messages require lock renewal for %s", s.entity)
|
||||
continue
|
||||
}
|
||||
msgReceiver := receiver.(*MessageReceiver)
|
||||
if err := msgReceiver.RenewMessageLocks(ctx, msgs, opts.TimeoutInSec); err != nil {
|
||||
s.logger.Errorf("error renewing message locks for %s: %s", s.entity, err)
|
||||
}
|
||||
s.logger.Debugf("Renewed message locks for %s", s.entity)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,14 +30,8 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
requireSessionsMetadataKey = "requireSessions"
|
||||
sessionIdleTimeoutMetadataKey = "sessionIdleTimeoutInSec"
|
||||
maxConcurrentSessionsMetadataKey = "maxConcurrentSessions"
|
||||
|
||||
defaultMaxBulkSubCount = 100
|
||||
defaultMaxBulkPubBytes uint64 = 1024 * 128 // 128 KiB
|
||||
defaultSesssionIdleTimeoutInSec = 60
|
||||
defaultMaxConcurrentSessions = 8
|
||||
defaultMaxBulkSubCount = 100
|
||||
defaultMaxBulkPubBytes uint64 = 1024 * 128 // 128 KiB
|
||||
)
|
||||
|
||||
type azureServiceBus struct {
|
||||
|
@ -183,11 +177,11 @@ func (a *azureServiceBus) BulkPublish(ctx context.Context, req *pubsub.BulkPubli
|
|||
|
||||
func (a *azureServiceBus) Subscribe(subscribeCtx context.Context, req pubsub.SubscribeRequest, handler pubsub.Handler) error {
|
||||
var requireSessions bool
|
||||
if val, ok := req.Metadata[requireSessionsMetadataKey]; ok && val != "" {
|
||||
if val, ok := req.Metadata[impl.RequireSessionsMetadataKey]; ok && val != "" {
|
||||
requireSessions = utils.IsTruthy(val)
|
||||
}
|
||||
sessionIdleTimeout := time.Duration(utils.GetElemOrDefaultFromMap(req.Metadata, sessionIdleTimeoutMetadataKey, defaultSesssionIdleTimeoutInSec)) * time.Second
|
||||
maxConcurrentSessions := utils.GetElemOrDefaultFromMap(req.Metadata, maxConcurrentSessionsMetadataKey, defaultMaxConcurrentSessions)
|
||||
sessionIdleTimeout := time.Duration(utils.GetElemOrDefaultFromMap(req.Metadata, impl.SessionIdleTimeoutMetadataKey, impl.DefaultSesssionIdleTimeoutInSec)) * time.Second
|
||||
maxConcurrentSessions := utils.GetElemOrDefaultFromMap(req.Metadata, impl.MaxConcurrentSessionsMetadataKey, impl.DefaultMaxConcurrentSessions)
|
||||
|
||||
sub := impl.NewSubscription(
|
||||
subscribeCtx, impl.SubsriptionOptions{
|
||||
|
@ -223,11 +217,11 @@ func (a *azureServiceBus) Subscribe(subscribeCtx context.Context, req pubsub.Sub
|
|||
|
||||
func (a *azureServiceBus) BulkSubscribe(subscribeCtx context.Context, req pubsub.SubscribeRequest, handler pubsub.BulkHandler) error {
|
||||
var requireSessions bool
|
||||
if val, ok := req.Metadata[requireSessionsMetadataKey]; ok && val != "" {
|
||||
if val, ok := req.Metadata[impl.RequireSessionsMetadataKey]; ok && val != "" {
|
||||
requireSessions = utils.IsTruthy(val)
|
||||
}
|
||||
sessionIdleTimeout := time.Duration(utils.GetElemOrDefaultFromMap(req.Metadata, sessionIdleTimeoutMetadataKey, defaultSesssionIdleTimeoutInSec)) * time.Second
|
||||
maxConcurrentSessions := utils.GetElemOrDefaultFromMap(req.Metadata, maxConcurrentSessionsMetadataKey, defaultMaxConcurrentSessions)
|
||||
sessionIdleTimeout := time.Duration(utils.GetElemOrDefaultFromMap(req.Metadata, impl.SessionIdleTimeoutMetadataKey, impl.DefaultSesssionIdleTimeoutInSec)) * time.Second
|
||||
maxConcurrentSessions := utils.GetElemOrDefaultFromMap(req.Metadata, impl.MaxConcurrentSessionsMetadataKey, impl.DefaultMaxConcurrentSessions)
|
||||
|
||||
maxBulkSubCount := utils.GetElemOrDefaultFromMap(req.Metadata, contribMetadata.MaxBulkSubCountKey, defaultMaxBulkSubCount)
|
||||
sub := impl.NewSubscription(
|
||||
|
|
Loading…
Reference in New Issue