make token cache include audience in hash key

Kubernetes-commit: 809f278b032103cd24fcbb5ea2196c6c7caa6f63
This commit is contained in:
Mike Danese 2018-10-16 10:02:01 -07:00 committed by Kubernetes Publisher
parent 2ced48ac6e
commit 908a04653f
4 changed files with 46 additions and 16 deletions

View File

@ -24,36 +24,36 @@ import (
// split cache lookups across N striped caches // split cache lookups across N striped caches
type stripedCache struct { type stripedCache struct {
stripeCount uint32 stripeCount uint32
keyFunc func(string) uint32 hashFunc func(string) uint32
caches []cache caches []cache
} }
type keyFunc func(string) uint32 type hashFunc func(string) uint32
type newCacheFunc func() cache type newCacheFunc func() cache
func newStripedCache(stripeCount int, keyFunc keyFunc, newCacheFunc newCacheFunc) cache { func newStripedCache(stripeCount int, hash hashFunc, newCacheFunc newCacheFunc) cache {
caches := []cache{} caches := []cache{}
for i := 0; i < stripeCount; i++ { for i := 0; i < stripeCount; i++ {
caches = append(caches, newCacheFunc()) caches = append(caches, newCacheFunc())
} }
return &stripedCache{ return &stripedCache{
stripeCount: uint32(stripeCount), stripeCount: uint32(stripeCount),
keyFunc: keyFunc, hashFunc: hash,
caches: caches, caches: caches,
} }
} }
func (c *stripedCache) get(key string) (*cacheRecord, bool) { func (c *stripedCache) get(key string) (*cacheRecord, bool) {
return c.caches[c.keyFunc(key)%c.stripeCount].get(key) return c.caches[c.hashFunc(key)%c.stripeCount].get(key)
} }
func (c *stripedCache) set(key string, value *cacheRecord, ttl time.Duration) { func (c *stripedCache) set(key string, value *cacheRecord, ttl time.Duration) {
c.caches[c.keyFunc(key)%c.stripeCount].set(key, value, ttl) c.caches[c.hashFunc(key)%c.stripeCount].set(key, value, ttl)
} }
func (c *stripedCache) remove(key string) { func (c *stripedCache) remove(key string) {
c.caches[c.keyFunc(key)%c.stripeCount].remove(key) c.caches[c.hashFunc(key)%c.stripeCount].remove(key)
} }
func fnvKeyFunc(key string) uint32 { func fnvHashFunc(key string) uint32 {
f := fnv.New32() f := fnv.New32()
f.Write([]byte(key)) f.Write([]byte(key))
return f.Sum32() return f.Sum32()

View File

@ -22,6 +22,7 @@ import (
"time" "time"
"github.com/pborman/uuid" "github.com/pborman/uuid"
"k8s.io/apimachinery/pkg/util/clock" "k8s.io/apimachinery/pkg/util/clock"
"k8s.io/apiserver/pkg/authentication/authenticator" "k8s.io/apiserver/pkg/authentication/authenticator"
"k8s.io/apiserver/pkg/authentication/user" "k8s.io/apiserver/pkg/authentication/user"
@ -36,11 +37,11 @@ func BenchmarkSimpleCache(b *testing.B) {
} }
func TestStripedCache(t *testing.T) { func TestStripedCache(t *testing.T) {
testCache(newStripedCache(32, fnvKeyFunc, func() cache { return newSimpleCache(128, clock.RealClock{}) }), t) testCache(newStripedCache(32, fnvHashFunc, func() cache { return newSimpleCache(128, clock.RealClock{}) }), t)
} }
func BenchmarkStripedCache(b *testing.B) { func BenchmarkStripedCache(b *testing.B) {
benchmarkCache(newStripedCache(32, fnvKeyFunc, func() cache { return newSimpleCache(128, clock.RealClock{}) }), b) benchmarkCache(newStripedCache(32, fnvHashFunc, func() cache { return newSimpleCache(128, clock.RealClock{}) }), b)
} }
func benchmarkCache(cache cache, b *testing.B) { func benchmarkCache(cache cache, b *testing.B) {

View File

@ -18,10 +18,12 @@ package cache
import ( import (
"context" "context"
"fmt"
"time" "time"
utilclock "k8s.io/apimachinery/pkg/util/clock" utilclock "k8s.io/apimachinery/pkg/util/clock"
"k8s.io/apiserver/pkg/authentication/authenticator" "k8s.io/apiserver/pkg/authentication/authenticator"
"k8s.io/apiserver/pkg/endpoints/request"
) )
// cacheRecord holds the three return values of the authenticator.Token AuthenticateToken method // cacheRecord holds the three return values of the authenticator.Token AuthenticateToken method
@ -59,15 +61,16 @@ func newWithClock(authenticator authenticator.Token, successTTL, failureTTL time
authenticator: authenticator, authenticator: authenticator,
successTTL: successTTL, successTTL: successTTL,
failureTTL: failureTTL, failureTTL: failureTTL,
cache: newStripedCache(32, fnvKeyFunc, func() cache { return newSimpleCache(128, clock) }), cache: newStripedCache(32, fnvHashFunc, func() cache { return newSimpleCache(128, clock) }),
} }
} }
// AuthenticateToken implements authenticator.Token // AuthenticateToken implements authenticator.Token
func (a *cachedTokenAuthenticator) AuthenticateToken(ctx context.Context, token string) (*authenticator.Response, bool, error) { func (a *cachedTokenAuthenticator) AuthenticateToken(ctx context.Context, token string) (*authenticator.Response, bool, error) {
// TODO(mikedanese): The key needs to incorporate any relevant data in the auds, _ := request.AudiencesFrom(ctx)
// context.
if record, ok := a.cache.get(token); ok { key := keyFunc(auds, token)
if record, ok := a.cache.get(key); ok {
return record.resp, record.ok, record.err return record.resp, record.ok, record.err
} }
@ -75,10 +78,14 @@ func (a *cachedTokenAuthenticator) AuthenticateToken(ctx context.Context, token
switch { switch {
case ok && a.successTTL > 0: case ok && a.successTTL > 0:
a.cache.set(token, &cacheRecord{resp: resp, ok: ok, err: err}, a.successTTL) a.cache.set(key, &cacheRecord{resp: resp, ok: ok, err: err}, a.successTTL)
case !ok && a.failureTTL > 0: case !ok && a.failureTTL > 0:
a.cache.set(token, &cacheRecord{resp: resp, ok: ok, err: err}, a.failureTTL) a.cache.set(key, &cacheRecord{resp: resp, ok: ok, err: err}, a.failureTTL)
} }
return resp, ok, err return resp, ok, err
} }
func keyFunc(auds []string, token string) string {
return fmt.Sprintf("%#v|%v", auds, token)
}

View File

@ -25,6 +25,7 @@ import (
utilclock "k8s.io/apimachinery/pkg/util/clock" utilclock "k8s.io/apimachinery/pkg/util/clock"
"k8s.io/apiserver/pkg/authentication/authenticator" "k8s.io/apiserver/pkg/authentication/authenticator"
"k8s.io/apiserver/pkg/authentication/user" "k8s.io/apiserver/pkg/authentication/user"
"k8s.io/apiserver/pkg/endpoints/request"
) )
func TestCachedTokenAuthenticator(t *testing.T) { func TestCachedTokenAuthenticator(t *testing.T) {
@ -104,3 +105,24 @@ func TestCachedTokenAuthenticator(t *testing.T) {
t.Errorf("Expected token calls, got %v", calledWithToken) t.Errorf("Expected token calls, got %v", calledWithToken)
} }
} }
func TestCachedTokenAuthenticatorWithAudiences(t *testing.T) {
resultUsers := make(map[string]user.Info)
fakeAuth := authenticator.TokenFunc(func(ctx context.Context, token string) (*authenticator.Response, bool, error) {
auds, _ := request.AudiencesFrom(ctx)
return &authenticator.Response{User: resultUsers[auds[0]+token]}, true, nil
})
fakeClock := utilclock.NewFakeClock(time.Now())
a := newWithClock(fakeAuth, time.Minute, 0, fakeClock)
resultUsers["audAusertoken1"] = &user.DefaultInfo{Name: "user1"}
resultUsers["audBusertoken1"] = &user.DefaultInfo{Name: "user1-different"}
if u, ok, _ := a.AuthenticateToken(request.WithAudiences(context.Background(), []string{"audA"}), "usertoken1"); !ok || u.User.GetName() != "user1" {
t.Errorf("Expected user1")
}
if u, ok, _ := a.AuthenticateToken(request.WithAudiences(context.Background(), []string{"audB"}), "usertoken1"); !ok || u.User.GetName() != "user1-different" {
t.Errorf("Expected user1-different")
}
}