AzBlob components: Extract shared code
Signed-off-by: Bernd Verst <4535280+berndverst@users.noreply.github.com>
This commit is contained in:
parent
70eb9f3a9c
commit
8811d5e64f
|
|
@ -20,12 +20,8 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/url"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
|
||||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
|
|
||||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
|
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
|
||||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob"
|
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob"
|
||||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blockblob"
|
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blockblob"
|
||||||
|
|
@ -33,8 +29,7 @@ import (
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
"github.com/dapr/components-contrib/bindings"
|
"github.com/dapr/components-contrib/bindings"
|
||||||
azauth "github.com/dapr/components-contrib/internal/authentication/azure"
|
storageinternal "github.com/dapr/components-contrib/internal/component/azure/blobstorage"
|
||||||
mdutils "github.com/dapr/components-contrib/metadata"
|
|
||||||
"github.com/dapr/kit/logger"
|
"github.com/dapr/kit/logger"
|
||||||
"github.com/dapr/kit/ptr"
|
"github.com/dapr/kit/ptr"
|
||||||
)
|
)
|
||||||
|
|
@ -62,9 +57,6 @@ const (
|
||||||
metadataKeyContentLanguage = "contentLanguage"
|
metadataKeyContentLanguage = "contentLanguage"
|
||||||
metadataKeyContentDisposition = "contentDisposition"
|
metadataKeyContentDisposition = "contentDisposition"
|
||||||
metadataKeyCacheControl = "cacheControl"
|
metadataKeyCacheControl = "cacheControl"
|
||||||
// Specifies the maximum number of HTTP requests that will be made to retry blob operations. A value
|
|
||||||
// of zero means that no additional HTTP requests will be made.
|
|
||||||
defaultBlobRetryCount = 3
|
|
||||||
// Specifies the maximum number of blobs to return, including all BlobPrefix elements. If the request does not
|
// Specifies the maximum number of blobs to return, including all BlobPrefix elements. If the request does not
|
||||||
// specify maxresults the server will return up to 5,000 items.
|
// specify maxresults the server will return up to 5,000 items.
|
||||||
// See: https://docs.microsoft.com/en-us/rest/api/storageservices/list-blobs#uri-parameters
|
// See: https://docs.microsoft.com/en-us/rest/api/storageservices/list-blobs#uri-parameters
|
||||||
|
|
@ -76,21 +68,12 @@ var ErrMissingBlobName = errors.New("blobName is a required attribute")
|
||||||
|
|
||||||
// AzureBlobStorage allows saving blobs to an Azure Blob Storage account.
|
// AzureBlobStorage allows saving blobs to an Azure Blob Storage account.
|
||||||
type AzureBlobStorage struct {
|
type AzureBlobStorage struct {
|
||||||
metadata *blobStorageMetadata
|
metadata *storageinternal.BlobStorageMetadata
|
||||||
containerClient *container.Client
|
containerClient *container.Client
|
||||||
|
|
||||||
logger logger.Logger
|
logger logger.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
type blobStorageMetadata struct {
|
|
||||||
StorageAccount string `json:"storageAccount"`
|
|
||||||
StorageAccessKey string `json:"storageAccessKey"`
|
|
||||||
Container string `json:"container"`
|
|
||||||
GetBlobRetryCount int32 `json:"getBlobRetryCount,string"`
|
|
||||||
DecodeBase64 bool `json:"decodeBase64,string"`
|
|
||||||
PublicAccessLevel azblob.PublicAccessType `json:"publicAccessLevel"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type createResponse struct {
|
type createResponse struct {
|
||||||
BlobURL string `json:"blobURL"`
|
BlobURL string `json:"blobURL"`
|
||||||
BlobName string `json:"blobName"`
|
BlobName string `json:"blobName"`
|
||||||
|
|
@ -118,110 +101,14 @@ func NewAzureBlobStorage(logger logger.Logger) bindings.OutputBinding {
|
||||||
|
|
||||||
// Init performs metadata parsing.
|
// Init performs metadata parsing.
|
||||||
func (a *AzureBlobStorage) Init(metadata bindings.Metadata) error {
|
func (a *AzureBlobStorage) Init(metadata bindings.Metadata) error {
|
||||||
m, err := a.parseMetadata(metadata)
|
var err error
|
||||||
|
a.containerClient, a.metadata, err = storageinternal.CreateContainerStorageClient(a.logger, metadata.Properties)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
a.metadata = m
|
|
||||||
|
|
||||||
userAgent := "dapr-" + logger.DaprVersion
|
|
||||||
options := container.ClientOptions{
|
|
||||||
ClientOptions: azcore.ClientOptions{
|
|
||||||
Retry: policy.RetryOptions{
|
|
||||||
MaxRetries: a.metadata.GetBlobRetryCount,
|
|
||||||
},
|
|
||||||
Telemetry: policy.TelemetryOptions{
|
|
||||||
ApplicationID: userAgent,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
settings, err := azauth.NewEnvironmentSettings("storage", metadata.Properties)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
customEndpoint, ok := metadata.Properties[endpointKey]
|
|
||||||
var URL *url.URL
|
|
||||||
if ok && customEndpoint != "" {
|
|
||||||
var parseErr error
|
|
||||||
URL, parseErr = url.Parse(fmt.Sprintf("%s/%s/%s", customEndpoint, m.StorageAccount, m.Container))
|
|
||||||
if parseErr != nil {
|
|
||||||
return parseErr
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
env := settings.AzureEnvironment
|
|
||||||
URL, _ = url.Parse(fmt.Sprintf("https://%s.blob.%s/%s", m.StorageAccount, env.StorageEndpointSuffix, m.Container))
|
|
||||||
}
|
|
||||||
|
|
||||||
var clientErr error
|
|
||||||
var client *container.Client
|
|
||||||
// Try using shared key credentials first
|
|
||||||
if m.StorageAccessKey != "" {
|
|
||||||
credential, newSharedKeyErr := azblob.NewSharedKeyCredential(m.StorageAccount, m.StorageAccessKey)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("invalid credentials with error: %w", newSharedKeyErr)
|
|
||||||
}
|
|
||||||
client, clientErr = container.NewClientWithSharedKeyCredential(URL.String(), credential, &options)
|
|
||||||
if clientErr != nil {
|
|
||||||
return fmt.Errorf("cannot init Blobstorage container client: %w", err)
|
|
||||||
}
|
|
||||||
a.containerClient = client
|
|
||||||
} else {
|
|
||||||
// fallback to AAD
|
|
||||||
credential, tokenErr := settings.GetTokenCredential()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("invalid credentials with error: %w", tokenErr)
|
|
||||||
}
|
|
||||||
client, clientErr = container.NewClient(URL.String(), credential, &options)
|
|
||||||
}
|
|
||||||
if clientErr != nil {
|
|
||||||
return fmt.Errorf("cannot init Blobstorage client: %w", clientErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
createContainerOptions := container.CreateOptions{
|
|
||||||
Access: &m.PublicAccessLevel,
|
|
||||||
Metadata: map[string]string{},
|
|
||||||
}
|
|
||||||
timeoutCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
||||||
_, err = client.Create(timeoutCtx, &createContainerOptions)
|
|
||||||
cancel()
|
|
||||||
// Don't return error, container might already exist
|
|
||||||
a.logger.Debugf("error creating container: %w", err)
|
|
||||||
a.containerClient = client
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AzureBlobStorage) parseMetadata(meta bindings.Metadata) (*blobStorageMetadata, error) {
|
|
||||||
m := blobStorageMetadata{
|
|
||||||
GetBlobRetryCount: defaultBlobRetryCount,
|
|
||||||
}
|
|
||||||
mdutils.DecodeMetadata(meta.Properties, &m)
|
|
||||||
|
|
||||||
if val, ok := mdutils.GetMetadataProperty(meta.Properties, azauth.StorageAccountNameKeys...); ok && val != "" {
|
|
||||||
m.StorageAccount = val
|
|
||||||
} else {
|
|
||||||
return nil, fmt.Errorf("missing or empty %s field from metadata", azauth.StorageAccountNameKeys[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
if val, ok := mdutils.GetMetadataProperty(meta.Properties, azauth.StorageContainerNameKeys...); ok && val != "" {
|
|
||||||
m.Container = val
|
|
||||||
} else {
|
|
||||||
return nil, fmt.Errorf("missing or empty %s field from metadata", azauth.StorageContainerNameKeys[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
// per the Dapr documentation "none" is a valid value
|
|
||||||
if m.PublicAccessLevel == "none" {
|
|
||||||
m.PublicAccessLevel = ""
|
|
||||||
}
|
|
||||||
if m.PublicAccessLevel != "" && !a.isValidPublicAccessType(m.PublicAccessLevel) {
|
|
||||||
return nil, fmt.Errorf("invalid public access level: %s; allowed: %s",
|
|
||||||
m.PublicAccessLevel, azblob.PossiblePublicAccessTypeValues())
|
|
||||||
}
|
|
||||||
|
|
||||||
return &m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AzureBlobStorage) Operations() []bindings.OperationKind {
|
func (a *AzureBlobStorage) Operations() []bindings.OperationKind {
|
||||||
return []bindings.OperationKind{
|
return []bindings.OperationKind{
|
||||||
bindings.CreateOperation,
|
bindings.CreateOperation,
|
||||||
|
|
@ -288,7 +175,7 @@ func (a *AzureBlobStorage) create(ctx context.Context, req *bindings.InvokeReque
|
||||||
}
|
}
|
||||||
|
|
||||||
uploadOptions := azblob.UploadBufferOptions{
|
uploadOptions := azblob.UploadBufferOptions{
|
||||||
Metadata: a.sanitizeMetadata(req.Metadata),
|
Metadata: storageinternal.SanitizeMetadata(a.logger, req.Metadata),
|
||||||
HTTPHeaders: &blobHTTPHeaders,
|
HTTPHeaders: &blobHTTPHeaders,
|
||||||
TransactionalContentMD5: contentMD5,
|
TransactionalContentMD5: contentMD5,
|
||||||
}
|
}
|
||||||
|
|
@ -486,17 +373,6 @@ func (a *AzureBlobStorage) Invoke(ctx context.Context, req *bindings.InvokeReque
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AzureBlobStorage) isValidPublicAccessType(accessType azblob.PublicAccessType) bool {
|
|
||||||
validTypes := azblob.PossiblePublicAccessTypeValues()
|
|
||||||
for _, item := range validTypes {
|
|
||||||
if item == accessType {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AzureBlobStorage) isValidDeleteSnapshotsOptionType(accessType azblob.DeleteSnapshotsOptionType) bool {
|
func (a *AzureBlobStorage) isValidDeleteSnapshotsOptionType(accessType azblob.DeleteSnapshotsOptionType) bool {
|
||||||
validTypes := azblob.PossibleDeleteSnapshotsOptionTypeValues()
|
validTypes := azblob.PossibleDeleteSnapshotsOptionTypeValues()
|
||||||
for _, item := range validTypes {
|
for _, item := range validTypes {
|
||||||
|
|
@ -507,41 +383,3 @@ func (a *AzureBlobStorage) isValidDeleteSnapshotsOptionType(accessType azblob.De
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AzureBlobStorage) sanitizeMetadata(metadata map[string]string) map[string]string {
|
|
||||||
for key, val := range metadata {
|
|
||||||
// Keep only letters and digits
|
|
||||||
n := 0
|
|
||||||
newKey := make([]byte, len(key))
|
|
||||||
for i := 0; i < len(key); i++ {
|
|
||||||
if (key[i] >= 'A' && key[i] <= 'Z') ||
|
|
||||||
(key[i] >= 'a' && key[i] <= 'z') ||
|
|
||||||
(key[i] >= '0' && key[i] <= '9') {
|
|
||||||
newKey[n] = key[i]
|
|
||||||
n++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if n != len(key) {
|
|
||||||
nks := string(newKey[:n])
|
|
||||||
a.logger.Warnf("metadata key %s contains disallowed characters, sanitized to %s", key, nks)
|
|
||||||
delete(metadata, key)
|
|
||||||
metadata[nks] = val
|
|
||||||
key = nks
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove all non-ascii characters
|
|
||||||
n = 0
|
|
||||||
newVal := make([]byte, len(val))
|
|
||||||
for i := 0; i < len(val); i++ {
|
|
||||||
if val[i] > 127 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
newVal[n] = val[i]
|
|
||||||
n++
|
|
||||||
}
|
|
||||||
metadata[key] = string(newVal[:n])
|
|
||||||
}
|
|
||||||
|
|
||||||
return metadata
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -17,83 +17,12 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/dapr/components-contrib/bindings"
|
"github.com/dapr/components-contrib/bindings"
|
||||||
"github.com/dapr/kit/logger"
|
"github.com/dapr/kit/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseMetadata(t *testing.T) {
|
|
||||||
m := bindings.Metadata{}
|
|
||||||
blobStorage := NewAzureBlobStorage(logger.NewLogger("test")).(*AzureBlobStorage)
|
|
||||||
|
|
||||||
t.Run("parse all metadata", func(t *testing.T) {
|
|
||||||
m.Properties = map[string]string{
|
|
||||||
"storageAccount": "account",
|
|
||||||
"storageAccessKey": "key",
|
|
||||||
"container": "test",
|
|
||||||
"getBlobRetryCount": "5",
|
|
||||||
"decodeBase64": "true",
|
|
||||||
}
|
|
||||||
meta, err := blobStorage.parseMetadata(m)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, "test", meta.Container)
|
|
||||||
assert.Equal(t, "account", meta.StorageAccount)
|
|
||||||
// storageAccessKey is parsed in the azauth package
|
|
||||||
assert.Equal(t, true, meta.DecodeBase64)
|
|
||||||
assert.Equal(t, int32(5), meta.GetBlobRetryCount)
|
|
||||||
assert.Equal(t, "", string(meta.PublicAccessLevel))
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("parse metadata with publicAccessLevel = blob", func(t *testing.T) {
|
|
||||||
m.Properties = map[string]string{
|
|
||||||
"storageAccount": "account",
|
|
||||||
"storageAccessKey": "key",
|
|
||||||
"container": "test",
|
|
||||||
"publicAccessLevel": "blob",
|
|
||||||
}
|
|
||||||
meta, err := blobStorage.parseMetadata(m)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, azblob.PublicAccessTypeBlob, meta.PublicAccessLevel)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("parse metadata with publicAccessLevel = container", func(t *testing.T) {
|
|
||||||
m.Properties = map[string]string{
|
|
||||||
"storageAccount": "account",
|
|
||||||
"storageAccessKey": "key",
|
|
||||||
"container": "test",
|
|
||||||
"publicAccessLevel": "container",
|
|
||||||
}
|
|
||||||
meta, err := blobStorage.parseMetadata(m)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, azblob.PublicAccessTypeContainer, meta.PublicAccessLevel)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("parse metadata with invalid publicAccessLevel", func(t *testing.T) {
|
|
||||||
m.Properties = map[string]string{
|
|
||||||
"storageAccount": "account",
|
|
||||||
"storageAccessKey": "key",
|
|
||||||
"container": "test",
|
|
||||||
"publicAccessLevel": "invalid",
|
|
||||||
}
|
|
||||||
_, err := blobStorage.parseMetadata(m)
|
|
||||||
assert.Error(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("sanitize metadata if necessary", func(t *testing.T) {
|
|
||||||
m.Properties = map[string]string{
|
|
||||||
"somecustomfield": "some-custom-value",
|
|
||||||
"specialfield": "special:valueÜ",
|
|
||||||
"not-allowed:": "not-allowed",
|
|
||||||
}
|
|
||||||
meta := blobStorage.sanitizeMetadata(m.Properties)
|
|
||||||
assert.Equal(t, meta["somecustomfield"], "some-custom-value")
|
|
||||||
assert.Equal(t, meta["specialfield"], "special:value")
|
|
||||||
assert.Equal(t, meta["notallowed"], "not-allowed")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetOption(t *testing.T) {
|
func TestGetOption(t *testing.T) {
|
||||||
blobStorage := NewAzureBlobStorage(logger.NewLogger("test")).(*AzureBlobStorage)
|
blobStorage := NewAzureBlobStorage(logger.NewLogger("test")).(*AzureBlobStorage)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,108 @@
|
||||||
|
/*
|
||||||
|
Copyright 2021 The Dapr Authors
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package blobstorage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||||
|
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
|
||||||
|
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
|
||||||
|
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container"
|
||||||
|
|
||||||
|
azauth "github.com/dapr/components-contrib/internal/authentication/azure"
|
||||||
|
"github.com/dapr/kit/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
endpointKey = "endpoint"
|
||||||
|
// Specifies the maximum number of HTTP requests that will be made to retry blob operations. A value
|
||||||
|
// of zero means that no additional HTTP requests will be made.
|
||||||
|
defaultBlobRetryCount = 3
|
||||||
|
)
|
||||||
|
|
||||||
|
func CreateContainerStorageClient(log logger.Logger, meta map[string]string) (*container.Client, *BlobStorageMetadata, error) {
|
||||||
|
m, err := parseMetadata(meta)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
userAgent := "dapr-" + logger.DaprVersion
|
||||||
|
options := container.ClientOptions{
|
||||||
|
ClientOptions: azcore.ClientOptions{
|
||||||
|
Retry: policy.RetryOptions{
|
||||||
|
MaxRetries: m.RetryCount,
|
||||||
|
},
|
||||||
|
Telemetry: policy.TelemetryOptions{
|
||||||
|
ApplicationID: userAgent,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
settings, err := azauth.NewEnvironmentSettings("storage", meta)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
customEndpoint, ok := meta[endpointKey]
|
||||||
|
var URL *url.URL
|
||||||
|
if ok && customEndpoint != "" {
|
||||||
|
var parseErr error
|
||||||
|
URL, parseErr = url.Parse(fmt.Sprintf("%s/%s/%s", customEndpoint, m.AccountName, m.ContainerName))
|
||||||
|
if parseErr != nil {
|
||||||
|
return nil, nil, parseErr
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
env := settings.AzureEnvironment
|
||||||
|
URL, _ = url.Parse(fmt.Sprintf("https://%s.blob.%s/%s", m.AccountName, env.StorageEndpointSuffix, m.ContainerName))
|
||||||
|
}
|
||||||
|
|
||||||
|
var clientErr error
|
||||||
|
var client *container.Client
|
||||||
|
// Try using shared key credentials first
|
||||||
|
if m.AccountKey != "" {
|
||||||
|
credential, newSharedKeyErr := azblob.NewSharedKeyCredential(m.AccountName, m.AccountKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("invalid credentials with error: %w", newSharedKeyErr)
|
||||||
|
}
|
||||||
|
client, clientErr = container.NewClientWithSharedKeyCredential(URL.String(), credential, &options)
|
||||||
|
if clientErr != nil {
|
||||||
|
return nil, nil, fmt.Errorf("cannot init Blobstorage container client: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// fallback to AAD
|
||||||
|
credential, tokenErr := settings.GetTokenCredential()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("invalid credentials with error: %w", tokenErr)
|
||||||
|
}
|
||||||
|
client, clientErr = container.NewClient(URL.String(), credential, &options)
|
||||||
|
}
|
||||||
|
if clientErr != nil {
|
||||||
|
return nil, nil, fmt.Errorf("cannot init Blobstorage client: %w", clientErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
createContainerOptions := container.CreateOptions{
|
||||||
|
Access: &m.PublicAccessLevel,
|
||||||
|
Metadata: map[string]string{},
|
||||||
|
}
|
||||||
|
timeoutCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
_, err = client.Create(timeoutCtx, &createContainerOptions)
|
||||||
|
cancel()
|
||||||
|
// Don't return error, container might already exist
|
||||||
|
log.Debugf("error creating container: %w", err)
|
||||||
|
|
||||||
|
return client, m, nil
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,123 @@
|
||||||
|
/*
|
||||||
|
Copyright 2021 The Dapr Authors
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package blobstorage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
|
||||||
|
|
||||||
|
azauth "github.com/dapr/components-contrib/internal/authentication/azure"
|
||||||
|
mdutils "github.com/dapr/components-contrib/metadata"
|
||||||
|
"github.com/dapr/kit/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BlobStorageMetadata struct {
|
||||||
|
AccountName string
|
||||||
|
AccountKey string
|
||||||
|
ContainerName string
|
||||||
|
RetryCount int32 `json:"retryCount,string"`
|
||||||
|
DecodeBase64 bool `json:"decodeBase64,string"`
|
||||||
|
PublicAccessLevel azblob.PublicAccessType
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseMetadata(meta map[string]string) (*BlobStorageMetadata, error) {
|
||||||
|
m := BlobStorageMetadata{
|
||||||
|
RetryCount: defaultBlobRetryCount,
|
||||||
|
}
|
||||||
|
mdutils.DecodeMetadata(meta, &m)
|
||||||
|
|
||||||
|
if val, ok := mdutils.GetMetadataProperty(meta, azauth.StorageAccountNameKeys...); ok && val != "" {
|
||||||
|
m.AccountName = val
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("missing or empty %s field from metadata", azauth.StorageAccountNameKeys[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if val, ok := mdutils.GetMetadataProperty(meta, azauth.StorageContainerNameKeys...); ok && val != "" {
|
||||||
|
m.ContainerName = val
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("missing or empty %s field from metadata", azauth.StorageContainerNameKeys[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
// per the Dapr documentation "none" is a valid value
|
||||||
|
if m.PublicAccessLevel == "none" {
|
||||||
|
m.PublicAccessLevel = ""
|
||||||
|
}
|
||||||
|
if m.PublicAccessLevel != "" && !isValidPublicAccessType(m.PublicAccessLevel) {
|
||||||
|
return nil, fmt.Errorf("invalid public access level: %s; allowed: %s",
|
||||||
|
m.PublicAccessLevel, azblob.PossiblePublicAccessTypeValues())
|
||||||
|
}
|
||||||
|
|
||||||
|
// we need this key for backwards compatibility
|
||||||
|
if val, ok := meta["getBlobRetryCount"]; ok && val != "" {
|
||||||
|
// convert val from string to int32
|
||||||
|
parseInt, err := strconv.ParseInt(val, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m.RetryCount = int32(parseInt)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValidPublicAccessType(accessType azblob.PublicAccessType) bool {
|
||||||
|
validTypes := azblob.PossiblePublicAccessTypeValues()
|
||||||
|
for _, item := range validTypes {
|
||||||
|
if item == accessType {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func SanitizeMetadata(log logger.Logger, metadata map[string]string) map[string]string {
|
||||||
|
for key, val := range metadata {
|
||||||
|
// Keep only letters and digits
|
||||||
|
n := 0
|
||||||
|
newKey := make([]byte, len(key))
|
||||||
|
for i := 0; i < len(key); i++ {
|
||||||
|
if (key[i] >= 'A' && key[i] <= 'Z') ||
|
||||||
|
(key[i] >= 'a' && key[i] <= 'z') ||
|
||||||
|
(key[i] >= '0' && key[i] <= '9') {
|
||||||
|
newKey[n] = key[i]
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if n != len(key) {
|
||||||
|
nks := string(newKey[:n])
|
||||||
|
log.Warnf("metadata key %s contains disallowed characters, sanitized to %s", key, nks)
|
||||||
|
delete(metadata, key)
|
||||||
|
metadata[nks] = val
|
||||||
|
key = nks
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove all non-ascii characters
|
||||||
|
n = 0
|
||||||
|
newVal := make([]byte, len(val))
|
||||||
|
for i := 0; i < len(val); i++ {
|
||||||
|
if val[i] > 127 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newVal[n] = val[i]
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
metadata[key] = string(newVal[:n])
|
||||||
|
}
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,93 @@
|
||||||
|
/*
|
||||||
|
Copyright 2021 The Dapr Authors
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
package blobstorage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
|
||||||
|
"github.com/dapr/kit/logger"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseMetadata(t *testing.T) {
|
||||||
|
logger := logger.NewLogger("test")
|
||||||
|
var m map[string]string
|
||||||
|
|
||||||
|
t.Run("parse all metadata", func(t *testing.T) {
|
||||||
|
m = map[string]string{
|
||||||
|
"storageAccount": "account",
|
||||||
|
"storageAccessKey": "key",
|
||||||
|
"container": "test",
|
||||||
|
"getBlobRetryCount": "5",
|
||||||
|
"decodeBase64": "true",
|
||||||
|
}
|
||||||
|
meta, err := parseMetadata(m)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "test", meta.ContainerName)
|
||||||
|
assert.Equal(t, "account", meta.AccountName)
|
||||||
|
// storageAccessKey is parsed in the azauth package
|
||||||
|
assert.Equal(t, true, meta.DecodeBase64)
|
||||||
|
assert.Equal(t, int32(5), meta.RetryCount)
|
||||||
|
assert.Equal(t, "", string(meta.PublicAccessLevel))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("parse metadata with publicAccessLevel = blob", func(t *testing.T) {
|
||||||
|
m = map[string]string{
|
||||||
|
"storageAccount": "account",
|
||||||
|
"storageAccessKey": "key",
|
||||||
|
"container": "test",
|
||||||
|
"publicAccessLevel": "blob",
|
||||||
|
}
|
||||||
|
meta, err := parseMetadata(m)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, azblob.PublicAccessTypeBlob, meta.PublicAccessLevel)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("parse metadata with publicAccessLevel = container", func(t *testing.T) {
|
||||||
|
m = map[string]string{
|
||||||
|
"storageAccount": "account",
|
||||||
|
"storageAccessKey": "key",
|
||||||
|
"container": "test",
|
||||||
|
"publicAccessLevel": "container",
|
||||||
|
}
|
||||||
|
meta, err := parseMetadata(m)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, azblob.PublicAccessTypeContainer, meta.PublicAccessLevel)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("parse metadata with invalid publicAccessLevel", func(t *testing.T) {
|
||||||
|
m = map[string]string{
|
||||||
|
"storageAccount": "account",
|
||||||
|
"storageAccessKey": "key",
|
||||||
|
"container": "test",
|
||||||
|
"publicAccessLevel": "invalid",
|
||||||
|
}
|
||||||
|
_, err := parseMetadata(m)
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("sanitize metadata if necessary", func(t *testing.T) {
|
||||||
|
m = map[string]string{
|
||||||
|
"somecustomfield": "some-custom-value",
|
||||||
|
"specialfield": "special:valueÜ",
|
||||||
|
"not-allowed:": "not-allowed",
|
||||||
|
}
|
||||||
|
meta := SanitizeMetadata(logger, m)
|
||||||
|
assert.Equal(t, meta["somecustomfield"], "some-custom-value")
|
||||||
|
assert.Equal(t, meta["specialfield"], "special:value")
|
||||||
|
assert.Equal(t, meta["notallowed"], "not-allowed")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
@ -40,20 +40,17 @@ import (
|
||||||
b64 "encoding/base64"
|
b64 "encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/url"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
|
|
||||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
|
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
|
||||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob"
|
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob"
|
||||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror"
|
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror"
|
||||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container"
|
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container"
|
||||||
jsoniter "github.com/json-iterator/go"
|
jsoniter "github.com/json-iterator/go"
|
||||||
|
|
||||||
azauth "github.com/dapr/components-contrib/internal/authentication/azure"
|
storageinternal "github.com/dapr/components-contrib/internal/component/azure/blobstorage"
|
||||||
mdutils "github.com/dapr/components-contrib/metadata"
|
mdutils "github.com/dapr/components-contrib/metadata"
|
||||||
"github.com/dapr/components-contrib/state"
|
"github.com/dapr/components-contrib/state"
|
||||||
"github.com/dapr/kit/logger"
|
"github.com/dapr/kit/logger"
|
||||||
|
|
@ -81,80 +78,13 @@ type StateStore struct {
|
||||||
logger logger.Logger
|
logger logger.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
type blobStorageMetadata struct {
|
|
||||||
AccountName string
|
|
||||||
ContainerName string
|
|
||||||
AccountKey string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Init the connection to blob storage, optionally creates a blob container if it doesn't exist.
|
// Init the connection to blob storage, optionally creates a blob container if it doesn't exist.
|
||||||
func (r *StateStore) Init(metadata state.Metadata) error {
|
func (r *StateStore) Init(metadata state.Metadata) error {
|
||||||
m, err := getBlobStorageMetadata(metadata.Properties)
|
var err error
|
||||||
|
r.containerClient, _, err = storageinternal.CreateContainerStorageClient(r.logger, metadata.Properties)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
userAgent := "dapr-" + logger.DaprVersion
|
|
||||||
options := container.ClientOptions{
|
|
||||||
ClientOptions: azcore.ClientOptions{
|
|
||||||
Telemetry: policy.TelemetryOptions{
|
|
||||||
ApplicationID: userAgent,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
settings, err := azauth.NewEnvironmentSettings("storage", metadata.Properties)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
customEndpoint, ok := metadata.Properties[endpointKey]
|
|
||||||
var URL *url.URL
|
|
||||||
if ok && customEndpoint != "" {
|
|
||||||
var parseErr error
|
|
||||||
URL, parseErr = url.Parse(fmt.Sprintf("%s/%s/%s", customEndpoint, m.AccountName, m.ContainerName))
|
|
||||||
if parseErr != nil {
|
|
||||||
return parseErr
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
env := settings.AzureEnvironment
|
|
||||||
URL, _ = url.Parse(fmt.Sprintf("https://%s.blob.%s/%s", m.AccountName, env.StorageEndpointSuffix, m.ContainerName))
|
|
||||||
}
|
|
||||||
|
|
||||||
var clientErr error
|
|
||||||
var client *container.Client
|
|
||||||
// Try using shared key credentials first
|
|
||||||
if m.AccountKey != "" {
|
|
||||||
credential, newSharedKeyErr := azblob.NewSharedKeyCredential(m.AccountName, m.AccountKey)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("invalid credentials with error: %w", newSharedKeyErr)
|
|
||||||
}
|
|
||||||
client, clientErr = container.NewClientWithSharedKeyCredential(URL.String(), credential, &options)
|
|
||||||
if clientErr != nil {
|
|
||||||
return fmt.Errorf("cannot init Blobstorage container client: %w", err)
|
|
||||||
}
|
|
||||||
r.containerClient = client
|
|
||||||
} else {
|
|
||||||
// fallback to AAD
|
|
||||||
credential, tokenErr := settings.GetTokenCredential()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("invalid credentials with error: %w", tokenErr)
|
|
||||||
}
|
|
||||||
client, clientErr = container.NewClient(URL.String(), credential, &options)
|
|
||||||
}
|
|
||||||
if clientErr != nil {
|
|
||||||
return fmt.Errorf("cannot init Blobstorage client: %w", clientErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
createContainerOptions := container.CreateOptions{
|
|
||||||
Access: nil,
|
|
||||||
}
|
|
||||||
timeoutCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
||||||
_, err = client.Create(timeoutCtx, &createContainerOptions)
|
|
||||||
cancel()
|
|
||||||
// Don't return error, container might already exist
|
|
||||||
r.logger.Debugf("error creating container: %w", err)
|
|
||||||
r.containerClient = client
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -193,7 +123,7 @@ func (r *StateStore) Ping() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *StateStore) GetComponentMetadata() map[string]string {
|
func (r *StateStore) GetComponentMetadata() map[string]string {
|
||||||
metadataStruct := blobStorageMetadata{}
|
metadataStruct := storageinternal.BlobStorageMetadata{}
|
||||||
metadataInfo := map[string]string{}
|
metadataInfo := map[string]string{}
|
||||||
mdutils.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo)
|
mdutils.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo)
|
||||||
return metadataInfo
|
return metadataInfo
|
||||||
|
|
@ -211,25 +141,6 @@ func NewAzureBlobStorageStore(logger logger.Logger) state.Store {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func getBlobStorageMetadata(meta map[string]string) (*blobStorageMetadata, error) {
|
|
||||||
m := blobStorageMetadata{}
|
|
||||||
err := mdutils.DecodeMetadata(meta, &m)
|
|
||||||
|
|
||||||
if val, ok := mdutils.GetMetadataProperty(meta, azauth.StorageAccountNameKeys...); ok && val != "" {
|
|
||||||
m.AccountName = val
|
|
||||||
} else {
|
|
||||||
return nil, fmt.Errorf("missing or empty %s field from metadata", azauth.StorageAccountNameKeys[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
if val, ok := mdutils.GetMetadataProperty(meta, azauth.StorageContainerNameKeys...); ok && val != "" {
|
|
||||||
m.ContainerName = val
|
|
||||||
} else {
|
|
||||||
return nil, fmt.Errorf("missing or empty %s field from metadata", azauth.StorageContainerNameKeys[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
return &m, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *StateStore) readFile(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
|
func (r *StateStore) readFile(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
|
||||||
blockBlobClient := r.containerClient.NewBlockBlobClient(getFileName(req.Key))
|
blockBlobClient := r.containerClient.NewBlockBlobClient(getFileName(req.Key))
|
||||||
|
|
||||||
|
|
@ -289,7 +200,7 @@ func (r *StateStore) writeFile(ctx context.Context, req *state.SetRequest) error
|
||||||
|
|
||||||
uploadOptions := azblob.UploadBufferOptions{
|
uploadOptions := azblob.UploadBufferOptions{
|
||||||
AccessConditions: &accessConditions,
|
AccessConditions: &accessConditions,
|
||||||
Metadata: req.Metadata,
|
Metadata: storageinternal.SanitizeMetadata(r.logger, req.Metadata),
|
||||||
HTTPHeaders: &blobHTTPHeaders,
|
HTTPHeaders: &blobHTTPHeaders,
|
||||||
Concurrency: 16,
|
Concurrency: 16,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -58,26 +58,6 @@ func TestInit(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetBlobStorageMetaData(t *testing.T) {
|
|
||||||
t.Run("Nothing at all passed", func(t *testing.T) {
|
|
||||||
m := make(map[string]string)
|
|
||||||
_, err := getBlobStorageMetadata(m)
|
|
||||||
|
|
||||||
assert.NotNil(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("All parameters passed and parsed", func(t *testing.T) {
|
|
||||||
m := make(map[string]string)
|
|
||||||
m["accountName"] = "acc"
|
|
||||||
m["containerName"] = "dapr"
|
|
||||||
meta, err := getBlobStorageMetadata(m)
|
|
||||||
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, "acc", meta.AccountName)
|
|
||||||
assert.Equal(t, "dapr", meta.ContainerName)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFileName(t *testing.T) {
|
func TestFileName(t *testing.T) {
|
||||||
t.Run("Valid composite key", func(t *testing.T) {
|
t.Run("Valid composite key", func(t *testing.T) {
|
||||||
key := getFileName("app_id||key")
|
key := getFileName("app_id||key")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue