From 026eb846a4a66a0a3444ad166066beb11a1b2fcc Mon Sep 17 00:00:00 2001 From: Abu Kashem Date: Thu, 26 Nov 2020 23:53:20 -0500 Subject: [PATCH] apiserver: plumb context with request deadline - as soon as a request is received by the apiserver, determine the timeout of the request and set a new request context with the deadline. - the timeout filter that times out non-long-running requests should use the request context as opposed to a fixed 60s wait today. - admission and storage layer uses the same request context with the deadline specified. we use the default timeout enforced by the apiserver: - if the user has specified a timeout of 0s, this implies no timeout on the user's part. - if the user has specified a timeout that exceeds the maximum deadline allowed by the apiserver. Kubernetes-commit: e416c9e574c49fd0190c8cdac58322aa33a935cf --- pkg/endpoints/apiserver_test.go | 7 +- pkg/endpoints/filters/request_deadline.go | 172 +++++++ .../filters/request_deadline_test.go | 478 ++++++++++++++++++ pkg/endpoints/handlers/create.go | 9 +- pkg/endpoints/handlers/delete.go | 21 +- pkg/endpoints/handlers/patch.go | 14 +- pkg/endpoints/handlers/rest.go | 22 +- pkg/endpoints/handlers/rest_test.go | 42 +- pkg/endpoints/handlers/update.go | 11 +- pkg/server/config.go | 8 +- .../filters/priority-and-fairness_test.go | 2 +- pkg/server/filters/timeout.go | 26 +- pkg/server/filters/timeout_test.go | 22 +- 13 files changed, 743 insertions(+), 91 deletions(-) create mode 100644 pkg/endpoints/filters/request_deadline.go create mode 100644 pkg/endpoints/filters/request_deadline_test.go diff --git a/pkg/endpoints/apiserver_test.go b/pkg/endpoints/apiserver_test.go index a7b977ee4..f8b393fc7 100644 --- a/pkg/endpoints/apiserver_test.go +++ b/pkg/endpoints/apiserver_test.go @@ -282,10 +282,13 @@ func handleInternal(storage map[string]rest.Storage, admissionControl admission. panic(fmt.Sprintf("unable to install container %s: %v", group.GroupVersion, err)) } } - handler := genericapifilters.WithAudit(mux, auditSink, auditpolicy.FakeChecker(auditinternal.LevelRequestResponse, nil), func(r *http.Request, requestInfo *request.RequestInfo) bool { + longRunningCheck := func(r *http.Request, requestInfo *request.RequestInfo) bool { // simplified long-running check return requestInfo.Verb == "watch" || requestInfo.Verb == "proxy" - }) + } + fakeChecker := auditpolicy.FakeChecker(auditinternal.LevelRequestResponse, nil) + handler := genericapifilters.WithAudit(mux, auditSink, fakeChecker, longRunningCheck) + handler = genericapifilters.WithRequestDeadline(handler, auditSink, fakeChecker, longRunningCheck, codecs, 60*time.Second) handler = genericapifilters.WithRequestInfo(handler, testRequestInfoResolver()) return &defaultAPIServer{handler, container} diff --git a/pkg/endpoints/filters/request_deadline.go b/pkg/endpoints/filters/request_deadline.go new file mode 100644 index 000000000..1e43cdab2 --- /dev/null +++ b/pkg/endpoints/filters/request_deadline.go @@ -0,0 +1,172 @@ +/* +Copyright 2020 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filters + +import ( + "context" + "errors" + "fmt" + "net/http" + "time" + + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + utilclock "k8s.io/apimachinery/pkg/util/clock" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + auditinternal "k8s.io/apiserver/pkg/apis/audit" + "k8s.io/apiserver/pkg/audit" + "k8s.io/apiserver/pkg/audit/policy" + "k8s.io/apiserver/pkg/endpoints/handlers/responsewriters" + "k8s.io/apiserver/pkg/endpoints/request" + "k8s.io/klog/v2" +) + +const ( + // The 'timeout' query parameter in the request URL has an invalid duration specifier + invalidTimeoutInURL = "invalid timeout specified in the request URL" +) + +// WithRequestDeadline determines the timeout duration applicable to the given request and sets a new context +// with the appropriate deadline. +// auditWrapper provides an http.Handler that audits a failed request. +// longRunning returns true if he given request is a long running request. +// requestTimeoutMaximum specifies the default request timeout value. +func WithRequestDeadline(handler http.Handler, sink audit.Sink, policy policy.Checker, longRunning request.LongRunningRequestCheck, + negotiatedSerializer runtime.NegotiatedSerializer, requestTimeoutMaximum time.Duration) http.Handler { + return withRequestDeadline(handler, sink, policy, longRunning, negotiatedSerializer, requestTimeoutMaximum, utilclock.RealClock{}) +} + +func withRequestDeadline(handler http.Handler, sink audit.Sink, policy policy.Checker, longRunning request.LongRunningRequestCheck, + negotiatedSerializer runtime.NegotiatedSerializer, requestTimeoutMaximum time.Duration, clock utilclock.PassiveClock) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + ctx := req.Context() + + requestInfo, ok := request.RequestInfoFrom(ctx) + if !ok { + handleError(w, req, http.StatusInternalServerError, fmt.Errorf("no RequestInfo found in context, handler chain must be wrong")) + return + } + if longRunning(req, requestInfo) { + handler.ServeHTTP(w, req) + return + } + + userSpecifiedTimeout, ok, err := parseTimeout(req) + if err != nil { + statusErr := apierrors.NewBadRequest(fmt.Sprintf("%s", err.Error())) + + klog.Errorf("Error - %s: %#v", err.Error(), req.RequestURI) + + failed := failedErrorHandler(negotiatedSerializer, statusErr) + failWithAudit := withFailedRequestAudit(failed, statusErr, sink, policy) + failWithAudit.ServeHTTP(w, req) + return + } + + timeout := requestTimeoutMaximum + if ok { + // we use the default timeout enforced by the apiserver: + // - if the user has specified a timeout of 0s, this implies no timeout on the user's part. + // - if the user has specified a timeout that exceeds the maximum deadline allowed by the apiserver. + if userSpecifiedTimeout > 0 && userSpecifiedTimeout < requestTimeoutMaximum { + timeout = userSpecifiedTimeout + } + } + + started := clock.Now() + if requestStartedTimestamp, ok := request.ReceivedTimestampFrom(ctx); ok { + started = requestStartedTimestamp + } + + ctx, cancel := context.WithDeadline(ctx, started.Add(timeout)) + defer cancel() + + req = req.WithContext(ctx) + handler.ServeHTTP(w, req) + }) +} + +// withFailedRequestAudit decorates a failed http.Handler and is used to audit a failed request. +// statusErr is used to populate the Message property of ResponseStatus. +func withFailedRequestAudit(failedHandler http.Handler, statusErr *apierrors.StatusError, sink audit.Sink, policy policy.Checker) http.Handler { + if sink == nil || policy == nil { + return failedHandler + } + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + req, ev, omitStages, err := createAuditEventAndAttachToContext(req, policy) + if err != nil { + utilruntime.HandleError(fmt.Errorf("failed to create audit event: %v", err)) + responsewriters.InternalError(w, req, errors.New("failed to create audit event")) + return + } + if ev == nil { + failedHandler.ServeHTTP(w, req) + return + } + + ev.ResponseStatus = &metav1.Status{} + ev.Stage = auditinternal.StageResponseStarted + if statusErr != nil { + ev.ResponseStatus.Message = statusErr.Error() + } + + rw := decorateResponseWriter(w, ev, sink, omitStages) + failedHandler.ServeHTTP(rw, req) + }) +} + +// failedErrorHandler returns an http.Handler that uses the specified StatusError object +// to render an error response to the request. +func failedErrorHandler(s runtime.NegotiatedSerializer, statusError *apierrors.StatusError) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + ctx := req.Context() + requestInfo, found := request.RequestInfoFrom(ctx) + if !found { + responsewriters.InternalError(w, req, errors.New("no RequestInfo found in the context")) + return + } + + gv := schema.GroupVersion{Group: requestInfo.APIGroup, Version: requestInfo.APIVersion} + responsewriters.ErrorNegotiated(statusError, s, gv, w, req) + }) +} + +// parseTimeout parses the given HTTP request URL and extracts the timeout query parameter +// value if specified by the user. +// If a timeout is not specified the function returns false and err is set to nil +// If the value specified is malformed then the function returns false and err is set +func parseTimeout(req *http.Request) (time.Duration, bool, error) { + value := req.URL.Query().Get("timeout") + if value == "" { + return 0, false, nil + } + + timeout, err := time.ParseDuration(value) + if err != nil { + return 0, false, fmt.Errorf("%s - %s", invalidTimeoutInURL, err.Error()) + } + + return timeout, true, nil +} + +func handleError(w http.ResponseWriter, r *http.Request, code int, err error) { + errorMsg := fmt.Sprintf("Error - %s: %#v", err.Error(), r.RequestURI) + http.Error(w, errorMsg, code) + klog.Errorf(errorMsg) +} diff --git a/pkg/endpoints/filters/request_deadline_test.go b/pkg/endpoints/filters/request_deadline_test.go new file mode 100644 index 000000000..55ff4a8f0 --- /dev/null +++ b/pkg/endpoints/filters/request_deadline_test.go @@ -0,0 +1,478 @@ +/* +Copyright 2020 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filters + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + + apierrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/serializer" + utilclock "k8s.io/apimachinery/pkg/util/clock" + auditinternal "k8s.io/apiserver/pkg/apis/audit" + "k8s.io/apiserver/pkg/audit/policy" + "k8s.io/apiserver/pkg/endpoints/request" +) + +func TestParseTimeout(t *testing.T) { + tests := []struct { + name string + url string + expected bool + timeoutExpected time.Duration + message string + }{ + { + name: "the user does not specify a timeout", + url: "/api/v1/namespaces?timeout=", + }, + { + name: "the user specifies a valid timeout", + url: "/api/v1/namespaces?timeout=10s", + expected: true, + timeoutExpected: 10 * time.Second, + }, + { + name: "the user specifies a timeout of 0s", + url: "/api/v1/namespaces?timeout=0s", + expected: true, + }, + { + name: "the user specifies an invalid timeout", + url: "/api/v1/namespaces?timeout=foo", + message: invalidTimeoutInURL, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + request, err := http.NewRequest(http.MethodGet, test.url, nil) + if err != nil { + t.Fatalf("failed to create new http request - %v", err) + } + + timeoutGot, ok, err := parseTimeout(request) + + if test.expected != ok { + t.Errorf("expected: %t, but got: %t", test.expected, ok) + } + if test.timeoutExpected != timeoutGot { + t.Errorf("expected timeout: %s, but got: %s", test.timeoutExpected, timeoutGot) + } + + errMessageGot := message(err) + if !strings.Contains(errMessageGot, test.message) { + t.Errorf("expected error message to contain: %s, but got: %s", test.message, errMessageGot) + } + }) + } +} + +func TestWithRequestDeadline(t *testing.T) { + const requestTimeoutMaximum = 60 * time.Second + tests := []struct { + name string + requestURL string + longRunning bool + hasDeadlineExpected bool + deadlineExpected time.Duration + handlerCallCountExpected int + statusCodeExpected int + }{ + { + name: "the user specifies a valid request timeout", + requestURL: "/api/v1/namespaces?timeout=15s", + longRunning: false, + handlerCallCountExpected: 1, + hasDeadlineExpected: true, + deadlineExpected: 14 * time.Second, // to account for the delay in verification + statusCodeExpected: http.StatusOK, + }, + { + name: "the user specifies a valid request timeout", + requestURL: "/api/v1/namespaces?timeout=15s", + longRunning: false, + handlerCallCountExpected: 1, + hasDeadlineExpected: true, + deadlineExpected: 14 * time.Second, // to account for the delay in verification + statusCodeExpected: http.StatusOK, + }, + { + name: "the specified timeout is 0s, default deadline is expected to be set", + requestURL: "/api/v1/namespaces?timeout=0s", + longRunning: false, + handlerCallCountExpected: 1, + hasDeadlineExpected: true, + deadlineExpected: requestTimeoutMaximum - time.Second, // to account for the delay in verification + statusCodeExpected: http.StatusOK, + }, + { + name: "the user does not specify any request timeout, default deadline is expected to be set", + requestURL: "/api/v1/namespaces?timeout=", + longRunning: false, + handlerCallCountExpected: 1, + hasDeadlineExpected: true, + deadlineExpected: requestTimeoutMaximum - time.Second, // to account for the delay in verification + statusCodeExpected: http.StatusOK, + }, + { + name: "the request is long running, no deadline is expected to be set", + requestURL: "/api/v1/namespaces?timeout=10s", + longRunning: true, + hasDeadlineExpected: false, + handlerCallCountExpected: 1, + statusCodeExpected: http.StatusOK, + }, + { + name: "the timeout specified is malformed, the request is aborted with HTTP 400", + requestURL: "/api/v1/namespaces?timeout=foo", + longRunning: false, + statusCodeExpected: http.StatusBadRequest, + }, + { + name: "the timeout specified exceeds the maximum deadline allowed, the default deadline is used", + requestURL: fmt.Sprintf("/api/v1/namespaces?timeout=%s", requestTimeoutMaximum+time.Second), + longRunning: false, + statusCodeExpected: http.StatusOK, + handlerCallCountExpected: 1, + hasDeadlineExpected: true, + deadlineExpected: requestTimeoutMaximum - time.Second, // to account for the delay in verification + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var ( + callCount int + hasDeadlineGot bool + deadlineGot time.Duration + ) + handler := http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) { + callCount++ + deadlineGot, hasDeadlineGot = deadline(req) + }) + + fakeSink := &fakeAuditSink{} + fakeChecker := policy.FakeChecker(auditinternal.LevelRequestResponse, nil) + withDeadline := WithRequestDeadline(handler, fakeSink, fakeChecker, + func(_ *http.Request, _ *request.RequestInfo) bool { return test.longRunning }, + newSerializer(), requestTimeoutMaximum) + withDeadline = WithRequestInfo(withDeadline, &fakeRequestResolver{}) + + testRequest := newRequest(t, test.requestURL) + + // make sure a default request does not have any deadline set + remaning, ok := deadline(testRequest) + if ok { + t.Fatalf("test setup failed, expected the new HTTP request context to have no deadline but got: %s", remaning) + } + + w := httptest.NewRecorder() + withDeadline.ServeHTTP(w, testRequest) + + if test.handlerCallCountExpected != callCount { + t.Errorf("expected the request handler to be invoked %d times, but was actually invoked %d times", test.handlerCallCountExpected, callCount) + } + + if test.hasDeadlineExpected != hasDeadlineGot { + t.Errorf("expected the request context to have deadline set: %t but got: %t", test.hasDeadlineExpected, hasDeadlineGot) + } + + deadlineGot = deadlineGot.Truncate(time.Second) + if test.deadlineExpected != deadlineGot { + t.Errorf("expected a request context with a deadline of %s but got: %s", test.deadlineExpected, deadlineGot) + } + + statusCodeGot := w.Result().StatusCode + if test.statusCodeExpected != statusCodeGot { + t.Errorf("expected status code %d but got: %d", test.statusCodeExpected, statusCodeGot) + } + }) + } +} + +func TestWithRequestDeadlineWithClock(t *testing.T) { + var ( + hasDeadlineGot bool + deadlineGot time.Duration + ) + handler := http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) { + deadlineGot, hasDeadlineGot = deadline(req) + }) + + // if the deadline filter uses the clock instead of using the request started timestamp from the context + // then we will see a request deadline of about a minute. + receivedTimestampExpected := time.Now().Add(time.Minute) + fakeClock := utilclock.NewFakeClock(receivedTimestampExpected) + + fakeSink := &fakeAuditSink{} + fakeChecker := policy.FakeChecker(auditinternal.LevelRequestResponse, nil) + withDeadline := withRequestDeadline(handler, fakeSink, fakeChecker, + func(_ *http.Request, _ *request.RequestInfo) bool { return false }, newSerializer(), time.Minute, fakeClock) + withDeadline = WithRequestInfo(withDeadline, &fakeRequestResolver{}) + + testRequest := newRequest(t, "/api/v1/namespaces?timeout=1s") + // the request has arrived just now. + testRequest = testRequest.WithContext(request.WithReceivedTimestamp(testRequest.Context(), time.Now())) + + w := httptest.NewRecorder() + withDeadline.ServeHTTP(w, testRequest) + + if !hasDeadlineGot { + t.Error("expected the request context to have deadline set") + } + + // we expect a deadline <= 1s since the filter should use the request started timestamp from the context. + if deadlineGot > time.Second { + t.Errorf("expected a request context with a deadline <= %s, but got: %s", time.Second, deadlineGot) + } +} + +func TestWithRequestDeadlineWithFailedRequestIsAudited(t *testing.T) { + var handlerInvoked bool + handler := http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) { + handlerInvoked = true + }) + + fakeSink := &fakeAuditSink{} + fakeChecker := policy.FakeChecker(auditinternal.LevelRequestResponse, nil) + withDeadline := WithRequestDeadline(handler, fakeSink, fakeChecker, + func(_ *http.Request, _ *request.RequestInfo) bool { return false }, newSerializer(), time.Minute) + withDeadline = WithRequestInfo(withDeadline, &fakeRequestResolver{}) + + testRequest := newRequest(t, "/api/v1/namespaces?timeout=foo") + w := httptest.NewRecorder() + withDeadline.ServeHTTP(w, testRequest) + + if handlerInvoked { + t.Error("expected the request to fail and the handler to be skipped") + } + + statusCodeGot := w.Result().StatusCode + if statusCodeGot != http.StatusBadRequest { + t.Errorf("expected status code %d, but got: %d", http.StatusBadRequest, statusCodeGot) + } + // verify that the audit event from the request context is written to the audit sink. + if len(fakeSink.events) != 1 { + t.Fatalf("expected audit sink to have 1 event, but got: %d", len(fakeSink.events)) + } +} + +func TestWithRequestDeadlineWithPanic(t *testing.T) { + var ( + panicErrGot interface{} + ctxGot context.Context + ) + + panicErrExpected := errors.New("apiserver panic'd") + handler := http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) { + ctxGot = req.Context() + panic(panicErrExpected) + }) + + fakeSink := &fakeAuditSink{} + fakeChecker := policy.FakeChecker(auditinternal.LevelRequestResponse, nil) + withDeadline := WithRequestDeadline(handler, fakeSink, fakeChecker, + func(_ *http.Request, _ *request.RequestInfo) bool { return false }, newSerializer(), 1*time.Minute) + withDeadline = WithRequestInfo(withDeadline, &fakeRequestResolver{}) + withPanicRecovery := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + defer func() { + panicErrGot = recover() + }() + withDeadline.ServeHTTP(w, req) + }) + + testRequest := newRequest(t, "/api/v1/namespaces?timeout=1s") + w := httptest.NewRecorder() + withPanicRecovery.ServeHTTP(w, testRequest) + + if panicErrExpected != panicErrGot { + t.Errorf("expected panic error: %#v, but got: %#v", panicErrExpected, panicErrGot) + } + if ctxGot.Err() != context.Canceled { + t.Error("expected the request context to be canceled on handler panic") + } +} + +func TestWithRequestDeadlineWithRequestTimesOut(t *testing.T) { + timeout := 100 * time.Millisecond + var errGot error + handler := http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) { + ctx := req.Context() + select { + case <-time.After(timeout + time.Second): + errGot = fmt.Errorf("expected the request context to have timed out in %s", timeout) + case <-ctx.Done(): + errGot = ctx.Err() + } + }) + + fakeSink := &fakeAuditSink{} + fakeChecker := policy.FakeChecker(auditinternal.LevelRequestResponse, nil) + withDeadline := WithRequestDeadline(handler, fakeSink, fakeChecker, + func(_ *http.Request, _ *request.RequestInfo) bool { return false }, newSerializer(), 1*time.Minute) + withDeadline = WithRequestInfo(withDeadline, &fakeRequestResolver{}) + + testRequest := newRequest(t, fmt.Sprintf("/api/v1/namespaces?timeout=%s", timeout)) + w := httptest.NewRecorder() + withDeadline.ServeHTTP(w, testRequest) + + if errGot != context.DeadlineExceeded { + t.Errorf("expected error: %#v, but got: %#v", context.DeadlineExceeded, errGot) + } +} + +func TestWithFailedRequestAudit(t *testing.T) { + tests := []struct { + name string + statusErr *apierrors.StatusError + errorHandlerCallCountExpected int + statusCodeExpected int + auditExpected bool + }{ + { + name: "bad request, the error handler is invoked and the request is audited", + statusErr: apierrors.NewBadRequest("error serving request"), + errorHandlerCallCountExpected: 1, + statusCodeExpected: http.StatusBadRequest, + auditExpected: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var ( + errorHandlerCallCountGot int + rwGot http.ResponseWriter + requestGot *http.Request + ) + + errorHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + http.Error(rw, "error serving request", http.StatusBadRequest) + + errorHandlerCallCountGot++ + requestGot = req + rwGot = rw + }) + + fakeSink := &fakeAuditSink{} + fakeChecker := policy.FakeChecker(auditinternal.LevelRequestResponse, nil) + + withAudit := withFailedRequestAudit(errorHandler, test.statusErr, fakeSink, fakeChecker) + + w := httptest.NewRecorder() + testRequest, err := http.NewRequest(http.MethodGet, "/apis/v1/namespaces/default/pods", nil) + if err != nil { + t.Fatalf("failed to create new http testRequest - %v", err) + } + info := request.RequestInfo{} + testRequest = testRequest.WithContext(request.WithRequestInfo(testRequest.Context(), &info)) + + withAudit.ServeHTTP(w, testRequest) + + if test.errorHandlerCallCountExpected != errorHandlerCallCountGot { + t.Errorf("expected the testRequest handler to be invoked %d times, but was actually invoked %d times", test.errorHandlerCallCountExpected, errorHandlerCallCountGot) + } + + statusCodeGot := w.Result().StatusCode + if test.statusCodeExpected != statusCodeGot { + t.Errorf("expected status code %d, but got: %d", test.statusCodeExpected, statusCodeGot) + } + + if test.auditExpected { + // verify that the right http.ResponseWriter is passed to the error handler + _, ok := rwGot.(*auditResponseWriter) + if !ok { + t.Errorf("expected an http.ResponseWriter of type: %T but got: %T", &auditResponseWriter{}, rwGot) + } + + auditEventGot := request.AuditEventFrom(requestGot.Context()) + if auditEventGot == nil { + t.Fatal("expected an audit event object but got nil") + } + if auditEventGot.Stage != auditinternal.StageResponseStarted { + t.Errorf("expected audit event Stage: %s, but got: %s", auditinternal.StageResponseStarted, auditEventGot.Stage) + } + if auditEventGot.ResponseStatus == nil { + t.Fatal("expected a ResponseStatus field of the audit event object, but got nil") + } + if test.statusCodeExpected != int(auditEventGot.ResponseStatus.Code) { + t.Errorf("expected audit event ResponseStatus.Code: %d, but got: %d", test.statusCodeExpected, auditEventGot.ResponseStatus.Code) + } + if test.statusErr.Error() != auditEventGot.ResponseStatus.Message { + t.Errorf("expected audit event ResponseStatus.Message: %s, but got: %s", test.statusErr, auditEventGot.ResponseStatus.Message) + } + + // verify that the audit event from the request context is written to the audit sink. + if len(fakeSink.events) != 1 { + t.Fatalf("expected audit sink to have 1 event, but got: %d", len(fakeSink.events)) + } + auditEventFromSink := fakeSink.events[0] + if !reflect.DeepEqual(auditEventGot, auditEventFromSink) { + t.Errorf("expected the audit event from the request context to be written to the audit sink, but got diffs: %s", cmp.Diff(auditEventGot, auditEventFromSink)) + } + } + }) + } +} + +func newRequest(t *testing.T, requestURL string) *http.Request { + req, err := http.NewRequest(http.MethodGet, requestURL, nil) + if err != nil { + t.Fatalf("failed to create new http request - %v", err) + } + + return req +} + +func message(err error) string { + if err != nil { + return err.Error() + } + + return "" +} + +func newSerializer() runtime.NegotiatedSerializer { + scheme := runtime.NewScheme() + return serializer.NewCodecFactory(scheme).WithoutConversion() +} + +type fakeRequestResolver struct{} + +func (r fakeRequestResolver) NewRequestInfo(req *http.Request) (*request.RequestInfo, error) { + return &request.RequestInfo{}, nil +} + +func deadline(r *http.Request) (time.Duration, bool) { + if deadline, ok := r.Context().Deadline(); ok { + remaining := time.Until(deadline) + return remaining, ok + } + + return 0, false +} diff --git a/pkg/endpoints/handlers/create.go b/pkg/endpoints/handlers/create.go index 0b950c72d..e914eaeea 100644 --- a/pkg/endpoints/handlers/create.go +++ b/pkg/endpoints/handlers/create.go @@ -57,9 +57,6 @@ func createHandler(r rest.NamedCreater, scope *RequestScope, admit admission.Int return } - // TODO: we either want to remove timeout or document it (if we document, move timeout out of this function and declare it in api_installer) - timeout := parseTimeout(req.URL.Query().Get("timeout")) - namespace, name, err := scope.Namer.Name(req) if err != nil { if includeName { @@ -76,7 +73,9 @@ func createHandler(r rest.NamedCreater, scope *RequestScope, admit admission.Int } } - ctx, cancel := context.WithTimeout(req.Context(), timeout) + // enforce a timeout of at most requestTimeoutUpperBound (34s) or less if the user-provided + // timeout inside the parent context is lower than requestTimeoutUpperBound. + ctx, cancel := context.WithTimeout(req.Context(), requestTimeoutUpperBound) defer cancel() outputMediaType, _, err := negotiation.NegotiateOutputMediaType(req, scope.Serializer, scope) if err != nil { @@ -157,7 +156,7 @@ func createHandler(r rest.NamedCreater, scope *RequestScope, admit admission.Int } // Dedup owner references before updating managed fields dedupOwnerReferencesAndAddWarning(obj, req.Context(), false) - result, err := finishRequest(timeout, func() (runtime.Object, error) { + result, err := finishRequest(ctx, func() (runtime.Object, error) { if scope.FieldManager != nil { liveObj, err := scope.Creater.New(scope.Kind) if err != nil { diff --git a/pkg/endpoints/handlers/delete.go b/pkg/endpoints/handlers/delete.go index 498eeee5f..67fcd91f1 100644 --- a/pkg/endpoints/handlers/delete.go +++ b/pkg/endpoints/handlers/delete.go @@ -54,16 +54,17 @@ func DeleteResource(r rest.GracefulDeleter, allowsOptions bool, scope *RequestSc return } - // TODO: we either want to remove timeout or document it (if we document, move timeout out of this function and declare it in api_installer) - timeout := parseTimeout(req.URL.Query().Get("timeout")) - namespace, name, err := scope.Namer.Name(req) if err != nil { scope.err(err, w, req) return } - ctx, cancel := context.WithTimeout(req.Context(), timeout) + + // enforce a timeout of at most requestTimeoutUpperBound (34s) or less if the user-provided + // timeout inside the parent context is lower than requestTimeoutUpperBound. + ctx, cancel := context.WithTimeout(req.Context(), requestTimeoutUpperBound) defer cancel() + ctx = request.WithNamespace(ctx, namespace) ae := request.AuditEventFrom(ctx) admit = admission.WithAudit(admit, ae) @@ -123,7 +124,7 @@ func DeleteResource(r rest.GracefulDeleter, allowsOptions bool, scope *RequestSc wasDeleted := true userInfo, _ := request.UserFrom(ctx) staticAdmissionAttrs := admission.NewAttributesRecord(nil, nil, scope.Kind, namespace, name, scope.Resource, scope.Subresource, admission.Delete, options, dryrun.IsDryRun(options.DryRun), userInfo) - result, err := finishRequest(timeout, func() (runtime.Object, error) { + result, err := finishRequest(ctx, func() (runtime.Object, error) { obj, deleted, err := r.Delete(ctx, name, rest.AdmissionToValidateObjectDeleteFunc(admit, staticAdmissionAttrs, scope), options) wasDeleted = deleted return obj, err @@ -172,17 +173,17 @@ func DeleteCollection(r rest.CollectionDeleter, checkBody bool, scope *RequestSc return } - // TODO: we either want to remove timeout or document it (if we document, move timeout out of this function and declare it in api_installer) - timeout := parseTimeout(req.URL.Query().Get("timeout")) - namespace, err := scope.Namer.Namespace(req) if err != nil { scope.err(err, w, req) return } - ctx, cancel := context.WithTimeout(req.Context(), timeout) + // enforce a timeout of at most requestTimeoutUpperBound (34s) or less if the user-provided + // timeout inside the parent context is lower than requestTimeoutUpperBound. + ctx, cancel := context.WithTimeout(req.Context(), requestTimeoutUpperBound) defer cancel() + ctx = request.WithNamespace(ctx, namespace) ae := request.AuditEventFrom(ctx) @@ -265,7 +266,7 @@ func DeleteCollection(r rest.CollectionDeleter, checkBody bool, scope *RequestSc admit = admission.WithAudit(admit, ae) userInfo, _ := request.UserFrom(ctx) staticAdmissionAttrs := admission.NewAttributesRecord(nil, nil, scope.Kind, namespace, "", scope.Resource, scope.Subresource, admission.Delete, options, dryrun.IsDryRun(options.DryRun), userInfo) - result, err := finishRequest(timeout, func() (runtime.Object, error) { + result, err := finishRequest(ctx, func() (runtime.Object, error) { return r.DeleteCollection(ctx, rest.AdmissionToValidateObjectDeleteFunc(admit, staticAdmissionAttrs, scope), options, &listOptions) }) if err != nil { diff --git a/pkg/endpoints/handlers/patch.go b/pkg/endpoints/handlers/patch.go index 5fe8cedc2..4c0d10230 100644 --- a/pkg/endpoints/handlers/patch.go +++ b/pkg/endpoints/handlers/patch.go @@ -84,19 +84,17 @@ func PatchResource(r rest.Patcher, scope *RequestScope, admit admission.Interfac return } - // TODO: we either want to remove timeout or document it (if we - // document, move timeout out of this function and declare it in - // api_installer) - timeout := parseTimeout(req.URL.Query().Get("timeout")) - namespace, name, err := scope.Namer.Name(req) if err != nil { scope.err(err, w, req) return } - ctx, cancel := context.WithTimeout(req.Context(), timeout) + // enforce a timeout of at most requestTimeoutUpperBound (34s) or less if the user-provided + // timeout inside the parent context is lower than requestTimeoutUpperBound. + ctx, cancel := context.WithTimeout(req.Context(), requestTimeoutUpperBound) defer cancel() + ctx = request.WithNamespace(ctx, namespace) outputMediaType, _, err := negotiation.NegotiateOutputMediaType(req, scope.Serializer, scope) @@ -208,7 +206,6 @@ func PatchResource(r rest.Patcher, scope *RequestScope, admit admission.Interfac codec: codec, - timeout: timeout, options: options, restPatcher: r, @@ -271,7 +268,6 @@ type patcher struct { codec runtime.Codec - timeout time.Duration options *metav1.PatchOptions // Operation information @@ -591,7 +587,7 @@ func (p *patcher) patchResource(ctx context.Context, scope *RequestScope) (runti wasCreated = created return updateObject, updateErr } - result, err := finishRequest(p.timeout, func() (runtime.Object, error) { + result, err := finishRequest(ctx, func() (runtime.Object, error) { result, err := requestFunc() // If the object wasn't committed to storage because it's serialized size was too large, // it is safe to remove managedFields (which can be large) and try again. diff --git a/pkg/endpoints/handlers/rest.go b/pkg/endpoints/handlers/rest.go index 4818b9f60..783ab96b9 100644 --- a/pkg/endpoints/handlers/rest.go +++ b/pkg/endpoints/handlers/rest.go @@ -53,6 +53,9 @@ import ( ) const ( + // 34 chose as a number close to 30 that is likely to be unique enough to jump out at me the next time I see a timeout. + // Everyone chooses 30. + requestTimeoutUpperBound = 34 * time.Second // DuplicateOwnerReferencesWarningFormat is the warning that a client receives when a create/update request contains // duplicate owner reference entries. DuplicateOwnerReferencesWarningFormat = ".metadata.ownerReferences contains duplicate entries; API server dedups owner references in 1.20+, and may reject such requests as early as 1.24; please fix your requests; duplicate UID(s) observed: %v" @@ -227,7 +230,7 @@ type resultFunc func() (runtime.Object, error) // finishRequest makes a given resultFunc asynchronous and handles errors returned by the response. // An api.Status object with status != success is considered an "error", which interrupts the normal response flow. -func finishRequest(timeout time.Duration, fn resultFunc) (result runtime.Object, err error) { +func finishRequest(ctx context.Context, fn resultFunc) (result runtime.Object, err error) { // these channels need to be buffered to prevent the goroutine below from hanging indefinitely // when the select statement reads something other than the one the goroutine sends on. ch := make(chan runtime.Object, 1) @@ -271,8 +274,8 @@ func finishRequest(timeout time.Duration, fn resultFunc) (result runtime.Object, return nil, err case p := <-panicCh: panic(p) - case <-time.After(timeout): - return nil, errors.NewTimeoutError(fmt.Sprintf("request did not complete within requested timeout %s", timeout), 0) + case <-ctx.Done(): + return nil, errors.NewTimeoutError(fmt.Sprintf("request did not complete within requested timeout %s", ctx.Err()), 0) } } @@ -487,19 +490,6 @@ func limitedReadBody(req *http.Request, limit int64) ([]byte, error) { return data, nil } -func parseTimeout(str string) time.Duration { - if str != "" { - timeout, err := time.ParseDuration(str) - if err == nil { - return timeout - } - - klog.ErrorS(err, "Failed to parse", "time", str) - } - // 34 chose as a number close to 30 that is likely to be unique enough to jump out at me the next time I see a timeout. Everyone chooses 30. - return 34 * time.Second -} - func isDryRun(url *url.URL) bool { return len(url.Query()["dryRun"]) != 0 } diff --git a/pkg/endpoints/handlers/rest_test.go b/pkg/endpoints/handlers/rest_test.go index 5e9dbb07c..5bd34f004 100644 --- a/pkg/endpoints/handlers/rest_test.go +++ b/pkg/endpoints/handlers/rest_test.go @@ -456,8 +456,6 @@ func (tc *patchTestCase) Run(t *testing.T) { codec: codec, - timeout: 1 * time.Second, - restPatcher: testPatcher, name: name, patchType: patchType, @@ -466,7 +464,10 @@ func (tc *patchTestCase) Run(t *testing.T) { trace: utiltrace.New("Patch", utiltrace.Field{"name", name}), } + ctx, cancel := context.WithTimeout(ctx, time.Second) resultObj, _, err := p.patchResource(ctx, &RequestScope{}) + cancel() + if len(tc.expectedError) != 0 { if err == nil || err.Error() != tc.expectedError { t.Errorf("%s: expected error %v, but got %v", tc.name, tc.expectedError, err) @@ -825,26 +826,18 @@ func TestHasUID(t *testing.T) { } } -func TestParseTimeout(t *testing.T) { - if d := parseTimeout(""); d != 34*time.Second { - t.Errorf("blank timeout produces %v", d) - } - if d := parseTimeout("not a timeout"); d != 34*time.Second { - t.Errorf("bad timeout produces %v", d) - } - if d := parseTimeout("10s"); d != 10*time.Second { - t.Errorf("10s timeout produced: %v", d) - } -} - func TestFinishRequest(t *testing.T) { exampleObj := &example.Pod{} exampleErr := fmt.Errorf("error") successStatusObj := &metav1.Status{Status: metav1.StatusSuccess, Message: "success message"} errorStatusObj := &metav1.Status{Status: metav1.StatusFailure, Message: "error message"} + timeoutFunc := func() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.TODO(), time.Second) + } + testcases := []struct { name string - timeout time.Duration + timeout func() (context.Context, context.CancelFunc) fn resultFunc expectedObj runtime.Object expectedErr error @@ -854,7 +847,7 @@ func TestFinishRequest(t *testing.T) { }{ { name: "Expected obj is returned", - timeout: time.Second, + timeout: timeoutFunc, fn: func() (runtime.Object, error) { return exampleObj, nil }, @@ -863,7 +856,7 @@ func TestFinishRequest(t *testing.T) { }, { name: "Expected error is returned", - timeout: time.Second, + timeout: timeoutFunc, fn: func() (runtime.Object, error) { return nil, exampleErr }, @@ -872,7 +865,7 @@ func TestFinishRequest(t *testing.T) { }, { name: "Successful status object is returned as expected", - timeout: time.Second, + timeout: timeoutFunc, fn: func() (runtime.Object, error) { return successStatusObj, nil }, @@ -881,7 +874,7 @@ func TestFinishRequest(t *testing.T) { }, { name: "Error status object is converted to StatusError", - timeout: time.Second, + timeout: timeoutFunc, fn: func() (runtime.Object, error) { return errorStatusObj, nil }, @@ -890,7 +883,7 @@ func TestFinishRequest(t *testing.T) { }, { name: "Panic is propagated up", - timeout: time.Second, + timeout: timeoutFunc, fn: func() (runtime.Object, error) { panic("my panic") }, @@ -900,7 +893,7 @@ func TestFinishRequest(t *testing.T) { }, { name: "Panic is propagated with stack", - timeout: time.Second, + timeout: timeoutFunc, fn: func() (runtime.Object, error) { panic("my panic") }, @@ -910,7 +903,7 @@ func TestFinishRequest(t *testing.T) { }, { name: "http.ErrAbortHandler panic is propagated without wrapping with stack", - timeout: time.Second, + timeout: timeoutFunc, fn: func() (runtime.Object, error) { panic(http.ErrAbortHandler) }, @@ -922,7 +915,10 @@ func TestFinishRequest(t *testing.T) { } for i, tc := range testcases { t.Run(tc.name, func(t *testing.T) { + ctx, cancel := tc.timeout() defer func() { + cancel() + r := recover() switch { case r == nil && len(tc.expectedPanic) > 0: @@ -937,7 +933,7 @@ func TestFinishRequest(t *testing.T) { t.Errorf("expected panic obj %#v, got %#v", tc.expectedPanicObj, r) } }() - obj, err := finishRequest(tc.timeout, tc.fn) + obj, err := finishRequest(ctx, tc.fn) if (err == nil && tc.expectedErr != nil) || (err != nil && tc.expectedErr == nil) || (err != nil && tc.expectedErr != nil && err.Error() != tc.expectedErr.Error()) { t.Errorf("%d: unexpected err. expected: %v, got: %v", i, tc.expectedErr, err) } diff --git a/pkg/endpoints/handlers/update.go b/pkg/endpoints/handlers/update.go index 1c46e7f0d..fd215bb38 100644 --- a/pkg/endpoints/handlers/update.go +++ b/pkg/endpoints/handlers/update.go @@ -54,16 +54,17 @@ func UpdateResource(r rest.Updater, scope *RequestScope, admit admission.Interfa return } - // TODO: we either want to remove timeout or document it (if we document, move timeout out of this function and declare it in api_installer) - timeout := parseTimeout(req.URL.Query().Get("timeout")) - namespace, name, err := scope.Namer.Name(req) if err != nil { scope.err(err, w, req) return } - ctx, cancel := context.WithTimeout(req.Context(), timeout) + + // enforce a timeout of at most requestTimeoutUpperBound (34s) or less if the user-provided + // timeout inside the parent context is lower than requestTimeoutUpperBound. + ctx, cancel := context.WithTimeout(req.Context(), requestTimeoutUpperBound) defer cancel() + ctx = request.WithNamespace(ctx, namespace) outputMediaType, _, err := negotiation.NegotiateOutputMediaType(req, scope.Serializer, scope) @@ -195,7 +196,7 @@ func UpdateResource(r rest.Updater, scope *RequestScope, admit admission.Interfa } // Dedup owner references before updating managed fields dedupOwnerReferencesAndAddWarning(obj, req.Context(), false) - result, err := finishRequest(timeout, func() (runtime.Object, error) { + result, err := finishRequest(ctx, func() (runtime.Object, error) { result, err := requestFunc() // If the object wasn't committed to storage because it's serialized size was too large, // it is safe to remove managedFields (which can be large) and try again. diff --git a/pkg/server/config.go b/pkg/server/config.go index 92b0fd30a..52ee3b3dd 100644 --- a/pkg/server/config.go +++ b/pkg/server/config.go @@ -746,7 +746,13 @@ func DefaultBuildHandlerChain(apiHandler http.Handler, c *Config) http.Handler { handler = filterlatency.TrackStarted(handler, "authentication") handler = genericfilters.WithCORS(handler, c.CorsAllowedOriginList, nil, nil, nil, "true") - handler = genericfilters.WithTimeoutForNonLongRunningRequests(handler, c.LongRunningFunc, c.RequestTimeout) + + // WithTimeoutForNonLongRunningRequests will call the rest of the request handling in a go-routine with the + // context with deadline. The go-routine can keep running, while the timeout logic will return a timeout to the client. + handler = genericfilters.WithTimeoutForNonLongRunningRequests(handler, c.LongRunningFunc) + + handler = genericapifilters.WithRequestDeadline(handler, c.AuditBackend, c.AuditPolicyChecker, + c.LongRunningFunc, c.Serializer, c.RequestTimeout) handler = genericfilters.WithWaitGroup(handler, c.LongRunningFunc, c.HandlerChainWaitGroup) handler = genericapifilters.WithRequestInfo(handler, c.RequestInfoResolver) if c.SecureServing != nil && !c.SecureServing.DisableHTTP2 && c.GoawayChance > 0 { diff --git a/pkg/server/filters/priority-and-fairness_test.go b/pkg/server/filters/priority-and-fairness_test.go index 8cd8c867a..fa33c5dc5 100644 --- a/pkg/server/filters/priority-and-fairness_test.go +++ b/pkg/server/filters/priority-and-fairness_test.go @@ -509,7 +509,7 @@ func newHandlerChain(t *testing.T, handler http.Handler, filter utilflowcontrol. apfHandler.ServeHTTP(w, r) }) - handler = WithTimeoutForNonLongRunningRequests(handler, longRunningRequestCheck, requestTimeout) + handler = WithTimeoutForNonLongRunningRequests(handler, longRunningRequestCheck) handler = apifilters.WithRequestInfo(handler, requestInfoFactory) handler = WithPanicRecovery(handler, requestInfoFactory) return handler diff --git a/pkg/server/filters/timeout.go b/pkg/server/filters/timeout.go index 2405bfd1f..ccbed60db 100644 --- a/pkg/server/filters/timeout.go +++ b/pkg/server/filters/timeout.go @@ -18,14 +18,12 @@ package filters import ( "bufio" - "context" "encoding/json" "fmt" "net" "net/http" "runtime" "sync" - "time" apierrors "k8s.io/apimachinery/pkg/api/errors" utilruntime "k8s.io/apimachinery/pkg/util/runtime" @@ -34,37 +32,33 @@ import ( ) // WithTimeoutForNonLongRunningRequests times out non-long-running requests after the time given by timeout. -func WithTimeoutForNonLongRunningRequests(handler http.Handler, longRunning apirequest.LongRunningRequestCheck, timeout time.Duration) http.Handler { +func WithTimeoutForNonLongRunningRequests(handler http.Handler, longRunning apirequest.LongRunningRequestCheck) http.Handler { if longRunning == nil { return handler } - timeoutFunc := func(req *http.Request) (*http.Request, <-chan time.Time, func(), *apierrors.StatusError) { + timeoutFunc := func(req *http.Request) (*http.Request, bool, func(), *apierrors.StatusError) { // TODO unify this with apiserver.MaxInFlightLimit ctx := req.Context() requestInfo, ok := apirequest.RequestInfoFrom(ctx) if !ok { // if this happens, the handler chain isn't setup correctly because there is no request info - return req, time.After(timeout), func() {}, apierrors.NewInternalError(fmt.Errorf("no request info found for request during timeout")) + return req, false, func() {}, apierrors.NewInternalError(fmt.Errorf("no request info found for request during timeout")) } if longRunning(req, requestInfo) { - return req, nil, nil, nil + return req, true, nil, nil } - ctx, cancel := context.WithCancel(ctx) - req = req.WithContext(ctx) - postTimeoutFn := func() { - cancel() metrics.RecordRequestTermination(req, requestInfo, metrics.APIServerComponent, http.StatusGatewayTimeout) } - return req, time.After(timeout), postTimeoutFn, apierrors.NewTimeoutError(fmt.Sprintf("request did not complete within %s", timeout), 0) + return req, false, postTimeoutFn, apierrors.NewTimeoutError("request did not complete within the allotted timeout", 0) } return WithTimeout(handler, timeoutFunc) } -type timeoutFunc = func(*http.Request) (req *http.Request, timeout <-chan time.Time, postTimeoutFunc func(), err *apierrors.StatusError) +type timeoutFunc = func(*http.Request) (req *http.Request, longRunning bool, postTimeoutFunc func(), err *apierrors.StatusError) // WithTimeout returns an http.Handler that runs h with a timeout // determined by timeoutFunc. The new http.Handler calls h.ServeHTTP to handle @@ -85,12 +79,14 @@ type timeoutHandler struct { } func (t *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - r, after, postTimeoutFn, err := t.timeout(r) - if after == nil { + r, longRunning, postTimeoutFn, err := t.timeout(r) + if longRunning { t.handler.ServeHTTP(w, r) return } + timeoutCh := r.Context().Done() + // resultCh is used as both errCh and stopCh resultCh := make(chan interface{}) tw := newTimeoutWriter(w) @@ -117,7 +113,7 @@ func (t *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { panic(err) } return - case <-after: + case <-timeoutCh: defer func() { // resultCh needs to have a reader, since the function doing // the work needs to send to it. This is defer'd to ensure it runs diff --git a/pkg/server/filters/timeout_test.go b/pkg/server/filters/timeout_test.go index 15767fec6..faf8c1ad8 100644 --- a/pkg/server/filters/timeout_test.go +++ b/pkg/server/filters/timeout_test.go @@ -18,6 +18,7 @@ package filters import ( "bytes" + "context" "crypto/tls" "crypto/x509" "encoding/json" @@ -92,18 +93,27 @@ func TestTimeout(t *testing.T) { timeoutErr := apierrors.NewServerTimeout(schema.GroupResource{Group: "foo", Resource: "bar"}, "get", 0) record := &recorder{} + var ctx context.Context + withDeadline := func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + req = req.WithContext(ctx) + handler.ServeHTTP(w, req) + }) + } + handler := newHandler(sendResponse, doPanic, writeErrors) - ts := httptest.NewServer(withPanicRecovery( - WithTimeout(handler, func(req *http.Request) (*http.Request, <-chan time.Time, func(), *apierrors.StatusError) { - return req, timeout, record.Record, timeoutErr + ts := httptest.NewServer(withDeadline(withPanicRecovery( + WithTimeout(handler, func(req *http.Request) (*http.Request, bool, func(), *apierrors.StatusError) { + return req, false, record.Record, timeoutErr }), func(w http.ResponseWriter, req *http.Request, err interface{}) { gotPanic <- err http.Error(w, "This request caused apiserver to panic. Look in the logs for details.", http.StatusInternalServerError) }), - ) + )) defer ts.Close() // No timeouts + ctx = context.Background() sendResponse <- resp res, err := http.Get(ts.URL) if err != nil { @@ -124,6 +134,8 @@ func TestTimeout(t *testing.T) { } // Times out + ctx, cancel := context.WithCancel(context.Background()) + cancel() timeout <- time.Time{} res, err = http.Get(ts.URL) if err != nil { @@ -145,6 +157,7 @@ func TestTimeout(t *testing.T) { } // Now try to send a response + ctx = context.Background() sendResponse <- resp if err := <-writeErrors; err != http.ErrHandlerTimeout { t.Errorf("got Write error of %v; expected %v", err, http.ErrHandlerTimeout) @@ -170,6 +183,7 @@ func TestTimeout(t *testing.T) { } // Panics with http.ErrAbortHandler + ctx = context.Background() doPanic <- http.ErrAbortHandler res, err = http.Get(ts.URL) if err != nil {