components-contrib/common/authentication/aws/static.go

445 lines
11 KiB
Go

/*
Copyright 2021 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package aws
import (
"context"
"errors"
"fmt"
"strconv"
"sync"
"time"
awsv2 "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
v2creds "github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/feature/rds/auth"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/dapr/kit/logger"
)
type StaticAuth struct {
mu sync.RWMutex
logger logger.Logger
region *string
endpoint *string
accessKey *string
secretKey *string
sessionToken string
assumeRoleARN *string
sessionName string
session *session.Session
cfg *aws.Config
clients *Clients
}
func newStaticIAM(_ context.Context, opts Options, cfg *aws.Config) (*StaticAuth, error) {
auth := &StaticAuth{
logger: opts.Logger,
cfg: func() *aws.Config {
// if nil is passed or it's just a default cfg,
// then we use the options to build the aws cfg.
if cfg != nil && cfg != aws.NewConfig() {
return cfg
}
return GetConfig(opts)
}(),
clients: newClients(),
}
if opts.Region != "" {
auth.region = &opts.Region
}
if opts.Endpoint != "" {
auth.endpoint = &opts.Endpoint
}
if opts.AccessKey != "" {
auth.accessKey = &opts.AccessKey
}
if opts.SecretKey != "" {
auth.secretKey = &opts.SecretKey
}
if opts.SessionToken != "" {
auth.sessionToken = opts.SessionToken
}
if opts.AssumeRoleARN != "" {
auth.assumeRoleARN = &opts.AssumeRoleARN
}
if opts.SessionName != "" {
auth.sessionName = opts.SessionName
}
initialSession, err := auth.createSession()
if err != nil {
return nil, fmt.Errorf("failed to get token client: %v", err)
}
auth.session = initialSession
return auth, nil
}
// This is to be used only for test purposes to inject mocked clients
func (a *StaticAuth) WithMockClients(clients *Clients) {
a.clients = clients
}
func (a *StaticAuth) S3() *S3Clients {
a.mu.Lock()
defer a.mu.Unlock()
if a.clients.s3 != nil {
return a.clients.s3
}
s3Clients := S3Clients{}
a.clients.s3 = &s3Clients
a.clients.s3.New(a.session)
return a.clients.s3
}
func (a *StaticAuth) DynamoDB() *DynamoDBClients {
a.mu.Lock()
defer a.mu.Unlock()
if a.clients.Dynamo != nil {
return a.clients.Dynamo
}
clients := DynamoDBClients{}
a.clients.Dynamo = &clients
a.clients.Dynamo.New(a.session)
return a.clients.Dynamo
}
func (a *StaticAuth) Sqs() *SqsClients {
a.mu.Lock()
defer a.mu.Unlock()
if a.clients.sqs != nil {
return a.clients.sqs
}
clients := SqsClients{}
a.clients.sqs = &clients
a.clients.sqs.New(a.session)
return a.clients.sqs
}
func (a *StaticAuth) Sns() *SnsClients {
a.mu.Lock()
defer a.mu.Unlock()
if a.clients.sns != nil {
return a.clients.sns
}
clients := SnsClients{}
a.clients.sns = &clients
a.clients.sns.New(a.session)
return a.clients.sns
}
func (a *StaticAuth) SnsSqs() *SnsSqsClients {
a.mu.Lock()
defer a.mu.Unlock()
if a.clients.snssqs != nil {
return a.clients.snssqs
}
clients := SnsSqsClients{}
a.clients.snssqs = &clients
a.clients.snssqs.New(a.session)
return a.clients.snssqs
}
func (a *StaticAuth) SecretManager() *SecretManagerClients {
a.mu.Lock()
defer a.mu.Unlock()
if a.clients.Secret != nil {
return a.clients.Secret
}
clients := SecretManagerClients{}
a.clients.Secret = &clients
a.clients.Secret.New(a.session)
return a.clients.Secret
}
func (a *StaticAuth) ParameterStore() *ParameterStoreClients {
a.mu.Lock()
defer a.mu.Unlock()
if a.clients.ParameterStore != nil {
return a.clients.ParameterStore
}
clients := ParameterStoreClients{}
a.clients.ParameterStore = &clients
a.clients.ParameterStore.New(a.session)
return a.clients.ParameterStore
}
func (a *StaticAuth) Kinesis() *KinesisClients {
a.mu.Lock()
defer a.mu.Unlock()
if a.clients.kinesis != nil {
return a.clients.kinesis
}
clients := KinesisClients{}
a.clients.kinesis = &clients
a.clients.kinesis.New(a.session)
return a.clients.kinesis
}
func (a *StaticAuth) Ses() *SesClients {
a.mu.Lock()
defer a.mu.Unlock()
if a.clients.ses != nil {
return a.clients.ses
}
clients := SesClients{}
a.clients.ses = &clients
a.clients.ses.New(a.session)
return a.clients.ses
}
func (a *StaticAuth) UpdatePostgres(ctx context.Context, poolConfig *pgxpool.Config) {
a.mu.Lock()
defer a.mu.Unlock()
// Set max connection lifetime to 8 minutes in postgres connection pool configuration.
// Note: this will refresh connections before the 15 min expiration on the IAM AWS auth token,
// while leveraging the BeforeConnect hook to recreate the token in time dynamically.
poolConfig.MaxConnLifetime = time.Minute * 8
// Setup connection pool config needed for AWS IAM authentication
poolConfig.BeforeConnect = func(ctx context.Context, pgConfig *pgx.ConnConfig) error {
// Manually reset auth token with aws and reset the config password using the new iam token
pwd, err := a.getDatabaseToken(ctx, poolConfig)
if err != nil {
return fmt.Errorf("failed to get database token: %w", err)
}
pgConfig.Password = pwd
poolConfig.ConnConfig.Password = pwd
return nil
}
}
// https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/UsingWithRDS.IAMDBAuth.Connecting.Go.html
func (a *StaticAuth) getDatabaseToken(ctx context.Context, poolConfig *pgxpool.Config) (string, error) {
dbEndpoint := poolConfig.ConnConfig.Host + ":" + strconv.Itoa(int(poolConfig.ConnConfig.Port))
// First, check if there are credentials set explicitly with accesskey and secretkey
var creds credentials.Value
if a.session != nil {
var err error
creds, err = a.session.Config.Credentials.Get()
if err != nil {
a.logger.Infof("failed to get access key and secret key, will fallback to reading the default AWS credentials file: %w", err)
}
}
if creds.AccessKeyID != "" && creds.SecretAccessKey != "" {
creds, err := a.session.Config.Credentials.Get()
if err != nil {
return "", fmt.Errorf("failed to retrieve session credentials: %w", err)
}
awsCfg := v2creds.NewStaticCredentialsProvider(creds.AccessKeyID, creds.SecretAccessKey, creds.SessionToken)
authenticationToken, err := auth.BuildAuthToken(
ctx, dbEndpoint, *a.region, poolConfig.ConnConfig.User, awsCfg)
if err != nil {
return "", fmt.Errorf("failed to create AWS authentication token: %w", err)
}
return authenticationToken, nil
}
// Second, check if we are assuming a role instead
if a.assumeRoleARN != nil {
awsCfg, err := config.LoadDefaultConfig(ctx)
if err != nil {
return "", fmt.Errorf("failed to load default AWS authentication configuration %w", err)
}
stsClient := sts.NewFromConfig(awsCfg)
assumeRoleCfg, err := config.LoadDefaultConfig(ctx,
config.WithRegion(*a.region),
config.WithCredentialsProvider(
awsv2.NewCredentialsCache(
stscreds.NewAssumeRoleProvider(stsClient, *a.assumeRoleARN, func(aro *stscreds.AssumeRoleOptions) {
if a.sessionName != "" {
aro.RoleSessionName = a.sessionName
}
}),
),
),
)
if err != nil {
return "", fmt.Errorf("failed to assume aws role %w", err)
}
authenticationToken, err := auth.BuildAuthToken(
ctx, dbEndpoint, *a.region, poolConfig.ConnConfig.User, assumeRoleCfg.Credentials)
if err != nil {
return "", fmt.Errorf("failed to create AWS authentication token: %w", err)
}
return authenticationToken, nil
}
// Lastly, and by default, just use the default aws configuration
awsCfg, err := config.LoadDefaultConfig(ctx)
if err != nil {
return "", fmt.Errorf("failed to load default AWS authentication configuration %w", err)
}
authenticationToken, err := auth.BuildAuthToken(ctx, dbEndpoint, *a.region, poolConfig.ConnConfig.User, awsCfg.Credentials)
if err != nil {
return "", fmt.Errorf("failed to create AWS authentication token: %w", err)
}
return authenticationToken, nil
}
func (a *StaticAuth) Kafka(opts KafkaOptions) (*KafkaClients, error) {
a.mu.Lock()
defer a.mu.Unlock()
// This means we've already set the config in our New function
// to use the SASL token provider.
if a.clients.kafka != nil {
return a.clients.kafka, nil
}
a.clients.kafka = initKafkaClients(opts)
// static auth has additional fields we need added,
// so we add those static auth specific fields here,
// and the rest of the token provider fields are added in New()
tokenProvider := mskTokenProvider{}
if a.assumeRoleARN != nil {
tokenProvider.awsIamRoleArn = *a.assumeRoleARN
}
if a.sessionName != "" {
tokenProvider.awsStsSessionName = a.sessionName
}
err := a.clients.kafka.New(a.session, &tokenProvider)
if err != nil {
return nil, fmt.Errorf("failed to create AWS IAM Kafka config: %w", err)
}
return a.clients.kafka, nil
}
func (a *StaticAuth) createSession() (*session.Session, error) {
var awsConfig *aws.Config
if a.cfg == nil {
awsConfig = aws.NewConfig()
} else {
awsConfig = a.cfg
}
if a.region != nil {
awsConfig = awsConfig.WithRegion(*a.region)
}
if a.accessKey != nil && a.secretKey != nil {
// session token is an option field
awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(*a.accessKey, *a.secretKey, a.sessionToken))
}
if a.endpoint != nil {
awsConfig = awsConfig.WithEndpoint(*a.endpoint)
}
// TODO support assume role for all aws components
awsSession, err := session.NewSessionWithOptions(session.Options{
Config: *awsConfig,
SharedConfigState: session.SharedConfigEnable,
})
if err != nil {
return nil, err
}
userAgentHandler := request.NamedHandler{
Name: "UserAgentHandler",
Fn: request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion),
}
awsSession.Handlers.Build.PushBackNamed(userAgentHandler)
return awsSession, nil
}
func (a *StaticAuth) Close() error {
a.mu.Lock()
defer a.mu.Unlock()
errs := make([]error, 2)
if a.clients.kafka != nil {
if a.clients.kafka.Producer != nil {
errs[0] = a.clients.kafka.Producer.Close()
a.clients.kafka.Producer = nil
}
if a.clients.kafka.ConsumerGroup != nil {
errs[1] = a.clients.kafka.ConsumerGroup.Close()
a.clients.kafka.ConsumerGroup = nil
}
}
return errors.Join(errs...)
}
func GetConfigV2(accessKey string, secretKey string, sessionToken string, region string, endpoint string) (awsv2.Config, error) {
optFns := []func(*config.LoadOptions) error{}
if region != "" {
optFns = append(optFns, config.WithRegion(region))
}
if accessKey != "" && secretKey != "" {
provider := v2creds.NewStaticCredentialsProvider(accessKey, secretKey, sessionToken)
optFns = append(optFns, config.WithCredentialsProvider(provider))
}
awsCfg, err := config.LoadDefaultConfig(context.Background(), optFns...)
if err != nil {
return awsv2.Config{}, err
}
if endpoint != "" {
awsCfg.BaseEndpoint = &endpoint
}
return awsCfg, nil
}