266 lines
8.2 KiB
Go
266 lines
8.2 KiB
Go
/*
|
|
Copyright 2024 The Dapr Authors
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package aws
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"testing"
|
|
|
|
"github.com/aws/aws-sdk-go/aws"
|
|
"github.com/aws/aws-sdk-go/aws/credentials"
|
|
"github.com/aws/aws-sdk-go/aws/request"
|
|
"github.com/aws/aws-sdk-go/aws/session"
|
|
"github.com/aws/aws-sdk-go/service/kinesis"
|
|
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
|
|
"github.com/aws/aws-sdk-go/service/sqs"
|
|
"github.com/aws/aws-sdk-go/service/sqs/sqsiface"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/vmware/vmware-go-kcl/clientlibrary/config"
|
|
)
|
|
|
|
type mockedSQS struct {
|
|
sqsiface.SQSAPI
|
|
GetQueueURLFn func(ctx context.Context, input *sqs.GetQueueUrlInput) (*sqs.GetQueueUrlOutput, error)
|
|
}
|
|
|
|
func (m *mockedSQS) GetQueueUrlWithContext(ctx context.Context, input *sqs.GetQueueUrlInput, opts ...request.Option) (*sqs.GetQueueUrlOutput, error) { //nolint:stylecheck
|
|
return m.GetQueueURLFn(ctx, input)
|
|
}
|
|
|
|
type mockedKinesis struct {
|
|
kinesisiface.KinesisAPI
|
|
DescribeStreamFn func(ctx context.Context, input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error)
|
|
}
|
|
|
|
func (m *mockedKinesis) DescribeStreamWithContext(ctx context.Context, input *kinesis.DescribeStreamInput, opts ...request.Option) (*kinesis.DescribeStreamOutput, error) {
|
|
return m.DescribeStreamFn(ctx, input)
|
|
}
|
|
|
|
func TestS3Clients_New(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
s3Client *S3Clients
|
|
session *session.Session
|
|
}{
|
|
{"initializes S3 client", &S3Clients{}, session.Must(session.NewSession())},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
tt.s3Client.New(tt.session)
|
|
require.NotNil(t, tt.s3Client.S3)
|
|
require.NotNil(t, tt.s3Client.Uploader)
|
|
require.NotNil(t, tt.s3Client.Downloader)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSqsClients_QueueURL(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
mockFn func() *mockedSQS
|
|
queueName string
|
|
expectedURL *string
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "returns queue URL successfully",
|
|
mockFn: func() *mockedSQS {
|
|
return &mockedSQS{
|
|
GetQueueURLFn: func(ctx context.Context, input *sqs.GetQueueUrlInput) (*sqs.GetQueueUrlOutput, error) {
|
|
return &sqs.GetQueueUrlOutput{
|
|
QueueUrl: aws.String("https://sqs.aws.com/123456789012/queue"),
|
|
}, nil
|
|
},
|
|
}
|
|
},
|
|
queueName: "valid-queue",
|
|
expectedURL: aws.String("https://sqs.aws.com/123456789012/queue"),
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "returns error when queue URL not found",
|
|
mockFn: func() *mockedSQS {
|
|
return &mockedSQS{
|
|
GetQueueURLFn: func(ctx context.Context, input *sqs.GetQueueUrlInput) (*sqs.GetQueueUrlOutput, error) {
|
|
return nil, errors.New("unable to get stream arn due to empty client")
|
|
},
|
|
}
|
|
},
|
|
queueName: "missing-queue",
|
|
expectedURL: nil,
|
|
expectError: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockSQS := tt.mockFn()
|
|
|
|
// Initialize SqsClients with the mocked SQS client
|
|
client := &SqsClients{
|
|
Sqs: mockSQS,
|
|
}
|
|
|
|
url, err := client.QueueURL(t.Context(), tt.queueName)
|
|
|
|
if tt.expectError {
|
|
require.Error(t, err)
|
|
} else {
|
|
require.NoError(t, err)
|
|
assert.Equal(t, tt.expectedURL, url)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestKinesisClients_Stream(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
kinesisClient *KinesisClients
|
|
streamName string
|
|
mockStreamARN *string
|
|
mockError error
|
|
expectedStream *string
|
|
expectedErr error
|
|
}{
|
|
{
|
|
name: "successfully retrieves stream ARN",
|
|
kinesisClient: &KinesisClients{
|
|
Kinesis: &mockedKinesis{DescribeStreamFn: func(ctx context.Context, input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) {
|
|
return &kinesis.DescribeStreamOutput{
|
|
StreamDescription: &kinesis.StreamDescription{
|
|
StreamARN: aws.String("arn:aws:kinesis:some-region:123456789012:stream/some-stream"),
|
|
},
|
|
}, nil
|
|
}},
|
|
Region: "us-west-1",
|
|
Credentials: credentials.NewStaticCredentials("accessKey", "secretKey", ""),
|
|
},
|
|
streamName: "some-stream",
|
|
expectedStream: aws.String("arn:aws:kinesis:some-region:123456789012:stream/some-stream"),
|
|
expectedErr: nil,
|
|
},
|
|
{
|
|
name: "returns error when stream not found",
|
|
kinesisClient: &KinesisClients{
|
|
Kinesis: &mockedKinesis{DescribeStreamFn: func(ctx context.Context, input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) {
|
|
return nil, errors.New("stream not found")
|
|
}},
|
|
Region: "us-west-1",
|
|
Credentials: credentials.NewStaticCredentials("accessKey", "secretKey", ""),
|
|
},
|
|
streamName: "nonexistent-stream",
|
|
expectedStream: nil,
|
|
expectedErr: errors.New("unable to get stream arn due to empty client"),
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := tt.kinesisClient.Stream(t.Context(), tt.streamName)
|
|
if tt.expectedErr != nil {
|
|
require.Error(t, err)
|
|
assert.Equal(t, tt.expectedErr.Error(), err.Error())
|
|
} else {
|
|
require.NoError(t, err)
|
|
assert.Equal(t, tt.expectedStream, got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestKinesisClients_WorkerCfg(t *testing.T) {
|
|
testCreds := credentials.NewStaticCredentials("accessKey", "secretKey", "")
|
|
tests := []struct {
|
|
name string
|
|
kinesisClient *KinesisClients
|
|
streamName string
|
|
consumer string
|
|
mode string
|
|
expectedConfig *config.KinesisClientLibConfiguration
|
|
}{
|
|
{
|
|
name: "successfully creates shared mode worker config",
|
|
kinesisClient: &KinesisClients{
|
|
Kinesis: &mockedKinesis{
|
|
DescribeStreamFn: func(ctx context.Context, input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) {
|
|
return &kinesis.DescribeStreamOutput{
|
|
StreamDescription: &kinesis.StreamDescription{
|
|
StreamARN: aws.String("arn:aws:kinesis:us-east-1:123456789012:stream/existing-stream"),
|
|
},
|
|
}, nil
|
|
},
|
|
},
|
|
Region: "us-west-1",
|
|
Credentials: testCreds,
|
|
},
|
|
streamName: "existing-stream",
|
|
consumer: "consumer1",
|
|
mode: "shared",
|
|
expectedConfig: config.NewKinesisClientLibConfigWithCredential(
|
|
"consumer1", "existing-stream", "us-west-1", "consumer1", testCreds,
|
|
),
|
|
},
|
|
{
|
|
name: "returns nil when mode is not shared",
|
|
kinesisClient: &KinesisClients{
|
|
Kinesis: &mockedKinesis{
|
|
DescribeStreamFn: func(ctx context.Context, input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) {
|
|
return &kinesis.DescribeStreamOutput{
|
|
StreamDescription: &kinesis.StreamDescription{
|
|
StreamARN: aws.String("arn:aws:kinesis:us-east-1:123456789012:stream/existing-stream"),
|
|
},
|
|
}, nil
|
|
},
|
|
},
|
|
Region: "us-west-1",
|
|
Credentials: testCreds,
|
|
},
|
|
streamName: "existing-stream",
|
|
consumer: "consumer1",
|
|
mode: "exclusive",
|
|
expectedConfig: nil,
|
|
},
|
|
{
|
|
name: "returns nil when client is nil",
|
|
kinesisClient: &KinesisClients{
|
|
Kinesis: nil,
|
|
Region: "us-west-1",
|
|
Credentials: credentials.NewStaticCredentials("accessKey", "secretKey", ""),
|
|
},
|
|
streamName: "existing-stream",
|
|
consumer: "consumer1",
|
|
mode: "shared",
|
|
expectedConfig: nil,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
cfg := tt.kinesisClient.WorkerCfg(t.Context(), tt.streamName, tt.consumer, tt.mode)
|
|
if tt.expectedConfig == nil {
|
|
assert.Equal(t, tt.expectedConfig, cfg)
|
|
return
|
|
}
|
|
assert.Equal(t, tt.expectedConfig.StreamName, cfg.StreamName)
|
|
assert.Equal(t, tt.expectedConfig.EnhancedFanOutConsumerName, cfg.EnhancedFanOutConsumerName)
|
|
assert.Equal(t, tt.expectedConfig.EnableEnhancedFanOutConsumer, cfg.EnableEnhancedFanOutConsumer)
|
|
assert.Equal(t, tt.expectedConfig.RegionName, cfg.RegionName)
|
|
})
|
|
}
|
|
}
|