mirror of https://github.com/dapr/kit.git
Updates based on joshvanl feedback
Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>
This commit is contained in:
parent
5e0eb9625d
commit
a885084755
|
@ -244,10 +244,15 @@ func (s *SPIFFE) fetchIdentity(ctx context.Context) (*Identity, error) {
|
|||
// we are using ParseInsecure here as the expectation is that the
|
||||
// requestSVIDFn will have already parsed and validate the JWT SVID
|
||||
// before returning it.
|
||||
//
|
||||
// we are parsing the token using our SPIFFE ID's trust domain
|
||||
// as the audience as we expect the issuer to always include
|
||||
// that as an audience since that ensures that the token is
|
||||
// valid for us and our trust domain.
|
||||
audiences := []string{spiffeID.TrustDomain().Name()}
|
||||
jwtSvid, err := jwtsvid.ParseInsecure(svidResponse.JWT, audiences)
|
||||
if err != nil {
|
||||
s.log.Warnf("Failed to parse JWT SVID: %v", err)
|
||||
s.log.Errorf("Failed to parse JWT SVID: %v, continuing without JWT SVID", err)
|
||||
} else {
|
||||
identity.JWTSVID = jwtSvid
|
||||
s.log.Infof("Successfully received JWT SVID with expiry: %s", jwtSvid.Expiry.String())
|
||||
|
@ -277,7 +282,7 @@ func (s *SPIFFE) fetchIdentity(ctx context.Context) (*Identity, error) {
|
|||
}
|
||||
|
||||
if svidResponse.JWT != "" {
|
||||
files["token.jwt"] = []byte(svidResponse.JWT)
|
||||
files["jwt_svid.token"] = []byte(svidResponse.JWT)
|
||||
}
|
||||
|
||||
if err := s.dir.Write(files); err != nil {
|
||||
|
|
|
@ -16,6 +16,8 @@ package spiffe
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
|
||||
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
|
||||
|
@ -24,7 +26,7 @@ import (
|
|||
var (
|
||||
errNoX509SVIDAvailable = errors.New("no X509 SVID available")
|
||||
errNoJWTSVIDAvailable = errors.New("no JWT SVID available")
|
||||
errAudienceMismatch = errors.New("JWT SVID has different audiences than requested")
|
||||
errAudienceRequired = errors.New("audience is required")
|
||||
)
|
||||
|
||||
// svidSource is an implementation of both go-spiffe x509svid.Source and jwtsvid.Source interfaces.
|
||||
|
@ -48,12 +50,27 @@ func (s *svidSource) GetX509SVID() (*x509svid.SVID, error) {
|
|||
return svid, nil
|
||||
}
|
||||
|
||||
// audienceMismatchError is an error that contains information about mismatched audiences
|
||||
type audienceMismatchError struct {
|
||||
Expected []string
|
||||
Actual []string
|
||||
}
|
||||
|
||||
func (e *audienceMismatchError) Error() string {
|
||||
return fmt.Sprintf("JWT SVID has different audiences than requested: expected %s, got %s",
|
||||
strings.Join(e.Expected, ", "), strings.Join(e.Actual, ", "))
|
||||
}
|
||||
|
||||
// FetchJWTSVID returns the current JWT SVID.
|
||||
// Implements the go-spiffe jwtsvid.Source interface.
|
||||
func (s *svidSource) FetchJWTSVID(_ context.Context, params jwtsvid.Params) (*jwtsvid.SVID, error) {
|
||||
s.spiffe.lock.RLock()
|
||||
defer s.spiffe.lock.RUnlock()
|
||||
|
||||
if params.Audience == "" {
|
||||
return nil, errAudienceRequired
|
||||
}
|
||||
|
||||
<-s.spiffe.readyCh
|
||||
|
||||
svid := s.spiffe.currentJWTSVID
|
||||
|
@ -64,7 +81,10 @@ func (s *svidSource) FetchJWTSVID(_ context.Context, params jwtsvid.Params) (*jw
|
|||
// verify that the audience being requested is the same as the audience in the SVID
|
||||
// WARN: we do not check extra audiences here.
|
||||
if !audiencesMatch(svid.Audience, []string{params.Audience}) {
|
||||
return nil, errAudienceMismatch
|
||||
return nil, &audienceMismatchError{
|
||||
Expected: []string{params.Audience},
|
||||
Actual: svid.Audience,
|
||||
}
|
||||
}
|
||||
|
||||
return svid, nil
|
||||
|
|
|
@ -14,11 +14,178 @@ limitations under the License.
|
|||
package spiffe
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/spiffe/go-spiffe/v2/spiffeid"
|
||||
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
|
||||
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_svidSource(*testing.T) {
|
||||
var _ x509svid.Source = new(svidSource)
|
||||
var _ jwtsvid.Source = new(svidSource)
|
||||
}
|
||||
|
||||
// createMockJWTSVID creates a mock JWT SVID for testing
|
||||
func createMockJWTSVID(audiences []string) (*jwtsvid.SVID, error) {
|
||||
td, err := spiffeid.TrustDomainFromString("example.org")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id, err := spiffeid.FromSegments(td, "workload")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svid := &jwtsvid.SVID{
|
||||
ID: id,
|
||||
Audience: audiences,
|
||||
Expiry: time.Now().Add(time.Hour),
|
||||
}
|
||||
|
||||
return svid, nil
|
||||
}
|
||||
|
||||
func TestFetchJWTSVID(t *testing.T) {
|
||||
t.Run("should return error when audience is empty", func(t *testing.T) {
|
||||
s := &svidSource{
|
||||
spiffe: &SPIFFE{
|
||||
readyCh: make(chan struct{}),
|
||||
lock: sync.RWMutex{},
|
||||
},
|
||||
}
|
||||
close(s.spiffe.readyCh) // Mark as ready
|
||||
|
||||
svid, err := s.FetchJWTSVID(context.Background(), jwtsvid.Params{
|
||||
Audience: "",
|
||||
})
|
||||
|
||||
assert.Nil(t, svid)
|
||||
assert.ErrorIs(t, err, errAudienceRequired)
|
||||
})
|
||||
|
||||
t.Run("should return error when no JWT SVID available", func(t *testing.T) {
|
||||
s := &svidSource{
|
||||
spiffe: &SPIFFE{
|
||||
readyCh: make(chan struct{}),
|
||||
lock: sync.RWMutex{},
|
||||
currentJWTSVID: nil,
|
||||
},
|
||||
}
|
||||
close(s.spiffe.readyCh) // Mark as ready
|
||||
|
||||
svid, err := s.FetchJWTSVID(context.Background(), jwtsvid.Params{
|
||||
Audience: "test-audience",
|
||||
})
|
||||
|
||||
assert.Nil(t, svid)
|
||||
assert.ErrorIs(t, err, errNoJWTSVIDAvailable)
|
||||
})
|
||||
|
||||
t.Run("should return error when audience doesn't match", func(t *testing.T) {
|
||||
// Create a mock SVID with a specific audience
|
||||
mockJWTSVID, err := createMockJWTSVID([]string{"actual-audience"})
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &svidSource{
|
||||
spiffe: &SPIFFE{
|
||||
readyCh: make(chan struct{}),
|
||||
lock: sync.RWMutex{},
|
||||
currentJWTSVID: mockJWTSVID,
|
||||
},
|
||||
}
|
||||
close(s.spiffe.readyCh) // Mark as ready
|
||||
|
||||
svid, err := s.FetchJWTSVID(context.Background(), jwtsvid.Params{
|
||||
Audience: "requested-audience",
|
||||
})
|
||||
|
||||
assert.Nil(t, svid)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Verify the specific error type and contents
|
||||
audienceErr, ok := err.(*audienceMismatchError)
|
||||
require.True(t, ok, "Expected audienceMismatchError")
|
||||
assert.Equal(t, "JWT SVID has different audiences than requested: expected requested-audience, got actual-audience", audienceErr.Error())
|
||||
})
|
||||
|
||||
t.Run("should return JWT SVID when audience matches", func(t *testing.T) {
|
||||
mockJWTSVID, err := createMockJWTSVID([]string{"test-audience", "extra-audience"})
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &svidSource{
|
||||
spiffe: &SPIFFE{
|
||||
readyCh: make(chan struct{}),
|
||||
lock: sync.RWMutex{},
|
||||
currentJWTSVID: mockJWTSVID,
|
||||
},
|
||||
}
|
||||
close(s.spiffe.readyCh) // Mark as ready
|
||||
|
||||
svid, err := s.FetchJWTSVID(context.Background(), jwtsvid.Params{
|
||||
Audience: "test-audience",
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, mockJWTSVID, svid)
|
||||
})
|
||||
|
||||
t.Run("should wait for readyCh before checking SVID", func(t *testing.T) {
|
||||
mockJWTSVID, err := createMockJWTSVID([]string{"test-audience"})
|
||||
require.NoError(t, err)
|
||||
|
||||
readyCh := make(chan struct{})
|
||||
s := &svidSource{
|
||||
spiffe: &SPIFFE{
|
||||
readyCh: readyCh,
|
||||
lock: sync.RWMutex{},
|
||||
currentJWTSVID: mockJWTSVID,
|
||||
},
|
||||
}
|
||||
|
||||
// Start goroutine to fetch SVID
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
resultCh := make(chan struct {
|
||||
svid *jwtsvid.SVID
|
||||
err error
|
||||
})
|
||||
|
||||
go func() {
|
||||
svid, err := s.FetchJWTSVID(ctx, jwtsvid.Params{
|
||||
Audience: "test-audience",
|
||||
})
|
||||
resultCh <- struct {
|
||||
svid *jwtsvid.SVID
|
||||
err error
|
||||
}{svid, err}
|
||||
}()
|
||||
|
||||
// Assert that fetch is blocked
|
||||
select {
|
||||
case <-resultCh:
|
||||
t.Fatal("FetchJWTSVID should be blocked until readyCh is closed")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
// Expected behavior - fetch is blocked
|
||||
}
|
||||
|
||||
// Close readyCh to unblock fetch
|
||||
close(readyCh)
|
||||
|
||||
// Now fetch should complete
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
assert.NoError(t, result.err)
|
||||
assert.NotNil(t, result.svid)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("FetchJWTSVID should have completed after readyCh was closed")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
*/
|
||||
|
||||
package trustanchors
|
||||
package file
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -29,6 +29,7 @@ import (
|
|||
|
||||
"github.com/dapr/kit/concurrency"
|
||||
"github.com/dapr/kit/crypto/pem"
|
||||
"github.com/dapr/kit/crypto/spiffe/trustanchors"
|
||||
"github.com/dapr/kit/fswatcher"
|
||||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
@ -41,10 +42,10 @@ var (
|
|||
ErrFailedToReadTrustAnchorsFile = errors.New("failed to read trust anchors file")
|
||||
)
|
||||
|
||||
type OptionsFile struct {
|
||||
type Options struct {
|
||||
Log logger.Logger
|
||||
CAPath string
|
||||
JwksPath string
|
||||
JwksPath *string
|
||||
}
|
||||
|
||||
// file is a TrustAnchors implementation that uses a file as the source of trust
|
||||
|
@ -52,7 +53,7 @@ type OptionsFile struct {
|
|||
type file struct {
|
||||
log logger.Logger
|
||||
caPath string
|
||||
jwksPath string
|
||||
jwksPath *string
|
||||
x509Bundle *x509bundle.Bundle
|
||||
jwtBundle *jwtbundle.Bundle
|
||||
rootPEM []byte
|
||||
|
@ -76,7 +77,7 @@ type file struct {
|
|||
caEvent chan struct{}
|
||||
}
|
||||
|
||||
func FromFile(opts OptionsFile) Interface {
|
||||
func From(opts Options) trustanchors.Interface {
|
||||
return &file{
|
||||
fsWatcherInterval: time.Millisecond * 500,
|
||||
initFileWatchInterval: time.Second,
|
||||
|
@ -99,7 +100,12 @@ func (f *file) Run(ctx context.Context) error {
|
|||
defer close(f.closeCh)
|
||||
|
||||
for {
|
||||
if found, err := filesExist(f.caPath, f.jwksPath); err != nil {
|
||||
fs := []string{f.caPath}
|
||||
if f.jwksPath != nil {
|
||||
fs = append(fs, *f.jwksPath)
|
||||
}
|
||||
|
||||
if found, err := filesExist(fs...); err != nil {
|
||||
return err
|
||||
} else if found {
|
||||
break
|
||||
|
@ -121,8 +127,8 @@ func (f *file) Run(ctx context.Context) error {
|
|||
}
|
||||
|
||||
targets := []string{f.caPath}
|
||||
if f.jwksPath != "" {
|
||||
targets = append(targets, f.jwksPath)
|
||||
if f.jwksPath != nil {
|
||||
targets = append(targets, *f.jwksPath)
|
||||
}
|
||||
|
||||
fs, err := fswatcher.New(fswatcher.Options{
|
||||
|
@ -136,7 +142,7 @@ func (f *file) Run(ctx context.Context) error {
|
|||
close(f.readyCh)
|
||||
|
||||
f.log.Infof("Watching trust anchors file '%s' for changes", f.caPath)
|
||||
if f.jwksPath != "" {
|
||||
if f.jwksPath != nil {
|
||||
f.log.Infof("Watching JWT bundle file '%s' for changes", f.jwksPath)
|
||||
}
|
||||
|
||||
|
@ -194,10 +200,10 @@ func (f *file) updateAnchors(ctx context.Context) error {
|
|||
f.rootPEM = rootPEMs
|
||||
f.x509Bundle = x509bundle.FromX509Authorities(spiffeid.TrustDomain{}, trustAnchorCerts)
|
||||
|
||||
if f.jwksPath != "" {
|
||||
jwks, err := os.ReadFile(f.jwksPath)
|
||||
if f.jwksPath != nil {
|
||||
jwks, err := os.ReadFile(*f.jwksPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read JWT bundle file '%s': %w", f.jwksPath, err)
|
||||
return fmt.Errorf("failed to read JWT bundle file '%s': %w", *f.jwksPath, err)
|
||||
}
|
||||
|
||||
jwtBundle, err := jwtbundle.Parse(spiffeid.TrustDomain{}, jwks)
|
|
@ -11,7 +11,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
*/
|
||||
|
||||
package trustanchors
|
||||
package file
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -31,7 +31,7 @@ import (
|
|||
func TestFile_Run(t *testing.T) {
|
||||
t.Run("if Run multiple times, expect error", func(t *testing.T) {
|
||||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
ta := FromFile(OptionsFile{
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
|
@ -74,7 +74,7 @@ func TestFile_Run(t *testing.T) {
|
|||
t.Run("if file is not found and context cancelled, should return ctx.Err", func(t *testing.T) {
|
||||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
|
||||
ta := FromFile(OptionsFile{
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
|
@ -102,7 +102,7 @@ func TestFile_Run(t *testing.T) {
|
|||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, nil, 0o600))
|
||||
|
||||
ta := FromFile(OptionsFile{
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
|
@ -127,7 +127,7 @@ func TestFile_Run(t *testing.T) {
|
|||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, []byte("garbage data"), 0o600))
|
||||
|
||||
ta := FromFile(OptionsFile{
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
|
@ -154,7 +154,7 @@ func TestFile_Run(t *testing.T) {
|
|||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, root, 0o600))
|
||||
|
||||
ta := FromFile(OptionsFile{
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
|
@ -180,7 +180,7 @@ func TestFile_Run(t *testing.T) {
|
|||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, pki.RootCertPEM, 0o600))
|
||||
|
||||
ta := FromFile(OptionsFile{
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
|
@ -211,7 +211,7 @@ func TestFile_Run(t *testing.T) {
|
|||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, root, 0o600))
|
||||
|
||||
ta := FromFile(OptionsFile{
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
|
@ -242,7 +242,7 @@ func TestFile_Run(t *testing.T) {
|
|||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
|
||||
|
||||
ta := FromFile(OptionsFile{
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
|
@ -273,7 +273,7 @@ func TestFile_Run(t *testing.T) {
|
|||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
|
||||
|
||||
ta := FromFile(OptionsFile{
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
|
@ -311,7 +311,7 @@ func TestFile_GetX509BundleForTrustDomain(t *testing.T) {
|
|||
root := append(pki.RootCertPEM, []byte("garbage data")...)
|
||||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, root, 0o600))
|
||||
ta := FromFile(OptionsFile{
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
|
@ -359,7 +359,7 @@ func TestFile_Watch(t *testing.T) {
|
|||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, pki.RootCertPEM, 0o600))
|
||||
|
||||
ta := FromFile(OptionsFile{
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
|
@ -400,7 +400,7 @@ func TestFile_Watch(t *testing.T) {
|
|||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, pki.RootCertPEM, 0o600))
|
||||
|
||||
ta := FromFile(OptionsFile{
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
|
@ -446,7 +446,7 @@ func TestFile_Watch(t *testing.T) {
|
|||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, pki1.RootCertPEM, 0o600))
|
||||
|
||||
ta := FromFile(OptionsFile{
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
|
@ -529,7 +529,7 @@ func TestFile_CurrentTrustAnchors(t *testing.T) {
|
|||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, pki1.RootCertPEM, 0o600))
|
||||
|
||||
ta := FromFile(OptionsFile{
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
|
@ -11,7 +11,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
*/
|
||||
|
||||
package trustanchors
|
||||
package multi
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -22,6 +22,7 @@ import (
|
|||
"github.com/spiffe/go-spiffe/v2/spiffeid"
|
||||
|
||||
"github.com/dapr/kit/concurrency"
|
||||
"github.com/dapr/kit/crypto/spiffe/trustanchors"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -29,17 +30,17 @@ var (
|
|||
ErrTrustDomainNotFound = errors.New("trust domain not found")
|
||||
)
|
||||
|
||||
type OptionsMulti struct {
|
||||
TrustAnchors map[spiffeid.TrustDomain]Interface
|
||||
type Options struct {
|
||||
TrustAnchors map[spiffeid.TrustDomain]trustanchors.Interface
|
||||
}
|
||||
|
||||
// multi is a TrustAnchors implementation which uses multiple trust anchors
|
||||
// which are indexed by trust domain.
|
||||
type multi struct {
|
||||
trustAnchors map[spiffeid.TrustDomain]Interface
|
||||
trustAnchors map[spiffeid.TrustDomain]trustanchors.Interface
|
||||
}
|
||||
|
||||
func FromMulti(opts OptionsMulti) Interface {
|
||||
func From(opts Options) trustanchors.Interface {
|
||||
return &multi{
|
||||
trustAnchors: opts.TrustAnchors,
|
||||
}
|
|
@ -11,7 +11,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
*/
|
||||
|
||||
package trustanchors
|
||||
package static
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -24,6 +24,7 @@ import (
|
|||
"github.com/spiffe/go-spiffe/v2/spiffeid"
|
||||
|
||||
"github.com/dapr/kit/crypto/pem"
|
||||
"github.com/dapr/kit/crypto/spiffe/trustanchors"
|
||||
)
|
||||
|
||||
// static is a TrustAcnhors implementation that uses a static list of trust
|
||||
|
@ -36,12 +37,12 @@ type static struct {
|
|||
closeCh chan struct{}
|
||||
}
|
||||
|
||||
type OptionsStatic struct {
|
||||
type Options struct {
|
||||
Anchors []byte
|
||||
Jwks []byte
|
||||
}
|
||||
|
||||
func FromStatic(opts OptionsStatic) (Interface, error) {
|
||||
func From(opts Options) (trustanchors.Interface, error) {
|
||||
// Create empty trust domain for now
|
||||
emptyTD := spiffeid.TrustDomain{}
|
||||
|
|
@ -11,7 +11,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
*/
|
||||
|
||||
package trustanchors
|
||||
package static
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -27,30 +27,30 @@ import (
|
|||
|
||||
func TestFromStatic(t *testing.T) {
|
||||
t.Run("empty root should return error", func(t *testing.T) {
|
||||
_, err := FromStatic(OptionsStatic{})
|
||||
_, err := From(Options{})
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("garbage data should return error", func(t *testing.T) {
|
||||
_, err := FromStatic(OptionsStatic{Anchors: []byte("garbage data")})
|
||||
_, err := From(Options{Anchors: []byte("garbage data")})
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("just garbage data should return error", func(t *testing.T) {
|
||||
_, err := FromStatic(OptionsStatic{Anchors: []byte("garbage data")})
|
||||
_, err := From(Options{Anchors: []byte("garbage data")})
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("garbage data in root should return error", func(t *testing.T) {
|
||||
pki := test.GenPKI(t, test.PKIOptions{})
|
||||
root := pki.RootCertPEM[10:]
|
||||
_, err := FromStatic(OptionsStatic{Anchors: root})
|
||||
_, err := From(Options{Anchors: root})
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("single root should be correctly parsed", func(t *testing.T) {
|
||||
pki := test.GenPKI(t, test.PKIOptions{})
|
||||
ta, err := FromStatic(OptionsStatic{Anchors: pki.RootCertPEM})
|
||||
ta, err := From(Options{Anchors: pki.RootCertPEM})
|
||||
require.NoError(t, err)
|
||||
taPEM, err := ta.CurrentTrustAnchors(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
@ -61,7 +61,7 @@ func TestFromStatic(t *testing.T) {
|
|||
pki := test.GenPKI(t, test.PKIOptions{})
|
||||
//nolint:gocritic
|
||||
root := append(pki.RootCertPEM, []byte("garbage data")...)
|
||||
ta, err := FromStatic(OptionsStatic{Anchors: root})
|
||||
ta, err := From(Options{Anchors: root})
|
||||
require.NoError(t, err)
|
||||
taPEM, err := ta.CurrentTrustAnchors(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
@ -72,7 +72,7 @@ func TestFromStatic(t *testing.T) {
|
|||
pki1, pki2 := test.GenPKI(t, test.PKIOptions{}), test.GenPKI(t, test.PKIOptions{})
|
||||
//nolint:gocritic
|
||||
roots := append(pki1.RootCertPEM, pki2.RootCertPEM...)
|
||||
ta, err := FromStatic(OptionsStatic{Anchors: roots})
|
||||
ta, err := From(Options{Anchors: roots})
|
||||
require.NoError(t, err)
|
||||
taPEM, err := ta.CurrentTrustAnchors(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
@ -85,7 +85,7 @@ func TestStatic_GetX509BundleForTrustDomain(t *testing.T) {
|
|||
pki := test.GenPKI(t, test.PKIOptions{})
|
||||
//nolint:gocritic
|
||||
root := append(pki.RootCertPEM, []byte("garbage data")...)
|
||||
ta, err := FromStatic(OptionsStatic{Anchors: root})
|
||||
ta, err := From(Options{Anchors: root})
|
||||
require.NoError(t, err)
|
||||
s, ok := ta.(*static)
|
||||
require.True(t, ok)
|
||||
|
@ -113,7 +113,7 @@ func TestStatic_GetX509BundleForTrustDomain(t *testing.T) {
|
|||
func TestStatic_Run(t *testing.T) {
|
||||
t.Run("Run multiple times should return error", func(t *testing.T) {
|
||||
pki := test.GenPKI(t, test.PKIOptions{})
|
||||
ta, err := FromStatic(OptionsStatic{Anchors: pki.RootCertPEM})
|
||||
ta, err := From(Options{Anchors: pki.RootCertPEM})
|
||||
require.NoError(t, err)
|
||||
s, ok := ta.(*static)
|
||||
require.True(t, ok)
|
||||
|
@ -154,7 +154,7 @@ func TestStatic_Run(t *testing.T) {
|
|||
func TestStatic_Watch(t *testing.T) {
|
||||
t.Run("should return when context is cancelled", func(t *testing.T) {
|
||||
pki := test.GenPKI(t, test.PKIOptions{})
|
||||
ta, err := FromStatic(OptionsStatic{Anchors: pki.RootCertPEM})
|
||||
ta, err := From(Options{Anchors: pki.RootCertPEM})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
@ -176,7 +176,7 @@ func TestStatic_Watch(t *testing.T) {
|
|||
|
||||
t.Run("should return when cancel is closed via closed Run", func(t *testing.T) {
|
||||
pki := test.GenPKI(t, test.PKIOptions{})
|
||||
ta, err := FromStatic(OptionsStatic{Anchors: pki.RootCertPEM})
|
||||
ta, err := From(Options{Anchors: pki.RootCertPEM})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
Loading…
Reference in New Issue