AzBlob components: Extract shared code
Signed-off-by: Bernd Verst <4535280+berndverst@users.noreply.github.com>
This commit is contained in:
parent
70eb9f3a9c
commit
8811d5e64f
|
|
@ -20,12 +20,8 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blockblob"
|
||||
|
|
@ -33,8 +29,7 @@ import (
|
|||
"github.com/google/uuid"
|
||||
|
||||
"github.com/dapr/components-contrib/bindings"
|
||||
azauth "github.com/dapr/components-contrib/internal/authentication/azure"
|
||||
mdutils "github.com/dapr/components-contrib/metadata"
|
||||
storageinternal "github.com/dapr/components-contrib/internal/component/azure/blobstorage"
|
||||
"github.com/dapr/kit/logger"
|
||||
"github.com/dapr/kit/ptr"
|
||||
)
|
||||
|
|
@ -62,9 +57,6 @@ const (
|
|||
metadataKeyContentLanguage = "contentLanguage"
|
||||
metadataKeyContentDisposition = "contentDisposition"
|
||||
metadataKeyCacheControl = "cacheControl"
|
||||
// Specifies the maximum number of HTTP requests that will be made to retry blob operations. A value
|
||||
// of zero means that no additional HTTP requests will be made.
|
||||
defaultBlobRetryCount = 3
|
||||
// Specifies the maximum number of blobs to return, including all BlobPrefix elements. If the request does not
|
||||
// specify maxresults the server will return up to 5,000 items.
|
||||
// See: https://docs.microsoft.com/en-us/rest/api/storageservices/list-blobs#uri-parameters
|
||||
|
|
@ -76,21 +68,12 @@ var ErrMissingBlobName = errors.New("blobName is a required attribute")
|
|||
|
||||
// AzureBlobStorage allows saving blobs to an Azure Blob Storage account.
|
||||
type AzureBlobStorage struct {
|
||||
metadata *blobStorageMetadata
|
||||
metadata *storageinternal.BlobStorageMetadata
|
||||
containerClient *container.Client
|
||||
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
type blobStorageMetadata struct {
|
||||
StorageAccount string `json:"storageAccount"`
|
||||
StorageAccessKey string `json:"storageAccessKey"`
|
||||
Container string `json:"container"`
|
||||
GetBlobRetryCount int32 `json:"getBlobRetryCount,string"`
|
||||
DecodeBase64 bool `json:"decodeBase64,string"`
|
||||
PublicAccessLevel azblob.PublicAccessType `json:"publicAccessLevel"`
|
||||
}
|
||||
|
||||
type createResponse struct {
|
||||
BlobURL string `json:"blobURL"`
|
||||
BlobName string `json:"blobName"`
|
||||
|
|
@ -118,110 +101,14 @@ func NewAzureBlobStorage(logger logger.Logger) bindings.OutputBinding {
|
|||
|
||||
// Init performs metadata parsing.
|
||||
func (a *AzureBlobStorage) Init(metadata bindings.Metadata) error {
|
||||
m, err := a.parseMetadata(metadata)
|
||||
var err error
|
||||
a.containerClient, a.metadata, err = storageinternal.CreateContainerStorageClient(a.logger, metadata.Properties)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
a.metadata = m
|
||||
|
||||
userAgent := "dapr-" + logger.DaprVersion
|
||||
options := container.ClientOptions{
|
||||
ClientOptions: azcore.ClientOptions{
|
||||
Retry: policy.RetryOptions{
|
||||
MaxRetries: a.metadata.GetBlobRetryCount,
|
||||
},
|
||||
Telemetry: policy.TelemetryOptions{
|
||||
ApplicationID: userAgent,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
settings, err := azauth.NewEnvironmentSettings("storage", metadata.Properties)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
customEndpoint, ok := metadata.Properties[endpointKey]
|
||||
var URL *url.URL
|
||||
if ok && customEndpoint != "" {
|
||||
var parseErr error
|
||||
URL, parseErr = url.Parse(fmt.Sprintf("%s/%s/%s", customEndpoint, m.StorageAccount, m.Container))
|
||||
if parseErr != nil {
|
||||
return parseErr
|
||||
}
|
||||
} else {
|
||||
env := settings.AzureEnvironment
|
||||
URL, _ = url.Parse(fmt.Sprintf("https://%s.blob.%s/%s", m.StorageAccount, env.StorageEndpointSuffix, m.Container))
|
||||
}
|
||||
|
||||
var clientErr error
|
||||
var client *container.Client
|
||||
// Try using shared key credentials first
|
||||
if m.StorageAccessKey != "" {
|
||||
credential, newSharedKeyErr := azblob.NewSharedKeyCredential(m.StorageAccount, m.StorageAccessKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid credentials with error: %w", newSharedKeyErr)
|
||||
}
|
||||
client, clientErr = container.NewClientWithSharedKeyCredential(URL.String(), credential, &options)
|
||||
if clientErr != nil {
|
||||
return fmt.Errorf("cannot init Blobstorage container client: %w", err)
|
||||
}
|
||||
a.containerClient = client
|
||||
} else {
|
||||
// fallback to AAD
|
||||
credential, tokenErr := settings.GetTokenCredential()
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid credentials with error: %w", tokenErr)
|
||||
}
|
||||
client, clientErr = container.NewClient(URL.String(), credential, &options)
|
||||
}
|
||||
if clientErr != nil {
|
||||
return fmt.Errorf("cannot init Blobstorage client: %w", clientErr)
|
||||
}
|
||||
|
||||
createContainerOptions := container.CreateOptions{
|
||||
Access: &m.PublicAccessLevel,
|
||||
Metadata: map[string]string{},
|
||||
}
|
||||
timeoutCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
_, err = client.Create(timeoutCtx, &createContainerOptions)
|
||||
cancel()
|
||||
// Don't return error, container might already exist
|
||||
a.logger.Debugf("error creating container: %w", err)
|
||||
a.containerClient = client
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AzureBlobStorage) parseMetadata(meta bindings.Metadata) (*blobStorageMetadata, error) {
|
||||
m := blobStorageMetadata{
|
||||
GetBlobRetryCount: defaultBlobRetryCount,
|
||||
}
|
||||
mdutils.DecodeMetadata(meta.Properties, &m)
|
||||
|
||||
if val, ok := mdutils.GetMetadataProperty(meta.Properties, azauth.StorageAccountNameKeys...); ok && val != "" {
|
||||
m.StorageAccount = val
|
||||
} else {
|
||||
return nil, fmt.Errorf("missing or empty %s field from metadata", azauth.StorageAccountNameKeys[0])
|
||||
}
|
||||
|
||||
if val, ok := mdutils.GetMetadataProperty(meta.Properties, azauth.StorageContainerNameKeys...); ok && val != "" {
|
||||
m.Container = val
|
||||
} else {
|
||||
return nil, fmt.Errorf("missing or empty %s field from metadata", azauth.StorageContainerNameKeys[0])
|
||||
}
|
||||
|
||||
// per the Dapr documentation "none" is a valid value
|
||||
if m.PublicAccessLevel == "none" {
|
||||
m.PublicAccessLevel = ""
|
||||
}
|
||||
if m.PublicAccessLevel != "" && !a.isValidPublicAccessType(m.PublicAccessLevel) {
|
||||
return nil, fmt.Errorf("invalid public access level: %s; allowed: %s",
|
||||
m.PublicAccessLevel, azblob.PossiblePublicAccessTypeValues())
|
||||
}
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func (a *AzureBlobStorage) Operations() []bindings.OperationKind {
|
||||
return []bindings.OperationKind{
|
||||
bindings.CreateOperation,
|
||||
|
|
@ -288,7 +175,7 @@ func (a *AzureBlobStorage) create(ctx context.Context, req *bindings.InvokeReque
|
|||
}
|
||||
|
||||
uploadOptions := azblob.UploadBufferOptions{
|
||||
Metadata: a.sanitizeMetadata(req.Metadata),
|
||||
Metadata: storageinternal.SanitizeMetadata(a.logger, req.Metadata),
|
||||
HTTPHeaders: &blobHTTPHeaders,
|
||||
TransactionalContentMD5: contentMD5,
|
||||
}
|
||||
|
|
@ -486,17 +373,6 @@ func (a *AzureBlobStorage) Invoke(ctx context.Context, req *bindings.InvokeReque
|
|||
}
|
||||
}
|
||||
|
||||
func (a *AzureBlobStorage) isValidPublicAccessType(accessType azblob.PublicAccessType) bool {
|
||||
validTypes := azblob.PossiblePublicAccessTypeValues()
|
||||
for _, item := range validTypes {
|
||||
if item == accessType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *AzureBlobStorage) isValidDeleteSnapshotsOptionType(accessType azblob.DeleteSnapshotsOptionType) bool {
|
||||
validTypes := azblob.PossibleDeleteSnapshotsOptionTypeValues()
|
||||
for _, item := range validTypes {
|
||||
|
|
@ -507,41 +383,3 @@ func (a *AzureBlobStorage) isValidDeleteSnapshotsOptionType(accessType azblob.De
|
|||
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *AzureBlobStorage) sanitizeMetadata(metadata map[string]string) map[string]string {
|
||||
for key, val := range metadata {
|
||||
// Keep only letters and digits
|
||||
n := 0
|
||||
newKey := make([]byte, len(key))
|
||||
for i := 0; i < len(key); i++ {
|
||||
if (key[i] >= 'A' && key[i] <= 'Z') ||
|
||||
(key[i] >= 'a' && key[i] <= 'z') ||
|
||||
(key[i] >= '0' && key[i] <= '9') {
|
||||
newKey[n] = key[i]
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
if n != len(key) {
|
||||
nks := string(newKey[:n])
|
||||
a.logger.Warnf("metadata key %s contains disallowed characters, sanitized to %s", key, nks)
|
||||
delete(metadata, key)
|
||||
metadata[nks] = val
|
||||
key = nks
|
||||
}
|
||||
|
||||
// Remove all non-ascii characters
|
||||
n = 0
|
||||
newVal := make([]byte, len(val))
|
||||
for i := 0; i < len(val); i++ {
|
||||
if val[i] > 127 {
|
||||
continue
|
||||
}
|
||||
newVal[n] = val[i]
|
||||
n++
|
||||
}
|
||||
metadata[key] = string(newVal[:n])
|
||||
}
|
||||
|
||||
return metadata
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,83 +17,12 @@ import (
|
|||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/dapr/components-contrib/bindings"
|
||||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
||||
func TestParseMetadata(t *testing.T) {
|
||||
m := bindings.Metadata{}
|
||||
blobStorage := NewAzureBlobStorage(logger.NewLogger("test")).(*AzureBlobStorage)
|
||||
|
||||
t.Run("parse all metadata", func(t *testing.T) {
|
||||
m.Properties = map[string]string{
|
||||
"storageAccount": "account",
|
||||
"storageAccessKey": "key",
|
||||
"container": "test",
|
||||
"getBlobRetryCount": "5",
|
||||
"decodeBase64": "true",
|
||||
}
|
||||
meta, err := blobStorage.parseMetadata(m)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "test", meta.Container)
|
||||
assert.Equal(t, "account", meta.StorageAccount)
|
||||
// storageAccessKey is parsed in the azauth package
|
||||
assert.Equal(t, true, meta.DecodeBase64)
|
||||
assert.Equal(t, int32(5), meta.GetBlobRetryCount)
|
||||
assert.Equal(t, "", string(meta.PublicAccessLevel))
|
||||
})
|
||||
|
||||
t.Run("parse metadata with publicAccessLevel = blob", func(t *testing.T) {
|
||||
m.Properties = map[string]string{
|
||||
"storageAccount": "account",
|
||||
"storageAccessKey": "key",
|
||||
"container": "test",
|
||||
"publicAccessLevel": "blob",
|
||||
}
|
||||
meta, err := blobStorage.parseMetadata(m)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, azblob.PublicAccessTypeBlob, meta.PublicAccessLevel)
|
||||
})
|
||||
|
||||
t.Run("parse metadata with publicAccessLevel = container", func(t *testing.T) {
|
||||
m.Properties = map[string]string{
|
||||
"storageAccount": "account",
|
||||
"storageAccessKey": "key",
|
||||
"container": "test",
|
||||
"publicAccessLevel": "container",
|
||||
}
|
||||
meta, err := blobStorage.parseMetadata(m)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, azblob.PublicAccessTypeContainer, meta.PublicAccessLevel)
|
||||
})
|
||||
|
||||
t.Run("parse metadata with invalid publicAccessLevel", func(t *testing.T) {
|
||||
m.Properties = map[string]string{
|
||||
"storageAccount": "account",
|
||||
"storageAccessKey": "key",
|
||||
"container": "test",
|
||||
"publicAccessLevel": "invalid",
|
||||
}
|
||||
_, err := blobStorage.parseMetadata(m)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("sanitize metadata if necessary", func(t *testing.T) {
|
||||
m.Properties = map[string]string{
|
||||
"somecustomfield": "some-custom-value",
|
||||
"specialfield": "special:valueÜ",
|
||||
"not-allowed:": "not-allowed",
|
||||
}
|
||||
meta := blobStorage.sanitizeMetadata(m.Properties)
|
||||
assert.Equal(t, meta["somecustomfield"], "some-custom-value")
|
||||
assert.Equal(t, meta["specialfield"], "special:value")
|
||||
assert.Equal(t, meta["notallowed"], "not-allowed")
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetOption(t *testing.T) {
|
||||
blobStorage := NewAzureBlobStorage(logger.NewLogger("test")).(*AzureBlobStorage)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,108 @@
|
|||
/*
|
||||
Copyright 2021 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package blobstorage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container"
|
||||
|
||||
azauth "github.com/dapr/components-contrib/internal/authentication/azure"
|
||||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
endpointKey = "endpoint"
|
||||
// Specifies the maximum number of HTTP requests that will be made to retry blob operations. A value
|
||||
// of zero means that no additional HTTP requests will be made.
|
||||
defaultBlobRetryCount = 3
|
||||
)
|
||||
|
||||
func CreateContainerStorageClient(log logger.Logger, meta map[string]string) (*container.Client, *BlobStorageMetadata, error) {
|
||||
m, err := parseMetadata(meta)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
userAgent := "dapr-" + logger.DaprVersion
|
||||
options := container.ClientOptions{
|
||||
ClientOptions: azcore.ClientOptions{
|
||||
Retry: policy.RetryOptions{
|
||||
MaxRetries: m.RetryCount,
|
||||
},
|
||||
Telemetry: policy.TelemetryOptions{
|
||||
ApplicationID: userAgent,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
settings, err := azauth.NewEnvironmentSettings("storage", meta)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
customEndpoint, ok := meta[endpointKey]
|
||||
var URL *url.URL
|
||||
if ok && customEndpoint != "" {
|
||||
var parseErr error
|
||||
URL, parseErr = url.Parse(fmt.Sprintf("%s/%s/%s", customEndpoint, m.AccountName, m.ContainerName))
|
||||
if parseErr != nil {
|
||||
return nil, nil, parseErr
|
||||
}
|
||||
} else {
|
||||
env := settings.AzureEnvironment
|
||||
URL, _ = url.Parse(fmt.Sprintf("https://%s.blob.%s/%s", m.AccountName, env.StorageEndpointSuffix, m.ContainerName))
|
||||
}
|
||||
|
||||
var clientErr error
|
||||
var client *container.Client
|
||||
// Try using shared key credentials first
|
||||
if m.AccountKey != "" {
|
||||
credential, newSharedKeyErr := azblob.NewSharedKeyCredential(m.AccountName, m.AccountKey)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid credentials with error: %w", newSharedKeyErr)
|
||||
}
|
||||
client, clientErr = container.NewClientWithSharedKeyCredential(URL.String(), credential, &options)
|
||||
if clientErr != nil {
|
||||
return nil, nil, fmt.Errorf("cannot init Blobstorage container client: %w", err)
|
||||
}
|
||||
} else {
|
||||
// fallback to AAD
|
||||
credential, tokenErr := settings.GetTokenCredential()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid credentials with error: %w", tokenErr)
|
||||
}
|
||||
client, clientErr = container.NewClient(URL.String(), credential, &options)
|
||||
}
|
||||
if clientErr != nil {
|
||||
return nil, nil, fmt.Errorf("cannot init Blobstorage client: %w", clientErr)
|
||||
}
|
||||
|
||||
createContainerOptions := container.CreateOptions{
|
||||
Access: &m.PublicAccessLevel,
|
||||
Metadata: map[string]string{},
|
||||
}
|
||||
timeoutCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
_, err = client.Create(timeoutCtx, &createContainerOptions)
|
||||
cancel()
|
||||
// Don't return error, container might already exist
|
||||
log.Debugf("error creating container: %w", err)
|
||||
|
||||
return client, m, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,123 @@
|
|||
/*
|
||||
Copyright 2021 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package blobstorage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
|
||||
|
||||
azauth "github.com/dapr/components-contrib/internal/authentication/azure"
|
||||
mdutils "github.com/dapr/components-contrib/metadata"
|
||||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
||||
type BlobStorageMetadata struct {
|
||||
AccountName string
|
||||
AccountKey string
|
||||
ContainerName string
|
||||
RetryCount int32 `json:"retryCount,string"`
|
||||
DecodeBase64 bool `json:"decodeBase64,string"`
|
||||
PublicAccessLevel azblob.PublicAccessType
|
||||
}
|
||||
|
||||
func parseMetadata(meta map[string]string) (*BlobStorageMetadata, error) {
|
||||
m := BlobStorageMetadata{
|
||||
RetryCount: defaultBlobRetryCount,
|
||||
}
|
||||
mdutils.DecodeMetadata(meta, &m)
|
||||
|
||||
if val, ok := mdutils.GetMetadataProperty(meta, azauth.StorageAccountNameKeys...); ok && val != "" {
|
||||
m.AccountName = val
|
||||
} else {
|
||||
return nil, fmt.Errorf("missing or empty %s field from metadata", azauth.StorageAccountNameKeys[0])
|
||||
}
|
||||
|
||||
if val, ok := mdutils.GetMetadataProperty(meta, azauth.StorageContainerNameKeys...); ok && val != "" {
|
||||
m.ContainerName = val
|
||||
} else {
|
||||
return nil, fmt.Errorf("missing or empty %s field from metadata", azauth.StorageContainerNameKeys[0])
|
||||
}
|
||||
|
||||
// per the Dapr documentation "none" is a valid value
|
||||
if m.PublicAccessLevel == "none" {
|
||||
m.PublicAccessLevel = ""
|
||||
}
|
||||
if m.PublicAccessLevel != "" && !isValidPublicAccessType(m.PublicAccessLevel) {
|
||||
return nil, fmt.Errorf("invalid public access level: %s; allowed: %s",
|
||||
m.PublicAccessLevel, azblob.PossiblePublicAccessTypeValues())
|
||||
}
|
||||
|
||||
// we need this key for backwards compatibility
|
||||
if val, ok := meta["getBlobRetryCount"]; ok && val != "" {
|
||||
// convert val from string to int32
|
||||
parseInt, err := strconv.ParseInt(val, 10, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.RetryCount = int32(parseInt)
|
||||
}
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func isValidPublicAccessType(accessType azblob.PublicAccessType) bool {
|
||||
validTypes := azblob.PossiblePublicAccessTypeValues()
|
||||
for _, item := range validTypes {
|
||||
if item == accessType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func SanitizeMetadata(log logger.Logger, metadata map[string]string) map[string]string {
|
||||
for key, val := range metadata {
|
||||
// Keep only letters and digits
|
||||
n := 0
|
||||
newKey := make([]byte, len(key))
|
||||
for i := 0; i < len(key); i++ {
|
||||
if (key[i] >= 'A' && key[i] <= 'Z') ||
|
||||
(key[i] >= 'a' && key[i] <= 'z') ||
|
||||
(key[i] >= '0' && key[i] <= '9') {
|
||||
newKey[n] = key[i]
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
if n != len(key) {
|
||||
nks := string(newKey[:n])
|
||||
log.Warnf("metadata key %s contains disallowed characters, sanitized to %s", key, nks)
|
||||
delete(metadata, key)
|
||||
metadata[nks] = val
|
||||
key = nks
|
||||
}
|
||||
|
||||
// Remove all non-ascii characters
|
||||
n = 0
|
||||
newVal := make([]byte, len(val))
|
||||
for i := 0; i < len(val); i++ {
|
||||
if val[i] > 127 {
|
||||
continue
|
||||
}
|
||||
newVal[n] = val[i]
|
||||
n++
|
||||
}
|
||||
metadata[key] = string(newVal[:n])
|
||||
}
|
||||
|
||||
return metadata
|
||||
}
|
||||
|
|
@ -0,0 +1,93 @@
|
|||
/*
|
||||
Copyright 2021 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package blobstorage
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
|
||||
"github.com/dapr/kit/logger"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseMetadata(t *testing.T) {
|
||||
logger := logger.NewLogger("test")
|
||||
var m map[string]string
|
||||
|
||||
t.Run("parse all metadata", func(t *testing.T) {
|
||||
m = map[string]string{
|
||||
"storageAccount": "account",
|
||||
"storageAccessKey": "key",
|
||||
"container": "test",
|
||||
"getBlobRetryCount": "5",
|
||||
"decodeBase64": "true",
|
||||
}
|
||||
meta, err := parseMetadata(m)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "test", meta.ContainerName)
|
||||
assert.Equal(t, "account", meta.AccountName)
|
||||
// storageAccessKey is parsed in the azauth package
|
||||
assert.Equal(t, true, meta.DecodeBase64)
|
||||
assert.Equal(t, int32(5), meta.RetryCount)
|
||||
assert.Equal(t, "", string(meta.PublicAccessLevel))
|
||||
})
|
||||
|
||||
t.Run("parse metadata with publicAccessLevel = blob", func(t *testing.T) {
|
||||
m = map[string]string{
|
||||
"storageAccount": "account",
|
||||
"storageAccessKey": "key",
|
||||
"container": "test",
|
||||
"publicAccessLevel": "blob",
|
||||
}
|
||||
meta, err := parseMetadata(m)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, azblob.PublicAccessTypeBlob, meta.PublicAccessLevel)
|
||||
})
|
||||
|
||||
t.Run("parse metadata with publicAccessLevel = container", func(t *testing.T) {
|
||||
m = map[string]string{
|
||||
"storageAccount": "account",
|
||||
"storageAccessKey": "key",
|
||||
"container": "test",
|
||||
"publicAccessLevel": "container",
|
||||
}
|
||||
meta, err := parseMetadata(m)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, azblob.PublicAccessTypeContainer, meta.PublicAccessLevel)
|
||||
})
|
||||
|
||||
t.Run("parse metadata with invalid publicAccessLevel", func(t *testing.T) {
|
||||
m = map[string]string{
|
||||
"storageAccount": "account",
|
||||
"storageAccessKey": "key",
|
||||
"container": "test",
|
||||
"publicAccessLevel": "invalid",
|
||||
}
|
||||
_, err := parseMetadata(m)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("sanitize metadata if necessary", func(t *testing.T) {
|
||||
m = map[string]string{
|
||||
"somecustomfield": "some-custom-value",
|
||||
"specialfield": "special:valueÜ",
|
||||
"not-allowed:": "not-allowed",
|
||||
}
|
||||
meta := SanitizeMetadata(logger, m)
|
||||
assert.Equal(t, meta["somecustomfield"], "some-custom-value")
|
||||
assert.Equal(t, meta["specialfield"], "special:value")
|
||||
assert.Equal(t, meta["notallowed"], "not-allowed")
|
||||
})
|
||||
}
|
||||
|
|
@ -40,20 +40,17 @@ import (
|
|||
b64 "encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container"
|
||||
jsoniter "github.com/json-iterator/go"
|
||||
|
||||
azauth "github.com/dapr/components-contrib/internal/authentication/azure"
|
||||
storageinternal "github.com/dapr/components-contrib/internal/component/azure/blobstorage"
|
||||
mdutils "github.com/dapr/components-contrib/metadata"
|
||||
"github.com/dapr/components-contrib/state"
|
||||
"github.com/dapr/kit/logger"
|
||||
|
|
@ -81,80 +78,13 @@ type StateStore struct {
|
|||
logger logger.Logger
|
||||
}
|
||||
|
||||
type blobStorageMetadata struct {
|
||||
AccountName string
|
||||
ContainerName string
|
||||
AccountKey string
|
||||
}
|
||||
|
||||
// Init the connection to blob storage, optionally creates a blob container if it doesn't exist.
|
||||
func (r *StateStore) Init(metadata state.Metadata) error {
|
||||
m, err := getBlobStorageMetadata(metadata.Properties)
|
||||
var err error
|
||||
r.containerClient, _, err = storageinternal.CreateContainerStorageClient(r.logger, metadata.Properties)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userAgent := "dapr-" + logger.DaprVersion
|
||||
options := container.ClientOptions{
|
||||
ClientOptions: azcore.ClientOptions{
|
||||
Telemetry: policy.TelemetryOptions{
|
||||
ApplicationID: userAgent,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
settings, err := azauth.NewEnvironmentSettings("storage", metadata.Properties)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
customEndpoint, ok := metadata.Properties[endpointKey]
|
||||
var URL *url.URL
|
||||
if ok && customEndpoint != "" {
|
||||
var parseErr error
|
||||
URL, parseErr = url.Parse(fmt.Sprintf("%s/%s/%s", customEndpoint, m.AccountName, m.ContainerName))
|
||||
if parseErr != nil {
|
||||
return parseErr
|
||||
}
|
||||
} else {
|
||||
env := settings.AzureEnvironment
|
||||
URL, _ = url.Parse(fmt.Sprintf("https://%s.blob.%s/%s", m.AccountName, env.StorageEndpointSuffix, m.ContainerName))
|
||||
}
|
||||
|
||||
var clientErr error
|
||||
var client *container.Client
|
||||
// Try using shared key credentials first
|
||||
if m.AccountKey != "" {
|
||||
credential, newSharedKeyErr := azblob.NewSharedKeyCredential(m.AccountName, m.AccountKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid credentials with error: %w", newSharedKeyErr)
|
||||
}
|
||||
client, clientErr = container.NewClientWithSharedKeyCredential(URL.String(), credential, &options)
|
||||
if clientErr != nil {
|
||||
return fmt.Errorf("cannot init Blobstorage container client: %w", err)
|
||||
}
|
||||
r.containerClient = client
|
||||
} else {
|
||||
// fallback to AAD
|
||||
credential, tokenErr := settings.GetTokenCredential()
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid credentials with error: %w", tokenErr)
|
||||
}
|
||||
client, clientErr = container.NewClient(URL.String(), credential, &options)
|
||||
}
|
||||
if clientErr != nil {
|
||||
return fmt.Errorf("cannot init Blobstorage client: %w", clientErr)
|
||||
}
|
||||
|
||||
createContainerOptions := container.CreateOptions{
|
||||
Access: nil,
|
||||
}
|
||||
timeoutCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
_, err = client.Create(timeoutCtx, &createContainerOptions)
|
||||
cancel()
|
||||
// Don't return error, container might already exist
|
||||
r.logger.Debugf("error creating container: %w", err)
|
||||
r.containerClient = client
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -193,7 +123,7 @@ func (r *StateStore) Ping() error {
|
|||
}
|
||||
|
||||
func (r *StateStore) GetComponentMetadata() map[string]string {
|
||||
metadataStruct := blobStorageMetadata{}
|
||||
metadataStruct := storageinternal.BlobStorageMetadata{}
|
||||
metadataInfo := map[string]string{}
|
||||
mdutils.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo)
|
||||
return metadataInfo
|
||||
|
|
@ -211,25 +141,6 @@ func NewAzureBlobStorageStore(logger logger.Logger) state.Store {
|
|||
return s
|
||||
}
|
||||
|
||||
func getBlobStorageMetadata(meta map[string]string) (*blobStorageMetadata, error) {
|
||||
m := blobStorageMetadata{}
|
||||
err := mdutils.DecodeMetadata(meta, &m)
|
||||
|
||||
if val, ok := mdutils.GetMetadataProperty(meta, azauth.StorageAccountNameKeys...); ok && val != "" {
|
||||
m.AccountName = val
|
||||
} else {
|
||||
return nil, fmt.Errorf("missing or empty %s field from metadata", azauth.StorageAccountNameKeys[0])
|
||||
}
|
||||
|
||||
if val, ok := mdutils.GetMetadataProperty(meta, azauth.StorageContainerNameKeys...); ok && val != "" {
|
||||
m.ContainerName = val
|
||||
} else {
|
||||
return nil, fmt.Errorf("missing or empty %s field from metadata", azauth.StorageContainerNameKeys[0])
|
||||
}
|
||||
|
||||
return &m, err
|
||||
}
|
||||
|
||||
func (r *StateStore) readFile(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
|
||||
blockBlobClient := r.containerClient.NewBlockBlobClient(getFileName(req.Key))
|
||||
|
||||
|
|
@ -289,7 +200,7 @@ func (r *StateStore) writeFile(ctx context.Context, req *state.SetRequest) error
|
|||
|
||||
uploadOptions := azblob.UploadBufferOptions{
|
||||
AccessConditions: &accessConditions,
|
||||
Metadata: req.Metadata,
|
||||
Metadata: storageinternal.SanitizeMetadata(r.logger, req.Metadata),
|
||||
HTTPHeaders: &blobHTTPHeaders,
|
||||
Concurrency: 16,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -58,26 +58,6 @@ func TestInit(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestGetBlobStorageMetaData(t *testing.T) {
|
||||
t.Run("Nothing at all passed", func(t *testing.T) {
|
||||
m := make(map[string]string)
|
||||
_, err := getBlobStorageMetadata(m)
|
||||
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("All parameters passed and parsed", func(t *testing.T) {
|
||||
m := make(map[string]string)
|
||||
m["accountName"] = "acc"
|
||||
m["containerName"] = "dapr"
|
||||
meta, err := getBlobStorageMetadata(m)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "acc", meta.AccountName)
|
||||
assert.Equal(t, "dapr", meta.ContainerName)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFileName(t *testing.T) {
|
||||
t.Run("Valid composite key", func(t *testing.T) {
|
||||
key := getFileName("app_id||key")
|
||||
|
|
|
|||
Loading…
Reference in New Issue