334 lines
11 KiB
Go
334 lines
11 KiB
Go
/*
|
|
* Copyright 2020 The Dragonfly Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package httpprotocol
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/go-http-utils/headers"
|
|
"github.com/jarcoal/httpmock"
|
|
"github.com/stretchr/testify/suite"
|
|
|
|
nethttp "d7y.io/dragonfly/v2/pkg/net/http"
|
|
"d7y.io/dragonfly/v2/pkg/source"
|
|
)
|
|
|
|
func TestHTTPSourceClientTestSuite(t *testing.T) {
|
|
suite.Run(t, new(HTTPSourceClientTestSuite))
|
|
}
|
|
|
|
type HTTPSourceClientTestSuite struct {
|
|
suite.Suite
|
|
httpClient *httpSourceClient
|
|
}
|
|
|
|
func (suite *HTTPSourceClientTestSuite) SetupSuite() {
|
|
suite.httpClient = newHTTPSourceClient()
|
|
httpmock.ActivateNonDefault(_defaultHTTPClient)
|
|
}
|
|
|
|
func (suite *HTTPSourceClientTestSuite) TearDownSuite() {
|
|
httpmock.DeactivateAndReset()
|
|
}
|
|
|
|
var (
|
|
timeoutRawURL = "https://timeout.com"
|
|
normalRawURL = "https://normal.com"
|
|
expireRawURL = "https://expired.com"
|
|
errorRawURL = "https://error.com"
|
|
forbiddenRawURL = "https://forbidden.com"
|
|
notfoundRawURL = "https://notfound.com"
|
|
normalNotSupportRangeRawURL = "https://notsuppertrange.com"
|
|
)
|
|
|
|
var (
|
|
testContent = "l am test case"
|
|
lastModified = "Sun, 06 Jun 2021 12:52:30 GMT"
|
|
expireLastModified = "Sun, 06 Jun 2021 11:52:30 GMT"
|
|
etag = "UMiJT4h7MCEAEgnqCLA2CdAaABnK"
|
|
expireEtag = "UMiJ2T4h7MCEAEgnqCLA2CdAaABnK"
|
|
)
|
|
|
|
func (suite *HTTPSourceClientTestSuite) SetupTest() {
|
|
httpmock.Reset()
|
|
httpmock.RegisterResponder(http.MethodGet, timeoutRawURL, func(request *http.Request) (*http.Response, error) {
|
|
// To simulate the timeout
|
|
time.Sleep(5 * time.Second)
|
|
return httpmock.NewStringResponse(http.StatusOK, "ok"), nil
|
|
})
|
|
|
|
httpmock.RegisterResponder(http.MethodGet, normalRawURL, func(request *http.Request) (*http.Response, error) {
|
|
if rang := request.Header.Get(headers.Range); rang != "" {
|
|
r, _ := nethttp.GetRange(rang[6:])
|
|
header := http.Header{}
|
|
header.Set(headers.LastModified, lastModified)
|
|
header.Set(headers.ETag, etag)
|
|
res := &http.Response{
|
|
StatusCode: http.StatusPartialContent,
|
|
ContentLength: int64(r.EndIndex) - int64(r.StartIndex) + int64(1),
|
|
Body: httpmock.NewRespBodyFromString(testContent[r.StartIndex:r.EndIndex]),
|
|
Header: header,
|
|
}
|
|
return res, nil
|
|
}
|
|
if expire := request.Header.Get(headers.IfModifiedSince); expire != "" {
|
|
header := http.Header{}
|
|
header.Set(headers.LastModified, lastModified)
|
|
header.Set(headers.ETag, etag)
|
|
res := &http.Response{
|
|
StatusCode: http.StatusNotModified,
|
|
ContentLength: int64(len(testContent)),
|
|
Body: httpmock.NewRespBodyFromString(testContent),
|
|
Header: header,
|
|
}
|
|
return res, nil
|
|
}
|
|
header := http.Header{}
|
|
header.Set(headers.LastModified, lastModified)
|
|
header.Set(headers.ETag, etag)
|
|
res := &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
ContentLength: 14,
|
|
Body: httpmock.NewRespBodyFromString(testContent),
|
|
Header: header,
|
|
}
|
|
return res, nil
|
|
})
|
|
|
|
header := http.Header{}
|
|
header.Set(headers.LastModified, lastModified)
|
|
header.Set(headers.ETag, etag)
|
|
httpmock.RegisterResponder(http.MethodGet, expireRawURL, func(request *http.Request) (*http.Response, error) {
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
ContentLength: 14,
|
|
Body: httpmock.NewRespBodyFromString(testContent),
|
|
Header: header,
|
|
}, nil
|
|
})
|
|
|
|
httpmock.RegisterResponder(http.MethodGet, forbiddenRawURL, httpmock.NewStringResponder(http.StatusForbidden, "forbidden"))
|
|
httpmock.RegisterResponder(http.MethodGet, notfoundRawURL, httpmock.NewStringResponder(http.StatusNotFound, "not found"))
|
|
httpmock.RegisterResponder(http.MethodGet, normalNotSupportRangeRawURL, httpmock.NewStringResponder(http.StatusOK, testContent))
|
|
httpmock.RegisterResponder(http.MethodGet, errorRawURL, httpmock.NewErrorResponder(fmt.Errorf("error")))
|
|
}
|
|
|
|
func (suite *HTTPSourceClientTestSuite) TestNewHTTPSourceClient() {
|
|
var sourceClient source.ResourceClient
|
|
sourceClient = NewHTTPSourceClient()
|
|
suite.Equal(_defaultHTTPClient, sourceClient.(*httpSourceClient).httpClient)
|
|
suite.EqualValues(*_defaultHTTPClient, *sourceClient.(*httpSourceClient).httpClient)
|
|
|
|
expectedHTTPClient := &http.Client{}
|
|
sourceClient = NewHTTPSourceClient(WithHTTPClient(expectedHTTPClient))
|
|
suite.Equal(expectedHTTPClient, sourceClient.(*httpSourceClient).httpClient)
|
|
suite.EqualValues(*expectedHTTPClient, *sourceClient.(*httpSourceClient).httpClient)
|
|
}
|
|
|
|
func (suite *HTTPSourceClientTestSuite) TestHttpSourceClientDownloadWithResponseHeader() {
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
timeoutRequest, err := source.NewRequestWithContext(ctx, timeoutRawURL, nil)
|
|
suite.Nil(err)
|
|
response, err := suite.httpClient.Download(timeoutRequest)
|
|
cancel()
|
|
suite.NotNil(err)
|
|
suite.Equal("Get \"https://timeout.com\": context deadline exceeded", err.Error())
|
|
suite.Nil(response)
|
|
|
|
normalRequest, _ := source.NewRequest(normalRawURL)
|
|
normalRangeRequest, _ := source.NewRequest(normalRawURL)
|
|
normalRangeRequest.Header.Add(headers.Range, fmt.Sprintf("bytes=%s", "0-3"))
|
|
notfoundRequest, _ := source.NewRequest(notfoundRawURL)
|
|
errorRequest, _ := source.NewRequest(errorRawURL)
|
|
tests := []struct {
|
|
name string
|
|
request *source.Request
|
|
content string
|
|
expireInfo *source.ExpireInfo
|
|
wantErr error
|
|
}{
|
|
{
|
|
name: "normal download",
|
|
request: normalRequest,
|
|
content: testContent,
|
|
expireInfo: &source.ExpireInfo{
|
|
LastModified: lastModified,
|
|
ETag: etag,
|
|
},
|
|
wantErr: nil,
|
|
}, {
|
|
name: "range download",
|
|
request: normalRangeRequest,
|
|
content: testContent[0:3],
|
|
expireInfo: &source.ExpireInfo{
|
|
LastModified: lastModified,
|
|
ETag: etag,
|
|
},
|
|
wantErr: nil,
|
|
}, {
|
|
name: "not found download",
|
|
request: notfoundRequest,
|
|
content: "",
|
|
expireInfo: nil,
|
|
wantErr: source.CheckResponseCode(404, []int{200, 206}),
|
|
}, {
|
|
name: "error download",
|
|
request: errorRequest,
|
|
content: "",
|
|
expireInfo: nil,
|
|
wantErr: fmt.Errorf("Get \"https://error.com\": error"),
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
suite.Run(tt.name, func() {
|
|
response, err := suite.httpClient.Download(tt.request)
|
|
if err != nil {
|
|
suite.True(tt.wantErr.Error() == err.Error())
|
|
return
|
|
}
|
|
if err = response.Validate(); err != nil {
|
|
suite.True(tt.wantErr.Error() == err.Error())
|
|
return
|
|
}
|
|
bytes, err := io.ReadAll(response.Body)
|
|
suite.Nil(err)
|
|
suite.Equal(tt.content, string(bytes))
|
|
expireInfo := response.ExpireInfo()
|
|
suite.Equal(tt.expireInfo, &expireInfo)
|
|
})
|
|
}
|
|
}
|
|
|
|
func (suite *HTTPSourceClientTestSuite) TestHttpSourceClientGetContentLength() {
|
|
normalRequest, _ := source.NewRequest(normalRawURL)
|
|
normalRangeRequest, _ := source.NewRequest(normalRawURL)
|
|
normalRangeRequest.Header.Add(headers.Range, fmt.Sprintf("bytes=%s", "0-3"))
|
|
tests := []struct {
|
|
name string
|
|
request *source.Request
|
|
want int64
|
|
wantErr error
|
|
}{
|
|
{name: "support content length", request: normalRequest,
|
|
want: int64(len(testContent)),
|
|
wantErr: nil},
|
|
{name: "not support content length", request: normalRangeRequest,
|
|
want: 4,
|
|
wantErr: nil},
|
|
}
|
|
for _, tt := range tests {
|
|
suite.Run(tt.name, func() {
|
|
got, err := suite.httpClient.GetContentLength(tt.request)
|
|
suite.Equal(tt.wantErr, err)
|
|
suite.Equal(tt.want, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
func (suite *HTTPSourceClientTestSuite) TestHttpSourceClientIsExpired() {
|
|
normalRequest, _ := source.NewRequest(normalRawURL)
|
|
errorRequest, _ := source.NewRequest(errorRawURL)
|
|
expireRequest, _ := source.NewRequest(expireRawURL)
|
|
tests := []struct {
|
|
name string
|
|
request *source.Request
|
|
expireInfo *source.ExpireInfo
|
|
want bool
|
|
wantErr bool
|
|
}{
|
|
{name: "not expire", request: normalRequest, expireInfo: &source.ExpireInfo{
|
|
LastModified: lastModified,
|
|
ETag: etag,
|
|
}, want: false, wantErr: false},
|
|
{name: "error not expire", request: errorRequest, expireInfo: &source.ExpireInfo{
|
|
LastModified: lastModified,
|
|
ETag: etag,
|
|
}, want: false, wantErr: true},
|
|
{name: "expired", request: expireRequest, expireInfo: &source.ExpireInfo{
|
|
LastModified: expireLastModified,
|
|
ETag: expireEtag,
|
|
}, want: true, wantErr: false},
|
|
}
|
|
for _, tt := range tests {
|
|
suite.Run(tt.name, func() {
|
|
got, err := suite.httpClient.IsExpired(tt.request, tt.expireInfo)
|
|
suite.Equal(tt.want, got)
|
|
suite.Equal(tt.wantErr, err != nil)
|
|
})
|
|
}
|
|
}
|
|
|
|
func (suite *HTTPSourceClientTestSuite) TestHttpSourceClientIsSupportRange() {
|
|
httpmock.RegisterResponder(http.MethodGet, timeoutRawURL, func(request *http.Request) (*http.Response, error) {
|
|
time.Sleep(3 * time.Second)
|
|
return httpmock.NewStringResponse(http.StatusOK, "ok"), nil
|
|
})
|
|
parent := context.Background()
|
|
ctx, cancel := context.WithTimeout(parent, 1*time.Second)
|
|
request, err := source.NewRequestWithContext(ctx, timeoutRawURL, nil)
|
|
suite.Nil(err)
|
|
support, err := suite.httpClient.IsSupportRange(request)
|
|
cancel()
|
|
suite.NotNil(err)
|
|
suite.Equal("Get \"https://timeout.com\": context deadline exceeded", err.Error())
|
|
suite.Equal(false, support)
|
|
httpmock.RegisterResponder(http.MethodGet, normalRawURL, httpmock.NewStringResponder(http.StatusPartialContent, ""))
|
|
httpmock.RegisterResponder(http.MethodGet, normalNotSupportRangeRawURL, httpmock.NewStringResponder(http.StatusOK, ""))
|
|
httpmock.RegisterResponder(http.MethodGet, errorRawURL, httpmock.NewErrorResponder(fmt.Errorf("xxx")))
|
|
|
|
supportRangeRequest, _ := source.NewRequest(normalRawURL)
|
|
supportRangeRequest.Header.Add(headers.Range, fmt.Sprintf("bytes=%s", "0-3"))
|
|
notSupportRangeURL, _ := source.NewRequest(normalNotSupportRangeRawURL)
|
|
notSupportRangeURL.Header.Add(headers.Range, fmt.Sprintf("bytes=%s", "0-3"))
|
|
errRequest, _ := source.NewRequest(errorRawURL)
|
|
tests := []struct {
|
|
name string
|
|
request *source.Request
|
|
want bool
|
|
wantErr bool
|
|
}{
|
|
{name: "support", request: supportRangeRequest, want: true, wantErr: false},
|
|
{name: "notSupport", request: notSupportRangeURL, want: false, wantErr: false},
|
|
{name: "error", request: errRequest, want: false, wantErr: true},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
suite.Run(tt.name, func() {
|
|
got, err := suite.httpClient.IsSupportRange(tt.request)
|
|
suite.Equal(tt.wantErr, err != nil)
|
|
suite.Equal(tt.want, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
func (suite *HTTPSourceClientTestSuite) TestHttpSourceClientDoRequest() {
|
|
var testURL = "https://www.hackhttp.com"
|
|
httpmock.RegisterResponder(http.MethodGet, testURL, httpmock.NewStringResponder(http.StatusOK, "ok"))
|
|
request, err := source.NewRequest(testURL)
|
|
suite.Nil(err)
|
|
res, err := suite.httpClient.doRequest(http.MethodGet, request)
|
|
suite.Nil(err)
|
|
bytes, err := io.ReadAll(res.Body)
|
|
suite.Nil(err)
|
|
suite.EqualValues("ok", string(bytes))
|
|
}
|