240 lines
6.8 KiB
Go
240 lines
6.8 KiB
Go
package main
|
|
|
|
import (
|
|
"log"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/kedacore/http-add-on/interceptor/config"
|
|
kedanet "github.com/kedacore/http-add-on/pkg/net"
|
|
"github.com/stretchr/testify/require"
|
|
"k8s.io/apimachinery/pkg/util/wait"
|
|
)
|
|
|
|
func newRoundTripper(
|
|
dialCtxFunc kedanet.DialContextFunc,
|
|
httpRespHeaderTimeout time.Duration,
|
|
) http.RoundTripper {
|
|
return &http.Transport{
|
|
DialContext: dialCtxFunc,
|
|
ResponseHeaderTimeout: httpRespHeaderTimeout,
|
|
}
|
|
}
|
|
|
|
func defaultTimeouts() config.Timeouts {
|
|
return config.Timeouts{
|
|
Connect: 100 * time.Millisecond,
|
|
KeepAlive: 100 * time.Millisecond,
|
|
ResponseHeader: 500 * time.Millisecond,
|
|
DeploymentReplicas: 1 * time.Second,
|
|
}
|
|
}
|
|
|
|
// returns a kedanet.DialContextFunc by calling kedanet.DialContextWithRetry. if you pass nil for the
|
|
// timeoutConfig, it uses standard values. otherwise it uses the one you passed.
|
|
//
|
|
// the returned config.Timeouts is what was passed to the DialContextWithRetry function
|
|
func retryDialContextFunc(
|
|
timeouts config.Timeouts,
|
|
backoff wait.Backoff,
|
|
) kedanet.DialContextFunc {
|
|
dialer := kedanet.NewNetDialer(
|
|
timeouts.Connect,
|
|
timeouts.KeepAlive,
|
|
)
|
|
return kedanet.DialContextWithRetry(dialer, backoff)
|
|
}
|
|
|
|
func reqAndRes(path string) (*httptest.ResponseRecorder, *http.Request, error) {
|
|
req, err := http.NewRequest("GET", path, nil)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
resRecorder := httptest.NewRecorder()
|
|
return resRecorder, req, nil
|
|
}
|
|
|
|
func TestForwarderSuccess(t *testing.T) {
|
|
r := require.New(t)
|
|
// this channel will be closed after the request was received, but
|
|
// before the response was sent
|
|
reqRecvCh := make(chan struct{})
|
|
const respCode = 302
|
|
const respBody = "TestForwardingHandler"
|
|
originHdl := kedanet.NewTestHTTPHandlerWrapper(
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
close(reqRecvCh)
|
|
w.WriteHeader(respCode)
|
|
w.Write([]byte(respBody))
|
|
}),
|
|
)
|
|
testServer := httptest.NewServer(originHdl)
|
|
defer testServer.Close()
|
|
forwardURL, err := url.Parse(testServer.URL)
|
|
r.NoError(err)
|
|
|
|
const path = "/testfwd"
|
|
res, req, err := reqAndRes(path)
|
|
r.NoError(err)
|
|
timeouts := defaultTimeouts()
|
|
dialCtxFunc := retryDialContextFunc(timeouts, timeouts.DefaultBackoff())
|
|
forwardRequest(
|
|
res,
|
|
req,
|
|
newRoundTripper(dialCtxFunc, timeouts.ResponseHeader),
|
|
forwardURL,
|
|
)
|
|
|
|
r.True(
|
|
ensureSignalBeforeTimeout(reqRecvCh, 100*time.Millisecond),
|
|
"request was not received within %s",
|
|
100*time.Millisecond,
|
|
)
|
|
forwardedRequests := originHdl.IncomingRequests()
|
|
r.Equal(1, len(forwardedRequests), "number of requests forwarded")
|
|
forwardedRequest := forwardedRequests[0]
|
|
r.Equal(path, forwardedRequest.URL.Path)
|
|
r.Equal(
|
|
302,
|
|
res.Code,
|
|
"Proxied status code was wrong. Response body was %s",
|
|
res.Body.String(),
|
|
)
|
|
r.Equal(respBody, res.Body.String())
|
|
}
|
|
|
|
// Test to make sure that the request forwarder times out if headers aren't returned in time
|
|
func TestForwarderHeaderTimeout(t *testing.T) {
|
|
r := require.New(t)
|
|
// the origin will wait until this channel receives or is closed
|
|
originWaitCh := make(chan struct{})
|
|
hdl := kedanet.NewTestHTTPHandlerWrapper(
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
<-originWaitCh
|
|
w.WriteHeader(200)
|
|
}),
|
|
)
|
|
srv, originURL, err := kedanet.StartTestServer(hdl)
|
|
r.NoError(err)
|
|
defer srv.Close()
|
|
|
|
timeouts := defaultTimeouts()
|
|
timeouts.Connect = 10 * time.Millisecond
|
|
timeouts.ResponseHeader = 10 * time.Millisecond
|
|
backoff := timeouts.Backoff(2, 2, 1)
|
|
dialCtxFunc := retryDialContextFunc(timeouts, backoff)
|
|
res, req, err := reqAndRes("/testfwd")
|
|
r.NoError(err)
|
|
forwardRequest(
|
|
res,
|
|
req,
|
|
newRoundTripper(dialCtxFunc, timeouts.ResponseHeader),
|
|
originURL,
|
|
)
|
|
|
|
forwardedRequests := hdl.IncomingRequests()
|
|
r.Equal(0, len(forwardedRequests))
|
|
r.Equal(502, res.Code)
|
|
r.Contains(res.Body.String(), "error on backend")
|
|
// the proxy has bailed out, so tell the origin to stop
|
|
close(originWaitCh)
|
|
}
|
|
|
|
// Test to ensure that the request forwarder waits for an origin that is slow
|
|
func TestForwarderWaitsForSlowOrigin(t *testing.T) {
|
|
r := require.New(t)
|
|
// the origin will wait until this channel receives or is closed
|
|
originWaitCh := make(chan struct{})
|
|
const originRespCode = 200
|
|
const originRespBodyStr = "Hello World!"
|
|
hdl := kedanet.NewTestHTTPHandlerWrapper(
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
<-originWaitCh
|
|
w.WriteHeader(originRespCode)
|
|
w.Write([]byte(originRespBodyStr))
|
|
}),
|
|
)
|
|
srv, originURL, err := kedanet.StartTestServer(hdl)
|
|
r.NoError(err)
|
|
defer srv.Close()
|
|
// the origin is gonna wait this long, and we'll make the proxy
|
|
// have a much longer timeout than this to account for timing issues
|
|
const originDelay = 5 * time.Millisecond
|
|
timeouts := config.Timeouts{
|
|
Connect: originDelay,
|
|
KeepAlive: 2 * time.Second,
|
|
// the handler is going to take 500 milliseconds to respond, so make the
|
|
// forwarder wait much longer than that
|
|
ResponseHeader: originDelay * 4,
|
|
}
|
|
|
|
dialCtxFunc := retryDialContextFunc(timeouts, timeouts.DefaultBackoff())
|
|
go func() {
|
|
time.Sleep(originDelay)
|
|
close(originWaitCh)
|
|
}()
|
|
const path = "/testfwd"
|
|
res, req, err := reqAndRes(path)
|
|
r.NoError(err)
|
|
forwardRequest(
|
|
res,
|
|
req,
|
|
newRoundTripper(dialCtxFunc, timeouts.ResponseHeader),
|
|
originURL,
|
|
)
|
|
// wait for the goroutine above to finish, with a little cusion
|
|
ensureSignalBeforeTimeout(originWaitCh, originDelay*2)
|
|
r.Equal(originRespCode, res.Code)
|
|
r.Equal(originRespBodyStr, res.Body.String())
|
|
|
|
}
|
|
|
|
func TestForwarderConnectionRetryAndTimeout(t *testing.T) {
|
|
r := require.New(t)
|
|
noSuchURL, err := url.Parse("https://localhost:65533")
|
|
r.NoError(err)
|
|
|
|
timeouts := config.Timeouts{
|
|
Connect: 10 * time.Millisecond,
|
|
KeepAlive: 1 * time.Millisecond,
|
|
ResponseHeader: 50 * time.Millisecond,
|
|
}
|
|
dialCtxFunc := retryDialContextFunc(timeouts, timeouts.DefaultBackoff())
|
|
res, req, err := reqAndRes("/test")
|
|
r.NoError(err)
|
|
|
|
start := time.Now()
|
|
forwardRequest(
|
|
res,
|
|
req,
|
|
newRoundTripper(dialCtxFunc, timeouts.ResponseHeader),
|
|
noSuchURL,
|
|
)
|
|
elapsed := time.Since(start)
|
|
log.Printf("forwardRequest took %s", elapsed)
|
|
|
|
// forwardDoneSignal should close _after_ the total timeout of forwardRequest.
|
|
//
|
|
// forwardRequest uses dialCtxFunc to establish network connections, and dialCtxFunc does
|
|
// exponential backoff. It starts at 2ms (timeouts.Connect above), doubles every time, and stops after 5 tries,
|
|
// so that's 2ms + 4ms + 8ms + 16ms + 32ms, or SUM(2^N) where N is in [1, 5]
|
|
expectedForwardTimeout := kedanet.MinTotalBackoffDuration(timeouts.DefaultBackoff())
|
|
r.GreaterOrEqualf(
|
|
elapsed,
|
|
expectedForwardTimeout,
|
|
"proxy returned after %s, expected not to return until %s",
|
|
time.Since(start),
|
|
expectedForwardTimeout,
|
|
)
|
|
r.Equal(
|
|
502,
|
|
res.Code,
|
|
"unexpected code (response body was '%s')",
|
|
res.Body.String(),
|
|
)
|
|
r.Contains(res.Body.String(), "error on backend")
|
|
}
|