289 lines
8.1 KiB
Go
289 lines
8.1 KiB
Go
/*
|
|
* Copyright 2022 The Dragonfly Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package s3protocol
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/aws/aws-sdk-go/aws"
|
|
"github.com/aws/aws-sdk-go/aws/credentials"
|
|
"github.com/aws/aws-sdk-go/aws/session"
|
|
"github.com/aws/aws-sdk-go/service/s3"
|
|
"github.com/go-http-utils/headers"
|
|
|
|
"d7y.io/dragonfly/v2/pkg/source"
|
|
)
|
|
|
|
const S3Scheme = "s3"
|
|
|
|
const (
|
|
// AWS Region ID
|
|
region = "awsRegion"
|
|
// AWS Access key ID
|
|
endpoint = "awsEndpoint"
|
|
// AWS Access key ID
|
|
accessKeyID = "awsAccessKeyID"
|
|
// AWS Secret Access Key
|
|
secretAccessKey = "awsSecretAccessKey"
|
|
// AWS Session Token
|
|
sessionToken = "awsSessionToken"
|
|
|
|
forcePathStyle = "awsS3ForcePathStyle"
|
|
)
|
|
|
|
var _ source.ResourceClient = (*s3SourceClient)(nil)
|
|
|
|
func init() {
|
|
source.RegisterBuilder(S3Scheme, source.NewPlainResourceClientBuilder(Builder))
|
|
}
|
|
|
|
func Builder(optionYaml []byte) (source.ResourceClient, source.RequestAdapter, []source.Hook, error) {
|
|
s3Client := &s3SourceClient{}
|
|
return s3Client, s3Client.adaptor, nil, nil
|
|
}
|
|
|
|
// s3SourceClient is an implementation of the interface of source.ResourceClient.
|
|
type s3SourceClient struct {
|
|
}
|
|
|
|
func (s *s3SourceClient) adaptor(request *source.Request) *source.Request {
|
|
clonedRequest := request.Clone(request.Context())
|
|
if request.Header.Get(source.Range) != "" {
|
|
clonedRequest.Header.Set(headers.Range, fmt.Sprintf("bytes=%s", request.Header.Get(source.Range)))
|
|
clonedRequest.Header.Del(source.Range)
|
|
}
|
|
return clonedRequest
|
|
}
|
|
|
|
func (s *s3SourceClient) newAWSS3Client(request *source.Request) (*s3.S3, error) {
|
|
cfg := aws.NewConfig().WithCredentials(credentials.NewStaticCredentials(
|
|
request.Header.Get(accessKeyID), request.Header.Get(secretAccessKey), request.Header.Get(sessionToken)))
|
|
session, err := session.NewSession(cfg)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("new aws session failed: %s", err)
|
|
}
|
|
|
|
opts := []*aws.Config{cfg.WithEndpoint(request.Header.Get(endpoint))}
|
|
if r := request.Header.Get(region); r != "" {
|
|
opts = append(opts, cfg.WithRegion(r))
|
|
}
|
|
if pathStyle := request.Header.Get(forcePathStyle); strings.ToLower(pathStyle) == "true" {
|
|
opts = append(opts, cfg.WithS3ForcePathStyle(true))
|
|
}
|
|
|
|
return s3.New(session, opts...), nil
|
|
}
|
|
|
|
// GetContentLength get length of resource content
|
|
// return source.UnknownSourceFileLen if response status is not StatusOK and StatusPartialContent
|
|
func (s *s3SourceClient) GetContentLength(request *source.Request) (int64, error) {
|
|
client, err := s.newAWSS3Client(request)
|
|
if err != nil {
|
|
return -1, err
|
|
}
|
|
resp, err := client.HeadObjectWithContext(request.Context(),
|
|
&s3.HeadObjectInput{
|
|
Bucket: aws.String(request.URL.Host),
|
|
Key: aws.String(request.URL.Path),
|
|
Range: aws.String(request.Header.Get(headers.Range)),
|
|
})
|
|
if err != nil {
|
|
return -1, err
|
|
}
|
|
return *resp.ContentLength, nil
|
|
}
|
|
|
|
// IsSupportRange checks if resource supports breakpoint continuation
|
|
// return false if response status is not StatusPartialContent
|
|
func (s *s3SourceClient) IsSupportRange(request *source.Request) (bool, error) {
|
|
// TODO whether all s3 implements support range ?
|
|
return true, nil
|
|
}
|
|
|
|
// IsExpired checks if a resource received or stored is the same.
|
|
// return false and non-nil err to prevent the source from exploding if
|
|
// fails to get the result, it is considered that the source has not expired
|
|
func (s *s3SourceClient) IsExpired(request *source.Request, info *source.ExpireInfo) (bool, error) {
|
|
return false, fmt.Errorf("not implemented") // TODO: Implement
|
|
}
|
|
|
|
// Download downloads from source
|
|
func (s *s3SourceClient) Download(request *source.Request) (*source.Response, error) {
|
|
client, err := s.newAWSS3Client(request)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resp, err := client.GetObjectWithContext(request.Context(),
|
|
&s3.GetObjectInput{
|
|
Bucket: aws.String(request.URL.Host),
|
|
Key: aws.String(request.URL.Path),
|
|
// TODO more header pass to GetObjectInput
|
|
Range: aws.String(request.Header.Get(headers.Range)),
|
|
})
|
|
|
|
if err != nil {
|
|
// TODO parse error details
|
|
return nil, err
|
|
}
|
|
|
|
hdr := source.Header{}
|
|
if resp.Expires != nil {
|
|
hdr[headers.Expires] = []string{*resp.Expires}
|
|
}
|
|
|
|
if resp.CacheControl != nil {
|
|
hdr[headers.CacheControl] = []string{*resp.CacheControl}
|
|
}
|
|
|
|
var contengLength int64 = -1
|
|
if resp.ContentLength != nil {
|
|
contengLength = *resp.ContentLength
|
|
}
|
|
|
|
return &source.Response{
|
|
Status: "OK",
|
|
StatusCode: http.StatusOK,
|
|
Header: hdr,
|
|
Body: resp.Body,
|
|
ContentLength: contengLength,
|
|
Validate: func() error {
|
|
return nil
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
// GetLastModified gets last modified timestamp milliseconds of resource
|
|
func (s *s3SourceClient) GetLastModified(request *source.Request) (int64, error) {
|
|
client, err := s.newAWSS3Client(request)
|
|
if err != nil {
|
|
return -1, err
|
|
}
|
|
resp, err := client.HeadObjectWithContext(request.Context(), &s3.HeadObjectInput{
|
|
Bucket: aws.String(request.URL.Host),
|
|
Key: aws.String(request.URL.Path),
|
|
})
|
|
if err != nil {
|
|
return -1, err
|
|
}
|
|
return resp.LastModified.UnixMilli(), nil
|
|
}
|
|
|
|
func (s *s3SourceClient) List(request *source.Request) (urls []source.URLEntry, err error) {
|
|
client, err := s.newAWSS3Client(request)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get s3 client: %w", err)
|
|
}
|
|
// if it's an object, just return it.
|
|
isDir, err := s.isDirectory(client, request)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// if request is a single file, just return
|
|
if !isDir {
|
|
return []source.URLEntry{buildURLEntry(false, request.URL)}, nil
|
|
}
|
|
|
|
// list all files
|
|
path := addTrailingSlash(request.URL.Path)
|
|
var continuationToken *string
|
|
delimiter := "/"
|
|
|
|
for {
|
|
output, err := client.ListObjectsV2WithContext(
|
|
request.Context(),
|
|
&s3.ListObjectsV2Input{
|
|
Bucket: aws.String(request.URL.Host),
|
|
Prefix: aws.String(path),
|
|
MaxKeys: aws.Int64(1000),
|
|
ContinuationToken: continuationToken,
|
|
Delimiter: &delimiter,
|
|
})
|
|
if err != nil {
|
|
return urls, fmt.Errorf("list s3 object %s/%s: %w", request.URL.Host, path, err)
|
|
}
|
|
|
|
for _, object := range output.Contents {
|
|
if *object.Key != *output.Prefix {
|
|
url := *request.URL
|
|
url.Path = addLeadingSlash(*object.Key)
|
|
urls = append(urls, buildURLEntry(false, &url))
|
|
}
|
|
}
|
|
|
|
for _, prefix := range output.CommonPrefixes {
|
|
url := *request.URL
|
|
url.Path = addLeadingSlash(*prefix.Prefix)
|
|
urls = append(urls, buildURLEntry(true, &url))
|
|
}
|
|
|
|
if output.IsTruncated == nil || !*output.IsTruncated {
|
|
break
|
|
}
|
|
continuationToken = output.NextContinuationToken
|
|
}
|
|
return urls, nil
|
|
}
|
|
|
|
func (s *s3SourceClient) isDirectory(client *s3.S3, request *source.Request) (bool, error) {
|
|
uPath := addTrailingSlash(request.URL.Path)
|
|
delimiter := "/"
|
|
output, err := client.ListObjectsV2WithContext(
|
|
request.Context(),
|
|
&s3.ListObjectsV2Input{
|
|
Bucket: aws.String(request.URL.Host),
|
|
Prefix: aws.String(uPath),
|
|
MaxKeys: aws.Int64(1),
|
|
Delimiter: &delimiter,
|
|
})
|
|
if err != nil {
|
|
return false, fmt.Errorf("list s3 object %s/%s: %w", request.URL.Host, uPath, err)
|
|
}
|
|
if len(output.Contents)+len(output.CommonPrefixes) > 0 {
|
|
return true, nil
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
func buildURLEntry(isDir bool, url *url.URL) source.URLEntry {
|
|
if isDir {
|
|
url.Path = addTrailingSlash(url.Path)
|
|
list := strings.Split(url.Path, "/")
|
|
return source.URLEntry{URL: url, Name: list[len(list)-2], IsDir: true}
|
|
}
|
|
_, name := filepath.Split(url.Path)
|
|
return source.URLEntry{URL: url, Name: name, IsDir: false}
|
|
}
|
|
|
|
func addLeadingSlash(s string) string {
|
|
if strings.HasPrefix(s, "/") {
|
|
return s
|
|
}
|
|
return "/" + s
|
|
}
|
|
|
|
func addTrailingSlash(s string) string {
|
|
if strings.HasSuffix(s, "/") {
|
|
return s
|
|
}
|
|
return s + "/"
|
|
}
|