Move Service Bus Pubsub/Binding to common auth (#1201)
* Move Service Bus Pubsub/Binding to common auth Both the pubsub and input/output binding for Azure Service Bus were connecting via a connection string. This is still supported but will now fallback to using AAD from the common auth library. This is also the recommended auth pattern going forward. * Move AMPQ specific auth and fix linter issues * Make conn string and namespace mutually exclusive * Move resourceName to a constant * Update auth_amqp.go * Update auth.go Co-authored-by: Long Dai <long.dai@intel.com> Co-authored-by: Simon Leet <31784195+CodeMonkeyLeet@users.noreply.github.com> Co-authored-by: Artur Souza <artursouza.ms@outlook.com> Co-authored-by: Dapr Bot <56698301+dapr-bot@users.noreply.github.com>
This commit is contained in:
parent
f1be130563
commit
d5a68041c9
|
|
@ -44,6 +44,8 @@ func NewEnvironmentSettings(resourceName string, values map[string]string) (Envi
|
|||
case "cosmosdb":
|
||||
// Azure Cosmos DB (data plane)
|
||||
es.Resource = "https://" + azureEnv.CosmosDBDNSSuffix
|
||||
case "servicebus":
|
||||
es.Resource = azureEnv.ResourceIdentifiers.ServiceBus
|
||||
default:
|
||||
return es, errors.New("invalid resource name: " + resourceName)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,25 @@
|
|||
// ------------------------------------------------------------
|
||||
// Copyright (c) Microsoft Corporation and Dapr Contributors.
|
||||
// Licensed under the MIT License.
|
||||
// ------------------------------------------------------------
|
||||
|
||||
package azure
|
||||
|
||||
import "github.com/Azure/azure-amqp-common-go/v3/aad"
|
||||
|
||||
const (
|
||||
AzureServiceBusResourceName string = "servicebus"
|
||||
)
|
||||
|
||||
// GetTokenProvider creates a TokenProvider for AAD retrieved from, in order:
|
||||
// 1. Client credentials
|
||||
// 2. Client certificate
|
||||
// 3. MSI.
|
||||
func (s EnvironmentSettings) GetAADTokenProvider() (*aad.TokenProvider, error) {
|
||||
spt, err := s.GetServicePrincipalToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return aad.NewJWTProvider(aad.JWTProviderWithAADToken(spt), aad.JWTProviderWithAzureEnvironment(s.AzureEnvironment))
|
||||
}
|
||||
|
|
@ -8,6 +8,7 @@ package servicebusqueues
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
|
@ -15,6 +16,7 @@ import (
|
|||
servicebus "github.com/Azure/azure-service-bus-go"
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
|
||||
azauth "github.com/dapr/components-contrib/authentication/azure"
|
||||
"github.com/dapr/components-contrib/bindings"
|
||||
contrib_metadata "github.com/dapr/components-contrib/metadata"
|
||||
"github.com/dapr/kit/logger"
|
||||
|
|
@ -43,6 +45,7 @@ type AzureServiceBusQueues struct {
|
|||
|
||||
type serviceBusQueuesMetadata struct {
|
||||
ConnectionString string `json:"connectionString"`
|
||||
NamespaceName string `json:"namespaceName,omitempty"`
|
||||
QueueName string `json:"queueName"`
|
||||
ttl time.Duration
|
||||
}
|
||||
|
|
@ -61,10 +64,36 @@ func (a *AzureServiceBusQueues) Init(metadata bindings.Metadata) error {
|
|||
userAgent := "dapr-" + logger.DaprVersion
|
||||
a.metadata = meta
|
||||
|
||||
ns, err := servicebus.NewNamespace(servicebus.NamespaceWithConnectionString(a.metadata.ConnectionString),
|
||||
servicebus.NamespaceWithUserAgent(userAgent))
|
||||
if err != nil {
|
||||
return err
|
||||
var ns *servicebus.Namespace
|
||||
if a.metadata.ConnectionString != "" {
|
||||
ns, err = servicebus.NewNamespace(servicebus.NamespaceWithConnectionString(a.metadata.ConnectionString),
|
||||
servicebus.NamespaceWithUserAgent(userAgent))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Initialization code
|
||||
settings, sErr := azauth.NewEnvironmentSettings(azauth.AzureServiceBusResourceName, metadata.Properties)
|
||||
if sErr != nil {
|
||||
return sErr
|
||||
}
|
||||
|
||||
tokenProvider, tErr := settings.GetAADTokenProvider()
|
||||
if tErr != nil {
|
||||
return tErr
|
||||
}
|
||||
|
||||
ns, err = servicebus.NewNamespace(servicebus.NamespaceWithTokenProvider(tokenProvider),
|
||||
servicebus.NamespaceWithUserAgent(userAgent))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// We set these separately as the ServiceBus SDK does not provide a way to pass the environment via the options
|
||||
// pattern unless you allow it to recreate the entire environment which seems wasteful.
|
||||
ns.Name = a.metadata.NamespaceName
|
||||
ns.Environment = *settings.AzureEnvironment
|
||||
ns.Suffix = settings.AzureEnvironment.ServiceBusEndpointSuffix
|
||||
}
|
||||
a.ns = ns
|
||||
|
||||
|
|
@ -124,6 +153,10 @@ func (a *AzureServiceBusQueues) parseMetadata(metadata bindings.Metadata) (*serv
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if m.ConnectionString != "" && m.NamespaceName != "" {
|
||||
return nil, errors.New("connectionString and namespaceName are mutually exclusive")
|
||||
}
|
||||
|
||||
ttl, ok, err := contrib_metadata.TryGetTTL(metadata.Properties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -93,3 +93,55 @@ func TestParseMetadataWithInvalidTTL(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMetadataConnectionStringAndNamespaceNameExclusivity(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
properties map[string]string
|
||||
expectedConnectionString string
|
||||
expectedNamespaceName string
|
||||
expectedQueueName string
|
||||
expectedErr bool
|
||||
}{
|
||||
{
|
||||
name: "ConnectionString and queue name",
|
||||
properties: map[string]string{"connectionString": "connString", "queueName": "queue1"},
|
||||
expectedConnectionString: "connString",
|
||||
expectedNamespaceName: "",
|
||||
expectedQueueName: "queue1",
|
||||
expectedErr: false,
|
||||
},
|
||||
{
|
||||
name: "Empty TTL",
|
||||
properties: map[string]string{"namespaceName": "testNamespace", "queueName": "queue1", metadata.TTLMetadataKey: ""},
|
||||
expectedConnectionString: "",
|
||||
expectedNamespaceName: "testNamespace",
|
||||
expectedQueueName: "queue1",
|
||||
expectedErr: false,
|
||||
},
|
||||
{
|
||||
name: "With TTL",
|
||||
properties: map[string]string{"connectionString": "connString", "namespaceName": "testNamespace", "queueName": "queue1", metadata.TTLMetadataKey: "1"},
|
||||
expectedConnectionString: "",
|
||||
expectedNamespaceName: "",
|
||||
expectedQueueName: "queue1",
|
||||
expectedErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range testCases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := bindings.Metadata{}
|
||||
m.Properties = tt.properties
|
||||
a := NewAzureServiceBusQueues(logger.NewLogger("test"))
|
||||
meta, err := a.parseMetadata(m)
|
||||
if tt.expectedErr {
|
||||
assert.NotNil(t, err)
|
||||
} else {
|
||||
assert.Equal(t, tt.expectedConnectionString, meta.ConnectionString)
|
||||
assert.Equal(t, tt.expectedQueueName, meta.QueueName)
|
||||
assert.Equal(t, tt.expectedNamespaceName, meta.NamespaceName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,4 +26,5 @@ type metadata struct {
|
|||
PrefetchCount *int `json:"prefetchCount"`
|
||||
PublishMaxRetries int `json:"publishMaxRetries"`
|
||||
PublishInitialRetryIntervalInMs int `json:"publishInitialRetryInternalInMs"`
|
||||
NamespaceName string `json:"namespaceName,omitempty"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import (
|
|||
|
||||
azservicebus "github.com/Azure/azure-service-bus-go"
|
||||
|
||||
azauth "github.com/dapr/components-contrib/authentication/azure"
|
||||
"github.com/dapr/components-contrib/pubsub"
|
||||
"github.com/dapr/kit/logger"
|
||||
"github.com/dapr/kit/retry"
|
||||
|
|
@ -43,6 +44,7 @@ const (
|
|||
connectionRecoveryInSec = "connectionRecoveryInSec"
|
||||
publishMaxRetries = "publishMaxRetries"
|
||||
publishInitialRetryInternalInMs = "publishInitialRetryInternalInMs"
|
||||
namespaceName = "namespaceName"
|
||||
errorMessagePrefix = "azure service bus error:"
|
||||
|
||||
// Defaults.
|
||||
|
|
@ -93,8 +95,15 @@ func parseAzureServiceBusMetadata(meta pubsub.Metadata) (metadata, error) {
|
|||
/* Required configuration settings - no defaults. */
|
||||
if val, ok := meta.Properties[connectionString]; ok && val != "" {
|
||||
m.ConnectionString = val
|
||||
|
||||
// The connection string and the namespace cannot both be present.
|
||||
if namespace, present := meta.Properties[namespaceName]; present && namespace != "" {
|
||||
return m, fmt.Errorf("%s connectionString and namespaceName cannot both be specified", errorMessagePrefix)
|
||||
}
|
||||
} else if val, ok := meta.Properties[namespaceName]; ok && val != "" {
|
||||
m.NamespaceName = val
|
||||
} else {
|
||||
return m, fmt.Errorf("%s missing connection string", errorMessagePrefix)
|
||||
return m, fmt.Errorf("%s missing connection string and namespace name", errorMessagePrefix)
|
||||
}
|
||||
|
||||
if val, ok := meta.Properties[consumerID]; ok && val != "" {
|
||||
|
|
@ -258,12 +267,37 @@ func (a *azureServiceBus) Init(metadata pubsub.Metadata) error {
|
|||
|
||||
userAgent := "dapr-" + logger.DaprVersion
|
||||
a.metadata = m
|
||||
a.namespace, err = azservicebus.NewNamespace(
|
||||
azservicebus.NamespaceWithConnectionString(a.metadata.ConnectionString),
|
||||
azservicebus.NamespaceWithUserAgent(userAgent))
|
||||
if a.metadata.ConnectionString != "" {
|
||||
a.namespace, err = azservicebus.NewNamespace(
|
||||
azservicebus.NamespaceWithConnectionString(a.metadata.ConnectionString),
|
||||
azservicebus.NamespaceWithUserAgent(userAgent))
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Initialization code
|
||||
settings, err := azauth.NewEnvironmentSettings(azauth.AzureServiceBusResourceName, metadata.Properties)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tokenProvider, err := settings.GetAADTokenProvider()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
a.namespace, err = azservicebus.NewNamespace(azservicebus.NamespaceWithTokenProvider(tokenProvider),
|
||||
azservicebus.NamespaceWithUserAgent(userAgent))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// We set these separately as the ServiceBus SDK does not provide a way to pass the environment via the options
|
||||
// pattern unless you allow it to recreate the entire environment which seems wasteful.
|
||||
a.namespace.Name = a.metadata.NamespaceName
|
||||
a.namespace.Environment = *settings.AzureEnvironment
|
||||
a.namespace.Suffix = settings.AzureEnvironment.ServiceBusEndpointSuffix
|
||||
}
|
||||
|
||||
a.topicManager = a.namespace.NewTopicManager()
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ const (
|
|||
func getFakeProperties() map[string]string {
|
||||
return map[string]string{
|
||||
connectionString: "fakeConnectionString",
|
||||
namespaceName: "",
|
||||
consumerID: "fakeConId",
|
||||
disableEntityManagement: "true",
|
||||
timeoutInSec: "90",
|
||||
|
|
@ -82,13 +83,14 @@ func TestParseServiceBusMetadata(t *testing.T) {
|
|||
assert.Equal(t, 10, *m.PrefetchCount)
|
||||
})
|
||||
|
||||
t.Run("missing required connectionString", func(t *testing.T) {
|
||||
t.Run("missing required connectionString and namespaceName", func(t *testing.T) {
|
||||
fakeProperties := getFakeProperties()
|
||||
|
||||
fakeMetaData := pubsub.Metadata{
|
||||
Properties: fakeProperties,
|
||||
}
|
||||
fakeMetaData.Properties[connectionString] = ""
|
||||
fakeMetaData.Properties[namespaceName] = ""
|
||||
|
||||
// act.
|
||||
m, err := parseAzureServiceBusMetadata(fakeMetaData)
|
||||
|
|
@ -99,6 +101,56 @@ func TestParseServiceBusMetadata(t *testing.T) {
|
|||
assert.Empty(t, m.ConnectionString)
|
||||
})
|
||||
|
||||
t.Run("connectionString makes namespace optional", func(t *testing.T) {
|
||||
fakeProperties := getFakeProperties()
|
||||
|
||||
fakeMetaData := pubsub.Metadata{
|
||||
Properties: fakeProperties,
|
||||
}
|
||||
fakeMetaData.Properties[namespaceName] = ""
|
||||
|
||||
// act.
|
||||
m, err := parseAzureServiceBusMetadata(fakeMetaData)
|
||||
|
||||
// assert.
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "fakeConnectionString", m.ConnectionString)
|
||||
})
|
||||
|
||||
t.Run("namespace makes conectionString optional", func(t *testing.T) {
|
||||
fakeProperties := getFakeProperties()
|
||||
|
||||
fakeMetaData := pubsub.Metadata{
|
||||
Properties: fakeProperties,
|
||||
}
|
||||
fakeMetaData.Properties[namespaceName] = "fakeNamespace"
|
||||
fakeMetaData.Properties[connectionString] = ""
|
||||
|
||||
// act.
|
||||
m, err := parseAzureServiceBusMetadata(fakeMetaData)
|
||||
|
||||
// assert.
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "fakeNamespace", m.NamespaceName)
|
||||
})
|
||||
|
||||
t.Run("connectionString and namespace are mutually exclusive", func(t *testing.T) {
|
||||
fakeProperties := getFakeProperties()
|
||||
|
||||
fakeMetaData := pubsub.Metadata{
|
||||
Properties: fakeProperties,
|
||||
}
|
||||
|
||||
fakeMetaData.Properties[namespaceName] = "fakeNamespace"
|
||||
|
||||
// act.
|
||||
_, err := parseAzureServiceBusMetadata(fakeMetaData)
|
||||
|
||||
// assert.
|
||||
assert.Error(t, err)
|
||||
assertValidErrorMessage(t, err)
|
||||
})
|
||||
|
||||
t.Run("missing required consumerID", func(t *testing.T) {
|
||||
fakeProperties := getFakeProperties()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue