mirror of https://github.com/dapr/kit.git
242 lines
7.5 KiB
Go
242 lines
7.5 KiB
Go
/*
|
|
Copyright 2023 The Dapr Authors
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package jwkscache
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/dapr/kit/logger"
|
|
)
|
|
|
|
const (
|
|
testJWKS1 = `{"keys":[{"kid":"mykey","alg":"RS256","kty":"RSA","use":"sig","e":"AQAB","n":"3I2mdIK4mRRu-ywMrYjUZzBxt0NlAVLrMhGlaJsby7PWTMiLpZVip4SBD9GwnCU0TGFD7k2-7tfs0y9U6WV7MwgCjc9m_DUUGbE-kKjEU7JYkLzYlndys-6xuhD4Jf1hu9AZVdfXftpWSy_NNg6fVwTH4nckOAbOSL1hXToOYWQcDDW95Rhw3U4z04PqssEpRKn5KGBuTahNNNiZcWns99pChpLTxgdm93LjMBI1KCGBpOaz7fcQJ9V3c6rSwMKyY3IPm1LwS6PIs7xb2ZJ0Eb8A6MtCkGhgNsodpkxhqKbqtxI-KqTuZy9g4jb8WKjJq9lB9q-HPHoQqIEDom6P8w"}]}`
|
|
testJWKS2 = `{"keys":[{"kid":"mykey","alg":"RS256","kty":"RSA","use":"sig","e":"AQAB","n":"3I2mdIK4mRRu-ywMrYjUZzBxt0NlAVLrMhGlaJsby7PWTMiLpZVip4SBD9GwnCU0TGFD7k2-7tfs0y9U6WV7MwgCjc9m_DUUGbE-kKjEU7JYkLzYlndys-6xuhD4Jf1hu9AZVdfXftpWSy_NNg6fVwTH4nckOAbOSL1hXToOYWQcDDW95Rhw3U4z04PqssEpRKn5KGBuTahNNNiZcWns99pChpLTxgdm93LjMBI1KCGBpOaz7fcQJ9V3c6rSwMKyY3IPm1LwS6PIs7xb2ZJ0Eb8A6MtCkGhgNsodpkxhqKbqtxI-KqTuZy9g4jb8WKjJq9lB9q-HPHoQqIEDom6P8w"},{"alg":"RS256","kty":"RSA","use":"sig","n":"yeNlzlub94YgerT030codqEztjfU_S6X4DbDA_iVKkjAWtYfPHDzz_sPCT1Axz6isZdf3lHpq_gYX4Sz-cbe4rjmigxUxr-FgKHQy3HeCdK6hNq9ASQvMK9LBOpXDNn7mei6RZWom4wo3CMvvsY1w8tjtfLb-yQwJPltHxShZq5-ihC9irpLI9xEBTgG12q5lGIFPhTl_7inA1PFK97LuSLnTJzW0bj096v_TMDg7pOWm_zHtF53qbVsI0e3v5nmdKXdFf9BjIARRfVrbxVxiZHjU6zL6jY5QJdh1QCmENoejj_ytspMmGW7yMRxzUqgxcAqOBpVm0b-_mW3HoBdjQ","e":"AQAB","kid":"testkey"}]}`
|
|
)
|
|
|
|
func TestJWKSCache(t *testing.T) {
|
|
log := logger.NewLogger("test")
|
|
|
|
t.Run("init with value", func(t *testing.T) {
|
|
cache := NewJWKSCache(testJWKS1, log)
|
|
err := cache.initCache(context.Background())
|
|
require.NoError(t, err)
|
|
|
|
set := cache.KeySet()
|
|
require.Equal(t, 1, set.Len())
|
|
|
|
key, ok := set.LookupKeyID("mykey")
|
|
require.True(t, ok)
|
|
require.NotNil(t, key)
|
|
})
|
|
|
|
t.Run("init with base64-encoded value", func(t *testing.T) {
|
|
cache := NewJWKSCache(base64.StdEncoding.EncodeToString([]byte(testJWKS1)), log)
|
|
err := cache.initCache(context.Background())
|
|
require.NoError(t, err)
|
|
|
|
set := cache.KeySet()
|
|
require.Equal(t, 1, set.Len())
|
|
|
|
key, ok := set.LookupKeyID("mykey")
|
|
require.True(t, ok)
|
|
require.NotNil(t, key)
|
|
})
|
|
|
|
t.Run("init with local file", func(t *testing.T) {
|
|
// Create a temporary directory and put the JWKS in there
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "jwks.json")
|
|
err := os.WriteFile(path, []byte(testJWKS1), 0o666)
|
|
require.NoError(t, err)
|
|
|
|
// Should wait for first file to be loaded before initialization is reported as completed
|
|
cache := NewJWKSCache(path, log)
|
|
err = cache.initCache(context.Background())
|
|
require.NoError(t, err)
|
|
|
|
set := cache.KeySet()
|
|
require.Equal(t, 1, set.Len())
|
|
|
|
key, ok := set.LookupKeyID("mykey")
|
|
require.True(t, ok)
|
|
require.NotNil(t, key)
|
|
|
|
// Sleep 1s before writing the file
|
|
time.Sleep(time.Second)
|
|
|
|
// Update the file and verify it's picked up
|
|
err = os.WriteFile(path, []byte(testJWKS2), 0o666)
|
|
require.NoError(t, err)
|
|
|
|
assert.Eventually(t, func() bool {
|
|
return cache.KeySet().Len() == 2
|
|
}, 5*time.Second, 50*time.Millisecond)
|
|
|
|
set = cache.KeySet()
|
|
key, ok = set.LookupKeyID("mykey")
|
|
require.True(t, ok)
|
|
require.NotNil(t, key)
|
|
key, ok = set.LookupKeyID("testkey")
|
|
require.True(t, ok)
|
|
require.NotNil(t, key)
|
|
})
|
|
|
|
t.Run("init with HTTP client", func(t *testing.T) {
|
|
// Create a custom HTTP client with a RoundTripper that doesn't require starting a TCP listener
|
|
client := &http.Client{
|
|
Transport: roundTripFn(func(r *http.Request) *http.Response {
|
|
if r.Method != http.MethodGet || r.URL.Path != "/jwks.json" {
|
|
return &http.Response{
|
|
StatusCode: http.StatusNotFound,
|
|
Header: make(http.Header),
|
|
}
|
|
}
|
|
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Header: http.Header{
|
|
"content-type": []string{"application/json"},
|
|
},
|
|
Body: io.NopCloser(strings.NewReader(testJWKS1)),
|
|
}
|
|
}),
|
|
}
|
|
|
|
cache := NewJWKSCache("http://localhost/jwks.json", log)
|
|
cache.SetHTTPClient(client)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancel()
|
|
err := cache.initCache(ctx)
|
|
require.NoError(t, err)
|
|
|
|
set := cache.KeySet()
|
|
require.Equal(t, 1, set.Len())
|
|
|
|
key, ok := set.LookupKeyID("mykey")
|
|
require.True(t, ok)
|
|
require.NotNil(t, key)
|
|
})
|
|
|
|
t.Run("start and wait for init", func(t *testing.T) {
|
|
cache := NewJWKSCache(testJWKS1, log)
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancel()
|
|
|
|
// Start in background
|
|
errCh := make(chan error)
|
|
go func() {
|
|
errCh <- cache.Start(ctx)
|
|
}()
|
|
|
|
// Wait for initialization
|
|
err := cache.WaitForCacheReady(ctx)
|
|
require.NoError(t, err)
|
|
|
|
// Canceling the context should make Start() return
|
|
cancel()
|
|
require.NoError(t, <-errCh)
|
|
})
|
|
|
|
t.Run("start and init fails", func(t *testing.T) {
|
|
// Create a custom HTTP client with a RoundTripper that doesn't require starting a TCP listener
|
|
client := &http.Client{
|
|
Transport: roundTripFn(func(r *http.Request) *http.Response {
|
|
// Return an error
|
|
return &http.Response{
|
|
StatusCode: http.StatusInternalServerError,
|
|
}
|
|
}),
|
|
}
|
|
|
|
cache := NewJWKSCache("https://localhost/jwks.json", log)
|
|
cache.SetHTTPClient(client)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancel()
|
|
|
|
// Start in background
|
|
errCh := make(chan error)
|
|
go func() {
|
|
errCh <- cache.Start(ctx)
|
|
}()
|
|
|
|
// Wait for initialization
|
|
err := cache.WaitForCacheReady(ctx)
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, "failed to fetch JWKS")
|
|
|
|
// Canceling the context should make Start() return with the init error
|
|
cancel()
|
|
require.Equal(t, err, <-errCh)
|
|
})
|
|
|
|
t.Run("start and init times out", func(t *testing.T) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 1500*time.Millisecond)
|
|
defer cancel()
|
|
|
|
// Create a custom HTTP client with a RoundTripper that doesn't require starting a TCP listener
|
|
client := &http.Client{
|
|
Transport: roundTripFn(func(r *http.Request) *http.Response {
|
|
// Wait until context is canceled
|
|
<-ctx.Done()
|
|
|
|
// Sleep for another 500ms
|
|
time.Sleep(500 * time.Millisecond)
|
|
|
|
// Return an error
|
|
return &http.Response{
|
|
StatusCode: http.StatusInternalServerError,
|
|
}
|
|
}),
|
|
}
|
|
|
|
cache := NewJWKSCache("https://localhost/jwks.json", log)
|
|
cache.SetHTTPClient(client)
|
|
|
|
// Start in background
|
|
errCh := make(chan error)
|
|
go func() {
|
|
errCh <- cache.Start(ctx)
|
|
}()
|
|
|
|
// Wait for initialization
|
|
err := cache.WaitForCacheReady(context.Background())
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, "failed to fetch JWKS")
|
|
require.ErrorIs(t, err, context.DeadlineExceeded)
|
|
|
|
// Canceling the context should make Start() return with the init error
|
|
cancel()
|
|
require.Equal(t, err, <-errCh)
|
|
})
|
|
}
|
|
|
|
type roundTripFn func(req *http.Request) *http.Response
|
|
|
|
func (f roundTripFn) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
return f(req), nil
|
|
}
|