http-add-on/interceptor/request_forwarder_test.go

240 lines
6.8 KiB
Go

package main
import (
"log"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/kedacore/http-add-on/interceptor/config"
kedanet "github.com/kedacore/http-add-on/pkg/net"
"github.com/stretchr/testify/require"
"k8s.io/apimachinery/pkg/util/wait"
)
func newRoundTripper(
dialCtxFunc kedanet.DialContextFunc,
httpRespHeaderTimeout time.Duration,
) http.RoundTripper {
return &http.Transport{
DialContext: dialCtxFunc,
ResponseHeaderTimeout: httpRespHeaderTimeout,
}
}
func defaultTimeouts() config.Timeouts {
return config.Timeouts{
Connect: 100 * time.Millisecond,
KeepAlive: 100 * time.Millisecond,
ResponseHeader: 500 * time.Millisecond,
DeploymentReplicas: 1 * time.Second,
}
}
// returns a kedanet.DialContextFunc by calling kedanet.DialContextWithRetry. if you pass nil for the
// timeoutConfig, it uses standard values. otherwise it uses the one you passed.
//
// the returned config.Timeouts is what was passed to the DialContextWithRetry function
func retryDialContextFunc(
timeouts config.Timeouts,
backoff wait.Backoff,
) kedanet.DialContextFunc {
dialer := kedanet.NewNetDialer(
timeouts.Connect,
timeouts.KeepAlive,
)
return kedanet.DialContextWithRetry(dialer, backoff)
}
func reqAndRes(path string) (*httptest.ResponseRecorder, *http.Request, error) {
req, err := http.NewRequest("GET", path, nil)
if err != nil {
return nil, nil, err
}
resRecorder := httptest.NewRecorder()
return resRecorder, req, nil
}
func TestForwarderSuccess(t *testing.T) {
r := require.New(t)
// this channel will be closed after the request was received, but
// before the response was sent
reqRecvCh := make(chan struct{})
const respCode = 302
const respBody = "TestForwardingHandler"
originHdl := kedanet.NewTestHTTPHandlerWrapper(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(reqRecvCh)
w.WriteHeader(respCode)
w.Write([]byte(respBody))
}),
)
testServer := httptest.NewServer(originHdl)
defer testServer.Close()
forwardURL, err := url.Parse(testServer.URL)
r.NoError(err)
const path = "/testfwd"
res, req, err := reqAndRes(path)
r.NoError(err)
timeouts := defaultTimeouts()
dialCtxFunc := retryDialContextFunc(timeouts, timeouts.DefaultBackoff())
forwardRequest(
res,
req,
newRoundTripper(dialCtxFunc, timeouts.ResponseHeader),
forwardURL,
)
r.True(
ensureSignalBeforeTimeout(reqRecvCh, 100*time.Millisecond),
"request was not received within %s",
100*time.Millisecond,
)
forwardedRequests := originHdl.IncomingRequests()
r.Equal(1, len(forwardedRequests), "number of requests forwarded")
forwardedRequest := forwardedRequests[0]
r.Equal(path, forwardedRequest.URL.Path)
r.Equal(
302,
res.Code,
"Proxied status code was wrong. Response body was %s",
res.Body.String(),
)
r.Equal(respBody, res.Body.String())
}
// Test to make sure that the request forwarder times out if headers aren't returned in time
func TestForwarderHeaderTimeout(t *testing.T) {
r := require.New(t)
// the origin will wait until this channel receives or is closed
originWaitCh := make(chan struct{})
hdl := kedanet.NewTestHTTPHandlerWrapper(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-originWaitCh
w.WriteHeader(200)
}),
)
srv, originURL, err := kedanet.StartTestServer(hdl)
r.NoError(err)
defer srv.Close()
timeouts := defaultTimeouts()
timeouts.Connect = 10 * time.Millisecond
timeouts.ResponseHeader = 10 * time.Millisecond
backoff := timeouts.Backoff(2, 2, 1)
dialCtxFunc := retryDialContextFunc(timeouts, backoff)
res, req, err := reqAndRes("/testfwd")
r.NoError(err)
forwardRequest(
res,
req,
newRoundTripper(dialCtxFunc, timeouts.ResponseHeader),
originURL,
)
forwardedRequests := hdl.IncomingRequests()
r.Equal(0, len(forwardedRequests))
r.Equal(502, res.Code)
r.Contains(res.Body.String(), "error on backend")
// the proxy has bailed out, so tell the origin to stop
close(originWaitCh)
}
// Test to ensure that the request forwarder waits for an origin that is slow
func TestForwarderWaitsForSlowOrigin(t *testing.T) {
r := require.New(t)
// the origin will wait until this channel receives or is closed
originWaitCh := make(chan struct{})
const originRespCode = 200
const originRespBodyStr = "Hello World!"
hdl := kedanet.NewTestHTTPHandlerWrapper(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-originWaitCh
w.WriteHeader(originRespCode)
w.Write([]byte(originRespBodyStr))
}),
)
srv, originURL, err := kedanet.StartTestServer(hdl)
r.NoError(err)
defer srv.Close()
// the origin is gonna wait this long, and we'll make the proxy
// have a much longer timeout than this to account for timing issues
const originDelay = 5 * time.Millisecond
timeouts := config.Timeouts{
Connect: originDelay,
KeepAlive: 2 * time.Second,
// the handler is going to take 500 milliseconds to respond, so make the
// forwarder wait much longer than that
ResponseHeader: originDelay * 4,
}
dialCtxFunc := retryDialContextFunc(timeouts, timeouts.DefaultBackoff())
go func() {
time.Sleep(originDelay)
close(originWaitCh)
}()
const path = "/testfwd"
res, req, err := reqAndRes(path)
r.NoError(err)
forwardRequest(
res,
req,
newRoundTripper(dialCtxFunc, timeouts.ResponseHeader),
originURL,
)
// wait for the goroutine above to finish, with a little cusion
ensureSignalBeforeTimeout(originWaitCh, originDelay*2)
r.Equal(originRespCode, res.Code)
r.Equal(originRespBodyStr, res.Body.String())
}
func TestForwarderConnectionRetryAndTimeout(t *testing.T) {
r := require.New(t)
noSuchURL, err := url.Parse("https://localhost:65533")
r.NoError(err)
timeouts := config.Timeouts{
Connect: 10 * time.Millisecond,
KeepAlive: 1 * time.Millisecond,
ResponseHeader: 50 * time.Millisecond,
}
dialCtxFunc := retryDialContextFunc(timeouts, timeouts.DefaultBackoff())
res, req, err := reqAndRes("/test")
r.NoError(err)
start := time.Now()
forwardRequest(
res,
req,
newRoundTripper(dialCtxFunc, timeouts.ResponseHeader),
noSuchURL,
)
elapsed := time.Since(start)
log.Printf("forwardRequest took %s", elapsed)
// forwardDoneSignal should close _after_ the total timeout of forwardRequest.
//
// forwardRequest uses dialCtxFunc to establish network connections, and dialCtxFunc does
// exponential backoff. It starts at 2ms (timeouts.Connect above), doubles every time, and stops after 5 tries,
// so that's 2ms + 4ms + 8ms + 16ms + 32ms, or SUM(2^N) where N is in [1, 5]
expectedForwardTimeout := kedanet.MinTotalBackoffDuration(timeouts.DefaultBackoff())
r.GreaterOrEqualf(
elapsed,
expectedForwardTimeout,
"proxy returned after %s, expected not to return until %s",
time.Since(start),
expectedForwardTimeout,
)
r.Equal(
502,
res.Code,
"unexpected code (response body was '%s')",
res.Body.String(),
)
r.Contains(res.Body.String(), "error on backend")
}