Updates based on joshvanl feedback

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>
This commit is contained in:
Jonathan Collinge 2025-05-13 08:16:35 +01:00
parent 5e0eb9625d
commit a885084755
No known key found for this signature in database
GPG Key ID: BF9B59005264DD95
8 changed files with 251 additions and 51 deletions

View File

@ -244,10 +244,15 @@ func (s *SPIFFE) fetchIdentity(ctx context.Context) (*Identity, error) {
// we are using ParseInsecure here as the expectation is that the
// requestSVIDFn will have already parsed and validate the JWT SVID
// before returning it.
//
// we are parsing the token using our SPIFFE ID's trust domain
// as the audience as we expect the issuer to always include
// that as an audience since that ensures that the token is
// valid for us and our trust domain.
audiences := []string{spiffeID.TrustDomain().Name()}
jwtSvid, err := jwtsvid.ParseInsecure(svidResponse.JWT, audiences)
if err != nil {
s.log.Warnf("Failed to parse JWT SVID: %v", err)
s.log.Errorf("Failed to parse JWT SVID: %v, continuing without JWT SVID", err)
} else {
identity.JWTSVID = jwtSvid
s.log.Infof("Successfully received JWT SVID with expiry: %s", jwtSvid.Expiry.String())
@ -277,7 +282,7 @@ func (s *SPIFFE) fetchIdentity(ctx context.Context) (*Identity, error) {
}
if svidResponse.JWT != "" {
files["token.jwt"] = []byte(svidResponse.JWT)
files["jwt_svid.token"] = []byte(svidResponse.JWT)
}
if err := s.dir.Write(files); err != nil {

View File

@ -16,6 +16,8 @@ package spiffe
import (
"context"
"errors"
"fmt"
"strings"
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
@ -24,7 +26,7 @@ import (
var (
errNoX509SVIDAvailable = errors.New("no X509 SVID available")
errNoJWTSVIDAvailable = errors.New("no JWT SVID available")
errAudienceMismatch = errors.New("JWT SVID has different audiences than requested")
errAudienceRequired = errors.New("audience is required")
)
// svidSource is an implementation of both go-spiffe x509svid.Source and jwtsvid.Source interfaces.
@ -48,12 +50,27 @@ func (s *svidSource) GetX509SVID() (*x509svid.SVID, error) {
return svid, nil
}
// audienceMismatchError is an error that contains information about mismatched audiences
type audienceMismatchError struct {
Expected []string
Actual []string
}
func (e *audienceMismatchError) Error() string {
return fmt.Sprintf("JWT SVID has different audiences than requested: expected %s, got %s",
strings.Join(e.Expected, ", "), strings.Join(e.Actual, ", "))
}
// FetchJWTSVID returns the current JWT SVID.
// Implements the go-spiffe jwtsvid.Source interface.
func (s *svidSource) FetchJWTSVID(_ context.Context, params jwtsvid.Params) (*jwtsvid.SVID, error) {
s.spiffe.lock.RLock()
defer s.spiffe.lock.RUnlock()
if params.Audience == "" {
return nil, errAudienceRequired
}
<-s.spiffe.readyCh
svid := s.spiffe.currentJWTSVID
@ -64,7 +81,10 @@ func (s *svidSource) FetchJWTSVID(_ context.Context, params jwtsvid.Params) (*jw
// verify that the audience being requested is the same as the audience in the SVID
// WARN: we do not check extra audiences here.
if !audiencesMatch(svid.Audience, []string{params.Audience}) {
return nil, errAudienceMismatch
return nil, &audienceMismatchError{
Expected: []string{params.Audience},
Actual: svid.Audience,
}
}
return svid, nil

View File

@ -14,11 +14,178 @@ limitations under the License.
package spiffe
import (
"context"
"sync"
"testing"
"time"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_svidSource(*testing.T) {
var _ x509svid.Source = new(svidSource)
var _ jwtsvid.Source = new(svidSource)
}
// createMockJWTSVID creates a mock JWT SVID for testing
func createMockJWTSVID(audiences []string) (*jwtsvid.SVID, error) {
td, err := spiffeid.TrustDomainFromString("example.org")
if err != nil {
return nil, err
}
id, err := spiffeid.FromSegments(td, "workload")
if err != nil {
return nil, err
}
svid := &jwtsvid.SVID{
ID: id,
Audience: audiences,
Expiry: time.Now().Add(time.Hour),
}
return svid, nil
}
func TestFetchJWTSVID(t *testing.T) {
t.Run("should return error when audience is empty", func(t *testing.T) {
s := &svidSource{
spiffe: &SPIFFE{
readyCh: make(chan struct{}),
lock: sync.RWMutex{},
},
}
close(s.spiffe.readyCh) // Mark as ready
svid, err := s.FetchJWTSVID(context.Background(), jwtsvid.Params{
Audience: "",
})
assert.Nil(t, svid)
assert.ErrorIs(t, err, errAudienceRequired)
})
t.Run("should return error when no JWT SVID available", func(t *testing.T) {
s := &svidSource{
spiffe: &SPIFFE{
readyCh: make(chan struct{}),
lock: sync.RWMutex{},
currentJWTSVID: nil,
},
}
close(s.spiffe.readyCh) // Mark as ready
svid, err := s.FetchJWTSVID(context.Background(), jwtsvid.Params{
Audience: "test-audience",
})
assert.Nil(t, svid)
assert.ErrorIs(t, err, errNoJWTSVIDAvailable)
})
t.Run("should return error when audience doesn't match", func(t *testing.T) {
// Create a mock SVID with a specific audience
mockJWTSVID, err := createMockJWTSVID([]string{"actual-audience"})
require.NoError(t, err)
s := &svidSource{
spiffe: &SPIFFE{
readyCh: make(chan struct{}),
lock: sync.RWMutex{},
currentJWTSVID: mockJWTSVID,
},
}
close(s.spiffe.readyCh) // Mark as ready
svid, err := s.FetchJWTSVID(context.Background(), jwtsvid.Params{
Audience: "requested-audience",
})
assert.Nil(t, svid)
assert.Error(t, err)
// Verify the specific error type and contents
audienceErr, ok := err.(*audienceMismatchError)
require.True(t, ok, "Expected audienceMismatchError")
assert.Equal(t, "JWT SVID has different audiences than requested: expected requested-audience, got actual-audience", audienceErr.Error())
})
t.Run("should return JWT SVID when audience matches", func(t *testing.T) {
mockJWTSVID, err := createMockJWTSVID([]string{"test-audience", "extra-audience"})
require.NoError(t, err)
s := &svidSource{
spiffe: &SPIFFE{
readyCh: make(chan struct{}),
lock: sync.RWMutex{},
currentJWTSVID: mockJWTSVID,
},
}
close(s.spiffe.readyCh) // Mark as ready
svid, err := s.FetchJWTSVID(context.Background(), jwtsvid.Params{
Audience: "test-audience",
})
assert.NoError(t, err)
assert.Equal(t, mockJWTSVID, svid)
})
t.Run("should wait for readyCh before checking SVID", func(t *testing.T) {
mockJWTSVID, err := createMockJWTSVID([]string{"test-audience"})
require.NoError(t, err)
readyCh := make(chan struct{})
s := &svidSource{
spiffe: &SPIFFE{
readyCh: readyCh,
lock: sync.RWMutex{},
currentJWTSVID: mockJWTSVID,
},
}
// Start goroutine to fetch SVID
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
resultCh := make(chan struct {
svid *jwtsvid.SVID
err error
})
go func() {
svid, err := s.FetchJWTSVID(ctx, jwtsvid.Params{
Audience: "test-audience",
})
resultCh <- struct {
svid *jwtsvid.SVID
err error
}{svid, err}
}()
// Assert that fetch is blocked
select {
case <-resultCh:
t.Fatal("FetchJWTSVID should be blocked until readyCh is closed")
case <-time.After(100 * time.Millisecond):
// Expected behavior - fetch is blocked
}
// Close readyCh to unblock fetch
close(readyCh)
// Now fetch should complete
select {
case result := <-resultCh:
assert.NoError(t, result.err)
assert.NotNil(t, result.svid)
case <-time.After(100 * time.Millisecond):
t.Fatal("FetchJWTSVID should have completed after readyCh was closed")
}
})
}

View File

@ -11,7 +11,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
package trustanchors
package file
import (
"context"
@ -29,6 +29,7 @@ import (
"github.com/dapr/kit/concurrency"
"github.com/dapr/kit/crypto/pem"
"github.com/dapr/kit/crypto/spiffe/trustanchors"
"github.com/dapr/kit/fswatcher"
"github.com/dapr/kit/logger"
)
@ -41,10 +42,10 @@ var (
ErrFailedToReadTrustAnchorsFile = errors.New("failed to read trust anchors file")
)
type OptionsFile struct {
type Options struct {
Log logger.Logger
CAPath string
JwksPath string
JwksPath *string
}
// file is a TrustAnchors implementation that uses a file as the source of trust
@ -52,7 +53,7 @@ type OptionsFile struct {
type file struct {
log logger.Logger
caPath string
jwksPath string
jwksPath *string
x509Bundle *x509bundle.Bundle
jwtBundle *jwtbundle.Bundle
rootPEM []byte
@ -76,7 +77,7 @@ type file struct {
caEvent chan struct{}
}
func FromFile(opts OptionsFile) Interface {
func From(opts Options) trustanchors.Interface {
return &file{
fsWatcherInterval: time.Millisecond * 500,
initFileWatchInterval: time.Second,
@ -99,7 +100,12 @@ func (f *file) Run(ctx context.Context) error {
defer close(f.closeCh)
for {
if found, err := filesExist(f.caPath, f.jwksPath); err != nil {
fs := []string{f.caPath}
if f.jwksPath != nil {
fs = append(fs, *f.jwksPath)
}
if found, err := filesExist(fs...); err != nil {
return err
} else if found {
break
@ -121,8 +127,8 @@ func (f *file) Run(ctx context.Context) error {
}
targets := []string{f.caPath}
if f.jwksPath != "" {
targets = append(targets, f.jwksPath)
if f.jwksPath != nil {
targets = append(targets, *f.jwksPath)
}
fs, err := fswatcher.New(fswatcher.Options{
@ -136,7 +142,7 @@ func (f *file) Run(ctx context.Context) error {
close(f.readyCh)
f.log.Infof("Watching trust anchors file '%s' for changes", f.caPath)
if f.jwksPath != "" {
if f.jwksPath != nil {
f.log.Infof("Watching JWT bundle file '%s' for changes", f.jwksPath)
}
@ -194,10 +200,10 @@ func (f *file) updateAnchors(ctx context.Context) error {
f.rootPEM = rootPEMs
f.x509Bundle = x509bundle.FromX509Authorities(spiffeid.TrustDomain{}, trustAnchorCerts)
if f.jwksPath != "" {
jwks, err := os.ReadFile(f.jwksPath)
if f.jwksPath != nil {
jwks, err := os.ReadFile(*f.jwksPath)
if err != nil {
return fmt.Errorf("failed to read JWT bundle file '%s': %w", f.jwksPath, err)
return fmt.Errorf("failed to read JWT bundle file '%s': %w", *f.jwksPath, err)
}
jwtBundle, err := jwtbundle.Parse(spiffeid.TrustDomain{}, jwks)

View File

@ -11,7 +11,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
package trustanchors
package file
import (
"context"
@ -31,7 +31,7 @@ import (
func TestFile_Run(t *testing.T) {
t.Run("if Run multiple times, expect error", func(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
ta := FromFile(OptionsFile{
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
@ -74,7 +74,7 @@ func TestFile_Run(t *testing.T) {
t.Run("if file is not found and context cancelled, should return ctx.Err", func(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
ta := FromFile(OptionsFile{
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
@ -102,7 +102,7 @@ func TestFile_Run(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, nil, 0o600))
ta := FromFile(OptionsFile{
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
@ -127,7 +127,7 @@ func TestFile_Run(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, []byte("garbage data"), 0o600))
ta := FromFile(OptionsFile{
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
@ -154,7 +154,7 @@ func TestFile_Run(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, root, 0o600))
ta := FromFile(OptionsFile{
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
@ -180,7 +180,7 @@ func TestFile_Run(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, pki.RootCertPEM, 0o600))
ta := FromFile(OptionsFile{
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
@ -211,7 +211,7 @@ func TestFile_Run(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, root, 0o600))
ta := FromFile(OptionsFile{
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
@ -242,7 +242,7 @@ func TestFile_Run(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
ta := FromFile(OptionsFile{
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
@ -273,7 +273,7 @@ func TestFile_Run(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
ta := FromFile(OptionsFile{
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
@ -311,7 +311,7 @@ func TestFile_GetX509BundleForTrustDomain(t *testing.T) {
root := append(pki.RootCertPEM, []byte("garbage data")...)
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, root, 0o600))
ta := FromFile(OptionsFile{
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
@ -359,7 +359,7 @@ func TestFile_Watch(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, pki.RootCertPEM, 0o600))
ta := FromFile(OptionsFile{
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
@ -400,7 +400,7 @@ func TestFile_Watch(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, pki.RootCertPEM, 0o600))
ta := FromFile(OptionsFile{
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
@ -446,7 +446,7 @@ func TestFile_Watch(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, pki1.RootCertPEM, 0o600))
ta := FromFile(OptionsFile{
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
@ -529,7 +529,7 @@ func TestFile_CurrentTrustAnchors(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, pki1.RootCertPEM, 0o600))
ta := FromFile(OptionsFile{
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})

View File

@ -11,7 +11,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
package trustanchors
package multi
import (
"context"
@ -22,6 +22,7 @@ import (
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/dapr/kit/concurrency"
"github.com/dapr/kit/crypto/spiffe/trustanchors"
)
var (
@ -29,17 +30,17 @@ var (
ErrTrustDomainNotFound = errors.New("trust domain not found")
)
type OptionsMulti struct {
TrustAnchors map[spiffeid.TrustDomain]Interface
type Options struct {
TrustAnchors map[spiffeid.TrustDomain]trustanchors.Interface
}
// multi is a TrustAnchors implementation which uses multiple trust anchors
// which are indexed by trust domain.
type multi struct {
trustAnchors map[spiffeid.TrustDomain]Interface
trustAnchors map[spiffeid.TrustDomain]trustanchors.Interface
}
func FromMulti(opts OptionsMulti) Interface {
func From(opts Options) trustanchors.Interface {
return &multi{
trustAnchors: opts.TrustAnchors,
}

View File

@ -11,7 +11,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
package trustanchors
package static
import (
"context"
@ -24,6 +24,7 @@ import (
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/dapr/kit/crypto/pem"
"github.com/dapr/kit/crypto/spiffe/trustanchors"
)
// static is a TrustAcnhors implementation that uses a static list of trust
@ -36,12 +37,12 @@ type static struct {
closeCh chan struct{}
}
type OptionsStatic struct {
type Options struct {
Anchors []byte
Jwks []byte
}
func FromStatic(opts OptionsStatic) (Interface, error) {
func From(opts Options) (trustanchors.Interface, error) {
// Create empty trust domain for now
emptyTD := spiffeid.TrustDomain{}

View File

@ -11,7 +11,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
package trustanchors
package static
import (
"context"
@ -27,30 +27,30 @@ import (
func TestFromStatic(t *testing.T) {
t.Run("empty root should return error", func(t *testing.T) {
_, err := FromStatic(OptionsStatic{})
_, err := From(Options{})
require.Error(t, err)
})
t.Run("garbage data should return error", func(t *testing.T) {
_, err := FromStatic(OptionsStatic{Anchors: []byte("garbage data")})
_, err := From(Options{Anchors: []byte("garbage data")})
require.Error(t, err)
})
t.Run("just garbage data should return error", func(t *testing.T) {
_, err := FromStatic(OptionsStatic{Anchors: []byte("garbage data")})
_, err := From(Options{Anchors: []byte("garbage data")})
require.Error(t, err)
})
t.Run("garbage data in root should return error", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
root := pki.RootCertPEM[10:]
_, err := FromStatic(OptionsStatic{Anchors: root})
_, err := From(Options{Anchors: root})
require.Error(t, err)
})
t.Run("single root should be correctly parsed", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
ta, err := FromStatic(OptionsStatic{Anchors: pki.RootCertPEM})
ta, err := From(Options{Anchors: pki.RootCertPEM})
require.NoError(t, err)
taPEM, err := ta.CurrentTrustAnchors(context.Background())
require.NoError(t, err)
@ -61,7 +61,7 @@ func TestFromStatic(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
//nolint:gocritic
root := append(pki.RootCertPEM, []byte("garbage data")...)
ta, err := FromStatic(OptionsStatic{Anchors: root})
ta, err := From(Options{Anchors: root})
require.NoError(t, err)
taPEM, err := ta.CurrentTrustAnchors(context.Background())
require.NoError(t, err)
@ -72,7 +72,7 @@ func TestFromStatic(t *testing.T) {
pki1, pki2 := test.GenPKI(t, test.PKIOptions{}), test.GenPKI(t, test.PKIOptions{})
//nolint:gocritic
roots := append(pki1.RootCertPEM, pki2.RootCertPEM...)
ta, err := FromStatic(OptionsStatic{Anchors: roots})
ta, err := From(Options{Anchors: roots})
require.NoError(t, err)
taPEM, err := ta.CurrentTrustAnchors(context.Background())
require.NoError(t, err)
@ -85,7 +85,7 @@ func TestStatic_GetX509BundleForTrustDomain(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
//nolint:gocritic
root := append(pki.RootCertPEM, []byte("garbage data")...)
ta, err := FromStatic(OptionsStatic{Anchors: root})
ta, err := From(Options{Anchors: root})
require.NoError(t, err)
s, ok := ta.(*static)
require.True(t, ok)
@ -113,7 +113,7 @@ func TestStatic_GetX509BundleForTrustDomain(t *testing.T) {
func TestStatic_Run(t *testing.T) {
t.Run("Run multiple times should return error", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
ta, err := FromStatic(OptionsStatic{Anchors: pki.RootCertPEM})
ta, err := From(Options{Anchors: pki.RootCertPEM})
require.NoError(t, err)
s, ok := ta.(*static)
require.True(t, ok)
@ -154,7 +154,7 @@ func TestStatic_Run(t *testing.T) {
func TestStatic_Watch(t *testing.T) {
t.Run("should return when context is cancelled", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
ta, err := FromStatic(OptionsStatic{Anchors: pki.RootCertPEM})
ta, err := From(Options{Anchors: pki.RootCertPEM})
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
@ -176,7 +176,7 @@ func TestStatic_Watch(t *testing.T) {
t.Run("should return when cancel is closed via closed Run", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
ta, err := FromStatic(OptionsStatic{Anchors: pki.RootCertPEM})
ta, err := From(Options{Anchors: pki.RootCertPEM})
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())