apiserver: plumb context with request deadline

- as soon as a request is received by the apiserver, determine the
  timeout of the request and set a new request context with the deadline.
- the timeout filter that times out non-long-running requests should
  use the request context as opposed to a fixed 60s wait today.
- admission and storage layer uses the same request context with the
  deadline specified.

we use the default timeout enforced by the apiserver:
- if the user has specified a timeout of 0s, this implies no timeout on the user's part.
- if the user has specified a timeout that exceeds the maximum deadline allowed by the apiserver.

Kubernetes-commit: e416c9e574c49fd0190c8cdac58322aa33a935cf
This commit is contained in:
Abu Kashem 2020-11-26 23:53:20 -05:00 committed by Kubernetes Publisher
parent d4c9a19592
commit 026eb846a4
13 changed files with 743 additions and 91 deletions

View File

@ -282,10 +282,13 @@ func handleInternal(storage map[string]rest.Storage, admissionControl admission.
panic(fmt.Sprintf("unable to install container %s: %v", group.GroupVersion, err))
}
}
handler := genericapifilters.WithAudit(mux, auditSink, auditpolicy.FakeChecker(auditinternal.LevelRequestResponse, nil), func(r *http.Request, requestInfo *request.RequestInfo) bool {
longRunningCheck := func(r *http.Request, requestInfo *request.RequestInfo) bool {
// simplified long-running check
return requestInfo.Verb == "watch" || requestInfo.Verb == "proxy"
})
}
fakeChecker := auditpolicy.FakeChecker(auditinternal.LevelRequestResponse, nil)
handler := genericapifilters.WithAudit(mux, auditSink, fakeChecker, longRunningCheck)
handler = genericapifilters.WithRequestDeadline(handler, auditSink, fakeChecker, longRunningCheck, codecs, 60*time.Second)
handler = genericapifilters.WithRequestInfo(handler, testRequestInfoResolver())
return &defaultAPIServer{handler, container}

View File

