feat: update source.Response and source client interface (#945)
* feat: update source.Response and source client interface Signed-off-by: Jim Ma <majinjing3@gmail.com>
This commit is contained in:
parent
0a152e01e9
commit
beaf4ce18d
|
|
@ -47,16 +47,16 @@ func (cm *manager) download(ctx context.Context, seedTask *task.SeedTask, breakP
|
|||
if !stringutils.IsBlank(breakRange) {
|
||||
downloadRequest.Header.Add(source.Range, breakRange)
|
||||
}
|
||||
body, expireInfo, err := source.DownloadWithExpireInfo(downloadRequest)
|
||||
response, err := source.Download(downloadRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// update Expire info
|
||||
cm.updateExpireInfo(seedTask.ID, map[string]string{
|
||||
source.LastModified: expireInfo.LastModified,
|
||||
source.ETag: expireInfo.ETag,
|
||||
source.LastModified: response.Header.Get(source.LastModified),
|
||||
source.ETag: response.Header.Get(source.ETag),
|
||||
})
|
||||
return body, err
|
||||
return response.Body, err
|
||||
}
|
||||
|
||||
func getBreakRange(breakPoint int64, taskRange string, fileTotalLength int64) (string, error) {
|
||||
|
|
|
|||
|
|
@ -94,31 +94,22 @@ func (suite *CDNManagerTestSuite) TestTriggerCDN() {
|
|||
sourceClient.EXPECT().IsSupportRange(gomock.Any()).Return(true, nil).AnyTimes()
|
||||
sourceClient.EXPECT().IsExpired(gomock.Any(), gomock.Any()).Return(false, nil).AnyTimes()
|
||||
sourceClient.EXPECT().Download(gomock.Any()).DoAndReturn(
|
||||
func(request *source.Request) (io.ReadCloser, error) {
|
||||
func(request *source.Request) (*source.Response, error) {
|
||||
content, _ := os.ReadFile("../../testdata/cdn/go.html")
|
||||
if request.Header.Get(source.Range) != "" {
|
||||
parsed, _ := rangeutils.GetRange(request.Header.Get(source.Range))
|
||||
return io.NopCloser(io.NewSectionReader(strings.NewReader(string(content)), int64(parsed.StartIndex), int64(parsed.EndIndex))), nil
|
||||
}
|
||||
return io.NopCloser(strings.NewReader(string(content))), nil
|
||||
},
|
||||
).AnyTimes()
|
||||
sourceClient.EXPECT().DownloadWithExpireInfo(gomock.Any()).DoAndReturn(
|
||||
func(request *source.Request) (io.ReadCloser, *source.ExpireInfo, error) {
|
||||
content, _ := os.ReadFile("../../testdata/cdn/go.html")
|
||||
if request.Header.Get(source.Range) != "" {
|
||||
parsed, _ := rangeutils.GetRange(request.Header.Get(source.Range))
|
||||
return io.NopCloser(io.NewSectionReader(strings.NewReader(string(content)), int64(parsed.StartIndex), int64(parsed.EndIndex))),
|
||||
&source.ExpireInfo{
|
||||
return source.NewResponse(
|
||||
io.NopCloser(io.NewSectionReader(strings.NewReader(string(content)), int64(parsed.StartIndex), int64(parsed.EndIndex))),
|
||||
source.WithExpireInfo(source.ExpireInfo{
|
||||
LastModified: "Sun, 06 Jun 2021 12:52:30 GMT",
|
||||
ETag: "etag",
|
||||
}, nil
|
||||
})), nil
|
||||
}
|
||||
return io.NopCloser(strings.NewReader(string(content))),
|
||||
&source.ExpireInfo{
|
||||
return source.NewResponse(io.NopCloser(strings.NewReader(string(content))),
|
||||
source.WithExpireInfo(source.ExpireInfo{
|
||||
LastModified: "Sun, 06 Jun 2021 12:52:30 GMT",
|
||||
ETag: "etag",
|
||||
}, nil
|
||||
})), nil
|
||||
},
|
||||
).AnyTimes()
|
||||
sourceClient.EXPECT().GetLastModified(gomock.Any()).Return(
|
||||
|
|
|
|||
|
|
@ -95,9 +95,9 @@ func TestFilePeerTask_BackSource_WithContentLength(t *testing.T) {
|
|||
return -1, fmt.Errorf("unexpect url: %s", request.URL.String())
|
||||
})
|
||||
sourceClient.EXPECT().Download(gomock.Any()).DoAndReturn(
|
||||
func(request *source.Request) (io.ReadCloser, error) {
|
||||
func(request *source.Request) (*source.Response, error) {
|
||||
if request.URL.String() == url {
|
||||
return io.NopCloser(bytes.NewBuffer(testBytes)), nil
|
||||
return source.NewResponse(io.NopCloser(bytes.NewBuffer(testBytes))), nil
|
||||
}
|
||||
return nil, fmt.Errorf("unexpect url: %s", request.URL.String())
|
||||
})
|
||||
|
|
@ -220,9 +220,9 @@ func TestFilePeerTask_BackSource_WithoutContentLength(t *testing.T) {
|
|||
return -1, fmt.Errorf("unexpect url: %s", request.URL.String())
|
||||
})
|
||||
sourceClient.EXPECT().Download(gomock.Any()).DoAndReturn(
|
||||
func(request *source.Request) (io.ReadCloser, error) {
|
||||
func(request *source.Request) (*source.Response, error) {
|
||||
if request.URL.String() == url {
|
||||
return io.NopCloser(bytes.NewBuffer(testBytes)), nil
|
||||
return source.NewResponse(io.NopCloser(bytes.NewBuffer(testBytes))), nil
|
||||
}
|
||||
return nil, fmt.Errorf("unexpect url: %s", request.URL.String())
|
||||
})
|
||||
|
|
|
|||
|
|
@ -222,8 +222,8 @@ func TestStreamPeerTask_BackSource_Partial_WithContentLength(t *testing.T) {
|
|||
return int64(len(testBytes)), nil
|
||||
})
|
||||
sourceClient.EXPECT().Download(gomock.Any()).DoAndReturn(
|
||||
func(request *source.Request) (io.ReadCloser, error) {
|
||||
return io.NopCloser(bytes.NewBuffer(testBytes)), nil
|
||||
func(request *source.Request) (*source.Response, error) {
|
||||
return source.NewResponse(io.NopCloser(bytes.NewBuffer(testBytes))), nil
|
||||
})
|
||||
|
||||
pm := &pieceManager{
|
||||
|
|
|
|||
|
|
@ -90,8 +90,8 @@ func TestStreamPeerTask_BackSource_WithContentLength(t *testing.T) {
|
|||
return int64(len(testBytes)), nil
|
||||
})
|
||||
sourceClient.EXPECT().Download(source.RequestEq(request.URL.String())).DoAndReturn(
|
||||
func(request *source.Request) (io.ReadCloser, error) {
|
||||
return io.NopCloser(bytes.NewBuffer(testBytes)), nil
|
||||
func(request *source.Request) (*source.Response, error) {
|
||||
return source.NewResponse(io.NopCloser(bytes.NewBuffer(testBytes))), nil
|
||||
})
|
||||
|
||||
ptm := &peerTaskManager{
|
||||
|
|
@ -190,8 +190,8 @@ func TestStreamPeerTask_BackSource_WithoutContentLength(t *testing.T) {
|
|||
return -1, nil
|
||||
})
|
||||
sourceClient.EXPECT().Download(source.RequestEq(request.URL.String())).DoAndReturn(
|
||||
func(request *source.Request) (io.ReadCloser, error) {
|
||||
return io.NopCloser(bytes.NewBuffer(testBytes)), nil
|
||||
func(request *source.Request) (*source.Response, error) {
|
||||
return source.NewResponse(io.NopCloser(bytes.NewBuffer(testBytes))), nil
|
||||
})
|
||||
|
||||
ptm := &peerTaskManager{
|
||||
|
|
|
|||
|
|
@ -376,16 +376,16 @@ func (pm *pieceManager) DownloadSource(ctx context.Context, pt Task, request *sc
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
body, err := source.Download(downloadRequest)
|
||||
response, err := source.Download(downloadRequest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer body.Close()
|
||||
reader := body.(io.Reader)
|
||||
defer response.Body.Close()
|
||||
reader := response.Body.(io.Reader)
|
||||
|
||||
// calc total md5
|
||||
if pm.calculateDigest && request.UrlMeta.Digest != "" {
|
||||
reader = digestutils.NewDigestReader(pt.Log(), body, request.UrlMeta.Digest)
|
||||
reader = digestutils.NewDigestReader(pt.Log(), response.Body, request.UrlMeta.Digest)
|
||||
}
|
||||
|
||||
// 2. save to storage
|
||||
|
|
|
|||
|
|
@ -143,7 +143,7 @@ func downloadFromSource(ctx context.Context, cfg *config.DfgetConfig, hdr map[st
|
|||
wLog = logger.With("url", cfg.URL)
|
||||
start = time.Now()
|
||||
target *os.File
|
||||
response io.ReadCloser
|
||||
response *source.Response
|
||||
err error
|
||||
written int64
|
||||
)
|
||||
|
|
@ -164,9 +164,9 @@ func downloadFromSource(ctx context.Context, cfg *config.DfgetConfig, hdr map[st
|
|||
if response, err = source.Download(downloadRequest); err != nil {
|
||||
return err
|
||||
}
|
||||
defer response.Close()
|
||||
defer response.Body.Close()
|
||||
|
||||
if written, err = io.Copy(target, response); err != nil {
|
||||
if written, err = io.Copy(target, response.Body); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ func Test_downloadFromSource(t *testing.T) {
|
|||
}
|
||||
request, err := source.NewRequest(cfg.URL)
|
||||
assert.Nil(t, err)
|
||||
sourceClient.EXPECT().Download(request).Return(io.NopCloser(strings.NewReader(content)), nil)
|
||||
sourceClient.EXPECT().Download(request).Return(source.NewResponse(io.NopCloser(strings.NewReader(content))), nil)
|
||||
|
||||
err = downloadFromSource(context.Background(), cfg, nil)
|
||||
assert.Nil(t, err)
|
||||
|
|
|
|||
|
|
@ -111,21 +111,25 @@ func (h *hdfsSourceClient) IsExpired(request *source.Request, info *source.Expir
|
|||
return fileInfo.ModTime().Format(source.LastModifiedLayout) != info.LastModified, nil
|
||||
}
|
||||
|
||||
func (h *hdfsSourceClient) Download(request *source.Request) (io.ReadCloser, error) {
|
||||
func (h *hdfsSourceClient) Download(request *source.Request) (*source.Response, error) {
|
||||
hdfsClient, path, err := h.getHDFSClientAndPath(request.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hdfsFile, err := hdfsClient.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fileInfo := hdfsFile.Stat()
|
||||
|
||||
// default read all data when rang is nil
|
||||
var limitReadN = hdfsFile.Stat().Size()
|
||||
var limitReadN = fileInfo.Size()
|
||||
if limitReadN < 0 {
|
||||
return nil, errors.Errorf("file length is illegal, length: %d", limitReadN)
|
||||
}
|
||||
|
||||
if request.Header.Get(source.Range) != "" {
|
||||
requestRange, err := rangeutils.ParseRange(request.Header.Get(source.Range), uint64(limitReadN))
|
||||
if err != nil {
|
||||
|
|
@ -139,44 +143,12 @@ func (h *hdfsSourceClient) Download(request *source.Request) (io.ReadCloser, err
|
|||
limitReadN = int64(requestRange.Length())
|
||||
}
|
||||
|
||||
return newHdfsFileReaderClose(hdfsFile, limitReadN, hdfsFile), nil
|
||||
}
|
||||
|
||||
func (h *hdfsSourceClient) DownloadWithExpireInfo(request *source.Request) (io.ReadCloser, *source.ExpireInfo, error) {
|
||||
|
||||
hdfsClient, path, err := h.getHDFSClientAndPath(request.URL)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
hdfsFile, err := hdfsClient.Open(path)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
fileInfo := hdfsFile.Stat()
|
||||
|
||||
// default read all data when rang is nil
|
||||
var limitReadN = fileInfo.Size()
|
||||
if limitReadN < 0 {
|
||||
return nil, nil, errors.Errorf("file length is illegal, length: %d", limitReadN)
|
||||
}
|
||||
|
||||
if request.Header.Get(source.Range) != "" {
|
||||
requestRange, err := rangeutils.ParseRange(request.Header.Get(source.Range), uint64(limitReadN))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
_, err = hdfsFile.Seek(int64(requestRange.StartIndex), 0)
|
||||
if err != nil {
|
||||
hdfsFile.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
limitReadN = int64(requestRange.EndIndex - requestRange.StartIndex)
|
||||
}
|
||||
return newHdfsFileReaderClose(hdfsFile, limitReadN, hdfsFile), &source.ExpireInfo{
|
||||
response := source.NewResponse(
|
||||
newHdfsFileReaderClose(hdfsFile, limitReadN, hdfsFile),
|
||||
source.WithExpireInfo(source.ExpireInfo{
|
||||
LastModified: timeutils.Format(fileInfo.ModTime()),
|
||||
}, nil
|
||||
}))
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (h *hdfsSourceClient) GetLastModified(request *source.Request) (int64, error) {
|
||||
|
|
|
|||
|
|
@ -228,9 +228,9 @@ func Test_Download_FileExist_ByRange(t *testing.T) {
|
|||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
download, err := sourceClient.Download(request)
|
||||
response, err := sourceClient.Download(request)
|
||||
assert.Nil(t, err)
|
||||
data, _ := io.ReadAll(download)
|
||||
data, _ := io.ReadAll(response.Body)
|
||||
|
||||
assert.Equal(t, hdfsExistFileContent, string(data))
|
||||
}
|
||||
|
|
@ -285,11 +285,11 @@ func Test_DownloadWithResponseHeader_FileExist_ByRange(t *testing.T) {
|
|||
request, err := source.NewRequest(hdfsExistFileURL)
|
||||
assert.Nil(t, err)
|
||||
request.Header.Add(source.Range, rang.String())
|
||||
body, expireInfo, err := sourceClient.DownloadWithExpireInfo(request)
|
||||
response, err := sourceClient.Download(request)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, hdfsExistFileLastModified, expireInfo.LastModified)
|
||||
assert.Equal(t, hdfsExistFileLastModified, response.ExpireInfo().LastModified)
|
||||
|
||||
data, _ := io.ReadAll(body)
|
||||
data, _ := io.ReadAll(response.Body)
|
||||
assert.Equal(t, string(data), string([]byte(hdfsExistFileContent)[hdfsExistFileRangeStart:hdfsExistFileRangeEnd]))
|
||||
}
|
||||
|
||||
|
|
@ -303,10 +303,9 @@ func TestDownloadWithResponseHeader_FileNotExist(t *testing.T) {
|
|||
request, err := source.NewRequest(hdfsNotExistFileURL)
|
||||
assert.Nil(t, err)
|
||||
request.Header.Add(source.Range, rang.String())
|
||||
reader, expireInfo, err := sourceClient.DownloadWithExpireInfo(request)
|
||||
response, err := sourceClient.Download(request)
|
||||
assert.EqualError(t, err, "open /user/root/input/f3.txt: file does not exist")
|
||||
assert.Nil(t, reader)
|
||||
assert.Nil(t, expireInfo)
|
||||
assert.Nil(t, response)
|
||||
}
|
||||
|
||||
func TestGetLastModified_FileExist(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ package httpprotocol
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
|
@ -170,7 +169,7 @@ func (client *httpSourceClient) IsExpired(request *source.Request, info *source.
|
|||
LastModified)), nil
|
||||
}
|
||||
|
||||
func (client *httpSourceClient) Download(request *source.Request) (io.ReadCloser, error) {
|
||||
func (client *httpSourceClient) Download(request *source.Request) (*source.Response, error) {
|
||||
resp, err := client.doRequest(http.MethodGet, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -180,23 +179,15 @@ func (client *httpSourceClient) Download(request *source.Request) (io.ReadCloser
|
|||
resp.Body.Close()
|
||||
return nil, err
|
||||
}
|
||||
return resp.Body, nil
|
||||
}
|
||||
|
||||
func (client *httpSourceClient) DownloadWithExpireInfo(request *source.Request) (io.ReadCloser, *source.ExpireInfo, error) {
|
||||
resp, err := client.doRequest(http.MethodGet, request)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
err = source.CheckResponseCode(resp.StatusCode, []int{http.StatusOK, http.StatusPartialContent})
|
||||
if err != nil {
|
||||
resp.Body.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
return resp.Body, &source.ExpireInfo{
|
||||
response := source.NewResponse(
|
||||
resp.Body,
|
||||
source.WithExpireInfo(
|
||||
source.ExpireInfo{
|
||||
LastModified: resp.Header.Get(headers.LastModified),
|
||||
ETag: resp.Header.Get(headers.ETag),
|
||||
}, nil
|
||||
},
|
||||
))
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (client *httpSourceClient) GetLastModified(request *source.Request) (int64, error) {
|
||||
|
|
|
|||
|
|
@ -131,7 +131,6 @@ func (suite *HTTPSourceClientTestSuite) SetupTest() {
|
|||
httpmock.RegisterResponder(http.MethodGet, notfoundRawURL, httpmock.NewStringResponder(http.StatusNotFound, "not found"))
|
||||
httpmock.RegisterResponder(http.MethodGet, normalNotSupportRangeRawURL, httpmock.NewStringResponder(http.StatusOK, testContent))
|
||||
httpmock.RegisterResponder(http.MethodGet, errorRawURL, httpmock.NewErrorResponder(fmt.Errorf("error")))
|
||||
|
||||
}
|
||||
|
||||
func (suite *HTTPSourceClientTestSuite) TestNewHTTPSourceClient() {
|
||||
|
|
@ -150,12 +149,11 @@ func (suite *HTTPSourceClientTestSuite) TestHttpSourceClientDownloadWithResponse
|
|||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
timeoutRequest, err := source.NewRequestWithContext(ctx, timeoutRawURL, nil)
|
||||
suite.Nil(err)
|
||||
reader, expireInfo, err := suite.httpClient.DownloadWithExpireInfo(timeoutRequest)
|
||||
response, err := suite.httpClient.Download(timeoutRequest)
|
||||
cancel()
|
||||
suite.NotNil(err)
|
||||
suite.Equal("Get \"https://timeout.com\": context deadline exceeded", err.Error())
|
||||
suite.Nil(reader)
|
||||
suite.Nil(expireInfo)
|
||||
suite.Nil(response)
|
||||
|
||||
normalRequest, _ := source.NewRequest(normalRawURL)
|
||||
normalRangeRequest, _ := source.NewRequest(normalRawURL)
|
||||
|
|
@ -203,15 +201,16 @@ func (suite *HTTPSourceClientTestSuite) TestHttpSourceClientDownloadWithResponse
|
|||
}
|
||||
for _, tt := range tests {
|
||||
suite.Run(tt.name, func() {
|
||||
reader, expireInfo, err := suite.httpClient.DownloadWithExpireInfo(tt.request)
|
||||
response, err := suite.httpClient.Download(tt.request)
|
||||
if err != nil {
|
||||
suite.True(tt.wantErr.Error() == err.Error())
|
||||
return
|
||||
}
|
||||
bytes, err := io.ReadAll(reader)
|
||||
bytes, err := io.ReadAll(response.Body)
|
||||
suite.Nil(err)
|
||||
suite.Equal(tt.content, string(bytes))
|
||||
suite.Equal(tt.expireInfo, expireInfo)
|
||||
expireInfo := response.ExpireInfo()
|
||||
suite.Equal(tt.expireInfo, &expireInfo)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
package mock
|
||||
|
||||
import (
|
||||
io "io"
|
||||
reflect "reflect"
|
||||
|
||||
source "d7y.io/dragonfly/v2/pkg/source"
|
||||
|
|
@ -36,10 +35,10 @@ func (m *MockResourceClient) EXPECT() *MockResourceClientMockRecorder {
|
|||
}
|
||||
|
||||
// Download mocks base method.
|
||||
func (m *MockResourceClient) Download(arg0 *source.Request) (io.ReadCloser, error) {
|
||||
func (m *MockResourceClient) Download(arg0 *source.Request) (*source.Response, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Download", arg0)
|
||||
ret0, _ := ret[0].(io.ReadCloser)
|
||||
ret0, _ := ret[0].(*source.Response)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
|
@ -50,22 +49,6 @@ func (mr *MockResourceClientMockRecorder) Download(arg0 interface{}) *gomock.Cal
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Download", reflect.TypeOf((*MockResourceClient)(nil).Download), arg0)
|
||||
}
|
||||
|
||||
// DownloadWithExpireInfo mocks base method.
|
||||
func (m *MockResourceClient) DownloadWithExpireInfo(arg0 *source.Request) (io.ReadCloser, *source.ExpireInfo, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DownloadWithExpireInfo", arg0)
|
||||
ret0, _ := ret[0].(io.ReadCloser)
|
||||
ret1, _ := ret[1].(*source.ExpireInfo)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// DownloadWithExpireInfo indicates an expected call of DownloadWithExpireInfo.
|
||||
func (mr *MockResourceClientMockRecorder) DownloadWithExpireInfo(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DownloadWithExpireInfo", reflect.TypeOf((*MockResourceClient)(nil).DownloadWithExpireInfo), arg0)
|
||||
}
|
||||
|
||||
// GetContentLength mocks base method.
|
||||
func (m *MockResourceClient) GetContentLength(arg0 *source.Request) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ package ossprotocol
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
|
@ -147,44 +146,33 @@ func (osc *ossSourceClient) IsExpired(request *source.Request, info *source.Expi
|
|||
return !(resHeader.Get(oss.HTTPHeaderEtag) == info.ETag || resHeader.Get(oss.HTTPHeaderLastModified) == info.LastModified), nil
|
||||
}
|
||||
|
||||
func (osc *ossSourceClient) Download(request *source.Request) (io.ReadCloser, error) {
|
||||
func (osc *ossSourceClient) Download(request *source.Request) (*source.Response, error) {
|
||||
client, err := osc.getClient(request.Header)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get oss client")
|
||||
return nil, errors.Wrapf(err, "get oss client")
|
||||
}
|
||||
bucket, err := client.Bucket(request.URL.Host)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "get oss bucket %s", request.URL.Host)
|
||||
}
|
||||
resp, err := bucket.GetObject(request.URL.Path, getOptions(request.Header)...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "get oss object %s", request.URL.Path)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (osc *ossSourceClient) DownloadWithExpireInfo(request *source.Request) (io.ReadCloser, *source.ExpireInfo, error) {
|
||||
client, err := osc.getClient(request.Header)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrapf(err, "get oss client")
|
||||
}
|
||||
bucket, err := client.Bucket(request.URL.Host)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrapf(err, "get oss bucket: %s", request.URL.Host)
|
||||
return nil, errors.Wrapf(err, "get oss bucket: %s", request.URL.Host)
|
||||
}
|
||||
objectResult, err := bucket.DoGetObject(&oss.GetObjectRequest{ObjectKey: request.URL.Path}, getOptions(request.Header))
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrapf(err, "get oss Object: %s", request.URL.Path)
|
||||
return nil, errors.Wrapf(err, "get oss Object: %s", request.URL.Path)
|
||||
}
|
||||
err = source.CheckResponseCode(objectResult.Response.StatusCode, []int{http.StatusOK, http.StatusPartialContent})
|
||||
if err != nil {
|
||||
objectResult.Response.Body.Close()
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
return objectResult.Response.Body, &source.ExpireInfo{
|
||||
response := source.NewResponse(
|
||||
objectResult.Response.Body,
|
||||
source.WithExpireInfo(
|
||||
source.ExpireInfo{
|
||||
LastModified: objectResult.Response.Headers.Get(headers.LastModified),
|
||||
ETag: objectResult.Response.Headers.Get(headers.ETag),
|
||||
}, nil
|
||||
},
|
||||
))
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (osc *ossSourceClient) GetLastModified(request *source.Request) (int64, error) {
|
||||
|
|
|
|||
|
|
@ -16,7 +16,11 @@
|
|||
|
||||
package source
|
||||
|
||||
import "io"
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Response struct {
|
||||
Status string
|
||||
|
|
@ -25,3 +29,70 @@ type Response struct {
|
|||
Body io.ReadCloser
|
||||
ContentLength int64
|
||||
}
|
||||
|
||||
func NewResponse(rc io.ReadCloser, opts ...func(*Response)) *Response {
|
||||
if rc == nil {
|
||||
// for custom plugin, return an error body
|
||||
rc = &errorBody{
|
||||
fmt.Errorf("empty io.ReadCloser, please check resource plugin implement")}
|
||||
}
|
||||
resp := &Response{
|
||||
Header: make(Header),
|
||||
Status: "OK",
|
||||
StatusCode: http.StatusOK,
|
||||
Body: rc,
|
||||
ContentLength: -1,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(resp)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func WithStatus(code int, status string) func(*Response) {
|
||||
return func(resp *Response) {
|
||||
resp.StatusCode = code
|
||||
resp.Status = status
|
||||
}
|
||||
}
|
||||
|
||||
func WithContentLength(length int64) func(*Response) {
|
||||
return func(resp *Response) {
|
||||
resp.ContentLength = length
|
||||
}
|
||||
}
|
||||
|
||||
func WithHeader(header map[string]string) func(*Response) {
|
||||
return func(resp *Response) {
|
||||
for k, v := range header {
|
||||
resp.Header.Set(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func WithExpireInfo(info ExpireInfo) func(*Response) {
|
||||
return func(resp *Response) {
|
||||
resp.Header.Set(LastModified, info.LastModified)
|
||||
resp.Header.Set(ETag, info.ETag)
|
||||
}
|
||||
}
|
||||
|
||||
func (resp *Response) ExpireInfo() ExpireInfo {
|
||||
return ExpireInfo{
|
||||
LastModified: resp.Header.Get(LastModified),
|
||||
ETag: resp.Header.Get(ETag),
|
||||
}
|
||||
}
|
||||
|
||||
type errorBody struct {
|
||||
error
|
||||
}
|
||||
|
||||
func (e *errorBody) Read(p []byte) (n int, err error) {
|
||||
return 0, e.error
|
||||
}
|
||||
|
||||
func (e *errorBody) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ package source
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
|
@ -105,10 +104,7 @@ type ResourceClient interface {
|
|||
IsExpired(request *Request, info *ExpireInfo) (bool, error)
|
||||
|
||||
// Download downloads from source
|
||||
Download(request *Request) (io.ReadCloser, error)
|
||||
|
||||
// DownloadWithExpireInfo download from source with expireInfo
|
||||
DownloadWithExpireInfo(request *Request) (io.ReadCloser, *ExpireInfo, error)
|
||||
Download(request *Request) (*Response, error)
|
||||
|
||||
// GetLastModified gets last modified timestamp milliseconds of resource
|
||||
GetLastModified(request *Request) (int64, error)
|
||||
|
|
@ -251,14 +247,10 @@ func (c *clientWrapper) IsSupportRange(request *Request) (bool, error) {
|
|||
func (c *clientWrapper) IsExpired(request *Request, info *ExpireInfo) (bool, error) {
|
||||
return c.rc.IsExpired(c.adapter(request), info)
|
||||
}
|
||||
func (c *clientWrapper) Download(request *Request) (io.ReadCloser, error) {
|
||||
func (c *clientWrapper) Download(request *Request) (*Response, error) {
|
||||
return c.rc.Download(c.adapter(request))
|
||||
}
|
||||
|
||||
func (c *clientWrapper) DownloadWithExpireInfo(request *Request) (io.ReadCloser, *ExpireInfo, error) {
|
||||
return c.rc.DownloadWithExpireInfo(c.adapter(request))
|
||||
}
|
||||
|
||||
func (c *clientWrapper) GetLastModified(request *Request) (int64, error) {
|
||||
return c.rc.GetLastModified(c.adapter(request))
|
||||
}
|
||||
|
|
@ -318,7 +310,7 @@ func GetLastModified(request *Request) (int64, error) {
|
|||
return client.GetLastModified(request)
|
||||
}
|
||||
|
||||
func Download(request *Request) (io.ReadCloser, error) {
|
||||
func Download(request *Request) (*Response, error) {
|
||||
client, ok := _defaultManager.GetClient(request.URL.Scheme)
|
||||
if !ok {
|
||||
return nil, errors.Wrapf(ErrNoClientFound, "scheme: %s", request.URL.Scheme)
|
||||
|
|
@ -326,14 +318,6 @@ func Download(request *Request) (io.ReadCloser, error) {
|
|||
return client.Download(request)
|
||||
}
|
||||
|
||||
func DownloadWithExpireInfo(request *Request) (io.ReadCloser, *ExpireInfo, error) {
|
||||
client, ok := _defaultManager.GetClient(request.URL.Scheme)
|
||||
if !ok {
|
||||
return nil, nil, errors.Wrapf(ErrNoClientFound, "scheme: %s", request.URL.Scheme)
|
||||
}
|
||||
return client.DownloadWithExpireInfo(request)
|
||||
}
|
||||
|
||||
func List(request *Request) ([]*url.URL, error) {
|
||||
client, ok := _defaultManager.GetClient(request.URL.Scheme)
|
||||
if !ok {
|
||||
|
|
|
|||
|
|
@ -39,13 +39,13 @@ func main() {
|
|||
}
|
||||
|
||||
request, err = source.NewRequest("")
|
||||
rc, err := client.Download(request)
|
||||
response, err := client.Download(request)
|
||||
if err != nil {
|
||||
fmt.Printf("download error: %s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(rc)
|
||||
data, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
fmt.Printf("read error: %s\n", err)
|
||||
os.Exit(1)
|
||||
|
|
@ -56,7 +56,7 @@ func main() {
|
|||
os.Exit(1)
|
||||
}
|
||||
|
||||
err = rc.Close()
|
||||
err = response.Body.Close()
|
||||
if err != nil {
|
||||
fmt.Printf("close error: %s\n", err)
|
||||
os.Exit(1)
|
||||
|
|
|
|||
|
|
@ -48,12 +48,8 @@ func (c *client) IsExpired(request *source.Request, info *source.ExpireInfo) (bo
|
|||
panic("implement me")
|
||||
}
|
||||
|
||||
func (c *client) Download(request *source.Request) (io.ReadCloser, error) {
|
||||
return io.NopCloser(bytes.NewBufferString(data)), nil
|
||||
}
|
||||
|
||||
func (c *client) DownloadWithExpireInfo(request *source.Request) (io.ReadCloser, *source.ExpireInfo, error) {
|
||||
return io.NopCloser(bytes.NewBufferString(data)), nil, nil
|
||||
func (c *client) Download(request *source.Request) (*source.Response, error) {
|
||||
return source.NewResponse(io.NopCloser(bytes.NewBufferString(data))), nil
|
||||
}
|
||||
|
||||
func (c *client) GetLastModified(request *source.Request) (int64, error) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue