diff --git a/pkg/authentication/request/headerrequest/requestheader.go b/pkg/authentication/request/headerrequest/requestheader.go new file mode 100644 index 000000000..7b515c351 --- /dev/null +++ b/pkg/authentication/request/headerrequest/requestheader.go @@ -0,0 +1,178 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package headerrequest + +import ( + "crypto/x509" + "fmt" + "io/ioutil" + "net/http" + "strings" + + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/apiserver/pkg/authentication/authenticator" + x509request "k8s.io/apiserver/pkg/authentication/request/x509" + "k8s.io/apiserver/pkg/authentication/user" + utilcert "k8s.io/client-go/pkg/util/cert" +) + +type requestHeaderAuthRequestHandler struct { + // nameHeaders are the headers to check (in order, case-insensitively) for an identity. The first header with a value wins. + nameHeaders []string + + // groupHeaders are the headers to check (case-insensitively) for group membership. All values of all headers will be added. + groupHeaders []string + + // extraHeaderPrefixes are the head prefixes to check (case-insensitively) for filling in + // the user.Info.Extra. All values of all matching headers will be added. + extraHeaderPrefixes []string +} + +func New(nameHeaders []string, groupHeaders []string, extraHeaderPrefixes []string) (authenticator.Request, error) { + trimmedNameHeaders, err := trimHeaders(nameHeaders...) + if err != nil { + return nil, err + } + trimmedGroupHeaders, err := trimHeaders(groupHeaders...) + if err != nil { + return nil, err + } + trimmedExtraHeaderPrefixes, err := trimHeaders(extraHeaderPrefixes...) + if err != nil { + return nil, err + } + + return &requestHeaderAuthRequestHandler{ + nameHeaders: trimmedNameHeaders, + groupHeaders: trimmedGroupHeaders, + extraHeaderPrefixes: trimmedExtraHeaderPrefixes, + }, nil +} + +func trimHeaders(headerNames ...string) ([]string, error) { + ret := []string{} + for _, headerName := range headerNames { + trimmedHeader := strings.TrimSpace(headerName) + if len(trimmedHeader) == 0 { + return nil, fmt.Errorf("empty header %q", headerName) + } + ret = append(ret, trimmedHeader) + } + + return ret, nil +} + +func NewSecure(clientCA string, proxyClientNames []string, nameHeaders []string, groupHeaders []string, extraHeaderPrefixes []string) (authenticator.Request, error) { + headerAuthenticator, err := New(nameHeaders, groupHeaders, extraHeaderPrefixes) + if err != nil { + return nil, err + } + + if len(clientCA) == 0 { + return nil, fmt.Errorf("missing clientCA file") + } + + // Wrap with an x509 verifier + caData, err := ioutil.ReadFile(clientCA) + if err != nil { + return nil, fmt.Errorf("error reading %s: %v", clientCA, err) + } + opts := x509request.DefaultVerifyOptions() + opts.Roots = x509.NewCertPool() + certs, err := utilcert.ParseCertsPEM(caData) + if err != nil { + return nil, fmt.Errorf("error loading certs from %s: %v", clientCA, err) + } + for _, cert := range certs { + opts.Roots.AddCert(cert) + } + + return x509request.NewVerifier(opts, headerAuthenticator, sets.NewString(proxyClientNames...)), nil +} + +func (a *requestHeaderAuthRequestHandler) AuthenticateRequest(req *http.Request) (user.Info, bool, error) { + name := headerValue(req.Header, a.nameHeaders) + if len(name) == 0 { + return nil, false, nil + } + groups := allHeaderValues(req.Header, a.groupHeaders) + extra := newExtra(req.Header, a.extraHeaderPrefixes) + + // clear headers used for authentication + for _, headerName := range a.nameHeaders { + req.Header.Del(headerName) + } + for _, headerName := range a.groupHeaders { + req.Header.Del(headerName) + } + for k := range extra { + for _, prefix := range a.extraHeaderPrefixes { + req.Header.Del(prefix + k) + } + } + + return &user.DefaultInfo{ + Name: name, + Groups: groups, + Extra: extra, + }, true, nil +} + +func headerValue(h http.Header, headerNames []string) string { + for _, headerName := range headerNames { + headerValue := h.Get(headerName) + if len(headerValue) > 0 { + return headerValue + } + } + return "" +} + +func allHeaderValues(h http.Header, headerNames []string) []string { + ret := []string{} + for _, headerName := range headerNames { + values, ok := h[headerName] + if !ok { + continue + } + + for _, headerValue := range values { + if len(headerValue) > 0 { + ret = append(ret, headerValue) + } + } + } + return ret +} + +func newExtra(h http.Header, headerPrefixes []string) map[string][]string { + ret := map[string][]string{} + + // we have to iterate over prefixes first in order to have proper ordering inside the value slices + for _, prefix := range headerPrefixes { + for headerName, vv := range h { + if !strings.HasPrefix(strings.ToLower(headerName), strings.ToLower(prefix)) { + continue + } + + extraKey := strings.ToLower(headerName[len(prefix):]) + ret[extraKey] = append(ret[extraKey], vv...) + } + } + + return ret +} diff --git a/pkg/authentication/request/headerrequest/requestheader_test.go b/pkg/authentication/request/headerrequest/requestheader_test.go new file mode 100644 index 000000000..33e5afcac --- /dev/null +++ b/pkg/authentication/request/headerrequest/requestheader_test.go @@ -0,0 +1,159 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package headerrequest + +import ( + "net/http" + "reflect" + "testing" + + "k8s.io/apiserver/pkg/authentication/user" +) + +func TestRequestHeader(t *testing.T) { + testcases := map[string]struct { + nameHeaders []string + groupHeaders []string + extraPrefixHeaders []string + requestHeaders http.Header + + expectedUser user.Info + expectedOk bool + }{ + "empty": {}, + "user no match": { + nameHeaders: []string{"X-Remote-User"}, + }, + "user match": { + nameHeaders: []string{"X-Remote-User"}, + requestHeaders: http.Header{"X-Remote-User": {"Bob"}}, + expectedUser: &user.DefaultInfo{ + Name: "Bob", + Groups: []string{}, + Extra: map[string][]string{}, + }, + expectedOk: true, + }, + "user exact match": { + nameHeaders: []string{"X-Remote-User"}, + requestHeaders: http.Header{ + "Prefixed-X-Remote-User-With-Suffix": {"Bob"}, + "X-Remote-User-With-Suffix": {"Bob"}, + }, + }, + "user first match": { + nameHeaders: []string{ + "X-Remote-User", + "A-Second-X-Remote-User", + "Another-X-Remote-User", + }, + requestHeaders: http.Header{ + "X-Remote-User": {"", "First header, second value"}, + "A-Second-X-Remote-User": {"Second header, first value", "Second header, second value"}, + "Another-X-Remote-User": {"Third header, first value"}}, + expectedUser: &user.DefaultInfo{ + Name: "Second header, first value", + Groups: []string{}, + Extra: map[string][]string{}, + }, + expectedOk: true, + }, + "user case-insensitive": { + nameHeaders: []string{"x-REMOTE-user"}, // configured headers can be case-insensitive + requestHeaders: http.Header{"X-Remote-User": {"Bob"}}, // the parsed headers are normalized by the http package + expectedUser: &user.DefaultInfo{ + Name: "Bob", + Groups: []string{}, + Extra: map[string][]string{}, + }, + expectedOk: true, + }, + + "groups none": { + nameHeaders: []string{"X-Remote-User"}, + groupHeaders: []string{"X-Remote-Group"}, + requestHeaders: http.Header{ + "X-Remote-User": {"Bob"}, + }, + expectedUser: &user.DefaultInfo{ + Name: "Bob", + Groups: []string{}, + Extra: map[string][]string{}, + }, + expectedOk: true, + }, + "groups all matches": { + nameHeaders: []string{"X-Remote-User"}, + groupHeaders: []string{"X-Remote-Group-1", "X-Remote-Group-2"}, + requestHeaders: http.Header{ + "X-Remote-User": {"Bob"}, + "X-Remote-Group-1": {"one-a", "one-b"}, + "X-Remote-Group-2": {"two-a", "two-b"}, + }, + expectedUser: &user.DefaultInfo{ + Name: "Bob", + Groups: []string{"one-a", "one-b", "two-a", "two-b"}, + Extra: map[string][]string{}, + }, + expectedOk: true, + }, + + "extra prefix matches case-insensitive": { + nameHeaders: []string{"X-Remote-User"}, + groupHeaders: []string{"X-Remote-Group-1", "X-Remote-Group-2"}, + extraPrefixHeaders: []string{"X-Remote-Extra-1-", "X-Remote-Extra-2-"}, + requestHeaders: http.Header{ + "X-Remote-User": {"Bob"}, + "X-Remote-Group-1": {"one-a", "one-b"}, + "X-Remote-Group-2": {"two-a", "two-b"}, + "X-Remote-extra-1-key1": {"alfa", "bravo"}, + "X-Remote-Extra-1-Key2": {"charlie", "delta"}, + "X-Remote-Extra-1-": {"india", "juliet"}, + "X-Remote-extra-2-": {"kilo", "lima"}, + "X-Remote-extra-2-Key1": {"echo", "foxtrot"}, + "X-Remote-Extra-2-key2": {"golf", "hotel"}, + }, + expectedUser: &user.DefaultInfo{ + Name: "Bob", + Groups: []string{"one-a", "one-b", "two-a", "two-b"}, + Extra: map[string][]string{ + "key1": {"alfa", "bravo", "echo", "foxtrot"}, + "key2": {"charlie", "delta", "golf", "hotel"}, + "": {"india", "juliet", "kilo", "lima"}, + }, + }, + expectedOk: true, + }, + } + + for k, testcase := range testcases { + auth, err := New(testcase.nameHeaders, testcase.groupHeaders, testcase.extraPrefixHeaders) + if err != nil { + t.Fatal(err) + } + req := &http.Request{Header: testcase.requestHeaders} + + user, ok, _ := auth.AuthenticateRequest(req) + if testcase.expectedOk != ok { + t.Errorf("%v: expected %v, got %v", k, testcase.expectedOk, ok) + } + if e, a := testcase.expectedUser, user; !reflect.DeepEqual(e, a) { + t.Errorf("%v: expected %#v, got %#v", k, e, a) + + } + } +} diff --git a/pkg/handlers/negotiation/doc.go b/pkg/handlers/negotiation/doc.go new file mode 100644 index 000000000..059af8315 --- /dev/null +++ b/pkg/handlers/negotiation/doc.go @@ -0,0 +1,18 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package negotation contains media type negotiation logic. +package negotiation // import "k8s.io/kubernetes/pkg/genericapiserver/api/handlers/negotiation" diff --git a/pkg/handlers/negotiation/errors.go b/pkg/handlers/negotiation/errors.go new file mode 100644 index 000000000..cd262706c --- /dev/null +++ b/pkg/handlers/negotiation/errors.go @@ -0,0 +1,61 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package negotiation + +import ( + "fmt" + "net/http" + "strings" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// errNotAcceptable indicates Accept negotiation has failed +type errNotAcceptable struct { + accepted []string +} + +func (e errNotAcceptable) Error() string { + return fmt.Sprintf("only the following media types are accepted: %v", strings.Join(e.accepted, ", ")) +} + +func (e errNotAcceptable) Status() metav1.Status { + return metav1.Status{ + Status: metav1.StatusFailure, + Code: http.StatusNotAcceptable, + Reason: metav1.StatusReason("NotAcceptable"), + Message: e.Error(), + } +} + +// errUnsupportedMediaType indicates Content-Type is not recognized +type errUnsupportedMediaType struct { + accepted []string +} + +func (e errUnsupportedMediaType) Error() string { + return fmt.Sprintf("the body of the request was in an unknown format - accepted media types include: %v", strings.Join(e.accepted, ", ")) +} + +func (e errUnsupportedMediaType) Status() metav1.Status { + return metav1.Status{ + Status: metav1.StatusFailure, + Code: http.StatusUnsupportedMediaType, + Reason: metav1.StatusReason("UnsupportedMediaType"), + Message: e.Error(), + } +} diff --git a/pkg/handlers/negotiation/negotiate.go b/pkg/handlers/negotiation/negotiate.go new file mode 100644 index 000000000..c3948d4cd --- /dev/null +++ b/pkg/handlers/negotiation/negotiate.go @@ -0,0 +1,305 @@ +/* +Copyright 2015 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package negotiation + +import ( + "mime" + "net/http" + "strconv" + "strings" + + "bitbucket.org/ww/goautoneg" + + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" +) + +// MediaTypesForSerializer returns a list of media and stream media types for the server. +func MediaTypesForSerializer(ns runtime.NegotiatedSerializer) (mediaTypes, streamMediaTypes []string) { + for _, info := range ns.SupportedMediaTypes() { + mediaTypes = append(mediaTypes, info.MediaType) + if info.StreamSerializer != nil { + // stream=watch is the existing mime-type parameter for watch + streamMediaTypes = append(streamMediaTypes, info.MediaType+";stream=watch") + } + } + return mediaTypes, streamMediaTypes +} + +func NegotiateOutputSerializer(req *http.Request, ns runtime.NegotiatedSerializer) (runtime.SerializerInfo, error) { + mediaType, ok := negotiateMediaTypeOptions(req.Header.Get("Accept"), acceptedMediaTypesForEndpoint(ns), defaultEndpointRestrictions) + if !ok { + supported, _ := MediaTypesForSerializer(ns) + return runtime.SerializerInfo{}, errNotAcceptable{supported} + } + // TODO: move into resthandler + info := mediaType.accepted.Serializer + if (mediaType.pretty || isPrettyPrint(req)) && info.PrettySerializer != nil { + info.Serializer = info.PrettySerializer + } + return info, nil +} + +func NegotiateOutputStreamSerializer(req *http.Request, ns runtime.NegotiatedSerializer) (runtime.SerializerInfo, error) { + mediaType, ok := negotiateMediaTypeOptions(req.Header.Get("Accept"), acceptedMediaTypesForEndpoint(ns), defaultEndpointRestrictions) + if !ok || mediaType.accepted.Serializer.StreamSerializer == nil { + _, supported := MediaTypesForSerializer(ns) + return runtime.SerializerInfo{}, errNotAcceptable{supported} + } + return mediaType.accepted.Serializer, nil +} + +func NegotiateInputSerializer(req *http.Request, ns runtime.NegotiatedSerializer) (runtime.SerializerInfo, error) { + mediaTypes := ns.SupportedMediaTypes() + mediaType := req.Header.Get("Content-Type") + if len(mediaType) == 0 { + mediaType = mediaTypes[0].MediaType + } + mediaType, _, err := mime.ParseMediaType(mediaType) + if err != nil { + _, supported := MediaTypesForSerializer(ns) + return runtime.SerializerInfo{}, errUnsupportedMediaType{supported} + } + + for _, info := range mediaTypes { + if info.MediaType != mediaType { + continue + } + return info, nil + } + + _, supported := MediaTypesForSerializer(ns) + return runtime.SerializerInfo{}, errUnsupportedMediaType{supported} +} + +// isPrettyPrint returns true if the "pretty" query parameter is true or if the User-Agent +// matches known "human" clients. +func isPrettyPrint(req *http.Request) bool { + // DEPRECATED: should be part of the content type + if req.URL != nil { + pp := req.URL.Query().Get("pretty") + if len(pp) > 0 { + pretty, _ := strconv.ParseBool(pp) + return pretty + } + } + userAgent := req.UserAgent() + // This covers basic all browers and cli http tools + if strings.HasPrefix(userAgent, "curl") || strings.HasPrefix(userAgent, "Wget") || strings.HasPrefix(userAgent, "Mozilla/5.0") { + return true + } + return false +} + +// negotiate the most appropriate content type given the accept header and a list of +// alternatives. +func negotiate(header string, alternatives []string) (goautoneg.Accept, bool) { + alternates := make([][]string, 0, len(alternatives)) + for _, alternate := range alternatives { + alternates = append(alternates, strings.SplitN(alternate, "/", 2)) + } + for _, clause := range goautoneg.ParseAccept(header) { + for _, alternate := range alternates { + if clause.Type == alternate[0] && clause.SubType == alternate[1] { + return clause, true + } + if clause.Type == alternate[0] && clause.SubType == "*" { + clause.SubType = alternate[1] + return clause, true + } + if clause.Type == "*" && clause.SubType == "*" { + clause.Type = alternate[0] + clause.SubType = alternate[1] + return clause, true + } + } + } + return goautoneg.Accept{}, false +} + +// endpointRestrictions is an interface that allows content-type negotiation +// to verify server support for specific options +type endpointRestrictions interface { + // AllowsConversion should return true if the specified group version kind + // is an allowed target object. + AllowsConversion(schema.GroupVersionKind) bool + // AllowsServerVersion should return true if the specified version is valid + // for the server group. + AllowsServerVersion(version string) bool + // AllowsStreamSchema should return true if the specified stream schema is + // valid for the server group. + AllowsStreamSchema(schema string) bool +} + +var defaultEndpointRestrictions = emptyEndpointRestrictions{} + +type emptyEndpointRestrictions struct{} + +func (emptyEndpointRestrictions) AllowsConversion(schema.GroupVersionKind) bool { return false } +func (emptyEndpointRestrictions) AllowsServerVersion(string) bool { return false } +func (emptyEndpointRestrictions) AllowsStreamSchema(s string) bool { return s == "watch" } + +// acceptedMediaType contains information about a valid media type that the +// server can serialize. +type acceptedMediaType struct { + // Type is the first part of the media type ("application") + Type string + // SubType is the second part of the media type ("json") + SubType string + // Serializer is the serialization info this object accepts + Serializer runtime.SerializerInfo +} + +// mediaTypeOptions describes information for a given media type that may alter +// the server response +type mediaTypeOptions struct { + // pretty is true if the requested representation should be formatted for human + // viewing + pretty bool + + // stream, if set, indicates that a streaming protocol variant of this encoding + // is desired. The only currently supported value is watch which returns versioned + // events. In the future, this may refer to other stream protocols. + stream string + + // convert is a request to alter the type of object returned by the server from the + // normal response + convert *schema.GroupVersionKind + // useServerVersion is an optional version for the server group + useServerVersion string + + // export is true if the representation requested should exclude fields the server + // has set + export bool + + // unrecognized is a list of all unrecognized keys + unrecognized []string + + // the accepted media type from the client + accepted *acceptedMediaType +} + +// acceptMediaTypeOptions returns an options object that matches the provided media type params. If +// it returns false, the provided options are not allowed and the media type must be skipped. These +// parameters are unversioned and may not be changed. +func acceptMediaTypeOptions(params map[string]string, accepts *acceptedMediaType, endpoint endpointRestrictions) (mediaTypeOptions, bool) { + var options mediaTypeOptions + + // extract all known parameters + for k, v := range params { + switch k { + + // controls transformation of the object when returned + case "as": + if options.convert == nil { + options.convert = &schema.GroupVersionKind{} + } + options.convert.Kind = v + case "g": + if options.convert == nil { + options.convert = &schema.GroupVersionKind{} + } + options.convert.Group = v + case "v": + if options.convert == nil { + options.convert = &schema.GroupVersionKind{} + } + options.convert.Version = v + + // controls the streaming schema + case "stream": + if len(v) > 0 && (accepts.Serializer.StreamSerializer == nil || !endpoint.AllowsStreamSchema(v)) { + return mediaTypeOptions{}, false + } + options.stream = v + + // controls the version of the server API group used + // for generic output + case "sv": + if len(v) > 0 && !endpoint.AllowsServerVersion(v) { + return mediaTypeOptions{}, false + } + options.useServerVersion = v + + // if specified, the server should transform the returned + // output and remove fields that are always server specified, + // or which fit the default behavior. + case "export": + options.export = v == "1" + + // if specified, the pretty serializer will be used + case "pretty": + options.pretty = v == "1" + + default: + options.unrecognized = append(options.unrecognized, k) + } + } + + if options.convert != nil && !endpoint.AllowsConversion(*options.convert) { + return mediaTypeOptions{}, false + } + + options.accepted = accepts + + return options, true +} + +// negotiateMediaTypeOptions returns the most appropriate content type given the accept header and +// a list of alternatives along with the accepted media type parameters. +func negotiateMediaTypeOptions(header string, accepted []acceptedMediaType, endpoint endpointRestrictions) (mediaTypeOptions, bool) { + if len(header) == 0 && len(accepted) > 0 { + return mediaTypeOptions{ + accepted: &accepted[0], + }, true + } + + clauses := goautoneg.ParseAccept(header) + for _, clause := range clauses { + for i := range accepted { + accepts := &accepted[i] + switch { + case clause.Type == accepts.Type && clause.SubType == accepts.SubType, + clause.Type == accepts.Type && clause.SubType == "*", + clause.Type == "*" && clause.SubType == "*": + // TODO: should we prefer the first type with no unrecognized options? Do we need to ignore unrecognized + // parameters. + return acceptMediaTypeOptions(clause.Params, accepts, endpoint) + } + } + } + return mediaTypeOptions{}, false +} + +// acceptedMediaTypesForEndpoint returns an array of structs that are used to efficiently check which +// allowed media types the server exposes. +func acceptedMediaTypesForEndpoint(ns runtime.NegotiatedSerializer) []acceptedMediaType { + var acceptedMediaTypes []acceptedMediaType + for _, info := range ns.SupportedMediaTypes() { + segments := strings.SplitN(info.MediaType, "/", 2) + if len(segments) == 1 { + segments = append(segments, "*") + } + t := acceptedMediaType{ + Type: segments[0], + SubType: segments[1], + Serializer: info, + } + acceptedMediaTypes = append(acceptedMediaTypes, t) + } + return acceptedMediaTypes +} diff --git a/pkg/handlers/negotiation/negotiate_test.go b/pkg/handlers/negotiation/negotiate_test.go new file mode 100644 index 000000000..8a747ff73 --- /dev/null +++ b/pkg/handlers/negotiation/negotiate_test.go @@ -0,0 +1,245 @@ +/* +Copyright 2015 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package negotiation + +import ( + "net/http" + "net/url" + "testing" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" +) + +// statusError is an object that can be converted into an metav1.Status +type statusError interface { + Status() metav1.Status +} + +type fakeNegotiater struct { + serializer, streamSerializer runtime.Serializer + framer runtime.Framer + types, streamTypes []string +} + +func (n *fakeNegotiater) SupportedMediaTypes() []runtime.SerializerInfo { + var out []runtime.SerializerInfo + for _, s := range n.types { + info := runtime.SerializerInfo{Serializer: n.serializer, MediaType: s, EncodesAsText: true} + for _, t := range n.streamTypes { + if t == s { + info.StreamSerializer = &runtime.StreamSerializerInfo{ + EncodesAsText: true, + Framer: n.framer, + Serializer: n.streamSerializer, + } + } + } + out = append(out, info) + } + return out +} + +func (n *fakeNegotiater) EncoderForVersion(serializer runtime.Encoder, gv runtime.GroupVersioner) runtime.Encoder { + return n.serializer +} + +func (n *fakeNegotiater) DecoderToVersion(serializer runtime.Decoder, gv runtime.GroupVersioner) runtime.Decoder { + return n.serializer +} + +var fakeCodec = runtime.NewCodec(runtime.NoopEncoder{}, runtime.NoopDecoder{}) + +func TestNegotiate(t *testing.T) { + testCases := []struct { + accept string + req *http.Request + ns *fakeNegotiater + serializer runtime.Serializer + contentType string + params map[string]string + errFn func(error) bool + }{ + // pick a default + { + req: &http.Request{}, + contentType: "application/json", + ns: &fakeNegotiater{serializer: fakeCodec, types: []string{"application/json"}}, + serializer: fakeCodec, + }, + { + accept: "", + contentType: "application/json", + ns: &fakeNegotiater{serializer: fakeCodec, types: []string{"application/json"}}, + serializer: fakeCodec, + }, + { + accept: "*/*", + contentType: "application/json", + ns: &fakeNegotiater{serializer: fakeCodec, types: []string{"application/json"}}, + serializer: fakeCodec, + }, + { + accept: "application/*", + contentType: "application/json", + ns: &fakeNegotiater{serializer: fakeCodec, types: []string{"application/json"}}, + serializer: fakeCodec, + }, + { + accept: "application/json", + contentType: "application/json", + ns: &fakeNegotiater{serializer: fakeCodec, types: []string{"application/json"}}, + serializer: fakeCodec, + }, + { + accept: "application/json", + contentType: "application/json", + ns: &fakeNegotiater{serializer: fakeCodec, types: []string{"application/json", "application/protobuf"}}, + serializer: fakeCodec, + }, + { + accept: "application/protobuf", + contentType: "application/protobuf", + ns: &fakeNegotiater{serializer: fakeCodec, types: []string{"application/json", "application/protobuf"}}, + serializer: fakeCodec, + }, + { + accept: "application/json; pretty=1", + contentType: "application/json", + ns: &fakeNegotiater{serializer: fakeCodec, types: []string{"application/json"}}, + serializer: fakeCodec, + params: map[string]string{"pretty": "1"}, + }, + { + accept: "unrecognized/stuff,application/json; pretty=1", + contentType: "application/json", + ns: &fakeNegotiater{serializer: fakeCodec, types: []string{"application/json"}}, + serializer: fakeCodec, + params: map[string]string{"pretty": "1"}, + }, + + // query param triggers pretty + { + req: &http.Request{ + Header: http.Header{"Accept": []string{"application/json"}}, + URL: &url.URL{RawQuery: "pretty=1"}, + }, + contentType: "application/json", + ns: &fakeNegotiater{serializer: fakeCodec, types: []string{"application/json"}}, + serializer: fakeCodec, + params: map[string]string{"pretty": "1"}, + }, + + // certain user agents trigger pretty + { + req: &http.Request{ + Header: http.Header{ + "Accept": []string{"application/json"}, + "User-Agent": []string{"curl"}, + }, + }, + contentType: "application/json", + ns: &fakeNegotiater{serializer: fakeCodec, types: []string{"application/json"}}, + serializer: fakeCodec, + params: map[string]string{"pretty": "1"}, + }, + { + req: &http.Request{ + Header: http.Header{ + "Accept": []string{"application/json"}, + "User-Agent": []string{"Wget"}, + }, + }, + contentType: "application/json", + ns: &fakeNegotiater{serializer: fakeCodec, types: []string{"application/json"}}, + serializer: fakeCodec, + params: map[string]string{"pretty": "1"}, + }, + { + req: &http.Request{ + Header: http.Header{ + "Accept": []string{"application/json"}, + "User-Agent": []string{"Mozilla/5.0"}, + }, + }, + contentType: "application/json", + ns: &fakeNegotiater{serializer: fakeCodec, types: []string{"application/json"}}, + serializer: fakeCodec, + params: map[string]string{"pretty": "1"}, + }, + + // "application" is not a valid media type, so the server will reject the response during + // negotiation (the server, in error, has specified an invalid media type) + { + accept: "application", + ns: &fakeNegotiater{serializer: fakeCodec, types: []string{"application"}}, + errFn: func(err error) bool { + return err.Error() == "only the following media types are accepted: application" + }, + }, + { + ns: &fakeNegotiater{}, + errFn: func(err error) bool { + return err.Error() == "only the following media types are accepted: " + }, + }, + { + accept: "*/*", + ns: &fakeNegotiater{}, + errFn: func(err error) bool { + return err.Error() == "only the following media types are accepted: " + }, + }, + } + + for i, test := range testCases { + req := test.req + if req == nil { + req = &http.Request{Header: http.Header{}} + req.Header.Set("Accept", test.accept) + } + s, err := NegotiateOutputSerializer(req, test.ns) + switch { + case err == nil && test.errFn != nil: + t.Errorf("%d: failed: expected error", i) + continue + case err != nil && test.errFn == nil: + t.Errorf("%d: failed: %v", i, err) + continue + case err != nil: + if !test.errFn(err) { + t.Errorf("%d: failed: %v", i, err) + } + status, ok := err.(statusError) + if !ok { + t.Errorf("%d: failed, error should be statusError: %v", i, err) + continue + } + if status.Status().Status != metav1.StatusFailure || status.Status().Code != http.StatusNotAcceptable { + t.Errorf("%d: failed: %v", i, err) + continue + } + continue + } + if test.contentType != s.MediaType { + t.Errorf("%d: unexpected %s %s", i, test.contentType, s.MediaType) + } + if s.Serializer != test.serializer { + t.Errorf("%d: unexpected %s %s", i, test.serializer, s.Serializer) + } + } +} diff --git a/pkg/metrics/OWNERS b/pkg/metrics/OWNERS new file mode 100755 index 000000000..f0706b3f1 --- /dev/null +++ b/pkg/metrics/OWNERS @@ -0,0 +1,3 @@ +reviewers: +- wojtek-t +- jimmidyson diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go new file mode 100644 index 000000000..6df018921 --- /dev/null +++ b/pkg/metrics/metrics.go @@ -0,0 +1,245 @@ +/* +Copyright 2015 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package metrics + +import ( + "bufio" + "net" + "net/http" + "strconv" + "time" + + utilnet "k8s.io/apimachinery/pkg/util/net" + + "github.com/emicklei/go-restful" + "github.com/prometheus/client_golang/prometheus" +) + +var ( + // TODO(a-robinson): Add unit tests for the handling of these metrics once + // the upstream library supports it. + requestCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "apiserver_request_count", + Help: "Counter of apiserver requests broken out for each verb, API resource, client, and HTTP response contentType and code.", + }, + []string{"verb", "resource", "client", "contentType", "code"}, + ) + requestLatencies = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "apiserver_request_latencies", + Help: "Response latency distribution in microseconds for each verb, resource and client.", + // Use buckets ranging from 125 ms to 8 seconds. + Buckets: prometheus.ExponentialBuckets(125000, 2.0, 7), + }, + []string{"verb", "resource"}, + ) + requestLatenciesSummary = prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Name: "apiserver_request_latencies_summary", + Help: "Response latency summary in microseconds for each verb and resource.", + // Make the sliding window of 1h. + MaxAge: time.Hour, + }, + []string{"verb", "resource"}, + ) +) + +// Register all metrics. +func Register() { + prometheus.MustRegister(requestCounter) + prometheus.MustRegister(requestLatencies) + prometheus.MustRegister(requestLatenciesSummary) +} + +func Monitor(verb, resource *string, client, contentType string, httpCode int, reqStart time.Time) { + elapsed := float64((time.Since(reqStart)) / time.Microsecond) + requestCounter.WithLabelValues(*verb, *resource, client, contentType, codeToString(httpCode)).Inc() + requestLatencies.WithLabelValues(*verb, *resource).Observe(elapsed) + requestLatenciesSummary.WithLabelValues(*verb, *resource).Observe(elapsed) +} + +func Reset() { + requestCounter.Reset() + requestLatencies.Reset() + requestLatenciesSummary.Reset() +} + +// InstrumentRouteFunc works like Prometheus' InstrumentHandlerFunc but wraps +// the go-restful RouteFunction instead of a HandlerFunc +func InstrumentRouteFunc(verb, resource string, routeFunc restful.RouteFunction) restful.RouteFunction { + return restful.RouteFunction(func(request *restful.Request, response *restful.Response) { + now := time.Now() + + delegate := &responseWriterDelegator{ResponseWriter: response.ResponseWriter} + + _, cn := response.ResponseWriter.(http.CloseNotifier) + _, fl := response.ResponseWriter.(http.Flusher) + _, hj := response.ResponseWriter.(http.Hijacker) + var rw http.ResponseWriter + if cn && fl && hj { + rw = &fancyResponseWriterDelegator{delegate} + } else { + rw = delegate + } + response.ResponseWriter = rw + + routeFunc(request, response) + Monitor(&verb, &resource, utilnet.GetHTTPClient(request.Request), rw.Header().Get("Content-Type"), delegate.status, now) + }) +} + +type responseWriterDelegator struct { + http.ResponseWriter + + status int + written int64 + wroteHeader bool +} + +func (r *responseWriterDelegator) WriteHeader(code int) { + r.status = code + r.wroteHeader = true + r.ResponseWriter.WriteHeader(code) +} + +func (r *responseWriterDelegator) Write(b []byte) (int, error) { + if !r.wroteHeader { + r.WriteHeader(http.StatusOK) + } + n, err := r.ResponseWriter.Write(b) + r.written += int64(n) + return n, err +} + +type fancyResponseWriterDelegator struct { + *responseWriterDelegator +} + +func (f *fancyResponseWriterDelegator) CloseNotify() <-chan bool { + return f.ResponseWriter.(http.CloseNotifier).CloseNotify() +} + +func (f *fancyResponseWriterDelegator) Flush() { + f.ResponseWriter.(http.Flusher).Flush() +} + +func (f *fancyResponseWriterDelegator) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return f.ResponseWriter.(http.Hijacker).Hijack() +} + +// Small optimization over Itoa +func codeToString(s int) string { + switch s { + case 100: + return "100" + case 101: + return "101" + + case 200: + return "200" + case 201: + return "201" + case 202: + return "202" + case 203: + return "203" + case 204: + return "204" + case 205: + return "205" + case 206: + return "206" + + case 300: + return "300" + case 301: + return "301" + case 302: + return "302" + case 304: + return "304" + case 305: + return "305" + case 307: + return "307" + + case 400: + return "400" + case 401: + return "401" + case 402: + return "402" + case 403: + return "403" + case 404: + return "404" + case 405: + return "405" + case 406: + return "406" + case 407: + return "407" + case 408: + return "408" + case 409: + return "409" + case 410: + return "410" + case 411: + return "411" + case 412: + return "412" + case 413: + return "413" + case 414: + return "414" + case 415: + return "415" + case 416: + return "416" + case 417: + return "417" + case 418: + return "418" + + case 500: + return "500" + case 501: + return "501" + case 502: + return "502" + case 503: + return "503" + case 504: + return "504" + case 505: + return "505" + + case 428: + return "428" + case 429: + return "429" + case 431: + return "431" + case 511: + return "511" + + default: + return strconv.Itoa(s) + } +} diff --git a/pkg/request/OWNERS b/pkg/request/OWNERS new file mode 100755 index 000000000..9d268c4d1 --- /dev/null +++ b/pkg/request/OWNERS @@ -0,0 +1,2 @@ +reviewers: +- sttts diff --git a/pkg/request/context.go b/pkg/request/context.go new file mode 100644 index 000000000..b6e7d0dba --- /dev/null +++ b/pkg/request/context.go @@ -0,0 +1,145 @@ +/* +Copyright 2014 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package request + +import ( + stderrs "errors" + "time" + + "golang.org/x/net/context" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apiserver/pkg/authentication/user" +) + +// Context carries values across API boundaries. +// This context matches the context.Context interface +// (https://blog.golang.org/context), for the purposes +// of passing the api.Context through to the storage tier. +// TODO: Determine the extent that this abstraction+interface +// is used by the api, and whether we can remove. +type Context interface { + // Value returns the value associated with key or nil if none. + Value(key interface{}) interface{} + + // Deadline returns the time when this Context will be canceled, if any. + Deadline() (deadline time.Time, ok bool) + + // Done returns a channel that is closed when this Context is canceled + // or times out. + Done() <-chan struct{} + + // Err indicates why this context was canceled, after the Done channel + // is closed. + Err() error +} + +// The key type is unexported to prevent collisions +type key int + +const ( + // namespaceKey is the context key for the request namespace. + namespaceKey key = iota + + // userKey is the context key for the request user. + userKey + + // uidKey is the context key for the uid to assign to an object on create. + uidKey + + // userAgentKey is the context key for the request user agent. + userAgentKey + + namespaceDefault = "default" // TODO(sttts): solve import cycle when using api.NamespaceDefault +) + +// NewContext instantiates a base context object for request flows. +func NewContext() Context { + return context.TODO() +} + +// NewDefaultContext instantiates a base context object for request flows in the default namespace +func NewDefaultContext() Context { + return WithNamespace(NewContext(), namespaceDefault) +} + +// WithValue returns a copy of parent in which the value associated with key is val. +func WithValue(parent Context, key interface{}, val interface{}) Context { + internalCtx, ok := parent.(context.Context) + if !ok { + panic(stderrs.New("Invalid context type")) + } + return context.WithValue(internalCtx, key, val) +} + +// WithNamespace returns a copy of parent in which the namespace value is set +func WithNamespace(parent Context, namespace string) Context { + return WithValue(parent, namespaceKey, namespace) +} + +// NamespaceFrom returns the value of the namespace key on the ctx +func NamespaceFrom(ctx Context) (string, bool) { + namespace, ok := ctx.Value(namespaceKey).(string) + return namespace, ok +} + +// NamespaceValue returns the value of the namespace key on the ctx, or the empty string if none +func NamespaceValue(ctx Context) string { + namespace, _ := NamespaceFrom(ctx) + return namespace +} + +// WithNamespaceDefaultIfNone returns a context whose namespace is the default if and only if the parent context has no namespace value +func WithNamespaceDefaultIfNone(parent Context) Context { + namespace, ok := NamespaceFrom(parent) + if !ok || len(namespace) == 0 { + return WithNamespace(parent, namespaceDefault) + } + return parent +} + +// WithUser returns a copy of parent in which the user value is set +func WithUser(parent Context, user user.Info) Context { + return WithValue(parent, userKey, user) +} + +// UserFrom returns the value of the user key on the ctx +func UserFrom(ctx Context) (user.Info, bool) { + user, ok := ctx.Value(userKey).(user.Info) + return user, ok +} + +// WithUID returns a copy of parent in which the uid value is set +func WithUID(parent Context, uid types.UID) Context { + return WithValue(parent, uidKey, uid) +} + +// UIDFrom returns the value of the uid key on the ctx +func UIDFrom(ctx Context) (types.UID, bool) { + uid, ok := ctx.Value(uidKey).(types.UID) + return uid, ok +} + +// WithUserAgent returns a copy of parent in which the user value is set +func WithUserAgent(parent Context, userAgent string) Context { + return WithValue(parent, userAgentKey, userAgent) +} + +// UserAgentFrom returns the value of the userAgent key on the ctx +func UserAgentFrom(ctx Context) (string, bool) { + userAgent, ok := ctx.Value(userAgentKey).(string) + return userAgent, ok +} diff --git a/pkg/request/context_test.go b/pkg/request/context_test.go new file mode 100644 index 000000000..5d2608f5c --- /dev/null +++ b/pkg/request/context_test.go @@ -0,0 +1,134 @@ +/* +Copyright 2014 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package request_test + +import ( + "testing" + + "k8s.io/apimachinery/pkg/types" + "k8s.io/apiserver/pkg/authentication/user" + "k8s.io/kubernetes/pkg/api" + genericapirequest "k8s.io/kubernetes/pkg/genericapiserver/api/request" +) + +// TestNamespaceContext validates that a namespace can be get/set on a context object +func TestNamespaceContext(t *testing.T) { + ctx := genericapirequest.NewDefaultContext() + result, ok := genericapirequest.NamespaceFrom(ctx) + if !ok { + t.Fatalf("Error getting namespace") + } + if api.NamespaceDefault != result { + t.Fatalf("Expected: %s, Actual: %s", api.NamespaceDefault, result) + } + + ctx = genericapirequest.NewContext() + result, ok = genericapirequest.NamespaceFrom(ctx) + if ok { + t.Fatalf("Should not be ok because there is no namespace on the context") + } +} + +//TestUserContext validates that a userinfo can be get/set on a context object +func TestUserContext(t *testing.T) { + ctx := genericapirequest.NewContext() + _, ok := genericapirequest.UserFrom(ctx) + if ok { + t.Fatalf("Should not be ok because there is no user.Info on the context") + } + ctx = genericapirequest.WithUser( + ctx, + &user.DefaultInfo{ + Name: "bob", + UID: "123", + Groups: []string{"group1"}, + Extra: map[string][]string{"foo": {"bar"}}, + }, + ) + + result, ok := genericapirequest.UserFrom(ctx) + if !ok { + t.Fatalf("Error getting user info") + } + + expectedName := "bob" + if result.GetName() != expectedName { + t.Fatalf("Get user name error, Expected: %s, Actual: %s", expectedName, result.GetName()) + } + + expectedUID := "123" + if result.GetUID() != expectedUID { + t.Fatalf("Get UID error, Expected: %s, Actual: %s", expectedUID, result.GetName()) + } + + expectedGroup := "group1" + actualGroup := result.GetGroups() + if len(actualGroup) != 1 { + t.Fatalf("Get user group number error, Expected: 1, Actual: %d", len(actualGroup)) + } else if actualGroup[0] != expectedGroup { + t.Fatalf("Get user group error, Expected: %s, Actual: %s", expectedGroup, actualGroup[0]) + } + + expectedExtraKey := "foo" + expectedExtraValue := "bar" + actualExtra := result.GetExtra() + if len(actualExtra[expectedExtraKey]) != 1 { + t.Fatalf("Get user extra map number error, Expected: 1, Actual: %d", len(actualExtra[expectedExtraKey])) + } else if actualExtra[expectedExtraKey][0] != expectedExtraValue { + t.Fatalf("Get user extra map value error, Expected: %s, Actual: %s", expectedExtraValue, actualExtra[expectedExtraKey]) + } + +} + +//TestUIDContext validates that a UID can be get/set on a context object +func TestUIDContext(t *testing.T) { + ctx := genericapirequest.NewContext() + _, ok := genericapirequest.UIDFrom(ctx) + if ok { + t.Fatalf("Should not be ok because there is no UID on the context") + } + ctx = genericapirequest.WithUID( + ctx, + types.UID("testUID"), + ) + _, ok = genericapirequest.UIDFrom(ctx) + if !ok { + t.Fatalf("Error getting UID") + } +} + +//TestUserAgentContext validates that a useragent can be get/set on a context object +func TestUserAgentContext(t *testing.T) { + ctx := genericapirequest.NewContext() + _, ok := genericapirequest.UserAgentFrom(ctx) + if ok { + t.Fatalf("Should not be ok because there is no UserAgent on the context") + } + + ctx = genericapirequest.WithUserAgent( + ctx, + "TestUserAgent", + ) + result, ok := genericapirequest.UserAgentFrom(ctx) + if !ok { + t.Fatalf("Error getting UserAgent") + } + expectedResult := "TestUserAgent" + if result != expectedResult { + t.Fatalf("Get user agent error, Expected: %s, Actual: %s", expectedResult, result) + } +} diff --git a/pkg/request/doc.go b/pkg/request/doc.go new file mode 100644 index 000000000..a2d8b3cc5 --- /dev/null +++ b/pkg/request/doc.go @@ -0,0 +1,20 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package request contains everything around extracting info from +// a http request object. +// TODO: this package is temporary. Handlers must move into pkg/apiserver/handlers to avoid dependency cycle +package request // import "k8s.io/kubernetes/pkg/genericapiserver/api/request" diff --git a/pkg/request/requestcontext.go b/pkg/request/requestcontext.go new file mode 100644 index 000000000..32fa9215f --- /dev/null +++ b/pkg/request/requestcontext.go @@ -0,0 +1,117 @@ +/* +Copyright 2014 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package request + +import ( + "errors" + "net/http" + "sync" + + "github.com/golang/glog" +) + +// RequestContextMapper keeps track of the context associated with a particular request +type RequestContextMapper interface { + // Get returns the context associated with the given request (if any), and true if the request has an associated context, and false if it does not. + Get(req *http.Request) (Context, bool) + // Update maps the request to the given context. If no context was previously associated with the request, an error is returned. + // Update should only be called with a descendant context of the previously associated context. + // Updating to an unrelated context may return an error in the future. + // The context associated with a request should only be updated by a limited set of callers. + // Valid examples include the authentication layer, or an audit/tracing layer. + Update(req *http.Request, context Context) error +} + +type requestContextMap struct { + contexts map[*http.Request]Context + lock sync.Mutex +} + +// NewRequestContextMapper returns a new RequestContextMapper. +// The returned mapper must be added as a request filter using NewRequestContextFilter. +func NewRequestContextMapper() RequestContextMapper { + return &requestContextMap{ + contexts: make(map[*http.Request]Context), + } +} + +// Get returns the context associated with the given request (if any), and true if the request has an associated context, and false if it does not. +// Get will only return a valid context when called from inside the filter chain set up by NewRequestContextFilter() +func (c *requestContextMap) Get(req *http.Request) (Context, bool) { + c.lock.Lock() + defer c.lock.Unlock() + context, ok := c.contexts[req] + return context, ok +} + +// Update maps the request to the given context. +// If no context was previously associated with the request, an error is returned and the context is ignored. +func (c *requestContextMap) Update(req *http.Request, context Context) error { + c.lock.Lock() + defer c.lock.Unlock() + if _, ok := c.contexts[req]; !ok { + return errors.New("No context associated") + } + // TODO: ensure the new context is a descendant of the existing one + c.contexts[req] = context + return nil +} + +// init maps the request to the given context and returns true if there was no context associated with the request already. +// if a context was already associated with the request, it ignores the given context and returns false. +// init is intentionally unexported to ensure that all init calls are paired with a remove after a request is handled +func (c *requestContextMap) init(req *http.Request, context Context) bool { + c.lock.Lock() + defer c.lock.Unlock() + if _, exists := c.contexts[req]; exists { + return false + } + c.contexts[req] = context + return true +} + +// remove is intentionally unexported to ensure that the context is not removed until a request is handled +func (c *requestContextMap) remove(req *http.Request) { + c.lock.Lock() + defer c.lock.Unlock() + delete(c.contexts, req) +} + +// WithRequestContext ensures there is a Context object associated with the request before calling the passed handler. +// After the passed handler runs, the context is cleaned up. +func WithRequestContext(handler http.Handler, mapper RequestContextMapper) http.Handler { + rcMap, ok := mapper.(*requestContextMap) + if !ok { + glog.Fatal("Unknown RequestContextMapper implementation.") + } + + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if rcMap.init(req, NewContext()) { + // If we were the ones to successfully initialize, pair with a remove + defer rcMap.remove(req) + } + handler.ServeHTTP(w, req) + }) +} + +// IsEmpty returns true if there are no contexts registered, or an error if it could not be determined. Intended for use by tests. +func IsEmpty(requestsToContexts RequestContextMapper) (bool, error) { + if requestsToContexts, ok := requestsToContexts.(*requestContextMap); ok { + return len(requestsToContexts.contexts) == 0, nil + } + return true, errors.New("Unknown RequestContextMapper implementation") +} diff --git a/pkg/request/requestinfo.go b/pkg/request/requestinfo.go new file mode 100644 index 000000000..4f231319e --- /dev/null +++ b/pkg/request/requestinfo.go @@ -0,0 +1,241 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package request + +import ( + "fmt" + "net/http" + "strings" + + "k8s.io/apimachinery/pkg/util/sets" +) + +// RequestInfo holds information parsed from the http.Request +type RequestInfo struct { + // IsResourceRequest indicates whether or not the request is for an API resource or subresource + IsResourceRequest bool + // Path is the URL path of the request + Path string + // Verb is the kube verb associated with the request for API requests, not the http verb. This includes things like list and watch. + // for non-resource requests, this is the lowercase http verb + Verb string + + APIPrefix string + APIGroup string + APIVersion string + Namespace string + // Resource is the name of the resource being requested. This is not the kind. For example: pods + Resource string + // Subresource is the name of the subresource being requested. This is a different resource, scoped to the parent resource, but it may have a different kind. + // For instance, /pods has the resource "pods" and the kind "Pod", while /pods/foo/status has the resource "pods", the sub resource "status", and the kind "Pod" + // (because status operates on pods). The binding resource for a pod though may be /pods/foo/binding, which has resource "pods", subresource "binding", and kind "Binding". + Subresource string + // Name is empty for some verbs, but if the request directly indicates a name (not in body content) then this field is filled in. + Name string + // Parts are the path parts for the request, always starting with /{resource}/{name} + Parts []string +} + +// specialVerbs contains just strings which are used in REST paths for special actions that don't fall under the normal +// CRUDdy GET/POST/PUT/DELETE actions on REST objects. +// TODO: find a way to keep this up to date automatically. Maybe dynamically populate list as handlers added to +// master's Mux. +var specialVerbs = sets.NewString("proxy", "redirect", "watch") + +// specialVerbsNoSubresources contains root verbs which do not allow subresources +var specialVerbsNoSubresources = sets.NewString("proxy", "redirect") + +// namespaceSubresources contains subresources of namespace +// this list allows the parser to distinguish between a namespace subresource, and a namespaced resource +var namespaceSubresources = sets.NewString("status", "finalize") + +// NamespaceSubResourcesForTest exports namespaceSubresources for testing in pkg/master/master_test.go, so we never drift +var NamespaceSubResourcesForTest = sets.NewString(namespaceSubresources.List()...) + +type RequestInfoFactory struct { + APIPrefixes sets.String // without leading and trailing slashes + GrouplessAPIPrefixes sets.String // without leading and trailing slashes +} + +// TODO write an integration test against the swagger doc to test the RequestInfo and match up behavior to responses +// NewRequestInfo returns the information from the http request. If error is not nil, RequestInfo holds the information as best it is known before the failure +// It handles both resource and non-resource requests and fills in all the pertinent information for each. +// Valid Inputs: +// Resource paths +// /apis/{api-group}/{version}/namespaces +// /api/{version}/namespaces +// /api/{version}/namespaces/{namespace} +// /api/{version}/namespaces/{namespace}/{resource} +// /api/{version}/namespaces/{namespace}/{resource}/{resourceName} +// /api/{version}/{resource} +// /api/{version}/{resource}/{resourceName} +// +// Special verbs without subresources: +// /api/{version}/proxy/{resource}/{resourceName} +// /api/{version}/proxy/namespaces/{namespace}/{resource}/{resourceName} +// /api/{version}/redirect/namespaces/{namespace}/{resource}/{resourceName} +// /api/{version}/redirect/{resource}/{resourceName} +// +// Special verbs with subresources: +// /api/{version}/watch/{resource} +// /api/{version}/watch/namespaces/{namespace}/{resource} +// +// NonResource paths +// /apis/{api-group}/{version} +// /apis/{api-group} +// /apis +// /api/{version} +// /api +// /healthz +// / +func (r *RequestInfoFactory) NewRequestInfo(req *http.Request) (*RequestInfo, error) { + // start with a non-resource request until proven otherwise + requestInfo := RequestInfo{ + IsResourceRequest: false, + Path: req.URL.Path, + Verb: strings.ToLower(req.Method), + } + + currentParts := splitPath(req.URL.Path) + if len(currentParts) < 3 { + // return a non-resource request + return &requestInfo, nil + } + + if !r.APIPrefixes.Has(currentParts[0]) { + // return a non-resource request + return &requestInfo, nil + } + requestInfo.APIPrefix = currentParts[0] + currentParts = currentParts[1:] + + if !r.GrouplessAPIPrefixes.Has(requestInfo.APIPrefix) { + // one part (APIPrefix) has already been consumed, so this is actually "do we have four parts?" + if len(currentParts) < 3 { + // return a non-resource request + return &requestInfo, nil + } + + requestInfo.APIGroup = currentParts[0] + currentParts = currentParts[1:] + } + + requestInfo.IsResourceRequest = true + requestInfo.APIVersion = currentParts[0] + currentParts = currentParts[1:] + + // handle input of form /{specialVerb}/* + if specialVerbs.Has(currentParts[0]) { + if len(currentParts) < 2 { + return &requestInfo, fmt.Errorf("unable to determine kind and namespace from url, %v", req.URL) + } + + requestInfo.Verb = currentParts[0] + currentParts = currentParts[1:] + + } else { + switch req.Method { + case "POST": + requestInfo.Verb = "create" + case "GET", "HEAD": + requestInfo.Verb = "get" + case "PUT": + requestInfo.Verb = "update" + case "PATCH": + requestInfo.Verb = "patch" + case "DELETE": + requestInfo.Verb = "delete" + default: + requestInfo.Verb = "" + } + } + + // URL forms: /namespaces/{namespace}/{kind}/*, where parts are adjusted to be relative to kind + if currentParts[0] == "namespaces" { + if len(currentParts) > 1 { + requestInfo.Namespace = currentParts[1] + + // if there is another step after the namespace name and it is not a known namespace subresource + // move currentParts to include it as a resource in its own right + if len(currentParts) > 2 && !namespaceSubresources.Has(currentParts[2]) { + currentParts = currentParts[2:] + } + } + } else { + requestInfo.Namespace = "" // TODO(sttts): solve import cycle when using api.NamespaceNone + } + + // parsing successful, so we now know the proper value for .Parts + requestInfo.Parts = currentParts + + // parts look like: resource/resourceName/subresource/other/stuff/we/don't/interpret + switch { + case len(requestInfo.Parts) >= 3 && !specialVerbsNoSubresources.Has(requestInfo.Verb): + requestInfo.Subresource = requestInfo.Parts[2] + fallthrough + case len(requestInfo.Parts) >= 2: + requestInfo.Name = requestInfo.Parts[1] + fallthrough + case len(requestInfo.Parts) >= 1: + requestInfo.Resource = requestInfo.Parts[0] + } + + // if there's no name on the request and we thought it was a get before, then the actual verb is a list or a watch + if len(requestInfo.Name) == 0 && requestInfo.Verb == "get" { + // Assumes v1.ListOptions + // Duplicates logic of Convert_Slice_string_To_bool + switch strings.ToLower(req.URL.Query().Get("watch")) { + case "false", "0", "": + requestInfo.Verb = "list" + default: + requestInfo.Verb = "watch" + } + } + // if there's no name on the request and we thought it was a delete before, then the actual verb is deletecollection + if len(requestInfo.Name) == 0 && requestInfo.Verb == "delete" { + requestInfo.Verb = "deletecollection" + } + + return &requestInfo, nil +} + +type requestInfoKeyType int + +// requestInfoKey is the RequestInfo key for the context. It's of private type here. Because +// keys are interfaces and interfaces are equal when the type and the value is equal, this +// does not conflict with the keys defined in pkg/api. +const requestInfoKey requestInfoKeyType = iota + +// WithRequestInfo returns a copy of parent in which the request info value is set +func WithRequestInfo(parent Context, info *RequestInfo) Context { + return WithValue(parent, requestInfoKey, info) +} + +// RequestInfoFrom returns the value of the RequestInfo key on the ctx +func RequestInfoFrom(ctx Context) (*RequestInfo, bool) { + info, ok := ctx.Value(requestInfoKey).(*RequestInfo) + return info, ok +} + +// splitPath returns the segments for a URL path. +func splitPath(path string) []string { + path = strings.Trim(path, "/") + if path == "" { + return []string{} + } + return strings.Split(path, "/") +} diff --git a/pkg/request/requestinfo_test.go b/pkg/request/requestinfo_test.go new file mode 100644 index 000000000..6e8550f42 --- /dev/null +++ b/pkg/request/requestinfo_test.go @@ -0,0 +1,196 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package request + +import ( + "net/http" + "reflect" + "testing" + + "k8s.io/apimachinery/pkg/util/sets" +) + +type fakeRL bool + +func (fakeRL) Stop() {} +func (f fakeRL) TryAccept() bool { return bool(f) } +func (f fakeRL) Accept() {} + +func TestGetAPIRequestInfo(t *testing.T) { + namespaceAll := "" // TODO(sttts): solve import cycle when using api.NamespaceAll + successCases := []struct { + method string + url string + expectedVerb string + expectedAPIPrefix string + expectedAPIGroup string + expectedAPIVersion string + expectedNamespace string + expectedResource string + expectedSubresource string + expectedName string + expectedParts []string + }{ + + // resource paths + {"GET", "/api/v1/namespaces", "list", "api", "", "v1", "", "namespaces", "", "", []string{"namespaces"}}, + {"GET", "/api/v1/namespaces/other", "get", "api", "", "v1", "other", "namespaces", "", "other", []string{"namespaces", "other"}}, + + {"GET", "/api/v1/namespaces/other/pods", "list", "api", "", "v1", "other", "pods", "", "", []string{"pods"}}, + {"GET", "/api/v1/namespaces/other/pods/foo", "get", "api", "", "v1", "other", "pods", "", "foo", []string{"pods", "foo"}}, + {"HEAD", "/api/v1/namespaces/other/pods/foo", "get", "api", "", "v1", "other", "pods", "", "foo", []string{"pods", "foo"}}, + {"GET", "/api/v1/pods", "list", "api", "", "v1", namespaceAll, "pods", "", "", []string{"pods"}}, + {"HEAD", "/api/v1/pods", "list", "api", "", "v1", namespaceAll, "pods", "", "", []string{"pods"}}, + {"GET", "/api/v1/namespaces/other/pods/foo", "get", "api", "", "v1", "other", "pods", "", "foo", []string{"pods", "foo"}}, + {"GET", "/api/v1/namespaces/other/pods", "list", "api", "", "v1", "other", "pods", "", "", []string{"pods"}}, + + // special verbs + {"GET", "/api/v1/proxy/namespaces/other/pods/foo", "proxy", "api", "", "v1", "other", "pods", "", "foo", []string{"pods", "foo"}}, + {"GET", "/api/v1/proxy/namespaces/other/pods/foo/subpath/not/a/subresource", "proxy", "api", "", "v1", "other", "pods", "", "foo", []string{"pods", "foo", "subpath", "not", "a", "subresource"}}, + {"GET", "/api/v1/redirect/namespaces/other/pods/foo", "redirect", "api", "", "v1", "other", "pods", "", "foo", []string{"pods", "foo"}}, + {"GET", "/api/v1/redirect/namespaces/other/pods/foo/subpath/not/a/subresource", "redirect", "api", "", "v1", "other", "pods", "", "foo", []string{"pods", "foo", "subpath", "not", "a", "subresource"}}, + {"GET", "/api/v1/watch/pods", "watch", "api", "", "v1", namespaceAll, "pods", "", "", []string{"pods"}}, + {"GET", "/api/v1/pods?watch=true", "watch", "api", "", "v1", namespaceAll, "pods", "", "", []string{"pods"}}, + {"GET", "/api/v1/pods?watch=false", "list", "api", "", "v1", namespaceAll, "pods", "", "", []string{"pods"}}, + {"GET", "/api/v1/watch/namespaces/other/pods", "watch", "api", "", "v1", "other", "pods", "", "", []string{"pods"}}, + {"GET", "/api/v1/namespaces/other/pods?watch=1", "watch", "api", "", "v1", "other", "pods", "", "", []string{"pods"}}, + {"GET", "/api/v1/namespaces/other/pods?watch=0", "list", "api", "", "v1", "other", "pods", "", "", []string{"pods"}}, + + // subresource identification + {"GET", "/api/v1/namespaces/other/pods/foo/status", "get", "api", "", "v1", "other", "pods", "status", "foo", []string{"pods", "foo", "status"}}, + {"GET", "/api/v1/namespaces/other/pods/foo/proxy/subpath", "get", "api", "", "v1", "other", "pods", "proxy", "foo", []string{"pods", "foo", "proxy", "subpath"}}, + {"PUT", "/api/v1/namespaces/other/finalize", "update", "api", "", "v1", "other", "namespaces", "finalize", "other", []string{"namespaces", "other", "finalize"}}, + {"PUT", "/api/v1/namespaces/other/status", "update", "api", "", "v1", "other", "namespaces", "status", "other", []string{"namespaces", "other", "status"}}, + + // verb identification + {"PATCH", "/api/v1/namespaces/other/pods/foo", "patch", "api", "", "v1", "other", "pods", "", "foo", []string{"pods", "foo"}}, + {"DELETE", "/api/v1/namespaces/other/pods/foo", "delete", "api", "", "v1", "other", "pods", "", "foo", []string{"pods", "foo"}}, + {"POST", "/api/v1/namespaces/other/pods", "create", "api", "", "v1", "other", "pods", "", "", []string{"pods"}}, + + // deletecollection verb identification + {"DELETE", "/api/v1/nodes", "deletecollection", "api", "", "v1", "", "nodes", "", "", []string{"nodes"}}, + {"DELETE", "/api/v1/namespaces", "deletecollection", "api", "", "v1", "", "namespaces", "", "", []string{"namespaces"}}, + {"DELETE", "/api/v1/namespaces/other/pods", "deletecollection", "api", "", "v1", "other", "pods", "", "", []string{"pods"}}, + {"DELETE", "/apis/extensions/v1/namespaces/other/pods", "deletecollection", "api", "extensions", "v1", "other", "pods", "", "", []string{"pods"}}, + + // api group identification + {"POST", "/apis/extensions/v1/namespaces/other/pods", "create", "api", "extensions", "v1", "other", "pods", "", "", []string{"pods"}}, + + // api version identification + {"POST", "/apis/extensions/v1beta3/namespaces/other/pods", "create", "api", "extensions", "v1beta3", "other", "pods", "", "", []string{"pods"}}, + } + + resolver := newTestRequestInfoResolver() + + for _, successCase := range successCases { + req, _ := http.NewRequest(successCase.method, successCase.url, nil) + + apiRequestInfo, err := resolver.NewRequestInfo(req) + if err != nil { + t.Errorf("Unexpected error for url: %s %v", successCase.url, err) + } + if !apiRequestInfo.IsResourceRequest { + t.Errorf("Expected resource request") + } + if successCase.expectedVerb != apiRequestInfo.Verb { + t.Errorf("Unexpected verb for url: %s, expected: %s, actual: %s", successCase.url, successCase.expectedVerb, apiRequestInfo.Verb) + } + if successCase.expectedAPIVersion != apiRequestInfo.APIVersion { + t.Errorf("Unexpected apiVersion for url: %s, expected: %s, actual: %s", successCase.url, successCase.expectedAPIVersion, apiRequestInfo.APIVersion) + } + if successCase.expectedNamespace != apiRequestInfo.Namespace { + t.Errorf("Unexpected namespace for url: %s, expected: %s, actual: %s", successCase.url, successCase.expectedNamespace, apiRequestInfo.Namespace) + } + if successCase.expectedResource != apiRequestInfo.Resource { + t.Errorf("Unexpected resource for url: %s, expected: %s, actual: %s", successCase.url, successCase.expectedResource, apiRequestInfo.Resource) + } + if successCase.expectedSubresource != apiRequestInfo.Subresource { + t.Errorf("Unexpected resource for url: %s, expected: %s, actual: %s", successCase.url, successCase.expectedSubresource, apiRequestInfo.Subresource) + } + if successCase.expectedName != apiRequestInfo.Name { + t.Errorf("Unexpected name for url: %s, expected: %s, actual: %s", successCase.url, successCase.expectedName, apiRequestInfo.Name) + } + if !reflect.DeepEqual(successCase.expectedParts, apiRequestInfo.Parts) { + t.Errorf("Unexpected parts for url: %s, expected: %v, actual: %v", successCase.url, successCase.expectedParts, apiRequestInfo.Parts) + } + } + + errorCases := map[string]string{ + "no resource path": "/", + "just apiversion": "/api/version/", + "just prefix, group, version": "/apis/group/version/", + "apiversion with no resource": "/api/version/", + "bad prefix": "/badprefix/version/resource", + "missing api group": "/apis/version/resource", + } + for k, v := range errorCases { + req, err := http.NewRequest("GET", v, nil) + if err != nil { + t.Errorf("Unexpected error %v", err) + } + apiRequestInfo, err := resolver.NewRequestInfo(req) + if err != nil { + t.Errorf("%s: Unexpected error %v", k, err) + } + if apiRequestInfo.IsResourceRequest { + t.Errorf("%s: expected non-resource request", k) + } + } +} + +func TestGetNonAPIRequestInfo(t *testing.T) { + tests := map[string]struct { + url string + expected bool + }{ + "simple groupless": {"/api/version/resource", true}, + "simple group": {"/apis/group/version/resource/name/subresource", true}, + "more steps": {"/api/version/resource/name/subresource", true}, + "group list": {"/apis/batch/v1/job", true}, + "group get": {"/apis/batch/v1/job/foo", true}, + "group subresource": {"/apis/batch/v1/job/foo/scale", true}, + + "bad root": {"/not-api/version/resource", false}, + "group without enough steps": {"/apis/extensions/v1beta1", false}, + "group without enough steps 2": {"/apis/extensions/v1beta1/", false}, + "not enough steps": {"/api/version", false}, + "one step": {"/api", false}, + "zero step": {"/", false}, + "empty": {"", false}, + } + + resolver := newTestRequestInfoResolver() + + for testName, tc := range tests { + req, _ := http.NewRequest("GET", tc.url, nil) + + apiRequestInfo, err := resolver.NewRequestInfo(req) + if err != nil { + t.Errorf("%s: Unexpected error %v", testName, err) + } + if e, a := tc.expected, apiRequestInfo.IsResourceRequest; e != a { + t.Errorf("%s: expected %v, actual %v", testName, e, a) + } + } +} + +func newTestRequestInfoResolver() *RequestInfoFactory { + return &RequestInfoFactory{ + APIPrefixes: sets.NewString("api", "apis"), + GrouplessAPIPrefixes: sets.NewString("api"), + } +} diff --git a/pkg/util/wsstream/conn.go b/pkg/util/wsstream/conn.go new file mode 100644 index 000000000..f01638ad6 --- /dev/null +++ b/pkg/util/wsstream/conn.go @@ -0,0 +1,349 @@ +/* +Copyright 2015 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package wsstream + +import ( + "encoding/base64" + "fmt" + "io" + "net/http" + "regexp" + "strings" + "time" + + "github.com/golang/glog" + "golang.org/x/net/websocket" + + "k8s.io/apimachinery/pkg/util/runtime" +) + +// The Websocket subprotocol "channel.k8s.io" prepends each binary message with a byte indicating +// the channel number (zero indexed) the message was sent on. Messages in both directions should +// prefix their messages with this channel byte. When used for remote execution, the channel numbers +// are by convention defined to match the POSIX file-descriptors assigned to STDIN, STDOUT, and STDERR +// (0, 1, and 2). No other conversion is performed on the raw subprotocol - writes are sent as they +// are received by the server. +// +// Example client session: +// +// CONNECT http://server.com with subprotocol "channel.k8s.io" +// WRITE []byte{0, 102, 111, 111, 10} # send "foo\n" on channel 0 (STDIN) +// READ []byte{1, 10} # receive "\n" on channel 1 (STDOUT) +// CLOSE +// +const ChannelWebSocketProtocol = "channel.k8s.io" + +// The Websocket subprotocol "base64.channel.k8s.io" base64 encodes each message with a character +// indicating the channel number (zero indexed) the message was sent on. Messages in both directions +// should prefix their messages with this channel char. When used for remote execution, the channel +// numbers are by convention defined to match the POSIX file-descriptors assigned to STDIN, STDOUT, +// and STDERR ('0', '1', and '2'). The data received on the server is base64 decoded (and must be +// be valid) and data written by the server to the client is base64 encoded. +// +// Example client session: +// +// CONNECT http://server.com with subprotocol "base64.channel.k8s.io" +// WRITE []byte{48, 90, 109, 57, 118, 67, 103, 111, 61} # send "foo\n" (base64: "Zm9vCgo=") on channel '0' (STDIN) +// READ []byte{49, 67, 103, 61, 61} # receive "\n" (base64: "Cg==") on channel '1' (STDOUT) +// CLOSE +// +const Base64ChannelWebSocketProtocol = "base64.channel.k8s.io" + +type codecType int + +const ( + rawCodec codecType = iota + base64Codec +) + +type ChannelType int + +const ( + IgnoreChannel ChannelType = iota + ReadChannel + WriteChannel + ReadWriteChannel +) + +var ( + // connectionUpgradeRegex matches any Connection header value that includes upgrade + connectionUpgradeRegex = regexp.MustCompile("(^|.*,\\s*)upgrade($|\\s*,)") +) + +// IsWebSocketRequest returns true if the incoming request contains connection upgrade headers +// for WebSockets. +func IsWebSocketRequest(req *http.Request) bool { + return connectionUpgradeRegex.MatchString(strings.ToLower(req.Header.Get("Connection"))) && strings.ToLower(req.Header.Get("Upgrade")) == "websocket" +} + +// IgnoreReceives reads from a WebSocket until it is closed, then returns. If timeout is set, the +// read and write deadlines are pushed every time a new message is received. +func IgnoreReceives(ws *websocket.Conn, timeout time.Duration) { + defer runtime.HandleCrash() + var data []byte + for { + resetTimeout(ws, timeout) + if err := websocket.Message.Receive(ws, &data); err != nil { + return + } + } +} + +// handshake ensures the provided user protocol matches one of the allowed protocols. It returns +// no error if no protocol is specified. +func handshake(config *websocket.Config, req *http.Request, allowed []string) error { + protocols := config.Protocol + if len(protocols) == 0 { + protocols = []string{""} + } + + for _, protocol := range protocols { + for _, allow := range allowed { + if allow == protocol { + config.Protocol = []string{protocol} + return nil + } + } + } + + return fmt.Errorf("requested protocol(s) are not supported: %v; supports %v", config.Protocol, allowed) +} + +// ChannelProtocolConfig describes a websocket subprotocol with channels. +type ChannelProtocolConfig struct { + Binary bool + Channels []ChannelType +} + +// NewDefaultChannelProtocols returns a channel protocol map with the +// subprotocols "", "channel.k8s.io", "base64.channel.k8s.io" and the given +// channels. +func NewDefaultChannelProtocols(channels []ChannelType) map[string]ChannelProtocolConfig { + return map[string]ChannelProtocolConfig{ + "": {Binary: true, Channels: channels}, + ChannelWebSocketProtocol: {Binary: true, Channels: channels}, + Base64ChannelWebSocketProtocol: {Binary: false, Channels: channels}, + } +} + +// Conn supports sending multiple binary channels over a websocket connection. +type Conn struct { + protocols map[string]ChannelProtocolConfig + selectedProtocol string + channels []*websocketChannel + codec codecType + ready chan struct{} + ws *websocket.Conn + timeout time.Duration +} + +// NewConn creates a WebSocket connection that supports a set of channels. Channels begin each +// web socket message with a single byte indicating the channel number (0-N). 255 is reserved for +// future use. The channel types for each channel are passed as an array, supporting the different +// duplex modes. Read and Write refer to whether the channel can be used as a Reader or Writer. +// +// The protocols parameter maps subprotocol names to ChannelProtocols. The empty string subprotocol +// name is used if websocket.Config.Protocol is empty. +func NewConn(protocols map[string]ChannelProtocolConfig) *Conn { + return &Conn{ + ready: make(chan struct{}), + protocols: protocols, + } +} + +// SetIdleTimeout sets the interval for both reads and writes before timeout. If not specified, +// there is no timeout on the connection. +func (conn *Conn) SetIdleTimeout(duration time.Duration) { + conn.timeout = duration +} + +// Open the connection and create channels for reading and writing. It returns +// the selected subprotocol, a slice of channels and an error. +func (conn *Conn) Open(w http.ResponseWriter, req *http.Request) (string, []io.ReadWriteCloser, error) { + go func() { + defer runtime.HandleCrash() + defer conn.Close() + websocket.Server{Handshake: conn.handshake, Handler: conn.handle}.ServeHTTP(w, req) + }() + <-conn.ready + rwc := make([]io.ReadWriteCloser, len(conn.channels)) + for i := range conn.channels { + rwc[i] = conn.channels[i] + } + return conn.selectedProtocol, rwc, nil +} + +func (conn *Conn) initialize(ws *websocket.Conn) { + negotiated := ws.Config().Protocol + conn.selectedProtocol = negotiated[0] + p := conn.protocols[conn.selectedProtocol] + if p.Binary { + conn.codec = rawCodec + } else { + conn.codec = base64Codec + } + conn.ws = ws + conn.channels = make([]*websocketChannel, len(p.Channels)) + for i, t := range p.Channels { + switch t { + case ReadChannel: + conn.channels[i] = newWebsocketChannel(conn, byte(i), true, false) + case WriteChannel: + conn.channels[i] = newWebsocketChannel(conn, byte(i), false, true) + case ReadWriteChannel: + conn.channels[i] = newWebsocketChannel(conn, byte(i), true, true) + case IgnoreChannel: + conn.channels[i] = newWebsocketChannel(conn, byte(i), false, false) + } + } + + close(conn.ready) +} + +func (conn *Conn) handshake(config *websocket.Config, req *http.Request) error { + supportedProtocols := make([]string, 0, len(conn.protocols)) + for p := range conn.protocols { + supportedProtocols = append(supportedProtocols, p) + } + return handshake(config, req, supportedProtocols) +} + +func (conn *Conn) resetTimeout() { + if conn.timeout > 0 { + conn.ws.SetDeadline(time.Now().Add(conn.timeout)) + } +} + +// Close is only valid after Open has been called +func (conn *Conn) Close() error { + <-conn.ready + for _, s := range conn.channels { + s.Close() + } + conn.ws.Close() + return nil +} + +// handle implements a websocket handler. +func (conn *Conn) handle(ws *websocket.Conn) { + defer conn.Close() + conn.initialize(ws) + + for { + conn.resetTimeout() + var data []byte + if err := websocket.Message.Receive(ws, &data); err != nil { + if err != io.EOF { + glog.Errorf("Error on socket receive: %v", err) + } + break + } + if len(data) == 0 { + continue + } + channel := data[0] + if conn.codec == base64Codec { + channel = channel - '0' + } + data = data[1:] + if int(channel) >= len(conn.channels) { + glog.V(6).Infof("Frame is targeted for a reader %d that is not valid, possible protocol error", channel) + continue + } + if _, err := conn.channels[channel].DataFromSocket(data); err != nil { + glog.Errorf("Unable to write frame to %d: %v\n%s", channel, err, string(data)) + continue + } + } +} + +// write multiplexes the specified channel onto the websocket +func (conn *Conn) write(num byte, data []byte) (int, error) { + conn.resetTimeout() + switch conn.codec { + case rawCodec: + frame := make([]byte, len(data)+1) + frame[0] = num + copy(frame[1:], data) + if err := websocket.Message.Send(conn.ws, frame); err != nil { + return 0, err + } + case base64Codec: + frame := string('0'+num) + base64.StdEncoding.EncodeToString(data) + if err := websocket.Message.Send(conn.ws, frame); err != nil { + return 0, err + } + } + return len(data), nil +} + +// websocketChannel represents a channel in a connection +type websocketChannel struct { + conn *Conn + num byte + r io.Reader + w io.WriteCloser + + read, write bool +} + +// newWebsocketChannel creates a pipe for writing to a websocket. Do not write to this pipe +// prior to the connection being opened. It may be no, half, or full duplex depending on +// read and write. +func newWebsocketChannel(conn *Conn, num byte, read, write bool) *websocketChannel { + r, w := io.Pipe() + return &websocketChannel{conn, num, r, w, read, write} +} + +func (p *websocketChannel) Write(data []byte) (int, error) { + if !p.write { + return len(data), nil + } + return p.conn.write(p.num, data) +} + +// DataFromSocket is invoked by the connection receiver to move data from the connection +// into a specific channel. +func (p *websocketChannel) DataFromSocket(data []byte) (int, error) { + if !p.read { + return len(data), nil + } + + switch p.conn.codec { + case rawCodec: + return p.w.Write(data) + case base64Codec: + dst := make([]byte, len(data)) + n, err := base64.StdEncoding.Decode(dst, data) + if err != nil { + return 0, err + } + return p.w.Write(dst[:n]) + } + return 0, nil +} + +func (p *websocketChannel) Read(data []byte) (int, error) { + if !p.read { + return 0, io.EOF + } + return p.r.Read(data) +} + +func (p *websocketChannel) Close() error { + return p.w.Close() +} diff --git a/pkg/util/wsstream/conn_test.go b/pkg/util/wsstream/conn_test.go new file mode 100644 index 000000000..1c049aad7 --- /dev/null +++ b/pkg/util/wsstream/conn_test.go @@ -0,0 +1,272 @@ +/* +Copyright 2015 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package wsstream + +import ( + "encoding/base64" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "reflect" + "sync" + "testing" + + "golang.org/x/net/websocket" +) + +func newServer(handler http.Handler) (*httptest.Server, string) { + server := httptest.NewServer(handler) + serverAddr := server.Listener.Addr().String() + return server, serverAddr +} + +func TestRawConn(t *testing.T) { + channels := []ChannelType{ReadWriteChannel, ReadWriteChannel, IgnoreChannel, ReadChannel, WriteChannel} + conn := NewConn(NewDefaultChannelProtocols(channels)) + + s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + conn.Open(w, req) + })) + defer s.Close() + + client, err := websocket.Dial("ws://"+addr, "", "http://localhost/") + if err != nil { + t.Fatal(err) + } + defer client.Close() + + <-conn.ready + wg := sync.WaitGroup{} + + // verify we can read a client write + wg.Add(1) + go func() { + defer wg.Done() + data, err := ioutil.ReadAll(conn.channels[0]) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(data, []byte("client")) { + t.Errorf("unexpected server read: %v", data) + } + }() + + if n, err := client.Write(append([]byte{0}, []byte("client")...)); err != nil || n != 7 { + t.Fatalf("%d: %v", n, err) + } + + // verify we can read a server write + wg.Add(1) + go func() { + defer wg.Done() + if n, err := conn.channels[1].Write([]byte("server")); err != nil && n != 6 { + t.Fatalf("%d: %v", n, err) + } + }() + + data := make([]byte, 1024) + if n, err := io.ReadAtLeast(client, data, 6); n != 7 || err != nil { + t.Fatalf("%d: %v", n, err) + } + if !reflect.DeepEqual(data[:7], append([]byte{1}, []byte("server")...)) { + t.Errorf("unexpected client read: %v", data[:7]) + } + + // verify that an ignore channel is empty in both directions. + if n, err := conn.channels[2].Write([]byte("test")); n != 4 || err != nil { + t.Errorf("writes should be ignored") + } + data = make([]byte, 1024) + if n, err := conn.channels[2].Read(data); n != 0 || err != io.EOF { + t.Errorf("reads should be ignored") + } + + // verify that a write to a Read channel doesn't block + if n, err := conn.channels[3].Write([]byte("test")); n != 4 || err != nil { + t.Errorf("writes should be ignored") + } + + // verify that a read from a Write channel doesn't block + data = make([]byte, 1024) + if n, err := conn.channels[4].Read(data); n != 0 || err != io.EOF { + t.Errorf("reads should be ignored") + } + + // verify that a client write to a Write channel doesn't block (is dropped) + if n, err := client.Write(append([]byte{4}, []byte("ignored")...)); err != nil || n != 8 { + t.Fatalf("%d: %v", n, err) + } + + client.Close() + wg.Wait() +} + +func TestBase64Conn(t *testing.T) { + conn := NewConn(NewDefaultChannelProtocols([]ChannelType{ReadWriteChannel, ReadWriteChannel})) + s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + conn.Open(w, req) + })) + defer s.Close() + + config, err := websocket.NewConfig("ws://"+addr, "http://localhost/") + if err != nil { + t.Fatal(err) + } + config.Protocol = []string{"base64.channel.k8s.io"} + client, err := websocket.DialConfig(config) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + <-conn.ready + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + data, err := ioutil.ReadAll(conn.channels[0]) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(data, []byte("client")) { + t.Errorf("unexpected server read: %s", string(data)) + } + }() + + clientData := base64.StdEncoding.EncodeToString([]byte("client")) + if n, err := client.Write(append([]byte{'0'}, clientData...)); err != nil || n != len(clientData)+1 { + t.Fatalf("%d: %v", n, err) + } + + wg.Add(1) + go func() { + defer wg.Done() + if n, err := conn.channels[1].Write([]byte("server")); err != nil && n != 6 { + t.Fatalf("%d: %v", n, err) + } + }() + + data := make([]byte, 1024) + if n, err := io.ReadAtLeast(client, data, 9); n != 9 || err != nil { + t.Fatalf("%d: %v", n, err) + } + expect := []byte(base64.StdEncoding.EncodeToString([]byte("server"))) + + if !reflect.DeepEqual(data[:9], append([]byte{'1'}, expect...)) { + t.Errorf("unexpected client read: %v", data[:9]) + } + + client.Close() + wg.Wait() +} + +type versionTest struct { + supported map[string]bool // protocol -> binary + requested []string + error bool + expected string +} + +func versionTests() []versionTest { + const ( + binary = true + base64 = false + ) + return []versionTest{ + { + supported: nil, + requested: []string{"raw"}, + error: true, + }, + { + supported: map[string]bool{"": binary, "raw": binary, "base64": base64}, + requested: nil, + expected: "", + }, + { + supported: map[string]bool{"": binary, "raw": binary, "base64": base64}, + requested: []string{"v1.raw"}, + error: true, + }, + { + supported: map[string]bool{"": binary, "raw": binary, "base64": base64}, + requested: []string{"v1.raw", "v1.base64"}, + error: true, + }, { + supported: map[string]bool{"": binary, "raw": binary, "base64": base64}, + requested: []string{"v1.raw", "raw"}, + expected: "raw", + }, + { + supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64}, + requested: []string{"v1.raw"}, + expected: "v1.raw", + }, + { + supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64}, + requested: []string{"v2.base64"}, + expected: "v2.base64", + }, + } +} + +func TestVersionedConn(t *testing.T) { + for i, test := range versionTests() { + func() { + supportedProtocols := map[string]ChannelProtocolConfig{} + for p, binary := range test.supported { + supportedProtocols[p] = ChannelProtocolConfig{ + Binary: binary, + Channels: []ChannelType{ReadWriteChannel}, + } + } + conn := NewConn(supportedProtocols) + // note that it's not enough to wait for conn.ready to avoid a race here. Hence, + // we use a channel. + selectedProtocol := make(chan string, 0) + s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + p, _, _ := conn.Open(w, req) + selectedProtocol <- p + })) + defer s.Close() + + config, err := websocket.NewConfig("ws://"+addr, "http://localhost/") + if err != nil { + t.Fatal(err) + } + config.Protocol = test.requested + client, err := websocket.DialConfig(config) + if err != nil { + if !test.error { + t.Fatalf("test %d: didn't expect error: %v", i, err) + } else { + return + } + } + defer client.Close() + if test.error && err == nil { + t.Fatalf("test %d: expected an error", i) + } + + <-conn.ready + if got, expected := <-selectedProtocol, test.expected; got != expected { + t.Fatalf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected) + } + }() + } +} diff --git a/pkg/util/wsstream/doc.go b/pkg/util/wsstream/doc.go new file mode 100644 index 000000000..3bda93e26 --- /dev/null +++ b/pkg/util/wsstream/doc.go @@ -0,0 +1,21 @@ +/* +Copyright 2015 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package wsstream contains utilities for streaming content over WebSockets. +// The Conn type allows callers to multiplex multiple read/write channels over +// a single websocket. The Reader type allows an io.Reader to be copied over +// a websocket channel as binary content. +package wsstream // import "k8s.io/kubernetes/pkg/util/wsstream" diff --git a/pkg/util/wsstream/stream.go b/pkg/util/wsstream/stream.go new file mode 100644 index 000000000..9dd165bfa --- /dev/null +++ b/pkg/util/wsstream/stream.go @@ -0,0 +1,177 @@ +/* +Copyright 2015 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package wsstream + +import ( + "encoding/base64" + "io" + "net/http" + "sync" + "time" + + "golang.org/x/net/websocket" + + "k8s.io/apimachinery/pkg/util/runtime" +) + +// The WebSocket subprotocol "binary.k8s.io" will only send messages to the +// client and ignore messages sent to the server. The received messages are +// the exact bytes written to the stream. Zero byte messages are possible. +const binaryWebSocketProtocol = "binary.k8s.io" + +// The WebSocket subprotocol "base64.binary.k8s.io" will only send messages to the +// client and ignore messages sent to the server. The received messages are +// a base64 version of the bytes written to the stream. Zero byte messages are +// possible. +const base64BinaryWebSocketProtocol = "base64.binary.k8s.io" + +// ReaderProtocolConfig describes a websocket subprotocol with one stream. +type ReaderProtocolConfig struct { + Binary bool +} + +// NewDefaultReaderProtocols returns a stream protocol map with the +// subprotocols "", "channel.k8s.io", "base64.channel.k8s.io". +func NewDefaultReaderProtocols() map[string]ReaderProtocolConfig { + return map[string]ReaderProtocolConfig{ + "": {Binary: true}, + binaryWebSocketProtocol: {Binary: true}, + base64BinaryWebSocketProtocol: {Binary: false}, + } +} + +// Reader supports returning an arbitrary byte stream over a websocket channel. +type Reader struct { + err chan error + r io.Reader + ping bool + timeout time.Duration + protocols map[string]ReaderProtocolConfig + selectedProtocol string + + handleCrash func() // overridable for testing +} + +// NewReader creates a WebSocket pipe that will copy the contents of r to a provided +// WebSocket connection. If ping is true, a zero length message will be sent to the client +// before the stream begins reading. +// +// The protocols parameter maps subprotocol names to StreamProtocols. The empty string +// subprotocol name is used if websocket.Config.Protocol is empty. +func NewReader(r io.Reader, ping bool, protocols map[string]ReaderProtocolConfig) *Reader { + return &Reader{ + r: r, + err: make(chan error), + ping: ping, + protocols: protocols, + handleCrash: func() { runtime.HandleCrash() }, + } +} + +// SetIdleTimeout sets the interval for both reads and writes before timeout. If not specified, +// there is no timeout on the reader. +func (r *Reader) SetIdleTimeout(duration time.Duration) { + r.timeout = duration +} + +func (r *Reader) handshake(config *websocket.Config, req *http.Request) error { + supportedProtocols := make([]string, 0, len(r.protocols)) + for p := range r.protocols { + supportedProtocols = append(supportedProtocols, p) + } + return handshake(config, req, supportedProtocols) +} + +// Copy the reader to the response. The created WebSocket is closed after this +// method completes. +func (r *Reader) Copy(w http.ResponseWriter, req *http.Request) error { + go func() { + defer r.handleCrash() + websocket.Server{Handshake: r.handshake, Handler: r.handle}.ServeHTTP(w, req) + }() + return <-r.err +} + +// handle implements a WebSocket handler. +func (r *Reader) handle(ws *websocket.Conn) { + // Close the connection when the client requests it, or when we finish streaming, whichever happens first + closeConnOnce := &sync.Once{} + closeConn := func() { + closeConnOnce.Do(func() { + ws.Close() + }) + } + + negotiated := ws.Config().Protocol + r.selectedProtocol = negotiated[0] + defer close(r.err) + defer closeConn() + + go func() { + defer runtime.HandleCrash() + // This blocks until the connection is closed. + // Client should not send anything. + IgnoreReceives(ws, r.timeout) + // Once the client closes, we should also close + closeConn() + }() + + r.err <- messageCopy(ws, r.r, !r.protocols[r.selectedProtocol].Binary, r.ping, r.timeout) +} + +func resetTimeout(ws *websocket.Conn, timeout time.Duration) { + if timeout > 0 { + ws.SetDeadline(time.Now().Add(timeout)) + } +} + +func messageCopy(ws *websocket.Conn, r io.Reader, base64Encode, ping bool, timeout time.Duration) error { + buf := make([]byte, 2048) + if ping { + resetTimeout(ws, timeout) + if base64Encode { + if err := websocket.Message.Send(ws, ""); err != nil { + return err + } + } else { + if err := websocket.Message.Send(ws, []byte{}); err != nil { + return err + } + } + } + for { + resetTimeout(ws, timeout) + n, err := r.Read(buf) + if err != nil { + if err == io.EOF { + return nil + } + return err + } + if n > 0 { + if base64Encode { + if err := websocket.Message.Send(ws, base64.StdEncoding.EncodeToString(buf[:n])); err != nil { + return err + } + } else { + if err := websocket.Message.Send(ws, buf[:n]); err != nil { + return err + } + } + } + } +} diff --git a/pkg/util/wsstream/stream_test.go b/pkg/util/wsstream/stream_test.go new file mode 100644 index 000000000..09dda761f --- /dev/null +++ b/pkg/util/wsstream/stream_test.go @@ -0,0 +1,294 @@ +/* +Copyright 2015 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package wsstream + +import ( + "bytes" + "encoding/base64" + "fmt" + "io" + "io/ioutil" + "net/http" + "reflect" + "strings" + "testing" + "time" + + "golang.org/x/net/websocket" +) + +func TestStream(t *testing.T) { + input := "some random text" + r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols()) + r.SetIdleTimeout(time.Second) + data, err := readWebSocket(r, t, nil) + if !reflect.DeepEqual(data, []byte(input)) { + t.Errorf("unexpected server read: %v", data) + } + if err != nil { + t.Fatal(err) + } +} + +func TestStreamPing(t *testing.T) { + input := "some random text" + r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols()) + r.SetIdleTimeout(time.Second) + err := expectWebSocketFrames(r, t, nil, [][]byte{ + {}, + []byte(input), + }) + if err != nil { + t.Fatal(err) + } +} + +func TestStreamBase64(t *testing.T) { + input := "some random text" + encoded := base64.StdEncoding.EncodeToString([]byte(input)) + r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols()) + data, err := readWebSocket(r, t, nil, "base64.binary.k8s.io") + if !reflect.DeepEqual(data, []byte(encoded)) { + t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded)) + } + if err != nil { + t.Fatal(err) + } +} + +func TestStreamVersionedBase64(t *testing.T) { + input := "some random text" + encoded := base64.StdEncoding.EncodeToString([]byte(input)) + r := NewReader(bytes.NewBuffer([]byte(input)), true, map[string]ReaderProtocolConfig{ + "": {Binary: true}, + "binary.k8s.io": {Binary: true}, + "base64.binary.k8s.io": {Binary: false}, + "v1.binary.k8s.io": {Binary: true}, + "v1.base64.binary.k8s.io": {Binary: false}, + "v2.binary.k8s.io": {Binary: true}, + "v2.base64.binary.k8s.io": {Binary: false}, + }) + data, err := readWebSocket(r, t, nil, "v2.base64.binary.k8s.io") + if !reflect.DeepEqual(data, []byte(encoded)) { + t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded)) + } + if err != nil { + t.Fatal(err) + } +} + +func TestStreamVersionedCopy(t *testing.T) { + for i, test := range versionTests() { + func() { + supportedProtocols := map[string]ReaderProtocolConfig{} + for p, binary := range test.supported { + supportedProtocols[p] = ReaderProtocolConfig{ + Binary: binary, + } + } + input := "some random text" + r := NewReader(bytes.NewBuffer([]byte(input)), true, supportedProtocols) + s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + err := r.Copy(w, req) + if err != nil { + w.WriteHeader(503) + } + })) + defer s.Close() + + config, err := websocket.NewConfig("ws://"+addr, "http://localhost/") + if err != nil { + t.Error(err) + return + } + config.Protocol = test.requested + client, err := websocket.DialConfig(config) + if err != nil { + if !test.error { + t.Errorf("test %d: didn't expect error: %v", i, err) + } + return + } + defer client.Close() + if test.error && err == nil { + t.Errorf("test %d: expected an error", i) + return + } + + <-r.err + if got, expected := r.selectedProtocol, test.expected; got != expected { + t.Errorf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected) + } + }() + } +} + +func TestStreamError(t *testing.T) { + input := "some random text" + errs := &errorReader{ + reads: [][]byte{ + []byte("some random"), + []byte(" text"), + }, + err: fmt.Errorf("bad read"), + } + r := NewReader(errs, false, NewDefaultReaderProtocols()) + + data, err := readWebSocket(r, t, nil) + if !reflect.DeepEqual(data, []byte(input)) { + t.Errorf("unexpected server read: %v", data) + } + if err == nil || err.Error() != "bad read" { + t.Fatal(err) + } +} + +func TestStreamSurvivesPanic(t *testing.T) { + input := "some random text" + errs := &errorReader{ + reads: [][]byte{ + []byte("some random"), + []byte(" text"), + }, + panicMessage: "bad read", + } + r := NewReader(errs, false, NewDefaultReaderProtocols()) + + // do not call runtime.HandleCrash() in handler. Otherwise, the tests are interrupted. + r.handleCrash = func() { recover() } + + data, err := readWebSocket(r, t, nil) + if !reflect.DeepEqual(data, []byte(input)) { + t.Errorf("unexpected server read: %v", data) + } + if err != nil { + t.Fatal(err) + } +} + +func TestStreamClosedDuringRead(t *testing.T) { + for i := 0; i < 25; i++ { + ch := make(chan struct{}) + input := "some random text" + errs := &errorReader{ + reads: [][]byte{ + []byte("some random"), + []byte(" text"), + }, + err: fmt.Errorf("stuff"), + pause: ch, + } + r := NewReader(errs, false, NewDefaultReaderProtocols()) + + data, err := readWebSocket(r, t, func(c *websocket.Conn) { + c.Close() + close(ch) + }) + // verify that the data returned by the server on an early close always has a specific error + if err == nil || !strings.Contains(err.Error(), "use of closed network connection") { + t.Fatal(err) + } + // verify that the data returned is a strict subset of the input + if !bytes.HasPrefix([]byte(input), data) && len(data) != 0 { + t.Fatalf("unexpected server read: %q", string(data)) + } + } +} + +type errorReader struct { + reads [][]byte + err error + panicMessage string + pause chan struct{} +} + +func (r *errorReader) Read(p []byte) (int, error) { + if len(r.reads) == 0 { + if r.pause != nil { + <-r.pause + } + if len(r.panicMessage) != 0 { + panic(r.panicMessage) + } + return 0, r.err + } + next := r.reads[0] + r.reads = r.reads[1:] + copy(p, next) + return len(next), nil +} + +func readWebSocket(r *Reader, t *testing.T, fn func(*websocket.Conn), protocols ...string) ([]byte, error) { + errCh := make(chan error, 1) + s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + errCh <- r.Copy(w, req) + })) + defer s.Close() + + config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr) + config.Protocol = protocols + client, err := websocket.DialConfig(config) + if err != nil { + return nil, err + } + defer client.Close() + + if fn != nil { + fn(client) + } + + data, err := ioutil.ReadAll(client) + if err != nil { + return data, err + } + return data, <-errCh +} + +func expectWebSocketFrames(r *Reader, t *testing.T, fn func(*websocket.Conn), frames [][]byte, protocols ...string) error { + errCh := make(chan error, 1) + s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + errCh <- r.Copy(w, req) + })) + defer s.Close() + + config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr) + config.Protocol = protocols + ws, err := websocket.DialConfig(config) + if err != nil { + return err + } + defer ws.Close() + + if fn != nil { + fn(ws) + } + + for i := range frames { + var data []byte + if err := websocket.Message.Receive(ws, &data); err != nil { + return err + } + if !reflect.DeepEqual(frames[i], data) { + return fmt.Errorf("frame %d did not match expected: %v", data, err) + } + } + var data []byte + if err := websocket.Message.Receive(ws, &data); err != io.EOF { + return fmt.Errorf("expected no more frames: %v (%v)", err, data) + } + return <-errCh +} diff --git a/pkg/webhook/webhook.go b/pkg/webhook/webhook.go new file mode 100755 index 000000000..1d05d4440 --- /dev/null +++ b/pkg/webhook/webhook.go @@ -0,0 +1,106 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package webhook implements a generic HTTP webhook plugin. +package webhook + +import ( + "fmt" + "time" + + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + runtimeserializer "k8s.io/apimachinery/pkg/runtime/serializer" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/pkg/api" + apierrors "k8s.io/client-go/pkg/api/errors" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/clientcmd" + + _ "k8s.io/client-go/pkg/apis/authorization/install" +) + +type GenericWebhook struct { + RestClient *rest.RESTClient + initialBackoff time.Duration +} + +// NewGenericWebhook creates a new GenericWebhook from the provided kubeconfig file. +func NewGenericWebhook(kubeConfigFile string, groupVersions []schema.GroupVersion, initialBackoff time.Duration) (*GenericWebhook, error) { + for _, groupVersion := range groupVersions { + if !api.Registry.IsEnabledVersion(groupVersion) { + return nil, fmt.Errorf("webhook plugin requires enabling extension resource: %s", groupVersion) + } + } + + loadingRules := clientcmd.NewDefaultClientConfigLoadingRules() + loadingRules.ExplicitPath = kubeConfigFile + loader := clientcmd.NewNonInteractiveDeferredLoadingClientConfig(loadingRules, &clientcmd.ConfigOverrides{}) + + clientConfig, err := loader.ClientConfig() + if err != nil { + return nil, err + } + codec := api.Codecs.LegacyCodec(groupVersions...) + clientConfig.ContentConfig.NegotiatedSerializer = runtimeserializer.NegotiatedSerializerWrapper(runtime.SerializerInfo{Serializer: codec}) + + restClient, err := rest.UnversionedRESTClientFor(clientConfig) + if err != nil { + return nil, err + } + + // TODO(ericchiang): Can we ensure remote service is reachable? + + return &GenericWebhook{restClient, initialBackoff}, nil +} + +// WithExponentialBackoff will retry webhookFn() up to 5 times with exponentially increasing backoff when +// it returns an error for which apierrors.SuggestsClientDelay() or apierrors.IsInternalError() returns true. +func (g *GenericWebhook) WithExponentialBackoff(webhookFn func() rest.Result) rest.Result { + var result rest.Result + WithExponentialBackoff(g.initialBackoff, func() error { + result = webhookFn() + return result.Error() + }) + return result +} + +// WithExponentialBackoff will retry webhookFn() up to 5 times with exponentially increasing backoff when +// it returns an error for which apierrors.SuggestsClientDelay() or apierrors.IsInternalError() returns true. +func WithExponentialBackoff(initialBackoff time.Duration, webhookFn func() error) error { + backoff := wait.Backoff{ + Duration: initialBackoff, + Factor: 1.5, + Jitter: 0.2, + Steps: 5, + } + + var err error + wait.ExponentialBackoff(backoff, func() (bool, error) { + err = webhookFn() + if _, shouldRetry := apierrors.SuggestsClientDelay(err); shouldRetry { + return false, nil + } + if apierrors.IsInternalError(err) { + return false, nil + } + if err != nil { + return false, err + } + return true, nil + }) + return err +}