445 lines
11 KiB
Go
445 lines
11 KiB
Go
/*
|
|
Copyright 2021 The Dapr Authors
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package aws
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
awsv2 "github.com/aws/aws-sdk-go-v2/aws"
|
|
"github.com/aws/aws-sdk-go-v2/config"
|
|
v2creds "github.com/aws/aws-sdk-go-v2/credentials"
|
|
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
|
|
"github.com/aws/aws-sdk-go-v2/feature/rds/auth"
|
|
"github.com/aws/aws-sdk-go-v2/service/sts"
|
|
"github.com/aws/aws-sdk-go/aws"
|
|
"github.com/aws/aws-sdk-go/aws/credentials"
|
|
"github.com/aws/aws-sdk-go/aws/request"
|
|
"github.com/aws/aws-sdk-go/aws/session"
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
|
|
"github.com/dapr/kit/logger"
|
|
)
|
|
|
|
type StaticAuth struct {
|
|
mu sync.RWMutex
|
|
logger logger.Logger
|
|
|
|
region *string
|
|
endpoint *string
|
|
accessKey *string
|
|
secretKey *string
|
|
sessionToken string
|
|
|
|
assumeRoleARN *string
|
|
sessionName string
|
|
|
|
session *session.Session
|
|
cfg *aws.Config
|
|
clients *Clients
|
|
}
|
|
|
|
func newStaticIAM(_ context.Context, opts Options, cfg *aws.Config) (*StaticAuth, error) {
|
|
auth := &StaticAuth{
|
|
logger: opts.Logger,
|
|
cfg: func() *aws.Config {
|
|
// if nil is passed or it's just a default cfg,
|
|
// then we use the options to build the aws cfg.
|
|
if cfg != nil && cfg != aws.NewConfig() {
|
|
return cfg
|
|
}
|
|
return GetConfig(opts)
|
|
}(),
|
|
clients: newClients(),
|
|
}
|
|
|
|
if opts.Region != "" {
|
|
auth.region = &opts.Region
|
|
}
|
|
if opts.Endpoint != "" {
|
|
auth.endpoint = &opts.Endpoint
|
|
}
|
|
if opts.AccessKey != "" {
|
|
auth.accessKey = &opts.AccessKey
|
|
}
|
|
if opts.SecretKey != "" {
|
|
auth.secretKey = &opts.SecretKey
|
|
}
|
|
if opts.SessionToken != "" {
|
|
auth.sessionToken = opts.SessionToken
|
|
}
|
|
if opts.AssumeRoleARN != "" {
|
|
auth.assumeRoleARN = &opts.AssumeRoleARN
|
|
}
|
|
if opts.SessionName != "" {
|
|
auth.sessionName = opts.SessionName
|
|
}
|
|
|
|
initialSession, err := auth.createSession()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get token client: %v", err)
|
|
}
|
|
|
|
auth.session = initialSession
|
|
|
|
return auth, nil
|
|
}
|
|
|
|
// This is to be used only for test purposes to inject mocked clients
|
|
func (a *StaticAuth) WithMockClients(clients *Clients) {
|
|
a.clients = clients
|
|
}
|
|
|
|
func (a *StaticAuth) S3() *S3Clients {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
if a.clients.s3 != nil {
|
|
return a.clients.s3
|
|
}
|
|
|
|
s3Clients := S3Clients{}
|
|
a.clients.s3 = &s3Clients
|
|
a.clients.s3.New(a.session)
|
|
return a.clients.s3
|
|
}
|
|
|
|
func (a *StaticAuth) DynamoDB() *DynamoDBClients {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
if a.clients.Dynamo != nil {
|
|
return a.clients.Dynamo
|
|
}
|
|
|
|
clients := DynamoDBClients{}
|
|
a.clients.Dynamo = &clients
|
|
a.clients.Dynamo.New(a.session)
|
|
|
|
return a.clients.Dynamo
|
|
}
|
|
|
|
func (a *StaticAuth) Sqs() *SqsClients {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
if a.clients.sqs != nil {
|
|
return a.clients.sqs
|
|
}
|
|
|
|
clients := SqsClients{}
|
|
a.clients.sqs = &clients
|
|
a.clients.sqs.New(a.session)
|
|
|
|
return a.clients.sqs
|
|
}
|
|
|
|
func (a *StaticAuth) Sns() *SnsClients {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
if a.clients.sns != nil {
|
|
return a.clients.sns
|
|
}
|
|
|
|
clients := SnsClients{}
|
|
a.clients.sns = &clients
|
|
a.clients.sns.New(a.session)
|
|
return a.clients.sns
|
|
}
|
|
|
|
func (a *StaticAuth) SnsSqs() *SnsSqsClients {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
if a.clients.snssqs != nil {
|
|
return a.clients.snssqs
|
|
}
|
|
|
|
clients := SnsSqsClients{}
|
|
a.clients.snssqs = &clients
|
|
a.clients.snssqs.New(a.session)
|
|
return a.clients.snssqs
|
|
}
|
|
|
|
func (a *StaticAuth) SecretManager() *SecretManagerClients {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
if a.clients.Secret != nil {
|
|
return a.clients.Secret
|
|
}
|
|
|
|
clients := SecretManagerClients{}
|
|
a.clients.Secret = &clients
|
|
a.clients.Secret.New(a.session)
|
|
return a.clients.Secret
|
|
}
|
|
|
|
func (a *StaticAuth) ParameterStore() *ParameterStoreClients {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
if a.clients.ParameterStore != nil {
|
|
return a.clients.ParameterStore
|
|
}
|
|
|
|
clients := ParameterStoreClients{}
|
|
a.clients.ParameterStore = &clients
|
|
a.clients.ParameterStore.New(a.session)
|
|
return a.clients.ParameterStore
|
|
}
|
|
|
|
func (a *StaticAuth) Kinesis() *KinesisClients {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
if a.clients.kinesis != nil {
|
|
return a.clients.kinesis
|
|
}
|
|
|
|
clients := KinesisClients{}
|
|
a.clients.kinesis = &clients
|
|
a.clients.kinesis.New(a.session)
|
|
return a.clients.kinesis
|
|
}
|
|
|
|
func (a *StaticAuth) Ses() *SesClients {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
if a.clients.ses != nil {
|
|
return a.clients.ses
|
|
}
|
|
|
|
clients := SesClients{}
|
|
a.clients.ses = &clients
|
|
a.clients.ses.New(a.session)
|
|
return a.clients.ses
|
|
}
|
|
|
|
func (a *StaticAuth) UpdatePostgres(ctx context.Context, poolConfig *pgxpool.Config) {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
// Set max connection lifetime to 8 minutes in postgres connection pool configuration.
|
|
// Note: this will refresh connections before the 15 min expiration on the IAM AWS auth token,
|
|
// while leveraging the BeforeConnect hook to recreate the token in time dynamically.
|
|
poolConfig.MaxConnLifetime = time.Minute * 8
|
|
|
|
// Setup connection pool config needed for AWS IAM authentication
|
|
poolConfig.BeforeConnect = func(ctx context.Context, pgConfig *pgx.ConnConfig) error {
|
|
// Manually reset auth token with aws and reset the config password using the new iam token
|
|
pwd, err := a.getDatabaseToken(ctx, poolConfig)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get database token: %w", err)
|
|
}
|
|
pgConfig.Password = pwd
|
|
poolConfig.ConnConfig.Password = pwd
|
|
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/UsingWithRDS.IAMDBAuth.Connecting.Go.html
|
|
func (a *StaticAuth) getDatabaseToken(ctx context.Context, poolConfig *pgxpool.Config) (string, error) {
|
|
dbEndpoint := poolConfig.ConnConfig.Host + ":" + strconv.Itoa(int(poolConfig.ConnConfig.Port))
|
|
|
|
// First, check if there are credentials set explicitly with accesskey and secretkey
|
|
var creds credentials.Value
|
|
if a.session != nil {
|
|
var err error
|
|
creds, err = a.session.Config.Credentials.Get()
|
|
if err != nil {
|
|
a.logger.Infof("failed to get access key and secret key, will fallback to reading the default AWS credentials file: %w", err)
|
|
}
|
|
}
|
|
|
|
if creds.AccessKeyID != "" && creds.SecretAccessKey != "" {
|
|
creds, err := a.session.Config.Credentials.Get()
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to retrieve session credentials: %w", err)
|
|
}
|
|
awsCfg := v2creds.NewStaticCredentialsProvider(creds.AccessKeyID, creds.SecretAccessKey, creds.SessionToken)
|
|
authenticationToken, err := auth.BuildAuthToken(
|
|
ctx, dbEndpoint, *a.region, poolConfig.ConnConfig.User, awsCfg)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to create AWS authentication token: %w", err)
|
|
}
|
|
|
|
return authenticationToken, nil
|
|
}
|
|
|
|
// Second, check if we are assuming a role instead
|
|
if a.assumeRoleARN != nil {
|
|
awsCfg, err := config.LoadDefaultConfig(ctx)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to load default AWS authentication configuration %w", err)
|
|
}
|
|
stsClient := sts.NewFromConfig(awsCfg)
|
|
|
|
assumeRoleCfg, err := config.LoadDefaultConfig(ctx,
|
|
config.WithRegion(*a.region),
|
|
config.WithCredentialsProvider(
|
|
awsv2.NewCredentialsCache(
|
|
stscreds.NewAssumeRoleProvider(stsClient, *a.assumeRoleARN, func(aro *stscreds.AssumeRoleOptions) {
|
|
if a.sessionName != "" {
|
|
aro.RoleSessionName = a.sessionName
|
|
}
|
|
}),
|
|
),
|
|
),
|
|
)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to assume aws role %w", err)
|
|
}
|
|
|
|
authenticationToken, err := auth.BuildAuthToken(
|
|
ctx, dbEndpoint, *a.region, poolConfig.ConnConfig.User, assumeRoleCfg.Credentials)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to create AWS authentication token: %w", err)
|
|
}
|
|
return authenticationToken, nil
|
|
}
|
|
|
|
// Lastly, and by default, just use the default aws configuration
|
|
awsCfg, err := config.LoadDefaultConfig(ctx)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to load default AWS authentication configuration %w", err)
|
|
}
|
|
|
|
authenticationToken, err := auth.BuildAuthToken(ctx, dbEndpoint, *a.region, poolConfig.ConnConfig.User, awsCfg.Credentials)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to create AWS authentication token: %w", err)
|
|
}
|
|
|
|
return authenticationToken, nil
|
|
}
|
|
|
|
func (a *StaticAuth) Kafka(opts KafkaOptions) (*KafkaClients, error) {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
// This means we've already set the config in our New function
|
|
// to use the SASL token provider.
|
|
if a.clients.kafka != nil {
|
|
return a.clients.kafka, nil
|
|
}
|
|
|
|
a.clients.kafka = initKafkaClients(opts)
|
|
// static auth has additional fields we need added,
|
|
// so we add those static auth specific fields here,
|
|
// and the rest of the token provider fields are added in New()
|
|
tokenProvider := mskTokenProvider{}
|
|
if a.assumeRoleARN != nil {
|
|
tokenProvider.awsIamRoleArn = *a.assumeRoleARN
|
|
}
|
|
if a.sessionName != "" {
|
|
tokenProvider.awsStsSessionName = a.sessionName
|
|
}
|
|
|
|
err := a.clients.kafka.New(a.session, &tokenProvider)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create AWS IAM Kafka config: %w", err)
|
|
}
|
|
|
|
return a.clients.kafka, nil
|
|
}
|
|
|
|
func (a *StaticAuth) createSession() (*session.Session, error) {
|
|
var awsConfig *aws.Config
|
|
if a.cfg == nil {
|
|
awsConfig = aws.NewConfig()
|
|
} else {
|
|
awsConfig = a.cfg
|
|
}
|
|
|
|
if a.region != nil {
|
|
awsConfig = awsConfig.WithRegion(*a.region)
|
|
}
|
|
|
|
if a.accessKey != nil && a.secretKey != nil {
|
|
// session token is an option field
|
|
awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(*a.accessKey, *a.secretKey, a.sessionToken))
|
|
}
|
|
|
|
if a.endpoint != nil {
|
|
awsConfig = awsConfig.WithEndpoint(*a.endpoint)
|
|
}
|
|
|
|
// TODO support assume role for all aws components
|
|
|
|
awsSession, err := session.NewSessionWithOptions(session.Options{
|
|
Config: *awsConfig,
|
|
SharedConfigState: session.SharedConfigEnable,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
userAgentHandler := request.NamedHandler{
|
|
Name: "UserAgentHandler",
|
|
Fn: request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion),
|
|
}
|
|
awsSession.Handlers.Build.PushBackNamed(userAgentHandler)
|
|
|
|
return awsSession, nil
|
|
}
|
|
|
|
func (a *StaticAuth) Close() error {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
errs := make([]error, 2)
|
|
if a.clients.kafka != nil {
|
|
if a.clients.kafka.Producer != nil {
|
|
errs[0] = a.clients.kafka.Producer.Close()
|
|
a.clients.kafka.Producer = nil
|
|
}
|
|
if a.clients.kafka.ConsumerGroup != nil {
|
|
errs[1] = a.clients.kafka.ConsumerGroup.Close()
|
|
a.clients.kafka.ConsumerGroup = nil
|
|
}
|
|
}
|
|
return errors.Join(errs...)
|
|
}
|
|
|
|
func GetConfigV2(accessKey string, secretKey string, sessionToken string, region string, endpoint string) (awsv2.Config, error) {
|
|
optFns := []func(*config.LoadOptions) error{}
|
|
if region != "" {
|
|
optFns = append(optFns, config.WithRegion(region))
|
|
}
|
|
|
|
if accessKey != "" && secretKey != "" {
|
|
provider := v2creds.NewStaticCredentialsProvider(accessKey, secretKey, sessionToken)
|
|
optFns = append(optFns, config.WithCredentialsProvider(provider))
|
|
}
|
|
|
|
awsCfg, err := config.LoadDefaultConfig(context.Background(), optFns...)
|
|
if err != nil {
|
|
return awsv2.Config{}, err
|
|
}
|
|
|
|
if endpoint != "" {
|
|
awsCfg.BaseEndpoint = &endpoint
|
|
}
|
|
|
|
return awsCfg, nil
|
|
}
|