171 lines
4.7 KiB
Go
171 lines
4.7 KiB
Go
package s3
|
|
|
|
import (
|
|
"context"
|
|
"crypto/hmac"
|
|
"crypto/sha256"
|
|
"errors"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/aws/aws-sdk-go-v2/aws"
|
|
"github.com/aws/aws-sdk-go-v2/internal/sdk"
|
|
"github.com/aws/aws-sdk-go-v2/internal/sync/singleflight"
|
|
"github.com/aws/smithy-go/container/private/cache"
|
|
"github.com/aws/smithy-go/container/private/cache/lru"
|
|
)
|
|
|
|
const s3ExpressCacheCap = 100
|
|
|
|
const s3ExpressRefreshWindow = 1 * time.Minute
|
|
|
|
type cacheKey struct {
|
|
CredentialsHash string // hmac(sigv4 akid, sigv4 secret)
|
|
Bucket string
|
|
}
|
|
|
|
func (c cacheKey) Slug() string {
|
|
return fmt.Sprintf("%s%s", c.CredentialsHash, c.Bucket)
|
|
}
|
|
|
|
type sessionCredsCache struct {
|
|
mu sync.Mutex
|
|
cache cache.Cache
|
|
}
|
|
|
|
func (c *sessionCredsCache) Get(key cacheKey) (*aws.Credentials, bool) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
if v, ok := c.cache.Get(key); ok {
|
|
return v.(*aws.Credentials), true
|
|
}
|
|
return nil, false
|
|
}
|
|
|
|
func (c *sessionCredsCache) Put(key cacheKey, creds *aws.Credentials) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
c.cache.Put(key, creds)
|
|
}
|
|
|
|
// The default S3Express provider uses an LRU cache with a capacity of 100.
|
|
//
|
|
// Credentials will be refreshed asynchronously when a Retrieve() call is made
|
|
// for cached credentials within an expiry window (1 minute, currently
|
|
// non-configurable).
|
|
type defaultS3ExpressCredentialsProvider struct {
|
|
sf singleflight.Group
|
|
|
|
client createSessionAPIClient
|
|
cache *sessionCredsCache
|
|
refreshWindow time.Duration
|
|
v4creds aws.CredentialsProvider // underlying credentials used for CreateSession
|
|
}
|
|
|
|
type createSessionAPIClient interface {
|
|
CreateSession(context.Context, *CreateSessionInput, ...func(*Options)) (*CreateSessionOutput, error)
|
|
}
|
|
|
|
func newDefaultS3ExpressCredentialsProvider() *defaultS3ExpressCredentialsProvider {
|
|
return &defaultS3ExpressCredentialsProvider{
|
|
cache: &sessionCredsCache{
|
|
cache: lru.New(s3ExpressCacheCap),
|
|
},
|
|
refreshWindow: s3ExpressRefreshWindow,
|
|
}
|
|
}
|
|
|
|
// returns a cloned provider using new base credentials, used when per-op
|
|
// config mutations change the credentials provider
|
|
func (p *defaultS3ExpressCredentialsProvider) CloneWithBaseCredentials(v4creds aws.CredentialsProvider) *defaultS3ExpressCredentialsProvider {
|
|
return &defaultS3ExpressCredentialsProvider{
|
|
client: p.client,
|
|
cache: p.cache,
|
|
refreshWindow: p.refreshWindow,
|
|
v4creds: v4creds,
|
|
}
|
|
}
|
|
|
|
func (p *defaultS3ExpressCredentialsProvider) Retrieve(ctx context.Context, bucket string) (aws.Credentials, error) {
|
|
v4creds, err := p.v4creds.Retrieve(ctx)
|
|
if err != nil {
|
|
return aws.Credentials{}, fmt.Errorf("get sigv4 creds: %w", err)
|
|
}
|
|
|
|
key := cacheKey{
|
|
CredentialsHash: gethmac(v4creds.AccessKeyID, v4creds.SecretAccessKey),
|
|
Bucket: bucket,
|
|
}
|
|
creds, ok := p.cache.Get(key)
|
|
if !ok || creds.Expired() {
|
|
return p.awaitDoChanRetrieve(ctx, key)
|
|
}
|
|
|
|
if creds.Expires.Sub(sdk.NowTime()) <= p.refreshWindow {
|
|
p.doChanRetrieve(ctx, key)
|
|
}
|
|
|
|
return *creds, nil
|
|
}
|
|
|
|
func (p *defaultS3ExpressCredentialsProvider) doChanRetrieve(ctx context.Context, key cacheKey) <-chan singleflight.Result {
|
|
return p.sf.DoChan(key.Slug(), func() (interface{}, error) {
|
|
return p.retrieve(ctx, key)
|
|
})
|
|
}
|
|
|
|
func (p *defaultS3ExpressCredentialsProvider) awaitDoChanRetrieve(ctx context.Context, key cacheKey) (aws.Credentials, error) {
|
|
ch := p.doChanRetrieve(ctx, key)
|
|
|
|
select {
|
|
case r := <-ch:
|
|
return r.Val.(aws.Credentials), r.Err
|
|
case <-ctx.Done():
|
|
return aws.Credentials{}, errors.New("s3express retrieve credentials canceled")
|
|
}
|
|
}
|
|
|
|
func (p *defaultS3ExpressCredentialsProvider) retrieve(ctx context.Context, key cacheKey) (aws.Credentials, error) {
|
|
resp, err := p.client.CreateSession(ctx, &CreateSessionInput{
|
|
Bucket: aws.String(key.Bucket),
|
|
})
|
|
if err != nil {
|
|
return aws.Credentials{}, err
|
|
}
|
|
|
|
creds, err := credentialsFromResponse(resp)
|
|
if err != nil {
|
|
return aws.Credentials{}, err
|
|
}
|
|
|
|
p.cache.Put(key, creds)
|
|
return *creds, nil
|
|
}
|
|
|
|
func credentialsFromResponse(o *CreateSessionOutput) (*aws.Credentials, error) {
|
|
if o.Credentials == nil {
|
|
return nil, errors.New("s3express session credentials unset")
|
|
}
|
|
|
|
if o.Credentials.AccessKeyId == nil || o.Credentials.SecretAccessKey == nil || o.Credentials.SessionToken == nil || o.Credentials.Expiration == nil {
|
|
return nil, errors.New("s3express session credentials missing one or more required fields")
|
|
}
|
|
|
|
return &aws.Credentials{
|
|
AccessKeyID: *o.Credentials.AccessKeyId,
|
|
SecretAccessKey: *o.Credentials.SecretAccessKey,
|
|
SessionToken: *o.Credentials.SessionToken,
|
|
CanExpire: true,
|
|
Expires: *o.Credentials.Expiration,
|
|
}, nil
|
|
}
|
|
|
|
func gethmac(p, key string) string {
|
|
hash := hmac.New(sha256.New, []byte(key))
|
|
hash.Write([]byte(p))
|
|
return string(hash.Sum(nil))
|
|
}
|