Azure ServiceBus PubSub Implementation (#93)
* initial skeleton code * updated sub message handler * lowercased errors * added deadletter * added context * remove nil args * added TODOs * add max delivery count and timeout config * refactored ensure methods * removed confusing comment * apply go fmt * removed consumerID default * updated go mod * fixed up linter issues * add package alias * removed TODOs * added additional servicebus config * fix linting
This commit is contained in:
parent
db7bf1f08b
commit
ed3c829578
6
go.mod
6
go.mod
|
|
@ -12,8 +12,8 @@ require (
|
|||
github.com/Azure/azure-service-bus-go v0.9.1
|
||||
github.com/Azure/azure-storage-blob-go v0.8.0
|
||||
github.com/Azure/go-autorest v13.0.1+incompatible // indirect
|
||||
github.com/Azure/go-autorest/autorest v0.9.1
|
||||
github.com/Azure/go-autorest/autorest/adal v0.6.0
|
||||
github.com/Azure/go-autorest/autorest v0.9.2
|
||||
github.com/Azure/go-autorest/autorest/adal v0.8.0
|
||||
github.com/Azure/go-autorest/autorest/azure/auth v0.3.0
|
||||
github.com/Azure/go-autorest/autorest/to v0.3.0 // indirect
|
||||
github.com/Azure/go-autorest/autorest/validation v0.2.0 // indirect
|
||||
|
|
@ -77,6 +77,8 @@ require (
|
|||
go.uber.org/multierr v1.2.0 // indirect
|
||||
golang.org/x/crypto v0.0.0-20190927123631-a832865fa7ad
|
||||
golang.org/x/exp v0.0.0-20190927203820-447a159532ef // indirect
|
||||
golang.org/x/image v0.0.0-20190910094157-69e4b8554b2a // indirect
|
||||
golang.org/x/mobile v0.0.0-20190923204409-d3ece3b6da5f // indirect
|
||||
golang.org/x/net v0.0.0-20190926025831-c00fd9afed17
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e // indirect
|
||||
golang.org/x/time v0.0.0-20190921001708-c4c64cad1fd0 // indirect
|
||||
|
|
|
|||
12
go.sum
12
go.sum
|
|
@ -52,10 +52,14 @@ github.com/Azure/go-autorest v13.0.1+incompatible/go.mod h1:r+4oMnoxhatjLLJ6zxSW
|
|||
github.com/Azure/go-autorest/autorest v0.9.0/go.mod h1:xyHB1BMZT0cuDHU7I0+g046+BFDTQ8rEZB0s4Yfa6bI=
|
||||
github.com/Azure/go-autorest/autorest v0.9.1 h1:JB7Mqhna/7J8gZfVHjxDSTLSD6ciz2YgSMb/4qLXTtY=
|
||||
github.com/Azure/go-autorest/autorest v0.9.1/go.mod h1:xyHB1BMZT0cuDHU7I0+g046+BFDTQ8rEZB0s4Yfa6bI=
|
||||
github.com/Azure/go-autorest/autorest v0.9.2 h1:6AWuh3uWrsZJcNoCHrCF/+g4aKPCU39kaMO6/qrnK/4=
|
||||
github.com/Azure/go-autorest/autorest v0.9.2/go.mod h1:xyHB1BMZT0cuDHU7I0+g046+BFDTQ8rEZB0s4Yfa6bI=
|
||||
github.com/Azure/go-autorest/autorest/adal v0.5.0 h1:q2gDruN08/guU9vAjuPWff0+QIrpH6ediguzdAzXAUU=
|
||||
github.com/Azure/go-autorest/autorest/adal v0.5.0/go.mod h1:8Z9fGy2MpX0PvDjB1pEgQTmVqjGhiHBW7RJJEciWzS0=
|
||||
github.com/Azure/go-autorest/autorest/adal v0.6.0 h1:UCTq22yE3RPgbU/8u4scfnnzuCW6pwQ9n+uBtV78ouo=
|
||||
github.com/Azure/go-autorest/autorest/adal v0.6.0/go.mod h1:Z6vX6WXXuyieHAXwMj0S6HY6e6wcHn37qQMBQlvY3lc=
|
||||
github.com/Azure/go-autorest/autorest/adal v0.8.0 h1:CxTzQrySOxDnKpLjFJeZAS5Qrv/qFPkgLjx5bOAi//I=
|
||||
github.com/Azure/go-autorest/autorest/adal v0.8.0/go.mod h1:Z6vX6WXXuyieHAXwMj0S6HY6e6wcHn37qQMBQlvY3lc=
|
||||
github.com/Azure/go-autorest/autorest/azure/auth v0.3.0 h1:JwftqZDtWkr3qt1kcEgPd7H57uCHsXKXf66agWUQcGw=
|
||||
github.com/Azure/go-autorest/autorest/azure/auth v0.3.0/go.mod h1:CI4BQYBct8NS7BXNBBX+RchsFsUu5+oz+OSyR/ZIi7U=
|
||||
github.com/Azure/go-autorest/autorest/azure/cli v0.3.0 h1:5PAqnv+CSTwW9mlZWZAizmzrazFWEgZykEZXpr2hDtY=
|
||||
|
|
@ -320,6 +324,8 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
|||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kubernetes-client/go v0.0.0-20190625181339-cd8e39e789c7 h1:NZlvd1Qf3MwoRhh87iVkJSHK3R31fX3D7kQfdJy6LnQ=
|
||||
github.com/kubernetes-client/go v0.0.0-20190625181339-cd8e39e789c7/go.mod h1:ks4KCmmxdXksTSu2dlnUanEOqNd/dsoyS6/7bay2RQ8=
|
||||
github.com/lithammer/shortuuid v3.0.0+incompatible h1:NcD0xWW/MZYXEHa6ITy6kaXN5nwm/V115vj2YXfhS0w=
|
||||
github.com/lithammer/shortuuid v3.0.0+incompatible/go.mod h1:FR74pbAuElzOUuenUHTK2Tciko1/vKuIKS9dSkDrA4w=
|
||||
github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
|
||||
github.com/mattn/go-ieproxy v0.0.0-20190610004146-91bb50d98149 h1:HfxbT6/JcvIljmERptWhwa8XzP7H3T+Z2N26gTsaDaA=
|
||||
github.com/mattn/go-ieproxy v0.0.0-20190610004146-91bb50d98149/go.mod h1:31jz6HNzdxOmlERGGEc4v/dMssOfmp2p5bT/okiKFFc=
|
||||
|
|
@ -478,11 +484,13 @@ golang.org/x/crypto v0.0.0-20190927123631-a832865fa7ad/go.mod h1:yigFU9vqHzYiE8U
|
|||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||
golang.org/x/exp v0.0.0-20190731235908-ec7cb31e5a56/go.mod h1:JhuoJpWY28nO4Vef9tZUw9qufEGTyX1+7lmHxV5q5G4=
|
||||
golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek=
|
||||
golang.org/x/exp v0.0.0-20190927203820-447a159532ef h1:0MEfU0Kh8iitbYr+L8WhnyAxLCVa5p0hV8tnPmdGDp0=
|
||||
golang.org/x/exp v0.0.0-20190927203820-447a159532ef/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
|
||||
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
||||
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
golang.org/x/image v0.0.0-20190910094157-69e4b8554b2a/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
|
||||
golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||
|
|
@ -492,6 +500,7 @@ golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac h1:8R1esu+8QioDxo4E4mX6bFzt
|
|||
golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
|
||||
golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE=
|
||||
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o=
|
||||
golang.org/x/mobile v0.0.0-20190923204409-d3ece3b6da5f/go.mod h1:p895TfNkDgPEmEQrNiOtIl3j98d/tGU95djDj7NfyjQ=
|
||||
golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc=
|
||||
golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
|
|
@ -515,6 +524,8 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL
|
|||
golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20190926025831-c00fd9afed17 h1:qPnAdmjNA41t3QBTx2mFGf/SD1IoslhYu7AmdsVzCcs=
|
||||
golang.org/x/net v0.0.0-20190926025831-c00fd9afed17/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20191014212845-da9a3fd4c582 h1:p9xBe/w/OzkeYVKm234g55gMdD1nSIooTir5kV11kfA=
|
||||
golang.org/x/net v0.0.0-20191014212845-da9a3fd4c582/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/oauth2 v0.0.0-20190402181905-9f3314589c9a/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
|
|
@ -575,6 +586,7 @@ golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgw
|
|||
golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
|
||||
golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
|
||||
golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20190909214602-067311248421/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20190927191325-030b2cf1153e h1:1xWUkZQQ9Z9UuZgNaIR6OQOE7rUFglXUUBZlO+dGg6I=
|
||||
golang.org/x/tools v0.0.0-20190927191325-030b2cf1153e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ Currently supported pub-subs are:
|
|||
|
||||
* Redis Streams
|
||||
* NATS
|
||||
* Azure Service Bus
|
||||
|
||||
## Implementing a new Pub Sub
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,348 @@
|
|||
// ------------------------------------------------------------
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
// ------------------------------------------------------------
|
||||
|
||||
package azureservicebus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
servicebus "github.com/Azure/azure-service-bus-go"
|
||||
log "github.com/Sirupsen/logrus"
|
||||
"github.com/dapr/components-contrib/pubsub"
|
||||
)
|
||||
|
||||
const (
|
||||
// Keys
|
||||
connectionString = "connectionString"
|
||||
consumerID = "consumerID"
|
||||
maxDeliveryCount = "maxDeliveryCount"
|
||||
timeoutInSec = "timeoutInSec"
|
||||
lockDurationInSec = "lockDurationInSec"
|
||||
defaultMessageTimeToLiveInSec = "defaultMessageTimeToLiveInSec"
|
||||
autoDeleteOnIdleInSec = "autoDeleteOnIdleInSec"
|
||||
disableEntityManagement = "disableEntityManagement"
|
||||
|
||||
// Defaults
|
||||
defaultTimeoutInSec = 60
|
||||
defaultDisableEntityManagement = false
|
||||
)
|
||||
|
||||
type azureServiceBus struct {
|
||||
metadata metadata
|
||||
namespace *servicebus.Namespace
|
||||
topicManager *servicebus.TopicManager
|
||||
}
|
||||
|
||||
type subscription interface {
|
||||
Close(ctx context.Context) error
|
||||
Receive(ctx context.Context, handler servicebus.Handler) error
|
||||
}
|
||||
|
||||
// NewAzureServiceBus returns a new Azure ServiceBus pub-sub implementation
|
||||
func NewAzureServiceBus() pubsub.PubSub {
|
||||
return &azureServiceBus{}
|
||||
}
|
||||
|
||||
func parseAzureServiceBusMetadata(meta pubsub.Metadata) (metadata, error) {
|
||||
m := metadata{}
|
||||
|
||||
/* Required configuration settings - no defaults */
|
||||
if val, ok := meta.Properties[connectionString]; ok && val != "" {
|
||||
m.ConnectionString = val
|
||||
} else {
|
||||
return m, errors.New("azure serivce bus error: missing connection string")
|
||||
}
|
||||
|
||||
if val, ok := meta.Properties[consumerID]; ok && val != "" {
|
||||
m.ConsumerID = val
|
||||
} else {
|
||||
return m, errors.New("azure service bus error: missing consumerID")
|
||||
}
|
||||
|
||||
/* Optional configuration settings - defaults will be set by the client */
|
||||
m.TimeoutInSec = defaultTimeoutInSec
|
||||
if val, ok := meta.Properties[timeoutInSec]; ok && val != "" {
|
||||
var err error
|
||||
m.TimeoutInSec, err = strconv.Atoi(val)
|
||||
if err != nil {
|
||||
return m, fmt.Errorf("azure service bus error: invalid timeoutInSec %s, %s", val, err)
|
||||
}
|
||||
}
|
||||
|
||||
m.DisableEntityManagement = defaultDisableEntityManagement
|
||||
if val, ok := meta.Properties[disableEntityManagement]; ok && val != "" {
|
||||
var err error
|
||||
m.DisableEntityManagement, err = strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
return m, fmt.Errorf("azure service bus error: invalid disableEntityManagement %s, %s", val, err)
|
||||
}
|
||||
}
|
||||
|
||||
/* Nullable configuration settings - defaults will be set by the server */
|
||||
if val, ok := meta.Properties[maxDeliveryCount]; ok && val != "" {
|
||||
valAsInt, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
return m, fmt.Errorf("azure service bus error: invalid maxDeliveryCount %s, %s", val, err)
|
||||
}
|
||||
m.MaxDeliveryCount = &valAsInt
|
||||
}
|
||||
|
||||
if val, ok := meta.Properties[lockDurationInSec]; ok && val != "" {
|
||||
valAsInt, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
return m, fmt.Errorf("azure service bus error: invalid lockDurationInSec %s, %s", val, err)
|
||||
}
|
||||
m.LockDurationInSec = &valAsInt
|
||||
}
|
||||
|
||||
if val, ok := meta.Properties[defaultMessageTimeToLiveInSec]; ok && val != "" {
|
||||
valAsInt, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
return m, fmt.Errorf("azure service bus error: invalid defaultMessageTimeToLiveInSec %s, %s", val, err)
|
||||
}
|
||||
m.DefaultMessageTimeToLiveInSec = &valAsInt
|
||||
}
|
||||
|
||||
if val, ok := meta.Properties[autoDeleteOnIdleInSec]; ok && val != "" {
|
||||
valAsInt, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
return m, fmt.Errorf("azure service bus error: invalid autoDeleteOnIdleInSecKey %s, %s", val, err)
|
||||
}
|
||||
m.AutoDeleteOnIdleInSec = &valAsInt
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (a *azureServiceBus) Init(metadata pubsub.Metadata) error {
|
||||
m, err := parseAzureServiceBusMetadata(metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
a.metadata = m
|
||||
a.namespace, err = servicebus.NewNamespace(servicebus.NamespaceWithConnectionString(a.metadata.ConnectionString))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
a.topicManager = a.namespace.NewTopicManager()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *azureServiceBus) Publish(req *pubsub.PublishRequest) error {
|
||||
if !a.metadata.DisableEntityManagement {
|
||||
err := a.ensureTopic(req.Topic)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
sender, err := a.namespace.NewTopic(req.Topic)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(a.metadata.TimeoutInSec))
|
||||
defer cancel()
|
||||
|
||||
err = sender.Send(ctx, servicebus.NewMessage(req.Data))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *azureServiceBus) Subscribe(req pubsub.SubscribeRequest, handler func(msg *pubsub.NewMessage) error) error {
|
||||
subID := a.metadata.ConsumerID
|
||||
if !a.metadata.DisableEntityManagement {
|
||||
err := a.ensureSubscription(subID, req.Topic)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
topic, err := a.namespace.NewTopic(req.Topic)
|
||||
if err != nil {
|
||||
return fmt.Errorf("service bus error: could not instantiate topic %s", req.Topic)
|
||||
}
|
||||
|
||||
var sub subscription
|
||||
sub, err = topic.NewSubscription(subID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("service bus error: could not instantiate subscription %s for topic %s", subID, req.Topic)
|
||||
}
|
||||
|
||||
sbHandlerFunc := servicebus.HandlerFunc(a.getHandlerFunc(req.Topic, handler))
|
||||
|
||||
ctx := context.Background()
|
||||
go a.handleSubscriptionMessages(ctx, req.Topic, sub, sbHandlerFunc)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *azureServiceBus) getHandlerFunc(topic string, handler func(msg *pubsub.NewMessage) error) func(ctx context.Context, message *servicebus.Message) error {
|
||||
return func(ctx context.Context, message *servicebus.Message) error {
|
||||
msg := &pubsub.NewMessage{
|
||||
Data: message.Data,
|
||||
Topic: topic,
|
||||
}
|
||||
err := handler(msg)
|
||||
if err != nil {
|
||||
return message.Abandon(ctx)
|
||||
}
|
||||
return message.Complete(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *azureServiceBus) handleSubscriptionMessages(ctx context.Context, topic string, sub subscription, handlerFunc servicebus.HandlerFunc) {
|
||||
for {
|
||||
if err := sub.Receive(ctx, handlerFunc); err != nil {
|
||||
log.Errorf("service bus error: error receiving from topic %s, %s", topic, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *azureServiceBus) ensureTopic(topic string) error {
|
||||
entity, err := a.getTopicEntity(topic)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if entity == nil {
|
||||
err = a.createTopicEntity(topic)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *azureServiceBus) ensureSubscription(name string, topic string) error {
|
||||
err := a.ensureTopic(topic)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
subManager, err := a.namespace.NewSubscriptionManager(topic)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
entity, err := a.getSubscriptionEntity(subManager, topic, name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if entity == nil {
|
||||
err = a.createSubscriptionEntity(subManager, topic, name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *azureServiceBus) getTopicEntity(topic string) (*servicebus.TopicEntity, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(a.metadata.TimeoutInSec))
|
||||
defer cancel()
|
||||
|
||||
if a.topicManager == nil {
|
||||
return nil, fmt.Errorf("service bus error: init() has not been called")
|
||||
}
|
||||
topicEntity, err := a.topicManager.Get(ctx, topic)
|
||||
if err != nil && !servicebus.IsErrNotFound(err) {
|
||||
return nil, fmt.Errorf("service bus error: could not get topic %s, %s", topic, err)
|
||||
}
|
||||
return topicEntity, nil
|
||||
}
|
||||
|
||||
func (a *azureServiceBus) createTopicEntity(topic string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(a.metadata.TimeoutInSec))
|
||||
defer cancel()
|
||||
_, err := a.topicManager.Put(ctx, topic)
|
||||
if err != nil {
|
||||
return fmt.Errorf("service bus error: could not put topic %s, %s", topic, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *azureServiceBus) getSubscriptionEntity(mgr *servicebus.SubscriptionManager, topic, subscription string) (*servicebus.SubscriptionEntity, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(a.metadata.TimeoutInSec))
|
||||
defer cancel()
|
||||
entity, err := mgr.Get(ctx, subscription)
|
||||
if err != nil && !servicebus.IsErrNotFound(err) {
|
||||
return nil, fmt.Errorf("service bus error: could not get subscription %s, %s", subscription, err)
|
||||
}
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
func (a *azureServiceBus) createSubscriptionEntity(mgr *servicebus.SubscriptionManager, topic, subscription string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(a.metadata.TimeoutInSec))
|
||||
defer cancel()
|
||||
|
||||
opts, err := a.createSubscriptionManagementOptions()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = mgr.Put(ctx, subscription, opts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("service bus error: could not put subscription %s, %s", subscription, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *azureServiceBus) createSubscriptionManagementOptions() ([]servicebus.SubscriptionManagementOption, error) {
|
||||
var opts []servicebus.SubscriptionManagementOption
|
||||
if a.metadata.MaxDeliveryCount != nil {
|
||||
opts = append(opts, subscriptionManagementOptionsWithMaxDeliveryCount(a.metadata.MaxDeliveryCount))
|
||||
}
|
||||
if a.metadata.LockDurationInSec != nil {
|
||||
opts = append(opts, subscriptionManagementOptionsWithLockDuration(a.metadata.LockDurationInSec))
|
||||
}
|
||||
if a.metadata.DefaultMessageTimeToLiveInSec != nil {
|
||||
opts = append(opts, subscriptionManagementOptionsWithDefaultMessageTimeToLive(a.metadata.DefaultMessageTimeToLiveInSec))
|
||||
}
|
||||
if a.metadata.DefaultMessageTimeToLiveInSec != nil {
|
||||
opts = append(opts, subscriptionManagementOptionsWithAutoDeleteOnIdle(a.metadata.AutoDeleteOnIdleInSec))
|
||||
}
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func subscriptionManagementOptionsWithMaxDeliveryCount(maxDeliveryCount *int) servicebus.SubscriptionManagementOption {
|
||||
return func(d *servicebus.SubscriptionDescription) error {
|
||||
mdc := int32(*maxDeliveryCount)
|
||||
d.MaxDeliveryCount = &mdc
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func subscriptionManagementOptionsWithAutoDeleteOnIdle(durationInSec *int) servicebus.SubscriptionManagementOption {
|
||||
return func(d *servicebus.SubscriptionDescription) error {
|
||||
duration := fmt.Sprintf("PT%dS", *durationInSec)
|
||||
d.AutoDeleteOnIdle = &duration
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func subscriptionManagementOptionsWithDefaultMessageTimeToLive(durationInSec *int) servicebus.SubscriptionManagementOption {
|
||||
return func(d *servicebus.SubscriptionDescription) error {
|
||||
duration := fmt.Sprintf("PT%dS", *durationInSec)
|
||||
d.DefaultMessageTimeToLive = &duration
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func subscriptionManagementOptionsWithLockDuration(durationInSec *int) servicebus.SubscriptionManagementOption {
|
||||
return func(d *servicebus.SubscriptionDescription) error {
|
||||
duration := fmt.Sprintf("PT%dS", *durationInSec)
|
||||
d.LockDuration = &duration
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,279 @@
|
|||
// ------------------------------------------------------------
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
// ------------------------------------------------------------
|
||||
|
||||
package azureservicebus
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/dapr/components-contrib/pubsub"
|
||||
)
|
||||
|
||||
const (
|
||||
invalidNumber = "invalid_number"
|
||||
)
|
||||
|
||||
func getFakeProperties() map[string]string {
|
||||
return map[string]string{
|
||||
connectionString: "fakeConnectionString",
|
||||
consumerID: "fakeConId",
|
||||
disableEntityManagement: "true",
|
||||
timeoutInSec: "90",
|
||||
maxDeliveryCount: "10",
|
||||
autoDeleteOnIdleInSec: "240",
|
||||
defaultMessageTimeToLiveInSec: "2400",
|
||||
lockDurationInSec: "120",
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseServiceBusMetadata(t *testing.T) {
|
||||
t.Run("metadata is correct", func(t *testing.T) {
|
||||
fakeProperties := getFakeProperties()
|
||||
|
||||
fakeMetaData := pubsub.Metadata{
|
||||
Properties: fakeProperties,
|
||||
}
|
||||
|
||||
// act
|
||||
m, err := parseAzureServiceBusMetadata(fakeMetaData)
|
||||
|
||||
// assert
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, fakeProperties[connectionString], m.ConnectionString)
|
||||
assert.Equal(t, fakeProperties[consumerID], m.ConsumerID)
|
||||
|
||||
assert.Equal(t, 90, m.TimeoutInSec)
|
||||
assert.Equal(t, true, m.DisableEntityManagement)
|
||||
|
||||
assert.NotNil(t, m.AutoDeleteOnIdleInSec)
|
||||
assert.Equal(t, 240, *m.AutoDeleteOnIdleInSec)
|
||||
assert.NotNil(t, m.MaxDeliveryCount)
|
||||
assert.Equal(t, 10, *m.MaxDeliveryCount)
|
||||
assert.NotNil(t, m.DefaultMessageTimeToLiveInSec)
|
||||
assert.Equal(t, 2400, *m.DefaultMessageTimeToLiveInSec)
|
||||
assert.NotNil(t, m.LockDurationInSec)
|
||||
assert.Equal(t, 120, *m.LockDurationInSec)
|
||||
})
|
||||
|
||||
t.Run("missing required connectionString", func(t *testing.T) {
|
||||
fakeProperties := getFakeProperties()
|
||||
|
||||
fakeMetaData := pubsub.Metadata{
|
||||
Properties: fakeProperties,
|
||||
}
|
||||
fakeMetaData.Properties[connectionString] = ""
|
||||
|
||||
// act
|
||||
m, err := parseAzureServiceBusMetadata(fakeMetaData)
|
||||
|
||||
// assert
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, m.ConnectionString)
|
||||
})
|
||||
|
||||
t.Run("missing required consumerID", func(t *testing.T) {
|
||||
fakeProperties := getFakeProperties()
|
||||
|
||||
fakeMetaData := pubsub.Metadata{
|
||||
Properties: fakeProperties,
|
||||
}
|
||||
fakeMetaData.Properties[consumerID] = ""
|
||||
|
||||
// act
|
||||
m, err := parseAzureServiceBusMetadata(fakeMetaData)
|
||||
|
||||
// assert
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, m.ConsumerID)
|
||||
})
|
||||
|
||||
t.Run("missing optional timeoutInSec", func(t *testing.T) {
|
||||
fakeProperties := getFakeProperties()
|
||||
|
||||
fakeMetaData := pubsub.Metadata{
|
||||
Properties: fakeProperties,
|
||||
}
|
||||
fakeMetaData.Properties[timeoutInSec] = ""
|
||||
|
||||
// act
|
||||
m, err := parseAzureServiceBusMetadata(fakeMetaData)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, m.TimeoutInSec, 60)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid optional timeoutInSec", func(t *testing.T) {
|
||||
fakeProperties := getFakeProperties()
|
||||
|
||||
fakeMetaData := pubsub.Metadata{
|
||||
Properties: fakeProperties,
|
||||
}
|
||||
fakeMetaData.Properties[timeoutInSec] = invalidNumber
|
||||
|
||||
// act
|
||||
_, err := parseAzureServiceBusMetadata(fakeMetaData)
|
||||
|
||||
// assert
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("missing optional disableEntityManagement", func(t *testing.T) {
|
||||
fakeProperties := getFakeProperties()
|
||||
|
||||
fakeMetaData := pubsub.Metadata{
|
||||
Properties: fakeProperties,
|
||||
}
|
||||
fakeMetaData.Properties[disableEntityManagement] = ""
|
||||
|
||||
// act
|
||||
m, err := parseAzureServiceBusMetadata(fakeMetaData)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, m.DisableEntityManagement, false)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid optional disableEntityManagement", func(t *testing.T) {
|
||||
fakeProperties := getFakeProperties()
|
||||
|
||||
fakeMetaData := pubsub.Metadata{
|
||||
Properties: fakeProperties,
|
||||
}
|
||||
fakeMetaData.Properties[disableEntityManagement] = "invalid_bool"
|
||||
|
||||
// act
|
||||
_, err := parseAzureServiceBusMetadata(fakeMetaData)
|
||||
|
||||
// assert
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("missing nullable maxDeliveryCount", func(t *testing.T) {
|
||||
fakeProperties := getFakeProperties()
|
||||
|
||||
fakeMetaData := pubsub.Metadata{
|
||||
Properties: fakeProperties,
|
||||
}
|
||||
fakeMetaData.Properties[maxDeliveryCount] = ""
|
||||
|
||||
// act
|
||||
m, err := parseAzureServiceBusMetadata(fakeMetaData)
|
||||
|
||||
// assert
|
||||
assert.Nil(t, m.MaxDeliveryCount)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid nullable maxDeliveryCount", func(t *testing.T) {
|
||||
fakeProperties := getFakeProperties()
|
||||
|
||||
fakeMetaData := pubsub.Metadata{
|
||||
Properties: fakeProperties,
|
||||
}
|
||||
fakeMetaData.Properties[maxDeliveryCount] = invalidNumber
|
||||
|
||||
// act
|
||||
_, err := parseAzureServiceBusMetadata(fakeMetaData)
|
||||
|
||||
// assert
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("missing nullable defaultMessageTimeToLiveInSec", func(t *testing.T) {
|
||||
fakeProperties := getFakeProperties()
|
||||
|
||||
fakeMetaData := pubsub.Metadata{
|
||||
Properties: fakeProperties,
|
||||
}
|
||||
fakeMetaData.Properties[defaultMessageTimeToLiveInSec] = ""
|
||||
|
||||
// act
|
||||
m, err := parseAzureServiceBusMetadata(fakeMetaData)
|
||||
|
||||
// assert
|
||||
assert.Nil(t, m.DefaultMessageTimeToLiveInSec)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid nullable defaultMessageTimeToLiveInSec", func(t *testing.T) {
|
||||
fakeProperties := getFakeProperties()
|
||||
|
||||
fakeMetaData := pubsub.Metadata{
|
||||
Properties: fakeProperties,
|
||||
}
|
||||
fakeMetaData.Properties[defaultMessageTimeToLiveInSec] = invalidNumber
|
||||
|
||||
// act
|
||||
_, err := parseAzureServiceBusMetadata(fakeMetaData)
|
||||
|
||||
// assert
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("missing nullable autoDeleteOnIdleInSec", func(t *testing.T) {
|
||||
fakeProperties := getFakeProperties()
|
||||
|
||||
fakeMetaData := pubsub.Metadata{
|
||||
Properties: fakeProperties,
|
||||
}
|
||||
fakeMetaData.Properties[autoDeleteOnIdleInSec] = ""
|
||||
|
||||
// act
|
||||
m, err := parseAzureServiceBusMetadata(fakeMetaData)
|
||||
|
||||
// assert
|
||||
assert.Nil(t, m.AutoDeleteOnIdleInSec)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid nullable autoDeleteOnIdleInSec", func(t *testing.T) {
|
||||
fakeProperties := getFakeProperties()
|
||||
|
||||
fakeMetaData := pubsub.Metadata{
|
||||
Properties: fakeProperties,
|
||||
}
|
||||
fakeMetaData.Properties[autoDeleteOnIdleInSec] = invalidNumber
|
||||
|
||||
// act
|
||||
_, err := parseAzureServiceBusMetadata(fakeMetaData)
|
||||
|
||||
// assert
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("missing nullable lockDurationInSec", func(t *testing.T) {
|
||||
fakeProperties := getFakeProperties()
|
||||
|
||||
fakeMetaData := pubsub.Metadata{
|
||||
Properties: fakeProperties,
|
||||
}
|
||||
fakeMetaData.Properties[lockDurationInSec] = ""
|
||||
|
||||
// act
|
||||
m, err := parseAzureServiceBusMetadata(fakeMetaData)
|
||||
|
||||
// assert
|
||||
assert.Nil(t, m.LockDurationInSec)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid nullable lockDurationInSec", func(t *testing.T) {
|
||||
fakeProperties := getFakeProperties()
|
||||
|
||||
fakeMetaData := pubsub.Metadata{
|
||||
Properties: fakeProperties,
|
||||
}
|
||||
fakeMetaData.Properties[lockDurationInSec] = invalidNumber
|
||||
|
||||
// act
|
||||
_, err := parseAzureServiceBusMetadata(fakeMetaData)
|
||||
|
||||
// assert
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
// ------------------------------------------------------------
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
// ------------------------------------------------------------
|
||||
|
||||
package azureservicebus
|
||||
|
||||
// Reference for settings:
|
||||
// https://github.com/Azure/azure-service-bus-go/blob/54b2faa53e5216616e59725281be692acc120c34/subscription_manager.go#L101
|
||||
type metadata struct {
|
||||
ConnectionString string `json:"connectionString"`
|
||||
ConsumerID string `json:"consumerID"`
|
||||
TimeoutInSec int `json:"timeoutInSec"`
|
||||
DisableEntityManagement bool `json:"disableEntityManagement"`
|
||||
MaxDeliveryCount *int `json:"maxDeliveryCount"`
|
||||
LockDurationInSec *int `json:"lockDurationInSec"`
|
||||
DefaultMessageTimeToLiveInSec *int `json:"defaultMessageTimeToLiveInSec"`
|
||||
AutoDeleteOnIdleInSec *int `json:"autoDeleteOnIdleInSec"`
|
||||
}
|
||||
|
|
@ -24,6 +24,7 @@ package adal
|
|||
*/
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
|
|
@ -101,7 +102,14 @@ type deviceToken struct {
|
|||
|
||||
// InitiateDeviceAuth initiates a device auth flow. It returns a DeviceCode
|
||||
// that can be used with CheckForUserCompletion or WaitForUserCompletion.
|
||||
// Deprecated: use InitiateDeviceAuthWithContext() instead.
|
||||
func InitiateDeviceAuth(sender Sender, oauthConfig OAuthConfig, clientID, resource string) (*DeviceCode, error) {
|
||||
return InitiateDeviceAuthWithContext(context.Background(), sender, oauthConfig, clientID, resource)
|
||||
}
|
||||
|
||||
// InitiateDeviceAuthWithContext initiates a device auth flow. It returns a DeviceCode
|
||||
// that can be used with CheckForUserCompletion or WaitForUserCompletion.
|
||||
func InitiateDeviceAuthWithContext(ctx context.Context, sender Sender, oauthConfig OAuthConfig, clientID, resource string) (*DeviceCode, error) {
|
||||
v := url.Values{
|
||||
"client_id": []string{clientID},
|
||||
"resource": []string{resource},
|
||||
|
|
@ -117,7 +125,7 @@ func InitiateDeviceAuth(sender Sender, oauthConfig OAuthConfig, clientID, resour
|
|||
|
||||
req.ContentLength = int64(len(s))
|
||||
req.Header.Set(contentType, mimeTypeFormPost)
|
||||
resp, err := sender.Do(req)
|
||||
resp, err := sender.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeSendingFails, err.Error())
|
||||
}
|
||||
|
|
@ -151,7 +159,14 @@ func InitiateDeviceAuth(sender Sender, oauthConfig OAuthConfig, clientID, resour
|
|||
|
||||
// CheckForUserCompletion takes a DeviceCode and checks with the Azure AD OAuth endpoint
|
||||
// to see if the device flow has: been completed, timed out, or otherwise failed
|
||||
// Deprecated: use CheckForUserCompletionWithContext() instead.
|
||||
func CheckForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {
|
||||
return CheckForUserCompletionWithContext(context.Background(), sender, code)
|
||||
}
|
||||
|
||||
// CheckForUserCompletionWithContext takes a DeviceCode and checks with the Azure AD OAuth endpoint
|
||||
// to see if the device flow has: been completed, timed out, or otherwise failed
|
||||
func CheckForUserCompletionWithContext(ctx context.Context, sender Sender, code *DeviceCode) (*Token, error) {
|
||||
v := url.Values{
|
||||
"client_id": []string{code.ClientID},
|
||||
"code": []string{*code.DeviceCode},
|
||||
|
|
@ -169,7 +184,7 @@ func CheckForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {
|
|||
|
||||
req.ContentLength = int64(len(s))
|
||||
req.Header.Set(contentType, mimeTypeFormPost)
|
||||
resp, err := sender.Do(req)
|
||||
resp, err := sender.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenSendingFails, err.Error())
|
||||
}
|
||||
|
|
@ -213,12 +228,19 @@ func CheckForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {
|
|||
|
||||
// WaitForUserCompletion calls CheckForUserCompletion repeatedly until a token is granted or an error state occurs.
|
||||
// This prevents the user from looping and checking against 'ErrDeviceAuthorizationPending'.
|
||||
// Deprecated: use WaitForUserCompletionWithContext() instead.
|
||||
func WaitForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {
|
||||
return WaitForUserCompletionWithContext(context.Background(), sender, code)
|
||||
}
|
||||
|
||||
// WaitForUserCompletionWithContext calls CheckForUserCompletion repeatedly until a token is granted or an error
|
||||
// state occurs. This prevents the user from looping and checking against 'ErrDeviceAuthorizationPending'.
|
||||
func WaitForUserCompletionWithContext(ctx context.Context, sender Sender, code *DeviceCode) (*Token, error) {
|
||||
intervalDuration := time.Duration(*code.Interval) * time.Second
|
||||
waitDuration := intervalDuration
|
||||
|
||||
for {
|
||||
token, err := CheckForUserCompletion(sender, code)
|
||||
token, err := CheckForUserCompletionWithContext(ctx, sender, code)
|
||||
|
||||
if err == nil {
|
||||
return token, nil
|
||||
|
|
@ -237,6 +259,11 @@ func WaitForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {
|
|||
return nil, fmt.Errorf("%s Error waiting for user to complete device flow. Server told us to slow_down too much", logPrefix)
|
||||
}
|
||||
|
||||
time.Sleep(waitDuration)
|
||||
select {
|
||||
case <-time.After(waitDuration):
|
||||
// noop
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,9 +26,9 @@ import (
|
|||
"fmt"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
|
@ -63,6 +63,12 @@ const (
|
|||
|
||||
// the default number of attempts to refresh an MSI authentication token
|
||||
defaultMaxMSIRefreshAttempts = 5
|
||||
|
||||
// asMSIEndpointEnv is the environment variable used to store the endpoint on App Service and Functions
|
||||
asMSIEndpointEnv = "MSI_ENDPOINT"
|
||||
|
||||
// asMSISecretEnv is the environment variable used to store the request secret on App Service and Functions
|
||||
asMSISecretEnv = "MSI_SECRET"
|
||||
)
|
||||
|
||||
// OAuthTokenProvider is an interface which should be implemented by an access token retriever
|
||||
|
|
@ -634,6 +640,31 @@ func GetMSIVMEndpoint() (string, error) {
|
|||
return msiEndpoint, nil
|
||||
}
|
||||
|
||||
func isAppService() bool {
|
||||
_, asMSIEndpointEnvExists := os.LookupEnv(asMSIEndpointEnv)
|
||||
_, asMSISecretEnvExists := os.LookupEnv(asMSISecretEnv)
|
||||
|
||||
return asMSIEndpointEnvExists && asMSISecretEnvExists
|
||||
}
|
||||
|
||||
// GetMSIAppServiceEndpoint get the MSI endpoint for App Service and Functions
|
||||
func GetMSIAppServiceEndpoint() (string, error) {
|
||||
asMSIEndpoint, asMSIEndpointEnvExists := os.LookupEnv(asMSIEndpointEnv)
|
||||
|
||||
if asMSIEndpointEnvExists {
|
||||
return asMSIEndpoint, nil
|
||||
}
|
||||
return "", errors.New("MSI endpoint not found")
|
||||
}
|
||||
|
||||
// GetMSIEndpoint get the appropriate MSI endpoint depending on the runtime environment
|
||||
func GetMSIEndpoint() (string, error) {
|
||||
if isAppService() {
|
||||
return GetMSIAppServiceEndpoint()
|
||||
}
|
||||
return GetMSIVMEndpoint()
|
||||
}
|
||||
|
||||
// NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension.
|
||||
// It will use the system assigned identity when creating the token.
|
||||
func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
|
||||
|
|
@ -666,7 +697,12 @@ func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedI
|
|||
|
||||
v := url.Values{}
|
||||
v.Set("resource", resource)
|
||||
v.Set("api-version", "2018-02-01")
|
||||
// App Service MSI currently only supports token API version 2017-09-01
|
||||
if isAppService() {
|
||||
v.Set("api-version", "2017-09-01")
|
||||
} else {
|
||||
v.Set("api-version", "2018-02-01")
|
||||
}
|
||||
if userAssignedID != nil {
|
||||
v.Set("client_id", *userAssignedID)
|
||||
}
|
||||
|
|
@ -793,7 +829,7 @@ func isIMDS(u url.URL) bool {
|
|||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return u.Host == imds.Host && u.Path == imds.Path
|
||||
return (u.Host == imds.Host && u.Path == imds.Path) || isAppService()
|
||||
}
|
||||
|
||||
func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error {
|
||||
|
|
@ -802,6 +838,11 @@ func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource
|
|||
return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err)
|
||||
}
|
||||
req.Header.Add("User-Agent", UserAgent())
|
||||
// Add header when runtime is on App Service or Functions
|
||||
if isAppService() {
|
||||
asMSISecret, _ := os.LookupEnv(asMSISecretEnv)
|
||||
req.Header.Add("Secret", asMSISecret)
|
||||
}
|
||||
req = req.WithContext(ctx)
|
||||
if !isIMDS(spt.inner.OauthConfig.TokenEndpoint) {
|
||||
v := url.Values{}
|
||||
|
|
@ -846,7 +887,8 @@ func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource
|
|||
resp, err = spt.sender.Do(req)
|
||||
}
|
||||
if err != nil {
|
||||
return newTokenRefreshError(fmt.Sprintf("adal: Failed to execute the refresh request. Error = '%v'", err), nil)
|
||||
// don't return a TokenRefreshError here; this will allow retry logic to apply
|
||||
return fmt.Errorf("adal: Failed to execute the refresh request. Error = '%v'", err)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
|
@ -913,10 +955,8 @@ func retryForIMDS(sender Sender, req *http.Request, maxAttempts int) (resp *http
|
|||
|
||||
for attempt < maxAttempts {
|
||||
resp, err = sender.Do(req)
|
||||
// retry on temporary network errors, e.g. transient network failures.
|
||||
// if we don't receive a response then assume we can't connect to the
|
||||
// endpoint so we're likely not running on an Azure VM so don't retry.
|
||||
if (err != nil && !isTemporaryNetworkError(err)) || resp == nil || resp.StatusCode == http.StatusOK || !containsInt(retries, resp.StatusCode) {
|
||||
// we want to retry if err is not nil or the status code is in the list of retry codes
|
||||
if err == nil && !responseHasStatusCode(resp, retries...) {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -940,20 +980,12 @@ func retryForIMDS(sender Sender, req *http.Request, maxAttempts int) (resp *http
|
|||
return
|
||||
}
|
||||
|
||||
// returns true if the specified error is a temporary network error or false if it's not.
|
||||
// if the error doesn't implement the net.Error interface the return value is true.
|
||||
func isTemporaryNetworkError(err error) bool {
|
||||
if netErr, ok := err.(net.Error); !ok || (ok && netErr.Temporary()) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// returns true if slice ints contains the value n
|
||||
func containsInt(ints []int, n int) bool {
|
||||
for _, i := range ints {
|
||||
if i == n {
|
||||
return true
|
||||
func responseHasStatusCode(resp *http.Response, codes ...int) bool {
|
||||
if resp != nil {
|
||||
for _, i := range codes {
|
||||
if i == resp.StatusCode {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
|
|
|
|||
|
|
@ -289,10 +289,6 @@ func doRetryForStatusCodesImpl(s Sender, r *http.Request, count429 bool, attempt
|
|||
return
|
||||
}
|
||||
resp, err = s.Do(rr.Request())
|
||||
// if the error isn't temporary don't bother retrying
|
||||
if err != nil && !IsTemporaryNetworkError(err) {
|
||||
return
|
||||
}
|
||||
// we want to retry if err is not nil (e.g. transient network failure). note that for failed authentication
|
||||
// resp and err will both have a value, so in this case we don't want to retry as it will never succeed.
|
||||
if err == nil && !ResponseHasStatusCode(resp, codes...) || IsTokenRefreshError(err) {
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ import (
|
|||
"runtime"
|
||||
)
|
||||
|
||||
const number = "v13.0.1"
|
||||
const number = "v13.0.2"
|
||||
|
||||
var (
|
||||
userAgent = fmt.Sprintf("Go/%s (%s-%s) go-autorest/%s",
|
||||
|
|
|
|||
|
|
@ -54,10 +54,10 @@ github.com/Azure/azure-service-bus-go
|
|||
github.com/Azure/azure-service-bus-go/atom
|
||||
# github.com/Azure/azure-storage-blob-go v0.8.0
|
||||
github.com/Azure/azure-storage-blob-go/azblob
|
||||
# github.com/Azure/go-autorest/autorest v0.9.1
|
||||
# github.com/Azure/go-autorest/autorest v0.9.2
|
||||
github.com/Azure/go-autorest/autorest
|
||||
github.com/Azure/go-autorest/autorest/azure
|
||||
# github.com/Azure/go-autorest/autorest/adal v0.6.0
|
||||
# github.com/Azure/go-autorest/autorest/adal v0.8.0
|
||||
github.com/Azure/go-autorest/autorest/adal
|
||||
# github.com/Azure/go-autorest/autorest/azure/auth v0.3.0
|
||||
github.com/Azure/go-autorest/autorest/azure/auth
|
||||
|
|
|
|||
Loading…
Reference in New Issue