feat: update source.Response and source client interface (#945)
* feat: update source.Response and source client interface Signed-off-by: Jim Ma <majinjing3@gmail.com>
This commit is contained in:
parent
0a152e01e9
commit
beaf4ce18d
|
|
@ -47,16 +47,16 @@ func (cm *manager) download(ctx context.Context, seedTask *task.SeedTask, breakP
|
||||||
if !stringutils.IsBlank(breakRange) {
|
if !stringutils.IsBlank(breakRange) {
|
||||||
downloadRequest.Header.Add(source.Range, breakRange)
|
downloadRequest.Header.Add(source.Range, breakRange)
|
||||||
}
|
}
|
||||||
body, expireInfo, err := source.DownloadWithExpireInfo(downloadRequest)
|
response, err := source.Download(downloadRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// update Expire info
|
// update Expire info
|
||||||
cm.updateExpireInfo(seedTask.ID, map[string]string{
|
cm.updateExpireInfo(seedTask.ID, map[string]string{
|
||||||
source.LastModified: expireInfo.LastModified,
|
source.LastModified: response.Header.Get(source.LastModified),
|
||||||
source.ETag: expireInfo.ETag,
|
source.ETag: response.Header.Get(source.ETag),
|
||||||
})
|
})
|
||||||
return body, err
|
return response.Body, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func getBreakRange(breakPoint int64, taskRange string, fileTotalLength int64) (string, error) {
|
func getBreakRange(breakPoint int64, taskRange string, fileTotalLength int64) (string, error) {
|
||||||
|
|
|
||||||
|
|
@ -94,31 +94,22 @@ func (suite *CDNManagerTestSuite) TestTriggerCDN() {
|
||||||
sourceClient.EXPECT().IsSupportRange(gomock.Any()).Return(true, nil).AnyTimes()
|
sourceClient.EXPECT().IsSupportRange(gomock.Any()).Return(true, nil).AnyTimes()
|
||||||
sourceClient.EXPECT().IsExpired(gomock.Any(), gomock.Any()).Return(false, nil).AnyTimes()
|
sourceClient.EXPECT().IsExpired(gomock.Any(), gomock.Any()).Return(false, nil).AnyTimes()
|
||||||
sourceClient.EXPECT().Download(gomock.Any()).DoAndReturn(
|
sourceClient.EXPECT().Download(gomock.Any()).DoAndReturn(
|
||||||
func(request *source.Request) (io.ReadCloser, error) {
|
func(request *source.Request) (*source.Response, error) {
|
||||||
content, _ := os.ReadFile("../../testdata/cdn/go.html")
|
content, _ := os.ReadFile("../../testdata/cdn/go.html")
|
||||||
if request.Header.Get(source.Range) != "" {
|
if request.Header.Get(source.Range) != "" {
|
||||||
parsed, _ := rangeutils.GetRange(request.Header.Get(source.Range))
|
parsed, _ := rangeutils.GetRange(request.Header.Get(source.Range))
|
||||||
return io.NopCloser(io.NewSectionReader(strings.NewReader(string(content)), int64(parsed.StartIndex), int64(parsed.EndIndex))), nil
|
return source.NewResponse(
|
||||||
}
|
io.NopCloser(io.NewSectionReader(strings.NewReader(string(content)), int64(parsed.StartIndex), int64(parsed.EndIndex))),
|
||||||
return io.NopCloser(strings.NewReader(string(content))), nil
|
source.WithExpireInfo(source.ExpireInfo{
|
||||||
},
|
|
||||||
).AnyTimes()
|
|
||||||
sourceClient.EXPECT().DownloadWithExpireInfo(gomock.Any()).DoAndReturn(
|
|
||||||
func(request *source.Request) (io.ReadCloser, *source.ExpireInfo, error) {
|
|
||||||
content, _ := os.ReadFile("../../testdata/cdn/go.html")
|
|
||||||
if request.Header.Get(source.Range) != "" {
|
|
||||||
parsed, _ := rangeutils.GetRange(request.Header.Get(source.Range))
|
|
||||||
return io.NopCloser(io.NewSectionReader(strings.NewReader(string(content)), int64(parsed.StartIndex), int64(parsed.EndIndex))),
|
|
||||||
&source.ExpireInfo{
|
|
||||||
LastModified: "Sun, 06 Jun 2021 12:52:30 GMT",
|
LastModified: "Sun, 06 Jun 2021 12:52:30 GMT",
|
||||||
ETag: "etag",
|
ETag: "etag",
|
||||||
}, nil
|
})), nil
|
||||||
}
|
}
|
||||||
return io.NopCloser(strings.NewReader(string(content))),
|
return source.NewResponse(io.NopCloser(strings.NewReader(string(content))),
|
||||||
&source.ExpireInfo{
|
source.WithExpireInfo(source.ExpireInfo{
|
||||||
LastModified: "Sun, 06 Jun 2021 12:52:30 GMT",
|
LastModified: "Sun, 06 Jun 2021 12:52:30 GMT",
|
||||||
ETag: "etag",
|
ETag: "etag",
|
||||||
}, nil
|
})), nil
|
||||||
},
|
},
|
||||||
).AnyTimes()
|
).AnyTimes()
|
||||||
sourceClient.EXPECT().GetLastModified(gomock.Any()).Return(
|
sourceClient.EXPECT().GetLastModified(gomock.Any()).Return(
|
||||||
|
|
|
||||||
|
|
@ -95,9 +95,9 @@ func TestFilePeerTask_BackSource_WithContentLength(t *testing.T) {
|
||||||
return -1, fmt.Errorf("unexpect url: %s", request.URL.String())
|
return -1, fmt.Errorf("unexpect url: %s", request.URL.String())
|
||||||
})
|
})
|
||||||
sourceClient.EXPECT().Download(gomock.Any()).DoAndReturn(
|
sourceClient.EXPECT().Download(gomock.Any()).DoAndReturn(
|
||||||
func(request *source.Request) (io.ReadCloser, error) {
|
func(request *source.Request) (*source.Response, error) {
|
||||||
if request.URL.String() == url {
|
if request.URL.String() == url {
|
||||||
return io.NopCloser(bytes.NewBuffer(testBytes)), nil
|
return source.NewResponse(io.NopCloser(bytes.NewBuffer(testBytes))), nil
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unexpect url: %s", request.URL.String())
|
return nil, fmt.Errorf("unexpect url: %s", request.URL.String())
|
||||||
})
|
})
|
||||||
|
|
@ -220,9 +220,9 @@ func TestFilePeerTask_BackSource_WithoutContentLength(t *testing.T) {
|
||||||
return -1, fmt.Errorf("unexpect url: %s", request.URL.String())
|
return -1, fmt.Errorf("unexpect url: %s", request.URL.String())
|
||||||
})
|
})
|
||||||
sourceClient.EXPECT().Download(gomock.Any()).DoAndReturn(
|
sourceClient.EXPECT().Download(gomock.Any()).DoAndReturn(
|
||||||
func(request *source.Request) (io.ReadCloser, error) {
|
func(request *source.Request) (*source.Response, error) {
|
||||||
if request.URL.String() == url {
|
if request.URL.String() == url {
|
||||||
return io.NopCloser(bytes.NewBuffer(testBytes)), nil
|
return source.NewResponse(io.NopCloser(bytes.NewBuffer(testBytes))), nil
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unexpect url: %s", request.URL.String())
|
return nil, fmt.Errorf("unexpect url: %s", request.URL.String())
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -222,8 +222,8 @@ func TestStreamPeerTask_BackSource_Partial_WithContentLength(t *testing.T) {
|
||||||
return int64(len(testBytes)), nil
|
return int64(len(testBytes)), nil
|
||||||
})
|
})
|
||||||
sourceClient.EXPECT().Download(gomock.Any()).DoAndReturn(
|
sourceClient.EXPECT().Download(gomock.Any()).DoAndReturn(
|
||||||
func(request *source.Request) (io.ReadCloser, error) {
|
func(request *source.Request) (*source.Response, error) {
|
||||||
return io.NopCloser(bytes.NewBuffer(testBytes)), nil
|
return source.NewResponse(io.NopCloser(bytes.NewBuffer(testBytes))), nil
|
||||||
})
|
})
|
||||||
|
|
||||||
pm := &pieceManager{
|
pm := &pieceManager{
|
||||||
|
|
|
||||||
|
|
@ -90,8 +90,8 @@ func TestStreamPeerTask_BackSource_WithContentLength(t *testing.T) {
|
||||||
return int64(len(testBytes)), nil
|
return int64(len(testBytes)), nil
|
||||||
})
|
})
|
||||||
sourceClient.EXPECT().Download(source.RequestEq(request.URL.String())).DoAndReturn(
|
sourceClient.EXPECT().Download(source.RequestEq(request.URL.String())).DoAndReturn(
|
||||||
func(request *source.Request) (io.ReadCloser, error) {
|
func(request *source.Request) (*source.Response, error) {
|
||||||
return io.NopCloser(bytes.NewBuffer(testBytes)), nil
|
return source.NewResponse(io.NopCloser(bytes.NewBuffer(testBytes))), nil
|
||||||
})
|
})
|
||||||
|
|
||||||
ptm := &peerTaskManager{
|
ptm := &peerTaskManager{
|
||||||
|
|
@ -190,8 +190,8 @@ func TestStreamPeerTask_BackSource_WithoutContentLength(t *testing.T) {
|
||||||
return -1, nil
|
return -1, nil
|
||||||
})
|
})
|
||||||
sourceClient.EXPECT().Download(source.RequestEq(request.URL.String())).DoAndReturn(
|
sourceClient.EXPECT().Download(source.RequestEq(request.URL.String())).DoAndReturn(
|
||||||
func(request *source.Request) (io.ReadCloser, error) {
|
func(request *source.Request) (*source.Response, error) {
|
||||||
return io.NopCloser(bytes.NewBuffer(testBytes)), nil
|
return source.NewResponse(io.NopCloser(bytes.NewBuffer(testBytes))), nil
|
||||||
})
|
})
|
||||||
|
|
||||||
ptm := &peerTaskManager{
|
ptm := &peerTaskManager{
|
||||||
|
|
|
||||||
|
|
@ -376,16 +376,16 @@ func (pm *pieceManager) DownloadSource(ctx context.Context, pt Task, request *sc
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
body, err := source.Download(downloadRequest)
|
response, err := source.Download(downloadRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer body.Close()
|
defer response.Body.Close()
|
||||||
reader := body.(io.Reader)
|
reader := response.Body.(io.Reader)
|
||||||
|
|
||||||
// calc total md5
|
// calc total md5
|
||||||
if pm.calculateDigest && request.UrlMeta.Digest != "" {
|
if pm.calculateDigest && request.UrlMeta.Digest != "" {
|
||||||
reader = digestutils.NewDigestReader(pt.Log(), body, request.UrlMeta.Digest)
|
reader = digestutils.NewDigestReader(pt.Log(), response.Body, request.UrlMeta.Digest)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. save to storage
|
// 2. save to storage
|
||||||
|
|
|
||||||
|
|
@ -143,7 +143,7 @@ func downloadFromSource(ctx context.Context, cfg *config.DfgetConfig, hdr map[st
|
||||||
wLog = logger.With("url", cfg.URL)
|
wLog = logger.With("url", cfg.URL)
|
||||||
start = time.Now()
|
start = time.Now()
|
||||||
target *os.File
|
target *os.File
|
||||||
response io.ReadCloser
|
response *source.Response
|
||||||
err error
|
err error
|
||||||
written int64
|
written int64
|
||||||
)
|
)
|
||||||
|
|
@ -164,9 +164,9 @@ func downloadFromSource(ctx context.Context, cfg *config.DfgetConfig, hdr map[st
|
||||||
if response, err = source.Download(downloadRequest); err != nil {
|
if response, err = source.Download(downloadRequest); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer response.Close()
|
defer response.Body.Close()
|
||||||
|
|
||||||
if written, err = io.Copy(target, response); err != nil {
|
if written, err = io.Copy(target, response.Body); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ func Test_downloadFromSource(t *testing.T) {
|
||||||
}
|
}
|
||||||
request, err := source.NewRequest(cfg.URL)
|
request, err := source.NewRequest(cfg.URL)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
sourceClient.EXPECT().Download(request).Return(io.NopCloser(strings.NewReader(content)), nil)
|
sourceClient.EXPECT().Download(request).Return(source.NewResponse(io.NopCloser(strings.NewReader(content))), nil)
|
||||||
|
|
||||||
err = downloadFromSource(context.Background(), cfg, nil)
|
err = downloadFromSource(context.Background(), cfg, nil)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
|
||||||
|
|
@ -111,21 +111,25 @@ func (h *hdfsSourceClient) IsExpired(request *source.Request, info *source.Expir
|
||||||
return fileInfo.ModTime().Format(source.LastModifiedLayout) != info.LastModified, nil
|
return fileInfo.ModTime().Format(source.LastModifiedLayout) != info.LastModified, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *hdfsSourceClient) Download(request *source.Request) (io.ReadCloser, error) {
|
func (h *hdfsSourceClient) Download(request *source.Request) (*source.Response, error) {
|
||||||
hdfsClient, path, err := h.getHDFSClientAndPath(request.URL)
|
hdfsClient, path, err := h.getHDFSClientAndPath(request.URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
hdfsFile, err := hdfsClient.Open(path)
|
hdfsFile, err := hdfsClient.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fileInfo := hdfsFile.Stat()
|
||||||
|
|
||||||
// default read all data when rang is nil
|
// default read all data when rang is nil
|
||||||
var limitReadN = hdfsFile.Stat().Size()
|
var limitReadN = fileInfo.Size()
|
||||||
if limitReadN < 0 {
|
if limitReadN < 0 {
|
||||||
return nil, errors.Errorf("file length is illegal, length: %d", limitReadN)
|
return nil, errors.Errorf("file length is illegal, length: %d", limitReadN)
|
||||||
}
|
}
|
||||||
|
|
||||||
if request.Header.Get(source.Range) != "" {
|
if request.Header.Get(source.Range) != "" {
|
||||||
requestRange, err := rangeutils.ParseRange(request.Header.Get(source.Range), uint64(limitReadN))
|
requestRange, err := rangeutils.ParseRange(request.Header.Get(source.Range), uint64(limitReadN))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -139,44 +143,12 @@ func (h *hdfsSourceClient) Download(request *source.Request) (io.ReadCloser, err
|
||||||
limitReadN = int64(requestRange.Length())
|
limitReadN = int64(requestRange.Length())
|
||||||
}
|
}
|
||||||
|
|
||||||
return newHdfsFileReaderClose(hdfsFile, limitReadN, hdfsFile), nil
|
response := source.NewResponse(
|
||||||
}
|
newHdfsFileReaderClose(hdfsFile, limitReadN, hdfsFile),
|
||||||
|
source.WithExpireInfo(source.ExpireInfo{
|
||||||
func (h *hdfsSourceClient) DownloadWithExpireInfo(request *source.Request) (io.ReadCloser, *source.ExpireInfo, error) {
|
|
||||||
|
|
||||||
hdfsClient, path, err := h.getHDFSClientAndPath(request.URL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
hdfsFile, err := hdfsClient.Open(path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
fileInfo := hdfsFile.Stat()
|
|
||||||
|
|
||||||
// default read all data when rang is nil
|
|
||||||
var limitReadN = fileInfo.Size()
|
|
||||||
if limitReadN < 0 {
|
|
||||||
return nil, nil, errors.Errorf("file length is illegal, length: %d", limitReadN)
|
|
||||||
}
|
|
||||||
|
|
||||||
if request.Header.Get(source.Range) != "" {
|
|
||||||
requestRange, err := rangeutils.ParseRange(request.Header.Get(source.Range), uint64(limitReadN))
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
_, err = hdfsFile.Seek(int64(requestRange.StartIndex), 0)
|
|
||||||
if err != nil {
|
|
||||||
hdfsFile.Close()
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
limitReadN = int64(requestRange.EndIndex - requestRange.StartIndex)
|
|
||||||
}
|
|
||||||
return newHdfsFileReaderClose(hdfsFile, limitReadN, hdfsFile), &source.ExpireInfo{
|
|
||||||
LastModified: timeutils.Format(fileInfo.ModTime()),
|
LastModified: timeutils.Format(fileInfo.ModTime()),
|
||||||
}, nil
|
}))
|
||||||
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *hdfsSourceClient) GetLastModified(request *source.Request) (int64, error) {
|
func (h *hdfsSourceClient) GetLastModified(request *source.Request) (int64, error) {
|
||||||
|
|
|
||||||
|
|
@ -228,9 +228,9 @@ func Test_Download_FileExist_ByRange(t *testing.T) {
|
||||||
})
|
})
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
download, err := sourceClient.Download(request)
|
response, err := sourceClient.Download(request)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
data, _ := io.ReadAll(download)
|
data, _ := io.ReadAll(response.Body)
|
||||||
|
|
||||||
assert.Equal(t, hdfsExistFileContent, string(data))
|
assert.Equal(t, hdfsExistFileContent, string(data))
|
||||||
}
|
}
|
||||||
|
|
@ -285,11 +285,11 @@ func Test_DownloadWithResponseHeader_FileExist_ByRange(t *testing.T) {
|
||||||
request, err := source.NewRequest(hdfsExistFileURL)
|
request, err := source.NewRequest(hdfsExistFileURL)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
request.Header.Add(source.Range, rang.String())
|
request.Header.Add(source.Range, rang.String())
|
||||||
body, expireInfo, err := sourceClient.DownloadWithExpireInfo(request)
|
response, err := sourceClient.Download(request)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, hdfsExistFileLastModified, expireInfo.LastModified)
|
assert.Equal(t, hdfsExistFileLastModified, response.ExpireInfo().LastModified)
|
||||||
|
|
||||||
data, _ := io.ReadAll(body)
|
data, _ := io.ReadAll(response.Body)
|
||||||
assert.Equal(t, string(data), string([]byte(hdfsExistFileContent)[hdfsExistFileRangeStart:hdfsExistFileRangeEnd]))
|
assert.Equal(t, string(data), string([]byte(hdfsExistFileContent)[hdfsExistFileRangeStart:hdfsExistFileRangeEnd]))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -303,10 +303,9 @@ func TestDownloadWithResponseHeader_FileNotExist(t *testing.T) {
|
||||||
request, err := source.NewRequest(hdfsNotExistFileURL)
|
request, err := source.NewRequest(hdfsNotExistFileURL)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
request.Header.Add(source.Range, rang.String())
|
request.Header.Add(source.Range, rang.String())
|
||||||
reader, expireInfo, err := sourceClient.DownloadWithExpireInfo(request)
|
response, err := sourceClient.Download(request)
|
||||||
assert.EqualError(t, err, "open /user/root/input/f3.txt: file does not exist")
|
assert.EqualError(t, err, "open /user/root/input/f3.txt: file does not exist")
|
||||||
assert.Nil(t, reader)
|
assert.Nil(t, response)
|
||||||
assert.Nil(t, expireInfo)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetLastModified_FileExist(t *testing.T) {
|
func TestGetLastModified_FileExist(t *testing.T) {
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,6 @@ package httpprotocol
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
@ -170,7 +169,7 @@ func (client *httpSourceClient) IsExpired(request *source.Request, info *source.
|
||||||
LastModified)), nil
|
LastModified)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (client *httpSourceClient) Download(request *source.Request) (io.ReadCloser, error) {
|
func (client *httpSourceClient) Download(request *source.Request) (*source.Response, error) {
|
||||||
resp, err := client.doRequest(http.MethodGet, request)
|
resp, err := client.doRequest(http.MethodGet, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -180,23 +179,15 @@ func (client *httpSourceClient) Download(request *source.Request) (io.ReadCloser
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return resp.Body, nil
|
response := source.NewResponse(
|
||||||
}
|
resp.Body,
|
||||||
|
source.WithExpireInfo(
|
||||||
func (client *httpSourceClient) DownloadWithExpireInfo(request *source.Request) (io.ReadCloser, *source.ExpireInfo, error) {
|
source.ExpireInfo{
|
||||||
resp, err := client.doRequest(http.MethodGet, request)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
err = source.CheckResponseCode(resp.StatusCode, []int{http.StatusOK, http.StatusPartialContent})
|
|
||||||
if err != nil {
|
|
||||||
resp.Body.Close()
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
return resp.Body, &source.ExpireInfo{
|
|
||||||
LastModified: resp.Header.Get(headers.LastModified),
|
LastModified: resp.Header.Get(headers.LastModified),
|
||||||
ETag: resp.Header.Get(headers.ETag),
|
ETag: resp.Header.Get(headers.ETag),
|
||||||
}, nil
|
},
|
||||||
|
))
|
||||||
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (client *httpSourceClient) GetLastModified(request *source.Request) (int64, error) {
|
func (client *httpSourceClient) GetLastModified(request *source.Request) (int64, error) {
|
||||||
|
|
|
||||||
|
|
@ -131,7 +131,6 @@ func (suite *HTTPSourceClientTestSuite) SetupTest() {
|
||||||
httpmock.RegisterResponder(http.MethodGet, notfoundRawURL, httpmock.NewStringResponder(http.StatusNotFound, "not found"))
|
httpmock.RegisterResponder(http.MethodGet, notfoundRawURL, httpmock.NewStringResponder(http.StatusNotFound, "not found"))
|
||||||
httpmock.RegisterResponder(http.MethodGet, normalNotSupportRangeRawURL, httpmock.NewStringResponder(http.StatusOK, testContent))
|
httpmock.RegisterResponder(http.MethodGet, normalNotSupportRangeRawURL, httpmock.NewStringResponder(http.StatusOK, testContent))
|
||||||
httpmock.RegisterResponder(http.MethodGet, errorRawURL, httpmock.NewErrorResponder(fmt.Errorf("error")))
|
httpmock.RegisterResponder(http.MethodGet, errorRawURL, httpmock.NewErrorResponder(fmt.Errorf("error")))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (suite *HTTPSourceClientTestSuite) TestNewHTTPSourceClient() {
|
func (suite *HTTPSourceClientTestSuite) TestNewHTTPSourceClient() {
|
||||||
|
|
@ -150,12 +149,11 @@ func (suite *HTTPSourceClientTestSuite) TestHttpSourceClientDownloadWithResponse
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
timeoutRequest, err := source.NewRequestWithContext(ctx, timeoutRawURL, nil)
|
timeoutRequest, err := source.NewRequestWithContext(ctx, timeoutRawURL, nil)
|
||||||
suite.Nil(err)
|
suite.Nil(err)
|
||||||
reader, expireInfo, err := suite.httpClient.DownloadWithExpireInfo(timeoutRequest)
|
response, err := suite.httpClient.Download(timeoutRequest)
|
||||||
cancel()
|
cancel()
|
||||||
suite.NotNil(err)
|
suite.NotNil(err)
|
||||||
suite.Equal("Get \"https://timeout.com\": context deadline exceeded", err.Error())
|
suite.Equal("Get \"https://timeout.com\": context deadline exceeded", err.Error())
|
||||||
suite.Nil(reader)
|
suite.Nil(response)
|
||||||
suite.Nil(expireInfo)
|
|
||||||
|
|
||||||
normalRequest, _ := source.NewRequest(normalRawURL)
|
normalRequest, _ := source.NewRequest(normalRawURL)
|
||||||
normalRangeRequest, _ := source.NewRequest(normalRawURL)
|
normalRangeRequest, _ := source.NewRequest(normalRawURL)
|
||||||
|
|
@ -203,15 +201,16 @@ func (suite *HTTPSourceClientTestSuite) TestHttpSourceClientDownloadWithResponse
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
suite.Run(tt.name, func() {
|
suite.Run(tt.name, func() {
|
||||||
reader, expireInfo, err := suite.httpClient.DownloadWithExpireInfo(tt.request)
|
response, err := suite.httpClient.Download(tt.request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
suite.True(tt.wantErr.Error() == err.Error())
|
suite.True(tt.wantErr.Error() == err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
bytes, err := io.ReadAll(reader)
|
bytes, err := io.ReadAll(response.Body)
|
||||||
suite.Nil(err)
|
suite.Nil(err)
|
||||||
suite.Equal(tt.content, string(bytes))
|
suite.Equal(tt.content, string(bytes))
|
||||||
suite.Equal(tt.expireInfo, expireInfo)
|
expireInfo := response.ExpireInfo()
|
||||||
|
suite.Equal(tt.expireInfo, &expireInfo)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@
|
||||||
package mock
|
package mock
|
||||||
|
|
||||||
import (
|
import (
|
||||||
io "io"
|
|
||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
source "d7y.io/dragonfly/v2/pkg/source"
|
source "d7y.io/dragonfly/v2/pkg/source"
|
||||||
|
|
@ -36,10 +35,10 @@ func (m *MockResourceClient) EXPECT() *MockResourceClientMockRecorder {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Download mocks base method.
|
// Download mocks base method.
|
||||||
func (m *MockResourceClient) Download(arg0 *source.Request) (io.ReadCloser, error) {
|
func (m *MockResourceClient) Download(arg0 *source.Request) (*source.Response, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "Download", arg0)
|
ret := m.ctrl.Call(m, "Download", arg0)
|
||||||
ret0, _ := ret[0].(io.ReadCloser)
|
ret0, _ := ret[0].(*source.Response)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
@ -50,22 +49,6 @@ func (mr *MockResourceClientMockRecorder) Download(arg0 interface{}) *gomock.Cal
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Download", reflect.TypeOf((*MockResourceClient)(nil).Download), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Download", reflect.TypeOf((*MockResourceClient)(nil).Download), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DownloadWithExpireInfo mocks base method.
|
|
||||||
func (m *MockResourceClient) DownloadWithExpireInfo(arg0 *source.Request) (io.ReadCloser, *source.ExpireInfo, error) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "DownloadWithExpireInfo", arg0)
|
|
||||||
ret0, _ := ret[0].(io.ReadCloser)
|
|
||||||
ret1, _ := ret[1].(*source.ExpireInfo)
|
|
||||||
ret2, _ := ret[2].(error)
|
|
||||||
return ret0, ret1, ret2
|
|
||||||
}
|
|
||||||
|
|
||||||
// DownloadWithExpireInfo indicates an expected call of DownloadWithExpireInfo.
|
|
||||||
func (mr *MockResourceClientMockRecorder) DownloadWithExpireInfo(arg0 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DownloadWithExpireInfo", reflect.TypeOf((*MockResourceClient)(nil).DownloadWithExpireInfo), arg0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetContentLength mocks base method.
|
// GetContentLength mocks base method.
|
||||||
func (m *MockResourceClient) GetContentLength(arg0 *source.Request) (int64, error) {
|
func (m *MockResourceClient) GetContentLength(arg0 *source.Request) (int64, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,6 @@ package ossprotocol
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
@ -147,44 +146,33 @@ func (osc *ossSourceClient) IsExpired(request *source.Request, info *source.Expi
|
||||||
return !(resHeader.Get(oss.HTTPHeaderEtag) == info.ETag || resHeader.Get(oss.HTTPHeaderLastModified) == info.LastModified), nil
|
return !(resHeader.Get(oss.HTTPHeaderEtag) == info.ETag || resHeader.Get(oss.HTTPHeaderLastModified) == info.LastModified), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (osc *ossSourceClient) Download(request *source.Request) (io.ReadCloser, error) {
|
func (osc *ossSourceClient) Download(request *source.Request) (*source.Response, error) {
|
||||||
client, err := osc.getClient(request.Header)
|
client, err := osc.getClient(request.Header)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "get oss client")
|
return nil, errors.Wrapf(err, "get oss client")
|
||||||
}
|
}
|
||||||
bucket, err := client.Bucket(request.URL.Host)
|
bucket, err := client.Bucket(request.URL.Host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "get oss bucket %s", request.URL.Host)
|
return nil, errors.Wrapf(err, "get oss bucket: %s", request.URL.Host)
|
||||||
}
|
|
||||||
resp, err := bucket.GetObject(request.URL.Path, getOptions(request.Header)...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrapf(err, "get oss object %s", request.URL.Path)
|
|
||||||
}
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (osc *ossSourceClient) DownloadWithExpireInfo(request *source.Request) (io.ReadCloser, *source.ExpireInfo, error) {
|
|
||||||
client, err := osc.getClient(request.Header)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, errors.Wrapf(err, "get oss client")
|
|
||||||
}
|
|
||||||
bucket, err := client.Bucket(request.URL.Host)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, errors.Wrapf(err, "get oss bucket: %s", request.URL.Host)
|
|
||||||
}
|
}
|
||||||
objectResult, err := bucket.DoGetObject(&oss.GetObjectRequest{ObjectKey: request.URL.Path}, getOptions(request.Header))
|
objectResult, err := bucket.DoGetObject(&oss.GetObjectRequest{ObjectKey: request.URL.Path}, getOptions(request.Header))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, errors.Wrapf(err, "get oss Object: %s", request.URL.Path)
|
return nil, errors.Wrapf(err, "get oss Object: %s", request.URL.Path)
|
||||||
}
|
}
|
||||||
err = source.CheckResponseCode(objectResult.Response.StatusCode, []int{http.StatusOK, http.StatusPartialContent})
|
err = source.CheckResponseCode(objectResult.Response.StatusCode, []int{http.StatusOK, http.StatusPartialContent})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
objectResult.Response.Body.Close()
|
objectResult.Response.Body.Close()
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return objectResult.Response.Body, &source.ExpireInfo{
|
response := source.NewResponse(
|
||||||
|
objectResult.Response.Body,
|
||||||
|
source.WithExpireInfo(
|
||||||
|
source.ExpireInfo{
|
||||||
LastModified: objectResult.Response.Headers.Get(headers.LastModified),
|
LastModified: objectResult.Response.Headers.Get(headers.LastModified),
|
||||||
ETag: objectResult.Response.Headers.Get(headers.ETag),
|
ETag: objectResult.Response.Headers.Get(headers.ETag),
|
||||||
}, nil
|
},
|
||||||
|
))
|
||||||
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (osc *ossSourceClient) GetLastModified(request *source.Request) (int64, error) {
|
func (osc *ossSourceClient) GetLastModified(request *source.Request) (int64, error) {
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,11 @@
|
||||||
|
|
||||||
package source
|
package source
|
||||||
|
|
||||||
import "io"
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
type Response struct {
|
type Response struct {
|
||||||
Status string
|
Status string
|
||||||
|
|
@ -25,3 +29,70 @@ type Response struct {
|
||||||
Body io.ReadCloser
|
Body io.ReadCloser
|
||||||
ContentLength int64
|
ContentLength int64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewResponse(rc io.ReadCloser, opts ...func(*Response)) *Response {
|
||||||
|
if rc == nil {
|
||||||
|
// for custom plugin, return an error body
|
||||||
|
rc = &errorBody{
|
||||||
|
fmt.Errorf("empty io.ReadCloser, please check resource plugin implement")}
|
||||||
|
}
|
||||||
|
resp := &Response{
|
||||||
|
Header: make(Header),
|
||||||
|
Status: "OK",
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: rc,
|
||||||
|
ContentLength: -1,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(resp)
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithStatus(code int, status string) func(*Response) {
|
||||||
|
return func(resp *Response) {
|
||||||
|
resp.StatusCode = code
|
||||||
|
resp.Status = status
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithContentLength(length int64) func(*Response) {
|
||||||
|
return func(resp *Response) {
|
||||||
|
resp.ContentLength = length
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithHeader(header map[string]string) func(*Response) {
|
||||||
|
return func(resp *Response) {
|
||||||
|
for k, v := range header {
|
||||||
|
resp.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithExpireInfo(info ExpireInfo) func(*Response) {
|
||||||
|
return func(resp *Response) {
|
||||||
|
resp.Header.Set(LastModified, info.LastModified)
|
||||||
|
resp.Header.Set(ETag, info.ETag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (resp *Response) ExpireInfo() ExpireInfo {
|
||||||
|
return ExpireInfo{
|
||||||
|
LastModified: resp.Header.Get(LastModified),
|
||||||
|
ETag: resp.Header.Get(ETag),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type errorBody struct {
|
||||||
|
error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *errorBody) Read(p []byte) (n int, err error) {
|
||||||
|
return 0, e.error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *errorBody) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,6 @@ package source
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
@ -105,10 +104,7 @@ type ResourceClient interface {
|
||||||
IsExpired(request *Request, info *ExpireInfo) (bool, error)
|
IsExpired(request *Request, info *ExpireInfo) (bool, error)
|
||||||
|
|
||||||
// Download downloads from source
|
// Download downloads from source
|
||||||
Download(request *Request) (io.ReadCloser, error)
|
Download(request *Request) (*Response, error)
|
||||||
|
|
||||||
// DownloadWithExpireInfo download from source with expireInfo
|
|
||||||
DownloadWithExpireInfo(request *Request) (io.ReadCloser, *ExpireInfo, error)
|
|
||||||
|
|
||||||
// GetLastModified gets last modified timestamp milliseconds of resource
|
// GetLastModified gets last modified timestamp milliseconds of resource
|
||||||
GetLastModified(request *Request) (int64, error)
|
GetLastModified(request *Request) (int64, error)
|
||||||
|
|
@ -251,14 +247,10 @@ func (c *clientWrapper) IsSupportRange(request *Request) (bool, error) {
|
||||||
func (c *clientWrapper) IsExpired(request *Request, info *ExpireInfo) (bool, error) {
|
func (c *clientWrapper) IsExpired(request *Request, info *ExpireInfo) (bool, error) {
|
||||||
return c.rc.IsExpired(c.adapter(request), info)
|
return c.rc.IsExpired(c.adapter(request), info)
|
||||||
}
|
}
|
||||||
func (c *clientWrapper) Download(request *Request) (io.ReadCloser, error) {
|
func (c *clientWrapper) Download(request *Request) (*Response, error) {
|
||||||
return c.rc.Download(c.adapter(request))
|
return c.rc.Download(c.adapter(request))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientWrapper) DownloadWithExpireInfo(request *Request) (io.ReadCloser, *ExpireInfo, error) {
|
|
||||||
return c.rc.DownloadWithExpireInfo(c.adapter(request))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *clientWrapper) GetLastModified(request *Request) (int64, error) {
|
func (c *clientWrapper) GetLastModified(request *Request) (int64, error) {
|
||||||
return c.rc.GetLastModified(c.adapter(request))
|
return c.rc.GetLastModified(c.adapter(request))
|
||||||
}
|
}
|
||||||
|
|
@ -318,7 +310,7 @@ func GetLastModified(request *Request) (int64, error) {
|
||||||
return client.GetLastModified(request)
|
return client.GetLastModified(request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Download(request *Request) (io.ReadCloser, error) {
|
func Download(request *Request) (*Response, error) {
|
||||||
client, ok := _defaultManager.GetClient(request.URL.Scheme)
|
client, ok := _defaultManager.GetClient(request.URL.Scheme)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.Wrapf(ErrNoClientFound, "scheme: %s", request.URL.Scheme)
|
return nil, errors.Wrapf(ErrNoClientFound, "scheme: %s", request.URL.Scheme)
|
||||||
|
|
@ -326,14 +318,6 @@ func Download(request *Request) (io.ReadCloser, error) {
|
||||||
return client.Download(request)
|
return client.Download(request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func DownloadWithExpireInfo(request *Request) (io.ReadCloser, *ExpireInfo, error) {
|
|
||||||
client, ok := _defaultManager.GetClient(request.URL.Scheme)
|
|
||||||
if !ok {
|
|
||||||
return nil, nil, errors.Wrapf(ErrNoClientFound, "scheme: %s", request.URL.Scheme)
|
|
||||||
}
|
|
||||||
return client.DownloadWithExpireInfo(request)
|
|
||||||
}
|
|
||||||
|
|
||||||
func List(request *Request) ([]*url.URL, error) {
|
func List(request *Request) ([]*url.URL, error) {
|
||||||
client, ok := _defaultManager.GetClient(request.URL.Scheme)
|
client, ok := _defaultManager.GetClient(request.URL.Scheme)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|
|
||||||
|
|
@ -39,13 +39,13 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
request, err = source.NewRequest("")
|
request, err = source.NewRequest("")
|
||||||
rc, err := client.Download(request)
|
response, err := client.Download(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("download error: %s\n", err)
|
fmt.Printf("download error: %s\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := io.ReadAll(rc)
|
data, err := io.ReadAll(response.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("read error: %s\n", err)
|
fmt.Printf("read error: %s\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
|
|
@ -56,7 +56,7 @@ func main() {
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = rc.Close()
|
err = response.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("close error: %s\n", err)
|
fmt.Printf("close error: %s\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
|
|
|
||||||
|
|
@ -48,12 +48,8 @@ func (c *client) IsExpired(request *source.Request, info *source.ExpireInfo) (bo
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) Download(request *source.Request) (io.ReadCloser, error) {
|
func (c *client) Download(request *source.Request) (*source.Response, error) {
|
||||||
return io.NopCloser(bytes.NewBufferString(data)), nil
|
return source.NewResponse(io.NopCloser(bytes.NewBufferString(data))), nil
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) DownloadWithExpireInfo(request *source.Request) (io.ReadCloser, *source.ExpireInfo, error) {
|
|
||||||
return io.NopCloser(bytes.NewBufferString(data)), nil, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) GetLastModified(request *source.Request) (int64, error) {
|
func (c *client) GetLastModified(request *source.Request) (int64, error) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue