move genericapiserver/server/filters to apiserver

This commit is contained in:
deads2k 2017-01-26 14:39:54 -05:00
parent 848a905661
commit 32ddb5c9d2
10 changed files with 1126 additions and 0 deletions

3
pkg/server/filters/OWNERS Executable file
View File

@ -0,0 +1,3 @@
reviewers:
- sttts
- dims

View File

@ -0,0 +1,98 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package filters
import (
"net/http"
"regexp"
"strings"
"github.com/golang/glog"
)
// TODO: use restful.CrossOriginResourceSharing
// See github.com/emicklei/go-restful/blob/master/examples/restful-CORS-filter.go, and
// github.com/emicklei/go-restful/blob/master/examples/restful-basic-authentication.go
// Or, for a more detailed implementation use https://github.com/martini-contrib/cors
// or implement CORS at your proxy layer.
// WithCORS is a simple CORS implementation that wraps an http Handler.
// Pass nil for allowedMethods and allowedHeaders to use the defaults. If allowedOriginPatterns
// is empty or nil, no CORS support is installed.
func WithCORS(handler http.Handler, allowedOriginPatterns []string, allowedMethods []string, allowedHeaders []string, exposedHeaders []string, allowCredentials string) http.Handler {
if len(allowedOriginPatterns) == 0 {
return handler
}
allowedOriginPatternsREs := allowedOriginRegexps(allowedOriginPatterns)
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
origin := req.Header.Get("Origin")
if origin != "" {
allowed := false
for _, re := range allowedOriginPatternsREs {
if allowed = re.MatchString(origin); allowed {
break
}
}
if allowed {
w.Header().Set("Access-Control-Allow-Origin", origin)
// Set defaults for methods and headers if nothing was passed
if allowedMethods == nil {
allowedMethods = []string{"POST", "GET", "OPTIONS", "PUT", "DELETE", "PATCH"}
}
if allowedHeaders == nil {
allowedHeaders = []string{"Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization", "X-Requested-With", "If-Modified-Since"}
}
if exposedHeaders == nil {
exposedHeaders = []string{"Date"}
}
w.Header().Set("Access-Control-Allow-Methods", strings.Join(allowedMethods, ", "))
w.Header().Set("Access-Control-Allow-Headers", strings.Join(allowedHeaders, ", "))
w.Header().Set("Access-Control-Expose-Headers", strings.Join(exposedHeaders, ", "))
w.Header().Set("Access-Control-Allow-Credentials", allowCredentials)
// Stop here if its a preflight OPTIONS request
if req.Method == "OPTIONS" {
w.WriteHeader(http.StatusNoContent)
return
}
}
}
// Dispatch to the next handler
handler.ServeHTTP(w, req)
})
}
func allowedOriginRegexps(allowedOrigins []string) []*regexp.Regexp {
res, err := compileRegexps(allowedOrigins)
if err != nil {
glog.Fatalf("Invalid CORS allowed origin, --cors-allowed-origins flag was set to %v - %v", strings.Join(allowedOrigins, ","), err)
}
return res
}
// Takes a list of strings and compiles them into a list of regular expressions
func compileRegexps(regexpStrings []string) ([]*regexp.Regexp, error) {
regexps := []*regexp.Regexp{}
for _, regexpStr := range regexpStrings {
r, err := regexp.Compile(regexpStr)
if err != nil {
return []*regexp.Regexp{}, err
}
regexps = append(regexps, r)
}
return regexps, nil
}

View File

@ -0,0 +1,183 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package filters
import (
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
)
func TestCORSAllowedOrigins(t *testing.T) {
table := []struct {
allowedOrigins []string
origin string
allowed bool
}{
{[]string{}, "example.com", false},
{[]string{"example.com"}, "example.com", true},
{[]string{"example.com"}, "not-allowed.com", false},
{[]string{"not-matching.com", "example.com"}, "example.com", true},
{[]string{".*"}, "example.com", true},
}
for _, item := range table {
handler := WithCORS(
http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}),
item.allowedOrigins, nil, nil, nil, "true",
)
server := httptest.NewServer(handler)
defer server.Close()
client := http.Client{}
request, err := http.NewRequest("GET", server.URL+"/version", nil)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
request.Header.Set("Origin", item.origin)
response, err := client.Do(request)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if item.allowed {
if !reflect.DeepEqual(item.origin, response.Header.Get("Access-Control-Allow-Origin")) {
t.Errorf("Expected %#v, Got %#v", item.origin, response.Header.Get("Access-Control-Allow-Origin"))
}
if response.Header.Get("Access-Control-Allow-Credentials") == "" {
t.Errorf("Expected Access-Control-Allow-Credentials header to be set")
}
if response.Header.Get("Access-Control-Allow-Headers") == "" {
t.Errorf("Expected Access-Control-Allow-Headers header to be set")
}
if response.Header.Get("Access-Control-Allow-Methods") == "" {
t.Errorf("Expected Access-Control-Allow-Methods header to be set")
}
if response.Header.Get("Access-Control-Expose-Headers") != "Date" {
t.Errorf("Expected Date in Access-Control-Expose-Headers header")
}
} else {
if response.Header.Get("Access-Control-Allow-Origin") != "" {
t.Errorf("Expected Access-Control-Allow-Origin header to not be set")
}
if response.Header.Get("Access-Control-Allow-Credentials") != "" {
t.Errorf("Expected Access-Control-Allow-Credentials header to not be set")
}
if response.Header.Get("Access-Control-Allow-Headers") != "" {
t.Errorf("Expected Access-Control-Allow-Headers header to not be set")
}
if response.Header.Get("Access-Control-Allow-Methods") != "" {
t.Errorf("Expected Access-Control-Allow-Methods header to not be set")
}
if response.Header.Get("Access-Control-Expose-Headers") == "Date" {
t.Errorf("Expected Date in Access-Control-Expose-Headers header")
}
}
}
}
func TestCORSAllowedMethods(t *testing.T) {
tests := []struct {
allowedMethods []string
method string
allowed bool
}{
{nil, "POST", true},
{nil, "GET", true},
{nil, "OPTIONS", true},
{nil, "PUT", true},
{nil, "DELETE", true},
{nil, "PATCH", true},
{[]string{"GET", "POST"}, "PATCH", false},
}
allowsMethod := func(res *http.Response, method string) bool {
allowedMethods := strings.Split(res.Header.Get("Access-Control-Allow-Methods"), ",")
for _, allowedMethod := range allowedMethods {
if strings.TrimSpace(allowedMethod) == method {
return true
}
}
return false
}
for _, test := range tests {
handler := WithCORS(
http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}),
[]string{".*"}, test.allowedMethods, nil, nil, "true",
)
server := httptest.NewServer(handler)
defer server.Close()
client := http.Client{}
request, err := http.NewRequest(test.method, server.URL+"/version", nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
request.Header.Set("Origin", "allowed.com")
response, err := client.Do(request)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
methodAllowed := allowsMethod(response, test.method)
switch {
case test.allowed && !methodAllowed:
t.Errorf("Expected %v to be allowed, Got only %#v", test.method, response.Header.Get("Access-Control-Allow-Methods"))
case !test.allowed && methodAllowed:
t.Errorf("Unexpected allowed method %v, Expected only %#v", test.method, response.Header.Get("Access-Control-Allow-Methods"))
}
}
}
func TestCompileRegex(t *testing.T) {
uncompiledRegexes := []string{"endsWithMe$", "^startingWithMe"}
regexes, err := compileRegexps(uncompiledRegexes)
if err != nil {
t.Errorf("Failed to compile legal regexes: '%v': %v", uncompiledRegexes, err)
}
if len(regexes) != len(uncompiledRegexes) {
t.Errorf("Wrong number of regexes returned: '%v': %v", uncompiledRegexes, regexes)
}
if !regexes[0].MatchString("Something that endsWithMe") {
t.Errorf("Wrong regex returned: '%v': %v", uncompiledRegexes[0], regexes[0])
}
if regexes[0].MatchString("Something that doesn't endsWithMe.") {
t.Errorf("Wrong regex returned: '%v': %v", uncompiledRegexes[0], regexes[0])
}
if !regexes[1].MatchString("startingWithMe is very important") {
t.Errorf("Wrong regex returned: '%v': %v", uncompiledRegexes[1], regexes[1])
}
if regexes[1].MatchString("not startingWithMe should fail") {
t.Errorf("Wrong regex returned: '%v': %v", uncompiledRegexes[1], regexes[1])
}
}

19
pkg/server/filters/doc.go Normal file
View File

@ -0,0 +1,19 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package filters contains all the http handler chain filters which
// are not api related.
package filters // import "k8s.io/apiserver/pkg/server/filters"

View File

@ -0,0 +1,40 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package filters
import (
"net/http"
"k8s.io/apimachinery/pkg/util/sets"
apirequest "k8s.io/apiserver/pkg/endpoints/request"
)
// LongRunningRequestCheck is a predicate which is true for long-running http requests.
type LongRunningRequestCheck func(r *http.Request, requestInfo *apirequest.RequestInfo) bool
// BasicLongRunningRequestCheck returns true if the given request has one of the specified verbs or one of the specified subresources
func BasicLongRunningRequestCheck(longRunningVerbs, longRunningSubresources sets.String) LongRunningRequestCheck {
return func(r *http.Request, requestInfo *apirequest.RequestInfo) bool {
if longRunningVerbs.Has(requestInfo.Verb) {
return true
}
if requestInfo.IsResourceRequest && longRunningSubresources.Has(requestInfo.Subresource) {
return true
}
return false
}
}

View File

@ -0,0 +1,111 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package filters
import (
"fmt"
"net/http"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/util/sets"
apirequest "k8s.io/apiserver/pkg/endpoints/request"
genericapirequest "k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/apiserver/pkg/server/httplog"
"github.com/golang/glog"
)
// Constant for the retry-after interval on rate limiting.
// TODO: maybe make this dynamic? or user-adjustable?
const retryAfter = "1"
var nonMutatingRequestVerbs = sets.NewString("get", "list", "watch")
func handleError(w http.ResponseWriter, r *http.Request, err error) {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintf(w, "Internal Server Error: %#v", r.RequestURI)
glog.Errorf(err.Error())
}
// WithMaxInFlightLimit limits the number of in-flight requests to buffer size of the passed in channel.
func WithMaxInFlightLimit(
handler http.Handler,
nonMutatingLimit int,
mutatingLimit int,
requestContextMapper genericapirequest.RequestContextMapper,
longRunningRequestCheck LongRunningRequestCheck,
) http.Handler {
if nonMutatingLimit == 0 && mutatingLimit == 0 {
return handler
}
var nonMutatingChan chan bool
var mutatingChan chan bool
if nonMutatingLimit != 0 {
nonMutatingChan = make(chan bool, nonMutatingLimit)
}
if mutatingLimit != 0 {
mutatingChan = make(chan bool, mutatingLimit)
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, ok := requestContextMapper.Get(r)
if !ok {
handleError(w, r, fmt.Errorf("no context found for request, handler chain must be wrong"))
return
}
requestInfo, ok := apirequest.RequestInfoFrom(ctx)
if !ok {
handleError(w, r, fmt.Errorf("no RequestInfo found in context, handler chain must be wrong"))
return
}
// Skip tracking long running events.
if longRunningRequestCheck != nil && longRunningRequestCheck(r, requestInfo) {
handler.ServeHTTP(w, r)
return
}
var c chan bool
if !nonMutatingRequestVerbs.Has(requestInfo.Verb) {
c = mutatingChan
} else {
c = nonMutatingChan
}
if c == nil {
handler.ServeHTTP(w, r)
} else {
select {
case c <- true:
defer func() { <-c }()
handler.ServeHTTP(w, r)
default:
tooManyRequests(r, w)
}
}
})
}
func tooManyRequests(req *http.Request, w http.ResponseWriter) {
// "Too Many Requests" response is returned before logger is setup for the request.
// So we need to explicitly log it here.
defer httplog.NewLogged(req, &w).Log()
// Return a 429 status indicating "Too Many Requests"
w.Header().Set("Retry-After", retryAfter)
http.Error(w, "Too many requests, please try again later.", errors.StatusTooManyRequests)
}

View File

@ -0,0 +1,240 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package filters
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/util/sets"
apirequest "k8s.io/apiserver/pkg/endpoints/request"
apifilters "k8s.io/kubernetes/pkg/genericapiserver/endpoints/filters"
)
func createMaxInflightServer(callsWg, blockWg *sync.WaitGroup, disableCallsWg *bool, disableCallsWgMutex *sync.Mutex, nonMutating, mutating int) *httptest.Server {
longRunningRequestCheck := BasicLongRunningRequestCheck(sets.NewString("watch"), sets.NewString("proxy"))
requestContextMapper := apirequest.NewRequestContextMapper()
requestInfoFactory := &apirequest.RequestInfoFactory{APIPrefixes: sets.NewString("apis", "api"), GrouplessAPIPrefixes: sets.NewString("api")}
handler := WithMaxInFlightLimit(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// A short, accounted request that does not wait for block WaitGroup.
if strings.Contains(r.URL.Path, "dontwait") {
return
}
disableCallsWgMutex.Lock()
waitForCalls := *disableCallsWg
disableCallsWgMutex.Unlock()
if waitForCalls {
callsWg.Done()
}
blockWg.Wait()
}),
nonMutating,
mutating,
requestContextMapper,
longRunningRequestCheck,
)
handler = apifilters.WithRequestInfo(handler, requestInfoFactory, requestContextMapper)
handler = apirequest.WithRequestContext(handler, requestContextMapper)
return httptest.NewServer(handler)
}
// Tests that MaxInFlightLimit works, i.e.
// - "long" requests such as proxy or watch, identified by regexp are not accounted despite
// hanging for the long time,
// - "short" requests are correctly accounted, i.e. there can be only size of channel passed to the
// constructor in flight at any given moment,
// - subsequent "short" requests are rejected instantly with appropriate error,
// - subsequent "long" requests are handled normally,
// - we correctly recover after some "short" requests finish, i.e. we can process new ones.
func TestMaxInFlightNonMutating(t *testing.T) {
const AllowedNonMutatingInflightRequestsNo = 3
// Calls is used to wait until all server calls are received. We are sending
// AllowedNonMutatingInflightRequestsNo of 'long' not-accounted requests and the same number of
// 'short' accounted ones.
calls := &sync.WaitGroup{}
calls.Add(AllowedNonMutatingInflightRequestsNo * 2)
// Responses is used to wait until all responses are
// received. This prevents some async requests getting EOF
// errors from prematurely closing the server
responses := &sync.WaitGroup{}
responses.Add(AllowedNonMutatingInflightRequestsNo * 2)
// Block is used to keep requests in flight for as long as we need to. All requests will
// be unblocked at the same time.
block := &sync.WaitGroup{}
block.Add(1)
waitForCalls := true
waitForCallsMutex := sync.Mutex{}
server := createMaxInflightServer(calls, block, &waitForCalls, &waitForCallsMutex, AllowedNonMutatingInflightRequestsNo, 1)
defer server.Close()
// These should hang, but not affect accounting. use a query param match
for i := 0; i < AllowedNonMutatingInflightRequestsNo; i++ {
// These should hang waiting on block...
go func() {
if err := expectHTTPGet(server.URL+"/api/v1/namespaces/default/wait?watch=true", http.StatusOK); err != nil {
t.Error(err)
}
responses.Done()
}()
}
// Check that sever is not saturated by not-accounted calls
if err := expectHTTPGet(server.URL+"/dontwait", http.StatusOK); err != nil {
t.Error(err)
}
// These should hang and be accounted, i.e. saturate the server
for i := 0; i < AllowedNonMutatingInflightRequestsNo; i++ {
// These should hang waiting on block...
go func() {
if err := expectHTTPGet(server.URL, http.StatusOK); err != nil {
t.Error(err)
}
responses.Done()
}()
}
// We wait for all calls to be received by the server
calls.Wait()
// Disable calls notifications in the server
waitForCallsMutex.Lock()
waitForCalls = false
waitForCallsMutex.Unlock()
// Do this multiple times to show that rate limit rejected requests don't block.
for i := 0; i < 2; i++ {
if err := expectHTTPGet(server.URL, errors.StatusTooManyRequests); err != nil {
t.Error(err)
}
}
// Validate that non-accounted URLs still work. use a path regex match
if err := expectHTTPGet(server.URL+"/api/v1/watch/namespaces/default/dontwait", http.StatusOK); err != nil {
t.Error(err)
}
// We should allow a single mutating request.
if err := expectHTTPPost(server.URL+"/dontwait", http.StatusOK); err != nil {
t.Error(err)
}
// Let all hanging requests finish
block.Done()
// Show that we recover from being blocked up.
// Too avoid flakyness we need to wait until at least one of the requests really finishes.
responses.Wait()
if err := expectHTTPGet(server.URL, http.StatusOK); err != nil {
t.Error(err)
}
}
func TestMaxInFlightMutating(t *testing.T) {
const AllowedMutatingInflightRequestsNo = 3
calls := &sync.WaitGroup{}
calls.Add(AllowedMutatingInflightRequestsNo)
responses := &sync.WaitGroup{}
responses.Add(AllowedMutatingInflightRequestsNo)
// Block is used to keep requests in flight for as long as we need to. All requests will
// be unblocked at the same time.
block := &sync.WaitGroup{}
block.Add(1)
waitForCalls := true
waitForCallsMutex := sync.Mutex{}
server := createMaxInflightServer(calls, block, &waitForCalls, &waitForCallsMutex, 1, AllowedMutatingInflightRequestsNo)
defer server.Close()
// These should hang and be accounted, i.e. saturate the server
for i := 0; i < AllowedMutatingInflightRequestsNo; i++ {
// These should hang waiting on block...
go func() {
if err := expectHTTPPost(server.URL+"/foo/bar", http.StatusOK); err != nil {
t.Error(err)
}
responses.Done()
}()
}
// We wait for all calls to be received by the server
calls.Wait()
// Disable calls notifications in the server
// Disable calls notifications in the server
waitForCallsMutex.Lock()
waitForCalls = false
waitForCallsMutex.Unlock()
// Do this multiple times to show that rate limit rejected requests don't block.
for i := 0; i < 2; i++ {
if err := expectHTTPPost(server.URL+"/foo/bar/", errors.StatusTooManyRequests); err != nil {
t.Error(err)
}
}
// Validate that non-mutating URLs still work. use a path regex match
if err := expectHTTPGet(server.URL+"/dontwait", http.StatusOK); err != nil {
t.Error(err)
}
// Let all hanging requests finish
block.Done()
// Show that we recover from being blocked up.
// Too avoid flakyness we need to wait until at least one of the requests really finishes.
responses.Wait()
if err := expectHTTPPost(server.URL+"/foo/bar", http.StatusOK); err != nil {
t.Error(err)
}
}
// We use GET as a sample non-mutating request.
func expectHTTPGet(url string, code int) error {
r, err := http.Get(url)
if err != nil {
return fmt.Errorf("unexpected error: %v", err)
}
if r.StatusCode != code {
return fmt.Errorf("unexpected response: %v", r.StatusCode)
}
return nil
}
// We use POST as a sample mutating request.
func expectHTTPPost(url string, code int) error {
r, err := http.Post(url, "text/html", strings.NewReader("foo bar"))
if err != nil {
return fmt.Errorf("unexpected error: %v", err)
}
if r.StatusCode != code {
return fmt.Errorf("unexpected response: %v", r.StatusCode)
}
return nil
}

View File

@ -0,0 +1,271 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package filters
import (
"bufio"
"encoding/json"
"fmt"
"net"
"net/http"
"sync"
"time"
apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/runtime/schema"
apirequest "k8s.io/apiserver/pkg/endpoints/request"
)
const globalTimeout = time.Minute
var errConnKilled = fmt.Errorf("kill connection/stream")
// WithTimeoutForNonLongRunningRequests times out non-long-running requests after the time given by globalTimeout.
func WithTimeoutForNonLongRunningRequests(handler http.Handler, requestContextMapper apirequest.RequestContextMapper, longRunning LongRunningRequestCheck) http.Handler {
if longRunning == nil {
return handler
}
timeoutFunc := func(req *http.Request) (<-chan time.Time, *apierrors.StatusError) {
// TODO unify this with apiserver.MaxInFlightLimit
ctx, ok := requestContextMapper.Get(req)
if !ok {
// if this happens, the handler chain isn't setup correctly because there is no context mapper
return time.After(globalTimeout), apierrors.NewInternalError(fmt.Errorf("no context found for request during timeout"))
}
requestInfo, ok := apirequest.RequestInfoFrom(ctx)
if !ok {
// if this happens, the handler chain isn't setup correctly because there is no request info
return time.After(globalTimeout), apierrors.NewInternalError(fmt.Errorf("no request info found for request during timeout"))
}
if longRunning(req, requestInfo) {
return nil, nil
}
return time.After(globalTimeout), apierrors.NewServerTimeout(schema.GroupResource{Group: requestInfo.APIGroup, Resource: requestInfo.Resource}, requestInfo.Verb, 0)
}
return WithTimeout(handler, timeoutFunc)
}
// WithTimeout returns an http.Handler that runs h with a timeout
// determined by timeoutFunc. The new http.Handler calls h.ServeHTTP to handle
// each request, but if a call runs for longer than its time limit, the
// handler responds with a 503 Service Unavailable error and the message
// provided. (If msg is empty, a suitable default message will be sent.) After
// the handler times out, writes by h to its http.ResponseWriter will return
// http.ErrHandlerTimeout. If timeoutFunc returns a nil timeout channel, no
// timeout will be enforced.
func WithTimeout(h http.Handler, timeoutFunc func(*http.Request) (timeout <-chan time.Time, err *apierrors.StatusError)) http.Handler {
return &timeoutHandler{h, timeoutFunc}
}
type timeoutHandler struct {
handler http.Handler
timeout func(*http.Request) (<-chan time.Time, *apierrors.StatusError)
}
func (t *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
after, err := t.timeout(r)
if after == nil {
t.handler.ServeHTTP(w, r)
return
}
done := make(chan struct{})
tw := newTimeoutWriter(w)
go func() {
t.handler.ServeHTTP(tw, r)
close(done)
}()
select {
case <-done:
return
case <-after:
tw.timeout(err)
}
}
type timeoutWriter interface {
http.ResponseWriter
timeout(*apierrors.StatusError)
}
func newTimeoutWriter(w http.ResponseWriter) timeoutWriter {
base := &baseTimeoutWriter{w: w}
_, notifiable := w.(http.CloseNotifier)
_, hijackable := w.(http.Hijacker)
switch {
case notifiable && hijackable:
return &closeHijackTimeoutWriter{base}
case notifiable:
return &closeTimeoutWriter{base}
case hijackable:
return &hijackTimeoutWriter{base}
default:
return base
}
}
type baseTimeoutWriter struct {
w http.ResponseWriter
mu sync.Mutex
// if the timeout handler has timedout
timedOut bool
// if this timeout writer has wrote header
wroteHeader bool
// if this timeout writer has been hijacked
hijacked bool
}
func (tw *baseTimeoutWriter) Header() http.Header {
tw.mu.Lock()
defer tw.mu.Unlock()
if tw.timedOut {
return http.Header{}
}
return tw.w.Header()
}
func (tw *baseTimeoutWriter) Write(p []byte) (int, error) {
tw.mu.Lock()
defer tw.mu.Unlock()
if tw.timedOut {
return 0, http.ErrHandlerTimeout
}
if tw.hijacked {
return 0, http.ErrHijacked
}
tw.wroteHeader = true
return tw.w.Write(p)
}
func (tw *baseTimeoutWriter) Flush() {
tw.mu.Lock()
defer tw.mu.Unlock()
if tw.timedOut {
return
}
if flusher, ok := tw.w.(http.Flusher); ok {
flusher.Flush()
}
}
func (tw *baseTimeoutWriter) WriteHeader(code int) {
tw.mu.Lock()
defer tw.mu.Unlock()
if tw.timedOut || tw.wroteHeader || tw.hijacked {
return
}
tw.wroteHeader = true
tw.w.WriteHeader(code)
}
func (tw *baseTimeoutWriter) timeout(err *apierrors.StatusError) {
tw.mu.Lock()
defer tw.mu.Unlock()
tw.timedOut = true
// The timeout writer has not been used by the inner handler.
// We can safely timeout the HTTP request by sending by a timeout
// handler
if !tw.wroteHeader && !tw.hijacked {
tw.w.WriteHeader(http.StatusGatewayTimeout)
enc := json.NewEncoder(tw.w)
enc.Encode(err)
} else {
// The timeout writer has been used by the inner handler. There is
// no way to timeout the HTTP request at the point. We have to shutdown
// the connection for HTTP1 or reset stream for HTTP2.
//
// Note from: Brad Fitzpatrick
// if the ServeHTTP goroutine panics, that will do the best possible thing for both
// HTTP/1 and HTTP/2. In HTTP/1, assuming you're replying with at least HTTP/1.1 and
// you've already flushed the headers so it's using HTTP chunking, it'll kill the TCP
// connection immediately without a proper 0-byte EOF chunk, so the peer will recognize
// the response as bogus. In HTTP/2 the server will just RST_STREAM the stream, leaving
// the TCP connection open, but resetting the stream to the peer so it'll have an error,
// like the HTTP/1 case.
panic(errConnKilled)
}
}
func (tw *baseTimeoutWriter) closeNotify() <-chan bool {
tw.mu.Lock()
defer tw.mu.Unlock()
if tw.timedOut {
done := make(chan bool)
close(done)
return done
}
return tw.w.(http.CloseNotifier).CloseNotify()
}
func (tw *baseTimeoutWriter) hijack() (net.Conn, *bufio.ReadWriter, error) {
tw.mu.Lock()
defer tw.mu.Unlock()
if tw.timedOut {
return nil, nil, http.ErrHandlerTimeout
}
conn, rw, err := tw.w.(http.Hijacker).Hijack()
if err == nil {
tw.hijacked = true
}
return conn, rw, err
}
type closeTimeoutWriter struct {
*baseTimeoutWriter
}
func (tw *closeTimeoutWriter) CloseNotify() <-chan bool {
return tw.closeNotify()
}
type hijackTimeoutWriter struct {
*baseTimeoutWriter
}
func (tw *hijackTimeoutWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return tw.hijack()
}
type closeHijackTimeoutWriter struct {
*baseTimeoutWriter
}
func (tw *closeHijackTimeoutWriter) CloseNotify() <-chan bool {
return tw.closeNotify()
}
func (tw *closeHijackTimeoutWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return tw.hijack()
}

View File

@ -0,0 +1,85 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package filters
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"time"
apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/runtime/schema"
"strings"
)
func TestTimeout(t *testing.T) {
sendResponse := make(chan struct{}, 1)
writeErrors := make(chan error, 1)
timeout := make(chan time.Time, 1)
resp := "test response"
timeoutErr := apierrors.NewServerTimeout(schema.GroupResource{Group: "foo", Resource: "bar"}, "get", 0)
ts := httptest.NewServer(WithTimeout(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
<-sendResponse
_, err := w.Write([]byte(resp))
writeErrors <- err
}),
func(*http.Request) (<-chan time.Time, *apierrors.StatusError) {
return timeout, timeoutErr
}))
defer ts.Close()
// No timeouts
sendResponse <- struct{}{}
res, err := http.Get(ts.URL)
if err != nil {
t.Error(err)
}
if res.StatusCode != http.StatusOK {
t.Errorf("got res.StatusCode %d; expected %d", res.StatusCode, http.StatusOK)
}
body, _ := ioutil.ReadAll(res.Body)
if string(body) != resp {
t.Errorf("got body %q; expected %q", string(body), resp)
}
if err := <-writeErrors; err != nil {
t.Errorf("got unexpected Write error on first request: %v", err)
}
// Times out
timeout <- time.Time{}
res, err = http.Get(ts.URL)
if err != nil {
t.Error(err)
}
if res.StatusCode != http.StatusGatewayTimeout {
t.Errorf("got res.StatusCode %d; expected %d", res.StatusCode, http.StatusServiceUnavailable)
}
body, _ = ioutil.ReadAll(res.Body)
if !strings.Contains(string(body), timeoutErr.Error()) {
t.Errorf("got body %q; expected it to contain %q", string(body), timeoutErr.Error())
}
// Now try to send a response
sendResponse <- struct{}{}
if err := <-writeErrors; err != http.ErrHandlerTimeout {
t.Errorf("got Write error of %v; expected %v", err, http.ErrHandlerTimeout)
}
}

