diff --git a/bindings/azure/blobstorage/blobstorage.go b/bindings/azure/blobstorage/blobstorage.go index 5a8dbbc06..5150a9b0e 100644 --- a/bindings/azure/blobstorage/blobstorage.go +++ b/bindings/azure/blobstorage/blobstorage.go @@ -20,12 +20,8 @@ import ( "errors" "fmt" "io" - "net/url" "strconv" - "time" - "github.com/Azure/azure-sdk-for-go/sdk/azcore" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blockblob" @@ -33,8 +29,7 @@ import ( "github.com/google/uuid" "github.com/dapr/components-contrib/bindings" - azauth "github.com/dapr/components-contrib/internal/authentication/azure" - mdutils "github.com/dapr/components-contrib/metadata" + storageinternal "github.com/dapr/components-contrib/internal/component/azure/blobstorage" "github.com/dapr/kit/logger" "github.com/dapr/kit/ptr" ) @@ -62,9 +57,6 @@ const ( metadataKeyContentLanguage = "contentLanguage" metadataKeyContentDisposition = "contentDisposition" metadataKeyCacheControl = "cacheControl" - // Specifies the maximum number of HTTP requests that will be made to retry blob operations. A value - // of zero means that no additional HTTP requests will be made. - defaultBlobRetryCount = 3 // Specifies the maximum number of blobs to return, including all BlobPrefix elements. If the request does not // specify maxresults the server will return up to 5,000 items. // See: https://docs.microsoft.com/en-us/rest/api/storageservices/list-blobs#uri-parameters @@ -76,21 +68,12 @@ var ErrMissingBlobName = errors.New("blobName is a required attribute") // AzureBlobStorage allows saving blobs to an Azure Blob Storage account. type AzureBlobStorage struct { - metadata *blobStorageMetadata + metadata *storageinternal.BlobStorageMetadata containerClient *container.Client logger logger.Logger } -type blobStorageMetadata struct { - StorageAccount string `json:"storageAccount"` - StorageAccessKey string `json:"storageAccessKey"` - Container string `json:"container"` - GetBlobRetryCount int32 `json:"getBlobRetryCount,string"` - DecodeBase64 bool `json:"decodeBase64,string"` - PublicAccessLevel azblob.PublicAccessType `json:"publicAccessLevel"` -} - type createResponse struct { BlobURL string `json:"blobURL"` BlobName string `json:"blobName"` @@ -118,110 +101,14 @@ func NewAzureBlobStorage(logger logger.Logger) bindings.OutputBinding { // Init performs metadata parsing. func (a *AzureBlobStorage) Init(metadata bindings.Metadata) error { - m, err := a.parseMetadata(metadata) + var err error + a.containerClient, a.metadata, err = storageinternal.CreateContainerStorageClient(a.logger, metadata.Properties) if err != nil { return err } - a.metadata = m - - userAgent := "dapr-" + logger.DaprVersion - options := container.ClientOptions{ - ClientOptions: azcore.ClientOptions{ - Retry: policy.RetryOptions{ - MaxRetries: a.metadata.GetBlobRetryCount, - }, - Telemetry: policy.TelemetryOptions{ - ApplicationID: userAgent, - }, - }, - } - - settings, err := azauth.NewEnvironmentSettings("storage", metadata.Properties) - if err != nil { - return err - } - customEndpoint, ok := metadata.Properties[endpointKey] - var URL *url.URL - if ok && customEndpoint != "" { - var parseErr error - URL, parseErr = url.Parse(fmt.Sprintf("%s/%s/%s", customEndpoint, m.StorageAccount, m.Container)) - if parseErr != nil { - return parseErr - } - } else { - env := settings.AzureEnvironment - URL, _ = url.Parse(fmt.Sprintf("https://%s.blob.%s/%s", m.StorageAccount, env.StorageEndpointSuffix, m.Container)) - } - - var clientErr error - var client *container.Client - // Try using shared key credentials first - if m.StorageAccessKey != "" { - credential, newSharedKeyErr := azblob.NewSharedKeyCredential(m.StorageAccount, m.StorageAccessKey) - if err != nil { - return fmt.Errorf("invalid credentials with error: %w", newSharedKeyErr) - } - client, clientErr = container.NewClientWithSharedKeyCredential(URL.String(), credential, &options) - if clientErr != nil { - return fmt.Errorf("cannot init Blobstorage container client: %w", err) - } - a.containerClient = client - } else { - // fallback to AAD - credential, tokenErr := settings.GetTokenCredential() - if err != nil { - return fmt.Errorf("invalid credentials with error: %w", tokenErr) - } - client, clientErr = container.NewClient(URL.String(), credential, &options) - } - if clientErr != nil { - return fmt.Errorf("cannot init Blobstorage client: %w", clientErr) - } - - createContainerOptions := container.CreateOptions{ - Access: &m.PublicAccessLevel, - Metadata: map[string]string{}, - } - timeoutCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - _, err = client.Create(timeoutCtx, &createContainerOptions) - cancel() - // Don't return error, container might already exist - a.logger.Debugf("error creating container: %w", err) - a.containerClient = client - return nil } -func (a *AzureBlobStorage) parseMetadata(meta bindings.Metadata) (*blobStorageMetadata, error) { - m := blobStorageMetadata{ - GetBlobRetryCount: defaultBlobRetryCount, - } - mdutils.DecodeMetadata(meta.Properties, &m) - - if val, ok := mdutils.GetMetadataProperty(meta.Properties, azauth.StorageAccountNameKeys...); ok && val != "" { - m.StorageAccount = val - } else { - return nil, fmt.Errorf("missing or empty %s field from metadata", azauth.StorageAccountNameKeys[0]) - } - - if val, ok := mdutils.GetMetadataProperty(meta.Properties, azauth.StorageContainerNameKeys...); ok && val != "" { - m.Container = val - } else { - return nil, fmt.Errorf("missing or empty %s field from metadata", azauth.StorageContainerNameKeys[0]) - } - - // per the Dapr documentation "none" is a valid value - if m.PublicAccessLevel == "none" { - m.PublicAccessLevel = "" - } - if m.PublicAccessLevel != "" && !a.isValidPublicAccessType(m.PublicAccessLevel) { - return nil, fmt.Errorf("invalid public access level: %s; allowed: %s", - m.PublicAccessLevel, azblob.PossiblePublicAccessTypeValues()) - } - - return &m, nil -} - func (a *AzureBlobStorage) Operations() []bindings.OperationKind { return []bindings.OperationKind{ bindings.CreateOperation, @@ -288,7 +175,7 @@ func (a *AzureBlobStorage) create(ctx context.Context, req *bindings.InvokeReque } uploadOptions := azblob.UploadBufferOptions{ - Metadata: a.sanitizeMetadata(req.Metadata), + Metadata: storageinternal.SanitizeMetadata(a.logger, req.Metadata), HTTPHeaders: &blobHTTPHeaders, TransactionalContentMD5: contentMD5, } @@ -486,17 +373,6 @@ func (a *AzureBlobStorage) Invoke(ctx context.Context, req *bindings.InvokeReque } } -func (a *AzureBlobStorage) isValidPublicAccessType(accessType azblob.PublicAccessType) bool { - validTypes := azblob.PossiblePublicAccessTypeValues() - for _, item := range validTypes { - if item == accessType { - return true - } - } - - return false -} - func (a *AzureBlobStorage) isValidDeleteSnapshotsOptionType(accessType azblob.DeleteSnapshotsOptionType) bool { validTypes := azblob.PossibleDeleteSnapshotsOptionTypeValues() for _, item := range validTypes { @@ -507,41 +383,3 @@ func (a *AzureBlobStorage) isValidDeleteSnapshotsOptionType(accessType azblob.De return false } - -func (a *AzureBlobStorage) sanitizeMetadata(metadata map[string]string) map[string]string { - for key, val := range metadata { - // Keep only letters and digits - n := 0 - newKey := make([]byte, len(key)) - for i := 0; i < len(key); i++ { - if (key[i] >= 'A' && key[i] <= 'Z') || - (key[i] >= 'a' && key[i] <= 'z') || - (key[i] >= '0' && key[i] <= '9') { - newKey[n] = key[i] - n++ - } - } - - if n != len(key) { - nks := string(newKey[:n]) - a.logger.Warnf("metadata key %s contains disallowed characters, sanitized to %s", key, nks) - delete(metadata, key) - metadata[nks] = val - key = nks - } - - // Remove all non-ascii characters - n = 0 - newVal := make([]byte, len(val)) - for i := 0; i < len(val); i++ { - if val[i] > 127 { - continue - } - newVal[n] = val[i] - n++ - } - metadata[key] = string(newVal[:n]) - } - - return metadata -} diff --git a/bindings/azure/blobstorage/blobstorage_test.go b/bindings/azure/blobstorage/blobstorage_test.go index 8462e497d..fbcb6e54e 100644 --- a/bindings/azure/blobstorage/blobstorage_test.go +++ b/bindings/azure/blobstorage/blobstorage_test.go @@ -17,83 +17,12 @@ import ( "context" "testing" - "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" "github.com/stretchr/testify/assert" "github.com/dapr/components-contrib/bindings" "github.com/dapr/kit/logger" ) -func TestParseMetadata(t *testing.T) { - m := bindings.Metadata{} - blobStorage := NewAzureBlobStorage(logger.NewLogger("test")).(*AzureBlobStorage) - - t.Run("parse all metadata", func(t *testing.T) { - m.Properties = map[string]string{ - "storageAccount": "account", - "storageAccessKey": "key", - "container": "test", - "getBlobRetryCount": "5", - "decodeBase64": "true", - } - meta, err := blobStorage.parseMetadata(m) - assert.Nil(t, err) - assert.Equal(t, "test", meta.Container) - assert.Equal(t, "account", meta.StorageAccount) - // storageAccessKey is parsed in the azauth package - assert.Equal(t, true, meta.DecodeBase64) - assert.Equal(t, int32(5), meta.GetBlobRetryCount) - assert.Equal(t, "", string(meta.PublicAccessLevel)) - }) - - t.Run("parse metadata with publicAccessLevel = blob", func(t *testing.T) { - m.Properties = map[string]string{ - "storageAccount": "account", - "storageAccessKey": "key", - "container": "test", - "publicAccessLevel": "blob", - } - meta, err := blobStorage.parseMetadata(m) - assert.Nil(t, err) - assert.Equal(t, azblob.PublicAccessTypeBlob, meta.PublicAccessLevel) - }) - - t.Run("parse metadata with publicAccessLevel = container", func(t *testing.T) { - m.Properties = map[string]string{ - "storageAccount": "account", - "storageAccessKey": "key", - "container": "test", - "publicAccessLevel": "container", - } - meta, err := blobStorage.parseMetadata(m) - assert.Nil(t, err) - assert.Equal(t, azblob.PublicAccessTypeContainer, meta.PublicAccessLevel) - }) - - t.Run("parse metadata with invalid publicAccessLevel", func(t *testing.T) { - m.Properties = map[string]string{ - "storageAccount": "account", - "storageAccessKey": "key", - "container": "test", - "publicAccessLevel": "invalid", - } - _, err := blobStorage.parseMetadata(m) - assert.Error(t, err) - }) - - t.Run("sanitize metadata if necessary", func(t *testing.T) { - m.Properties = map[string]string{ - "somecustomfield": "some-custom-value", - "specialfield": "special:valueÜ", - "not-allowed:": "not-allowed", - } - meta := blobStorage.sanitizeMetadata(m.Properties) - assert.Equal(t, meta["somecustomfield"], "some-custom-value") - assert.Equal(t, meta["specialfield"], "special:value") - assert.Equal(t, meta["notallowed"], "not-allowed") - }) -} - func TestGetOption(t *testing.T) { blobStorage := NewAzureBlobStorage(logger.NewLogger("test")).(*AzureBlobStorage) diff --git a/internal/component/azure/blobstorage/client.go b/internal/component/azure/blobstorage/client.go new file mode 100644 index 000000000..64eb5a6d2 --- /dev/null +++ b/internal/component/azure/blobstorage/client.go @@ -0,0 +1,108 @@ +/* +Copyright 2021 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package blobstorage + +import ( + "context" + "fmt" + "net/url" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container" + + azauth "github.com/dapr/components-contrib/internal/authentication/azure" + "github.com/dapr/kit/logger" +) + +const ( + endpointKey = "endpoint" + // Specifies the maximum number of HTTP requests that will be made to retry blob operations. A value + // of zero means that no additional HTTP requests will be made. + defaultBlobRetryCount = 3 +) + +func CreateContainerStorageClient(log logger.Logger, meta map[string]string) (*container.Client, *BlobStorageMetadata, error) { + m, err := parseMetadata(meta) + if err != nil { + return nil, nil, err + } + + userAgent := "dapr-" + logger.DaprVersion + options := container.ClientOptions{ + ClientOptions: azcore.ClientOptions{ + Retry: policy.RetryOptions{ + MaxRetries: m.RetryCount, + }, + Telemetry: policy.TelemetryOptions{ + ApplicationID: userAgent, + }, + }, + } + + settings, err := azauth.NewEnvironmentSettings("storage", meta) + if err != nil { + return nil, nil, err + } + customEndpoint, ok := meta[endpointKey] + var URL *url.URL + if ok && customEndpoint != "" { + var parseErr error + URL, parseErr = url.Parse(fmt.Sprintf("%s/%s/%s", customEndpoint, m.AccountName, m.ContainerName)) + if parseErr != nil { + return nil, nil, parseErr + } + } else { + env := settings.AzureEnvironment + URL, _ = url.Parse(fmt.Sprintf("https://%s.blob.%s/%s", m.AccountName, env.StorageEndpointSuffix, m.ContainerName)) + } + + var clientErr error + var client *container.Client + // Try using shared key credentials first + if m.AccountKey != "" { + credential, newSharedKeyErr := azblob.NewSharedKeyCredential(m.AccountName, m.AccountKey) + if err != nil { + return nil, nil, fmt.Errorf("invalid credentials with error: %w", newSharedKeyErr) + } + client, clientErr = container.NewClientWithSharedKeyCredential(URL.String(), credential, &options) + if clientErr != nil { + return nil, nil, fmt.Errorf("cannot init Blobstorage container client: %w", err) + } + } else { + // fallback to AAD + credential, tokenErr := settings.GetTokenCredential() + if err != nil { + return nil, nil, fmt.Errorf("invalid credentials with error: %w", tokenErr) + } + client, clientErr = container.NewClient(URL.String(), credential, &options) + } + if clientErr != nil { + return nil, nil, fmt.Errorf("cannot init Blobstorage client: %w", clientErr) + } + + createContainerOptions := container.CreateOptions{ + Access: &m.PublicAccessLevel, + Metadata: map[string]string{}, + } + timeoutCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + _, err = client.Create(timeoutCtx, &createContainerOptions) + cancel() + // Don't return error, container might already exist + log.Debugf("error creating container: %w", err) + + return client, m, nil +} diff --git a/internal/component/azure/blobstorage/metadata.go b/internal/component/azure/blobstorage/metadata.go new file mode 100644 index 000000000..14223c9ef --- /dev/null +++ b/internal/component/azure/blobstorage/metadata.go @@ -0,0 +1,123 @@ +/* +Copyright 2021 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package blobstorage + +import ( + "fmt" + "strconv" + + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" + + azauth "github.com/dapr/components-contrib/internal/authentication/azure" + mdutils "github.com/dapr/components-contrib/metadata" + "github.com/dapr/kit/logger" +) + +type BlobStorageMetadata struct { + AccountName string + AccountKey string + ContainerName string + RetryCount int32 `json:"retryCount,string"` + DecodeBase64 bool `json:"decodeBase64,string"` + PublicAccessLevel azblob.PublicAccessType +} + +func parseMetadata(meta map[string]string) (*BlobStorageMetadata, error) { + m := BlobStorageMetadata{ + RetryCount: defaultBlobRetryCount, + } + mdutils.DecodeMetadata(meta, &m) + + if val, ok := mdutils.GetMetadataProperty(meta, azauth.StorageAccountNameKeys...); ok && val != "" { + m.AccountName = val + } else { + return nil, fmt.Errorf("missing or empty %s field from metadata", azauth.StorageAccountNameKeys[0]) + } + + if val, ok := mdutils.GetMetadataProperty(meta, azauth.StorageContainerNameKeys...); ok && val != "" { + m.ContainerName = val + } else { + return nil, fmt.Errorf("missing or empty %s field from metadata", azauth.StorageContainerNameKeys[0]) + } + + // per the Dapr documentation "none" is a valid value + if m.PublicAccessLevel == "none" { + m.PublicAccessLevel = "" + } + if m.PublicAccessLevel != "" && !isValidPublicAccessType(m.PublicAccessLevel) { + return nil, fmt.Errorf("invalid public access level: %s; allowed: %s", + m.PublicAccessLevel, azblob.PossiblePublicAccessTypeValues()) + } + + // we need this key for backwards compatibility + if val, ok := meta["getBlobRetryCount"]; ok && val != "" { + // convert val from string to int32 + parseInt, err := strconv.ParseInt(val, 10, 32) + if err != nil { + return nil, err + } + m.RetryCount = int32(parseInt) + } + + return &m, nil +} + +func isValidPublicAccessType(accessType azblob.PublicAccessType) bool { + validTypes := azblob.PossiblePublicAccessTypeValues() + for _, item := range validTypes { + if item == accessType { + return true + } + } + + return false +} + +func SanitizeMetadata(log logger.Logger, metadata map[string]string) map[string]string { + for key, val := range metadata { + // Keep only letters and digits + n := 0 + newKey := make([]byte, len(key)) + for i := 0; i < len(key); i++ { + if (key[i] >= 'A' && key[i] <= 'Z') || + (key[i] >= 'a' && key[i] <= 'z') || + (key[i] >= '0' && key[i] <= '9') { + newKey[n] = key[i] + n++ + } + } + + if n != len(key) { + nks := string(newKey[:n]) + log.Warnf("metadata key %s contains disallowed characters, sanitized to %s", key, nks) + delete(metadata, key) + metadata[nks] = val + key = nks + } + + // Remove all non-ascii characters + n = 0 + newVal := make([]byte, len(val)) + for i := 0; i < len(val); i++ { + if val[i] > 127 { + continue + } + newVal[n] = val[i] + n++ + } + metadata[key] = string(newVal[:n]) + } + + return metadata +} diff --git a/internal/component/azure/blobstorage/metadata_test.go b/internal/component/azure/blobstorage/metadata_test.go new file mode 100644 index 000000000..e60ef9c14 --- /dev/null +++ b/internal/component/azure/blobstorage/metadata_test.go @@ -0,0 +1,93 @@ +/* +Copyright 2021 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package blobstorage + +import ( + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" + "github.com/dapr/kit/logger" + "github.com/stretchr/testify/assert" +) + +func TestParseMetadata(t *testing.T) { + logger := logger.NewLogger("test") + var m map[string]string + + t.Run("parse all metadata", func(t *testing.T) { + m = map[string]string{ + "storageAccount": "account", + "storageAccessKey": "key", + "container": "test", + "getBlobRetryCount": "5", + "decodeBase64": "true", + } + meta, err := parseMetadata(m) + assert.Nil(t, err) + assert.Equal(t, "test", meta.ContainerName) + assert.Equal(t, "account", meta.AccountName) + // storageAccessKey is parsed in the azauth package + assert.Equal(t, true, meta.DecodeBase64) + assert.Equal(t, int32(5), meta.RetryCount) + assert.Equal(t, "", string(meta.PublicAccessLevel)) + }) + + t.Run("parse metadata with publicAccessLevel = blob", func(t *testing.T) { + m = map[string]string{ + "storageAccount": "account", + "storageAccessKey": "key", + "container": "test", + "publicAccessLevel": "blob", + } + meta, err := parseMetadata(m) + assert.Nil(t, err) + assert.Equal(t, azblob.PublicAccessTypeBlob, meta.PublicAccessLevel) + }) + + t.Run("parse metadata with publicAccessLevel = container", func(t *testing.T) { + m = map[string]string{ + "storageAccount": "account", + "storageAccessKey": "key", + "container": "test", + "publicAccessLevel": "container", + } + meta, err := parseMetadata(m) + assert.Nil(t, err) + assert.Equal(t, azblob.PublicAccessTypeContainer, meta.PublicAccessLevel) + }) + + t.Run("parse metadata with invalid publicAccessLevel", func(t *testing.T) { + m = map[string]string{ + "storageAccount": "account", + "storageAccessKey": "key", + "container": "test", + "publicAccessLevel": "invalid", + } + _, err := parseMetadata(m) + assert.Error(t, err) + }) + + t.Run("sanitize metadata if necessary", func(t *testing.T) { + m = map[string]string{ + "somecustomfield": "some-custom-value", + "specialfield": "special:valueÜ", + "not-allowed:": "not-allowed", + } + meta := SanitizeMetadata(logger, m) + assert.Equal(t, meta["somecustomfield"], "some-custom-value") + assert.Equal(t, meta["specialfield"], "special:value") + assert.Equal(t, meta["notallowed"], "not-allowed") + }) +} diff --git a/state/azure/blobstorage/blobstorage.go b/state/azure/blobstorage/blobstorage.go index 5c6f6cd38..fb92126dd 100644 --- a/state/azure/blobstorage/blobstorage.go +++ b/state/azure/blobstorage/blobstorage.go @@ -40,20 +40,17 @@ import ( b64 "encoding/base64" "fmt" "io" - "net/url" "reflect" "strings" - "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container" jsoniter "github.com/json-iterator/go" - azauth "github.com/dapr/components-contrib/internal/authentication/azure" + storageinternal "github.com/dapr/components-contrib/internal/component/azure/blobstorage" mdutils "github.com/dapr/components-contrib/metadata" "github.com/dapr/components-contrib/state" "github.com/dapr/kit/logger" @@ -81,80 +78,13 @@ type StateStore struct { logger logger.Logger } -type blobStorageMetadata struct { - AccountName string - ContainerName string - AccountKey string -} - // Init the connection to blob storage, optionally creates a blob container if it doesn't exist. func (r *StateStore) Init(metadata state.Metadata) error { - m, err := getBlobStorageMetadata(metadata.Properties) + var err error + r.containerClient, _, err = storageinternal.CreateContainerStorageClient(r.logger, metadata.Properties) if err != nil { return err } - - userAgent := "dapr-" + logger.DaprVersion - options := container.ClientOptions{ - ClientOptions: azcore.ClientOptions{ - Telemetry: policy.TelemetryOptions{ - ApplicationID: userAgent, - }, - }, - } - - settings, err := azauth.NewEnvironmentSettings("storage", metadata.Properties) - if err != nil { - return err - } - customEndpoint, ok := metadata.Properties[endpointKey] - var URL *url.URL - if ok && customEndpoint != "" { - var parseErr error - URL, parseErr = url.Parse(fmt.Sprintf("%s/%s/%s", customEndpoint, m.AccountName, m.ContainerName)) - if parseErr != nil { - return parseErr - } - } else { - env := settings.AzureEnvironment - URL, _ = url.Parse(fmt.Sprintf("https://%s.blob.%s/%s", m.AccountName, env.StorageEndpointSuffix, m.ContainerName)) - } - - var clientErr error - var client *container.Client - // Try using shared key credentials first - if m.AccountKey != "" { - credential, newSharedKeyErr := azblob.NewSharedKeyCredential(m.AccountName, m.AccountKey) - if err != nil { - return fmt.Errorf("invalid credentials with error: %w", newSharedKeyErr) - } - client, clientErr = container.NewClientWithSharedKeyCredential(URL.String(), credential, &options) - if clientErr != nil { - return fmt.Errorf("cannot init Blobstorage container client: %w", err) - } - r.containerClient = client - } else { - // fallback to AAD - credential, tokenErr := settings.GetTokenCredential() - if err != nil { - return fmt.Errorf("invalid credentials with error: %w", tokenErr) - } - client, clientErr = container.NewClient(URL.String(), credential, &options) - } - if clientErr != nil { - return fmt.Errorf("cannot init Blobstorage client: %w", clientErr) - } - - createContainerOptions := container.CreateOptions{ - Access: nil, - } - timeoutCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - _, err = client.Create(timeoutCtx, &createContainerOptions) - cancel() - // Don't return error, container might already exist - r.logger.Debugf("error creating container: %w", err) - r.containerClient = client - return nil } @@ -193,7 +123,7 @@ func (r *StateStore) Ping() error { } func (r *StateStore) GetComponentMetadata() map[string]string { - metadataStruct := blobStorageMetadata{} + metadataStruct := storageinternal.BlobStorageMetadata{} metadataInfo := map[string]string{} mdutils.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo) return metadataInfo @@ -211,25 +141,6 @@ func NewAzureBlobStorageStore(logger logger.Logger) state.Store { return s } -func getBlobStorageMetadata(meta map[string]string) (*blobStorageMetadata, error) { - m := blobStorageMetadata{} - err := mdutils.DecodeMetadata(meta, &m) - - if val, ok := mdutils.GetMetadataProperty(meta, azauth.StorageAccountNameKeys...); ok && val != "" { - m.AccountName = val - } else { - return nil, fmt.Errorf("missing or empty %s field from metadata", azauth.StorageAccountNameKeys[0]) - } - - if val, ok := mdutils.GetMetadataProperty(meta, azauth.StorageContainerNameKeys...); ok && val != "" { - m.ContainerName = val - } else { - return nil, fmt.Errorf("missing or empty %s field from metadata", azauth.StorageContainerNameKeys[0]) - } - - return &m, err -} - func (r *StateStore) readFile(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { blockBlobClient := r.containerClient.NewBlockBlobClient(getFileName(req.Key)) @@ -289,7 +200,7 @@ func (r *StateStore) writeFile(ctx context.Context, req *state.SetRequest) error uploadOptions := azblob.UploadBufferOptions{ AccessConditions: &accessConditions, - Metadata: req.Metadata, + Metadata: storageinternal.SanitizeMetadata(r.logger, req.Metadata), HTTPHeaders: &blobHTTPHeaders, Concurrency: 16, } diff --git a/state/azure/blobstorage/blobstorage_test.go b/state/azure/blobstorage/blobstorage_test.go index e86505d51..2e69af070 100644 --- a/state/azure/blobstorage/blobstorage_test.go +++ b/state/azure/blobstorage/blobstorage_test.go @@ -58,26 +58,6 @@ func TestInit(t *testing.T) { }) } -func TestGetBlobStorageMetaData(t *testing.T) { - t.Run("Nothing at all passed", func(t *testing.T) { - m := make(map[string]string) - _, err := getBlobStorageMetadata(m) - - assert.NotNil(t, err) - }) - - t.Run("All parameters passed and parsed", func(t *testing.T) { - m := make(map[string]string) - m["accountName"] = "acc" - m["containerName"] = "dapr" - meta, err := getBlobStorageMetadata(m) - - assert.Nil(t, err) - assert.Equal(t, "acc", meta.AccountName) - assert.Equal(t, "dapr", meta.ContainerName) - }) -} - func TestFileName(t *testing.T) { t.Run("Valid composite key", func(t *testing.T) { key := getFileName("app_id||key")