325 lines
10 KiB
Go
325 lines
10 KiB
Go
/*
|
|
* Copyright 2020 The Dragonfly Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package httpprotocol
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"net/url"
|
|
"testing"
|
|
"time"
|
|
|
|
"d7y.io/dragonfly/v2/pkg/source"
|
|
"d7y.io/dragonfly/v2/pkg/util/rangeutils"
|
|
"github.com/go-http-utils/headers"
|
|
"github.com/jarcoal/httpmock"
|
|
"github.com/stretchr/testify/suite"
|
|
)
|
|
|
|
func TestHTTPSourceClientTestSuite(t *testing.T) {
|
|
suite.Run(t, new(HTTPSourceClientTestSuite))
|
|
}
|
|
|
|
type HTTPSourceClientTestSuite struct {
|
|
suite.Suite
|
|
source.ResourceClient
|
|
}
|
|
|
|
func (suite *HTTPSourceClientTestSuite) SetupSuite() {
|
|
suite.ResourceClient = NewHTTPSourceClient()
|
|
httpmock.ActivateNonDefault(_defaultHTTPClient)
|
|
}
|
|
|
|
func (suite *HTTPSourceClientTestSuite) TearDownSuite() {
|
|
httpmock.DeactivateAndReset()
|
|
}
|
|
|
|
var (
|
|
timeoutURL = "http://timeout.com"
|
|
normalURL = "http://normal.com"
|
|
errorURL = "http://error.com"
|
|
forbiddenURL = "http://forbidden.com"
|
|
notfoundURL = "http://notfound.com"
|
|
normalNotSupportRangeURL = "http://notsuppertrange.com"
|
|
)
|
|
|
|
var (
|
|
testContent = "l am test case"
|
|
lastModified = "Sun, 06 Jun 2021 12:52:30 GMT"
|
|
//etag = "UMiJT4h7MCEAEgnqCLA2CdAaABnK" // todo etag business code can not obtain
|
|
etag = ""
|
|
)
|
|
|
|
func (suite *HTTPSourceClientTestSuite) SetupTest() {
|
|
httpmock.Reset()
|
|
httpmock.RegisterResponder(http.MethodGet, timeoutURL, func(request *http.Request) (*http.Response, error) {
|
|
// To simulate the timeout
|
|
time.Sleep(5 * time.Second)
|
|
return httpmock.NewStringResponse(http.StatusOK, "ok"), nil
|
|
})
|
|
|
|
httpmock.RegisterResponder(http.MethodGet, normalURL, func(request *http.Request) (*http.Response, error) {
|
|
if rang := request.Header.Get(headers.Range); rang != "" {
|
|
r, _ := rangeutils.ParseRange(rang[6:])
|
|
res := &http.Response{
|
|
StatusCode: http.StatusPartialContent,
|
|
ContentLength: int64(r.EndIndex) - int64(r.StartIndex) + int64(1),
|
|
Body: httpmock.NewRespBodyFromString(testContent[r.StartIndex:r.EndIndex]),
|
|
Header: http.Header{
|
|
headers.LastModified: []string{lastModified},
|
|
headers.ETag: []string{etag},
|
|
},
|
|
}
|
|
return res, nil
|
|
}
|
|
if expire := request.Header.Get(headers.IfModifiedSince); expire != "" {
|
|
res := &http.Response{
|
|
StatusCode: http.StatusNotModified,
|
|
ContentLength: int64(len(testContent)),
|
|
Body: httpmock.NewRespBodyFromString(testContent),
|
|
Header: http.Header{
|
|
headers.LastModified: []string{lastModified},
|
|
headers.ETag: []string{etag},
|
|
},
|
|
}
|
|
return res, nil
|
|
}
|
|
res := &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
ContentLength: 14,
|
|
Body: httpmock.NewRespBodyFromString(testContent),
|
|
Header: http.Header{
|
|
headers.LastModified: []string{lastModified},
|
|
headers.ETag: []string{etag},
|
|
},
|
|
}
|
|
return res, nil
|
|
})
|
|
|
|
httpmock.RegisterResponder(http.MethodGet, forbiddenURL, httpmock.NewStringResponder(http.StatusForbidden, "forbidden"))
|
|
httpmock.RegisterResponder(http.MethodGet, notfoundURL, httpmock.NewStringResponder(http.StatusNotFound, "not found"))
|
|
httpmock.RegisterResponder(http.MethodGet, normalNotSupportRangeURL, httpmock.NewStringResponder(http.StatusOK, testContent))
|
|
httpmock.RegisterResponder(http.MethodGet, errorURL, httpmock.NewErrorResponder(fmt.Errorf("error")))
|
|
}
|
|
|
|
func (suite *HTTPSourceClientTestSuite) TestNewHTTPSourceClient() {
|
|
var sourceClient source.ResourceClient
|
|
sourceClient = NewHTTPSourceClient()
|
|
suite.Equal(_defaultHTTPClient, sourceClient.(*httpSourceClient).httpClient)
|
|
suite.EqualValues(*_defaultHTTPClient, *sourceClient.(*httpSourceClient).httpClient)
|
|
|
|
expectedHTTPClient := &http.Client{}
|
|
sourceClient = NewHTTPSourceClient(WithHTTPClient(expectedHTTPClient))
|
|
suite.Equal(expectedHTTPClient, sourceClient.(*httpSourceClient).httpClient)
|
|
suite.EqualValues(*expectedHTTPClient, *sourceClient.(*httpSourceClient).httpClient)
|
|
}
|
|
|
|
func (suite *HTTPSourceClientTestSuite) TestHttpSourceClientDownloadWithResponseHeader() {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
|
reader, responseHeader, err := suite.DownloadWithResponseHeader(ctx, timeoutURL, source.RequestHeader{}, nil)
|
|
cancel()
|
|
suite.NotNil(err)
|
|
suite.Equal("Get \"http://timeout.com\": context deadline exceeded", err.Error())
|
|
suite.Nil(reader)
|
|
suite.Nil(responseHeader)
|
|
|
|
type args struct {
|
|
ctx context.Context
|
|
url string
|
|
header source.RequestHeader
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
content string
|
|
expireInfo source.ResponseHeader
|
|
wantErr error
|
|
}{
|
|
{
|
|
name: "normal download",
|
|
args: args{
|
|
ctx: context.Background(),
|
|
url: normalURL,
|
|
header: nil,
|
|
},
|
|
content: testContent,
|
|
expireInfo: source.ResponseHeader{
|
|
source.LastModified: lastModified,
|
|
source.ETag: etag,
|
|
},
|
|
wantErr: nil,
|
|
}, {
|
|
name: "range download",
|
|
args: args{
|
|
ctx: context.Background(),
|
|
url: normalURL,
|
|
header: source.RequestHeader{"Range": fmt.Sprintf("bytes=%s", "0-3")},
|
|
},
|
|
content: testContent[0:3],
|
|
expireInfo: source.ResponseHeader{
|
|
headers.LastModified: lastModified,
|
|
headers.ETag: etag,
|
|
},
|
|
wantErr: nil,
|
|
}, {
|
|
name: "not found download",
|
|
args: args{
|
|
ctx: context.Background(),
|
|
url: notfoundURL,
|
|
header: nil,
|
|
},
|
|
content: "",
|
|
expireInfo: nil,
|
|
wantErr: fmt.Errorf("unexpected status code: %d", http.StatusNotFound),
|
|
}, {
|
|
name: "error download",
|
|
args: args{
|
|
ctx: context.Background(),
|
|
url: errorURL,
|
|
header: nil,
|
|
},
|
|
content: "",
|
|
expireInfo: nil,
|
|
wantErr: &url.Error{
|
|
Op: "Get",
|
|
URL: errorURL,
|
|
Err: fmt.Errorf("error"),
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
suite.Run(tt.name, func() {
|
|
reader, responseHeader, err := suite.DownloadWithResponseHeader(tt.args.ctx, tt.args.url, tt.args.header, nil)
|
|
suite.Equal(tt.wantErr, err)
|
|
if err != nil {
|
|
return
|
|
}
|
|
bytes, err := ioutil.ReadAll(reader)
|
|
suite.Nil(err)
|
|
suite.Equal(tt.content, string(bytes))
|
|
suite.Equal(tt.expireInfo, responseHeader)
|
|
})
|
|
}
|
|
}
|
|
|
|
func (suite *HTTPSourceClientTestSuite) TestHttpSourceClientGetContentLength() {
|
|
type args struct {
|
|
ctx context.Context
|
|
url string
|
|
header source.RequestHeader
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
want int64
|
|
wantErr error
|
|
}{
|
|
{name: "support content length", args: args{ctx: context.Background(), url: normalURL, header: map[string]string{}}, want: int64(len(testContent)),
|
|
wantErr: nil},
|
|
{name: "not support content length", args: args{ctx: context.Background(), url: normalURL, header: source.RequestHeader{"Range": fmt.Sprintf("bytes=%s",
|
|
"0-3")}}, want: 4,
|
|
wantErr: nil},
|
|
}
|
|
for _, tt := range tests {
|
|
suite.Run(tt.name, func() {
|
|
got, err := suite.GetContentLength(tt.args.ctx, tt.args.url, tt.args.header, nil)
|
|
suite.Equal(tt.wantErr, err)
|
|
suite.Equal(tt.want, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
func (suite *HTTPSourceClientTestSuite) TestHttpSourceClientIsExpired() {
|
|
type args struct {
|
|
ctx context.Context
|
|
url string
|
|
header source.RequestHeader
|
|
expireInfo map[string]string
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
want bool
|
|
wantErr bool
|
|
}{
|
|
{name: "not expire", args: args{context.Background(), normalURL, source.RequestHeader{}, map[string]string{headers.LastModified: lastModified,
|
|
headers.ETag: etag}}, want: false, wantErr: false},
|
|
{name: "error not expire", args: args{context.Background(), errorURL, source.RequestHeader{}, map[string]string{headers.LastModified: lastModified,
|
|
headers.ETag: etag}}, want: false, wantErr: true},
|
|
{name: "expired", args: args{context.Background(), normalURL, source.RequestHeader{}, map[string]string{}}, want: true, wantErr: false},
|
|
}
|
|
for _, tt := range tests {
|
|
suite.Run(tt.name, func() {
|
|
got, err := suite.IsExpired(tt.args.ctx, tt.args.url, tt.args.header, tt.args.expireInfo)
|
|
suite.Equal(tt.want, got)
|
|
suite.Equal(tt.wantErr, err != nil)
|
|
})
|
|
}
|
|
}
|
|
|
|
func (suite *HTTPSourceClientTestSuite) TestHttpSourceClientIsSupportRange() {
|
|
httpmock.RegisterResponder(http.MethodGet, timeoutURL, func(request *http.Request) (*http.Response, error) {
|
|
time.Sleep(3 * time.Second)
|
|
return httpmock.NewStringResponse(http.StatusOK, "ok"), nil
|
|
})
|
|
parent := context.Background()
|
|
ctx, cancel := context.WithTimeout(parent, 1*time.Second)
|
|
support, err := suite.IsSupportRange(ctx, timeoutURL, nil)
|
|
cancel()
|
|
suite.NotNil(err)
|
|
suite.Equal("Get \"http://timeout.com\": context deadline exceeded", err.Error())
|
|
suite.Equal(false, support)
|
|
httpmock.RegisterResponder(http.MethodGet, normalURL, httpmock.NewStringResponder(http.StatusPartialContent, ""))
|
|
httpmock.RegisterResponder(http.MethodGet, "http://notSupportRange.com", httpmock.NewStringResponder(http.StatusOK, ""))
|
|
httpmock.RegisterResponder(http.MethodGet, "http://error.com", httpmock.NewErrorResponder(fmt.Errorf("xxx")))
|
|
type args struct {
|
|
ctx context.Context
|
|
url string
|
|
header map[string]string
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
want bool
|
|
wantErr bool
|
|
}{
|
|
{name: "support", args: args{ctx: context.Background(), url: normalURL, header: source.RequestHeader{"Range": fmt.Sprintf("bytes=%s",
|
|
"0-3")}}, want: true, wantErr: false},
|
|
{name: "notSupport", args: args{ctx: context.Background(), url: normalNotSupportRangeURL, header: source.RequestHeader{"Range": fmt.Sprintf("bytes=%s",
|
|
"0-3")}}, want: false, wantErr: false},
|
|
}
|
|
for _, tt := range tests {
|
|
suite.Run(tt.name, func() {
|
|
got, err := suite.IsSupportRange(tt.args.ctx, tt.args.url, tt.args.header)
|
|
suite.Equal(tt.wantErr, err != nil)
|
|
suite.Equal(tt.want, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
func (suite *HTTPSourceClientTestSuite) TestHttpSourceClientDoRequest() {
|
|
var testURL = "http://www.hackhttp.com"
|
|
httpmock.RegisterResponder(http.MethodGet, testURL, httpmock.NewStringResponder(http.StatusOK, "ok"))
|
|
res, err := suite.ResourceClient.(*httpSourceClient).doRequest(context.Background(), http.MethodGet, "http://www.hackhttp.com", nil)
|
|
suite.Nil(err)
|
|
bytes, err := ioutil.ReadAll(res.Body)
|
|
suite.Nil(err)
|
|
suite.EqualValues("ok", string(bytes))
|
|
}
|