View File

@ -0,0 +1,76 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package filters
import (
"net/http"
"runtime/debug"
"github.com/golang/glog"
apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/util/runtime"
apirequest "k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/apiserver/pkg/server/httplog"
)
// WithPanicRecovery wraps an http Handler to recover and log panics.
func WithPanicRecovery(handler http.Handler, requestContextMapper apirequest.RequestContextMapper) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
defer runtime.HandleCrash(func(err interface{}) {
http.Error(w, "This request caused apisever to panic. Look in log for details.", http.StatusInternalServerError)
glog.Errorf("APIServer panic'd on %v %v: %v\n%s\n", req.Method, req.RequestURI, err, debug.Stack())
})
logger := httplog.NewLogged(req, &w)
var requestInfo *apirequest.RequestInfo
ctx, ok := requestContextMapper.Get(req)
if !ok {
glog.Errorf("no context found for request, handler chain must be wrong")
} else {
requestInfo, ok = apirequest.RequestInfoFrom(ctx)
if !ok {
glog.Errorf("no RequestInfo found in context, handler chain must be wrong")
}
}
if !ok || requestInfo.Verb != "proxy" {
logger.StacktraceWhen(
httplog.StatusIsNot(
http.StatusOK,
http.StatusCreated,
http.StatusAccepted,
http.StatusBadRequest,
http.StatusMovedPermanently,
http.StatusTemporaryRedirect,
http.StatusConflict,
http.StatusNotFound,
http.StatusUnauthorized,
http.StatusForbidden,
http.StatusNotModified,
apierrors.StatusUnprocessableEntity,
http.StatusSwitchingProtocols,
),
)
}
defer logger.Log()
// Dispatch to the internal handler
handler.ServeHTTP(w, req)
})
}