JWKS Cache: add WaitForCacheReady method (#47)

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
Alessandro (Ale) Segala 2023-04-18 10:01:51 -07:00 committed by GitHub
parent b60341fe3e
commit c93a9df941
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 115 additions and 1 deletions

View File

@ -60,6 +60,7 @@ type JWKSCache struct {
lock sync.RWMutex
client *http.Client
running atomic.Bool
initCh chan error
}
// NewJWKSCache creates a new JWKSCache object.
@ -70,6 +71,8 @@ func NewJWKSCache(location string, logger logger.Logger) *JWKSCache {
requestTimeout: defaultRequestTimeout,
minRefreshInterval: defaultMinRefreshInterval,
initCh: make(chan error, 1),
}
}
@ -84,9 +87,16 @@ func (c *JWKSCache) Start(ctx context.Context) error {
// Init the cache
err := c.initCache(ctx)
if err != nil {
return fmt.Errorf("failed to init cache: %w", err)
err = fmt.Errorf("failed to init cache: %w", err)
// Store the error in the initCh, then close it
c.initCh <- err
close(c.initCh)
return err
}
// Close initCh
close(c.initCh)
// Block until context is canceled
<-ctx.Done()
@ -116,6 +126,17 @@ func (c *JWKSCache) KeySet() jwk.Set {
return c.jwks
}
// WaitForCacheReady pauses until the cache is ready (the initial JWKS has been fetched) or the passed ctx is canceled.
// It will return the initialization error.
func (c *JWKSCache) WaitForCacheReady(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case err := <-c.initCh:
return err
}
}
// Init the cache from the given location.
func (c *JWKSCache) initCache(ctx context.Context) error {
if len(c.location) == 0 {

View File

@ -139,6 +139,99 @@ func TestJWKSCache(t *testing.T) {
require.True(t, ok)
require.NotNil(t, key)
})
t.Run("start and wait for init", func(t *testing.T) {
cache := NewJWKSCache(testJWKS1, log)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
// Start in background
errCh := make(chan error)
go func() {
errCh <- cache.Start(ctx)
}()
// Wait for initialization
err := cache.WaitForCacheReady(ctx)
require.NoError(t, err)
// Canceling the context should make Start() return
cancel()
require.Nil(t, <-errCh)
})
t.Run("start and init fails", func(t *testing.T) {
// Create a custom HTTP client with a RoundTripper that doesn't require starting a TCP listener
client := &http.Client{
Transport: roundTripFn(func(r *http.Request) *http.Response {
// Return an error
return &http.Response{
StatusCode: http.StatusInternalServerError,
}
}),
}
cache := NewJWKSCache("https://localhost/jwks.json", log)
cache.SetHTTPClient(client)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
// Start in background
errCh := make(chan error)
go func() {
errCh <- cache.Start(ctx)
}()
// Wait for initialization
err := cache.WaitForCacheReady(ctx)
require.Error(t, err)
assert.ErrorContains(t, err, "failed to fetch JWKS")
// Canceling the context should make Start() return with the init error
cancel()
require.Equal(t, err, <-errCh)
})
t.Run("start and init times out", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1500*time.Millisecond)
defer cancel()
// Create a custom HTTP client with a RoundTripper that doesn't require starting a TCP listener
client := &http.Client{
Transport: roundTripFn(func(r *http.Request) *http.Response {
// Wait until context is canceled
<-ctx.Done()
// Sleep for another 500ms
time.Sleep(500 * time.Millisecond)
// Return an error
return &http.Response{
StatusCode: http.StatusInternalServerError,
}
}),
}
cache := NewJWKSCache("https://localhost/jwks.json", log)
cache.SetHTTPClient(client)
// Start in background
errCh := make(chan error)
go func() {
errCh <- cache.Start(ctx)
}()
// Wait for initialization
err := cache.WaitForCacheReady(context.Background())
require.Error(t, err)
assert.ErrorContains(t, err, "failed to fetch JWKS")
assert.ErrorIs(t, err, context.DeadlineExceeded)
// Canceling the context should make Start() return with the init error
cancel()
require.Equal(t, err, <-errCh)
})
}
type roundTripFn func(req *http.Request) *http.Response