/* * Copyright 2020 The Dragonfly Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package httpprotocol import ( "context" "fmt" "io/ioutil" "net/http" "net/url" "testing" "time" "d7y.io/dragonfly/v2/pkg/source" "d7y.io/dragonfly/v2/pkg/util/rangeutils" "github.com/go-http-utils/headers" "github.com/jarcoal/httpmock" "github.com/stretchr/testify/suite" ) func TestHTTPSourceClientTestSuite(t *testing.T) { suite.Run(t, new(HTTPSourceClientTestSuite)) } type HTTPSourceClientTestSuite struct { suite.Suite source.ResourceClient } func (suite *HTTPSourceClientTestSuite) SetupSuite() { suite.ResourceClient = NewHTTPSourceClient() httpmock.ActivateNonDefault(_defaultHTTPClient) } func (suite *HTTPSourceClientTestSuite) TearDownSuite() { httpmock.DeactivateAndReset() } var ( timeoutURL = "http://timeout.com" normalURL = "http://normal.com" errorURL = "http://error.com" forbiddenURL = "http://forbidden.com" notfoundURL = "http://notfound.com" normalNotSupportRangeURL = "http://notsuppertrange.com" ) var ( testContent = "l am test case" lastModified = "Sun, 06 Jun 2021 12:52:30 GMT" //etag = "UMiJT4h7MCEAEgnqCLA2CdAaABnK" // todo etag business code can not obtain etag = "" ) func (suite *HTTPSourceClientTestSuite) SetupTest() { httpmock.Reset() httpmock.RegisterResponder(http.MethodGet, timeoutURL, func(request *http.Request) (*http.Response, error) { // To simulate the timeout time.Sleep(5 * time.Second) return httpmock.NewStringResponse(http.StatusOK, "ok"), nil }) httpmock.RegisterResponder(http.MethodGet, normalURL, func(request *http.Request) (*http.Response, error) { if rang := request.Header.Get(headers.Range); rang != "" { r, _ := rangeutils.ParseRange(rang[6:]) res := &http.Response{ StatusCode: http.StatusPartialContent, ContentLength: int64(r.EndIndex) - int64(r.StartIndex) + int64(1), Body: httpmock.NewRespBodyFromString(testContent[r.StartIndex:r.EndIndex]), Header: http.Header{ headers.LastModified: []string{lastModified}, headers.ETag: []string{etag}, }, } return res, nil } if expire := request.Header.Get(headers.IfModifiedSince); expire != "" { res := &http.Response{ StatusCode: http.StatusNotModified, ContentLength: int64(len(testContent)), Body: httpmock.NewRespBodyFromString(testContent), Header: http.Header{ headers.LastModified: []string{lastModified}, headers.ETag: []string{etag}, }, } return res, nil } res := &http.Response{ StatusCode: http.StatusOK, ContentLength: 14, Body: httpmock.NewRespBodyFromString(testContent), Header: http.Header{ headers.LastModified: []string{lastModified}, headers.ETag: []string{etag}, }, } return res, nil }) httpmock.RegisterResponder(http.MethodGet, forbiddenURL, httpmock.NewStringResponder(http.StatusForbidden, "forbidden")) httpmock.RegisterResponder(http.MethodGet, notfoundURL, httpmock.NewStringResponder(http.StatusNotFound, "not found")) httpmock.RegisterResponder(http.MethodGet, normalNotSupportRangeURL, httpmock.NewStringResponder(http.StatusOK, testContent)) httpmock.RegisterResponder(http.MethodGet, errorURL, httpmock.NewErrorResponder(fmt.Errorf("error"))) } func (suite *HTTPSourceClientTestSuite) TestNewHTTPSourceClient() { var sourceClient source.ResourceClient sourceClient = NewHTTPSourceClient() suite.Equal(_defaultHTTPClient, sourceClient.(*httpSourceClient).httpClient) suite.EqualValues(*_defaultHTTPClient, *sourceClient.(*httpSourceClient).httpClient) expectedHTTPClient := &http.Client{} sourceClient = NewHTTPSourceClient(WithHTTPClient(expectedHTTPClient)) suite.Equal(expectedHTTPClient, sourceClient.(*httpSourceClient).httpClient) suite.EqualValues(*expectedHTTPClient, *sourceClient.(*httpSourceClient).httpClient) } func (suite *HTTPSourceClientTestSuite) TestHttpSourceClientDownloadWithResponseHeader() { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) reader, responseHeader, err := suite.DownloadWithResponseHeader(ctx, timeoutURL, source.RequestHeader{}, nil) cancel() suite.NotNil(err) suite.Equal("Get \"http://timeout.com\": context deadline exceeded", err.Error()) suite.Nil(reader) suite.Nil(responseHeader) type args struct { ctx context.Context url string header source.RequestHeader } tests := []struct { name string args args content string expireInfo source.ResponseHeader wantErr error }{ { name: "normal download", args: args{ ctx: context.Background(), url: normalURL, header: nil, }, content: testContent, expireInfo: source.ResponseHeader{ source.LastModified: lastModified, source.ETag: etag, }, wantErr: nil, }, { name: "range download", args: args{ ctx: context.Background(), url: normalURL, header: source.RequestHeader{"Range": fmt.Sprintf("bytes=%s", "0-3")}, }, content: testContent[0:3], expireInfo: source.ResponseHeader{ headers.LastModified: lastModified, headers.ETag: etag, }, wantErr: nil, }, { name: "not found download", args: args{ ctx: context.Background(), url: notfoundURL, header: nil, }, content: "", expireInfo: nil, wantErr: fmt.Errorf("unexpected status code: %d", http.StatusNotFound), }, { name: "error download", args: args{ ctx: context.Background(), url: errorURL, header: nil, }, content: "", expireInfo: nil, wantErr: &url.Error{ Op: "Get", URL: errorURL, Err: fmt.Errorf("error"), }, }, } for _, tt := range tests { suite.Run(tt.name, func() { reader, responseHeader, err := suite.DownloadWithResponseHeader(tt.args.ctx, tt.args.url, tt.args.header, nil) suite.Equal(tt.wantErr, err) if err != nil { return } bytes, err := ioutil.ReadAll(reader) suite.Nil(err) suite.Equal(tt.content, string(bytes)) suite.Equal(tt.expireInfo, responseHeader) }) } } func (suite *HTTPSourceClientTestSuite) TestHttpSourceClientGetContentLength() { type args struct { ctx context.Context url string header source.RequestHeader } tests := []struct { name string args args want int64 wantErr error }{ {name: "support content length", args: args{ctx: context.Background(), url: normalURL, header: map[string]string{}}, want: int64(len(testContent)), wantErr: nil}, {name: "not support content length", args: args{ctx: context.Background(), url: normalURL, header: source.RequestHeader{"Range": fmt.Sprintf("bytes=%s", "0-3")}}, want: 4, wantErr: nil}, } for _, tt := range tests { suite.Run(tt.name, func() { got, err := suite.GetContentLength(tt.args.ctx, tt.args.url, tt.args.header, nil) suite.Equal(tt.wantErr, err) suite.Equal(tt.want, got) }) } } func (suite *HTTPSourceClientTestSuite) TestHttpSourceClientIsExpired() { type args struct { ctx context.Context url string header source.RequestHeader expireInfo map[string]string } tests := []struct { name string args args want bool wantErr bool }{ {name: "not expire", args: args{context.Background(), normalURL, source.RequestHeader{}, map[string]string{headers.LastModified: lastModified, headers.ETag: etag}}, want: false, wantErr: false}, {name: "error not expire", args: args{context.Background(), errorURL, source.RequestHeader{}, map[string]string{headers.LastModified: lastModified, headers.ETag: etag}}, want: false, wantErr: true}, {name: "expired", args: args{context.Background(), normalURL, source.RequestHeader{}, map[string]string{}}, want: true, wantErr: false}, } for _, tt := range tests { suite.Run(tt.name, func() { got, err := suite.IsExpired(tt.args.ctx, tt.args.url, tt.args.header, tt.args.expireInfo) suite.Equal(tt.want, got) suite.Equal(tt.wantErr, err != nil) }) } } func (suite *HTTPSourceClientTestSuite) TestHttpSourceClientIsSupportRange() { httpmock.RegisterResponder(http.MethodGet, timeoutURL, func(request *http.Request) (*http.Response, error) { time.Sleep(3 * time.Second) return httpmock.NewStringResponse(http.StatusOK, "ok"), nil }) parent := context.Background() ctx, cancel := context.WithTimeout(parent, 1*time.Second) support, err := suite.IsSupportRange(ctx, timeoutURL, nil) cancel() suite.NotNil(err) suite.Equal("Get \"http://timeout.com\": context deadline exceeded", err.Error()) suite.Equal(false, support) httpmock.RegisterResponder(http.MethodGet, normalURL, httpmock.NewStringResponder(http.StatusPartialContent, "")) httpmock.RegisterResponder(http.MethodGet, "http://notSupportRange.com", httpmock.NewStringResponder(http.StatusOK, "")) httpmock.RegisterResponder(http.MethodGet, "http://error.com", httpmock.NewErrorResponder(fmt.Errorf("xxx"))) type args struct { ctx context.Context url string header map[string]string } tests := []struct { name string args args want bool wantErr bool }{ {name: "support", args: args{ctx: context.Background(), url: normalURL, header: source.RequestHeader{"Range": fmt.Sprintf("bytes=%s", "0-3")}}, want: true, wantErr: false}, {name: "notSupport", args: args{ctx: context.Background(), url: normalNotSupportRangeURL, header: source.RequestHeader{"Range": fmt.Sprintf("bytes=%s", "0-3")}}, want: false, wantErr: false}, } for _, tt := range tests { suite.Run(tt.name, func() { got, err := suite.IsSupportRange(tt.args.ctx, tt.args.url, tt.args.header) suite.Equal(tt.wantErr, err != nil) suite.Equal(tt.want, got) }) } } func (suite *HTTPSourceClientTestSuite) TestHttpSourceClientDoRequest() { var testURL = "http://www.hackhttp.com" httpmock.RegisterResponder(http.MethodGet, testURL, httpmock.NewStringResponder(http.StatusOK, "ok")) res, err := suite.ResourceClient.(*httpSourceClient).doRequest(context.Background(), http.MethodGet, "http://www.hackhttp.com", nil) suite.Nil(err) bytes, err := ioutil.ReadAll(res.Body) suite.Nil(err) suite.EqualValues("ok", string(bytes)) }