Force token cache to support audit annotations

Signed-off-by: Monis Khan <mok@vmware.com>

Kubernetes-commit: 6039451d358c20b8161e08eb1d3626134195026f
This commit is contained in:
Monis Khan 2020-04-14 12:46:37 -04:00 committed by Kubernetes Publisher
parent 51732c2088
commit 09aff09e1a
2 changed files with 211 additions and 25 deletions

View File

@ -34,7 +34,10 @@ import (
apierrors "k8s.io/apimachinery/pkg/api/errors"
utilclock "k8s.io/apimachinery/pkg/util/clock"
auditinternal "k8s.io/apiserver/pkg/apis/audit"
"k8s.io/apiserver/pkg/audit"
"k8s.io/apiserver/pkg/authentication/authenticator"
"k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/klog/v2"
)
@ -47,6 +50,16 @@ type cacheRecord struct {
resp *authenticator.Response
ok bool
err error
// this cache assumes token authn has no side-effects or temporal dependence.
// neither of these are true for audit annotations set via AddAuditAnnotation.
//
// for audit annotations, the assumption is that for some period of time (cache TTL),
// all requests with the same API audiences and the same bearer token result in the
// same annotations. This may not be true if the authenticator sets an annotation
// based on the current time, but that may be okay since cache TTLs are generally
// small (seconds).
annotations map[string]string
}
type cachedTokenAuthenticator struct {
@ -109,6 +122,17 @@ func newWithClock(authenticator authenticator.Token, cacheErrs bool, successTTL,
// AuthenticateToken implements authenticator.Token
func (a *cachedTokenAuthenticator) AuthenticateToken(ctx context.Context, token string) (*authenticator.Response, bool, error) {
record := a.doAuthenticateToken(ctx, token)
if !record.ok || record.err != nil {
return nil, false, record.err
}
for key, value := range record.annotations {
audit.AddAuditAnnotation(ctx, key, value)
}
return record.resp, true, nil
}
func (a *cachedTokenAuthenticator) doAuthenticateToken(ctx context.Context, token string) *cacheRecord {
doneAuthenticating := stats.authenticating()
auds, audsOk := authenticator.AudiencesFrom(ctx)
@ -117,7 +141,7 @@ func (a *cachedTokenAuthenticator) AuthenticateToken(ctx context.Context, token
if record, ok := a.cache.get(key); ok {
// Record cache hit
doneAuthenticating(true)
return record.resp, record.ok, record.err
return record
}
// Record cache miss
@ -125,18 +149,19 @@ func (a *cachedTokenAuthenticator) AuthenticateToken(ctx context.Context, token
defer doneBlocking()
defer doneAuthenticating(false)
type lookup struct {
resp *authenticator.Response
ok bool
}
c := a.group.DoChan(key, func() (val interface{}, _ error) {
// always use one place to read and write the output of AuthenticateToken
record := &cacheRecord{}
c := a.group.DoChan(key, func() (val interface{}, err error) {
doneFetching := stats.fetching()
// We're leaving the request handling stack so we need to handle crashes
// ourselves. Log a stack trace and return a 500 if something panics.
defer func() {
if r := recover(); r != nil {
err = errAuthnCrash
// make sure to always return a record
record.err = errAuthnCrash
val = record
// Same as stdlib http server code. Manually allocate stack
// trace buffer size to prevent excessively large logs
const size = 64 << 10
@ -144,12 +169,12 @@ func (a *cachedTokenAuthenticator) AuthenticateToken(ctx context.Context, token
buf = buf[:runtime.Stack(buf, false)]
klog.Errorf("%v\n%s", r, buf)
}
doneFetching(err == nil)
doneFetching(record.err == nil)
}()
// Check again for a cached record. We may have raced with a fetch.
if record, ok := a.cache.get(key); ok {
return lookup{record.resp, record.ok}, record.err
return record, nil
}
// Detach the context because the lookup may be shared by multiple callers,
@ -161,29 +186,35 @@ func (a *cachedTokenAuthenticator) AuthenticateToken(ctx context.Context, token
ctx = authenticator.WithAudiences(ctx, auds)
}
resp, ok, err := a.authenticator.AuthenticateToken(ctx, token)
if !a.cacheErrs && err != nil {
return nil, err
// since this is shared work between multiple requests, we have no way of knowing if any
// particular request supports audit annotations. thus we always attempt to record them.
ev := &auditinternal.Event{Level: auditinternal.LevelMetadata}
ctx = request.WithAuditEvent(ctx, ev)
record.resp, record.ok, record.err = a.authenticator.AuthenticateToken(ctx, token)
record.annotations = ev.Annotations
if !a.cacheErrs && record.err != nil {
return record, nil
}
switch {
case ok && a.successTTL > 0:
a.cache.set(key, &cacheRecord{resp: resp, ok: ok, err: err}, a.successTTL)
case !ok && a.failureTTL > 0:
a.cache.set(key, &cacheRecord{resp: resp, ok: ok, err: err}, a.failureTTL)
case record.ok && a.successTTL > 0:
a.cache.set(key, record, a.successTTL)
case !record.ok && a.failureTTL > 0:
a.cache.set(key, record, a.failureTTL)
}
return lookup{resp, ok}, err
return record, nil
})
select {
case result := <-c:
if result.Err != nil {
return nil, false, result.Err
}
lookup := result.Val.(lookup)
return lookup.resp, lookup.ok, nil
// we always set Val and never set Err
return result.Val.(*cacheRecord)
case <-ctx.Done():
return nil, false, ctx.Err()
// fake a record on context cancel
return &cacheRecord{err: ctx.Err()}
}
}

View File

@ -33,8 +33,11 @@ import (
"github.com/google/go-cmp/cmp"
utilclock "k8s.io/apimachinery/pkg/util/clock"
"k8s.io/apimachinery/pkg/util/uuid"
auditinternal "k8s.io/apiserver/pkg/apis/audit"
"k8s.io/apiserver/pkg/audit"
"k8s.io/apiserver/pkg/authentication/authenticator"
"k8s.io/apiserver/pkg/authentication/user"
"k8s.io/apiserver/pkg/endpoints/request"
)
func TestCachedTokenAuthenticator(t *testing.T) {
@ -274,6 +277,144 @@ func TestSharedLookup(t *testing.T) {
})
}
func TestCachedAuditAnnotations(t *testing.T) {
snorlax := &authenticator.Response{User: &user.DefaultInfo{Name: "snorlax"}}
t.Run("annotations from cache", func(t *testing.T) {
var lookups uint32
c := make(chan struct{})
a := New(authenticator.TokenFunc(func(ctx context.Context, token string) (*authenticator.Response, bool, error) {
<-c
atomic.AddUint32(&lookups, 1)
audit.AddAuditAnnotation(ctx, "snorlax", "rocks")
audit.AddAuditAnnotation(ctx, "pandas", "are amazing")
return snorlax, true, nil
}), false, time.Minute, 0)
allAnnotations := make(chan map[string]string, 10)
defer close(allAnnotations)
var wg sync.WaitGroup
for i := 0; i < cap(allAnnotations); i++ {
wg.Add(1)
go func() {
defer wg.Done()
// exercise both ways of tracking audit annotations
r := mathrand.New(mathrand.NewSource(mathrand.Int63()))
randomChoice := r.Int()%2 == 0
ctx := context.Background()
if randomChoice {
ctx = audit.WithAuditAnnotations(ctx)
} else {
ctx = request.WithAuditEvent(ctx, &auditinternal.Event{Level: auditinternal.LevelMetadata})
}
_, _, _ = a.AuthenticateToken(ctx, "token")
if randomChoice {
allAnnotations <- extractAnnotations(ctx)
} else {
allAnnotations <- request.AuditEventFrom(ctx).Annotations
}
}()
}
// no good way to make sure that all the callers are queued so we sleep.
time.Sleep(1 * time.Second)
close(c)
wg.Wait()
want := map[string]string{"snorlax": "rocks", "pandas": "are amazing"}
for i := 0; i < cap(allAnnotations); i++ {
annotations := <-allAnnotations
if diff := cmp.Diff(want, annotations); diff != "" {
t.Errorf("%d: unexpected annotations (-want +got): %s", i, diff)
}
}
if queued := len(allAnnotations); queued != 0 {
t.Errorf("expected all annoations to be processed: %d", queued)
}
if lookups > 3 {
t.Errorf("unexpected number of lookups: got=%d, wanted less than 3", lookups)
}
})
t.Run("annotations do not change during cache TTL", func(t *testing.T) {
a := New(authenticator.TokenFunc(func(ctx context.Context, token string) (*authenticator.Response, bool, error) {
audit.AddAuditAnnotation(ctx, "timestamp", time.Now().String())
return snorlax, true, nil
}), false, time.Minute, 0)
allAnnotations := make([]map[string]string, 0, 10)
for i := 0; i < cap(allAnnotations); i++ {
ctx := audit.WithAuditAnnotations(context.Background())
_, _, _ = a.AuthenticateToken(ctx, "token")
allAnnotations = append(allAnnotations, extractAnnotations(ctx))
}
if len(allAnnotations) != cap(allAnnotations) {
t.Errorf("failed to process all annotations")
}
want := allAnnotations[0]
if ok := len(want) == 1 && len(want["timestamp"]) > 0; !ok {
t.Errorf("invalid annotations: %v", want)
}
for i, annotations := range allAnnotations[1:] {
if diff := cmp.Diff(want, annotations); diff != "" {
t.Errorf("%d: unexpected annotations (-want +got): %s", i, diff)
}
}
})
t.Run("different tokens can have different annotations", func(t *testing.T) {
a := New(authenticator.TokenFunc(func(ctx context.Context, token string) (*authenticator.Response, bool, error) {
audit.AddAuditAnnotation(ctx, "timestamp", time.Now().String())
return snorlax, true, nil
}), false, time.Minute, 0)
ctx1 := audit.WithAuditAnnotations(context.Background())
_, _, _ = a.AuthenticateToken(ctx1, "token1")
annotations1 := extractAnnotations(ctx1)
// guarantee different now times
time.Sleep(time.Second)
ctx2 := audit.WithAuditAnnotations(context.Background())
_, _, _ = a.AuthenticateToken(ctx2, "token2")
annotations2 := extractAnnotations(ctx2)
if ok := len(annotations1) == 1 && len(annotations1["timestamp"]) > 0; !ok {
t.Errorf("invalid annotations 1: %v", annotations1)
}
if ok := len(annotations2) == 1 && len(annotations2["timestamp"]) > 0; !ok {
t.Errorf("invalid annotations 2: %v", annotations2)
}
if annotations1["timestamp"] == annotations2["timestamp"] {
t.Errorf("annotations should have different timestamp value: %v", annotations1)
}
})
}
func extractAnnotations(ctx context.Context) map[string]string {
annotationsSlice := reflect.ValueOf(ctx).Elem().FieldByName("val").Elem().Elem()
annotations := map[string]string{}
for i := 0; i < annotationsSlice.Len(); i++ {
annotation := annotationsSlice.Index(i)
key := annotation.FieldByName("key").String()
val := annotation.FieldByName("value").String()
annotations[key] = val
}
return annotations
}
func BenchmarkCachedTokenAuthenticator(b *testing.B) {
tokenCount := []int{100, 500, 2500, 12500, 62500}
threadCount := []int{1, 16, 256}
@ -318,6 +459,8 @@ func (s *singleBenchmark) makeTokens() {
s.tokenToAuds = map[string]authenticator.Audiences{}
s.tokens = []string{}
rr := mathrand.New(mathrand.NewSource(mathrand.Int63()))
for i := 0; i < s.tokenCount; i++ {
tok := fmt.Sprintf("%v-%v", jwtToken, i)
r := cacheRecord{
@ -327,14 +470,23 @@ func (s *singleBenchmark) makeTokens() {
}
// make different combinations of audience, failures, denies for the tokens.
auds := []string{}
for i := 0; i < mathrand.Intn(4); i++ {
for i := 0; i < rr.Intn(4); i++ {
auds = append(auds, string(uuid.NewUUID()))
}
choice := mathrand.Float64()
choice := rr.Float64()
switch {
case choice < 0.9:
r.ok = true
r.err = nil
// add some realistic annotations on ~20% of successful authentications
if f := rr.Float64(); f < 0.2 {
r.annotations = map[string]string{
"audience.authentication.kubernetes.io": "e8357258-88b1-11ea-bc55-0242ac130003",
"namespace.authentication.kubernetes.io": "kube-system",
"float.authentication.kubernetes.io": fmt.Sprint(f),
}
}
case choice < 0.99:
r.ok = false
r.err = nil
@ -355,6 +507,9 @@ func (s *singleBenchmark) lookup(ctx context.Context, token string) (*authentica
if !ok {
panic("test setup problem")
}
for key, val := range r.annotations {
audit.AddAuditAnnotation(ctx, key, val)
}
return r.resp, r.ok, r.err
}