From 32ddb5c9d2f6b7dfff06e984d23db81b1cd78a2c Mon Sep 17 00:00:00 2001 From: deads2k Date: Thu, 26 Jan 2017 14:39:54 -0500 Subject: [PATCH] move genericapiserver/server/filters to apiserver --- pkg/server/filters/OWNERS | 3 + pkg/server/filters/cors.go | 98 +++++++++ pkg/server/filters/cors_test.go | 183 +++++++++++++++++ pkg/server/filters/doc.go | 19 ++ pkg/server/filters/longrunning.go | 40 ++++ pkg/server/filters/maxinflight.go | 111 ++++++++++ pkg/server/filters/maxinflight_test.go | 240 ++++++++++++++++++++++ pkg/server/filters/timeout.go | 271 +++++++++++++++++++++++++ pkg/server/filters/timeout_test.go | 85 ++++++++ pkg/server/filters/wrap.go | 76 +++++++ 10 files changed, 1126 insertions(+) create mode 100755 pkg/server/filters/OWNERS create mode 100644 pkg/server/filters/cors.go create mode 100644 pkg/server/filters/cors_test.go create mode 100644 pkg/server/filters/doc.go create mode 100644 pkg/server/filters/longrunning.go create mode 100644 pkg/server/filters/maxinflight.go create mode 100644 pkg/server/filters/maxinflight_test.go create mode 100644 pkg/server/filters/timeout.go create mode 100644 pkg/server/filters/timeout_test.go create mode 100644 pkg/server/filters/wrap.go diff --git a/pkg/server/filters/OWNERS b/pkg/server/filters/OWNERS new file mode 100755 index 000000000..121af9571 --- /dev/null +++ b/pkg/server/filters/OWNERS @@ -0,0 +1,3 @@ +reviewers: +- sttts +- dims diff --git a/pkg/server/filters/cors.go b/pkg/server/filters/cors.go new file mode 100644 index 000000000..2c6e66ed6 --- /dev/null +++ b/pkg/server/filters/cors.go @@ -0,0 +1,98 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filters + +import ( + "net/http" + "regexp" + "strings" + + "github.com/golang/glog" +) + +// TODO: use restful.CrossOriginResourceSharing +// See github.com/emicklei/go-restful/blob/master/examples/restful-CORS-filter.go, and +// github.com/emicklei/go-restful/blob/master/examples/restful-basic-authentication.go +// Or, for a more detailed implementation use https://github.com/martini-contrib/cors +// or implement CORS at your proxy layer. + +// WithCORS is a simple CORS implementation that wraps an http Handler. +// Pass nil for allowedMethods and allowedHeaders to use the defaults. If allowedOriginPatterns +// is empty or nil, no CORS support is installed. +func WithCORS(handler http.Handler, allowedOriginPatterns []string, allowedMethods []string, allowedHeaders []string, exposedHeaders []string, allowCredentials string) http.Handler { + if len(allowedOriginPatterns) == 0 { + return handler + } + allowedOriginPatternsREs := allowedOriginRegexps(allowedOriginPatterns) + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + origin := req.Header.Get("Origin") + if origin != "" { + allowed := false + for _, re := range allowedOriginPatternsREs { + if allowed = re.MatchString(origin); allowed { + break + } + } + if allowed { + w.Header().Set("Access-Control-Allow-Origin", origin) + // Set defaults for methods and headers if nothing was passed + if allowedMethods == nil { + allowedMethods = []string{"POST", "GET", "OPTIONS", "PUT", "DELETE", "PATCH"} + } + if allowedHeaders == nil { + allowedHeaders = []string{"Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization", "X-Requested-With", "If-Modified-Since"} + } + if exposedHeaders == nil { + exposedHeaders = []string{"Date"} + } + w.Header().Set("Access-Control-Allow-Methods", strings.Join(allowedMethods, ", ")) + w.Header().Set("Access-Control-Allow-Headers", strings.Join(allowedHeaders, ", ")) + w.Header().Set("Access-Control-Expose-Headers", strings.Join(exposedHeaders, ", ")) + w.Header().Set("Access-Control-Allow-Credentials", allowCredentials) + + // Stop here if its a preflight OPTIONS request + if req.Method == "OPTIONS" { + w.WriteHeader(http.StatusNoContent) + return + } + } + } + // Dispatch to the next handler + handler.ServeHTTP(w, req) + }) +} + +func allowedOriginRegexps(allowedOrigins []string) []*regexp.Regexp { + res, err := compileRegexps(allowedOrigins) + if err != nil { + glog.Fatalf("Invalid CORS allowed origin, --cors-allowed-origins flag was set to %v - %v", strings.Join(allowedOrigins, ","), err) + } + return res +} + +// Takes a list of strings and compiles them into a list of regular expressions +func compileRegexps(regexpStrings []string) ([]*regexp.Regexp, error) { + regexps := []*regexp.Regexp{} + for _, regexpStr := range regexpStrings { + r, err := regexp.Compile(regexpStr) + if err != nil { + return []*regexp.Regexp{}, err + } + regexps = append(regexps, r) + } + return regexps, nil +} diff --git a/pkg/server/filters/cors_test.go b/pkg/server/filters/cors_test.go new file mode 100644 index 000000000..6e8b46236 --- /dev/null +++ b/pkg/server/filters/cors_test.go @@ -0,0 +1,183 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filters + +import ( + "net/http" + "net/http/httptest" + "reflect" + "strings" + "testing" +) + +func TestCORSAllowedOrigins(t *testing.T) { + table := []struct { + allowedOrigins []string + origin string + allowed bool + }{ + {[]string{}, "example.com", false}, + {[]string{"example.com"}, "example.com", true}, + {[]string{"example.com"}, "not-allowed.com", false}, + {[]string{"not-matching.com", "example.com"}, "example.com", true}, + {[]string{".*"}, "example.com", true}, + } + + for _, item := range table { + handler := WithCORS( + http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}), + item.allowedOrigins, nil, nil, nil, "true", + ) + server := httptest.NewServer(handler) + defer server.Close() + client := http.Client{} + + request, err := http.NewRequest("GET", server.URL+"/version", nil) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + request.Header.Set("Origin", item.origin) + + response, err := client.Do(request) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if item.allowed { + if !reflect.DeepEqual(item.origin, response.Header.Get("Access-Control-Allow-Origin")) { + t.Errorf("Expected %#v, Got %#v", item.origin, response.Header.Get("Access-Control-Allow-Origin")) + } + + if response.Header.Get("Access-Control-Allow-Credentials") == "" { + t.Errorf("Expected Access-Control-Allow-Credentials header to be set") + } + + if response.Header.Get("Access-Control-Allow-Headers") == "" { + t.Errorf("Expected Access-Control-Allow-Headers header to be set") + } + + if response.Header.Get("Access-Control-Allow-Methods") == "" { + t.Errorf("Expected Access-Control-Allow-Methods header to be set") + } + + if response.Header.Get("Access-Control-Expose-Headers") != "Date" { + t.Errorf("Expected Date in Access-Control-Expose-Headers header") + } + } else { + if response.Header.Get("Access-Control-Allow-Origin") != "" { + t.Errorf("Expected Access-Control-Allow-Origin header to not be set") + } + + if response.Header.Get("Access-Control-Allow-Credentials") != "" { + t.Errorf("Expected Access-Control-Allow-Credentials header to not be set") + } + + if response.Header.Get("Access-Control-Allow-Headers") != "" { + t.Errorf("Expected Access-Control-Allow-Headers header to not be set") + } + + if response.Header.Get("Access-Control-Allow-Methods") != "" { + t.Errorf("Expected Access-Control-Allow-Methods header to not be set") + } + + if response.Header.Get("Access-Control-Expose-Headers") == "Date" { + t.Errorf("Expected Date in Access-Control-Expose-Headers header") + } + } + } +} + +func TestCORSAllowedMethods(t *testing.T) { + tests := []struct { + allowedMethods []string + method string + allowed bool + }{ + {nil, "POST", true}, + {nil, "GET", true}, + {nil, "OPTIONS", true}, + {nil, "PUT", true}, + {nil, "DELETE", true}, + {nil, "PATCH", true}, + {[]string{"GET", "POST"}, "PATCH", false}, + } + + allowsMethod := func(res *http.Response, method string) bool { + allowedMethods := strings.Split(res.Header.Get("Access-Control-Allow-Methods"), ",") + for _, allowedMethod := range allowedMethods { + if strings.TrimSpace(allowedMethod) == method { + return true + } + } + return false + } + + for _, test := range tests { + handler := WithCORS( + http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}), + []string{".*"}, test.allowedMethods, nil, nil, "true", + ) + server := httptest.NewServer(handler) + defer server.Close() + client := http.Client{} + + request, err := http.NewRequest(test.method, server.URL+"/version", nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + request.Header.Set("Origin", "allowed.com") + + response, err := client.Do(request) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + methodAllowed := allowsMethod(response, test.method) + switch { + case test.allowed && !methodAllowed: + t.Errorf("Expected %v to be allowed, Got only %#v", test.method, response.Header.Get("Access-Control-Allow-Methods")) + case !test.allowed && methodAllowed: + t.Errorf("Unexpected allowed method %v, Expected only %#v", test.method, response.Header.Get("Access-Control-Allow-Methods")) + } + } + +} + +func TestCompileRegex(t *testing.T) { + uncompiledRegexes := []string{"endsWithMe$", "^startingWithMe"} + regexes, err := compileRegexps(uncompiledRegexes) + + if err != nil { + t.Errorf("Failed to compile legal regexes: '%v': %v", uncompiledRegexes, err) + } + if len(regexes) != len(uncompiledRegexes) { + t.Errorf("Wrong number of regexes returned: '%v': %v", uncompiledRegexes, regexes) + } + + if !regexes[0].MatchString("Something that endsWithMe") { + t.Errorf("Wrong regex returned: '%v': %v", uncompiledRegexes[0], regexes[0]) + } + if regexes[0].MatchString("Something that doesn't endsWithMe.") { + t.Errorf("Wrong regex returned: '%v': %v", uncompiledRegexes[0], regexes[0]) + } + if !regexes[1].MatchString("startingWithMe is very important") { + t.Errorf("Wrong regex returned: '%v': %v", uncompiledRegexes[1], regexes[1]) + } + if regexes[1].MatchString("not startingWithMe should fail") { + t.Errorf("Wrong regex returned: '%v': %v", uncompiledRegexes[1], regexes[1]) + } +} diff --git a/pkg/server/filters/doc.go b/pkg/server/filters/doc.go new file mode 100644 index 000000000..a90cc3b49 --- /dev/null +++ b/pkg/server/filters/doc.go @@ -0,0 +1,19 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package filters contains all the http handler chain filters which +// are not api related. +package filters // import "k8s.io/apiserver/pkg/server/filters" diff --git a/pkg/server/filters/longrunning.go b/pkg/server/filters/longrunning.go new file mode 100644 index 000000000..4ea58625b --- /dev/null +++ b/pkg/server/filters/longrunning.go @@ -0,0 +1,40 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filters + +import ( + "net/http" + + "k8s.io/apimachinery/pkg/util/sets" + apirequest "k8s.io/apiserver/pkg/endpoints/request" +) + +// LongRunningRequestCheck is a predicate which is true for long-running http requests. +type LongRunningRequestCheck func(r *http.Request, requestInfo *apirequest.RequestInfo) bool + +// BasicLongRunningRequestCheck returns true if the given request has one of the specified verbs or one of the specified subresources +func BasicLongRunningRequestCheck(longRunningVerbs, longRunningSubresources sets.String) LongRunningRequestCheck { + return func(r *http.Request, requestInfo *apirequest.RequestInfo) bool { + if longRunningVerbs.Has(requestInfo.Verb) { + return true + } + if requestInfo.IsResourceRequest && longRunningSubresources.Has(requestInfo.Subresource) { + return true + } + return false + } +} diff --git a/pkg/server/filters/maxinflight.go b/pkg/server/filters/maxinflight.go new file mode 100644 index 000000000..c480d7b6e --- /dev/null +++ b/pkg/server/filters/maxinflight.go @@ -0,0 +1,111 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filters + +import ( + "fmt" + "net/http" + + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/util/sets" + apirequest "k8s.io/apiserver/pkg/endpoints/request" + genericapirequest "k8s.io/apiserver/pkg/endpoints/request" + "k8s.io/apiserver/pkg/server/httplog" + + "github.com/golang/glog" +) + +// Constant for the retry-after interval on rate limiting. +// TODO: maybe make this dynamic? or user-adjustable? +const retryAfter = "1" + +var nonMutatingRequestVerbs = sets.NewString("get", "list", "watch") + +func handleError(w http.ResponseWriter, r *http.Request, err error) { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintf(w, "Internal Server Error: %#v", r.RequestURI) + glog.Errorf(err.Error()) +} + +// WithMaxInFlightLimit limits the number of in-flight requests to buffer size of the passed in channel. +func WithMaxInFlightLimit( + handler http.Handler, + nonMutatingLimit int, + mutatingLimit int, + requestContextMapper genericapirequest.RequestContextMapper, + longRunningRequestCheck LongRunningRequestCheck, +) http.Handler { + if nonMutatingLimit == 0 && mutatingLimit == 0 { + return handler + } + var nonMutatingChan chan bool + var mutatingChan chan bool + if nonMutatingLimit != 0 { + nonMutatingChan = make(chan bool, nonMutatingLimit) + } + if mutatingLimit != 0 { + mutatingChan = make(chan bool, mutatingLimit) + } + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, ok := requestContextMapper.Get(r) + if !ok { + handleError(w, r, fmt.Errorf("no context found for request, handler chain must be wrong")) + return + } + requestInfo, ok := apirequest.RequestInfoFrom(ctx) + if !ok { + handleError(w, r, fmt.Errorf("no RequestInfo found in context, handler chain must be wrong")) + return + } + + // Skip tracking long running events. + if longRunningRequestCheck != nil && longRunningRequestCheck(r, requestInfo) { + handler.ServeHTTP(w, r) + return + } + + var c chan bool + if !nonMutatingRequestVerbs.Has(requestInfo.Verb) { + c = mutatingChan + } else { + c = nonMutatingChan + } + + if c == nil { + handler.ServeHTTP(w, r) + } else { + select { + case c <- true: + defer func() { <-c }() + handler.ServeHTTP(w, r) + default: + tooManyRequests(r, w) + } + } + }) +} + +func tooManyRequests(req *http.Request, w http.ResponseWriter) { + // "Too Many Requests" response is returned before logger is setup for the request. + // So we need to explicitly log it here. + defer httplog.NewLogged(req, &w).Log() + + // Return a 429 status indicating "Too Many Requests" + w.Header().Set("Retry-After", retryAfter) + http.Error(w, "Too many requests, please try again later.", errors.StatusTooManyRequests) +} diff --git a/pkg/server/filters/maxinflight_test.go b/pkg/server/filters/maxinflight_test.go new file mode 100644 index 000000000..872fe074c --- /dev/null +++ b/pkg/server/filters/maxinflight_test.go @@ -0,0 +1,240 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filters + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/util/sets" + apirequest "k8s.io/apiserver/pkg/endpoints/request" + apifilters "k8s.io/kubernetes/pkg/genericapiserver/endpoints/filters" +) + +func createMaxInflightServer(callsWg, blockWg *sync.WaitGroup, disableCallsWg *bool, disableCallsWgMutex *sync.Mutex, nonMutating, mutating int) *httptest.Server { + + longRunningRequestCheck := BasicLongRunningRequestCheck(sets.NewString("watch"), sets.NewString("proxy")) + + requestContextMapper := apirequest.NewRequestContextMapper() + requestInfoFactory := &apirequest.RequestInfoFactory{APIPrefixes: sets.NewString("apis", "api"), GrouplessAPIPrefixes: sets.NewString("api")} + handler := WithMaxInFlightLimit( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // A short, accounted request that does not wait for block WaitGroup. + if strings.Contains(r.URL.Path, "dontwait") { + return + } + disableCallsWgMutex.Lock() + waitForCalls := *disableCallsWg + disableCallsWgMutex.Unlock() + if waitForCalls { + callsWg.Done() + } + blockWg.Wait() + }), + nonMutating, + mutating, + requestContextMapper, + longRunningRequestCheck, + ) + handler = apifilters.WithRequestInfo(handler, requestInfoFactory, requestContextMapper) + handler = apirequest.WithRequestContext(handler, requestContextMapper) + + return httptest.NewServer(handler) +} + +// Tests that MaxInFlightLimit works, i.e. +// - "long" requests such as proxy or watch, identified by regexp are not accounted despite +// hanging for the long time, +// - "short" requests are correctly accounted, i.e. there can be only size of channel passed to the +// constructor in flight at any given moment, +// - subsequent "short" requests are rejected instantly with appropriate error, +// - subsequent "long" requests are handled normally, +// - we correctly recover after some "short" requests finish, i.e. we can process new ones. +func TestMaxInFlightNonMutating(t *testing.T) { + const AllowedNonMutatingInflightRequestsNo = 3 + + // Calls is used to wait until all server calls are received. We are sending + // AllowedNonMutatingInflightRequestsNo of 'long' not-accounted requests and the same number of + // 'short' accounted ones. + calls := &sync.WaitGroup{} + calls.Add(AllowedNonMutatingInflightRequestsNo * 2) + + // Responses is used to wait until all responses are + // received. This prevents some async requests getting EOF + // errors from prematurely closing the server + responses := &sync.WaitGroup{} + responses.Add(AllowedNonMutatingInflightRequestsNo * 2) + + // Block is used to keep requests in flight for as long as we need to. All requests will + // be unblocked at the same time. + block := &sync.WaitGroup{} + block.Add(1) + + waitForCalls := true + waitForCallsMutex := sync.Mutex{} + + server := createMaxInflightServer(calls, block, &waitForCalls, &waitForCallsMutex, AllowedNonMutatingInflightRequestsNo, 1) + defer server.Close() + + // These should hang, but not affect accounting. use a query param match + for i := 0; i < AllowedNonMutatingInflightRequestsNo; i++ { + // These should hang waiting on block... + go func() { + if err := expectHTTPGet(server.URL+"/api/v1/namespaces/default/wait?watch=true", http.StatusOK); err != nil { + t.Error(err) + } + responses.Done() + }() + } + + // Check that sever is not saturated by not-accounted calls + if err := expectHTTPGet(server.URL+"/dontwait", http.StatusOK); err != nil { + t.Error(err) + } + + // These should hang and be accounted, i.e. saturate the server + for i := 0; i < AllowedNonMutatingInflightRequestsNo; i++ { + // These should hang waiting on block... + go func() { + if err := expectHTTPGet(server.URL, http.StatusOK); err != nil { + t.Error(err) + } + responses.Done() + }() + } + // We wait for all calls to be received by the server + calls.Wait() + // Disable calls notifications in the server + waitForCallsMutex.Lock() + waitForCalls = false + waitForCallsMutex.Unlock() + + // Do this multiple times to show that rate limit rejected requests don't block. + for i := 0; i < 2; i++ { + if err := expectHTTPGet(server.URL, errors.StatusTooManyRequests); err != nil { + t.Error(err) + } + } + // Validate that non-accounted URLs still work. use a path regex match + if err := expectHTTPGet(server.URL+"/api/v1/watch/namespaces/default/dontwait", http.StatusOK); err != nil { + t.Error(err) + } + + // We should allow a single mutating request. + if err := expectHTTPPost(server.URL+"/dontwait", http.StatusOK); err != nil { + t.Error(err) + } + + // Let all hanging requests finish + block.Done() + + // Show that we recover from being blocked up. + // Too avoid flakyness we need to wait until at least one of the requests really finishes. + responses.Wait() + if err := expectHTTPGet(server.URL, http.StatusOK); err != nil { + t.Error(err) + } +} + +func TestMaxInFlightMutating(t *testing.T) { + const AllowedMutatingInflightRequestsNo = 3 + + calls := &sync.WaitGroup{} + calls.Add(AllowedMutatingInflightRequestsNo) + + responses := &sync.WaitGroup{} + responses.Add(AllowedMutatingInflightRequestsNo) + + // Block is used to keep requests in flight for as long as we need to. All requests will + // be unblocked at the same time. + block := &sync.WaitGroup{} + block.Add(1) + + waitForCalls := true + waitForCallsMutex := sync.Mutex{} + + server := createMaxInflightServer(calls, block, &waitForCalls, &waitForCallsMutex, 1, AllowedMutatingInflightRequestsNo) + defer server.Close() + + // These should hang and be accounted, i.e. saturate the server + for i := 0; i < AllowedMutatingInflightRequestsNo; i++ { + // These should hang waiting on block... + go func() { + if err := expectHTTPPost(server.URL+"/foo/bar", http.StatusOK); err != nil { + t.Error(err) + } + responses.Done() + }() + } + // We wait for all calls to be received by the server + calls.Wait() + // Disable calls notifications in the server + // Disable calls notifications in the server + waitForCallsMutex.Lock() + waitForCalls = false + waitForCallsMutex.Unlock() + + // Do this multiple times to show that rate limit rejected requests don't block. + for i := 0; i < 2; i++ { + if err := expectHTTPPost(server.URL+"/foo/bar/", errors.StatusTooManyRequests); err != nil { + t.Error(err) + } + } + // Validate that non-mutating URLs still work. use a path regex match + if err := expectHTTPGet(server.URL+"/dontwait", http.StatusOK); err != nil { + t.Error(err) + } + + // Let all hanging requests finish + block.Done() + + // Show that we recover from being blocked up. + // Too avoid flakyness we need to wait until at least one of the requests really finishes. + responses.Wait() + if err := expectHTTPPost(server.URL+"/foo/bar", http.StatusOK); err != nil { + t.Error(err) + } +} + +// We use GET as a sample non-mutating request. +func expectHTTPGet(url string, code int) error { + r, err := http.Get(url) + if err != nil { + return fmt.Errorf("unexpected error: %v", err) + } + if r.StatusCode != code { + return fmt.Errorf("unexpected response: %v", r.StatusCode) + } + return nil +} + +// We use POST as a sample mutating request. +func expectHTTPPost(url string, code int) error { + r, err := http.Post(url, "text/html", strings.NewReader("foo bar")) + if err != nil { + return fmt.Errorf("unexpected error: %v", err) + } + if r.StatusCode != code { + return fmt.Errorf("unexpected response: %v", r.StatusCode) + } + return nil +} diff --git a/pkg/server/filters/timeout.go b/pkg/server/filters/timeout.go new file mode 100644 index 000000000..9232fcb51 --- /dev/null +++ b/pkg/server/filters/timeout.go @@ -0,0 +1,271 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filters + +import ( + "bufio" + "encoding/json" + "fmt" + "net" + "net/http" + "sync" + "time" + + apierrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/runtime/schema" + apirequest "k8s.io/apiserver/pkg/endpoints/request" +) + +const globalTimeout = time.Minute + +var errConnKilled = fmt.Errorf("kill connection/stream") + +// WithTimeoutForNonLongRunningRequests times out non-long-running requests after the time given by globalTimeout. +func WithTimeoutForNonLongRunningRequests(handler http.Handler, requestContextMapper apirequest.RequestContextMapper, longRunning LongRunningRequestCheck) http.Handler { + if longRunning == nil { + return handler + } + timeoutFunc := func(req *http.Request) (<-chan time.Time, *apierrors.StatusError) { + // TODO unify this with apiserver.MaxInFlightLimit + ctx, ok := requestContextMapper.Get(req) + if !ok { + // if this happens, the handler chain isn't setup correctly because there is no context mapper + return time.After(globalTimeout), apierrors.NewInternalError(fmt.Errorf("no context found for request during timeout")) + } + + requestInfo, ok := apirequest.RequestInfoFrom(ctx) + if !ok { + // if this happens, the handler chain isn't setup correctly because there is no request info + return time.After(globalTimeout), apierrors.NewInternalError(fmt.Errorf("no request info found for request during timeout")) + } + + if longRunning(req, requestInfo) { + return nil, nil + } + return time.After(globalTimeout), apierrors.NewServerTimeout(schema.GroupResource{Group: requestInfo.APIGroup, Resource: requestInfo.Resource}, requestInfo.Verb, 0) + } + return WithTimeout(handler, timeoutFunc) +} + +// WithTimeout returns an http.Handler that runs h with a timeout +// determined by timeoutFunc. The new http.Handler calls h.ServeHTTP to handle +// each request, but if a call runs for longer than its time limit, the +// handler responds with a 503 Service Unavailable error and the message +// provided. (If msg is empty, a suitable default message will be sent.) After +// the handler times out, writes by h to its http.ResponseWriter will return +// http.ErrHandlerTimeout. If timeoutFunc returns a nil timeout channel, no +// timeout will be enforced. +func WithTimeout(h http.Handler, timeoutFunc func(*http.Request) (timeout <-chan time.Time, err *apierrors.StatusError)) http.Handler { + return &timeoutHandler{h, timeoutFunc} +} + +type timeoutHandler struct { + handler http.Handler + timeout func(*http.Request) (<-chan time.Time, *apierrors.StatusError) +} + +func (t *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + after, err := t.timeout(r) + if after == nil { + t.handler.ServeHTTP(w, r) + return + } + + done := make(chan struct{}) + tw := newTimeoutWriter(w) + go func() { + t.handler.ServeHTTP(tw, r) + close(done) + }() + select { + case <-done: + return + case <-after: + tw.timeout(err) + } +} + +type timeoutWriter interface { + http.ResponseWriter + timeout(*apierrors.StatusError) +} + +func newTimeoutWriter(w http.ResponseWriter) timeoutWriter { + base := &baseTimeoutWriter{w: w} + + _, notifiable := w.(http.CloseNotifier) + _, hijackable := w.(http.Hijacker) + + switch { + case notifiable && hijackable: + return &closeHijackTimeoutWriter{base} + case notifiable: + return &closeTimeoutWriter{base} + case hijackable: + return &hijackTimeoutWriter{base} + default: + return base + } +} + +type baseTimeoutWriter struct { + w http.ResponseWriter + + mu sync.Mutex + // if the timeout handler has timedout + timedOut bool + // if this timeout writer has wrote header + wroteHeader bool + // if this timeout writer has been hijacked + hijacked bool +} + +func (tw *baseTimeoutWriter) Header() http.Header { + tw.mu.Lock() + defer tw.mu.Unlock() + + if tw.timedOut { + return http.Header{} + } + + return tw.w.Header() +} + +func (tw *baseTimeoutWriter) Write(p []byte) (int, error) { + tw.mu.Lock() + defer tw.mu.Unlock() + + if tw.timedOut { + return 0, http.ErrHandlerTimeout + } + if tw.hijacked { + return 0, http.ErrHijacked + } + + tw.wroteHeader = true + return tw.w.Write(p) +} + +func (tw *baseTimeoutWriter) Flush() { + tw.mu.Lock() + defer tw.mu.Unlock() + + if tw.timedOut { + return + } + + if flusher, ok := tw.w.(http.Flusher); ok { + flusher.Flush() + } +} + +func (tw *baseTimeoutWriter) WriteHeader(code int) { + tw.mu.Lock() + defer tw.mu.Unlock() + + if tw.timedOut || tw.wroteHeader || tw.hijacked { + return + } + + tw.wroteHeader = true + tw.w.WriteHeader(code) +} + +func (tw *baseTimeoutWriter) timeout(err *apierrors.StatusError) { + tw.mu.Lock() + defer tw.mu.Unlock() + + tw.timedOut = true + + // The timeout writer has not been used by the inner handler. + // We can safely timeout the HTTP request by sending by a timeout + // handler + if !tw.wroteHeader && !tw.hijacked { + tw.w.WriteHeader(http.StatusGatewayTimeout) + enc := json.NewEncoder(tw.w) + enc.Encode(err) + } else { + // The timeout writer has been used by the inner handler. There is + // no way to timeout the HTTP request at the point. We have to shutdown + // the connection for HTTP1 or reset stream for HTTP2. + // + // Note from: Brad Fitzpatrick + // if the ServeHTTP goroutine panics, that will do the best possible thing for both + // HTTP/1 and HTTP/2. In HTTP/1, assuming you're replying with at least HTTP/1.1 and + // you've already flushed the headers so it's using HTTP chunking, it'll kill the TCP + // connection immediately without a proper 0-byte EOF chunk, so the peer will recognize + // the response as bogus. In HTTP/2 the server will just RST_STREAM the stream, leaving + // the TCP connection open, but resetting the stream to the peer so it'll have an error, + // like the HTTP/1 case. + panic(errConnKilled) + } +} + +func (tw *baseTimeoutWriter) closeNotify() <-chan bool { + tw.mu.Lock() + defer tw.mu.Unlock() + + if tw.timedOut { + done := make(chan bool) + close(done) + return done + } + + return tw.w.(http.CloseNotifier).CloseNotify() +} + +func (tw *baseTimeoutWriter) hijack() (net.Conn, *bufio.ReadWriter, error) { + tw.mu.Lock() + defer tw.mu.Unlock() + + if tw.timedOut { + return nil, nil, http.ErrHandlerTimeout + } + conn, rw, err := tw.w.(http.Hijacker).Hijack() + if err == nil { + tw.hijacked = true + } + return conn, rw, err +} + +type closeTimeoutWriter struct { + *baseTimeoutWriter +} + +func (tw *closeTimeoutWriter) CloseNotify() <-chan bool { + return tw.closeNotify() +} + +type hijackTimeoutWriter struct { + *baseTimeoutWriter +} + +func (tw *hijackTimeoutWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return tw.hijack() +} + +type closeHijackTimeoutWriter struct { + *baseTimeoutWriter +} + +func (tw *closeHijackTimeoutWriter) CloseNotify() <-chan bool { + return tw.closeNotify() +} + +func (tw *closeHijackTimeoutWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return tw.hijack() +} diff --git a/pkg/server/filters/timeout_test.go b/pkg/server/filters/timeout_test.go new file mode 100644 index 000000000..449812241 --- /dev/null +++ b/pkg/server/filters/timeout_test.go @@ -0,0 +1,85 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filters + +import ( + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + "time" + + apierrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/runtime/schema" + "strings" +) + +func TestTimeout(t *testing.T) { + sendResponse := make(chan struct{}, 1) + writeErrors := make(chan error, 1) + timeout := make(chan time.Time, 1) + resp := "test response" + timeoutErr := apierrors.NewServerTimeout(schema.GroupResource{Group: "foo", Resource: "bar"}, "get", 0) + + ts := httptest.NewServer(WithTimeout(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + <-sendResponse + _, err := w.Write([]byte(resp)) + writeErrors <- err + }), + func(*http.Request) (<-chan time.Time, *apierrors.StatusError) { + return timeout, timeoutErr + })) + defer ts.Close() + + // No timeouts + sendResponse <- struct{}{} + res, err := http.Get(ts.URL) + if err != nil { + t.Error(err) + } + if res.StatusCode != http.StatusOK { + t.Errorf("got res.StatusCode %d; expected %d", res.StatusCode, http.StatusOK) + } + body, _ := ioutil.ReadAll(res.Body) + if string(body) != resp { + t.Errorf("got body %q; expected %q", string(body), resp) + } + if err := <-writeErrors; err != nil { + t.Errorf("got unexpected Write error on first request: %v", err) + } + + // Times out + timeout <- time.Time{} + res, err = http.Get(ts.URL) + if err != nil { + t.Error(err) + } + if res.StatusCode != http.StatusGatewayTimeout { + t.Errorf("got res.StatusCode %d; expected %d", res.StatusCode, http.StatusServiceUnavailable) + } + body, _ = ioutil.ReadAll(res.Body) + if !strings.Contains(string(body), timeoutErr.Error()) { + t.Errorf("got body %q; expected it to contain %q", string(body), timeoutErr.Error()) + } + + // Now try to send a response + sendResponse <- struct{}{} + if err := <-writeErrors; err != http.ErrHandlerTimeout { + t.Errorf("got Write error of %v; expected %v", err, http.ErrHandlerTimeout) + } +} diff --git a/pkg/server/filters/wrap.go b/pkg/server/filters/wrap.go new file mode 100644 index 000000000..4f651360c --- /dev/null +++ b/pkg/server/filters/wrap.go @@ -0,0 +1,76 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filters + +import ( + "net/http" + "runtime/debug" + + "github.com/golang/glog" + + apierrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/util/runtime" + apirequest "k8s.io/apiserver/pkg/endpoints/request" + "k8s.io/apiserver/pkg/server/httplog" +) + +// WithPanicRecovery wraps an http Handler to recover and log panics. +func WithPanicRecovery(handler http.Handler, requestContextMapper apirequest.RequestContextMapper) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + defer runtime.HandleCrash(func(err interface{}) { + http.Error(w, "This request caused apisever to panic. Look in log for details.", http.StatusInternalServerError) + glog.Errorf("APIServer panic'd on %v %v: %v\n%s\n", req.Method, req.RequestURI, err, debug.Stack()) + }) + + logger := httplog.NewLogged(req, &w) + + var requestInfo *apirequest.RequestInfo + ctx, ok := requestContextMapper.Get(req) + if !ok { + glog.Errorf("no context found for request, handler chain must be wrong") + } else { + requestInfo, ok = apirequest.RequestInfoFrom(ctx) + if !ok { + glog.Errorf("no RequestInfo found in context, handler chain must be wrong") + } + } + + if !ok || requestInfo.Verb != "proxy" { + logger.StacktraceWhen( + httplog.StatusIsNot( + http.StatusOK, + http.StatusCreated, + http.StatusAccepted, + http.StatusBadRequest, + http.StatusMovedPermanently, + http.StatusTemporaryRedirect, + http.StatusConflict, + http.StatusNotFound, + http.StatusUnauthorized, + http.StatusForbidden, + http.StatusNotModified, + apierrors.StatusUnprocessableEntity, + http.StatusSwitchingProtocols, + ), + ) + } + defer logger.Log() + + // Dispatch to the internal handler + handler.ServeHTTP(w, req) + }) +}