Merge branch 'master' into merge-release1.11-master

This commit is contained in:
Bernd Verst 2023-05-25 12:34:51 -05:00 committed by GitHub
commit 92b830b085
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 426 additions and 36 deletions

View File

@ -126,11 +126,10 @@ func (k *keyvaultCrypto) getKeyFromVault(parentCtx context.Context, kid keyID) (
}
// Handler for the getKeyCacheFn method
func (k *keyvaultCrypto) getKeyCacheFn(key string) func(resolve func(jwk.Key), reject func(error)) {
func (k *keyvaultCrypto) getKeyCacheFn(ctx context.Context, key string) func(resolve func(jwk.Key), reject func(error)) {
kid := newKeyID(key)
parentCtx := context.Background()
return func(resolve func(jwk.Key), reject func(error)) {
pk, err := k.getKeyFromVault(parentCtx, kid)
pk, err := k.getKeyFromVault(ctx, kid)
if err != nil {
reject(err)
return

View File

@ -19,53 +19,79 @@ import (
"github.com/chebyrash/promise"
"github.com/lestrrat-go/jwx/v2/jwk"
kitctx "github.com/dapr/kit/context"
)
// GetKeyFn is the type of the getKeyFn function used by the PubKeyCache.
type GetKeyFn = func(key string) func(resolve func(jwk.Key), reject func(error))
type GetKeyFn = func(ctx context.Context, key string) func(resolve func(jwk.Key), reject func(error))
// PubKeyCache implements GetKey with a local cache.
// PubKeyCache implements GetKey with a local cache. We use promises for cache
// entries so that multiple callers getting the same key at the same time
// (where the key is not in the cache yet), will result in only a single key
// fetch.
// Each cache item uses a context pool so that a key fetch call will only be
// cancelled once all callers have cancelled their context.
type PubKeyCache struct {
getKeyFn GetKeyFn
pubKeys map[string]*promise.Promise[jwk.Key]
pubKeysLock *sync.Mutex
pubKeys map[string]pubKeyCacheEntry
lock sync.Mutex
}
type pubKeyCacheEntry struct {
promise *promise.Promise[jwk.Key]
ctx *kitctx.Pool
}
// NewPubKeyCache returns a new PubKeyCache object
func NewPubKeyCache(getKeyFn GetKeyFn) *PubKeyCache {
return &PubKeyCache{
getKeyFn: getKeyFn,
pubKeys: map[string]*promise.Promise[jwk.Key]{},
pubKeysLock: &sync.Mutex{},
getKeyFn: getKeyFn,
pubKeys: make(map[string]pubKeyCacheEntry),
}
}
// GetKey returns a public key from the cache, or uses getKeyFn to request it
func (kc *PubKeyCache) GetKey(parentCtx context.Context, key string) (pubKey jwk.Key, err error) {
timeoutPromise := promise.New(func(_ func(jwk.Key), reject func(error)) {
<-parentCtx.Done()
reject(parentCtx.Err())
})
// GetKey returns a public key from the cache, or uses getKeyFn to request it.
func (kc *PubKeyCache) GetKey(ctx context.Context, key string) (jwk.Key, error) {
// Check if the key is in the cache already
kc.pubKeysLock.Lock()
kc.lock.Lock()
p, ok := kc.pubKeys[key]
if ok {
kc.pubKeysLock.Unlock()
return promise.Race(p, timeoutPromise).Await()
// Add the context to the context pool and return the promise (which may
// already be resolved).
kc.pubKeys[key].ctx.Add(ctx)
kc.lock.Unlock()
jwkKey, err := p.promise.Await(ctx)
if err != nil || jwkKey == nil {
return nil, err
}
return *jwkKey, nil
}
// Create a new promise, which resolves with a background context
p = promise.New(kc.getKeyFn(key))
p = promise.Catch(p, func(err error) error {
kc.pubKeysLock.Lock()
delete(kc.pubKeys, key)
kc.pubKeysLock.Unlock()
return err
})
// Key is not in the cache, create the promise in the cache and return
// result. Create a new context pool for the promise. Cancel the pool on
// return so that the context pool doesn't expand indefinitely on cache
// reads.
p.ctx = kitctx.NewPool(ctx)
p.promise = promise.Catch(
promise.New(kc.getKeyFn(p.ctx, key)),
p.ctx,
func(err error) error {
kc.lock.Lock()
p.ctx.Cancel()
delete(kc.pubKeys, key)
kc.lock.Unlock()
return err
},
)
kc.pubKeys[key] = p
kc.pubKeysLock.Unlock()
kc.lock.Unlock()
return promise.Race(p, timeoutPromise).Await()
jwkKey, err := p.promise.Await(ctx)
if err != nil || jwkKey == nil {
return nil, err
}
p.ctx.Cancel()
return *jwkKey, nil
}

365
crypto/pubkey_cache_test.go Normal file
View File

@ -0,0 +1,365 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package crypto
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"errors"
"sync"
"testing"
"time"
"github.com/chebyrash/promise"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
kitctx "github.com/dapr/kit/context"
)
func TestPubKeyCacheGetKey(t *testing.T) {
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
testKey, err := jwk.FromRaw(pk.PublicKey)
require.NoError(t, err)
pk2, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
testKey2, err := jwk.FromRaw(pk2.PublicKey)
require.NoError(t, err)
t.Run("existing key should return key", func(t *testing.T) {
t.Parallel()
cache := NewPubKeyCache(func(context.Context, string) func(resolve func(jwk.Key), reject func(error)) {
return func(resolve func(jwk.Key), reject func(error)) { assert.Fail(t, "should not be called") }
})
cache.pubKeys["key"] = pubKeyCacheEntry{
promise: promise.New(func(resolve func(jwk.Key), reject func(error)) {
resolve(testKey)
}),
ctx: kitctx.NewPool(),
}
result, err := cache.GetKey(context.Background(), "key")
assert.NoError(t, err)
assert.Equal(t, testKey, result)
})
t.Run("two different keys should be returned", func(t *testing.T) {
t.Parallel()
cache := NewPubKeyCache(func(context.Context, string) func(resolve func(jwk.Key), reject func(error)) {
return func(resolve func(jwk.Key), reject func(error)) { assert.Fail(t, "should not be called") }
})
cache.pubKeys["key"] = pubKeyCacheEntry{
promise: promise.New(func(resolve func(jwk.Key), reject func(error)) {
resolve(testKey)
}),
ctx: kitctx.NewPool(),
}
cache.pubKeys["another-key"] = pubKeyCacheEntry{
promise: promise.New(func(resolve func(jwk.Key), reject func(error)) {
resolve(testKey2)
}),
ctx: kitctx.NewPool(),
}
result, err := cache.GetKey(context.Background(), "key")
assert.NoError(t, err)
assert.Equal(t, testKey, result)
result, err = cache.GetKey(context.Background(), "another-key")
assert.NoError(t, err)
assert.Equal(t, testKey2, result)
})
t.Run("cold cache should fetch key", func(t *testing.T) {
t.Parallel()
var called int
cache := NewPubKeyCache(func(context.Context, string) func(resolve func(jwk.Key), reject func(error)) {
return func(resolve func(jwk.Key), reject func(error)) {
called++
resolve(testKey)
}
})
result, err := cache.GetKey(context.Background(), "key")
assert.NoError(t, err)
assert.Equal(t, testKey, result)
assert.Equal(t, 1, called, "should be called once")
})
t.Run("cold cache should fetch different keys", func(t *testing.T) {
t.Parallel()
var called int
cache := NewPubKeyCache(func(ctx context.Context, i string) func(resolve func(jwk.Key), reject func(error)) {
return func(resolve func(jwk.Key), reject func(error)) {
called++
switch i {
case "key":
resolve(testKey)
default:
resolve(testKey2)
}
}
})
result, err := cache.GetKey(context.Background(), "key")
assert.NoError(t, err)
assert.Equal(t, testKey, result)
result, err = cache.GetKey(context.Background(), "another-key")
assert.NoError(t, err)
assert.Equal(t, testKey2, result)
assert.Equal(t, 2, called, "should be called once")
})
t.Run("fetch key which errors, should error getKey", func(t *testing.T) {
t.Parallel()
var called int
cache := NewPubKeyCache(func(ctx context.Context, i string) func(resolve func(jwk.Key), reject func(error)) {
return func(resolve func(jwk.Key), reject func(error)) {
assert.Equal(t, "key", i)
called++
reject(assert.AnError)
}
})
result, err := cache.GetKey(context.Background(), "key")
assert.Error(t, err)
assert.Nil(t, result)
assert.Equal(t, 1, called, "should be called once")
})
t.Run("multiple fetch key at the same time should only all getKey once", func(t *testing.T) {
t.Parallel()
var called int
cache := NewPubKeyCache(func(ctx context.Context, i string) func(resolve func(jwk.Key), reject func(error)) {
return func(resolve func(jwk.Key), reject func(error)) {
assert.Equal(t, "key", i)
called++
resolve(testKey)
}
})
var wg sync.WaitGroup
wg.Add(10)
for i := 0; i < 10; i++ {
go func() {
defer wg.Done()
result, err := cache.GetKey(context.Background(), "key")
assert.NoError(t, err)
assert.Equal(t, testKey, result)
}()
}
wg.Wait()
assert.Equal(t, 1, called, "should be called once")
})
t.Run("calling get key and context is cancelled should return context error", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancelCause(context.Background())
getKeyReturned := make(chan struct{})
cache := NewPubKeyCache(func(ctx context.Context, i string) func(resolve func(jwk.Key), reject func(error)) {
return func(resolve func(jwk.Key), reject func(error)) {
t.Cleanup(func() {
select {
case <-getKeyReturned:
case <-time.After(1 * time.Second):
assert.Fail(t, "expected GetKey to return from cancelled context in time")
}
})
assert.Equal(t, "key", i)
cancel(assert.AnError)
resolve(testKey)
}
})
result, err := cache.GetKey(ctx, "key")
assert.Equal(t, context.Canceled, err)
assert.Nil(t, result)
close(getKeyReturned)
})
t.Run("only callers with cancelled contexts should return context error", func(t *testing.T) {
t.Parallel()
ctx1, cancel1 := context.WithCancelCause(context.Background())
ctx2 := context.Background()
getKeyReturned := make(chan struct{})
var cache *PubKeyCache
cache = NewPubKeyCache(func(ctx context.Context, i string) func(resolve func(jwk.Key), reject func(error)) {
return func(resolve func(jwk.Key), reject func(error)) {
require.Eventually(t, func() bool {
cache.lock.Lock()
defer cache.lock.Unlock()
return cache.pubKeys["key"].ctx.Size() == 2
}, time.Second*5, time.Millisecond)
t.Cleanup(func() {
select {
case <-getKeyReturned:
case <-time.After(1 * time.Second):
assert.Fail(t, "expected GetKey to return from cancelled context in time")
}
})
assert.Equal(t, "key", i)
cancel1(assert.AnError)
select {
case <-ctx.Done():
assert.Fail(t, "GetKey context should not be cancelled")
default:
}
resolve(testKey)
}
})
go func() {
result, err := cache.GetKey(ctx2, "key")
assert.NoError(t, err)
assert.Equal(t, testKey, result)
close(getKeyReturned)
}()
result, err := cache.GetKey(ctx1, "key")
assert.Equal(t, context.Canceled, err)
assert.Nil(t, result)
})
t.Run("if all callers give cancelled contexts, the underlying context should also be cancelled", func(t *testing.T) {
t.Parallel()
ctx1, cancel1 := context.WithCancelCause(context.Background())
ctx2, cancel2 := context.WithCancelCause(context.Background())
getKeyReturned := make(chan struct{})
var cache *PubKeyCache
cache = NewPubKeyCache(func(ctx context.Context, i string) func(resolve func(jwk.Key), reject func(error)) {
return func(resolve func(jwk.Key), reject func(error)) {
require.Eventually(t, func() bool {
cache.lock.Lock()
defer cache.lock.Unlock()
pk, ok := cache.pubKeys["key"]
return ok && pk.ctx.Size() == 2
}, time.Second*5, time.Millisecond)
select {
case <-ctx.Done():
case <-time.After(1 * time.Second):
assert.Fail(t, "expected GetKey to get cancelled context in time")
}
t.Cleanup(func() {
select {
case <-getKeyReturned:
case <-time.After(1 * time.Second):
assert.Fail(t, "expected GetKey to return from cancelled context in time")
}
})
reject(errors.New("error which is not surfaced"))
}
})
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
result, err := cache.GetKey(ctx1, "key")
assert.Equal(t, context.Canceled, err)
assert.Nil(t, result)
}()
go func() {
defer wg.Done()
result, err := cache.GetKey(ctx2, "key")
assert.Equal(t, context.Canceled, err)
assert.Nil(t, result)
}()
cancel1(assert.AnError)
cancel2(assert.AnError)
wg.Wait()
close(getKeyReturned)
})
t.Run("if first caller cancels their context, other callers should still await", func(t *testing.T) {
t.Parallel()
var cache *PubKeyCache
assertSize := func(size int) {
require.Eventually(t, func() bool {
cache.lock.Lock()
defer cache.lock.Unlock()
pk, ok := cache.pubKeys["key"]
return ok && pk.ctx.Size() == size
}, time.Second*5, time.Millisecond)
}
ctx, cancel := context.WithCancelCause(context.Background())
getKeyReturned := make(chan struct{})
cache = NewPubKeyCache(func(ctx context.Context, i string) func(resolve func(jwk.Key), reject func(error)) {
return func(resolve func(jwk.Key), reject func(error)) {
assertSize(3)
// cancel the first caller
cancel(assert.AnError)
t.Cleanup(func() {
select {
case <-getKeyReturned:
case <-time.After(1 * time.Second):
assert.Fail(t, "expected GetKey to return from cancelled context in time")
}
})
resolve(testKey)
}
})
var wg sync.WaitGroup
wg.Add(3)
go func() {
defer wg.Done()
result, err := cache.GetKey(ctx, "key")
assert.Equal(t, context.Canceled, err)
assert.Nil(t, result)
}()
go func() {
defer wg.Done()
assertSize(1)
result, err := cache.GetKey(context.Background(), "key")
assert.NoError(t, err)
assert.Equal(t, testKey, result)
}()
go func() {
defer wg.Done()
assertSize(2)
result, err := cache.GetKey(context.Background(), "key")
assert.NoError(t, err)
assert.Equal(t, testKey, result)
}()
wg.Wait()
close(getKeyReturned)
})
}

2
go.mod
View File

@ -43,7 +43,7 @@ require (
github.com/bradfitz/gomemcache v0.0.0-20230124162541-5f7a7d875746
github.com/camunda/zeebe/clients/go/v8 v8.1.8
github.com/cenkalti/backoff/v4 v4.2.1
github.com/chebyrash/promise v0.0.0-20220530143319-1123826567d6
github.com/chebyrash/promise v0.0.0-20230414144155-dd8f641675f4
github.com/cinience/go_rocketmq v0.0.2
github.com/cloudevents/sdk-go/binding/format/protobuf/v2 v2.13.0
github.com/cloudevents/sdk-go/v2 v2.13.0

4
go.sum
View File

@ -676,8 +676,8 @@ github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chebyrash/promise v0.0.0-20220530143319-1123826567d6 h1:AtcTeZIfucJjiqhIeMoOAR292ti2QOyo2aqN3SoWopo=
github.com/chebyrash/promise v0.0.0-20220530143319-1123826567d6/go.mod h1:4DRxP3p0R7/5msq1uKcI1THYmfWgFXxQqr0DutaIAEk=
github.com/chebyrash/promise v0.0.0-20230414144155-dd8f641675f4 h1:UgduI3q9y7ShLNoZacVZqFFR8CJAzv7CqFOpwvFIxBc=
github.com/chebyrash/promise v0.0.0-20230414144155-dd8f641675f4/go.mod h1:f0Wnnt1WnX1xY9NnS4856Rsww/8Emacybw7kwbLEfHc=
github.com/chenzhuoyu/iasm v0.0.0-20220818063314-28c361dae733/go.mod h1:wOQ0nsbeOLa2awv8bUYFW/EHXbjQMlZ10fAlXDB2sz8=
github.com/chenzhuoyu/iasm v0.0.0-20230222070914-0b1b64b0e762 h1:4+00EOUb1t9uxAbgY8VvgfKJKDpim3co4MqsAbelIbs=
github.com/chenzhuoyu/iasm v0.0.0-20230222070914-0b1b64b0e762/go.mod h1:Xjy2NpN3h7aUqeqM+woSuuvxmIe6+DDsiNLIrkAmYog=

View File

@ -87,7 +87,7 @@ require (
github.com/bytedance/gopkg v0.0.0-20220817015305-b879a72dc90f // indirect
github.com/cenkalti/backoff v2.2.1+incompatible // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/chebyrash/promise v0.0.0-20220530143319-1123826567d6 // indirect
github.com/chebyrash/promise v0.0.0-20230414144155-dd8f641675f4 // indirect
github.com/chenzhuoyu/iasm v0.0.0-20230222070914-0b1b64b0e762 // indirect
github.com/choleraehyq/pid v0.0.16 // indirect
github.com/cloudevents/sdk-go/binding/format/protobuf/v2 v2.13.0 // indirect

View File

@ -252,8 +252,8 @@ github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chebyrash/promise v0.0.0-20220530143319-1123826567d6 h1:AtcTeZIfucJjiqhIeMoOAR292ti2QOyo2aqN3SoWopo=
github.com/chebyrash/promise v0.0.0-20220530143319-1123826567d6/go.mod h1:4DRxP3p0R7/5msq1uKcI1THYmfWgFXxQqr0DutaIAEk=
github.com/chebyrash/promise v0.0.0-20230414144155-dd8f641675f4 h1:UgduI3q9y7ShLNoZacVZqFFR8CJAzv7CqFOpwvFIxBc=
github.com/chebyrash/promise v0.0.0-20230414144155-dd8f641675f4/go.mod h1:f0Wnnt1WnX1xY9NnS4856Rsww/8Emacybw7kwbLEfHc=
github.com/chenzhuoyu/iasm v0.0.0-20220818063314-28c361dae733/go.mod h1:wOQ0nsbeOLa2awv8bUYFW/EHXbjQMlZ10fAlXDB2sz8=
github.com/chenzhuoyu/iasm v0.0.0-20230222070914-0b1b64b0e762 h1:4+00EOUb1t9uxAbgY8VvgfKJKDpim3co4MqsAbelIbs=
github.com/chenzhuoyu/iasm v0.0.0-20230222070914-0b1b64b0e762/go.mod h1:Xjy2NpN3h7aUqeqM+woSuuvxmIe6+DDsiNLIrkAmYog=