dragonfly/pkg/source/clients/httpprotocol/http_source_client_test.go

334 lines
11 KiB
Go

/*
* Copyright 2020 The Dragonfly Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package httpprotocol
import (
"context"
"fmt"
"io"
"net/http"
"testing"
"time"
"github.com/go-http-utils/headers"
"github.com/jarcoal/httpmock"
"github.com/stretchr/testify/suite"
nethttp "d7y.io/dragonfly/v2/pkg/net/http"
"d7y.io/dragonfly/v2/pkg/source"
)
func TestHTTPSourceClientTestSuite(t *testing.T) {
suite.Run(t, new(HTTPSourceClientTestSuite))
}
type HTTPSourceClientTestSuite struct {
suite.Suite
httpClient *httpSourceClient
}
func (suite *HTTPSourceClientTestSuite) SetupSuite() {
suite.httpClient = newHTTPSourceClient()
httpmock.ActivateNonDefault(_defaultHTTPClient)
}
func (suite *HTTPSourceClientTestSuite) TearDownSuite() {
httpmock.DeactivateAndReset()
}
var (
timeoutRawURL = "https://timeout.com"
normalRawURL = "https://normal.com"
expireRawURL = "https://expired.com"
errorRawURL = "https://error.com"
forbiddenRawURL = "https://forbidden.com"
notfoundRawURL = "https://notfound.com"
normalNotSupportRangeRawURL = "https://notsuppertrange.com"
)
var (
testContent = "l am test case"
lastModified = "Sun, 06 Jun 2021 12:52:30 GMT"
expireLastModified = "Sun, 06 Jun 2021 11:52:30 GMT"
etag = "UMiJT4h7MCEAEgnqCLA2CdAaABnK"
expireEtag = "UMiJ2T4h7MCEAEgnqCLA2CdAaABnK"
)
func (suite *HTTPSourceClientTestSuite) SetupTest() {
httpmock.Reset()
httpmock.RegisterResponder(http.MethodGet, timeoutRawURL, func(request *http.Request) (*http.Response, error) {
// To simulate the timeout
time.Sleep(5 * time.Second)
return httpmock.NewStringResponse(http.StatusOK, "ok"), nil
})
httpmock.RegisterResponder(http.MethodGet, normalRawURL, func(request *http.Request) (*http.Response, error) {
if rang := request.Header.Get(headers.Range); rang != "" {
r, _ := nethttp.GetRange(rang[6:])
header := http.Header{}
header.Set(headers.LastModified, lastModified)
header.Set(headers.ETag, etag)
res := &http.Response{
StatusCode: http.StatusPartialContent,
ContentLength: int64(r.EndIndex) - int64(r.StartIndex) + int64(1),
Body: httpmock.NewRespBodyFromString(testContent[r.StartIndex:r.EndIndex]),
Header: header,
}
return res, nil
}
if expire := request.Header.Get(headers.IfModifiedSince); expire != "" {
header := http.Header{}
header.Set(headers.LastModified, lastModified)
header.Set(headers.ETag, etag)
res := &http.Response{
StatusCode: http.StatusNotModified,
ContentLength: int64(len(testContent)),
Body: httpmock.NewRespBodyFromString(testContent),
Header: header,
}
return res, nil
}
header := http.Header{}
header.Set(headers.LastModified, lastModified)
header.Set(headers.ETag, etag)
res := &http.Response{
StatusCode: http.StatusOK,
ContentLength: 14,
Body: httpmock.NewRespBodyFromString(testContent),
Header: header,
}
return res, nil
})
header := http.Header{}
header.Set(headers.LastModified, lastModified)
header.Set(headers.ETag, etag)
httpmock.RegisterResponder(http.MethodGet, expireRawURL, func(request *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
ContentLength: 14,
Body: httpmock.NewRespBodyFromString(testContent),
Header: header,
}, nil
})
httpmock.RegisterResponder(http.MethodGet, forbiddenRawURL, httpmock.NewStringResponder(http.StatusForbidden, "forbidden"))
httpmock.RegisterResponder(http.MethodGet, notfoundRawURL, httpmock.NewStringResponder(http.StatusNotFound, "not found"))
httpmock.RegisterResponder(http.MethodGet, normalNotSupportRangeRawURL, httpmock.NewStringResponder(http.StatusOK, testContent))
httpmock.RegisterResponder(http.MethodGet, errorRawURL, httpmock.NewErrorResponder(fmt.Errorf("error")))
}
func (suite *HTTPSourceClientTestSuite) TestNewHTTPSourceClient() {
var sourceClient source.ResourceClient
sourceClient = NewHTTPSourceClient()
suite.Equal(_defaultHTTPClient, sourceClient.(*httpSourceClient).httpClient)
suite.EqualValues(*_defaultHTTPClient, *sourceClient.(*httpSourceClient).httpClient)
expectedHTTPClient := &http.Client{}
sourceClient = NewHTTPSourceClient(WithHTTPClient(expectedHTTPClient))
suite.Equal(expectedHTTPClient, sourceClient.(*httpSourceClient).httpClient)
suite.EqualValues(*expectedHTTPClient, *sourceClient.(*httpSourceClient).httpClient)
}
func (suite *HTTPSourceClientTestSuite) TestHttpSourceClientDownloadWithResponseHeader() {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
timeoutRequest, err := source.NewRequestWithContext(ctx, timeoutRawURL, nil)
suite.Nil(err)
response, err := suite.httpClient.Download(timeoutRequest)
cancel()
suite.NotNil(err)
suite.Equal("Get \"https://timeout.com\": context deadline exceeded", err.Error())
suite.Nil(response)
normalRequest, _ := source.NewRequest(normalRawURL)
normalRangeRequest, _ := source.NewRequest(normalRawURL)
normalRangeRequest.Header.Add(headers.Range, fmt.Sprintf("bytes=%s", "0-3"))
notfoundRequest, _ := source.NewRequest(notfoundRawURL)
errorRequest, _ := source.NewRequest(errorRawURL)
tests := []struct {
name string
request *source.Request
content string
expireInfo *source.ExpireInfo
wantErr error
}{
{
name: "normal download",
request: normalRequest,
content: testContent,
expireInfo: &source.ExpireInfo{
LastModified: lastModified,
ETag: etag,
},
wantErr: nil,
}, {
name: "range download",
request: normalRangeRequest,
content: testContent[0:3],
expireInfo: &source.ExpireInfo{
LastModified: lastModified,
ETag: etag,
},
wantErr: nil,
}, {
name: "not found download",
request: notfoundRequest,
content: "",
expireInfo: nil,
wantErr: source.CheckResponseCode(404, []int{200, 206}),
}, {
name: "error download",
request: errorRequest,
content: "",
expireInfo: nil,
wantErr: fmt.Errorf("Get \"https://error.com\": error"),
},
}
for _, tt := range tests {
suite.Run(tt.name, func() {
response, err := suite.httpClient.Download(tt.request)
if err != nil {
suite.True(tt.wantErr.Error() == err.Error())
return
}
if err = response.Validate(); err != nil {
suite.True(tt.wantErr.Error() == err.Error())
return
}
bytes, err := io.ReadAll(response.Body)
suite.Nil(err)
suite.Equal(tt.content, string(bytes))
expireInfo := response.ExpireInfo()
suite.Equal(tt.expireInfo, &expireInfo)
})
}
}
func (suite *HTTPSourceClientTestSuite) TestHttpSourceClientGetContentLength() {
normalRequest, _ := source.NewRequest(normalRawURL)
normalRangeRequest, _ := source.NewRequest(normalRawURL)
normalRangeRequest.Header.Add(headers.Range, fmt.Sprintf("bytes=%s", "0-3"))
tests := []struct {
name string
request *source.Request
want int64
wantErr error
}{
{name: "support content length", request: normalRequest,
want: int64(len(testContent)),
wantErr: nil},
{name: "not support content length", request: normalRangeRequest,
want: 4,
wantErr: nil},
}
for _, tt := range tests {
suite.Run(tt.name, func() {
got, err := suite.httpClient.GetContentLength(tt.request)
suite.Equal(tt.wantErr, err)
suite.Equal(tt.want, got)
})
}
}
func (suite *HTTPSourceClientTestSuite) TestHttpSourceClientIsExpired() {
normalRequest, _ := source.NewRequest(normalRawURL)
errorRequest, _ := source.NewRequest(errorRawURL)
expireRequest, _ := source.NewRequest(expireRawURL)
tests := []struct {
name string
request *source.Request
expireInfo *source.ExpireInfo
want bool
wantErr bool
}{
{name: "not expire", request: normalRequest, expireInfo: &source.ExpireInfo{
LastModified: lastModified,
ETag: etag,
}, want: false, wantErr: false},
{name: "error not expire", request: errorRequest, expireInfo: &source.ExpireInfo{
LastModified: lastModified,
ETag: etag,
}, want: false, wantErr: true},
{name: "expired", request: expireRequest, expireInfo: &source.ExpireInfo{
LastModified: expireLastModified,
ETag: expireEtag,
}, want: true, wantErr: false},
}
for _, tt := range tests {
suite.Run(tt.name, func() {
got, err := suite.httpClient.IsExpired(tt.request, tt.expireInfo)
suite.Equal(tt.want, got)
suite.Equal(tt.wantErr, err != nil)
})
}
}
func (suite *HTTPSourceClientTestSuite) TestHttpSourceClientIsSupportRange() {
httpmock.RegisterResponder(http.MethodGet, timeoutRawURL, func(request *http.Request) (*http.Response, error) {
time.Sleep(3 * time.Second)
return httpmock.NewStringResponse(http.StatusOK, "ok"), nil
})
parent := context.Background()
ctx, cancel := context.WithTimeout(parent, 1*time.Second)
request, err := source.NewRequestWithContext(ctx, timeoutRawURL, nil)
suite.Nil(err)
support, err := suite.httpClient.IsSupportRange(request)
cancel()
suite.NotNil(err)
suite.Equal("Get \"https://timeout.com\": context deadline exceeded", err.Error())
suite.Equal(false, support)
httpmock.RegisterResponder(http.MethodGet, normalRawURL, httpmock.NewStringResponder(http.StatusPartialContent, ""))
httpmock.RegisterResponder(http.MethodGet, normalNotSupportRangeRawURL, httpmock.NewStringResponder(http.StatusOK, ""))
httpmock.RegisterResponder(http.MethodGet, errorRawURL, httpmock.NewErrorResponder(fmt.Errorf("xxx")))
supportRangeRequest, _ := source.NewRequest(normalRawURL)
supportRangeRequest.Header.Add(headers.Range, fmt.Sprintf("bytes=%s", "0-3"))
notSupportRangeURL, _ := source.NewRequest(normalNotSupportRangeRawURL)
notSupportRangeURL.Header.Add(headers.Range, fmt.Sprintf("bytes=%s", "0-3"))
errRequest, _ := source.NewRequest(errorRawURL)
tests := []struct {
name string
request *source.Request
want bool
wantErr bool
}{
{name: "support", request: supportRangeRequest, want: true, wantErr: false},
{name: "notSupport", request: notSupportRangeURL, want: false, wantErr: false},
{name: "error", request: errRequest, want: false, wantErr: true},
}
for _, tt := range tests {
suite.Run(tt.name, func() {
got, err := suite.httpClient.IsSupportRange(tt.request)
suite.Equal(tt.wantErr, err != nil)
suite.Equal(tt.want, got)
})
}
}
func (suite *HTTPSourceClientTestSuite) TestHttpSourceClientDoRequest() {
var testURL = "https://www.hackhttp.com"
httpmock.RegisterResponder(http.MethodGet, testURL, httpmock.NewStringResponder(http.StatusOK, "ok"))
request, err := source.NewRequest(testURL)
suite.Nil(err)
res, err := suite.httpClient.doRequest(http.MethodGet, request)
suite.Nil(err)
bytes, err := io.ReadAll(res.Body)
suite.Nil(err)
suite.EqualValues("ok", string(bytes))
}