diff --git a/bindings/apns/authorization_builder.go b/bindings/apns/authorization_builder.go index 9ae8aa1b6..2e60df006 100644 --- a/bindings/apns/authorization_builder.go +++ b/bindings/apns/authorization_builder.go @@ -17,7 +17,7 @@ import ( "sync" "time" - "github.com/golang-jwt/jwt" + jwt "github.com/golang-jwt/jwt/v4" "github.com/dapr/kit/logger" ) diff --git a/bindings/azure/signalr/signalr.go b/bindings/azure/signalr/signalr.go index 6a90d7410..b88e5ba3f 100644 --- a/bindings/azure/signalr/signalr.go +++ b/bindings/azure/signalr/signalr.go @@ -22,102 +22,167 @@ import ( "strings" "time" - "github.com/golang-jwt/jwt" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + jwt "github.com/golang-jwt/jwt/v4" "github.com/pkg/errors" "github.com/dapr/components-contrib/bindings" + azauth "github.com/dapr/components-contrib/internal/authentication/azure" "github.com/dapr/kit/logger" ) const ( - errorPrefix = "azure signalr error:" - logPrefix = "azure signalr:" + errorPrefix = "azure signalr error:" + logPrefix = "azure signalr:" + + // Metadata keys. + // Azure AD credentials are parsed separately and not listed here. connectionStringKey = "connectionString" + accessKeyKey = "accessKey" + endpointKey = "endpoint" hubKey = "hub" - groupKey = "group" - userKey = "user" + + // Invoke metadata keys. + groupKey = "group" + userKey = "user" ) -// NewSignalR creates a new pub/sub based on Azure SignalR. -func NewSignalR(logger logger.Logger) *SignalR { - return &SignalR{ - tokens: make(map[string]signalrCachedToken), - httpClient: &http.Client{Timeout: 30 * time.Second}, - logger: logger, +// Global HTTP client +var httpClient *http.Client + +func init() { + httpClient = &http.Client{ + Timeout: 30 * time.Second, } } -type signalrCachedToken struct { - token string - expiration time.Time +// NewSignalR creates a new output binding for Azure SignalR. +func NewSignalR(logger logger.Logger) *SignalR { + return &SignalR{ + logger: logger, + httpClient: httpClient, + } } // SignalR is an output binding for Azure SignalR. type SignalR struct { - endpoint string - accessKey string - version string - hub string - userAgent string - tokens map[string]signalrCachedToken - httpClient *http.Client + endpoint string + accessKey string + hub string + userAgent string + aadToken azcore.TokenCredential - logger logger.Logger + httpClient *http.Client + logger logger.Logger } // Init is responsible for initializing the SignalR output based on the metadata. -func (s *SignalR) Init(metadata bindings.Metadata) error { +func (s *SignalR) Init(metadata bindings.Metadata) (err error) { s.userAgent = "dapr-" + logger.DaprVersion - connectionString, ok := metadata.Properties[connectionStringKey] - if !ok || connectionString == "" { - return fmt.Errorf("missing connection string") + err = s.parseMetadata(metadata.Properties) + if err != nil { + return err } - if hub, ok := metadata.Properties[hubKey]; ok && hub != "" { - s.hub = hub - } - - // Expected: Endpoint=https://.service.signalr.net;AccessKey=;Version=1.0; - connectionValues := strings.Split(strings.TrimSpace(connectionString), ";") - for _, connectionValue := range connectionValues { - if i := strings.Index(connectionValue, "="); i != -1 && len(connectionValue) > (i+1) { - k := connectionValue[0:i] - switch k { - case "Endpoint": - s.endpoint = connectionValue[i+1:] - if s.endpoint[len(s.endpoint)-1] == '/' { - s.endpoint = s.endpoint[:len(s.endpoint)-1] - } - case "AccessKey": - s.accessKey = connectionValue[i+1:] - case "Version": - s.version = connectionValue[i+1:] - } + // If using AAD for authentication, init the token provider + if s.accessKey == "" { + var settings azauth.EnvironmentSettings + settings, err = azauth.NewEnvironmentSettings("signalr", metadata.Properties) + if err != nil { + return err + } + s.aadToken, err = settings.GetTokenCredential() + if err != nil { + return err } } - if s.endpoint == "" { - return fmt.Errorf("missing endpoint in connection string") + return nil +} + +func (s *SignalR) parseMetadata(md map[string]string) (err error) { + // Start by parsing the connection string if present + connectionString, ok := md[connectionStringKey] + if ok && connectionString != "" { + // Expected options: + // Access key: "Endpoint=https://.service.signalr.net;AccessKey=;Version=1.0;" + // System-assigned managed identity: "Endpoint=https://.service.signalr.net;AuthType=aad;Version=1.0;" + // User-assigned managed identity: "Endpoint=https://.service.signalr.net;AuthType=aad;ClientId=;Version=1.0;" + // Azure AD application: "Endpoint=https://.service.signalr.net;AuthType=aad;ClientId=;ClientSecret=;TenantId=;Version=1.0;" + // Note: connection string can't be used if the client secret contains the ; key + connectionValues := strings.Split(strings.TrimSpace(connectionString), ";") + useAAD := false + for _, connectionValue := range connectionValues { + if i := strings.Index(connectionValue, "="); i != -1 && len(connectionValue) > (i+1) { + k := connectionValue[0:i] + switch k { + case "Endpoint": + s.endpoint = connectionValue[i+1:] + case "AccessKey": + s.accessKey = connectionValue[i+1:] + case "AuthType": + if connectionValue[i+1:] != "aad" { + return fmt.Errorf("invalid value for AuthType in the connection string; only 'aad' is supported") + } + useAAD = true + case "ClientId", "ClientSecret", "TenantId": + v := connectionValue[i+1:] + // Set the values in the metadata map so they can be picked up by the azauth module + md["azure"+k] = v + case "Version": + v := connectionValue[i+1:] + // We only support version "1.0" + if v != "1.0" { + return fmt.Errorf("invalid value for Version in the connection string: '%s'; only version '1.0' is supported", v) + } + } + } else if len(connectionValue) != 0 { + return fmt.Errorf("the connection string is invalid or malformed") + } + } + + // Check here because if we use a connection string, we'd have an explicit "AuthType=aad" option + // We would otherwise catch this issue later, but here we can be more explicit with the error + if s.accessKey == "" && !useAAD { + return fmt.Errorf("missing AccessKey in the connection string") + } } - if s.accessKey == "" { - return fmt.Errorf("missing access key in connection string") + // Parse the other metadata keys, which could also override the values from the connection string + if v, ok := md[hubKey]; ok && v != "" { + s.hub = v + } + if v, ok := md[endpointKey]; ok && v != "" { + s.endpoint = v + } + if v, ok := md[accessKeyKey]; ok && v != "" { + s.accessKey = v + } + + // Trim ending "/" from endpoint + s.endpoint = strings.TrimSuffix(s.endpoint, "/") + + // Check for required values + if s.endpoint == "" { + return fmt.Errorf("missing endpoint in the metadata or connection string") } return nil } func (s *SignalR) resolveAPIURL(req *bindings.InvokeRequest) (string, error) { - hub := s.hub - if hub == "" { - hubFromRequest, ok := req.Metadata[hubKey] - if !ok || hubFromRequest == "" { - return "", fmt.Errorf("%s missing hub", errorPrefix) - } - - hub = hubFromRequest + hub, ok := req.Metadata[hubKey] + if !ok || hub == "" { + hub = s.hub } + if hub == "" { + return "", fmt.Errorf("%s missing hub", errorPrefix) + } + + // Hub name is lower-cased in the official SDKs (e.g. .NET) + hub = strings.ToLower(hub) var url string if group, ok := req.Metadata[groupKey]; ok && group != "" { @@ -138,26 +203,26 @@ func (s *SignalR) sendMessageToSignalR(ctx context.Context, url string, token st } httpReq.Header.Set("Authorization", "Bearer "+token) - httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Content-Type", "application/json; charset=utf-8") httpReq.Header.Set("User-Agent", s.userAgent) resp, err := s.httpClient.Do(httpReq) if err != nil { return errors.Wrap(err, "request to azure signalr api failed") } - defer resp.Body.Close() - if resp.StatusCode != 200 && resp.StatusCode != 202 { - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return err - } - - return fmt.Errorf("%s azure signalr returned code %d, content is '%s'", errorPrefix, resp.StatusCode, string(body)) + // Read the body regardless to drain it and ensure the connection can be reused + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err } - s.logger.Debugf("%s azure signalr call to '%s' returned with status code %d", logPrefix, url, resp.StatusCode) + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { + return fmt.Errorf("%s azure signalr failed with code %d, content is '%s'", errorPrefix, resp.StatusCode, string(body)) + } + + s.logger.Debugf("%s azure signalr call to '%s' completed with code %d", logPrefix, url, resp.StatusCode) return nil } @@ -172,7 +237,7 @@ func (s *SignalR) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bin return nil, err } - token, err := s.ensureValidToken(url) + token, err := s.getToken(ctx, url) if err != nil { return nil, err } @@ -185,34 +250,34 @@ func (s *SignalR) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bin return nil, nil } -func (s *SignalR) ensureValidToken(url string) (string, error) { - now := time.Now() +// Returns an access token for a request to the given URL +func (s *SignalR) getToken(ctx context.Context, url string) (token string, err error) { + // If we have an Azure AD token provider, use that first + if s.aadToken != nil { + var at azcore.AccessToken + at, err = s.aadToken.GetToken(ctx, policy.TokenRequestOptions{ + Scopes: []string{"https://signalr.azure.com/.default"}, + }) + if err != nil { + return "", err + } + token = at.Token + } else { + claims := &jwt.StandardClaims{ + ExpiresAt: time.Now().Add(15 * time.Minute).Unix(), + Audience: url, + } + err = claims.Valid() + if err != nil { + return "", err + } - if existing, ok := s.tokens[url]; ok { - if existing.token != "" && now.Before(existing.expiration) { - return existing.token, nil + jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + token, err = jwtToken.SignedString([]byte(s.accessKey)) + if err != nil { + return "", err } } - expiration := now.Add(1 * time.Hour) - - claims := &jwt.StandardClaims{ - ExpiresAt: expiration.Unix(), - Audience: url, - } - - if err := claims.Valid(); err != nil { - return "", err - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - - newToken, err := token.SignedString([]byte(s.accessKey)) - if err != nil { - return "", err - } - - s.tokens[url] = signalrCachedToken{token: newToken, expiration: expiration.Add(time.Minute * -5)} - - return newToken, nil + return token, nil } diff --git a/bindings/azure/signalr/signalr_test.go b/bindings/azure/signalr/signalr_test.go index 2ef424d16..153bc2fe3 100644 --- a/bindings/azure/signalr/signalr_test.go +++ b/bindings/azure/signalr/signalr_test.go @@ -31,12 +31,12 @@ import ( func TestConfigurationValid(t *testing.T) { tests := []struct { - name string - properties map[string]string - expectedEndpoint string - expectedAccessKey string - expectedVersion string - expectedHub string + name string + properties map[string]string + expectedEndpoint string + expectedAccessKey string + expectedHub string + additionalMetadata map[string]string }{ { "With all properties", @@ -45,8 +45,8 @@ func TestConfigurationValid(t *testing.T) { }, "https://fake.service.signalr.net", "fakekey", - "1.0", "", + nil, }, { "With missing version", @@ -56,7 +56,7 @@ func TestConfigurationValid(t *testing.T) { "https://fake.service.signalr.net", "fakekey", "", - "", + nil, }, { "With semicolon after access key", @@ -66,40 +66,137 @@ func TestConfigurationValid(t *testing.T) { "https://fake.service.signalr.net", "fakekey", "", - "", + nil, }, { "With trailing slash in endpoint", map[string]string{ - "connectionString": "Endpoint=https://fake.service.signalr.net/;AccessKey=fakekey;Version=1.1", + "connectionString": "Endpoint=https://fake.service.signalr.net/;AccessKey=fakekey;Version=1.0", }, "https://fake.service.signalr.net", "fakekey", - "1.1", "", + nil, }, { "With hub", map[string]string{ - "connectionString": "Endpoint=https://fake.service.signalr.net/;AccessKey=fakekey;Version=1.1", + "connectionString": "Endpoint=https://fake.service.signalr.net/;AccessKey=fakekey;Version=1.0", "hub": "myhub", }, "https://fake.service.signalr.net", "fakekey", - "1.1", "myhub", + nil, + }, + { + "With AAD and no access key (system-assigned MSI)", + map[string]string{ + "connectionString": "Endpoint=https://fake.service.signalr.net/;AuthType=aad;Version=1.0", + }, + "https://fake.service.signalr.net", + "", + "", + nil, + }, + { + "Add azureClientId to metadata map (user-assigned MSI)", + map[string]string{ + "connectionString": "Endpoint=https://fake.service.signalr.net/;AuthType=aad;ClientId=b83aec5c-54a3-4e4a-8831-ba3f849b79a1;Version=1.0", + }, + "https://fake.service.signalr.net", + "", + "", + map[string]string{ + "azureClientId": "b83aec5c-54a3-4e4a-8831-ba3f849b79a1", + }, + }, + { + "Add Azure AD credentials to metadata map (Azure AD app)", + map[string]string{ + "connectionString": "Endpoint=https://fake.service.signalr.net/;AuthType=aad;ClientId=b83aec5c-54a3-4e4a-8831-ba3f849b79a1;ClientSecret=fakesecret;TenantId=f0f4622e-e476-46b5-bd0c-1866d27117d4;Version=1.0", + }, + "https://fake.service.signalr.net", + "", + "", + map[string]string{ + "azureClientId": "b83aec5c-54a3-4e4a-8831-ba3f849b79a1", + "azureClientSecret": "fakesecret", + "azureTenantId": "f0f4622e-e476-46b5-bd0c-1866d27117d4", + }, + }, + { + "No connection string, access key", + map[string]string{ + "endpoint": "https://fake.service.signalr.net/", + "accessKey": "fakekey", + }, + "https://fake.service.signalr.net", + "fakekey", + "", + nil, + }, + { + "No connection string, access key and hub", + map[string]string{ + "endpoint": "https://fake.service.signalr.net/", + "accessKey": "fakekey", + "hub": "myhub", + }, + "https://fake.service.signalr.net", + "fakekey", + "myhub", + nil, + }, + { + "No connection string, Azure AD", + map[string]string{ + "endpoint": "https://fake.service.signalr.net/", + "azureClientId": "b83aec5c-54a3-4e4a-8831-ba3f849b79a1", + "azureClientSecret": "fakesecret", + "azureTenantId": "f0f4622e-e476-46b5-bd0c-1866d27117d4", + }, + "https://fake.service.signalr.net", + "", + "", + map[string]string{ + "azureClientId": "b83aec5c-54a3-4e4a-8831-ba3f849b79a1", + "azureClientSecret": "fakesecret", + "azureTenantId": "f0f4622e-e476-46b5-bd0c-1866d27117d4", + }, + }, + { + "No connection string, Azure AD with aliased names", + map[string]string{ + "endpoint": "https://fake.service.signalr.net/", + "spnClientId": "b83aec5c-54a3-4e4a-8831-ba3f849b79a1", + "spnClientSecret": "fakesecret", + "spnTenantId": "f0f4622e-e476-46b5-bd0c-1866d27117d4", + }, + "https://fake.service.signalr.net", + "", + "", + map[string]string{ + "spnClientId": "b83aec5c-54a3-4e4a-8831-ba3f849b79a1", + "spnClientSecret": "fakesecret", + "spnTenantId": "f0f4622e-e476-46b5-bd0c-1866d27117d4", + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := NewSignalR(logger.NewLogger("test")) - err := s.Init(bindings.Metadata{Properties: tt.properties}) + err := s.parseMetadata(tt.properties) assert.Nil(t, err) assert.Equal(t, tt.expectedEndpoint, s.endpoint) assert.Equal(t, tt.expectedAccessKey, s.accessKey) - assert.Equal(t, tt.expectedVersion, s.version) assert.Equal(t, tt.expectedHub, s.hub) + if len(tt.additionalMetadata) > 0 { + for k := range tt.additionalMetadata { + assert.Equal(t, tt.properties[k], tt.additionalMetadata[k]) + } + } }) } } @@ -138,7 +235,7 @@ func TestInvalidConfigurations(t *testing.T) { }, }, { - "Missing access key", + "Missing access key (no AAD)", map[string]string{ "connectionString1": "Endpoint=https://fake.service.signalr.net;", }, @@ -146,7 +243,13 @@ func TestInvalidConfigurations(t *testing.T) { { "With empty endpoint value", map[string]string{ - "connectionString": "Endpoint=;AccessKey=fakekey;Version=1.1", + "connectionString": "Endpoint=;AccessKey=fakekey;Version=1.0", + }, + }, + { + "With invalid version", + map[string]string{ + "connectionString": "Endpoint=https://fake.service.signalr.net;AccessKey=fakekey;Version=2.0", }, }, } @@ -154,7 +257,7 @@ func TestInvalidConfigurations(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := NewSignalR(logger.NewLogger("test")) - err := s.Init(bindings.Metadata{Properties: tt.properties}) + err := s.parseMetadata(tt.properties) assert.NotNil(t, err) }) } @@ -265,12 +368,12 @@ func TestWriteShouldSucceed(t *testing.T) { userID string expectedURL string }{ - {"Broadcast receiving hub should call SignalR service", "testHub", "", "", "", "https://fake.service.signalr.net/api/v1/hubs/testHub"}, - {"Broadcast with hub metadata should call SignalR service", "", "testHub", "", "", "https://fake.service.signalr.net/api/v1/hubs/testHub"}, - {"Group receiving hub should call SignalR service", "testHub", "", "mygroup", "", "https://fake.service.signalr.net/api/v1/hubs/testHub/groups/mygroup"}, - {"Group with hub metadata should call SignalR service", "", "testHub", "mygroup", "", "https://fake.service.signalr.net/api/v1/hubs/testHub/groups/mygroup"}, - {"User receiving hub should call SignalR service", "testHub", "", "", "myuser", "https://fake.service.signalr.net/api/v1/hubs/testHub/users/myuser"}, - {"User with hub metadata should call SignalR service", "", "testHub", "", "myuser", "https://fake.service.signalr.net/api/v1/hubs/testHub/users/myuser"}, + {"Broadcast receiving hub should call SignalR service", "testHub", "", "", "", "https://fake.service.signalr.net/api/v1/hubs/testhub"}, + {"Broadcast with hub metadata should call SignalR service", "", "testHub", "", "", "https://fake.service.signalr.net/api/v1/hubs/testhub"}, + {"Group receiving hub should call SignalR service", "testHub", "", "mygroup", "", "https://fake.service.signalr.net/api/v1/hubs/testhub/groups/mygroup"}, + {"Group with hub metadata should call SignalR service", "", "testHub", "mygroup", "", "https://fake.service.signalr.net/api/v1/hubs/testhub/groups/mygroup"}, + {"User receiving hub should call SignalR service", "testHub", "", "", "myuser", "https://fake.service.signalr.net/api/v1/hubs/testhub/users/myuser"}, + {"User with hub metadata should call SignalR service", "", "testHub", "", "myuser", "https://fake.service.signalr.net/api/v1/hubs/testhub/users/myuser"}, } for _, tt := range tests { @@ -291,7 +394,7 @@ func TestWriteShouldSucceed(t *testing.T) { assert.Equal(t, int32(1), httpTransport.requestCount) assert.Equal(t, tt.expectedURL, httpTransport.request.URL.String()) assert.NotNil(t, httpTransport.request) - assert.Equal(t, "application/json", httpTransport.request.Header.Get("Content-Type")) + assert.Equal(t, "application/json; charset=utf-8", httpTransport.request.Header.Get("Content-Type")) }) } } diff --git a/go.mod b/go.mod index 17f2bcf52..d8c803d12 100644 --- a/go.mod +++ b/go.mod @@ -65,7 +65,7 @@ require ( github.com/go-redis/redis/v8 v8.11.5 github.com/go-sql-driver/mysql v1.6.0 github.com/gocql/gocql v0.0.0-20210515062232-b7ef815b4556 - github.com/golang-jwt/jwt v3.2.2+incompatible + github.com/golang-jwt/jwt v3.2.2+incompatible // indirect github.com/golang/mock v1.6.0 github.com/golang/snappy v0.0.4 // indirect github.com/google/uuid v1.3.0 @@ -270,7 +270,7 @@ require ( github.com/gogap/errors v0.0.0-20200228125012-531a6449b28c // indirect github.com/gogap/stack v0.0.0-20150131034635-fef68dddd4f8 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/golang-jwt/jwt/v4 v4.2.0 // indirect + github.com/golang-jwt/jwt/v4 v4.2.0 github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.2 // indirect diff --git a/internal/authentication/azure/auth.go b/internal/authentication/azure/auth.go index a3a8fc989..a13febdfc 100644 --- a/internal/authentication/azure/auth.go +++ b/internal/authentication/azure/auth.go @@ -63,6 +63,9 @@ func NewEnvironmentSettings(resourceName string, values map[string]string) (Envi // The resource name to request a token is https://eventhubs.azure.net/, and it's the same for all clouds/tenants. // Kafka connection does not factor in here. es.Resource = "https://eventhubs.azure.net" + case "signalr": + // Azure SignalR (data plane) + es.Resource = "https://signalr.azure.com" default: return es, errors.New("invalid resource name: " + resourceName) }