Adds JWT handling to spiffe

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>
This commit is contained in:
Jonathan Collinge 2025-04-11 09:54:32 +01:00
parent e3d4a8f1b4
commit 15329f8a80
No known key found for this signature in database
GPG Key ID: BF9B59005264DD95
7 changed files with 258 additions and 46 deletions

View File

@ -1,5 +1,5 @@
/*
Copyright 2024 The Dapr Authors
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
@ -16,6 +16,7 @@ package context
import (
"context"
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
"github.com/dapr/kit/crypto/spiffe"
@ -23,13 +24,48 @@ import (
type ctxkey int
const svidKey ctxkey = iota
const (
x509SvidKey ctxkey = iota
jwtSvidKey
)
// Deprecated: use WithX509 instead.
// With adds the x509 SVID source from the SPIFFE object to the context.
func With(ctx context.Context, spiffe *spiffe.SPIFFE) context.Context {
return context.WithValue(ctx, svidKey, spiffe.SVIDSource())
return context.WithValue(ctx, x509SvidKey, spiffe.X509SVIDSource())
}
// Deprecated: use FromX509 instead.
// From retrieves the x509 SVID source from the context.
func From(ctx context.Context) (x509svid.Source, bool) {
svid, ok := ctx.Value(svidKey).(x509svid.Source)
svid, ok := ctx.Value(x509SvidKey).(x509svid.Source)
return svid, ok
}
// WithX509 adds an x509 SVID source to the context.
func WithX509(ctx context.Context, source x509svid.Source) context.Context {
return context.WithValue(ctx, x509SvidKey, source)
}
// WithJWT adds a JWT SVID source to the context.
func WithJWT(ctx context.Context, source jwtsvid.Source) context.Context {
return context.WithValue(ctx, jwtSvidKey, source)
}
// FromX509 retrieves the x509 SVID source from the context.
func FromX509(ctx context.Context) (x509svid.Source, bool) {
svid, ok := ctx.Value(x509SvidKey).(x509svid.Source)
return svid, ok
}
// FromJWT retrieves the JWT SVID source from the context.
func FromJWT(ctx context.Context) (jwtsvid.Source, bool) {
svid, ok := ctx.Value(jwtSvidKey).(jwtsvid.Source)
return svid, ok
}
// WithSpiffe adds both X509 and JWT SVID sources to the context.
func WithSpiffe(ctx context.Context, spiffe *spiffe.SPIFFE) context.Context {
ctx = WithX509(ctx, spiffe.X509SVIDSource())
return WithJWT(ctx, spiffe.JWTSVIDSource())
}

View File

@ -0,0 +1,74 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package context
import (
"context"
"testing"
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
)
type mockX509Source struct{}
func (m *mockX509Source) GetX509SVID() (*x509svid.SVID, error) {
return nil, nil
}
type mockJWTSource struct{}
func (m *mockJWTSource) FetchJWTSVID(context.Context, jwtsvid.Params) (*jwtsvid.SVID, error) {
return nil, nil
}
func TestWithX509FromX509(t *testing.T) {
source := &mockX509Source{}
ctx := WithX509(context.Background(), source)
retrieved, ok := FromX509(ctx)
if !ok {
t.Error("Failed to retrieve X509 source from context")
}
if retrieved != source {
t.Error("Retrieved source does not match the original source")
}
}
func TestWithJWTFromJWT(t *testing.T) {
source := &mockJWTSource{}
ctx := WithJWT(context.Background(), source)
retrieved, ok := FromJWT(ctx)
if !ok {
t.Error("Failed to retrieve JWT source from context")
}
if retrieved != source {
t.Error("Retrieved source does not match the original source")
}
}
func TestWithFrom(t *testing.T) {
x509Source := &mockX509Source{}
ctx := WithX509(context.Background(), x509Source)
// Should be able to retrieve using the legacy From function
retrieved, ok := From(ctx)
if !ok {
t.Error("Failed to retrieve X509 source from context using legacy From")
}
if retrieved != x509Source {
t.Error("Retrieved source does not match the original source using legacy From")
}
}

View File

@ -25,6 +25,7 @@ import (
"sync/atomic"
"time"
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
"k8s.io/utils/clock"
@ -34,8 +35,24 @@ import (
"github.com/dapr/kit/logger"
)
// SVIDResponse represents the response from the SVID request function,
// containing both X.509 certificates and a JWT token.
type SVIDResponse struct {
X509Certificates []*x509.Certificate
JWT string
Audiences []string
}
// Identity contains both X.509 and JWT SVIDs for a workload.
type Identity struct {
X509SVID *x509svid.SVID
JWTSVID *jwtsvid.SVID
}
type (
RequestSVIDFn func(context.Context, []byte) ([]*x509.Certificate, error)
// RequestSVIDFn is the function type that requests SVIDs from a SPIFFE server,
// returning both X.509 certificates and a JWT token.
RequestSVIDFn func(context.Context, []byte) (*SVIDResponse, error)
)
type Options struct {
@ -51,11 +68,12 @@ type Options struct {
TrustAnchors trustanchors.Interface
}
// SPIFFE is a readable/writeable store of a SPIFFE X.509 SVID.
// Used to manage a workload SVID, and share read-only interfaces to consumers.
// SPIFFE is a readable/writeable store of SPIFFE SVID credentials.
// Used to manage workload SVIDs, and share read-only interfaces to consumers.
type SPIFFE struct {
currentSVID *x509svid.SVID
requestSVIDFn RequestSVIDFn
currentX509SVID *x509svid.SVID
currentJWTSVID *jwtsvid.SVID
requestSVIDFn RequestSVIDFn
dir *dir.Dir
trustAnchors trustanchors.Interface
@ -92,15 +110,16 @@ func (s *SPIFFE) Run(ctx context.Context) error {
}
s.lock.Lock()
s.log.Info("Fetching initial identity certificate")
initialCert, err := s.fetchIdentityCertificate(ctx)
s.log.Info("Fetching initial identity")
initialIdentity, err := s.fetchIdentity(ctx)
if err != nil {
close(s.readyCh)
s.lock.Unlock()
return fmt.Errorf("failed to retrieve the initial identity certificate: %w", err)
return fmt.Errorf("failed to retrieve the initial identity: %w", err)
}
s.currentSVID = initialCert
s.currentX509SVID = initialIdentity.X509SVID
s.currentJWTSVID = initialIdentity.JWTSVID
close(s.readyCh)
s.lock.Unlock()
@ -122,12 +141,12 @@ func (s *SPIFFE) Ready(ctx context.Context) error {
}
// runRotation starts up the manager responsible for renewing the workload
// certificate. Receives the initial certificate to calculate the next rotation
// identity. Receives the initial identity to calculate the next rotation
// time.
func (s *SPIFFE) runRotation(ctx context.Context) {
defer s.log.Debug("stopping workload cert expiry watcher")
s.lock.RLock()
cert := s.currentSVID.Certificates[0]
cert := s.currentX509SVID.Certificates[0]
s.lock.RUnlock()
renewTime := renewalTime(cert.NotBefore, cert.NotAfter)
s.log.Infof("Starting workload cert expiry watcher; current cert expires on: %s, renewing at %s",
@ -139,10 +158,10 @@ func (s *SPIFFE) runRotation(ctx context.Context) {
if s.clock.Now().Before(renewTime) {
continue
}
s.log.Infof("Renewing workload cert; current cert expires on: %s", cert.NotAfter.String())
svid, err := s.fetchIdentityCertificate(ctx)
s.log.Infof("Renewing workload identity; current cert expires on: %s", cert.NotAfter.String())
identity, err := s.fetchIdentity(ctx)
if err != nil {
s.log.Errorf("Error renewing identity certificate, trying again in 10 seconds: %s", err)
s.log.Errorf("Error renewing identity, trying again in 10 seconds: %s", err)
select {
case <-s.clock.After(10 * time.Second):
continue
@ -151,11 +170,15 @@ func (s *SPIFFE) runRotation(ctx context.Context) {
}
}
s.lock.Lock()
s.currentSVID = svid
cert = svid.Certificates[0]
s.currentX509SVID = identity.X509SVID
s.currentJWTSVID = identity.JWTSVID
cert = identity.X509SVID.Certificates[0]
s.lock.Unlock()
renewTime = renewalTime(cert.NotBefore, cert.NotAfter)
s.log.Infof("Successfully renewed workload cert; new cert expires on: %s", cert.NotAfter.String())
s.log.Infof("Successfully renewed workload identity; new cert expires on: %s", cert.NotAfter.String())
if identity.JWTSVID != nil {
s.log.Infof("New JWT SVID expires on: %s", identity.JWTSVID.Expiry.String())
}
case <-ctx.Done():
return
@ -163,8 +186,9 @@ func (s *SPIFFE) runRotation(ctx context.Context) {
}
}
// fetchIdentityCertificate fetches a new SVID using the configured requester.
func (s *SPIFFE) fetchIdentityCertificate(ctx context.Context) (*x509svid.SVID, error) {
// fetchIdentity fetches a new identity using the configured requester.
// Returns both X.509 SVID and JWT SVID (if available).
func (s *SPIFFE) fetchIdentity(ctx context.Context) (*Identity, error) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, fmt.Errorf("failed to generate private key: %w", err)
@ -175,27 +199,46 @@ func (s *SPIFFE) fetchIdentityCertificate(ctx context.Context) (*x509svid.SVID,
return nil, fmt.Errorf("failed to create sidecar csr: %w", err)
}
workloadcert, err := s.requestSVIDFn(ctx, csrDER)
svidResponse, err := s.requestSVIDFn(ctx, csrDER)
if err != nil {
return nil, err
}
if len(workloadcert) == 0 {
if len(svidResponse.X509Certificates) == 0 {
return nil, errors.New("no certificates received from sentry")
}
spiffeID, err := x509svid.IDFromCert(workloadcert[0])
spiffeID, err := x509svid.IDFromCert(svidResponse.X509Certificates[0])
if err != nil {
return nil, fmt.Errorf("error parsing spiffe id from newly signed certificate: %w", err)
}
identity := &Identity{
X509SVID: &x509svid.SVID{
ID: spiffeID,
Certificates: svidResponse.X509Certificates,
PrivateKey: key,
},
}
// If we have a JWT token, parse it and include it in the identity
if svidResponse.JWT != "" {
jwtSvid, err := jwtsvid.ParseInsecure(svidResponse.JWT, svidResponse.Audiences)
if err != nil {
s.log.Warnf("Failed to parse JWT SVID: %v", err)
} else {
identity.JWTSVID = jwtSvid
s.log.Infof("Successfully received JWT SVID with expiry: %s", jwtSvid.Expiry.String())
}
}
if s.dir != nil {
pkPEM, err := pem.EncodePrivateKey(key)
if err != nil {
return nil, err
}
certPEM, err := pem.EncodeX509Chain(workloadcert)
certPEM, err := pem.EncodeX509Chain(svidResponse.X509Certificates)
if err != nil {
return nil, err
}
@ -205,23 +248,29 @@ func (s *SPIFFE) fetchIdentityCertificate(ctx context.Context) (*x509svid.SVID,
return nil, err
}
if err := s.dir.Write(map[string][]byte{
files := map[string][]byte{
"key.pem": pkPEM,
"cert.pem": certPEM,
"ca.pem": td,
}); err != nil {
}
if svidResponse.JWT != "" {
files["token.jwt"] = []byte(svidResponse.JWT)
}
if err := s.dir.Write(files); err != nil {
return nil, err
}
}
return &x509svid.SVID{
ID: spiffeID,
Certificates: workloadcert,
PrivateKey: key,
}, nil
return identity, nil
}
func (s *SPIFFE) SVIDSource() x509svid.Source {
func (s *SPIFFE) X509SVIDSource() x509svid.Source {
return &svidSource{spiffe: s}
}
func (s *SPIFFE) JWTSVIDSource() jwtsvid.Source {
return &svidSource{spiffe: s}
}
@ -229,3 +278,25 @@ func (s *SPIFFE) SVIDSource() x509svid.Source {
func renewalTime(notBefore, notAfter time.Time) time.Time {
return notBefore.Add(notAfter.Sub(notBefore) / 2)
}
// audiencesMatch checks if the SVID audiences contain all the requested audiences
func audiencesMatch(svidAudiences []string, requestedAudiences []string) bool {
if len(requestedAudiences) == 0 {
return true
}
// Create a map for faster lookup
audienceMap := make(map[string]struct{}, len(svidAudiences))
for _, audience := range svidAudiences {
audienceMap[audience] = struct{}{}
}
// Check if all requested audiences are in the SVID
for _, requested := range requestedAudiences {
if _, ok := audienceMap[requested]; !ok {
return false
}
}
return true
}

View File

@ -47,8 +47,10 @@ func Test_Run(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
s := New(Options{
Log: logger.NewLogger("test"),
RequestSVIDFn: func(context.Context, []byte) ([]*x509.Certificate, error) {
return []*x509.Certificate{pki.LeafCert}, nil
RequestSVIDFn: func(context.Context, []byte) (*SVIDResponse, error) {
return &SVIDResponse{
X509Certificates: []*x509.Certificate{pki.LeafCert},
}, nil
},
})
@ -79,7 +81,7 @@ func Test_Run(t *testing.T) {
t.Run("should return error if initial fetch errors", func(t *testing.T) {
s := New(Options{
Log: logger.NewLogger("test"),
RequestSVIDFn: func(context.Context, []byte) ([]*x509.Certificate, error) {
RequestSVIDFn: func(context.Context, []byte) (*SVIDResponse, error) {
return nil, errors.New("this is an error")
},
})
@ -95,9 +97,11 @@ func Test_Run(t *testing.T) {
var fetches atomic.Int32
s := New(Options{
Log: logger.NewLogger("test"),
RequestSVIDFn: func(context.Context, []byte) ([]*x509.Certificate, error) {
RequestSVIDFn: func(context.Context, []byte) (*SVIDResponse, error) {
fetches.Add(1)
return []*x509.Certificate{pki.LeafCert}, nil
return &SVIDResponse{
X509Certificates: []*x509.Certificate{pki.LeafCert},
}, nil
},
})
now := time.Now()
@ -144,9 +148,11 @@ func Test_Run(t *testing.T) {
var fetches atomic.Int32
s := New(Options{
Log: logger.NewLogger("test"),
RequestSVIDFn: func(context.Context, []byte) ([]*x509.Certificate, error) {
RequestSVIDFn: func(context.Context, []byte) (*SVIDResponse, error) {
fetches.Add(1)
return respCert, respErr
return &SVIDResponse{
X509Certificates: respCert,
}, respErr
},
})
now := time.Now()

View File

@ -14,27 +14,49 @@ limitations under the License.
package spiffe
import (
"context"
"errors"
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
)
// svidSource is an implementation of the Go spiffe x509svid Source interface.
// svidSource is an implementation of both go-spiffe x509svid.Source and jwtsvid.Source interfaces.
type svidSource struct {
spiffe *SPIFFE
}
// GetX509SVID returns the current X.509 certificate identity as a SPIFFE SVID.
// Implements the go-spiffe x509 source interface.
// Implements the go-spiffe x509svid.Source interface.
func (s *svidSource) GetX509SVID() (*x509svid.SVID, error) {
s.spiffe.lock.RLock()
defer s.spiffe.lock.RUnlock()
<-s.spiffe.readyCh
svid := s.spiffe.currentSVID
svid := s.spiffe.currentX509SVID
if svid == nil {
return nil, errors.New("no SVID available")
return nil, errors.New("no X509 SVID available")
}
return svid, nil
}
// FetchJWTSVID returns the current JWT SVID.
// Implements the go-spiffe jwtsvid.Source interface.
func (s *svidSource) FetchJWTSVID(ctx context.Context, params jwtsvid.Params) (*jwtsvid.SVID, error) {
s.spiffe.lock.RLock()
defer s.spiffe.lock.RUnlock()
<-s.spiffe.readyCh
svid := s.spiffe.currentJWTSVID
if svid == nil {
return nil, errors.New("no JWT SVID available")
}
if !audiencesMatch(svid.Audience, []string{params.Audience}) {
return nil, errors.New("JWT SVID has different audiences than requested")
}
return svid, nil

1
go.mod
View File

@ -28,6 +28,7 @@ require (
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect
github.com/go-jose/go-jose/v3 v3.0.1 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/lestrrat-go/blackmagic v1.0.2 // indirect

2
go.sum
View File

@ -19,6 +19,7 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.1.0 h1:Hsa8mG0dQ46ij8Sl2AYJDUv1oA9/d6Vk+3LG99Oe02g=
@ -72,6 +73,7 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec
github.com/zeebo/errs v1.3.0 h1:hmiaKqgYZzcVgRL1Vkc1Mn2914BbzB0IBxs+ebeutGs=
github.com/zeebo/errs v1.3.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=