source-controller/pkg/azure/blob_test.go

379 lines
8.6 KiB
Go

/*
Copyright 2022 The Flux authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package azure
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"math/big"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
. "github.com/onsi/gomega"
corev1 "k8s.io/api/core/v1"
)
func TestValidateSecret(t *testing.T) {
tests := []struct {
name string
secret *corev1.Secret
wantErr bool
}{
{
name: "valid UserManagedIdentity Secret",
secret: &corev1.Secret{
Data: map[string][]byte{
clientIDField: []byte("some-client-id-"),
},
},
},
{
name: "valid ServicePrincipal Certificate Secret",
secret: &corev1.Secret{
Data: map[string][]byte{
tenantIDField: []byte("some-tenant-id-"),
clientIDField: []byte("some-client-id-"),
clientCertificateField: []byte("some-certificate"),
},
},
},
{
name: "valid ServicePrincipal Secret",
secret: &corev1.Secret{
Data: map[string][]byte{
tenantIDField: []byte("some-tenant-id-"),
clientIDField: []byte("some-client-id-"),
clientSecretField: []byte("some-client-secret-"),
},
},
},
{
name: "valid SharedKey Secret",
secret: &corev1.Secret{
Data: map[string][]byte{
accountKeyField: []byte("some-account-key"),
},
},
},
{
name: "valid AuthorityHost Secret",
secret: &corev1.Secret{
Data: map[string][]byte{
authorityHostField: []byte("some.host.com"),
},
},
},
{
name: "invalid ServicePrincipal Secret with missing ClientID and ClientSecret",
secret: &corev1.Secret{
Data: map[string][]byte{
tenantIDField: []byte("some-tenant-id-"),
},
},
wantErr: true,
},
{
name: "invalid empty secret",
secret: &corev1.Secret{},
wantErr: true,
},
{
name: "valid nil secret",
secret: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := NewWithT(t)
want := BeNil()
if tt.wantErr {
want = HaveOccurred()
}
g.Expect(ValidateSecret(tt.secret)).To(want)
})
}
}
func TestBlobClient_ObjectIsNotFound(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{
name: "StorageError with BlobNotFound code",
err: &azblob.StorageError{ErrorCode: azblob.StorageErrorCodeBlobNotFound},
want: true,
},
{
name: "StorageError with different code",
err: &azblob.StorageError{ErrorCode: azblob.StorageErrorCodeInternalError},
},
{
name: "other error",
err: errors.New("an error"),
},
{
name: "nil error",
err: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := NewWithT(t)
c := &BlobClient{}
g.Expect(c.ObjectIsNotFound(tt.err)).To(Equal(tt.want))
})
}
}
func Test_extractAccountNameFromEndpoint(t *testing.T) {
tests := []struct {
name string
endpoint string
want string
wantErr bool
}{
{
name: "returns account name for endpoint",
endpoint: "https://foo.blob.core.windows.net",
want: "foo",
},
{
name: "error for endpoint URL parse err",
endpoint: "#http//foo.blob.core.windows.net",
wantErr: true,
},
{
name: "error for endpoint URL without subdomain",
endpoint: "https://windows.net",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := NewWithT(t)
got, err := extractAccountNameFromEndpoint(tt.endpoint)
g.Expect(err != nil).To(Equal(tt.wantErr))
g.Expect(got).To(Equal(tt.want))
})
}
}
func Test_tokenCredentialFromSecret(t *testing.T) {
tests := []struct {
name string
secret *corev1.Secret
want azcore.TokenCredential
wantErr bool
}{
{
name: "with ClientID field",
secret: &corev1.Secret{
Data: map[string][]byte{
clientIDField: []byte("client-id"),
},
},
want: &azidentity.ManagedIdentityCredential{},
},
{
name: "with TenantID, ClientID and ClientCertificate fields",
secret: &corev1.Secret{
Data: map[string][]byte{
clientIDField: []byte("client-id"),
tenantIDField: []byte("tenant-id"),
clientCertificateField: validTls(t),
},
},
want: &azidentity.ClientCertificateCredential{},
},
{
name: "with TenantID, ClientID and ClientSecret fields",
secret: &corev1.Secret{
Data: map[string][]byte{
clientIDField: []byte("client-id"),
tenantIDField: []byte("tenant-id"),
clientSecretField: []byte("client-secret"),
},
},
want: &azidentity.ClientSecretCredential{},
},
{
name: "empty secret",
secret: &corev1.Secret{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := NewWithT(t)
got, err := tokenCredentialFromSecret(tt.secret)
g.Expect(err != nil).To(Equal(tt.wantErr))
if tt.want != nil {
g.Expect(got).ToNot(BeNil())
g.Expect(got).To(BeAssignableToTypeOf(tt.want))
return
}
g.Expect(got).To(BeNil())
})
}
}
func Test_sharedCredentialFromSecret(t *testing.T) {
var testKey = []byte("dGVzdA==")
tests := []struct {
name string
endpoint string
secret *corev1.Secret
want *azblob.SharedKeyCredential
wantErr bool
}{
{
name: "with AccountKey field",
endpoint: "https://some.endpoint.com",
secret: &corev1.Secret{
Data: map[string][]byte{
accountKeyField: testKey,
},
},
want: &azblob.SharedKeyCredential{},
},
{
name: "invalid endpoint",
endpoint: "#http//some.endpoint.com",
secret: &corev1.Secret{
Data: map[string][]byte{
accountKeyField: testKey,
},
},
wantErr: true,
},
{
name: "empty secret",
secret: &corev1.Secret{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := NewWithT(t)
got, err := sharedCredentialFromSecret(tt.endpoint, tt.secret)
g.Expect(err != nil).To(Equal(tt.wantErr))
if tt.want != nil {
g.Expect(got).ToNot(BeNil())
g.Expect(got).To(BeAssignableToTypeOf(tt.want))
return
}
g.Expect(got).To(BeNil())
})
}
}
func Test_chainCredentialWithSecret(t *testing.T) {
g := NewWithT(t)
got, err := chainCredentialWithSecret(nil)
g.Expect(err).ToNot(HaveOccurred())
g.Expect(got).To(BeAssignableToTypeOf(&azidentity.ChainedTokenCredential{}))
}
func Test_extractAccountNameFromEndpoint1(t *testing.T) {
tests := []struct {
name string
endpoint string
want string
wantErr string
}{
{
name: "valid URL",
endpoint: endpointURL("foo"),
want: "foo",
},
{
name: "URL parse error",
endpoint: " https://example.com",
wantErr: "first path segment in URL cannot contain colon",
},
{
name: "error on non subdomain",
endpoint: "https://example.com",
wantErr: "expected 'example.com' to be a subdomain",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := NewWithT(t)
got, err := extractAccountNameFromEndpoint(tt.endpoint)
if tt.wantErr != "" {
g.Expect(err.Error()).To(ContainSubstring(tt.wantErr))
g.Expect(got).To(BeEmpty())
return
}
g.Expect(err).To(BeNil())
g.Expect(got).To(Equal(tt.want))
})
}
}
func endpointURL(accountName string) string {
return fmt.Sprintf("https://%s.blob.core.windows.net", accountName)
}
func validTls(t *testing.T) []byte {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal("Private key cannot be created.", err.Error())
}
out := bytes.NewBuffer(nil)
var privateKey = &pem.Block{
Type: "PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
}
if err = pem.Encode(out, privateKey); err != nil {
t.Fatal("Private key cannot be PEM encoded.", err.Error())
}
certTemplate := x509.Certificate{
SerialNumber: big.NewInt(1337),
}
cert, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &key.PublicKey, key)
if err != nil {
t.Fatal("Certificate cannot be created.", err.Error())
}
var certificate = &pem.Block{
Type: "CERTIFICATE",
Bytes: cert,
}
if err = pem.Encode(out, certificate); err != nil {
t.Fatal("Certificate cannot be PEM encoded.", err.Error())
}
return out.Bytes()
}