feat: update source.Response and source client interface (#945)

* feat: update source.Response and source client interface

Signed-off-by: Jim Ma <majinjing3@gmail.com>
This commit is contained in:
Jim Ma 2021-12-17 18:20:56 +08:00 committed by Gaius
parent 0a152e01e9
commit beaf4ce18d
No known key found for this signature in database
GPG Key ID: 8B4E5D1290FA2FFB
18 changed files with 161 additions and 187 deletions

View File

@ -47,16 +47,16 @@ func (cm *manager) download(ctx context.Context, seedTask *task.SeedTask, breakP
if !stringutils.IsBlank(breakRange) { if !stringutils.IsBlank(breakRange) {
downloadRequest.Header.Add(source.Range, breakRange) downloadRequest.Header.Add(source.Range, breakRange)
} }
body, expireInfo, err := source.DownloadWithExpireInfo(downloadRequest) response, err := source.Download(downloadRequest)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// update Expire info // update Expire info
cm.updateExpireInfo(seedTask.ID, map[string]string{ cm.updateExpireInfo(seedTask.ID, map[string]string{
source.LastModified: expireInfo.LastModified, source.LastModified: response.Header.Get(source.LastModified),
source.ETag: expireInfo.ETag, source.ETag: response.Header.Get(source.ETag),
}) })
return body, err return response.Body, err
} }
func getBreakRange(breakPoint int64, taskRange string, fileTotalLength int64) (string, error) { func getBreakRange(breakPoint int64, taskRange string, fileTotalLength int64) (string, error) {

View File

@ -94,31 +94,22 @@ func (suite *CDNManagerTestSuite) TestTriggerCDN() {
sourceClient.EXPECT().IsSupportRange(gomock.Any()).Return(true, nil).AnyTimes() sourceClient.EXPECT().IsSupportRange(gomock.Any()).Return(true, nil).AnyTimes()
sourceClient.EXPECT().IsExpired(gomock.Any(), gomock.Any()).Return(false, nil).AnyTimes() sourceClient.EXPECT().IsExpired(gomock.Any(), gomock.Any()).Return(false, nil).AnyTimes()
sourceClient.EXPECT().Download(gomock.Any()).DoAndReturn( sourceClient.EXPECT().Download(gomock.Any()).DoAndReturn(
func(request *source.Request) (io.ReadCloser, error) { func(request *source.Request) (*source.Response, error) {
content, _ := os.ReadFile("../../testdata/cdn/go.html") content, _ := os.ReadFile("../../testdata/cdn/go.html")
if request.Header.Get(source.Range) != "" { if request.Header.Get(source.Range) != "" {
parsed, _ := rangeutils.GetRange(request.Header.Get(source.Range)) parsed, _ := rangeutils.GetRange(request.Header.Get(source.Range))
return io.NopCloser(io.NewSectionReader(strings.NewReader(string(content)), int64(parsed.StartIndex), int64(parsed.EndIndex))), nil return source.NewResponse(
} io.NopCloser(io.NewSectionReader(strings.NewReader(string(content)), int64(parsed.StartIndex), int64(parsed.EndIndex))),
return io.NopCloser(strings.NewReader(string(content))), nil source.WithExpireInfo(source.ExpireInfo{
},
).AnyTimes()
sourceClient.EXPECT().DownloadWithExpireInfo(gomock.Any()).DoAndReturn(
func(request *source.Request) (io.ReadCloser, *source.ExpireInfo, error) {
content, _ := os.ReadFile("../../testdata/cdn/go.html")
if request.Header.Get(source.Range) != "" {
parsed, _ := rangeutils.GetRange(request.Header.Get(source.Range))
return io.NopCloser(io.NewSectionReader(strings.NewReader(string(content)), int64(parsed.StartIndex), int64(parsed.EndIndex))),
&source.ExpireInfo{
LastModified: "Sun, 06 Jun 2021 12:52:30 GMT", LastModified: "Sun, 06 Jun 2021 12:52:30 GMT",
ETag: "etag", ETag: "etag",
}, nil })), nil
} }
return io.NopCloser(strings.NewReader(string(content))), return source.NewResponse(io.NopCloser(strings.NewReader(string(content))),
&source.ExpireInfo{ source.WithExpireInfo(source.ExpireInfo{
LastModified: "Sun, 06 Jun 2021 12:52:30 GMT", LastModified: "Sun, 06 Jun 2021 12:52:30 GMT",
ETag: "etag", ETag: "etag",
}, nil })), nil
}, },
).AnyTimes() ).AnyTimes()
sourceClient.EXPECT().GetLastModified(gomock.Any()).Return( sourceClient.EXPECT().GetLastModified(gomock.Any()).Return(

View File

@ -95,9 +95,9 @@ func TestFilePeerTask_BackSource_WithContentLength(t *testing.T) {
return -1, fmt.Errorf("unexpect url: %s", request.URL.String()) return -1, fmt.Errorf("unexpect url: %s", request.URL.String())
}) })
sourceClient.EXPECT().Download(gomock.Any()).DoAndReturn( sourceClient.EXPECT().Download(gomock.Any()).DoAndReturn(
func(request *source.Request) (io.ReadCloser, error) { func(request *source.Request) (*source.Response, error) {
if request.URL.String() == url { if request.URL.String() == url {
return io.NopCloser(bytes.NewBuffer(testBytes)), nil return source.NewResponse(io.NopCloser(bytes.NewBuffer(testBytes))), nil
} }
return nil, fmt.Errorf("unexpect url: %s", request.URL.String()) return nil, fmt.Errorf("unexpect url: %s", request.URL.String())
}) })
@ -220,9 +220,9 @@ func TestFilePeerTask_BackSource_WithoutContentLength(t *testing.T) {
return -1, fmt.Errorf("unexpect url: %s", request.URL.String()) return -1, fmt.Errorf("unexpect url: %s", request.URL.String())
}) })
sourceClient.EXPECT().Download(gomock.Any()).DoAndReturn( sourceClient.EXPECT().Download(gomock.Any()).DoAndReturn(
func(request *source.Request) (io.ReadCloser, error) { func(request *source.Request) (*source.Response, error) {
if request.URL.String() == url { if request.URL.String() == url {
return io.NopCloser(bytes.NewBuffer(testBytes)), nil return source.NewResponse(io.NopCloser(bytes.NewBuffer(testBytes))), nil
} }
return nil, fmt.Errorf("unexpect url: %s", request.URL.String()) return nil, fmt.Errorf("unexpect url: %s", request.URL.String())
}) })

