components-contrib/crypto/pubkey_cache_test.go

366 lines
10 KiB
Go

/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package crypto
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"errors"
"sync"
"testing"
"time"
"github.com/chebyrash/promise"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
kitctx "github.com/dapr/kit/context"
)
func TestPubKeyCacheGetKey(t *testing.T) {
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
testKey, err := jwk.FromRaw(pk.PublicKey)
require.NoError(t, err)
pk2, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
testKey2, err := jwk.FromRaw(pk2.PublicKey)
require.NoError(t, err)
t.Run("existing key should return key", func(t *testing.T) {
t.Parallel()
cache := NewPubKeyCache(func(context.Context, string) func(resolve func(jwk.Key), reject func(error)) {
return func(resolve func(jwk.Key), reject func(error)) { assert.Fail(t, "should not be called") }
})
cache.pubKeys["key"] = pubKeyCacheEntry{
promise: promise.New(func(resolve func(jwk.Key), reject func(error)) {
resolve(testKey)
}),
ctx: kitctx.NewPool(),
}
result, err := cache.GetKey(t.Context(), "key")
require.NoError(t, err)
assert.Equal(t, testKey, result)
})
t.Run("two different keys should be returned", func(t *testing.T) {
t.Parallel()
cache := NewPubKeyCache(func(context.Context, string) func(resolve func(jwk.Key), reject func(error)) {
return func(resolve func(jwk.Key), reject func(error)) { assert.Fail(t, "should not be called") }
})
cache.pubKeys["key"] = pubKeyCacheEntry{
promise: promise.New(func(resolve func(jwk.Key), reject func(error)) {
resolve(testKey)
}),
ctx: kitctx.NewPool(),
}
cache.pubKeys["another-key"] = pubKeyCacheEntry{
promise: promise.New(func(resolve func(jwk.Key), reject func(error)) {
resolve(testKey2)
}),
ctx: kitctx.NewPool(),
}
result, err := cache.GetKey(t.Context(), "key")
require.NoError(t, err)
assert.Equal(t, testKey, result)
result, err = cache.GetKey(t.Context(), "another-key")
require.NoError(t, err)
assert.Equal(t, testKey2, result)
})
t.Run("cold cache should fetch key", func(t *testing.T) {
t.Parallel()
var called int
cache := NewPubKeyCache(func(context.Context, string) func(resolve func(jwk.Key), reject func(error)) {
return func(resolve func(jwk.Key), reject func(error)) {
called++
resolve(testKey)
}
})
result, err := cache.GetKey(t.Context(), "key")
require.NoError(t, err)
assert.Equal(t, testKey, result)
assert.Equal(t, 1, called, "should be called once")
})
t.Run("cold cache should fetch different keys", func(t *testing.T) {
t.Parallel()
var called int
cache := NewPubKeyCache(func(ctx context.Context, i string) func(resolve func(jwk.Key), reject func(error)) {
return func(resolve func(jwk.Key), reject func(error)) {
called++
switch i {
case "key":
resolve(testKey)
default:
resolve(testKey2)
}
}
})
result, err := cache.GetKey(t.Context(), "key")
require.NoError(t, err)
assert.Equal(t, testKey, result)
result, err = cache.GetKey(t.Context(), "another-key")
require.NoError(t, err)
assert.Equal(t, testKey2, result)
assert.Equal(t, 2, called, "should be called once")
})
t.Run("fetch key which errors, should error getKey", func(t *testing.T) {
t.Parallel()
var called int
cache := NewPubKeyCache(func(ctx context.Context, i string) func(resolve func(jwk.Key), reject func(error)) {
return func(resolve func(jwk.Key), reject func(error)) {
assert.Equal(t, "key", i)
called++
reject(assert.AnError)
}
})
result, err := cache.GetKey(t.Context(), "key")
require.Error(t, err)
assert.Nil(t, result)
assert.Equal(t, 1, called, "should be called once")
})
t.Run("multiple fetch key at the same time should only all getKey once", func(t *testing.T) {
t.Parallel()
var called int
cache := NewPubKeyCache(func(ctx context.Context, i string) func(resolve func(jwk.Key), reject func(error)) {
return func(resolve func(jwk.Key), reject func(error)) {
assert.Equal(t, "key", i)
called++
resolve(testKey)
}
})
var wg sync.WaitGroup
wg.Add(10)
for range 10 {
go func() {
defer wg.Done()
result, err := cache.GetKey(t.Context(), "key")
require.NoError(t, err)
assert.Equal(t, testKey, result)
}()
}
wg.Wait()
assert.Equal(t, 1, called, "should be called once")
})
t.Run("calling get key and context is cancelled should return context error", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancelCause(t.Context())
getKeyReturned := make(chan struct{})
cache := NewPubKeyCache(func(ctx context.Context, i string) func(resolve func(jwk.Key), reject func(error)) {
return func(resolve func(jwk.Key), reject func(error)) {
t.Cleanup(func() {
select {
case <-getKeyReturned:
case <-time.After(1 * time.Second):
assert.Fail(t, "expected GetKey to return from cancelled context in time")
}
})
assert.Equal(t, "key", i)
cancel(assert.AnError)
resolve(testKey)
}
})
result, err := cache.GetKey(ctx, "key")
assert.Equal(t, context.Canceled, err)
assert.Nil(t, result)
close(getKeyReturned)
})
t.Run("only callers with cancelled contexts should return context error", func(t *testing.T) {
t.Parallel()
ctx1, cancel1 := context.WithCancelCause(t.Context())
ctx2 := t.Context()
getKeyReturned := make(chan struct{})
var cache *PubKeyCache
cache = NewPubKeyCache(func(ctx context.Context, i string) func(resolve func(jwk.Key), reject func(error)) {
return func(resolve func(jwk.Key), reject func(error)) {
require.Eventually(t, func() bool {
cache.lock.Lock()
defer cache.lock.Unlock()
return cache.pubKeys["key"].ctx.Size() == 2
}, time.Second*5, time.Millisecond)
t.Cleanup(func() {
select {
case <-getKeyReturned:
case <-time.After(1 * time.Second):
assert.Fail(t, "expected GetKey to return from cancelled context in time")
}
})
assert.Equal(t, "key", i)
cancel1(assert.AnError)
select {
case <-ctx.Done():
assert.Fail(t, "GetKey context should not be cancelled")
default:
}
resolve(testKey)
}
})
go func() {
result, err := cache.GetKey(ctx2, "key")
require.NoError(t, err)
assert.Equal(t, testKey, result)
close(getKeyReturned)
}()
result, err := cache.GetKey(ctx1, "key")
assert.Equal(t, context.Canceled, err)
assert.Nil(t, result)
})
t.Run("if all callers give cancelled contexts, the underlying context should also be cancelled", func(t *testing.T) {
t.Parallel()
ctx1, cancel1 := context.WithCancelCause(t.Context())
ctx2, cancel2 := context.WithCancelCause(t.Context())
getKeyReturned := make(chan struct{})
var cache *PubKeyCache
cache = NewPubKeyCache(func(ctx context.Context, i string) func(resolve func(jwk.Key), reject func(error)) {
return func(resolve func(jwk.Key), reject func(error)) {
require.Eventually(t, func() bool {
cache.lock.Lock()
defer cache.lock.Unlock()
pk, ok := cache.pubKeys["key"]
return ok && pk.ctx.Size() == 2
}, time.Second*5, time.Millisecond)
select {
case <-ctx.Done():
case <-time.After(1 * time.Second):
assert.Fail(t, "expected GetKey to get cancelled context in time")
}
t.Cleanup(func() {
select {
case <-getKeyReturned:
case <-time.After(1 * time.Second):
assert.Fail(t, "expected GetKey to return from cancelled context in time")
}
})
reject(errors.New("error which is not surfaced"))
}
})
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
result, err := cache.GetKey(ctx1, "key")
assert.Equal(t, context.Canceled, err)
assert.Nil(t, result)
}()
go func() {
defer wg.Done()
result, err := cache.GetKey(ctx2, "key")
assert.Equal(t, context.Canceled, err)
assert.Nil(t, result)
}()
cancel1(assert.AnError)
cancel2(assert.AnError)
wg.Wait()
close(getKeyReturned)
})
t.Run("if first caller cancels their context, other callers should still await", func(t *testing.T) {
t.Parallel()
var cache *PubKeyCache
assertSize := func(size int) {
require.Eventually(t, func() bool {
cache.lock.Lock()
defer cache.lock.Unlock()
pk, ok := cache.pubKeys["key"]
return ok && pk.ctx.Size() == size
}, time.Second*5, time.Millisecond)
}
ctx, cancel := context.WithCancelCause(t.Context())
getKeyReturned := make(chan struct{})
cache = NewPubKeyCache(func(ctx context.Context, i string) func(resolve func(jwk.Key), reject func(error)) {
return func(resolve func(jwk.Key), reject func(error)) {
assertSize(3)
// cancel the first caller
cancel(assert.AnError)
t.Cleanup(func() {
select {
case <-getKeyReturned:
case <-time.After(1 * time.Second):
assert.Fail(t, "expected GetKey to return from cancelled context in time")
}
})
resolve(testKey)
}
})
var wg sync.WaitGroup
wg.Add(3)
go func() {
defer wg.Done()
result, err := cache.GetKey(ctx, "key")
assert.Equal(t, context.Canceled, err)
assert.Nil(t, result)
}()
go func() {
defer wg.Done()
assertSize(1)
result, err := cache.GetKey(t.Context(), "key")
require.NoError(t, err)
assert.Equal(t, testKey, result)
}()
go func() {
defer wg.Done()
assertSize(2)
result, err := cache.GetKey(t.Context(), "key")
require.NoError(t, err)
assert.Equal(t, testKey, result)
}()
wg.Wait()
close(getKeyReturned)
})
}