diff --git a/pkg/azure/blob.go b/pkg/azure/blob.go index 7ef27865..efd99ef5 100644 --- a/pkg/azure/blob.go +++ b/pkg/azure/blob.go @@ -44,11 +44,13 @@ var ( ) const ( - resourceIDField = "resourceId" - clientIDField = "clientId" - tenantIDField = "tenantId" - clientSecretField = "clientSecret" - accountKeyField = "accountKey" + resourceIDField = "resourceId" + clientIDField = "clientId" + tenantIDField = "tenantId" + clientSecretField = "clientSecret" + clientCertificateField = "clientCertificate" + clientCertificatePasswordField = "clientCertificatePassword" + accountKeyField = "accountKey" ) // BlobClient is a minimal Azure Blob client for fetching objects. @@ -62,13 +64,17 @@ type BlobClient struct { // order: // // - azidentity.ManagedIdentityCredential for a Resource ID, when a -// resourceIDField is found. -// - azidentity.ManagedIdentityCredential for a User ID, when a clientIDField -// but no tenantIDField found. -// - azidentity.ClientSecretCredential when a tenantIDField, clientIDField and -// clientSecretField are found. -// - azblob.SharedKeyCredential when an accountKeyField is found. The Account -// Name is extracted from the endpoint specified on the Bucket object. +// `resourceId` field is found. +// - azidentity.ManagedIdentityCredential for a User ID, when a `clientId` +// field but no `tenantId` is found. +// - azidentity.ClientCertificateCredential when `tenantId`, +// `clientCertificate` (and optionally `clientCertificatePassword`) fields +// are found. +// - azidentity.ClientSecretCredential when `tenantId`, `clientId` and +// `clientSecret` fields are found. +// - azblob.SharedKeyCredential when an `accountKey` field is found. +// The account name is extracted from the endpoint specified on the Bucket +// object. // // If no credentials are found, a simple client without credentials is // returned. @@ -119,6 +125,9 @@ func ValidateSecret(secret *corev1.Secret) error { if _, hasClientSecret := secret.Data[clientSecretField]; hasClientSecret { valid = true } + if _, hasClientCertificate := secret.Data[clientCertificateField]; hasClientCertificate { + valid = true + } } } if _, hasResourceID := secret.Data[resourceIDField]; hasResourceID { @@ -132,8 +141,8 @@ func ValidateSecret(secret *corev1.Secret) error { } if !valid { - return fmt.Errorf("invalid '%s' secret data: requires a '%s', '%s', or '%s' field, or a combination of '%s', '%s' and '%s'", - secret.Name, resourceIDField, clientIDField, accountKeyField, tenantIDField, clientIDField, clientSecretField) + return fmt.Errorf("invalid '%s' secret data: requires a '%s', '%s', or '%s' field, a combination of '%s', '%s' and '%s', or '%s', '%s' and '%s'", + secret.Name, resourceIDField, clientIDField, accountKeyField, tenantIDField, clientIDField, clientSecretField, tenantIDField, clientIDField, clientCertificateField) } return nil } @@ -275,6 +284,13 @@ func tokenCredentialFromSecret(secret *corev1.Secret) (azcore.TokenCredential, e ID: azidentity.ClientID(clientID), }) } + if clientCertificate, hasClientCertificate := secret.Data[clientCertificateField]; hasClientCertificate { + certs, key, err := azidentity.ParseCertificates(clientCertificate, secret.Data[clientCertificatePasswordField]) + if err != nil { + return nil, fmt.Errorf("failed to parse client certificates: %w", err) + } + return azidentity.NewClientCertificateCredential(string(tenantID), string(clientID), certs, key, nil) + } if clientSecret, hasClientSecret := secret.Data[clientSecretField]; hasClientSecret { return azidentity.NewClientSecretCredential(string(tenantID), string(clientID), string(clientSecret), nil) } diff --git a/pkg/azure/blob_test.go b/pkg/azure/blob_test.go index 3dd59156..5002f647 100644 --- a/pkg/azure/blob_test.go +++ b/pkg/azure/blob_test.go @@ -17,8 +17,14 @@ limitations under the License. package azure import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" "errors" "fmt" + "math/big" "testing" "github.com/Azure/azure-sdk-for-go/sdk/azcore" @@ -50,6 +56,16 @@ func TestValidateSecret(t *testing.T) { }, }, }, + { + name: "valid ServicePrincipal Certificate Secret", + secret: &corev1.Secret{ + Data: map[string][]byte{ + tenantIDField: []byte("some-tenant-id-"), + clientIDField: []byte("some-client-id-"), + clientCertificateField: []byte("some-certificate"), + }, + }, + }, { name: "valid ServicePrincipal Secret", secret: &corev1.Secret{ @@ -192,6 +208,17 @@ func Test_tokenCredentialFromSecret(t *testing.T) { }, want: &azidentity.ManagedIdentityCredential{}, }, + { + name: "with TenantID, ClientID and ClientCertificate fields", + secret: &corev1.Secret{ + Data: map[string][]byte{ + clientIDField: []byte("client-id"), + tenantIDField: []byte("tenant-id"), + clientCertificateField: validTls(t), + }, + }, + want: &azidentity.ClientCertificateCredential{}, + }, { name: "with TenantID, ClientID and ClientSecret fields", secret: &corev1.Secret{ @@ -316,3 +343,37 @@ func Test_extractAccountNameFromEndpoint1(t *testing.T) { func endpointURL(accountName string) string { return fmt.Sprintf("https://%s.blob.core.windows.net", accountName) } + +func validTls(t *testing.T) []byte { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal("Private key cannot be created.", err.Error()) + } + + out := bytes.NewBuffer(nil) + + var privateKey = &pem.Block{ + Type: "PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + } + if err = pem.Encode(out, privateKey); err != nil { + t.Fatal("Private key cannot be PEM encoded.", err.Error()) + } + + certTemplate := x509.Certificate{ + SerialNumber: big.NewInt(1337), + } + cert, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &key.PublicKey, key) + if err != nil { + t.Fatal("Certificate cannot be created.", err.Error()) + } + var certificate = &pem.Block{ + Type: "CERTIFICATE", + Bytes: cert, + } + if err = pem.Encode(out, certificate); err != nil { + t.Fatal("Certificate cannot be PEM encoded.", err.Error()) + } + + return out.Bytes() +}