diff --git a/pkg/audit/request.go b/pkg/audit/request.go index 205bf25c8..960ec9321 100644 --- a/pkg/audit/request.go +++ b/pkg/audit/request.go @@ -23,9 +23,6 @@ import ( "reflect" "time" - "github.com/google/uuid" - "k8s.io/klog/v2" - authnv1 "k8s.io/api/authentication/v1" "k8s.io/apimachinery/pkg/api/meta" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -36,6 +33,10 @@ import ( auditinternal "k8s.io/apiserver/pkg/apis/audit" "k8s.io/apiserver/pkg/authentication/user" "k8s.io/apiserver/pkg/authorization/authorizer" + "k8s.io/apiserver/pkg/endpoints/request" + "k8s.io/klog/v2" + + "github.com/google/uuid" ) const ( @@ -52,14 +53,11 @@ func NewEventFromRequest(req *http.Request, requestReceivedTimestamp time.Time, Level: level, } - // prefer the id from the headers. If not available, create a new one. - // TODO(audit): do we want to forbid the header for non-front-proxy users? - ids := req.Header.Get(auditinternal.HeaderAuditID) - if ids != "" { - ev.AuditID = types.UID(ids) - } else { - ev.AuditID = types.UID(uuid.New().String()) + auditID, found := request.AuditIDFrom(req.Context()) + if !found { + auditID = types.UID(uuid.New().String()) } + ev.AuditID = auditID ips := utilnet.SourceIPs(req) ev.SourceIPs = make([]string, len(ips)) diff --git a/pkg/endpoints/filters/audit.go b/pkg/endpoints/filters/audit.go index 2f78ff1de..853d1da9f 100644 --- a/pkg/endpoints/filters/audit.go +++ b/pkg/endpoints/filters/audit.go @@ -195,10 +195,6 @@ type auditResponseWriter struct { omitStages []auditinternal.Stage } -func (a *auditResponseWriter) setHttpHeader() { - a.ResponseWriter.Header().Set(auditinternal.HeaderAuditID, string(a.event.AuditID)) -} - func (a *auditResponseWriter) processCode(code int) { a.once.Do(func() { if a.event.ResponseStatus == nil { @@ -216,13 +212,11 @@ func (a *auditResponseWriter) processCode(code int) { func (a *auditResponseWriter) Write(bs []byte) (int, error) { // the Go library calls WriteHeader internally if no code was written yet. But this will go unnoticed for us a.processCode(http.StatusOK) - a.setHttpHeader() return a.ResponseWriter.Write(bs) } func (a *auditResponseWriter) WriteHeader(code int) { a.processCode(code) - a.setHttpHeader() a.ResponseWriter.WriteHeader(code) } @@ -245,12 +239,6 @@ func (f *fancyResponseWriterDelegator) Hijack() (net.Conn, *bufio.ReadWriter, er // fake a response status before protocol switch happens f.processCode(http.StatusSwitchingProtocols) - // This will be ignored if WriteHeader() function has already been called. - // It's not guaranteed Audit-ID http header is sent for all requests. - // For example, when user run "kubectl exec", apiserver uses a proxy handler - // to deal with the request, users can only get http headers returned by kubelet node. - f.setHttpHeader() - return f.ResponseWriter.(http.Hijacker).Hijack() } diff --git a/pkg/endpoints/filters/audit_test.go b/pkg/endpoints/filters/audit_test.go index 673e326ce..e326712e1 100644 --- a/pkg/endpoints/filters/audit_test.go +++ b/pkg/endpoints/filters/audit_test.go @@ -673,6 +673,7 @@ func TestAudit(t *testing.T) { // simplified long-running check return ri.Verb == "watch" }) + handler = WithAuditID(handler) req, _ := http.NewRequest(test.verb, test.path, nil) req = withTestContext(req, &user.DefaultInfo{Name: "admin"}, nil) @@ -772,16 +773,20 @@ func TestAuditIDHttpHeader(t *testing.T) { expectedHeader bool }{ { - "no http header when there is no audit", + // we always want an audit ID since it can appear in logging/tracing and it is propagated + // to the aggregated apiserver(s) to improve correlation. + "http header when there is no audit", "", auditinternal.LevelNone, - false, + true, }, { - "no http header when there is no audit even the request header specified", + // we always want an audit ID since it can appear in logging/tracing and it is propagated + // to the aggregated apiserver(s) to improve correlation. + "http header when there is no audit even the request header specified", uuid.New().String(), auditinternal.LevelNone, - false, + true, }, { "server generated header", @@ -796,38 +801,42 @@ func TestAuditIDHttpHeader(t *testing.T) { true, }, } { - sink := &fakeAuditSink{} - var handler http.Handler - handler = http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(200) + t.Run(test.desc, func(t *testing.T) { + sink := &fakeAuditSink{} + var handler http.Handler + handler = http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(200) + }) + policyChecker := policy.FakeChecker(test.level, nil) + + handler = WithAudit(handler, sink, policyChecker, nil) + handler = WithAuditID(handler) + + req, _ := http.NewRequest("GET", "/api/v1/namespaces/default/pods", nil) + req.RemoteAddr = "127.0.0.1" + req = withTestContext(req, &user.DefaultInfo{Name: "admin"}, nil) + if test.requestHeader != "" { + req.Header.Add("Audit-ID", test.requestHeader) + } + + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + resp := w.Result() + if test.expectedHeader { + if resp.Header.Get("Audit-ID") == "" { + t.Errorf("[%s] expected Audit-ID http header returned, but not returned", test.desc) + return + } + // if get Audit-ID returned, it should be the same with the requested one + if test.requestHeader != "" && resp.Header.Get("Audit-ID") != test.requestHeader { + t.Errorf("[%s] returned audit http header is not the same with the requested http header, expected: %s, get %s", test.desc, test.requestHeader, resp.Header.Get("Audit-ID")) + } + } else { + if resp.Header.Get("Audit-ID") != "" { + t.Errorf("[%s] expected no Audit-ID http header returned, but got %s", test.desc, resp.Header.Get("Audit-ID")) + } + } }) - policyChecker := policy.FakeChecker(test.level, nil) - handler = WithAudit(handler, sink, policyChecker, nil) - - req, _ := http.NewRequest("GET", "/api/v1/namespaces/default/pods", nil) - req.RemoteAddr = "127.0.0.1" - req = withTestContext(req, &user.DefaultInfo{Name: "admin"}, nil) - if test.requestHeader != "" { - req.Header.Add("Audit-ID", test.requestHeader) - } - - w := httptest.NewRecorder() - handler.ServeHTTP(w, req) - resp := w.Result() - if test.expectedHeader { - if resp.Header.Get("Audit-ID") == "" { - t.Errorf("[%s] expected Audit-ID http header returned, but not returned", test.desc) - continue - } - // if get Audit-ID returned, it should be the same with the requested one - if test.requestHeader != "" && resp.Header.Get("Audit-ID") != test.requestHeader { - t.Errorf("[%s] returned audit http header is not the same with the requested http header, expected: %s, get %s", test.desc, test.requestHeader, resp.Header.Get("Audit-ID")) - } - } else { - if resp.Header.Get("Audit-ID") != "" { - t.Errorf("[%s] expected no Audit-ID http header returned, but got %s", test.desc, resp.Header.Get("Audit-ID")) - } - } } } diff --git a/pkg/endpoints/filters/with_auditid.go b/pkg/endpoints/filters/with_auditid.go new file mode 100644 index 000000000..a7e8c7e4a --- /dev/null +++ b/pkg/endpoints/filters/with_auditid.go @@ -0,0 +1,68 @@ +/* +Copyright 2021 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filters + +import ( + "net/http" + + "k8s.io/apimachinery/pkg/types" + auditinternal "k8s.io/apiserver/pkg/apis/audit" + "k8s.io/apiserver/pkg/endpoints/request" + + "github.com/google/uuid" +) + +// WithAuditID attaches the Audit-ID associated with a request to the context. +// +// a. If the caller does not specify a value for Audit-ID in the request header, we generate a new audit ID +// b. We echo the Audit-ID value to the caller via the response Header 'Audit-ID'. +func WithAuditID(handler http.Handler) http.Handler { + return withAuditID(handler, func() string { + return uuid.New().String() + }) +} + +func withAuditID(handler http.Handler, newAuditIDFunc func() string) http.Handler { + if newAuditIDFunc == nil { + return handler + } + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + auditID := r.Header.Get(auditinternal.HeaderAuditID) + if len(auditID) == 0 { + auditID = newAuditIDFunc() + } + + // Note: we save the user specified value of the Audit-ID header as is, no truncation is performed. + r = r.WithContext(request.WithAuditID(ctx, types.UID(auditID))) + + // We echo the Audit-ID in to the response header. + // It's not guaranteed Audit-ID http header is sent for all requests. + // For example, when user run "kubectl exec", apiserver uses a proxy handler + // to deal with the request, users can only get http headers returned by kubelet node. + // + // This filter will also be used by other aggregated api server(s). For an aggregated API + // we don't want to see the same audit ID appearing more than once. + if value := w.Header().Get(auditinternal.HeaderAuditID); len(value) == 0 { + w.Header().Set(auditinternal.HeaderAuditID, auditID) + } + + handler.ServeHTTP(w, r) + }) +} diff --git a/pkg/endpoints/filters/with_auditid_test.go b/pkg/endpoints/filters/with_auditid_test.go new file mode 100644 index 000000000..29c8c9c8b --- /dev/null +++ b/pkg/endpoints/filters/with_auditid_test.go @@ -0,0 +1,113 @@ +/* +Copyright 2021 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filters + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/uuid" + "k8s.io/apiserver/pkg/endpoints/request" +) + +func TestWithAuditID(t *testing.T) { + largeAuditID := fmt.Sprintf("%s-%s", uuid.New().String(), uuid.New().String()) + tests := []struct { + name string + newAuditIDFunc func() string + auditIDSpecified string + auditIDExpected string + }{ + { + name: "user specifies a value for Audit-ID in the request header", + auditIDSpecified: "foo-bar-baz", + auditIDExpected: "foo-bar-baz", + }, + { + name: "user does not specify a value for Audit-ID in the request header", + newAuditIDFunc: func() string { + return "foo-bar-baz" + }, + auditIDExpected: "foo-bar-baz", + }, + { + name: "the value in Audit-ID request header is too large, should not be truncated", + auditIDSpecified: largeAuditID, + auditIDExpected: largeAuditID, + }, + { + name: "the generated Audit-ID is too large, should not be truncated", + newAuditIDFunc: func() string { + return largeAuditID + }, + auditIDExpected: largeAuditID, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const auditKey = "Audit-ID" + var ( + innerHandlerCallCount int + auditIDGot string + found bool + ) + handler := http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) { + innerHandlerCallCount++ + + // does the inner handler see the audit ID? + v, ok := request.AuditIDFrom(req.Context()) + + found = ok + auditIDGot = string(v) + }) + + wrapped := WithAuditID(handler) + if test.newAuditIDFunc != nil { + wrapped = withAuditID(handler, test.newAuditIDFunc) + } + + testRequest, err := http.NewRequest(http.MethodGet, "/api/v1/namespaces", nil) + if err != nil { + t.Fatalf("failed to create new http request - %v", err) + } + if len(test.auditIDSpecified) > 0 { + testRequest.Header.Set(auditKey, test.auditIDSpecified) + } + + w := httptest.NewRecorder() + wrapped.ServeHTTP(w, testRequest) + + if innerHandlerCallCount != 1 { + t.Errorf("WithAuditID: expected the inner handler to be invoked once, but was invoked %d times", innerHandlerCallCount) + } + if !found { + t.Error("WithAuditID: expected request.AuditIDFrom to return true, but got false") + } + if test.auditIDExpected != auditIDGot { + t.Errorf("WithAuditID: expected the request context to have: %q, but got=%q", test.auditIDExpected, auditIDGot) + } + + auditIDEchoed := w.Header().Get(auditKey) + if test.auditIDExpected != auditIDEchoed { + t.Errorf("WithAuditID: expected Audit-ID response header: %q, but got: %q", test.auditIDExpected, auditIDEchoed) + } + }) + } +} diff --git a/pkg/endpoints/request/auditid.go b/pkg/endpoints/request/auditid.go new file mode 100644 index 000000000..a7b3d84ad --- /dev/null +++ b/pkg/endpoints/request/auditid.go @@ -0,0 +1,66 @@ +/* +Copyright 2021 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package request + +import ( + "context" + "net/http" + + "k8s.io/apimachinery/pkg/types" +) + +type auditIDKeyType int + +// auditIDKey is the key to associate the Audit-ID value of a request. +const auditIDKey auditIDKeyType = iota + +// WithAuditID returns a copy of the parent context into which the Audit-ID +// associated with the request is set. +// +// If the specified auditID is empty, no value is set and the parent context is returned as is. +func WithAuditID(parent context.Context, auditID types.UID) context.Context { + if auditID == "" { + return parent + } + return WithValue(parent, auditIDKey, auditID) +} + +// AuditIDFrom returns the value of the audit ID from the request context. +func AuditIDFrom(ctx context.Context) (types.UID, bool) { + auditID, ok := ctx.Value(auditIDKey).(types.UID) + return auditID, ok +} + +// GetAuditIDTruncated returns the audit ID (truncated) associated with a request. +// If the length of the Audit-ID value exceeds the limit, we truncate it to keep +// the first N (maxAuditIDLength) characters. +// This is intended to be used in logging only. +func GetAuditIDTruncated(req *http.Request) string { + auditID, ok := AuditIDFrom(req.Context()) + if !ok { + return "" + } + + // if the user has specified a very long audit ID then we will use the first N characters + // Note: assuming Audit-ID header is in ASCII + const maxAuditIDLength = 64 + if len(auditID) > maxAuditIDLength { + auditID = auditID[0:maxAuditIDLength] + } + + return string(auditID) +} diff --git a/pkg/endpoints/request/auditid_test.go b/pkg/endpoints/request/auditid_test.go new file mode 100644 index 000000000..1583bdb21 --- /dev/null +++ b/pkg/endpoints/request/auditid_test.go @@ -0,0 +1,68 @@ +/* +Copyright 2021 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package request + +import ( + "context" + "testing" + + "k8s.io/apimachinery/pkg/types" +) + +func TestAuditIDFrom(t *testing.T) { + tests := []struct { + name string + auditID string + auditIDExpected string + expected bool + }{ + { + name: "empty audit ID", + auditID: "", + auditIDExpected: "", + expected: false, + }, + { + name: "non empty audit ID", + auditID: "foo-bar-baz", + auditIDExpected: "foo-bar-baz", + expected: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + parent := context.TODO() + ctx := WithAuditID(parent, types.UID(test.auditID)) + + // for an empty audit ID we don't expect a copy of the parent context. + if len(test.auditID) == 0 && parent != ctx { + t.Error("expected no copy of the parent context with an empty audit ID") + } + + value, ok := AuditIDFrom(ctx) + if test.expected != ok { + t.Errorf("expected AuditIDFrom to return: %t, but got: %t", test.expected, ok) + } + + auditIDGot := string(value) + if test.auditIDExpected != auditIDGot { + t.Errorf("expected audit ID: %q, but got: %q", test.auditIDExpected, auditIDGot) + } + }) + } +} diff --git a/pkg/server/config.go b/pkg/server/config.go index f2f303f47..12d18796e 100644 --- a/pkg/server/config.go +++ b/pkg/server/config.go @@ -768,6 +768,7 @@ func DefaultBuildHandlerChain(apiHandler http.Handler, c *Config) http.Handler { handler = genericfilters.WithHTTPLogging(handler) handler = genericapifilters.WithRequestReceivedTimestamp(handler) handler = genericfilters.WithPanicRecovery(handler, c.RequestInfoResolver) + handler = genericapifilters.WithAuditID(handler) return handler } diff --git a/pkg/server/filters/priority-and-fairness_test.go b/pkg/server/filters/priority-and-fairness_test.go index 51f689313..0ef6345f3 100644 --- a/pkg/server/filters/priority-and-fairness_test.go +++ b/pkg/server/filters/priority-and-fairness_test.go @@ -854,6 +854,7 @@ func newHandlerChain(t *testing.T, handler http.Handler, filter utilflowcontrol. handler = apifilters.WithRequestDeadline(handler, nil, nil, longRunningRequestCheck, nil, requestTimeout) handler = apifilters.WithRequestInfo(handler, requestInfoFactory) handler = WithPanicRecovery(handler, requestInfoFactory) + handler = apifilters.WithAuditID(handler) return handler }