View File

@ -222,8 +222,8 @@ func TestStreamPeerTask_BackSource_Partial_WithContentLength(t *testing.T) {
return int64(len(testBytes)), nil return int64(len(testBytes)), nil
}) })
sourceClient.EXPECT().Download(gomock.Any()).DoAndReturn( sourceClient.EXPECT().Download(gomock.Any()).DoAndReturn(
func(request *source.Request) (io.ReadCloser, error) { func(request *source.Request) (*source.Response, error) {
return io.NopCloser(bytes.NewBuffer(testBytes)), nil return source.NewResponse(io.NopCloser(bytes.NewBuffer(testBytes))), nil
}) })
pm := &pieceManager{ pm := &pieceManager{

View File

@ -90,8 +90,8 @@ func TestStreamPeerTask_BackSource_WithContentLength(t *testing.T) {
return int64(len(testBytes)), nil return int64(len(testBytes)), nil
}) })
sourceClient.EXPECT().Download(source.RequestEq(request.URL.String())).DoAndReturn( sourceClient.EXPECT().Download(source.RequestEq(request.URL.String())).DoAndReturn(
func(request *source.Request) (io.ReadCloser, error) { func(request *source.Request) (*source.Response, error) {
return io.NopCloser(bytes.NewBuffer(testBytes)), nil return source.NewResponse(io.NopCloser(bytes.NewBuffer(testBytes))), nil
}) })
ptm := &peerTaskManager{ ptm := &peerTaskManager{
@ -190,8 +190,8 @@ func TestStreamPeerTask_BackSource_WithoutContentLength(t *testing.T) {
return -1, nil return -1, nil
}) })
sourceClient.EXPECT().Download(source.RequestEq(request.URL.String())).DoAndReturn( sourceClient.EXPECT().Download(source.RequestEq(request.URL.String())).DoAndReturn(
func(request *source.Request) (io.ReadCloser, error) { func(request *source.Request) (*source.Response, error) {
return io.NopCloser(bytes.NewBuffer(testBytes)), nil return source.NewResponse(io.NopCloser(bytes.NewBuffer(testBytes))), nil
}) })
ptm := &peerTaskManager{ ptm := &peerTaskManager{

View File

@ -376,16 +376,16 @@ func (pm *pieceManager) DownloadSource(ctx context.Context, pt Task, request *sc
if err != nil { if err != nil {
return err return err
} }
body, err := source.Download(downloadRequest) response, err := source.Download(downloadRequest)
if err != nil { if err != nil {
return err return err
} }
defer body.Close() defer response.Body.Close()
reader := body.(io.Reader) reader := response.Body.(io.Reader)
// calc total md5 // calc total md5
if pm.calculateDigest && request.UrlMeta.Digest != "" { if pm.calculateDigest && request.UrlMeta.Digest != "" {
reader = digestutils.NewDigestReader(pt.Log(), body, request.UrlMeta.Digest) reader = digestutils.NewDigestReader(pt.Log(), response.Body, request.UrlMeta.Digest)
} }
// 2. save to storage // 2. save to storage

View File

@ -143,7 +143,7 @@ func downloadFromSource(ctx context.Context, cfg *config.DfgetConfig, hdr map[st
wLog = logger.With("url", cfg.URL) wLog = logger.With("url", cfg.URL)
start = time.Now() start = time.Now()
target *os.File target *os.File
response io.ReadCloser response *source.Response
err error err error
written int64 written int64
) )
@ -164,9 +164,9 @@ func downloadFromSource(ctx context.Context, cfg *config.DfgetConfig, hdr map[st
if response, err = source.Download(downloadRequest); err != nil { if response, err = source.Download(downloadRequest); err != nil {
return err return err
} }
defer response.Close() defer response.Body.Close()
if written, err = io.Copy(target, response); err != nil { if written, err = io.Copy(target, response.Body); err != nil {
return err return err
} }

View File

@ -57,7 +57,7 @@ func Test_downloadFromSource(t *testing.T) {
} }
request, err := source.NewRequest(cfg.URL) request, err := source.NewRequest(cfg.URL)
assert.Nil(t, err) assert.Nil(t, err)
sourceClient.EXPECT().Download(request).Return(io.NopCloser(strings.NewReader(content)), nil) sourceClient.EXPECT().Download(request).Return(source.NewResponse(io.NopCloser(strings.NewReader(content))), nil)
err = downloadFromSource(context.Background(), cfg, nil) err = downloadFromSource(context.Background(), cfg, nil)
assert.Nil(t, err) assert.Nil(t, err)

View File

@ -111,21 +111,25 @@ func (h *hdfsSourceClient) IsExpired(request *source.Request, info *source.Expir
return fileInfo.ModTime().Format(source.LastModifiedLayout) != info.LastModified, nil return fileInfo.ModTime().Format(source.LastModifiedLayout) != info.LastModified, nil
} }
func (h *hdfsSourceClient) Download(request *source.Request) (io.ReadCloser, error) { func (h *hdfsSourceClient) Download(request *source.Request) (*source.Response, error) {
hdfsClient, path, err := h.getHDFSClientAndPath(request.URL) hdfsClient, path, err := h.getHDFSClientAndPath(request.URL)
if err != nil { if err != nil {
return nil, err return nil, err
} }
hdfsFile, err := hdfsClient.Open(path) hdfsFile, err := hdfsClient.Open(path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
fileInfo := hdfsFile.Stat()
// default read all data when rang is nil // default read all data when rang is nil
var limitReadN = hdfsFile.Stat().Size() var limitReadN = fileInfo.Size()
if limitReadN < 0 { if limitReadN < 0 {
return nil, errors.Errorf("file length is illegal, length: %d", limitReadN) return nil, errors.Errorf("file length is illegal, length: %d", limitReadN)
} }
if request.Header.Get(source.Range) != "" { if request.Header.Get(source.Range) != "" {
requestRange, err := rangeutils.ParseRange(request.Header.Get(source.Range), uint64(limitReadN)) requestRange, err := rangeutils.ParseRange(request.Header.Get(source.Range), uint64(limitReadN))
if err != nil { if err != nil {
@ -139,44 +143,12 @@ func (h *hdfsSourceClient) Download(request *source.Request) (io.ReadCloser, err
limitReadN = int64(requestRange.Length()) limitReadN = int64(requestRange.Length())
} }
return newHdfsFileReaderClose(hdfsFile, limitReadN, hdfsFile), nil response := source.NewResponse(
} newHdfsFileReaderClose(hdfsFile, limitReadN, hdfsFile),
source.WithExpireInfo(source.ExpireInfo{
func (h *hdfsSourceClient) DownloadWithExpireInfo(request *source.Request) (io.ReadCloser, *source.ExpireInfo, error) {
hdfsClient, path, err := h.getHDFSClientAndPath(request.URL)
if err != nil {
return nil, nil, err
}
hdfsFile, err := hdfsClient.Open(path)
if err != nil {
return nil, nil, err
}
fileInfo := hdfsFile.Stat()
// default read all data when rang is nil
var limitReadN = fileInfo.Size()
if limitReadN < 0 {
return nil, nil, errors.Errorf("file length is illegal, length: %d", limitReadN)
}
if request.Header.Get(source.Range) != "" {
requestRange, err := rangeutils.ParseRange(request.Header.Get(source.Range), uint64(limitReadN))
if err != nil {
return nil, nil, err
}
_, err = hdfsFile.Seek(int64(requestRange.StartIndex), 0)
if err != nil {
hdfsFile.Close()
return nil, nil, err
}
limitReadN = int64(requestRange.EndIndex - requestRange.StartIndex)
}
return newHdfsFileReaderClose(hdfsFile, limitReadN, hdfsFile), &source.ExpireInfo{
LastModified: timeutils.Format(fileInfo.ModTime()), LastModified: timeutils.Format(fileInfo.ModTime()),
}, nil }))
return response, nil
} }
func (h *hdfsSourceClient) GetLastModified(request *source.Request) (int64, error) { func (h *hdfsSourceClient) GetLastModified(request *source.Request) (int64, error) {

View File

@ -228,9 +228,9 @@ func Test_Download_FileExist_ByRange(t *testing.T) {
}) })
assert.Nil(t, err) assert.Nil(t, err)
download, err := sourceClient.Download(request) response, err := sourceClient.Download(request)
assert.Nil(t, err) assert.Nil(t, err)
data, _ := io.ReadAll(download) data, _ := io.ReadAll(response.Body)
assert.Equal(t, hdfsExistFileContent, string(data)) assert.Equal(t, hdfsExistFileContent, string(data))
} }
@ -285,11 +285,11 @@ func Test_DownloadWithResponseHeader_FileExist_ByRange(t *testing.T) {
request, err := source.NewRequest(hdfsExistFileURL) request, err := source.NewRequest(hdfsExistFileURL)
assert.Nil(t, err) assert.Nil(t, err)
request.Header.Add(source.Range, rang.String()) request.Header.Add(source.Range, rang.String())
body, expireInfo, err := sourceClient.DownloadWithExpireInfo(request) response, err := sourceClient.Download(request)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, hdfsExistFileLastModified, expireInfo.LastModified) assert.Equal(t, hdfsExistFileLastModified, response.ExpireInfo().LastModified)
data, _ := io.ReadAll(body) data, _ := io.ReadAll(response.Body)
assert.Equal(t, string(data), string([]byte(hdfsExistFileContent)[hdfsExistFileRangeStart:hdfsExistFileRangeEnd])) assert.Equal(t, string(data), string([]byte(hdfsExistFileContent)[hdfsExistFileRangeStart:hdfsExistFileRangeEnd]))
} }
@ -303,10 +303,9 @@ func TestDownloadWithResponseHeader_FileNotExist(t *testing.T) {
request, err := source.NewRequest(hdfsNotExistFileURL) request, err := source.NewRequest(hdfsNotExistFileURL)
assert.Nil(t, err) assert.Nil(t, err)
request.Header.Add(source.Range, rang.String()) request.Header.Add(source.Range, rang.String())
reader, expireInfo, err := sourceClient.DownloadWithExpireInfo(request) response, err := sourceClient.Download(request)
assert.EqualError(t, err, "open /user/root/input/f3.txt: file does not exist") assert.EqualError(t, err, "open /user/root/input/f3.txt: file does not exist")
assert.Nil(t, reader) assert.Nil(t, response)
assert.Nil(t, expireInfo)
} }
func TestGetLastModified_FileExist(t *testing.T) { func TestGetLastModified_FileExist(t *testing.T) {

View File

@ -18,7 +18,6 @@ package httpprotocol
import ( import (
"fmt" "fmt"
"io"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
@ -170,7 +169,7 @@ func (client *httpSourceClient) IsExpired(request *source.Request, info *source.
LastModified)), nil LastModified)), nil
} }
func (client *httpSourceClient) Download(request *source.Request) (io.ReadCloser, error) { func (client *httpSourceClient) Download(request *source.Request) (*source.Response, error) {
resp, err := client.doRequest(http.MethodGet, request) resp, err := client.doRequest(http.MethodGet, request)
if err != nil { if err != nil {
return nil, err return nil, err
@ -180,23 +179,15 @@ func (client *httpSourceClient) Download(request *source.Request) (io.ReadCloser
resp.Body.Close() resp.Body.Close()
return nil, err return nil, err
} }
return resp.Body, nil response := source.NewResponse(
} resp.Body,
source.WithExpireInfo(
func (client *httpSourceClient) DownloadWithExpireInfo(request *source.Request) (io.ReadCloser, *source.ExpireInfo, error) { source.ExpireInfo{
resp, err := client.doRequest(http.MethodGet, request)
if err != nil {
return nil, nil, err
}
err = source.CheckResponseCode(resp.StatusCode, []int{http.StatusOK, http.StatusPartialContent})
if err != nil {
resp.Body.Close()
return nil, nil, err
}
return resp.Body, &source.ExpireInfo{
LastModified: resp.Header.Get(headers.LastModified), LastModified: resp.Header.Get(headers.LastModified),
ETag: resp.Header.Get(headers.ETag), ETag: resp.Header.Get(headers.ETag),
}, nil },
))
return response, nil
} }
func (client *httpSourceClient) GetLastModified(request *source.Request) (int64, error) { func (client *httpSourceClient) GetLastModified(request *source.Request) (int64, error) {

View File

@ -131,7 +131,6 @@ func (suite *HTTPSourceClientTestSuite) SetupTest() {
httpmock.RegisterResponder(http.MethodGet, notfoundRawURL, httpmock.NewStringResponder(http.StatusNotFound, "not found")) httpmock.RegisterResponder(http.MethodGet, notfoundRawURL, httpmock.NewStringResponder(http.StatusNotFound, "not found"))
httpmock.RegisterResponder(http.MethodGet, normalNotSupportRangeRawURL, httpmock.NewStringResponder(http.StatusOK, testContent)) httpmock.RegisterResponder(http.MethodGet, normalNotSupportRangeRawURL, httpmock.NewStringResponder(http.StatusOK, testContent))
httpmock.RegisterResponder(http.MethodGet, errorRawURL, httpmock.NewErrorResponder(fmt.Errorf("error"))) httpmock.RegisterResponder(http.MethodGet, errorRawURL, httpmock.NewErrorResponder(fmt.Errorf("error")))
} }
func (suite *HTTPSourceClientTestSuite) TestNewHTTPSourceClient() { func (suite *HTTPSourceClientTestSuite) TestNewHTTPSourceClient() {
@ -150,12 +149,11 @@ func (suite *HTTPSourceClientTestSuite) TestHttpSourceClientDownloadWithResponse
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
timeoutRequest, err := source.NewRequestWithContext(ctx, timeoutRawURL, nil) timeoutRequest, err := source.NewRequestWithContext(ctx, timeoutRawURL, nil)
suite.Nil(err) suite.Nil(err)
reader, expireInfo, err := suite.httpClient.DownloadWithExpireInfo(timeoutRequest) response, err := suite.httpClient.Download(timeoutRequest)
cancel() cancel()
suite.NotNil(err) suite.NotNil(err)
suite.Equal("Get \"https://timeout.com\": context deadline exceeded", err.Error()) suite.Equal("Get \"https://timeout.com\": context deadline exceeded", err.Error())
suite.Nil(reader) suite.Nil(response)
suite.Nil(expireInfo)
normalRequest, _ := source.NewRequest(normalRawURL) normalRequest, _ := source.NewRequest(normalRawURL)
normalRangeRequest, _ := source.NewRequest(normalRawURL) normalRangeRequest, _ := source.NewRequest(normalRawURL)
@ -203,15 +201,16 @@ func (suite *HTTPSourceClientTestSuite) TestHttpSourceClientDownloadWithResponse
} }
for _, tt := range tests { for _, tt := range tests {
suite.Run(tt.name, func() { suite.Run(tt.name, func() {
reader, expireInfo, err := suite.httpClient.DownloadWithExpireInfo(tt.request) response, err := suite.httpClient.Download(tt.request)
if err != nil { if err != nil {
suite.True(tt.wantErr.Error() == err.Error()) suite.True(tt.wantErr.Error() == err.Error())
return return
} }
bytes, err := io.ReadAll(reader) bytes, err := io.ReadAll(response.Body)
suite.Nil(err) suite.Nil(err)
suite.Equal(tt.content, string(bytes)) suite.Equal(tt.content, string(bytes))
suite.Equal(tt.expireInfo, expireInfo) expireInfo := response.ExpireInfo()
suite.Equal(tt.expireInfo, &expireInfo)
}) })
} }
} }

