Azure AD support in SignalR (#1852)

* WIP: Azure AD support in SignalR

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>

* Correct SignalR AAD details

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>

* Misc fixes

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>

* azauth package name

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
Alessandro (Ale) Segala 2022-07-06 14:05:24 -07:00 committed by GitHub
parent a2f3a84b96
commit 3821c00131
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 294 additions and 123 deletions

View File

@ -17,7 +17,7 @@ import (
"sync"
"time"
"github.com/golang-jwt/jwt"
jwt "github.com/golang-jwt/jwt/v4"
"github.com/dapr/kit/logger"
)

View File

@ -22,102 +22,167 @@ import (
"strings"
"time"
"github.com/golang-jwt/jwt"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
jwt "github.com/golang-jwt/jwt/v4"
"github.com/pkg/errors"
"github.com/dapr/components-contrib/bindings"
azauth "github.com/dapr/components-contrib/internal/authentication/azure"
"github.com/dapr/kit/logger"
)
const (
errorPrefix = "azure signalr error:"
logPrefix = "azure signalr:"
errorPrefix = "azure signalr error:"
logPrefix = "azure signalr:"
// Metadata keys.
// Azure AD credentials are parsed separately and not listed here.
connectionStringKey = "connectionString"
accessKeyKey = "accessKey"
endpointKey = "endpoint"
hubKey = "hub"
groupKey = "group"
userKey = "user"
// Invoke metadata keys.
groupKey = "group"
userKey = "user"
)
// NewSignalR creates a new pub/sub based on Azure SignalR.
func NewSignalR(logger logger.Logger) *SignalR {
return &SignalR{
tokens: make(map[string]signalrCachedToken),
httpClient: &http.Client{Timeout: 30 * time.Second},
logger: logger,
// Global HTTP client
var httpClient *http.Client
func init() {
httpClient = &http.Client{
Timeout: 30 * time.Second,
}
}
type signalrCachedToken struct {
token string
expiration time.Time
// NewSignalR creates a new output binding for Azure SignalR.
func NewSignalR(logger logger.Logger) *SignalR {
return &SignalR{
logger: logger,
httpClient: httpClient,
}
}
// SignalR is an output binding for Azure SignalR.
type SignalR struct {
endpoint string
accessKey string
version string
hub string
userAgent string
tokens map[string]signalrCachedToken
httpClient *http.Client
endpoint string
accessKey string
hub string
userAgent string
aadToken azcore.TokenCredential
logger logger.Logger
httpClient *http.Client
logger logger.Logger
}
// Init is responsible for initializing the SignalR output based on the metadata.
func (s *SignalR) Init(metadata bindings.Metadata) error {
func (s *SignalR) Init(metadata bindings.Metadata) (err error) {
s.userAgent = "dapr-" + logger.DaprVersion
connectionString, ok := metadata.Properties[connectionStringKey]
if !ok || connectionString == "" {
return fmt.Errorf("missing connection string")
err = s.parseMetadata(metadata.Properties)
if err != nil {
return err
}
if hub, ok := metadata.Properties[hubKey]; ok && hub != "" {
s.hub = hub
}
// Expected: Endpoint=https://<servicename>.service.signalr.net;AccessKey=<access key>;Version=1.0;
connectionValues := strings.Split(strings.TrimSpace(connectionString), ";")
for _, connectionValue := range connectionValues {
if i := strings.Index(connectionValue, "="); i != -1 && len(connectionValue) > (i+1) {
k := connectionValue[0:i]
switch k {
case "Endpoint":
s.endpoint = connectionValue[i+1:]
if s.endpoint[len(s.endpoint)-1] == '/' {
s.endpoint = s.endpoint[:len(s.endpoint)-1]
}
case "AccessKey":
s.accessKey = connectionValue[i+1:]
case "Version":
s.version = connectionValue[i+1:]
}
// If using AAD for authentication, init the token provider
if s.accessKey == "" {
var settings azauth.EnvironmentSettings
settings, err = azauth.NewEnvironmentSettings("signalr", metadata.Properties)
if err != nil {
return err
}
s.aadToken, err = settings.GetTokenCredential()
if err != nil {
return err
}
}
if s.endpoint == "" {
return fmt.Errorf("missing endpoint in connection string")
return nil
}
func (s *SignalR) parseMetadata(md map[string]string) (err error) {
// Start by parsing the connection string if present
connectionString, ok := md[connectionStringKey]
if ok && connectionString != "" {
// Expected options:
// Access key: "Endpoint=https://<servicename>.service.signalr.net;AccessKey=<access key>;Version=1.0;"
// System-assigned managed identity: "Endpoint=https://<servicename>.service.signalr.net;AuthType=aad;Version=1.0;"
// User-assigned managed identity: "Endpoint=https://<servicename>.service.signalr.net;AuthType=aad;ClientId=<clientid>;Version=1.0;"
// Azure AD application: "Endpoint=https://<servicename>.service.signalr.net;AuthType=aad;ClientId=<clientid>;ClientSecret=<clientsecret>;TenantId=<tenantid>;Version=1.0;"
// Note: connection string can't be used if the client secret contains the ; key
connectionValues := strings.Split(strings.TrimSpace(connectionString), ";")
useAAD := false
for _, connectionValue := range connectionValues {
if i := strings.Index(connectionValue, "="); i != -1 && len(connectionValue) > (i+1) {
k := connectionValue[0:i]
switch k {
case "Endpoint":
s.endpoint = connectionValue[i+1:]
case "AccessKey":
s.accessKey = connectionValue[i+1:]
case "AuthType":
if connectionValue[i+1:] != "aad" {
return fmt.Errorf("invalid value for AuthType in the connection string; only 'aad' is supported")
}
useAAD = true
case "ClientId", "ClientSecret", "TenantId":
v := connectionValue[i+1:]
// Set the values in the metadata map so they can be picked up by the azauth module
md["azure"+k] = v
case "Version":
v := connectionValue[i+1:]
// We only support version "1.0"
if v != "1.0" {
return fmt.Errorf("invalid value for Version in the connection string: '%s'; only version '1.0' is supported", v)
}
}
} else if len(connectionValue) != 0 {
return fmt.Errorf("the connection string is invalid or malformed")
}
}
// Check here because if we use a connection string, we'd have an explicit "AuthType=aad" option
// We would otherwise catch this issue later, but here we can be more explicit with the error
if s.accessKey == "" && !useAAD {
return fmt.Errorf("missing AccessKey in the connection string")
}
}
if s.accessKey == "" {
return fmt.Errorf("missing access key in connection string")
// Parse the other metadata keys, which could also override the values from the connection string
if v, ok := md[hubKey]; ok && v != "" {
s.hub = v
}
if v, ok := md[endpointKey]; ok && v != "" {
s.endpoint = v
}
if v, ok := md[accessKeyKey]; ok && v != "" {
s.accessKey = v
}
// Trim ending "/" from endpoint
s.endpoint = strings.TrimSuffix(s.endpoint, "/")
// Check for required values
if s.endpoint == "" {
return fmt.Errorf("missing endpoint in the metadata or connection string")
}
return nil
}
func (s *SignalR) resolveAPIURL(req *bindings.InvokeRequest) (string, error) {
hub := s.hub
if hub == "" {
hubFromRequest, ok := req.Metadata[hubKey]
if !ok || hubFromRequest == "" {
return "", fmt.Errorf("%s missing hub", errorPrefix)
}
hub = hubFromRequest
hub, ok := req.Metadata[hubKey]
if !ok || hub == "" {
hub = s.hub
}
if hub == "" {
return "", fmt.Errorf("%s missing hub", errorPrefix)
}
// Hub name is lower-cased in the official SDKs (e.g. .NET)
hub = strings.ToLower(hub)
var url string
if group, ok := req.Metadata[groupKey]; ok && group != "" {
@ -138,26 +203,26 @@ func (s *SignalR) sendMessageToSignalR(ctx context.Context, url string, token st
}
httpReq.Header.Set("Authorization", "Bearer "+token)
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Content-Type", "application/json; charset=utf-8")
httpReq.Header.Set("User-Agent", s.userAgent)
resp, err := s.httpClient.Do(httpReq)
if err != nil {
return errors.Wrap(err, "request to azure signalr api failed")
}
defer resp.Body.Close()
if resp.StatusCode != 200 && resp.StatusCode != 202 {
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
return fmt.Errorf("%s azure signalr returned code %d, content is '%s'", errorPrefix, resp.StatusCode, string(body))
// Read the body regardless to drain it and ensure the connection can be reused
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
s.logger.Debugf("%s azure signalr call to '%s' returned with status code %d", logPrefix, url, resp.StatusCode)
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
return fmt.Errorf("%s azure signalr failed with code %d, content is '%s'", errorPrefix, resp.StatusCode, string(body))
}
s.logger.Debugf("%s azure signalr call to '%s' completed with code %d", logPrefix, url, resp.StatusCode)
return nil
}
@ -172,7 +237,7 @@ func (s *SignalR) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bin
return nil, err
}
token, err := s.ensureValidToken(url)
token, err := s.getToken(ctx, url)
if err != nil {
return nil, err
}
@ -185,34 +250,34 @@ func (s *SignalR) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bin
return nil, nil
}
func (s *SignalR) ensureValidToken(url string) (string, error) {
now := time.Now()
// Returns an access token for a request to the given URL
func (s *SignalR) getToken(ctx context.Context, url string) (token string, err error) {
// If we have an Azure AD token provider, use that first
if s.aadToken != nil {
var at azcore.AccessToken
at, err = s.aadToken.GetToken(ctx, policy.TokenRequestOptions{
Scopes: []string{"https://signalr.azure.com/.default"},
})
if err != nil {
return "", err
}
token = at.Token
} else {
claims := &jwt.StandardClaims{
ExpiresAt: time.Now().Add(15 * time.Minute).Unix(),
Audience: url,
}
err = claims.Valid()
if err != nil {
return "", err
}
if existing, ok := s.tokens[url]; ok {
if existing.token != "" && now.Before(existing.expiration) {
return existing.token, nil
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
token, err = jwtToken.SignedString([]byte(s.accessKey))
if err != nil {
return "", err
}
}
expiration := now.Add(1 * time.Hour)
claims := &jwt.StandardClaims{
ExpiresAt: expiration.Unix(),
Audience: url,
}
if err := claims.Valid(); err != nil {
return "", err
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
newToken, err := token.SignedString([]byte(s.accessKey))
if err != nil {
return "", err
}
s.tokens[url] = signalrCachedToken{token: newToken, expiration: expiration.Add(time.Minute * -5)}
return newToken, nil
return token, nil
}

View File

@ -31,12 +31,12 @@ import (
func TestConfigurationValid(t *testing.T) {
tests := []struct {
name string
properties map[string]string
expectedEndpoint string
expectedAccessKey string
expectedVersion string
expectedHub string
name string
properties map[string]string
expectedEndpoint string
expectedAccessKey string
expectedHub string
additionalMetadata map[string]string
}{
{
"With all properties",
@ -45,8 +45,8 @@ func TestConfigurationValid(t *testing.T) {
},
"https://fake.service.signalr.net",
"fakekey",
"1.0",
"",
nil,
},
{
"With missing version",
@ -56,7 +56,7 @@ func TestConfigurationValid(t *testing.T) {
"https://fake.service.signalr.net",
"fakekey",
"",
"",
nil,
},
{
"With semicolon after access key",
@ -66,40 +66,137 @@ func TestConfigurationValid(t *testing.T) {
"https://fake.service.signalr.net",
"fakekey",
"",
"",
nil,
},
{
"With trailing slash in endpoint",
map[string]string{
"connectionString": "Endpoint=https://fake.service.signalr.net/;AccessKey=fakekey;Version=1.1",
"connectionString": "Endpoint=https://fake.service.signalr.net/;AccessKey=fakekey;Version=1.0",
},
"https://fake.service.signalr.net",
"fakekey",
"1.1",
"",
nil,
},
{
"With hub",
map[string]string{
"connectionString": "Endpoint=https://fake.service.signalr.net/;AccessKey=fakekey;Version=1.1",
"connectionString": "Endpoint=https://fake.service.signalr.net/;AccessKey=fakekey;Version=1.0",
"hub": "myhub",
},
"https://fake.service.signalr.net",
"fakekey",
"1.1",
"myhub",
nil,
},
{
"With AAD and no access key (system-assigned MSI)",
map[string]string{
"connectionString": "Endpoint=https://fake.service.signalr.net/;AuthType=aad;Version=1.0",
},
"https://fake.service.signalr.net",
"",
"",
nil,
},
{
"Add azureClientId to metadata map (user-assigned MSI)",
map[string]string{
"connectionString": "Endpoint=https://fake.service.signalr.net/;AuthType=aad;ClientId=b83aec5c-54a3-4e4a-8831-ba3f849b79a1;Version=1.0",
},
"https://fake.service.signalr.net",
"",
"",
map[string]string{
"azureClientId": "b83aec5c-54a3-4e4a-8831-ba3f849b79a1",
},
},
{
"Add Azure AD credentials to metadata map (Azure AD app)",
map[string]string{
"connectionString": "Endpoint=https://fake.service.signalr.net/;AuthType=aad;ClientId=b83aec5c-54a3-4e4a-8831-ba3f849b79a1;ClientSecret=fakesecret;TenantId=f0f4622e-e476-46b5-bd0c-1866d27117d4;Version=1.0",
},
"https://fake.service.signalr.net",
"",
"",
map[string]string{
"azureClientId": "b83aec5c-54a3-4e4a-8831-ba3f849b79a1",
"azureClientSecret": "fakesecret",
"azureTenantId": "f0f4622e-e476-46b5-bd0c-1866d27117d4",
},
},
{
"No connection string, access key",
map[string]string{
"endpoint": "https://fake.service.signalr.net/",
"accessKey": "fakekey",
},
"https://fake.service.signalr.net",
"fakekey",
"",
nil,
},
{
"No connection string, access key and hub",
map[string]string{
"endpoint": "https://fake.service.signalr.net/",
"accessKey": "fakekey",
"hub": "myhub",
},
"https://fake.service.signalr.net",
"fakekey",
"myhub",
nil,
},
{
"No connection string, Azure AD",
map[string]string{
"endpoint": "https://fake.service.signalr.net/",
"azureClientId": "b83aec5c-54a3-4e4a-8831-ba3f849b79a1",
"azureClientSecret": "fakesecret",
"azureTenantId": "f0f4622e-e476-46b5-bd0c-1866d27117d4",
},
"https://fake.service.signalr.net",
"",
"",
map[string]string{
"azureClientId": "b83aec5c-54a3-4e4a-8831-ba3f849b79a1",
"azureClientSecret": "fakesecret",
"azureTenantId": "f0f4622e-e476-46b5-bd0c-1866d27117d4",
},
},
{
"No connection string, Azure AD with aliased names",
map[string]string{
"endpoint": "https://fake.service.signalr.net/",
"spnClientId": "b83aec5c-54a3-4e4a-8831-ba3f849b79a1",
"spnClientSecret": "fakesecret",
"spnTenantId": "f0f4622e-e476-46b5-bd0c-1866d27117d4",
},
"https://fake.service.signalr.net",
"",
"",
map[string]string{
"spnClientId": "b83aec5c-54a3-4e4a-8831-ba3f849b79a1",
"spnClientSecret": "fakesecret",
"spnTenantId": "f0f4622e-e476-46b5-bd0c-1866d27117d4",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := NewSignalR(logger.NewLogger("test"))
err := s.Init(bindings.Metadata{Properties: tt.properties})
err := s.parseMetadata(tt.properties)
assert.Nil(t, err)
assert.Equal(t, tt.expectedEndpoint, s.endpoint)
assert.Equal(t, tt.expectedAccessKey, s.accessKey)
assert.Equal(t, tt.expectedVersion, s.version)
assert.Equal(t, tt.expectedHub, s.hub)
if len(tt.additionalMetadata) > 0 {
for k := range tt.additionalMetadata {
assert.Equal(t, tt.properties[k], tt.additionalMetadata[k])
}
}
})
}
}
@ -138,7 +235,7 @@ func TestInvalidConfigurations(t *testing.T) {
},
},
{
"Missing access key",
"Missing access key (no AAD)",
map[string]string{
"connectionString1": "Endpoint=https://fake.service.signalr.net;",
},
@ -146,7 +243,13 @@ func TestInvalidConfigurations(t *testing.T) {
{
"With empty endpoint value",
map[string]string{
"connectionString": "Endpoint=;AccessKey=fakekey;Version=1.1",
"connectionString": "Endpoint=;AccessKey=fakekey;Version=1.0",
},
},
{
"With invalid version",
map[string]string{
"connectionString": "Endpoint=https://fake.service.signalr.net;AccessKey=fakekey;Version=2.0",
},
},
}
@ -154,7 +257,7 @@ func TestInvalidConfigurations(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := NewSignalR(logger.NewLogger("test"))
err := s.Init(bindings.Metadata{Properties: tt.properties})
err := s.parseMetadata(tt.properties)
assert.NotNil(t, err)
})
}
@ -265,12 +368,12 @@ func TestWriteShouldSucceed(t *testing.T) {
userID string
expectedURL string
}{
{"Broadcast receiving hub should call SignalR service", "testHub", "", "", "", "https://fake.service.signalr.net/api/v1/hubs/testHub"},
{"Broadcast with hub metadata should call SignalR service", "", "testHub", "", "", "https://fake.service.signalr.net/api/v1/hubs/testHub"},
{"Group receiving hub should call SignalR service", "testHub", "", "mygroup", "", "https://fake.service.signalr.net/api/v1/hubs/testHub/groups/mygroup"},
{"Group with hub metadata should call SignalR service", "", "testHub", "mygroup", "", "https://fake.service.signalr.net/api/v1/hubs/testHub/groups/mygroup"},
{"User receiving hub should call SignalR service", "testHub", "", "", "myuser", "https://fake.service.signalr.net/api/v1/hubs/testHub/users/myuser"},
{"User with hub metadata should call SignalR service", "", "testHub", "", "myuser", "https://fake.service.signalr.net/api/v1/hubs/testHub/users/myuser"},
{"Broadcast receiving hub should call SignalR service", "testHub", "", "", "", "https://fake.service.signalr.net/api/v1/hubs/testhub"},
{"Broadcast with hub metadata should call SignalR service", "", "testHub", "", "", "https://fake.service.signalr.net/api/v1/hubs/testhub"},
{"Group receiving hub should call SignalR service", "testHub", "", "mygroup", "", "https://fake.service.signalr.net/api/v1/hubs/testhub/groups/mygroup"},
{"Group with hub metadata should call SignalR service", "", "testHub", "mygroup", "", "https://fake.service.signalr.net/api/v1/hubs/testhub/groups/mygroup"},
{"User receiving hub should call SignalR service", "testHub", "", "", "myuser", "https://fake.service.signalr.net/api/v1/hubs/testhub/users/myuser"},
{"User with hub metadata should call SignalR service", "", "testHub", "", "myuser", "https://fake.service.signalr.net/api/v1/hubs/testhub/users/myuser"},
}
for _, tt := range tests {
@ -291,7 +394,7 @@ func TestWriteShouldSucceed(t *testing.T) {
assert.Equal(t, int32(1), httpTransport.requestCount)
assert.Equal(t, tt.expectedURL, httpTransport.request.URL.String())
assert.NotNil(t, httpTransport.request)
assert.Equal(t, "application/json", httpTransport.request.Header.Get("Content-Type"))
assert.Equal(t, "application/json; charset=utf-8", httpTransport.request.Header.Get("Content-Type"))
})
}
}

4
go.mod
View File

@ -65,7 +65,7 @@ require (
github.com/go-redis/redis/v8 v8.11.5
github.com/go-sql-driver/mysql v1.6.0
github.com/gocql/gocql v0.0.0-20210515062232-b7ef815b4556
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
github.com/golang/mock v1.6.0
github.com/golang/snappy v0.0.4 // indirect
github.com/google/uuid v1.3.0
@ -270,7 +270,7 @@ require (
github.com/gogap/errors v0.0.0-20200228125012-531a6449b28c // indirect
github.com/gogap/stack v0.0.0-20150131034635-fef68dddd4f8 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt/v4 v4.2.0 // indirect
github.com/golang-jwt/jwt/v4 v4.2.0
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.2 // indirect

View File

@ -63,6 +63,9 @@ func NewEnvironmentSettings(resourceName string, values map[string]string) (Envi
// The resource name to request a token is https://eventhubs.azure.net/, and it's the same for all clouds/tenants.
// Kafka connection does not factor in here.
es.Resource = "https://eventhubs.azure.net"
case "signalr":
// Azure SignalR (data plane)
es.Resource = "https://signalr.azure.com"
default:
return es, errors.New("invalid resource name: " + resourceName)
}