dragonfly/pkg/source/clients/s3protocol/s3_source_client.go

289 lines
8.1 KiB
Go

/*
* Copyright 2022 The Dragonfly Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package s3protocol
import (
"fmt"
"net/http"
"net/url"
"path/filepath"
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/go-http-utils/headers"
"d7y.io/dragonfly/v2/pkg/source"
)
const S3Scheme = "s3"
const (
// AWS Region ID
region = "awsRegion"
// AWS Access key ID
endpoint = "awsEndpoint"
// AWS Access key ID
accessKeyID = "awsAccessKeyID"
// AWS Secret Access Key
secretAccessKey = "awsSecretAccessKey"
// AWS Session Token
sessionToken = "awsSessionToken"
forcePathStyle = "awsS3ForcePathStyle"
)
var _ source.ResourceClient = (*s3SourceClient)(nil)
func init() {
source.RegisterBuilder(S3Scheme, source.NewPlainResourceClientBuilder(Builder))
}
func Builder(optionYaml []byte) (source.ResourceClient, source.RequestAdapter, []source.Hook, error) {
s3Client := &s3SourceClient{}
return s3Client, s3Client.adaptor, nil, nil
}
// s3SourceClient is an implementation of the interface of source.ResourceClient.
type s3SourceClient struct {
}
func (s *s3SourceClient) adaptor(request *source.Request) *source.Request {
clonedRequest := request.Clone(request.Context())
if request.Header.Get(source.Range) != "" {
clonedRequest.Header.Set(headers.Range, fmt.Sprintf("bytes=%s", request.Header.Get(source.Range)))
clonedRequest.Header.Del(source.Range)
}
return clonedRequest
}
func (s *s3SourceClient) newAWSS3Client(request *source.Request) (*s3.S3, error) {
cfg := aws.NewConfig().WithCredentials(credentials.NewStaticCredentials(
request.Header.Get(accessKeyID), request.Header.Get(secretAccessKey), request.Header.Get(sessionToken)))
session, err := session.NewSession(cfg)
if err != nil {
return nil, fmt.Errorf("new aws session failed: %s", err)
}
opts := []*aws.Config{cfg.WithEndpoint(request.Header.Get(endpoint))}
if r := request.Header.Get(region); r != "" {
opts = append(opts, cfg.WithRegion(r))
}
if pathStyle := request.Header.Get(forcePathStyle); strings.ToLower(pathStyle) == "true" {
opts = append(opts, cfg.WithS3ForcePathStyle(true))
}
return s3.New(session, opts...), nil
}
// GetContentLength get length of resource content
// return source.UnknownSourceFileLen if response status is not StatusOK and StatusPartialContent
func (s *s3SourceClient) GetContentLength(request *source.Request) (int64, error) {
client, err := s.newAWSS3Client(request)
if err != nil {
return -1, err
}
resp, err := client.HeadObjectWithContext(request.Context(),
&s3.HeadObjectInput{
Bucket: aws.String(request.URL.Host),
Key: aws.String(request.URL.Path),
Range: aws.String(request.Header.Get(headers.Range)),
})
if err != nil {
return -1, err
}
return *resp.ContentLength, nil
}
// IsSupportRange checks if resource supports breakpoint continuation
// return false if response status is not StatusPartialContent
func (s *s3SourceClient) IsSupportRange(request *source.Request) (bool, error) {
// TODO whether all s3 implements support range ?
return true, nil
}
// IsExpired checks if a resource received or stored is the same.
// return false and non-nil err to prevent the source from exploding if
// fails to get the result, it is considered that the source has not expired
func (s *s3SourceClient) IsExpired(request *source.Request, info *source.ExpireInfo) (bool, error) {
return false, fmt.Errorf("not implemented") // TODO: Implement
}
// Download downloads from source
func (s *s3SourceClient) Download(request *source.Request) (*source.Response, error) {
client, err := s.newAWSS3Client(request)
if err != nil {
return nil, err
}
resp, err := client.GetObjectWithContext(request.Context(),
&s3.GetObjectInput{
Bucket: aws.String(request.URL.Host),
Key: aws.String(request.URL.Path),
// TODO more header pass to GetObjectInput
Range: aws.String(request.Header.Get(headers.Range)),
})
if err != nil {
// TODO parse error details
return nil, err
}
hdr := source.Header{}
if resp.Expires != nil {
hdr[headers.Expires] = []string{*resp.Expires}
}
if resp.CacheControl != nil {
hdr[headers.CacheControl] = []string{*resp.CacheControl}
}
var contengLength int64 = -1
if resp.ContentLength != nil {
contengLength = *resp.ContentLength
}
return &source.Response{
Status: "OK",
StatusCode: http.StatusOK,
Header: hdr,
Body: resp.Body,
ContentLength: contengLength,
Validate: func() error {
return nil
},
}, nil
}
// GetLastModified gets last modified timestamp milliseconds of resource
func (s *s3SourceClient) GetLastModified(request *source.Request) (int64, error) {
client, err := s.newAWSS3Client(request)
if err != nil {
return -1, err
}
resp, err := client.HeadObjectWithContext(request.Context(), &s3.HeadObjectInput{
Bucket: aws.String(request.URL.Host),
Key: aws.String(request.URL.Path),
})
if err != nil {
return -1, err
}
return resp.LastModified.UnixMilli(), nil
}
func (s *s3SourceClient) List(request *source.Request) (urls []source.URLEntry, err error) {
client, err := s.newAWSS3Client(request)
if err != nil {
return nil, fmt.Errorf("get s3 client: %w", err)
}
// if it's an object, just return it.
isDir, err := s.isDirectory(client, request)
if err != nil {
return nil, err
}
// if request is a single file, just return
if !isDir {
return []source.URLEntry{buildURLEntry(false, request.URL)}, nil
}
// list all files
path := addTrailingSlash(request.URL.Path)
var continuationToken *string
delimiter := "/"
for {
output, err := client.ListObjectsV2WithContext(
request.Context(),
&s3.ListObjectsV2Input{
Bucket: aws.String(request.URL.Host),
Prefix: aws.String(path),
MaxKeys: aws.Int64(1000),
ContinuationToken: continuationToken,
Delimiter: &delimiter,
})
if err != nil {
return urls, fmt.Errorf("list s3 object %s/%s: %w", request.URL.Host, path, err)
}
for _, object := range output.Contents {
if *object.Key != *output.Prefix {
url := *request.URL
url.Path = addLeadingSlash(*object.Key)
urls = append(urls, buildURLEntry(false, &url))
}
}
for _, prefix := range output.CommonPrefixes {
url := *request.URL
url.Path = addLeadingSlash(*prefix.Prefix)
urls = append(urls, buildURLEntry(true, &url))
}
if output.IsTruncated == nil || !*output.IsTruncated {
break
}
continuationToken = output.NextContinuationToken
}
return urls, nil
}
func (s *s3SourceClient) isDirectory(client *s3.S3, request *source.Request) (bool, error) {
uPath := addTrailingSlash(request.URL.Path)
delimiter := "/"
output, err := client.ListObjectsV2WithContext(
request.Context(),
&s3.ListObjectsV2Input{
Bucket: aws.String(request.URL.Host),
Prefix: aws.String(uPath),
MaxKeys: aws.Int64(1),
Delimiter: &delimiter,
})
if err != nil {
return false, fmt.Errorf("list s3 object %s/%s: %w", request.URL.Host, uPath, err)
}
if len(output.Contents)+len(output.CommonPrefixes) > 0 {
return true, nil
}
return false, nil
}
func buildURLEntry(isDir bool, url *url.URL) source.URLEntry {
if isDir {
url.Path = addTrailingSlash(url.Path)
list := strings.Split(url.Path, "/")
return source.URLEntry{URL: url, Name: list[len(list)-2], IsDir: true}
}
_, name := filepath.Split(url.Path)
return source.URLEntry{URL: url, Name: name, IsDir: false}
}
func addLeadingSlash(s string) string {
if strings.HasPrefix(s, "/") {
return s
}
return "/" + s
}
func addTrailingSlash(s string) string {
if strings.HasSuffix(s, "/") {
return s
}
return s + "/"
}