View File

@ -5,7 +5,6 @@
package mock package mock
import ( import (
io "io"
reflect "reflect" reflect "reflect"
source "d7y.io/dragonfly/v2/pkg/source" source "d7y.io/dragonfly/v2/pkg/source"
@ -36,10 +35,10 @@ func (m *MockResourceClient) EXPECT() *MockResourceClientMockRecorder {
} }
// Download mocks base method. // Download mocks base method.
func (m *MockResourceClient) Download(arg0 *source.Request) (io.ReadCloser, error) { func (m *MockResourceClient) Download(arg0 *source.Request) (*source.Response, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Download", arg0) ret := m.ctrl.Call(m, "Download", arg0)
ret0, _ := ret[0].(io.ReadCloser) ret0, _ := ret[0].(*source.Response)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
@ -50,22 +49,6 @@ func (mr *MockResourceClientMockRecorder) Download(arg0 interface{}) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Download", reflect.TypeOf((*MockResourceClient)(nil).Download), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Download", reflect.TypeOf((*MockResourceClient)(nil).Download), arg0)
} }
// DownloadWithExpireInfo mocks base method.
func (m *MockResourceClient) DownloadWithExpireInfo(arg0 *source.Request) (io.ReadCloser, *source.ExpireInfo, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DownloadWithExpireInfo", arg0)
ret0, _ := ret[0].(io.ReadCloser)
ret1, _ := ret[1].(*source.ExpireInfo)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// DownloadWithExpireInfo indicates an expected call of DownloadWithExpireInfo.
func (mr *MockResourceClientMockRecorder) DownloadWithExpireInfo(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DownloadWithExpireInfo", reflect.TypeOf((*MockResourceClient)(nil).DownloadWithExpireInfo), arg0)
}
// GetContentLength mocks base method. // GetContentLength mocks base method.
func (m *MockResourceClient) GetContentLength(arg0 *source.Request) (int64, error) { func (m *MockResourceClient) GetContentLength(arg0 *source.Request) (int64, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@ -18,7 +18,6 @@ package ossprotocol
import ( import (
"fmt" "fmt"
"io"
"net/http" "net/http"
"strconv" "strconv"
"sync" "sync"
@ -147,44 +146,33 @@ func (osc *ossSourceClient) IsExpired(request *source.Request, info *source.Expi
return !(resHeader.Get(oss.HTTPHeaderEtag) == info.ETag || resHeader.Get(oss.HTTPHeaderLastModified) == info.LastModified), nil return !(resHeader.Get(oss.HTTPHeaderEtag) == info.ETag || resHeader.Get(oss.HTTPHeaderLastModified) == info.LastModified), nil
} }
func (osc *ossSourceClient) Download(request *source.Request) (io.ReadCloser, error) { func (osc *ossSourceClient) Download(request *source.Request) (*source.Response, error) {
client, err := osc.getClient(request.Header) client, err := osc.getClient(request.Header)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "get oss client") return nil, errors.Wrapf(err, "get oss client")
} }
bucket, err := client.Bucket(request.URL.Host) bucket, err := client.Bucket(request.URL.Host)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "get oss bucket %s", request.URL.Host) return nil, errors.Wrapf(err, "get oss bucket: %s", request.URL.Host)
}
resp, err := bucket.GetObject(request.URL.Path, getOptions(request.Header)...)
if err != nil {
return nil, errors.Wrapf(err, "get oss object %s", request.URL.Path)
}
return resp, nil
}
func (osc *ossSourceClient) DownloadWithExpireInfo(request *source.Request) (io.ReadCloser, *source.ExpireInfo, error) {
client, err := osc.getClient(request.Header)
if err != nil {
return nil, nil, errors.Wrapf(err, "get oss client")
}
bucket, err := client.Bucket(request.URL.Host)
if err != nil {
return nil, nil, errors.Wrapf(err, "get oss bucket: %s", request.URL.Host)
} }
objectResult, err := bucket.DoGetObject(&oss.GetObjectRequest{ObjectKey: request.URL.Path}, getOptions(request.Header)) objectResult, err := bucket.DoGetObject(&oss.GetObjectRequest{ObjectKey: request.URL.Path}, getOptions(request.Header))
if err != nil { if err != nil {
return nil, nil, errors.Wrapf(err, "get oss Object: %s", request.URL.Path) return nil, errors.Wrapf(err, "get oss Object: %s", request.URL.Path)
} }
err = source.CheckResponseCode(objectResult.Response.StatusCode, []int{http.StatusOK, http.StatusPartialContent}) err = source.CheckResponseCode(objectResult.Response.StatusCode, []int{http.StatusOK, http.StatusPartialContent})
if err != nil { if err != nil {
objectResult.Response.Body.Close() objectResult.Response.Body.Close()
return nil, nil, err return nil, err
} }
return objectResult.Response.Body, &source.ExpireInfo{ response := source.NewResponse(
objectResult.Response.Body,
source.WithExpireInfo(
source.ExpireInfo{
LastModified: objectResult.Response.Headers.Get(headers.LastModified), LastModified: objectResult.Response.Headers.Get(headers.LastModified),
ETag: objectResult.Response.Headers.Get(headers.ETag), ETag: objectResult.Response.Headers.Get(headers.ETag),
}, nil },
))
return response, nil
} }
func (osc *ossSourceClient) GetLastModified(request *source.Request) (int64, error) { func (osc *ossSourceClient) GetLastModified(request *source.Request) (int64, error) {

View File

@ -16,7 +16,11 @@
package source package source
import "io" import (
"fmt"
"io"
"net/http"
)
type Response struct { type Response struct {
Status string Status string
@ -25,3 +29,70 @@ type Response struct {
Body io.ReadCloser Body io.ReadCloser
ContentLength int64 ContentLength int64
} }
func NewResponse(rc io.ReadCloser, opts ...func(*Response)) *Response {
if rc == nil {
// for custom plugin, return an error body
rc = &errorBody{
fmt.Errorf("empty io.ReadCloser, please check resource plugin implement")}
}
resp := &Response{
Header: make(Header),
Status: "OK",
StatusCode: http.StatusOK,
Body: rc,
ContentLength: -1,
}
for _, opt := range opts {
opt(resp)
}
return resp
}
func WithStatus(code int, status string) func(*Response) {
return func(resp *Response) {
resp.StatusCode = code
resp.Status = status
}
}
func WithContentLength(length int64) func(*Response) {
return func(resp *Response) {
resp.ContentLength = length
}
}
func WithHeader(header map[string]string) func(*Response) {
return func(resp *Response) {
for k, v := range header {
resp.Header.Set(k, v)
}
}
}
func WithExpireInfo(info ExpireInfo) func(*Response) {
return func(resp *Response) {
resp.Header.Set(LastModified, info.LastModified)
resp.Header.Set(ETag, info.ETag)
}
}
func (resp *Response) ExpireInfo() ExpireInfo {
return ExpireInfo{
LastModified: resp.Header.Get(LastModified),
ETag: resp.Header.Get(ETag),
}
}
type errorBody struct {
error
}
func (e *errorBody) Read(p []byte) (n int, err error) {
return 0, e.error
}
func (e *errorBody) Close() error {
return nil
}

View File

@ -20,7 +20,6 @@ package source
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"net/url" "net/url"
"strconv" "strconv"
"strings" "strings"
@ -105,10 +104,7 @@ type ResourceClient interface {
IsExpired(request *Request, info *ExpireInfo) (bool, error) IsExpired(request *Request, info *ExpireInfo) (bool, error)
// Download downloads from source // Download downloads from source
Download(request *Request) (io.ReadCloser, error) Download(request *Request) (*Response, error)
// DownloadWithExpireInfo download from source with expireInfo
DownloadWithExpireInfo(request *Request) (io.ReadCloser, *ExpireInfo, error)
// GetLastModified gets last modified timestamp milliseconds of resource // GetLastModified gets last modified timestamp milliseconds of resource
GetLastModified(request *Request) (int64, error) GetLastModified(request *Request) (int64, error)
@ -251,14 +247,10 @@ func (c *clientWrapper) IsSupportRange(request *Request) (bool, error) {
func (c *clientWrapper) IsExpired(request *Request, info *ExpireInfo) (bool, error) { func (c *clientWrapper) IsExpired(request *Request, info *ExpireInfo) (bool, error) {
return c.rc.IsExpired(c.adapter(request), info) return c.rc.IsExpired(c.adapter(request), info)
} }
func (c *clientWrapper) Download(request *Request) (io.ReadCloser, error) { func (c *clientWrapper) Download(request *Request) (*Response, error) {
return c.rc.Download(c.adapter(request)) return c.rc.Download(c.adapter(request))
} }
func (c *clientWrapper) DownloadWithExpireInfo(request *Request) (io.ReadCloser, *ExpireInfo, error) {
return c.rc.DownloadWithExpireInfo(c.adapter(request))
}
func (c *clientWrapper) GetLastModified(request *Request) (int64, error) { func (c *clientWrapper) GetLastModified(request *Request) (int64, error) {
return c.rc.GetLastModified(c.adapter(request)) return c.rc.GetLastModified(c.adapter(request))
} }
@ -318,7 +310,7 @@ func GetLastModified(request *Request) (int64, error) {
return client.GetLastModified(request) return client.GetLastModified(request)
} }
func Download(request *Request) (io.ReadCloser, error) { func Download(request *Request) (*Response, error) {
client, ok := _defaultManager.GetClient(request.URL.Scheme) client, ok := _defaultManager.GetClient(request.URL.Scheme)
if !ok { if !ok {
return nil, errors.Wrapf(ErrNoClientFound, "scheme: %s", request.URL.Scheme) return nil, errors.Wrapf(ErrNoClientFound, "scheme: %s", request.URL.Scheme)
@ -326,14 +318,6 @@ func Download(request *Request) (io.ReadCloser, error) {
return client.Download(request) return client.Download(request)
} }
func DownloadWithExpireInfo(request *Request) (io.ReadCloser, *ExpireInfo, error) {
client, ok := _defaultManager.GetClient(request.URL.Scheme)
if !ok {
return nil, nil, errors.Wrapf(ErrNoClientFound, "scheme: %s", request.URL.Scheme)
}
return client.DownloadWithExpireInfo(request)
}
func List(request *Request) ([]*url.URL, error) { func List(request *Request) ([]*url.URL, error) {
client, ok := _defaultManager.GetClient(request.URL.Scheme) client, ok := _defaultManager.GetClient(request.URL.Scheme)
if !ok { if !ok {

View File

@ -39,13 +39,13 @@ func main() {
} }
request, err = source.NewRequest("") request, err = source.NewRequest("")
rc, err := client.Download(request) response, err := client.Download(request)
if err != nil { if err != nil {
fmt.Printf("download error: %s\n", err) fmt.Printf("download error: %s\n", err)
os.Exit(1) os.Exit(1)
} }
data, err := io.ReadAll(rc) data, err := io.ReadAll(response.Body)
if err != nil { if err != nil {
fmt.Printf("read error: %s\n", err) fmt.Printf("read error: %s\n", err)
os.Exit(1) os.Exit(1)
@ -56,7 +56,7 @@ func main() {
os.Exit(1) os.Exit(1)
} }
err = rc.Close() err = response.Body.Close()
if err != nil { if err != nil {
fmt.Printf("close error: %s\n", err) fmt.Printf("close error: %s\n", err)
os.Exit(1) os.Exit(1)

View File

@ -48,12 +48,8 @@ func (c *client) IsExpired(request *source.Request, info *source.ExpireInfo) (bo
panic("implement me") panic("implement me")
} }
func (c *client) Download(request *source.Request) (io.ReadCloser, error) { func (c *client) Download(request *source.Request) (*source.Response, error) {
return io.NopCloser(bytes.NewBufferString(data)), nil return source.NewResponse(io.NopCloser(bytes.NewBufferString(data))), nil
}
func (c *client) DownloadWithExpireInfo(request *source.Request) (io.ReadCloser, *source.ExpireInfo, error) {
return io.NopCloser(bytes.NewBufferString(data)), nil, nil
} }
func (c *client) GetLastModified(request *source.Request) (int64, error) { func (c *client) GetLastModified(request *source.Request) (int64, error) {