diff --git a/cdn/supervisor/cdn/downloader.go b/cdn/supervisor/cdn/downloader.go index 0bc68b317..45b44fac3 100644 --- a/cdn/supervisor/cdn/downloader.go +++ b/cdn/supervisor/cdn/downloader.go @@ -47,16 +47,16 @@ func (cm *manager) download(ctx context.Context, seedTask *task.SeedTask, breakP if !stringutils.IsBlank(breakRange) { downloadRequest.Header.Add(source.Range, breakRange) } - body, expireInfo, err := source.DownloadWithExpireInfo(downloadRequest) + response, err := source.Download(downloadRequest) if err != nil { return nil, err } // update Expire info cm.updateExpireInfo(seedTask.ID, map[string]string{ - source.LastModified: expireInfo.LastModified, - source.ETag: expireInfo.ETag, + source.LastModified: response.Header.Get(source.LastModified), + source.ETag: response.Header.Get(source.ETag), }) - return body, err + return response.Body, err } func getBreakRange(breakPoint int64, taskRange string, fileTotalLength int64) (string, error) { diff --git a/cdn/supervisor/cdn/manager_test.go b/cdn/supervisor/cdn/manager_test.go index 14b14e349..f850d1092 100644 --- a/cdn/supervisor/cdn/manager_test.go +++ b/cdn/supervisor/cdn/manager_test.go @@ -94,31 +94,22 @@ func (suite *CDNManagerTestSuite) TestTriggerCDN() { sourceClient.EXPECT().IsSupportRange(gomock.Any()).Return(true, nil).AnyTimes() sourceClient.EXPECT().IsExpired(gomock.Any(), gomock.Any()).Return(false, nil).AnyTimes() sourceClient.EXPECT().Download(gomock.Any()).DoAndReturn( - func(request *source.Request) (io.ReadCloser, error) { + func(request *source.Request) (*source.Response, error) { content, _ := os.ReadFile("../../testdata/cdn/go.html") if request.Header.Get(source.Range) != "" { parsed, _ := rangeutils.GetRange(request.Header.Get(source.Range)) - return io.NopCloser(io.NewSectionReader(strings.NewReader(string(content)), int64(parsed.StartIndex), int64(parsed.EndIndex))), nil - } - return io.NopCloser(strings.NewReader(string(content))), nil - }, - ).AnyTimes() - sourceClient.EXPECT().DownloadWithExpireInfo(gomock.Any()).DoAndReturn( - func(request *source.Request) (io.ReadCloser, *source.ExpireInfo, error) { - content, _ := os.ReadFile("../../testdata/cdn/go.html") - if request.Header.Get(source.Range) != "" { - parsed, _ := rangeutils.GetRange(request.Header.Get(source.Range)) - return io.NopCloser(io.NewSectionReader(strings.NewReader(string(content)), int64(parsed.StartIndex), int64(parsed.EndIndex))), - &source.ExpireInfo{ + return source.NewResponse( + io.NopCloser(io.NewSectionReader(strings.NewReader(string(content)), int64(parsed.StartIndex), int64(parsed.EndIndex))), + source.WithExpireInfo(source.ExpireInfo{ LastModified: "Sun, 06 Jun 2021 12:52:30 GMT", ETag: "etag", - }, nil + })), nil } - return io.NopCloser(strings.NewReader(string(content))), - &source.ExpireInfo{ + return source.NewResponse(io.NopCloser(strings.NewReader(string(content))), + source.WithExpireInfo(source.ExpireInfo{ LastModified: "Sun, 06 Jun 2021 12:52:30 GMT", ETag: "etag", - }, nil + })), nil }, ).AnyTimes() sourceClient.EXPECT().GetLastModified(gomock.Any()).Return( diff --git a/client/daemon/peer/peertask_file_test.go b/client/daemon/peer/peertask_file_test.go index 430becd87..9053cd043 100644 --- a/client/daemon/peer/peertask_file_test.go +++ b/client/daemon/peer/peertask_file_test.go @@ -95,9 +95,9 @@ func TestFilePeerTask_BackSource_WithContentLength(t *testing.T) { return -1, fmt.Errorf("unexpect url: %s", request.URL.String()) }) sourceClient.EXPECT().Download(gomock.Any()).DoAndReturn( - func(request *source.Request) (io.ReadCloser, error) { + func(request *source.Request) (*source.Response, error) { if request.URL.String() == url { - return io.NopCloser(bytes.NewBuffer(testBytes)), nil + return source.NewResponse(io.NopCloser(bytes.NewBuffer(testBytes))), nil } return nil, fmt.Errorf("unexpect url: %s", request.URL.String()) }) @@ -220,9 +220,9 @@ func TestFilePeerTask_BackSource_WithoutContentLength(t *testing.T) { return -1, fmt.Errorf("unexpect url: %s", request.URL.String()) }) sourceClient.EXPECT().Download(gomock.Any()).DoAndReturn( - func(request *source.Request) (io.ReadCloser, error) { + func(request *source.Request) (*source.Response, error) { if request.URL.String() == url { - return io.NopCloser(bytes.NewBuffer(testBytes)), nil + return source.NewResponse(io.NopCloser(bytes.NewBuffer(testBytes))), nil } return nil, fmt.Errorf("unexpect url: %s", request.URL.String()) }) diff --git a/client/daemon/peer/peertask_stream_backsource_partial_test.go b/client/daemon/peer/peertask_stream_backsource_partial_test.go index ab6801fe8..02a0d8bac 100644 --- a/client/daemon/peer/peertask_stream_backsource_partial_test.go +++ b/client/daemon/peer/peertask_stream_backsource_partial_test.go @@ -222,8 +222,8 @@ func TestStreamPeerTask_BackSource_Partial_WithContentLength(t *testing.T) { return int64(len(testBytes)), nil }) sourceClient.EXPECT().Download(gomock.Any()).DoAndReturn( - func(request *source.Request) (io.ReadCloser, error) { - return io.NopCloser(bytes.NewBuffer(testBytes)), nil + func(request *source.Request) (*source.Response, error) { + return source.NewResponse(io.NopCloser(bytes.NewBuffer(testBytes))), nil }) pm := &pieceManager{ diff --git a/client/daemon/peer/peertask_stream_test.go b/client/daemon/peer/peertask_stream_test.go index 4205b0f1f..ba0e85474 100644 --- a/client/daemon/peer/peertask_stream_test.go +++ b/client/daemon/peer/peertask_stream_test.go @@ -90,8 +90,8 @@ func TestStreamPeerTask_BackSource_WithContentLength(t *testing.T) { return int64(len(testBytes)), nil }) sourceClient.EXPECT().Download(source.RequestEq(request.URL.String())).DoAndReturn( - func(request *source.Request) (io.ReadCloser, error) { - return io.NopCloser(bytes.NewBuffer(testBytes)), nil + func(request *source.Request) (*source.Response, error) { + return source.NewResponse(io.NopCloser(bytes.NewBuffer(testBytes))), nil }) ptm := &peerTaskManager{ @@ -190,8 +190,8 @@ func TestStreamPeerTask_BackSource_WithoutContentLength(t *testing.T) { return -1, nil }) sourceClient.EXPECT().Download(source.RequestEq(request.URL.String())).DoAndReturn( - func(request *source.Request) (io.ReadCloser, error) { - return io.NopCloser(bytes.NewBuffer(testBytes)), nil + func(request *source.Request) (*source.Response, error) { + return source.NewResponse(io.NopCloser(bytes.NewBuffer(testBytes))), nil }) ptm := &peerTaskManager{ diff --git a/client/daemon/peer/piece_manager.go b/client/daemon/peer/piece_manager.go index 59e4c8363..c26483289 100644 --- a/client/daemon/peer/piece_manager.go +++ b/client/daemon/peer/piece_manager.go @@ -376,16 +376,16 @@ func (pm *pieceManager) DownloadSource(ctx context.Context, pt Task, request *sc if err != nil { return err } - body, err := source.Download(downloadRequest) + response, err := source.Download(downloadRequest) if err != nil { return err } - defer body.Close() - reader := body.(io.Reader) + defer response.Body.Close() + reader := response.Body.(io.Reader) // calc total md5 if pm.calculateDigest && request.UrlMeta.Digest != "" { - reader = digestutils.NewDigestReader(pt.Log(), body, request.UrlMeta.Digest) + reader = digestutils.NewDigestReader(pt.Log(), response.Body, request.UrlMeta.Digest) } // 2. save to storage diff --git a/client/dfget/dfget.go b/client/dfget/dfget.go index 88d425216..b55575391 100644 --- a/client/dfget/dfget.go +++ b/client/dfget/dfget.go @@ -143,7 +143,7 @@ func downloadFromSource(ctx context.Context, cfg *config.DfgetConfig, hdr map[st wLog = logger.With("url", cfg.URL) start = time.Now() target *os.File - response io.ReadCloser + response *source.Response err error written int64 ) @@ -164,9 +164,9 @@ func downloadFromSource(ctx context.Context, cfg *config.DfgetConfig, hdr map[st if response, err = source.Download(downloadRequest); err != nil { return err } - defer response.Close() + defer response.Body.Close() - if written, err = io.Copy(target, response); err != nil { + if written, err = io.Copy(target, response.Body); err != nil { return err } diff --git a/client/dfget/dfget_test.go b/client/dfget/dfget_test.go index b1b7b8735..2fdf4e650 100644 --- a/client/dfget/dfget_test.go +++ b/client/dfget/dfget_test.go @@ -57,7 +57,7 @@ func Test_downloadFromSource(t *testing.T) { } request, err := source.NewRequest(cfg.URL) assert.Nil(t, err) - sourceClient.EXPECT().Download(request).Return(io.NopCloser(strings.NewReader(content)), nil) + sourceClient.EXPECT().Download(request).Return(source.NewResponse(io.NopCloser(strings.NewReader(content))), nil) err = downloadFromSource(context.Background(), cfg, nil) assert.Nil(t, err) diff --git a/pkg/source/hdfsprotocol/hdfs_source_client.go b/pkg/source/hdfsprotocol/hdfs_source_client.go index 90a74f407..47c9d15d7 100644 --- a/pkg/source/hdfsprotocol/hdfs_source_client.go +++ b/pkg/source/hdfsprotocol/hdfs_source_client.go @@ -111,21 +111,25 @@ func (h *hdfsSourceClient) IsExpired(request *source.Request, info *source.Expir return fileInfo.ModTime().Format(source.LastModifiedLayout) != info.LastModified, nil } -func (h *hdfsSourceClient) Download(request *source.Request) (io.ReadCloser, error) { +func (h *hdfsSourceClient) Download(request *source.Request) (*source.Response, error) { hdfsClient, path, err := h.getHDFSClientAndPath(request.URL) if err != nil { return nil, err } + hdfsFile, err := hdfsClient.Open(path) if err != nil { return nil, err } + fileInfo := hdfsFile.Stat() + // default read all data when rang is nil - var limitReadN = hdfsFile.Stat().Size() + var limitReadN = fileInfo.Size() if limitReadN < 0 { return nil, errors.Errorf("file length is illegal, length: %d", limitReadN) } + if request.Header.Get(source.Range) != "" { requestRange, err := rangeutils.ParseRange(request.Header.Get(source.Range), uint64(limitReadN)) if err != nil { @@ -139,44 +143,12 @@ func (h *hdfsSourceClient) Download(request *source.Request) (io.ReadCloser, err limitReadN = int64(requestRange.Length()) } - return newHdfsFileReaderClose(hdfsFile, limitReadN, hdfsFile), nil -} - -func (h *hdfsSourceClient) DownloadWithExpireInfo(request *source.Request) (io.ReadCloser, *source.ExpireInfo, error) { - - hdfsClient, path, err := h.getHDFSClientAndPath(request.URL) - if err != nil { - return nil, nil, err - } - - hdfsFile, err := hdfsClient.Open(path) - if err != nil { - return nil, nil, err - } - - fileInfo := hdfsFile.Stat() - - // default read all data when rang is nil - var limitReadN = fileInfo.Size() - if limitReadN < 0 { - return nil, nil, errors.Errorf("file length is illegal, length: %d", limitReadN) - } - - if request.Header.Get(source.Range) != "" { - requestRange, err := rangeutils.ParseRange(request.Header.Get(source.Range), uint64(limitReadN)) - if err != nil { - return nil, nil, err - } - _, err = hdfsFile.Seek(int64(requestRange.StartIndex), 0) - if err != nil { - hdfsFile.Close() - return nil, nil, err - } - limitReadN = int64(requestRange.EndIndex - requestRange.StartIndex) - } - return newHdfsFileReaderClose(hdfsFile, limitReadN, hdfsFile), &source.ExpireInfo{ - LastModified: timeutils.Format(fileInfo.ModTime()), - }, nil + response := source.NewResponse( + newHdfsFileReaderClose(hdfsFile, limitReadN, hdfsFile), + source.WithExpireInfo(source.ExpireInfo{ + LastModified: timeutils.Format(fileInfo.ModTime()), + })) + return response, nil } func (h *hdfsSourceClient) GetLastModified(request *source.Request) (int64, error) { diff --git a/pkg/source/hdfsprotocol/hdfs_source_client_test.go b/pkg/source/hdfsprotocol/hdfs_source_client_test.go index c05c5197a..d0c1a9c59 100644 --- a/pkg/source/hdfsprotocol/hdfs_source_client_test.go +++ b/pkg/source/hdfsprotocol/hdfs_source_client_test.go @@ -228,9 +228,9 @@ func Test_Download_FileExist_ByRange(t *testing.T) { }) assert.Nil(t, err) - download, err := sourceClient.Download(request) + response, err := sourceClient.Download(request) assert.Nil(t, err) - data, _ := io.ReadAll(download) + data, _ := io.ReadAll(response.Body) assert.Equal(t, hdfsExistFileContent, string(data)) } @@ -285,11 +285,11 @@ func Test_DownloadWithResponseHeader_FileExist_ByRange(t *testing.T) { request, err := source.NewRequest(hdfsExistFileURL) assert.Nil(t, err) request.Header.Add(source.Range, rang.String()) - body, expireInfo, err := sourceClient.DownloadWithExpireInfo(request) + response, err := sourceClient.Download(request) assert.Nil(t, err) - assert.Equal(t, hdfsExistFileLastModified, expireInfo.LastModified) + assert.Equal(t, hdfsExistFileLastModified, response.ExpireInfo().LastModified) - data, _ := io.ReadAll(body) + data, _ := io.ReadAll(response.Body) assert.Equal(t, string(data), string([]byte(hdfsExistFileContent)[hdfsExistFileRangeStart:hdfsExistFileRangeEnd])) } @@ -303,10 +303,9 @@ func TestDownloadWithResponseHeader_FileNotExist(t *testing.T) { request, err := source.NewRequest(hdfsNotExistFileURL) assert.Nil(t, err) request.Header.Add(source.Range, rang.String()) - reader, expireInfo, err := sourceClient.DownloadWithExpireInfo(request) + response, err := sourceClient.Download(request) assert.EqualError(t, err, "open /user/root/input/f3.txt: file does not exist") - assert.Nil(t, reader) - assert.Nil(t, expireInfo) + assert.Nil(t, response) } func TestGetLastModified_FileExist(t *testing.T) { diff --git a/pkg/source/httpprotocol/http_source_client.go b/pkg/source/httpprotocol/http_source_client.go index 7f691d548..cfd383be5 100644 --- a/pkg/source/httpprotocol/http_source_client.go +++ b/pkg/source/httpprotocol/http_source_client.go @@ -18,7 +18,6 @@ package httpprotocol import ( "fmt" - "io" "net" "net/http" "net/url" @@ -170,7 +169,7 @@ func (client *httpSourceClient) IsExpired(request *source.Request, info *source. LastModified)), nil } -func (client *httpSourceClient) Download(request *source.Request) (io.ReadCloser, error) { +func (client *httpSourceClient) Download(request *source.Request) (*source.Response, error) { resp, err := client.doRequest(http.MethodGet, request) if err != nil { return nil, err @@ -180,23 +179,15 @@ func (client *httpSourceClient) Download(request *source.Request) (io.ReadCloser resp.Body.Close() return nil, err } - return resp.Body, nil -} - -func (client *httpSourceClient) DownloadWithExpireInfo(request *source.Request) (io.ReadCloser, *source.ExpireInfo, error) { - resp, err := client.doRequest(http.MethodGet, request) - if err != nil { - return nil, nil, err - } - err = source.CheckResponseCode(resp.StatusCode, []int{http.StatusOK, http.StatusPartialContent}) - if err != nil { - resp.Body.Close() - return nil, nil, err - } - return resp.Body, &source.ExpireInfo{ - LastModified: resp.Header.Get(headers.LastModified), - ETag: resp.Header.Get(headers.ETag), - }, nil + response := source.NewResponse( + resp.Body, + source.WithExpireInfo( + source.ExpireInfo{ + LastModified: resp.Header.Get(headers.LastModified), + ETag: resp.Header.Get(headers.ETag), + }, + )) + return response, nil } func (client *httpSourceClient) GetLastModified(request *source.Request) (int64, error) { diff --git a/pkg/source/httpprotocol/http_source_client_test.go b/pkg/source/httpprotocol/http_source_client_test.go index c05a1a1e4..083da2ca0 100644 --- a/pkg/source/httpprotocol/http_source_client_test.go +++ b/pkg/source/httpprotocol/http_source_client_test.go @@ -131,7 +131,6 @@ func (suite *HTTPSourceClientTestSuite) SetupTest() { httpmock.RegisterResponder(http.MethodGet, notfoundRawURL, httpmock.NewStringResponder(http.StatusNotFound, "not found")) httpmock.RegisterResponder(http.MethodGet, normalNotSupportRangeRawURL, httpmock.NewStringResponder(http.StatusOK, testContent)) httpmock.RegisterResponder(http.MethodGet, errorRawURL, httpmock.NewErrorResponder(fmt.Errorf("error"))) - } func (suite *HTTPSourceClientTestSuite) TestNewHTTPSourceClient() { @@ -150,12 +149,11 @@ func (suite *HTTPSourceClientTestSuite) TestHttpSourceClientDownloadWithResponse ctx, cancel := context.WithTimeout(context.Background(), time.Second) timeoutRequest, err := source.NewRequestWithContext(ctx, timeoutRawURL, nil) suite.Nil(err) - reader, expireInfo, err := suite.httpClient.DownloadWithExpireInfo(timeoutRequest) + response, err := suite.httpClient.Download(timeoutRequest) cancel() suite.NotNil(err) suite.Equal("Get \"https://timeout.com\": context deadline exceeded", err.Error()) - suite.Nil(reader) - suite.Nil(expireInfo) + suite.Nil(response) normalRequest, _ := source.NewRequest(normalRawURL) normalRangeRequest, _ := source.NewRequest(normalRawURL) @@ -203,15 +201,16 @@ func (suite *HTTPSourceClientTestSuite) TestHttpSourceClientDownloadWithResponse } for _, tt := range tests { suite.Run(tt.name, func() { - reader, expireInfo, err := suite.httpClient.DownloadWithExpireInfo(tt.request) + response, err := suite.httpClient.Download(tt.request) if err != nil { suite.True(tt.wantErr.Error() == err.Error()) return } - bytes, err := io.ReadAll(reader) + bytes, err := io.ReadAll(response.Body) suite.Nil(err) suite.Equal(tt.content, string(bytes)) - suite.Equal(tt.expireInfo, expireInfo) + expireInfo := response.ExpireInfo() + suite.Equal(tt.expireInfo, &expireInfo) }) } } diff --git a/pkg/source/mock/mock_source_client.go b/pkg/source/mock/mock_source_client.go index bd4d0293b..3096c53ac 100644 --- a/pkg/source/mock/mock_source_client.go +++ b/pkg/source/mock/mock_source_client.go @@ -5,7 +5,6 @@ package mock import ( - io "io" reflect "reflect" source "d7y.io/dragonfly/v2/pkg/source" @@ -36,10 +35,10 @@ func (m *MockResourceClient) EXPECT() *MockResourceClientMockRecorder { } // Download mocks base method. -func (m *MockResourceClient) Download(arg0 *source.Request) (io.ReadCloser, error) { +func (m *MockResourceClient) Download(arg0 *source.Request) (*source.Response, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Download", arg0) - ret0, _ := ret[0].(io.ReadCloser) + ret0, _ := ret[0].(*source.Response) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -50,22 +49,6 @@ func (mr *MockResourceClientMockRecorder) Download(arg0 interface{}) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Download", reflect.TypeOf((*MockResourceClient)(nil).Download), arg0) } -// DownloadWithExpireInfo mocks base method. -func (m *MockResourceClient) DownloadWithExpireInfo(arg0 *source.Request) (io.ReadCloser, *source.ExpireInfo, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DownloadWithExpireInfo", arg0) - ret0, _ := ret[0].(io.ReadCloser) - ret1, _ := ret[1].(*source.ExpireInfo) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 -} - -// DownloadWithExpireInfo indicates an expected call of DownloadWithExpireInfo. -func (mr *MockResourceClientMockRecorder) DownloadWithExpireInfo(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DownloadWithExpireInfo", reflect.TypeOf((*MockResourceClient)(nil).DownloadWithExpireInfo), arg0) -} - // GetContentLength mocks base method. func (m *MockResourceClient) GetContentLength(arg0 *source.Request) (int64, error) { m.ctrl.T.Helper() diff --git a/pkg/source/ossprotocol/oss_source_client.go b/pkg/source/ossprotocol/oss_source_client.go index c56ba1e0a..44fc47935 100644 --- a/pkg/source/ossprotocol/oss_source_client.go +++ b/pkg/source/ossprotocol/oss_source_client.go @@ -18,7 +18,6 @@ package ossprotocol import ( "fmt" - "io" "net/http" "strconv" "sync" @@ -147,44 +146,33 @@ func (osc *ossSourceClient) IsExpired(request *source.Request, info *source.Expi return !(resHeader.Get(oss.HTTPHeaderEtag) == info.ETag || resHeader.Get(oss.HTTPHeaderLastModified) == info.LastModified), nil } -func (osc *ossSourceClient) Download(request *source.Request) (io.ReadCloser, error) { +func (osc *ossSourceClient) Download(request *source.Request) (*source.Response, error) { client, err := osc.getClient(request.Header) if err != nil { - return nil, errors.Wrap(err, "get oss client") + return nil, errors.Wrapf(err, "get oss client") } bucket, err := client.Bucket(request.URL.Host) if err != nil { - return nil, errors.Wrapf(err, "get oss bucket %s", request.URL.Host) - } - resp, err := bucket.GetObject(request.URL.Path, getOptions(request.Header)...) - if err != nil { - return nil, errors.Wrapf(err, "get oss object %s", request.URL.Path) - } - return resp, nil -} - -func (osc *ossSourceClient) DownloadWithExpireInfo(request *source.Request) (io.ReadCloser, *source.ExpireInfo, error) { - client, err := osc.getClient(request.Header) - if err != nil { - return nil, nil, errors.Wrapf(err, "get oss client") - } - bucket, err := client.Bucket(request.URL.Host) - if err != nil { - return nil, nil, errors.Wrapf(err, "get oss bucket: %s", request.URL.Host) + return nil, errors.Wrapf(err, "get oss bucket: %s", request.URL.Host) } objectResult, err := bucket.DoGetObject(&oss.GetObjectRequest{ObjectKey: request.URL.Path}, getOptions(request.Header)) if err != nil { - return nil, nil, errors.Wrapf(err, "get oss Object: %s", request.URL.Path) + return nil, errors.Wrapf(err, "get oss Object: %s", request.URL.Path) } err = source.CheckResponseCode(objectResult.Response.StatusCode, []int{http.StatusOK, http.StatusPartialContent}) if err != nil { objectResult.Response.Body.Close() - return nil, nil, err + return nil, err } - return objectResult.Response.Body, &source.ExpireInfo{ - LastModified: objectResult.Response.Headers.Get(headers.LastModified), - ETag: objectResult.Response.Headers.Get(headers.ETag), - }, nil + response := source.NewResponse( + objectResult.Response.Body, + source.WithExpireInfo( + source.ExpireInfo{ + LastModified: objectResult.Response.Headers.Get(headers.LastModified), + ETag: objectResult.Response.Headers.Get(headers.ETag), + }, + )) + return response, nil } func (osc *ossSourceClient) GetLastModified(request *source.Request) (int64, error) { diff --git a/pkg/source/response.go b/pkg/source/response.go index cb1c86b19..4b4220d75 100644 --- a/pkg/source/response.go +++ b/pkg/source/response.go @@ -16,7 +16,11 @@ package source -import "io" +import ( + "fmt" + "io" + "net/http" +) type Response struct { Status string @@ -25,3 +29,70 @@ type Response struct { Body io.ReadCloser ContentLength int64 } + +func NewResponse(rc io.ReadCloser, opts ...func(*Response)) *Response { + if rc == nil { + // for custom plugin, return an error body + rc = &errorBody{ + fmt.Errorf("empty io.ReadCloser, please check resource plugin implement")} + } + resp := &Response{ + Header: make(Header), + Status: "OK", + StatusCode: http.StatusOK, + Body: rc, + ContentLength: -1, + } + + for _, opt := range opts { + opt(resp) + } + return resp +} + +func WithStatus(code int, status string) func(*Response) { + return func(resp *Response) { + resp.StatusCode = code + resp.Status = status + } +} + +func WithContentLength(length int64) func(*Response) { + return func(resp *Response) { + resp.ContentLength = length + } +} + +func WithHeader(header map[string]string) func(*Response) { + return func(resp *Response) { + for k, v := range header { + resp.Header.Set(k, v) + } + } +} + +func WithExpireInfo(info ExpireInfo) func(*Response) { + return func(resp *Response) { + resp.Header.Set(LastModified, info.LastModified) + resp.Header.Set(ETag, info.ETag) + } +} + +func (resp *Response) ExpireInfo() ExpireInfo { + return ExpireInfo{ + LastModified: resp.Header.Get(LastModified), + ETag: resp.Header.Get(ETag), + } +} + +type errorBody struct { + error +} + +func (e *errorBody) Read(p []byte) (n int, err error) { + return 0, e.error +} + +func (e *errorBody) Close() error { + return nil +} diff --git a/pkg/source/source_client.go b/pkg/source/source_client.go index 0398bc24f..150ee9ea6 100644 --- a/pkg/source/source_client.go +++ b/pkg/source/source_client.go @@ -20,7 +20,6 @@ package source import ( "context" "fmt" - "io" "net/url" "strconv" "strings" @@ -105,10 +104,7 @@ type ResourceClient interface { IsExpired(request *Request, info *ExpireInfo) (bool, error) // Download downloads from source - Download(request *Request) (io.ReadCloser, error) - - // DownloadWithExpireInfo download from source with expireInfo - DownloadWithExpireInfo(request *Request) (io.ReadCloser, *ExpireInfo, error) + Download(request *Request) (*Response, error) // GetLastModified gets last modified timestamp milliseconds of resource GetLastModified(request *Request) (int64, error) @@ -251,14 +247,10 @@ func (c *clientWrapper) IsSupportRange(request *Request) (bool, error) { func (c *clientWrapper) IsExpired(request *Request, info *ExpireInfo) (bool, error) { return c.rc.IsExpired(c.adapter(request), info) } -func (c *clientWrapper) Download(request *Request) (io.ReadCloser, error) { +func (c *clientWrapper) Download(request *Request) (*Response, error) { return c.rc.Download(c.adapter(request)) } -func (c *clientWrapper) DownloadWithExpireInfo(request *Request) (io.ReadCloser, *ExpireInfo, error) { - return c.rc.DownloadWithExpireInfo(c.adapter(request)) -} - func (c *clientWrapper) GetLastModified(request *Request) (int64, error) { return c.rc.GetLastModified(c.adapter(request)) } @@ -318,7 +310,7 @@ func GetLastModified(request *Request) (int64, error) { return client.GetLastModified(request) } -func Download(request *Request) (io.ReadCloser, error) { +func Download(request *Request) (*Response, error) { client, ok := _defaultManager.GetClient(request.URL.Scheme) if !ok { return nil, errors.Wrapf(ErrNoClientFound, "scheme: %s", request.URL.Scheme) @@ -326,14 +318,6 @@ func Download(request *Request) (io.ReadCloser, error) { return client.Download(request) } -func DownloadWithExpireInfo(request *Request) (io.ReadCloser, *ExpireInfo, error) { - client, ok := _defaultManager.GetClient(request.URL.Scheme) - if !ok { - return nil, nil, errors.Wrapf(ErrNoClientFound, "scheme: %s", request.URL.Scheme) - } - return client.DownloadWithExpireInfo(request) -} - func List(request *Request) ([]*url.URL, error) { client, ok := _defaultManager.GetClient(request.URL.Scheme) if !ok { diff --git a/pkg/source/testdata/main.go b/pkg/source/testdata/main.go index b6db53ab3..679ad058e 100644 --- a/pkg/source/testdata/main.go +++ b/pkg/source/testdata/main.go @@ -39,13 +39,13 @@ func main() { } request, err = source.NewRequest("") - rc, err := client.Download(request) + response, err := client.Download(request) if err != nil { fmt.Printf("download error: %s\n", err) os.Exit(1) } - data, err := io.ReadAll(rc) + data, err := io.ReadAll(response.Body) if err != nil { fmt.Printf("read error: %s\n", err) os.Exit(1) @@ -56,7 +56,7 @@ func main() { os.Exit(1) } - err = rc.Close() + err = response.Body.Close() if err != nil { fmt.Printf("close error: %s\n", err) os.Exit(1) diff --git a/pkg/source/testdata/plugin/dfs.go b/pkg/source/testdata/plugin/dfs.go index b71dbaac6..60232687b 100644 --- a/pkg/source/testdata/plugin/dfs.go +++ b/pkg/source/testdata/plugin/dfs.go @@ -48,12 +48,8 @@ func (c *client) IsExpired(request *source.Request, info *source.ExpireInfo) (bo panic("implement me") } -func (c *client) Download(request *source.Request) (io.ReadCloser, error) { - return io.NopCloser(bytes.NewBufferString(data)), nil -} - -func (c *client) DownloadWithExpireInfo(request *source.Request) (io.ReadCloser, *source.ExpireInfo, error) { - return io.NopCloser(bytes.NewBufferString(data)), nil, nil +func (c *client) Download(request *source.Request) (*source.Response, error) { + return source.NewResponse(io.NopCloser(bytes.NewBufferString(data))), nil } func (c *client) GetLastModified(request *source.Request) (int64, error) {