Force token cache to support audit annotations

Signed-off-by: Monis Khan <mok@vmware.com>

Kubernetes-commit: 6039451d358c20b8161e08eb1d3626134195026f
This commit is contained in:
Monis Khan 2020-04-14 12:46:37 -04:00 committed by Kubernetes Publisher
parent 51732c2088
commit 09aff09e1a
2 changed files with 211 additions and 25 deletions

View File

@ -34,7 +34,10 @@ import (
apierrors "k8s.io/apimachinery/pkg/api/errors" apierrors "k8s.io/apimachinery/pkg/api/errors"
utilclock "k8s.io/apimachinery/pkg/util/clock" utilclock "k8s.io/apimachinery/pkg/util/clock"
auditinternal "k8s.io/apiserver/pkg/apis/audit"
"k8s.io/apiserver/pkg/audit"
"k8s.io/apiserver/pkg/authentication/authenticator" "k8s.io/apiserver/pkg/authentication/authenticator"
"k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/klog/v2" "k8s.io/klog/v2"
) )
@ -47,6 +50,16 @@ type cacheRecord struct {
resp *authenticator.Response resp *authenticator.Response
ok bool ok bool
err error err error
// this cache assumes token authn has no side-effects or temporal dependence.
// neither of these are true for audit annotations set via AddAuditAnnotation.
//
// for audit annotations, the assumption is that for some period of time (cache TTL),
// all requests with the same API audiences and the same bearer token result in the
// same annotations. This may not be true if the authenticator sets an annotation
// based on the current time, but that may be okay since cache TTLs are generally
// small (seconds).
annotations map[string]string
} }
type cachedTokenAuthenticator struct { type cachedTokenAuthenticator struct {
@ -109,6 +122,17 @@ func newWithClock(authenticator authenticator.Token, cacheErrs bool, successTTL,
// AuthenticateToken implements authenticator.Token // AuthenticateToken implements authenticator.Token
func (a *cachedTokenAuthenticator) AuthenticateToken(ctx context.Context, token string) (*authenticator.Response, bool, error) { func (a *cachedTokenAuthenticator) AuthenticateToken(ctx context.Context, token string) (*authenticator.Response, bool, error) {
record := a.doAuthenticateToken(ctx, token)
if !record.ok || record.err != nil {
return nil, false, record.err
}
for key, value := range record.annotations {
audit.AddAuditAnnotation(ctx, key, value)
}
return record.resp, true, nil
}
func (a *cachedTokenAuthenticator) doAuthenticateToken(ctx context.Context, token string) *cacheRecord {
doneAuthenticating := stats.authenticating() doneAuthenticating := stats.authenticating()
auds, audsOk := authenticator.AudiencesFrom(ctx) auds, audsOk := authenticator.AudiencesFrom(ctx)
@ -117,7 +141,7 @@ func (a *cachedTokenAuthenticator) AuthenticateToken(ctx context.Context, token
if record, ok := a.cache.get(key); ok { if record, ok := a.cache.get(key); ok {
// Record cache hit // Record cache hit
doneAuthenticating(true) doneAuthenticating(true)
return record.resp, record.ok, record.err return record
} }
// Record cache miss // Record cache miss
@ -125,18 +149,19 @@ func (a *cachedTokenAuthenticator) AuthenticateToken(ctx context.Context, token
defer doneBlocking() defer doneBlocking()
defer doneAuthenticating(false) defer doneAuthenticating(false)
type lookup struct { c := a.group.DoChan(key, func() (val interface{}, _ error) {
resp *authenticator.Response // always use one place to read and write the output of AuthenticateToken
ok bool record := &cacheRecord{}
}
c := a.group.DoChan(key, func() (val interface{}, err error) {
doneFetching := stats.fetching() doneFetching := stats.fetching()
// We're leaving the request handling stack so we need to handle crashes // We're leaving the request handling stack so we need to handle crashes
// ourselves. Log a stack trace and return a 500 if something panics. // ourselves. Log a stack trace and return a 500 if something panics.
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
err = errAuthnCrash // make sure to always return a record
record.err = errAuthnCrash
val = record
// Same as stdlib http server code. Manually allocate stack // Same as stdlib http server code. Manually allocate stack
// trace buffer size to prevent excessively large logs // trace buffer size to prevent excessively large logs
const size = 64 << 10 const size = 64 << 10
@ -144,12 +169,12 @@ func (a *cachedTokenAuthenticator) AuthenticateToken(ctx context.Context, token
buf = buf[:runtime.Stack(buf, false)] buf = buf[:runtime.Stack(buf, false)]
klog.Errorf("%v\n%s", r, buf) klog.Errorf("%v\n%s", r, buf)
} }
doneFetching(err == nil) doneFetching(record.err == nil)
}() }()
// Check again for a cached record. We may have raced with a fetch. // Check again for a cached record. We may have raced with a fetch.
if record, ok := a.cache.get(key); ok { if record, ok := a.cache.get(key); ok {
return lookup{record.resp, record.ok}, record.err return record, nil
} }
// Detach the context because the lookup may be shared by multiple callers, // Detach the context because the lookup may be shared by multiple callers,
@ -161,29 +186,35 @@ func (a *cachedTokenAuthenticator) AuthenticateToken(ctx context.Context, token
ctx = authenticator.WithAudiences(ctx, auds) ctx = authenticator.WithAudiences(ctx, auds)
} }
resp, ok, err := a.authenticator.AuthenticateToken(ctx, token) // since this is shared work between multiple requests, we have no way of knowing if any
if !a.cacheErrs && err != nil { // particular request supports audit annotations. thus we always attempt to record them.
return nil, err ev := &auditinternal.Event{Level: auditinternal.LevelMetadata}
ctx = request.WithAuditEvent(ctx, ev)
record.resp, record.ok, record.err = a.authenticator.AuthenticateToken(ctx, token)
record.annotations = ev.Annotations
if !a.cacheErrs && record.err != nil {
return record, nil
} }
switch { switch {
case ok && a.successTTL > 0: case record.ok && a.successTTL > 0:
a.cache.set(key, &cacheRecord{resp: resp, ok: ok, err: err}, a.successTTL) a.cache.set(key, record, a.successTTL)
case !ok && a.failureTTL > 0: case !record.ok && a.failureTTL > 0:
a.cache.set(key, &cacheRecord{resp: resp, ok: ok, err: err}, a.failureTTL) a.cache.set(key, record, a.failureTTL)
} }
return lookup{resp, ok}, err
return record, nil
}) })
select { select {
case result := <-c: case result := <-c:
if result.Err != nil { // we always set Val and never set Err
return nil, false, result.Err return result.Val.(*cacheRecord)
}
lookup := result.Val.(lookup)
return lookup.resp, lookup.ok, nil
case <-ctx.Done(): case <-ctx.Done():
return nil, false, ctx.Err() // fake a record on context cancel
return &cacheRecord{err: ctx.Err()}
} }
} }

View File

@ -33,8 +33,11 @@ import (
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
utilclock "k8s.io/apimachinery/pkg/util/clock" utilclock "k8s.io/apimachinery/pkg/util/clock"
"k8s.io/apimachinery/pkg/util/uuid" "k8s.io/apimachinery/pkg/util/uuid"
auditinternal "k8s.io/apiserver/pkg/apis/audit"
"k8s.io/apiserver/pkg/audit"
"k8s.io/apiserver/pkg/authentication/authenticator" "k8s.io/apiserver/pkg/authentication/authenticator"
"k8s.io/apiserver/pkg/authentication/user" "k8s.io/apiserver/pkg/authentication/user"
"k8s.io/apiserver/pkg/endpoints/request"
) )
func TestCachedTokenAuthenticator(t *testing.T) { func TestCachedTokenAuthenticator(t *testing.T) {
@ -274,6 +277,144 @@ func TestSharedLookup(t *testing.T) {
}) })
} }
func TestCachedAuditAnnotations(t *testing.T) {
snorlax := &authenticator.Response{User: &user.DefaultInfo{Name: "snorlax"}}
t.Run("annotations from cache", func(t *testing.T) {
var lookups uint32
c := make(chan struct{})
a := New(authenticator.TokenFunc(func(ctx context.Context, token string) (*authenticator.Response, bool, error) {
<-c
atomic.AddUint32(&lookups, 1)
audit.AddAuditAnnotation(ctx, "snorlax", "rocks")
audit.AddAuditAnnotation(ctx, "pandas", "are amazing")
return snorlax, true, nil
}), false, time.Minute, 0)
allAnnotations := make(chan map[string]string, 10)
defer close(allAnnotations)
var wg sync.WaitGroup
for i := 0; i < cap(allAnnotations); i++ {
wg.Add(1)
go func() {
defer wg.Done()
// exercise both ways of tracking audit annotations
r := mathrand.New(mathrand.NewSource(mathrand.Int63()))
randomChoice := r.Int()%2 == 0
ctx := context.Background()
if randomChoice {
ctx = audit.WithAuditAnnotations(ctx)
} else {
ctx = request.WithAuditEvent(ctx, &auditinternal.Event{Level: auditinternal.LevelMetadata})
}
_, _, _ = a.AuthenticateToken(ctx, "token")
if randomChoice {
allAnnotations <- extractAnnotations(ctx)
} else {
allAnnotations <- request.AuditEventFrom(ctx).Annotations
}
}()
}
// no good way to make sure that all the callers are queued so we sleep.
time.Sleep(1 * time.Second)
close(c)
wg.Wait()
want := map[string]string{"snorlax": "rocks", "pandas": "are amazing"}
for i := 0; i < cap(allAnnotations); i++ {
annotations := <-allAnnotations
if diff := cmp.Diff(want, annotations); diff != "" {
t.Errorf("%d: unexpected annotations (-want +got): %s", i, diff)
}
}
if queued := len(allAnnotations); queued != 0 {
t.Errorf("expected all annoations to be processed: %d", queued)
}
if lookups > 3 {
t.Errorf("unexpected number of lookups: got=%d, wanted less than 3", lookups)
}
})
t.Run("annotations do not change during cache TTL", func(t *testing.T) {
a := New(authenticator.TokenFunc(func(ctx context.Context, token string) (*authenticator.Response, bool, error) {
audit.AddAuditAnnotation(ctx, "timestamp", time.Now().String())
return snorlax, true, nil
}), false, time.Minute, 0)
allAnnotations := make([]map[string]string, 0, 10)
for i := 0; i < cap(allAnnotations); i++ {
ctx := audit.WithAuditAnnotations(context.Background())
_, _, _ = a.AuthenticateToken(ctx, "token")
allAnnotations = append(allAnnotations, extractAnnotations(ctx))
}
if len(allAnnotations) != cap(allAnnotations) {
t.Errorf("failed to process all annotations")
}
want := allAnnotations[0]
if ok := len(want) == 1 && len(want["timestamp"]) > 0; !ok {
t.Errorf("invalid annotations: %v", want)
}
for i, annotations := range allAnnotations[1:] {
if diff := cmp.Diff(want, annotations); diff != "" {
t.Errorf("%d: unexpected annotations (-want +got): %s", i, diff)
}
}
})
t.Run("different tokens can have different annotations", func(t *testing.T) {
a := New(authenticator.TokenFunc(func(ctx context.Context, token string) (*authenticator.Response, bool, error) {
audit.AddAuditAnnotation(ctx, "timestamp", time.Now().String())
return snorlax, true, nil
}), false, time.Minute, 0)
ctx1 := audit.WithAuditAnnotations(context.Background())
_, _, _ = a.AuthenticateToken(ctx1, "token1")
annotations1 := extractAnnotations(ctx1)
// guarantee different now times
time.Sleep(time.Second)
ctx2 := audit.WithAuditAnnotations(context.Background())
_, _, _ = a.AuthenticateToken(ctx2, "token2")
annotations2 := extractAnnotations(ctx2)
if ok := len(annotations1) == 1 && len(annotations1["timestamp"]) > 0; !ok {
t.Errorf("invalid annotations 1: %v", annotations1)
}
if ok := len(annotations2) == 1 && len(annotations2["timestamp"]) > 0; !ok {
t.Errorf("invalid annotations 2: %v", annotations2)
}
if annotations1["timestamp"] == annotations2["timestamp"] {
t.Errorf("annotations should have different timestamp value: %v", annotations1)
}
})
}
func extractAnnotations(ctx context.Context) map[string]string {
annotationsSlice := reflect.ValueOf(ctx).Elem().FieldByName("val").Elem().Elem()
annotations := map[string]string{}
for i := 0; i < annotationsSlice.Len(); i++ {
annotation := annotationsSlice.Index(i)
key := annotation.FieldByName("key").String()
val := annotation.FieldByName("value").String()
annotations[key] = val
}
return annotations
}
func BenchmarkCachedTokenAuthenticator(b *testing.B) { func BenchmarkCachedTokenAuthenticator(b *testing.B) {
tokenCount := []int{100, 500, 2500, 12500, 62500} tokenCount := []int{100, 500, 2500, 12500, 62500}
threadCount := []int{1, 16, 256} threadCount := []int{1, 16, 256}
@ -318,6 +459,8 @@ func (s *singleBenchmark) makeTokens() {
s.tokenToAuds = map[string]authenticator.Audiences{} s.tokenToAuds = map[string]authenticator.Audiences{}
s.tokens = []string{} s.tokens = []string{}
rr := mathrand.New(mathrand.NewSource(mathrand.Int63()))
for i := 0; i < s.tokenCount; i++ { for i := 0; i < s.tokenCount; i++ {
tok := fmt.Sprintf("%v-%v", jwtToken, i) tok := fmt.Sprintf("%v-%v", jwtToken, i)
r := cacheRecord{ r := cacheRecord{
@ -327,14 +470,23 @@ func (s *singleBenchmark) makeTokens() {
} }
// make different combinations of audience, failures, denies for the tokens. // make different combinations of audience, failures, denies for the tokens.
auds := []string{} auds := []string{}
for i := 0; i < mathrand.Intn(4); i++ { for i := 0; i < rr.Intn(4); i++ {
auds = append(auds, string(uuid.NewUUID())) auds = append(auds, string(uuid.NewUUID()))
} }
choice := mathrand.Float64() choice := rr.Float64()
switch { switch {
case choice < 0.9: case choice < 0.9:
r.ok = true r.ok = true
r.err = nil r.err = nil
// add some realistic annotations on ~20% of successful authentications
if f := rr.Float64(); f < 0.2 {
r.annotations = map[string]string{
"audience.authentication.kubernetes.io": "e8357258-88b1-11ea-bc55-0242ac130003",
"namespace.authentication.kubernetes.io": "kube-system",
"float.authentication.kubernetes.io": fmt.Sprint(f),
}
}
case choice < 0.99: case choice < 0.99:
r.ok = false r.ok = false
r.err = nil r.err = nil
@ -355,6 +507,9 @@ func (s *singleBenchmark) lookup(ctx context.Context, token string) (*authentica
if !ok { if !ok {
panic("test setup problem") panic("test setup problem")
} }
for key, val := range r.annotations {
audit.AddAuditAnnotation(ctx, key, val)
}
return r.resp, r.ok, r.err return r.resp, r.ok, r.err
} }