@ -0,0 +1,172 @@
/*
Copyright 2020 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package filters
import (
"context"
"errors"
"fmt"
"net/http"
"time"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
utilclock "k8s.io/apimachinery/pkg/util/clock"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
auditinternal "k8s.io/apiserver/pkg/apis/audit"
"k8s.io/apiserver/pkg/audit"
"k8s.io/apiserver/pkg/audit/policy"
"k8s.io/apiserver/pkg/endpoints/handlers/responsewriters"
"k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/klog/v2"
)
const (
// The 'timeout' query parameter in the request URL has an invalid duration specifier
invalidTimeoutInURL = "invalid timeout specified in the request URL"
)
// WithRequestDeadline determines the timeout duration applicable to the given request and sets a new context
// with the appropriate deadline.
// auditWrapper provides an http.Handler that audits a failed request.
// longRunning returns true if he given request is a long running request.
// requestTimeoutMaximum specifies the default request timeout value.
func WithRequestDeadline(handler http.Handler, sink audit.Sink, policy policy.Checker, longRunning request.LongRunningRequestCheck,
negotiatedSerializer runtime.NegotiatedSerializer, requestTimeoutMaximum time.Duration) http.Handler {
return withRequestDeadline(handler, sink, policy, longRunning, negotiatedSerializer, requestTimeoutMaximum, utilclock.RealClock{})
}
func withRequestDeadline(handler http.Handler, sink audit.Sink, policy policy.Checker, longRunning request.LongRunningRequestCheck,
negotiatedSerializer runtime.NegotiatedSerializer, requestTimeoutMaximum time.Duration, clock utilclock.PassiveClock) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
requestInfo, ok := request.RequestInfoFrom(ctx)
if !ok {
handleError(w, req, http.StatusInternalServerError, fmt.Errorf("no RequestInfo found in context, handler chain must be wrong"))
return
}
if longRunning(req, requestInfo) {
handler.ServeHTTP(w, req)
return
}
userSpecifiedTimeout, ok, err := parseTimeout(req)
if err != nil {
statusErr := apierrors.NewBadRequest(fmt.Sprintf("%s", err.Error()))
klog.Errorf("Error - %s: %#v", err.Error(), req.RequestURI)
failed := failedErrorHandler(negotiatedSerializer, statusErr)
failWithAudit := withFailedRequestAudit(failed, statusErr, sink, policy)
failWithAudit.ServeHTTP(w, req)
return
}
timeout := requestTimeoutMaximum
if ok {
// we use the default timeout enforced by the apiserver:
// - if the user has specified a timeout of 0s, this implies no timeout on the user's part.
// - if the user has specified a timeout that exceeds the maximum deadline allowed by the apiserver.
if userSpecifiedTimeout > 0 && userSpecifiedTimeout < requestTimeoutMaximum {
timeout = userSpecifiedTimeout
}
}
started := clock.Now()
if requestStartedTimestamp, ok := request.ReceivedTimestampFrom(ctx); ok {
started = requestStartedTimestamp
}
ctx, cancel := context.WithDeadline(ctx, started.Add(timeout))
defer cancel()
req = req.WithContext(ctx)
handler.ServeHTTP(w, req)
})
}
// withFailedRequestAudit decorates a failed http.Handler and is used to audit a failed request.
// statusErr is used to populate the Message property of ResponseStatus.
func withFailedRequestAudit(failedHandler http.Handler, statusErr *apierrors.StatusError, sink audit.Sink, policy policy.Checker) http.Handler {
if sink == nil || policy == nil {
return failedHandler
}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req, ev, omitStages, err := createAuditEventAndAttachToContext(req, policy)
if err != nil {
utilruntime.HandleError(fmt.Errorf("failed to create audit event: %v", err))
responsewriters.InternalError(w, req, errors.New("failed to create audit event"))
return
}
if ev == nil {
failedHandler.ServeHTTP(w, req)
return
}
ev.ResponseStatus = &metav1.Status{}
ev.Stage = auditinternal.StageResponseStarted
if statusErr != nil {
ev.ResponseStatus.Message = statusErr.Error()
}
rw := decorateResponseWriter(w, ev, sink, omitStages)
failedHandler.ServeHTTP(rw, req)
})
}
// failedErrorHandler returns an http.Handler that uses the specified StatusError object
// to render an error response to the request.
func failedErrorHandler(s runtime.NegotiatedSerializer, statusError *apierrors.StatusError) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
requestInfo, found := request.RequestInfoFrom(ctx)
if !found {
responsewriters.InternalError(w, req, errors.New("no RequestInfo found in the context"))
return
}
gv := schema.GroupVersion{Group: requestInfo.APIGroup, Version: requestInfo.APIVersion}
responsewriters.ErrorNegotiated(statusError, s, gv, w, req)
})
}
// parseTimeout parses the given HTTP request URL and extracts the timeout query parameter
// value if specified by the user.
// If a timeout is not specified the function returns false and err is set to nil
// If the value specified is malformed then the function returns false and err is set
func parseTimeout(req *http.Request) (time.Duration, bool, error) {
value := req.URL.Query().Get("timeout")
if value == "" {
return 0, false, nil
}
timeout, err := time.ParseDuration(value)
if err != nil {
return 0, false, fmt.Errorf("%s - %s", invalidTimeoutInURL, err.Error())
}
return timeout, true, nil
}
func handleError(w http.ResponseWriter, r *http.Request, code int, err error) {
errorMsg := fmt.Sprintf("Error - %s: %#v", err.Error(), r.RequestURI)
http.Error(w, errorMsg, code)
klog.Errorf(errorMsg)
}

View File

@ -0,0 +1,478 @@
/*
Copyright 2020 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package filters
import (
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/serializer"
utilclock "k8s.io/apimachinery/pkg/util/clock"
auditinternal "k8s.io/apiserver/pkg/apis/audit"
"k8s.io/apiserver/pkg/audit/policy"
"k8s.io/apiserver/pkg/endpoints/request"
)
func TestParseTimeout(t *testing.T) {
tests := []struct {
name string
url string
expected bool
timeoutExpected time.Duration
message string
}{
{
name: "the user does not specify a timeout",
url: "/api/v1/namespaces?timeout=",
},
{
name: "the user specifies a valid timeout",
url: "/api/v1/namespaces?timeout=10s",
expected: true,
timeoutExpected: 10 * time.Second,
},
{
name: "the user specifies a timeout of 0s",
url: "/api/v1/namespaces?timeout=0s",
expected: true,
},
{
name: "the user specifies an invalid timeout",
url: "/api/v1/namespaces?timeout=foo",
message: invalidTimeoutInURL,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, test.url, nil)
if err != nil {
t.Fatalf("failed to create new http request - %v", err)
}
timeoutGot, ok, err := parseTimeout(request)
if test.expected != ok {
t.Errorf("expected: %t, but got: %t", test.expected, ok)
}
if test.timeoutExpected != timeoutGot {
t.Errorf("expected timeout: %s, but got: %s", test.timeoutExpected, timeoutGot)
}
errMessageGot := message(err)
if !strings.Contains(errMessageGot, test.message) {
t.Errorf("expected error message to contain: %s, but got: %s", test.message, errMessageGot)
}
})
}
}
func TestWithRequestDeadline(t *testing.T) {
const requestTimeoutMaximum = 60 * time.Second
tests := []struct {
name string
requestURL string
longRunning bool
hasDeadlineExpected bool
deadlineExpected time.Duration
handlerCallCountExpected int
statusCodeExpected int
}{
{
name: "the user specifies a valid request timeout",
requestURL: "/api/v1/namespaces?timeout=15s",
longRunning: false,
handlerCallCountExpected: 1,
hasDeadlineExpected: true,
deadlineExpected: 14 * time.Second, // to account for the delay in verification
statusCodeExpected: http.StatusOK,
},
{
name: "the user specifies a valid request timeout",
requestURL: "/api/v1/namespaces?timeout=15s",
longRunning: false,
handlerCallCountExpected: 1,
hasDeadlineExpected: true,
deadlineExpected: 14 * time.Second, // to account for the delay in verification
statusCodeExpected: http.StatusOK,
},
{
name: "the specified timeout is 0s, default deadline is expected to be set",
requestURL: "/api/v1/namespaces?timeout=0s",
longRunning: false,
handlerCallCountExpected: 1,
hasDeadlineExpected: true,
deadlineExpected: requestTimeoutMaximum - time.Second, // to account for the delay in verification
statusCodeExpected: http.StatusOK,
},
{
name: "the user does not specify any request timeout, default deadline is expected to be set",
requestURL: "/api/v1/namespaces?timeout=",
longRunning: false,
handlerCallCountExpected: 1,
hasDeadlineExpected: true,
deadlineExpected: requestTimeoutMaximum - time.Second, // to account for the delay in verification
statusCodeExpected: http.StatusOK,
},
{
name: "the request is long running, no deadline is expected to be set",
requestURL: "/api/v1/namespaces?timeout=10s",
longRunning: true,
hasDeadlineExpected: false,
handlerCallCountExpected: 1,
statusCodeExpected: http.StatusOK,
},
{
name: "the timeout specified is malformed, the request is aborted with HTTP 400",
requestURL: "/api/v1/namespaces?timeout=foo",
longRunning: false,
statusCodeExpected: http.StatusBadRequest,
},
{
name: "the timeout specified exceeds the maximum deadline allowed, the default deadline is used",
requestURL: fmt.Sprintf("/api/v1/namespaces?timeout=%s", requestTimeoutMaximum+time.Second),
longRunning: false,
statusCodeExpected: http.StatusOK,
handlerCallCountExpected: 1,
hasDeadlineExpected: true,
deadlineExpected: requestTimeoutMaximum - time.Second, // to account for the delay in verification
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var (
callCount int
hasDeadlineGot bool
deadlineGot time.Duration
)
handler := http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) {
callCount++
deadlineGot, hasDeadlineGot = deadline(req)
})
fakeSink := &fakeAuditSink{}
fakeChecker := policy.FakeChecker(auditinternal.LevelRequestResponse, nil)
withDeadline := WithRequestDeadline(handler, fakeSink, fakeChecker,
func(_ *http.Request, _ *request.RequestInfo) bool { return test.longRunning },
newSerializer(), requestTimeoutMaximum)
withDeadline = WithRequestInfo(withDeadline, &fakeRequestResolver{})
testRequest := newRequest(t, test.requestURL)
// make sure a default request does not have any deadline set
remaning, ok := deadline(testRequest)
if ok {
t.Fatalf("test setup failed, expected the new HTTP request context to have no deadline but got: %s", remaning)
}
w := httptest.NewRecorder()
withDeadline.ServeHTTP(w, testRequest)
if test.handlerCallCountExpected != callCount {
t.Errorf("expected the request handler to be invoked %d times, but was actually invoked %d times", test.handlerCallCountExpected, callCount)
}
if test.hasDeadlineExpected != hasDeadlineGot {
t.Errorf("expected the request context to have deadline set: %t but got: %t", test.hasDeadlineExpected, hasDeadlineGot)
}
deadlineGot = deadlineGot.Truncate(time.Second)
if test.deadlineExpected != deadlineGot {
t.Errorf("expected a request context with a deadline of %s but got: %s", test.deadlineExpected, deadlineGot)
}
statusCodeGot := w.Result().StatusCode
if test.statusCodeExpected != statusCodeGot {
t.Errorf("expected status code %d but got: %d", test.statusCodeExpected, statusCodeGot)
}
})
}
}
func TestWithRequestDeadlineWithClock(t *testing.T) {
var (
hasDeadlineGot bool
deadlineGot time.Duration
)
handler := http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) {
deadlineGot, hasDeadlineGot = deadline(req)
})
// if the deadline filter uses the clock instead of using the request started timestamp from the context
// then we will see a request deadline of about a minute.
receivedTimestampExpected := time.Now().Add(time.Minute)
fakeClock := utilclock.NewFakeClock(receivedTimestampExpected)
fakeSink := &fakeAuditSink{}
fakeChecker := policy.FakeChecker(auditinternal.LevelRequestResponse, nil)
withDeadline := withRequestDeadline(handler, fakeSink, fakeChecker,
func(_ *http.Request, _ *request.RequestInfo) bool { return false }, newSerializer(), time.Minute, fakeClock)
withDeadline = WithRequestInfo(withDeadline, &fakeRequestResolver{})
testRequest := newRequest(t, "/api/v1/namespaces?timeout=1s")
// the request has arrived just now.
testRequest = testRequest.WithContext(request.WithReceivedTimestamp(testRequest.Context(), time.Now()))
w := httptest.NewRecorder()
withDeadline.ServeHTTP(w, testRequest)
if !hasDeadlineGot {
t.Error("expected the request context to have deadline set")
}
// we expect a deadline <= 1s since the filter should use the request started timestamp from the context.
if deadlineGot > time.Second {
t.Errorf("expected a request context with a deadline <= %s, but got: %s", time.Second, deadlineGot)
}
}
func TestWithRequestDeadlineWithFailedRequestIsAudited(t *testing.T) {
var handlerInvoked bool
handler := http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) {
handlerInvoked = true
})
fakeSink := &fakeAuditSink{}
fakeChecker := policy.FakeChecker(auditinternal.LevelRequestResponse, nil)
withDeadline := WithRequestDeadline(handler, fakeSink, fakeChecker,
func(_ *http.Request, _ *request.RequestInfo) bool { return false }, newSerializer(), time.Minute)
withDeadline = WithRequestInfo(withDeadline, &fakeRequestResolver{})
testRequest := newRequest(t, "/api/v1/namespaces?timeout=foo")
w := httptest.NewRecorder()
withDeadline.ServeHTTP(w, testRequest)
if handlerInvoked {
t.Error("expected the request to fail and the handler to be skipped")
}
statusCodeGot := w.Result().StatusCode
if statusCodeGot != http.StatusBadRequest {
t.Errorf("expected status code %d, but got: %d", http.StatusBadRequest, statusCodeGot)
}
// verify that the audit event from the request context is written to the audit sink.
if len(fakeSink.events) != 1 {
t.Fatalf("expected audit sink to have 1 event, but got: %d", len(fakeSink.events))
}
}
func TestWithRequestDeadlineWithPanic(t *testing.T) {
var (
panicErrGot interface{}
ctxGot context.Context
)
panicErrExpected := errors.New("apiserver panic'd")
handler := http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) {
ctxGot = req.Context()
panic(panicErrExpected)
})
fakeSink := &fakeAuditSink{}
fakeChecker := policy.FakeChecker(auditinternal.LevelRequestResponse, nil)
withDeadline := WithRequestDeadline(handler, fakeSink, fakeChecker,
func(_ *http.Request, _ *request.RequestInfo) bool { return false }, newSerializer(), 1*time.Minute)
withDeadline = WithRequestInfo(withDeadline, &fakeRequestResolver{})
withPanicRecovery := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
defer func() {
panicErrGot = recover()
}()
withDeadline.ServeHTTP(w, req)
})
testRequest := newRequest(t, "/api/v1/namespaces?timeout=1s")
w := httptest.NewRecorder()
withPanicRecovery.ServeHTTP(w, testRequest)
if panicErrExpected != panicErrGot {
t.Errorf("expected panic error: %#v, but got: %#v", panicErrExpected, panicErrGot)
}
if ctxGot.Err() != context.Canceled {
t.Error("expected the request context to be canceled on handler panic")
}
}
func TestWithRequestDeadlineWithRequestTimesOut(t *testing.T) {
timeout := 100 * time.Millisecond
var errGot error
handler := http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) {
ctx := req.Context()
select {
case <-time.After(timeout + time.Second):
errGot = fmt.Errorf("expected the request context to have timed out in %s", timeout)
case <-ctx.Done():
errGot = ctx.Err()
}
})
fakeSink := &fakeAuditSink{}
fakeChecker := policy.FakeChecker(auditinternal.LevelRequestResponse, nil)
withDeadline := WithRequestDeadline(handler, fakeSink, fakeChecker,
func(_ *http.Request, _ *request.RequestInfo) bool { return false }, newSerializer(), 1*time.Minute)
withDeadline = WithRequestInfo(withDeadline, &fakeRequestResolver{})
testRequest := newRequest(t, fmt.Sprintf("/api/v1/namespaces?timeout=%s", timeout))
w := httptest.NewRecorder()
withDeadline.ServeHTTP(w, testRequest)
if errGot != context.DeadlineExceeded {
t.Errorf("expected error: %#v, but got: %#v", context.DeadlineExceeded, errGot)
}
}
func TestWithFailedRequestAudit(t *testing.T) {
tests := []struct {
name string
statusErr *apierrors.StatusError
errorHandlerCallCountExpected int
statusCodeExpected int
auditExpected bool
}{
{
name: "bad request, the error handler is invoked and the request is audited",
statusErr: apierrors.NewBadRequest("error serving request"),
errorHandlerCallCountExpected: 1,
statusCodeExpected: http.StatusBadRequest,
auditExpected: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var (
errorHandlerCallCountGot int
rwGot http.ResponseWriter
requestGot *http.Request
)
errorHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
http.Error(rw, "error serving request", http.StatusBadRequest)
errorHandlerCallCountGot++
requestGot = req
rwGot = rw
})
fakeSink := &fakeAuditSink{}
fakeChecker := policy.FakeChecker(auditinternal.LevelRequestResponse, nil)
withAudit := withFailedRequestAudit(errorHandler, test.statusErr, fakeSink, fakeChecker)
w := httptest.NewRecorder()
testRequest, err := http.NewRequest(http.MethodGet, "/apis/v1/namespaces/default/pods", nil)
if err != nil {
t.Fatalf("failed to create new http testRequest - %v", err)
}
info := request.RequestInfo{}
testRequest = testRequest.WithContext(request.WithRequestInfo(testRequest.Context(), &info))
withAudit.ServeHTTP(w, testRequest)
if test.errorHandlerCallCountExpected != errorHandlerCallCountGot {
t.Errorf("expected the testRequest handler to be invoked %d times, but was actually invoked %d times", test.errorHandlerCallCountExpected, errorHandlerCallCountGot)
}
statusCodeGot := w.Result().StatusCode
if test.statusCodeExpected != statusCodeGot {
t.Errorf("expected status code %d, but got: %d", test.statusCodeExpected, statusCodeGot)
}
if test.auditExpected {
// verify that the right http.ResponseWriter is passed to the error handler
_, ok := rwGot.(*auditResponseWriter)
if !ok {
t.Errorf("expected an http.ResponseWriter of type: %T but got: %T", &auditResponseWriter{}, rwGot)
}
auditEventGot := request.AuditEventFrom(requestGot.Context())
if auditEventGot == nil {
t.Fatal("expected an audit event object but got nil")
}
if auditEventGot.Stage != auditinternal.StageResponseStarted {
t.Errorf("expected audit event Stage: %s, but got: %s", auditinternal.StageResponseStarted, auditEventGot.Stage)
}
if auditEventGot.ResponseStatus == nil {
t.Fatal("expected a ResponseStatus field of the audit event object, but got nil")
}
if test.statusCodeExpected != int(auditEventGot.ResponseStatus.Code) {
t.Errorf("expected audit event ResponseStatus.Code: %d, but got: %d", test.statusCodeExpected, auditEventGot.ResponseStatus.Code)
}
if test.statusErr.Error() != auditEventGot.ResponseStatus.Message {
t.Errorf("expected audit event ResponseStatus.Message: %s, but got: %s", test.statusErr, auditEventGot.ResponseStatus.Message)
}
// verify that the audit event from the request context is written to the audit sink.
if len(fakeSink.events) != 1 {
t.Fatalf("expected audit sink to have 1 event, but got: %d", len(fakeSink.events))
}
auditEventFromSink := fakeSink.events[0]
if !reflect.DeepEqual(auditEventGot, auditEventFromSink) {
t.Errorf("expected the audit event from the request context to be written to the audit sink, but got diffs: %s", cmp.Diff(auditEventGot, auditEventFromSink))
}
}
})
}
}
func newRequest(t *testing.T, requestURL string) *http.Request {
req, err := http.NewRequest(http.MethodGet, requestURL, nil)
if err != nil {
t.Fatalf("failed to create new http request - %v", err)
}
return req
}
func message(err error) string {
if err != nil {
return err.Error()
}
return ""
}
func newSerializer() runtime.NegotiatedSerializer {
scheme := runtime.NewScheme()
return serializer.NewCodecFactory(scheme).WithoutConversion()
}
type fakeRequestResolver struct{}
func (r fakeRequestResolver) NewRequestInfo(req *http.Request) (*request.RequestInfo, error) {
return &request.RequestInfo{}, nil
}
func deadline(r *http.Request) (time.Duration, bool) {
if deadline, ok := r.Context().Deadline(); ok {
remaining := time.Until(deadline)
return remaining, ok
}
return 0, false
}

View File

@ -57,9 +57,6 @@ func createHandler(r rest.NamedCreater, scope *RequestScope, admit admission.Int
return
}
// TODO: we either want to remove timeout or document it (if we document, move timeout out of this function and declare it in api_installer)
timeout := parseTimeout(req.URL.Query().Get("timeout"))
namespace, name, err := scope.Namer.Name(req)
if err != nil {
if includeName {
@ -76,7 +73,9 @@ func createHandler(r rest.NamedCreater, scope *RequestScope, admit admission.Int
}
}
ctx, cancel := context.WithTimeout(req.Context(), timeout)
// enforce a timeout of at most requestTimeoutUpperBound (34s) or less if the user-provided
// timeout inside the parent context is lower than requestTimeoutUpperBound.
ctx, cancel := context.WithTimeout(req.Context(), requestTimeoutUpperBound)
defer cancel()
outputMediaType, _, err := negotiation.NegotiateOutputMediaType(req, scope.Serializer, scope)
if err != nil {
@ -157,7 +156,7 @@ func createHandler(r rest.NamedCreater, scope *RequestScope, admit admission.Int
}
// Dedup owner references before updating managed fields
dedupOwnerReferencesAndAddWarning(obj, req.Context(), false)
result, err := finishRequest(timeout, func() (runtime.Object, error) {
result, err := finishRequest(ctx, func() (runtime.Object, error) {
if scope.FieldManager != nil {
liveObj, err := scope.Creater.New(scope.Kind)
if err != nil {

View File

@ -54,16 +54,17 @@ func DeleteResource(r rest.GracefulDeleter, allowsOptions bool, scope *RequestSc
return
}
// TODO: we either want to remove timeout or document it (if we document, move timeout out of this function and declare it in api_installer)
timeout := parseTimeout(req.URL.Query().Get("timeout"))
namespace, name, err := scope.Namer.Name(req)
if err != nil {
scope.err(err, w, req)
return
}
ctx, cancel := context.WithTimeout(req.Context(), timeout)
// enforce a timeout of at most requestTimeoutUpperBound (34s) or less if the user-provided
// timeout inside the parent context is lower than requestTimeoutUpperBound.
ctx, cancel := context.WithTimeout(req.Context(), requestTimeoutUpperBound)
defer cancel()
ctx = request.WithNamespace(ctx, namespace)
ae := request.AuditEventFrom(ctx)
admit = admission.WithAudit(admit, ae)
@ -123,7 +124,7 @@ func DeleteResource(r rest.GracefulDeleter, allowsOptions bool, scope *RequestSc
wasDeleted := true
userInfo, _ := request.UserFrom(ctx)
staticAdmissionAttrs := admission.NewAttributesRecord(nil, nil, scope.Kind, namespace, name, scope.Resource, scope.Subresource, admission.Delete, options, dryrun.IsDryRun(options.DryRun), userInfo)
result, err := finishRequest(timeout, func() (runtime.Object, error) {
result, err := finishRequest(ctx, func() (runtime.Object, error) {
obj, deleted, err := r.Delete(ctx, name, rest.AdmissionToValidateObjectDeleteFunc(admit, staticAdmissionAttrs, scope), options)
wasDeleted = deleted
return obj, err
@ -172,17 +173,17 @@ func DeleteCollection(r rest.CollectionDeleter, checkBody bool, scope *RequestSc
return
}
// TODO: we either want to remove timeout or document it (if we document, move timeout out of this function and declare it in api_installer)
timeout := parseTimeout(req.URL.Query().Get("timeout"))
namespace, err := scope.Namer.Namespace(req)
if err != nil {
scope.err(err, w, req)
return
}
ctx, cancel := context.WithTimeout(req.Context(), timeout)
// enforce a timeout of at most requestTimeoutUpperBound (34s) or less if the user-provided
// timeout inside the parent context is lower than requestTimeoutUpperBound.
ctx, cancel := context.WithTimeout(req.Context(), requestTimeoutUpperBound)
defer cancel()
ctx = request.WithNamespace(ctx, namespace)
ae := request.AuditEventFrom(ctx)
@ -265,7 +266,7 @@ func DeleteCollection(r rest.CollectionDeleter, checkBody bool, scope *RequestSc
admit = admission.WithAudit(admit, ae)
userInfo, _ := request.UserFrom(ctx)
staticAdmissionAttrs := admission.NewAttributesRecord(nil, nil, scope.Kind, namespace, "", scope.Resource, scope.Subresource, admission.Delete, options, dryrun.IsDryRun(options.DryRun), userInfo)
result, err := finishRequest(timeout, func() (runtime.Object, error) {
result, err := finishRequest(ctx, func() (runtime.Object, error) {
return r.DeleteCollection(ctx, rest.AdmissionToValidateObjectDeleteFunc(admit, staticAdmissionAttrs, scope), options, &listOptions)
})
if err != nil {

View File

@ -84,19 +84,17 @@ func PatchResource(r rest.Patcher, scope *RequestScope, admit admission.Interfac
return
}
// TODO: we either want to remove timeout or document it (if we
// document, move timeout out of this function and declare it in
// api_installer)
timeout := parseTimeout(req.URL.Query().Get("timeout"))
namespace, name, err := scope.Namer.Name(req)
if err != nil {
scope.err(err, w, req)
return
}
ctx, cancel := context.WithTimeout(req.Context(), timeout)
// enforce a timeout of at most requestTimeoutUpperBound (34s) or less if the user-provided
// timeout inside the parent context is lower than requestTimeoutUpperBound.
ctx, cancel := context.WithTimeout(req.Context(), requestTimeoutUpperBound)
defer cancel()
ctx = request.WithNamespace(ctx, namespace)
outputMediaType, _, err := negotiation.NegotiateOutputMediaType(req, scope.Serializer, scope)
@ -208,7 +206,6 @@ func PatchResource(r rest.Patcher, scope *RequestScope, admit admission.Interfac
codec: codec,
timeout: timeout,
options: options,
restPatcher: r,
@ -271,7 +268,6 @@ type patcher struct {
codec runtime.Codec
timeout time.Duration
options *metav1.PatchOptions
// Operation information
@ -591,7 +587,7 @@ func (p *patcher) patchResource(ctx context.Context, scope *RequestScope) (runti
wasCreated = created
return updateObject, updateErr
}
result, err := finishRequest(p.timeout, func() (runtime.Object, error) {
result, err := finishRequest(ctx, func() (runtime.Object, error) {
result, err := requestFunc()
// If the object wasn't committed to storage because it's serialized size was too large,
// it is safe to remove managedFields (which can be large) and try again.

View File

@ -53,6 +53,9 @@ import (
)
const (
// 34 chose as a number close to 30 that is likely to be unique enough to jump out at me the next time I see a timeout.
// Everyone chooses 30.
requestTimeoutUpperBound = 34 * time.Second
// DuplicateOwnerReferencesWarningFormat is the warning that a client receives when a create/update request contains
// duplicate owner reference entries.
DuplicateOwnerReferencesWarningFormat = ".metadata.ownerReferences contains duplicate entries; API server dedups owner references in 1.20+, and may reject such requests as early as 1.24; please fix your requests; duplicate UID(s) observed: %v"
@ -227,7 +230,7 @@ type resultFunc func() (runtime.Object, error)
// finishRequest makes a given resultFunc asynchronous and handles errors returned by the response.
// An api.Status object with status != success is considered an "error", which interrupts the normal response flow.
func finishRequest(timeout time.Duration, fn resultFunc) (result runtime.Object, err error) {
func finishRequest(ctx context.Context, fn resultFunc) (result runtime.Object, err error) {
// these channels need to be buffered to prevent the goroutine below from hanging indefinitely
// when the select statement reads something other than the one the goroutine sends on.
ch := make(chan runtime.Object, 1)
@ -271,8 +274,8 @@ func finishRequest(timeout time.Duration, fn resultFunc) (result runtime.Object,
return nil, err
case p := <-panicCh:
panic(p)
case <-time.After(timeout):
return nil, errors.NewTimeoutError(fmt.Sprintf("request did not complete within requested timeout %s", timeout), 0)
case <-ctx.Done():
return nil, errors.NewTimeoutError(fmt.Sprintf("request did not complete within requested timeout %s", ctx.Err()), 0)
}
}
@ -487,19 +490,6 @@ func limitedReadBody(req *http.Request, limit int64) ([]byte, error) {
return data, nil
}
func parseTimeout(str string) time.Duration {
if str != "" {
timeout, err := time.ParseDuration(str)
if err == nil {
return timeout
}
klog.ErrorS(err, "Failed to parse", "time", str)
}
// 34 chose as a number close to 30 that is likely to be unique enough to jump out at me the next time I see a timeout. Everyone chooses 30.
return 34 * time.Second
}
func isDryRun(url *url.URL) bool {
return len(url.Query()["dryRun"]) != 0
}

View File

@ -456,8 +456,6 @@ func (tc *patchTestCase) Run(t *testing.T) {
codec: codec,
timeout: 1 * time.Second,
restPatcher: testPatcher,
name: name,
patchType: patchType,
@ -466,7 +464,10 @@ func (tc *patchTestCase) Run(t *testing.T) {
trace: utiltrace.New("Patch", utiltrace.Field{"name", name}),
}
ctx, cancel := context.WithTimeout(ctx, time.Second)
resultObj, _, err := p.patchResource(ctx, &RequestScope{})
cancel()
if len(tc.expectedError) != 0 {
if err == nil || err.Error() != tc.expectedError {
t.Errorf("%s: expected error %v, but got %v", tc.name, tc.expectedError, err)
@ -825,26 +826,18 @@ func TestHasUID(t *testing.T) {
}
}
func TestParseTimeout(t *testing.T) {
if d := parseTimeout(""); d != 34*time.Second {
t.Errorf("blank timeout produces %v", d)
}
if d := parseTimeout("not a timeout"); d != 34*time.Second {
t.Errorf("bad timeout produces %v", d)
}
if d := parseTimeout("10s"); d != 10*time.Second {
t.Errorf("10s timeout produced: %v", d)
}
}
func TestFinishRequest(t *testing.T) {
exampleObj := &example.Pod{}
exampleErr := fmt.Errorf("error")
successStatusObj := &metav1.Status{Status: metav1.StatusSuccess, Message: "success message"}
errorStatusObj := &metav1.Status{Status: metav1.StatusFailure, Message: "error message"}
timeoutFunc := func() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.TODO(), time.Second)
}
testcases := []struct {
name string
timeout time.Duration
timeout func() (context.Context, context.CancelFunc)
fn resultFunc
expectedObj runtime.Object
expectedErr error
@ -854,7 +847,7 @@ func TestFinishRequest(t *testing.T) {
}{
{
name: "Expected obj is returned",
timeout: time.Second,
timeout: timeoutFunc,
fn: func() (runtime.Object, error) {
return exampleObj, nil
},
@ -863,7 +856,7 @@ func TestFinishRequest(t *testing.T) {
},
{
name: "Expected error is returned",
timeout: time.Second,
timeout: timeoutFunc,
fn: func() (runtime.Object, error) {
return nil, exampleErr
},
@ -872,7 +865,7 @@ func TestFinishRequest(t *testing.T) {
},
{
name: "Successful status object is returned as expected",
timeout: time.Second,
timeout: timeoutFunc,
fn: func() (runtime.Object, error) {
return successStatusObj, nil
},
@ -881,7 +874,7 @@ func TestFinishRequest(t *testing.T) {
},
{
name: "Error status object is converted to StatusError",
timeout: time.Second,
timeout: timeoutFunc,
fn: func() (runtime.Object, error) {
return errorStatusObj, nil
},
@ -890,7 +883,7 @@ func TestFinishRequest(t *testing.T) {
},
{
name: "Panic is propagated up",
timeout: time.Second,
timeout: timeoutFunc,
fn: func() (runtime.Object, error) {
panic("my panic")
},
@ -900,7 +893,7 @@ func TestFinishRequest(t *testing.T) {
},
{
name: "Panic is propagated with stack",
timeout: time.Second,
timeout: timeoutFunc,
fn: func() (runtime.Object, error) {
panic("my panic")
},
@ -910,7 +903,7 @@ func TestFinishRequest(t *testing.T) {
},
{
name: "http.ErrAbortHandler panic is propagated without wrapping with stack",
timeout: time.Second,
timeout: timeoutFunc,
fn: func() (runtime.Object, error) {
panic(http.ErrAbortHandler)
},
@ -922,7 +915,10 @@ func TestFinishRequest(t *testing.T) {
}
for i, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
ctx, cancel := tc.timeout()
defer func() {
cancel()
r := recover()
switch {
case r == nil && len(tc.expectedPanic) > 0:
@ -937,7 +933,7 @@ func TestFinishRequest(t *testing.T) {
t.Errorf("expected panic obj %#v, got %#v", tc.expectedPanicObj, r)
}
}()
obj, err := finishRequest(tc.timeout, tc.fn)
obj, err := finishRequest(ctx, tc.fn)
if (err == nil && tc.expectedErr != nil) || (err != nil && tc.expectedErr == nil) || (err != nil && tc.expectedErr != nil && err.Error() != tc.expectedErr.Error()) {
t.Errorf("%d: unexpected err. expected: %v, got: %v", i, tc.expectedErr, err)
}

View File

@ -54,16 +54,17 @@ func UpdateResource(r rest.Updater, scope *RequestScope, admit admission.Interfa
return
}
// TODO: we either want to remove timeout or document it (if we document, move timeout out of this function and declare it in api_installer)
timeout := parseTimeout(req.URL.Query().Get("timeout"))
namespace, name, err := scope.Namer.Name(req)
if err != nil {
scope.err(err, w, req)
return
}
ctx, cancel := context.WithTimeout(req.Context(), timeout)
// enforce a timeout of at most requestTimeoutUpperBound (34s) or less if the user-provided
// timeout inside the parent context is lower than requestTimeoutUpperBound.
ctx, cancel := context.WithTimeout(req.Context(), requestTimeoutUpperBound)
defer cancel()
ctx = request.WithNamespace(ctx, namespace)
outputMediaType, _, err := negotiation.NegotiateOutputMediaType(req, scope.Serializer, scope)
@ -195,7 +196,7 @@ func UpdateResource(r rest.Updater, scope *RequestScope, admit admission.Interfa
}
// Dedup owner references before updating managed fields
dedupOwnerReferencesAndAddWarning(obj, req.Context(), false)
result, err := finishRequest(timeout, func() (runtime.Object, error) {
result, err := finishRequest(ctx, func() (runtime.Object, error) {
result, err := requestFunc()
// If the object wasn't committed to storage because it's serialized size was too large,
// it is safe to remove managedFields (which can be large) and try again.

View File

@ -746,7 +746,13 @@ func DefaultBuildHandlerChain(apiHandler http.Handler, c *Config) http.Handler {
handler = filterlatency.TrackStarted(handler, "authentication")
handler = genericfilters.WithCORS(handler, c.CorsAllowedOriginList, nil, nil, nil, "true")
handler = genericfilters.WithTimeoutForNonLongRunningRequests(handler, c.LongRunningFunc, c.RequestTimeout)
// WithTimeoutForNonLongRunningRequests will call the rest of the request handling in a go-routine with the
// context with deadline. The go-routine can keep running, while the timeout logic will return a timeout to the client.
handler = genericfilters.WithTimeoutForNonLongRunningRequests(handler, c.LongRunningFunc)
handler = genericapifilters.WithRequestDeadline(handler, c.AuditBackend, c.AuditPolicyChecker,
c.LongRunningFunc, c.Serializer, c.RequestTimeout)
handler = genericfilters.WithWaitGroup(handler, c.LongRunningFunc, c.HandlerChainWaitGroup)
handler = genericapifilters.WithRequestInfo(handler, c.RequestInfoResolver)
if c.SecureServing != nil && !c.SecureServing.DisableHTTP2 && c.GoawayChance > 0 {

View File

@ -509,7 +509,7 @@ func newHandlerChain(t *testing.T, handler http.Handler, filter utilflowcontrol.
apfHandler.ServeHTTP(w, r)
})
handler = WithTimeoutForNonLongRunningRequests(handler, longRunningRequestCheck, requestTimeout)
handler = WithTimeoutForNonLongRunningRequests(handler, longRunningRequestCheck)
handler = apifilters.WithRequestInfo(handler, requestInfoFactory)
handler = WithPanicRecovery(handler, requestInfoFactory)
return handler

View File

@ -18,14 +18,12 @@ package filters
import (
"bufio"
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"runtime"
"sync"
"time"
apierrors "k8s.io/apimachinery/pkg/api/errors"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
@ -34,37 +32,33 @@ import (
)
// WithTimeoutForNonLongRunningRequests times out non-long-running requests after the time given by timeout.
func WithTimeoutForNonLongRunningRequests(handler http.Handler, longRunning apirequest.LongRunningRequestCheck, timeout time.Duration) http.Handler {
func WithTimeoutForNonLongRunningRequests(handler http.Handler, longRunning apirequest.LongRunningRequestCheck) http.Handler {
if longRunning == nil {
return handler
}
timeoutFunc := func(req *http.Request) (*http.Request, <-chan time.Time, func(), *apierrors.StatusError) {
timeoutFunc := func(req *http.Request) (*http.Request, bool, func(), *apierrors.StatusError) {
// TODO unify this with apiserver.MaxInFlightLimit
ctx := req.Context()
requestInfo, ok := apirequest.RequestInfoFrom(ctx)
if !ok {
// if this happens, the handler chain isn't setup correctly because there is no request info
return req, time.After(timeout), func() {}, apierrors.NewInternalError(fmt.Errorf("no request info found for request during timeout"))
return req, false, func() {}, apierrors.NewInternalError(fmt.Errorf("no request info found for request during timeout"))
}
if longRunning(req, requestInfo) {
return req, nil, nil, nil
return req, true, nil, nil
}
ctx, cancel := context.WithCancel(ctx)
req = req.WithContext(ctx)
postTimeoutFn := func() {
cancel()
metrics.RecordRequestTermination(req, requestInfo, metrics.APIServerComponent, http.StatusGatewayTimeout)
}
return req, time.After(timeout), postTimeoutFn, apierrors.NewTimeoutError(fmt.Sprintf("request did not complete within %s", timeout), 0)
return req, false, postTimeoutFn, apierrors.NewTimeoutError("request did not complete within the allotted timeout", 0)
}
return WithTimeout(handler, timeoutFunc)
}
type timeoutFunc = func(*http.Request) (req *http.Request, timeout <-chan time.Time, postTimeoutFunc func(), err *apierrors.StatusError)
type timeoutFunc = func(*http.Request) (req *http.Request, longRunning bool, postTimeoutFunc func(), err *apierrors.StatusError)
// WithTimeout returns an http.Handler that runs h with a timeout
// determined by timeoutFunc. The new http.Handler calls h.ServeHTTP to handle
@ -85,12 +79,14 @@ type timeoutHandler struct {
}
func (t *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
r, after, postTimeoutFn, err := t.timeout(r)
if after == nil {
r, longRunning, postTimeoutFn, err := t.timeout(r)
if longRunning {
t.handler.ServeHTTP(w, r)
return
}
timeoutCh := r.Context().Done()
// resultCh is used as both errCh and stopCh
resultCh := make(chan interface{})
tw := newTimeoutWriter(w)
@ -117,7 +113,7 @@ func (t *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
panic(err)
}
return
case <-after:
case <-timeoutCh:
defer func() {
// resultCh needs to have a reader, since the function doing
// the work needs to send to it. This is defer'd to ensure it runs

View File

@ -18,6 +18,7 @@ package filters
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
@ -92,18 +93,27 @@ func TestTimeout(t *testing.T) {
timeoutErr := apierrors.NewServerTimeout(schema.GroupResource{Group: "foo", Resource: "bar"}, "get", 0)
record := &recorder{}
var ctx context.Context
withDeadline := func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req = req.WithContext(ctx)
handler.ServeHTTP(w, req)
})
}
handler := newHandler(sendResponse, doPanic, writeErrors)
ts := httptest.NewServer(withPanicRecovery(
WithTimeout(handler, func(req *http.Request) (*http.Request, <-chan time.Time, func(), *apierrors.StatusError) {
return req, timeout, record.Record, timeoutErr
ts := httptest.NewServer(withDeadline(withPanicRecovery(
WithTimeout(handler, func(req *http.Request) (*http.Request, bool, func(), *apierrors.StatusError) {
return req, false, record.Record, timeoutErr
}), func(w http.ResponseWriter, req *http.Request, err interface{}) {
gotPanic <- err
http.Error(w, "This request caused apiserver to panic. Look in the logs for details.", http.StatusInternalServerError)
}),
)
))
defer ts.Close()
// No timeouts
ctx = context.Background()
sendResponse <- resp
res, err := http.Get(ts.URL)
if err != nil {
@ -124,6 +134,8 @@ func TestTimeout(t *testing.T) {
}
// Times out
ctx, cancel := context.WithCancel(context.Background())
cancel()
timeout <- time.Time{}
res, err = http.Get(ts.URL)
if err != nil {
@ -145,6 +157,7 @@ func TestTimeout(t *testing.T) {
}
// Now try to send a response
ctx = context.Background()
sendResponse <- resp
if err := <-writeErrors; err != http.ErrHandlerTimeout {
t.Errorf("got Write error of %v; expected %v", err, http.ErrHandlerTimeout)
@ -170,6 +183,7 @@ func TestTimeout(t *testing.T) {
}
// Panics with http.ErrAbortHandler
ctx = context.Background()
doPanic <- http.ErrAbortHandler
res, err = http.Get(ts.URL)
if err != nil {