Compare commits

...

6 Commits

Author SHA1 Message Date
Joni Collinge 598b032bce
Fix deprecation comment to reference correct function (#125)
Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>
2025-06-16 17:06:11 +01:00
Josh van Leeuwen d7d50a1e1b
events/loop: drain queue on close (#124)
* events/loop: drain queue on close

Update looper to drain the Enqueue loop in the event that an error has
occurred when Enqueuing. Handle an error'd `Run` on `Close` by
respecting the `closedCh` channel.

Signed-off-by: joshvanl <me@joshvanl.dev>

* Update loop.go

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2025-06-16 08:26:36 -05:00
Josh van Leeuwen baea626399
Update .golangci.yml to remove deprecations (#122)
* Update .golangci.yml  to remove deprecations

Signed-off-by: joshvanl <me@joshvanl.dev>

* Update .golangci.yml

Co-authored-by: Cassie Coyle <cassie.i.coyle@gmail.com>
Signed-off-by: Josh van Leeuwen <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
Signed-off-by: Josh van Leeuwen <me@joshvanl.dev>
Co-authored-by: Cassie Coyle <cassie.i.coyle@gmail.com>
2025-05-22 08:58:18 -05:00
Joni Collinge bc7dc566c4
Add JWT handling to spiffe package (#118)
* Adds JWT handling to spiffe

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Update log

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Add jwtbundle

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Update file watcher

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Clean up jwt spiffe

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* lint

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Update renewal behavior

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Updates based on joshvanl feedback

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* go mod tidy

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* lint

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* lint

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Move ready chan check to avoid race

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Add small delay after fs write to allow watcher to pick up change

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Resolve feedback

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Resolve feedback

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* lint

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

---------

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>
2025-05-16 13:15:56 +01:00
Josh van Leeuwen 98fe567235
events/loop: add reset (#120)
* events/loop: add reset

Update loop implementation is include functionality for Reset which is
useful when caching the loop struct for future use to reduce
allocations.

Signed-off-by: joshvanl <me@joshvanl.dev>

* lint

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2025-05-15 23:23:38 +01:00
Josh van Leeuwen e3d4a8f1b4
Add Copyright headers to env, and remove utils. (#103)
Chore to add Copyright headers to `env` files.

Move containing funcs and deletes `utils` package. `utils` packages are
generally a code smell, and better placed in a more descriptive package
name that gives context.

Signed-off-by: joshvanl <me@joshvanl.dev>
2025-04-23 11:29:15 -03:00
26 changed files with 1133 additions and 383 deletions

View File

@ -4,7 +4,7 @@ run:
concurrency: 4
# timeout for analysis, e.g. 30s, 5m, default is 1m
deadline: 10m
timeout: 15m
# exit code when at least one issue was found, default is 1
issues-exit-code: 1
@ -15,29 +15,34 @@ run:
# list of build tags, all linters use it. Default is empty list.
build-tags:
- unit
# which dirs to skip: they won't be analyzed;
# can use regexp here: generated.*, regexp is applied on full path;
# default value is empty list, but next dirs are always skipped independently
# from this option's value:
# third_party$, testdata$, examples$, Godeps$, builtin$
skip-dirs:
- ^pkg.*client.*clientset.*versioned.*
- ^pkg.*client.*informers.*externalversions.*
- ^pkg.*proto.*
- allcomponents
- subtlecrypto
# which files to skip: they will be analyzed, but issues from them
# won't be reported. Default value is empty list, but there is
# no need to include all autogenerated files, we confidently recognize
# autogenerated files. If it's not please let us know.
skip-files:
# skip-files:
# - ".*\\.my\\.go$"
# - lib/bad.go
issues:
# which dirs to skip: they won't be analyzed;
# can use regexp here: generated.*, regexp is applied on full path;
# default value is empty list, but next dirs are always skipped independently
# from this option's value:
# third_party$, testdata$, examples$, Godeps$, builtin$
exclude-dirs:
- ^pkg.*client.*clientset.*versioned.*
- ^pkg.*client.*informers.*externalversions.*
- ^pkg.*proto.*
- pkg/proto
# output configuration options
output:
# colored-line-number|line-number|json|tab|checkstyle, default is "colored-line-number"
format: tab
formats:
- format: tab
# print lines of code with issue, default is true
print-issued-lines: true
@ -57,23 +62,19 @@ linters-settings:
# default is false: such cases aren't reported by default.
check-blank: false
# [deprecated] comma-separated list of pairs of the form pkg:regex
# the regex is used to ignore names within pkg. (default "fmt:.*").
# see https://github.com/kisielk/errcheck#the-deprecated-method for details
ignore: fmt:.*,io/ioutil:^Read.*
exclude-functions:
- fmt:.*
- io/ioutil:^Read.*
# path to a file containing a list of functions to exclude from checking
# see https://github.com/kisielk/errcheck#excluding-functions for details
exclude:
# exclude:
funlen:
lines: 60
statements: 40
govet:
# report about shadowed variables
check-shadowing: true
# settings per analyzer
settings:
printf: # analyzer name, run `go tool vet help` to see all analyzers
@ -86,13 +87,12 @@ linters-settings:
# enable or disable analyzers by name
enable:
- atomicalign
enable-all: false
disable:
- shadow
enable-all: false
disable-all: false
golint:
revive:
# minimal confidence for issues, default is 0.8
min-confidence: 0.8
confidence: 0.8
gofmt:
# simplify code: gofmt with `-s` option, true by default
simplify: true
@ -106,9 +106,6 @@ linters-settings:
gocognit:
# minimal code complexity to report, 30 by default (but we recommend 10-20)
min-complexity: 10
maligned:
# print struct with more effective memory layout or not, false by default
suggest-new: true
dupl:
# tokens count to trigger issue, 150 by default
threshold: 100
@ -121,55 +118,60 @@ linters-settings:
rules:
main:
deny:
- pkg: "github.com/Sirupsen/logrus"
desc: "must use github.com/dapr/kit/logger"
- pkg: "github.com/agrea/ptr"
desc: "must use github.com/dapr/kit/ptr"
- pkg: "go.uber.org/atomic"
desc: "must use sync/atomic"
- pkg: "golang.org/x/net/context"
desc: "must use context"
- pkg: "github.com/pkg/errors"
desc: "must use standard library (errors package and/or fmt.Errorf)"
- pkg: "github.com/go-chi/chi$"
desc: "must use github.com/go-chi/chi/v5"
- pkg: "github.com/cenkalti/backoff$"
desc: "must use github.com/cenkalti/backoff/v4"
- pkg: "github.com/cenkalti/backoff/v2"
desc: "must use github.com/cenkalti/backoff/v4"
- pkg: "github.com/cenkalti/backoff/v3"
desc: "must use github.com/cenkalti/backoff/v4"
- pkg: "github.com/benbjohnson/clock"
desc: "must use k8s.io/utils/clock"
- pkg: "github.com/ghodss/yaml"
desc: "must use sigs.k8s.io/yaml"
- pkg: "gopkg.in/yaml.v2"
desc: "must use gopkg.in/yaml.v3"
- pkg: "github.com/golang-jwt/jwt"
desc: "must use github.com/lestrrat-go/jwx/v2"
- pkg: "github.com/golang-jwt/jwt/v2"
desc: "must use github.com/lestrrat-go/jwx/v2"
- pkg: "github.com/golang-jwt/jwt/v3"
desc: "must use github.com/lestrrat-go/jwx/v2"
- pkg: "github.com/golang-jwt/jwt/v4"
desc: "must use github.com/lestrrat-go/jwx/v2"
- pkg: "github.com/gogo/status"
desc: "must use google.golang.org/grpc/status"
- pkg: "github.com/gogo/protobuf"
desc: "must use google.golang.org/protobuf"
- pkg: "github.com/lestrrat-go/jwx/jwa"
desc: "must use github.com/lestrrat-go/jwx/v2"
- pkg: "github.com/lestrrat-go/jwx/jwt"
desc: "must use github.com/lestrrat-go/jwx/v2"
- pkg: "github.com/labstack/gommon/log"
desc: "must use github.com/dapr/kit/logger"
- pkg: "github.com/gobuffalo/logger"
desc: "must use github.com/dapr/kit/logger"
- pkg: "github.com/Sirupsen/logrus"
desc: "must use github.com/dapr/kit/logger"
- pkg: "github.com/agrea/ptr"
desc: "must use github.com/dapr/kit/ptr"
- pkg: "go.uber.org/atomic"
desc: "must use sync/atomic"
- pkg: "golang.org/x/net/context"
desc: "must use context"
- pkg: "github.com/pkg/errors"
desc: "must use standard library (errors package and/or fmt.Errorf)"
- pkg: "github.com/go-chi/chi$"
desc: "must use github.com/go-chi/chi/v5"
- pkg: "github.com/cenkalti/backoff$"
desc: "must use github.com/cenkalti/backoff/v4"
- pkg: "github.com/cenkalti/backoff/v2"
desc: "must use github.com/cenkalti/backoff/v4"
- pkg: "github.com/cenkalti/backoff/v3"
desc: "must use github.com/cenkalti/backoff/v4"
- pkg: "github.com/benbjohnson/clock"
desc: "must use k8s.io/utils/clock"
- pkg: "github.com/ghodss/yaml"
desc: "must use sigs.k8s.io/yaml"
- pkg: "gopkg.in/yaml.v2"
desc: "must use gopkg.in/yaml.v3"
- pkg: "github.com/golang-jwt/jwt"
desc: "must use github.com/lestrrat-go/jwx/v2"
- pkg: "github.com/golang-jwt/jwt/v2"
desc: "must use github.com/lestrrat-go/jwx/v2"
- pkg: "github.com/golang-jwt/jwt/v3"
desc: "must use github.com/lestrrat-go/jwx/v2"
- pkg: "github.com/golang-jwt/jwt/v4"
desc: "must use github.com/lestrrat-go/jwx/v2"
# pkg: Commonly auto-completed by gopls
- pkg: "github.com/gogo/status"
desc: "must use google.golang.org/grpc/status"
- pkg: "github.com/gogo/protobuf"
desc: "must use google.golang.org/protobuf"
- pkg: "github.com/lestrrat-go/jwx/jwa"
desc: "must use github.com/lestrrat-go/jwx/v2"
- pkg: "github.com/lestrrat-go/jwx/jwt"
desc: "must use github.com/lestrrat-go/jwx/v2"
- pkg: "github.com/labstack/gommon/log"
desc: "must use github.com/dapr/kit/logger"
- pkg: "github.com/gobuffalo/logger"
desc: "must use github.com/dapr/kit/logger"
- pkg: "k8s.io/utils/pointer"
desc: "must use github.com/dapr/kit/ptr"
- pkg: "k8s.io/utils/ptr"
desc: "must use github.com/dapr/kit/ptr"
misspell:
# Correct spellings using locale preferences for US or UK.
# Default is to use a neutral variety of English.
# Setting locale to US will correct the British spelling of 'colour' to 'color'.
locale: default
# locale: default
ignore-words:
- someword
lll:
@ -178,17 +180,9 @@ linters-settings:
line-length: 120
# tab width in spaces. Default to 1.
tab-width: 1
unparam:
# Inspect exported functions, default is false. Set to true if no external program/library imports your code.
# XXX: if you enable this setting, unparam will report a lot of false-positives in text editors:
# if it's called for subdir of a project it can't find external interfaces. All text editor integrations
# with golangci-lint call it on a directory with the changed file.
check-exported: false
nakedret:
# make an issue if func has more lines of code than this setting and it has naked returns; default is 30
max-func-lines: 30
nolintlint:
allow-unused: true
prealloc:
# XXX: we don't recommend using this linter before doing performance profiling.
# For most programs usage of prealloc will be a premature optimization.
@ -203,7 +197,6 @@ linters-settings:
# See https://go-critic.github.io/overview#checks-overview
# To check which checks are enabled run `GL_DEBUG=gocritic golangci-lint run`
# By default list of stable checks is used.
enabled-checks:
# Which checks should be disabled; can't be combined with 'enabled-checks'; default is empty
disabled-checks:
@ -251,62 +244,51 @@ linters-settings:
allow-assign-and-call: true
# Allow multiline assignments to be cuddled. Default is true.
allow-multiline-assign: true
# Allow case blocks to end with a whitespace.
allow-case-traling-whitespace: true
# Allow declarations (var) to be cuddled.
allow-cuddle-declarations: false
# If the number of lines in a case block is equal to or lager than this number,
# the case *must* end white a newline.
# https://github.com/bombsimon/wsl/blob/master/doc/configuration.md#force-case-trailing-whitespace
# Default: 0
force-case-trailing-whitespace: 1
linters:
fast: false
enable-all: true
disable:
# TODO Enforce the below linters later
- musttag
- dupl
- nonamedreturns
- errcheck
- funlen
- goconst
- gochecknoglobals
- gochecknoinits
- gocyclo
- gocognit
- nosnakecase
- varcheck
- structcheck
- deadcode
- godox
- interfacer
- lll
- maligned
- scopelint
- unparam
- wsl
- gomnd
- testpackage
- nestif
- nlreturn
- exhaustive
- exhaustruct
- noctx
- gci
- golint
- tparallel
- paralleltest
- wrapcheck
- tagliatelle
- ireturn
- exhaustive
- exhaustivestruct
- exhaustruct
- errchkjson
- contextcheck
- gomoddirectives
- godot
- cyclop
- varnamelen
- gosec
- tagalign
- errorlint
- forcetypeassert
- ifshort
- maintidx
- nilnil
- predeclared
@ -315,4 +297,13 @@ linters:
- wastedassign
- containedctx
- gosimple
- forbidigo
- nonamedreturns
- asasalint
- rowserrcheck
- sqlclosecheck
- inamedparam
- tagalign
- mnd
- canonicalheader
- err113
- fatcontext

View File

@ -23,6 +23,7 @@ import (
"encoding/pem"
"errors"
"fmt"
"os"
)
// DecodePEMCertificatesChain takes a PEM-encoded x509 certificates byte array
@ -187,3 +188,24 @@ func PublicKeysEqual(a, b crypto.PublicKey) (bool, error) {
return false, fmt.Errorf("unrecognised public key type: %T", a)
}
}
// GetPEM loads a PEM-encoded file (certificate or key).
func GetPEM(val string) ([]byte, error) {
// If val is already a PEM-encoded string, return it as-is
if IsValidPEM(val) {
return []byte(val), nil
}
// Assume it's a file
pemBytes, err := os.ReadFile(val)
if err != nil {
return nil, fmt.Errorf("value is neither a valid file path or nor a valid PEM-encoded string: %w", err)
}
return pemBytes, nil
}
// IsValidPEM validates the provided input has PEM formatted block.
func IsValidPEM(val string) bool {
block, _ := pem.Decode([]byte(val))
return block != nil
}

View File

@ -1,5 +1,5 @@
/*
Copyright 2024 The Dapr Authors
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
@ -16,6 +16,7 @@ package context
import (
"context"
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
"github.com/dapr/kit/crypto/spiffe"
@ -23,13 +24,48 @@ import (
type ctxkey int
const svidKey ctxkey = iota
const (
x509SvidKey ctxkey = iota
jwtSvidKey
)
// Deprecated: use WithX509 instead.
// With adds the x509 SVID source from the SPIFFE object to the context.
func With(ctx context.Context, spiffe *spiffe.SPIFFE) context.Context {
return context.WithValue(ctx, svidKey, spiffe.SVIDSource())
return context.WithValue(ctx, x509SvidKey, spiffe.X509SVIDSource())
}
// Deprecated: use X509From instead.
// From retrieves the x509 SVID source from the context.
func From(ctx context.Context) (x509svid.Source, bool) {
svid, ok := ctx.Value(svidKey).(x509svid.Source)
svid, ok := ctx.Value(x509SvidKey).(x509svid.Source)
return svid, ok
}
// WithX509 adds an x509 SVID source to the context.
func WithX509(ctx context.Context, source x509svid.Source) context.Context {
return context.WithValue(ctx, x509SvidKey, source)
}
// WithJWT adds a JWT SVID source to the context.
func WithJWT(ctx context.Context, source jwtsvid.Source) context.Context {
return context.WithValue(ctx, jwtSvidKey, source)
}
// X509From retrieves the x509 SVID source from the context.
func X509From(ctx context.Context) (x509svid.Source, bool) {
svid, ok := ctx.Value(x509SvidKey).(x509svid.Source)
return svid, ok
}
// JWTFrom retrieves the JWT SVID source from the context.
func JWTFrom(ctx context.Context) (jwtsvid.Source, bool) {
svid, ok := ctx.Value(jwtSvidKey).(jwtsvid.Source)
return svid, ok
}
// WithSpiffe adds both X509 and JWT SVID sources to the context.
func WithSpiffe(ctx context.Context, spiffe *spiffe.SPIFFE) context.Context {
ctx = WithX509(ctx, spiffe.X509SVIDSource())
return WithJWT(ctx, spiffe.JWTSVIDSource())
}

View File

@ -0,0 +1,64 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package context
import (
"context"
"testing"
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
"github.com/stretchr/testify/assert"
)
type mockX509Source struct{}
func (m *mockX509Source) GetX509SVID() (*x509svid.SVID, error) {
return nil, nil
}
type mockJWTSource struct{}
func (m *mockJWTSource) FetchJWTSVID(context.Context, jwtsvid.Params) (*jwtsvid.SVID, error) {
return nil, nil
}
func TestWithX509FromX509(t *testing.T) {
source := &mockX509Source{}
ctx := WithX509(context.Background(), source)
retrieved, ok := X509From(ctx)
assert.True(t, ok, "Failed to retrieve X509 source from context")
assert.Equal(t, x509svid.Source(source), retrieved, "Retrieved source does not match the original source")
}
func TestWithJWTFromJWT(t *testing.T) {
source := &mockJWTSource{}
ctx := WithJWT(context.Background(), source)
retrieved, ok := JWTFrom(ctx)
assert.True(t, ok, "Failed to retrieve JWT source from context")
assert.Equal(t, jwtsvid.Source(source), retrieved, "Retrieved source does not match the original source")
}
func TestWithFrom(t *testing.T) {
x509Source := &mockX509Source{}
ctx := WithX509(context.Background(), x509Source)
// Should be able to retrieve using the legacy From function
retrieved, ok := From(ctx)
assert.True(t, ok, "Failed to retrieve X509 source from context using legacy From")
assert.Equal(t, x509svid.Source(x509Source), retrieved, "Retrieved source does not match the original source using legacy From")
}

View File

@ -25,6 +25,7 @@ import (
"sync/atomic"
"time"
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
"k8s.io/utils/clock"
@ -34,8 +35,29 @@ import (
"github.com/dapr/kit/logger"
)
const (
// renewalDivisor represents the divisor for calculating renewal time.
// A value of 2 means renewal at 50% of the validity period.
renewalDivisor = 2
)
// SVIDResponse represents the response from the SVID request function,
// containing both X.509 certificates and a JWT token.
type SVIDResponse struct {
X509Certificates []*x509.Certificate
JWT *string
}
// Identity contains both X.509 and JWT SVIDs for a workload.
type Identity struct {
X509SVID *x509svid.SVID
JWTSVID *jwtsvid.SVID
}
type (
RequestSVIDFn func(context.Context, []byte) ([]*x509.Certificate, error)
// RequestSVIDFn is the function type that requests SVIDs from a SPIFFE server,
// returning both X.509 certificates and a JWT token.
RequestSVIDFn func(context.Context, []byte) (*SVIDResponse, error)
)
type Options struct {
@ -51,11 +73,12 @@ type Options struct {
TrustAnchors trustanchors.Interface
}
// SPIFFE is a readable/writeable store of a SPIFFE X.509 SVID.
// Used to manage a workload SVID, and share read-only interfaces to consumers.
// SPIFFE is a readable/writeable store of SPIFFE SVID credentials.
// Used to manage workload SVIDs, and share read-only interfaces to consumers.
type SPIFFE struct {
currentSVID *x509svid.SVID
requestSVIDFn RequestSVIDFn
currentX509SVID *x509svid.SVID
currentJWTSVID *jwtsvid.SVID
requestSVIDFn RequestSVIDFn
dir *dir.Dir
trustAnchors trustanchors.Interface
@ -92,15 +115,16 @@ func (s *SPIFFE) Run(ctx context.Context) error {
}
s.lock.Lock()
s.log.Info("Fetching initial identity certificate")
initialCert, err := s.fetchIdentityCertificate(ctx)
s.log.Info("Fetching initial identity")
initialIdentity, err := s.fetchIdentity(ctx)
if err != nil {
close(s.readyCh)
s.lock.Unlock()
return fmt.Errorf("failed to retrieve the initial identity certificate: %w", err)
return fmt.Errorf("failed to retrieve the initial identity: %w", err)
}
s.currentSVID = initialCert
s.currentX509SVID = initialIdentity.X509SVID
s.currentJWTSVID = initialIdentity.JWTSVID
close(s.readyCh)
s.lock.Unlock()
@ -121,28 +145,48 @@ func (s *SPIFFE) Ready(ctx context.Context) error {
}
}
// runRotation starts up the manager responsible for renewing the workload
// certificate. Receives the initial certificate to calculate the next rotation
// time.
// logIdentityInfo creates a log message with expiry details for both X.509 and JWT SVIDs
func (s *SPIFFE) logIdentityInfo(prefix string, cert *x509.Certificate, jwtSVID *jwtsvid.SVID, renewTime *time.Time) {
msg := prefix + "; cert expires on: %s"
args := []any{cert.NotAfter.String()}
if jwtSVID != nil {
msg += ", jwt expires on: %s"
args = append(args, jwtSVID.Expiry.String())
}
if renewTime != nil {
msg += ", renewal at: %s"
args = append(args, renewTime.String())
}
s.log.Infof(msg, args...)
}
// runRotation starts up the manager responsible for renewing the workload identity
func (s *SPIFFE) runRotation(ctx context.Context) {
defer s.log.Debug("stopping workload cert expiry watcher")
defer s.log.Debug("stopping workload identity expiry watcher")
s.lock.RLock()
cert := s.currentSVID.Certificates[0]
cert := s.currentX509SVID.Certificates[0]
jwtSVID := s.currentJWTSVID
s.lock.RUnlock()
renewTime := renewalTime(cert.NotBefore, cert.NotAfter)
s.log.Infof("Starting workload cert expiry watcher; current cert expires on: %s, renewing at %s",
cert.NotAfter.String(), renewTime.String())
renewTime := calculateRenewalTime(time.Now(), cert, jwtSVID)
s.logIdentityInfo("Starting workload identity expiry watcher", cert, jwtSVID, renewTime)
for {
select {
case <-s.clock.After(min(time.Minute, renewTime.Sub(s.clock.Now()))):
if s.clock.Now().Before(renewTime) {
if s.clock.Now().Before(*renewTime) {
continue
}
s.log.Infof("Renewing workload cert; current cert expires on: %s", cert.NotAfter.String())
svid, err := s.fetchIdentityCertificate(ctx)
s.logIdentityInfo("Renewing workload identity", cert, jwtSVID, nil)
identity, err := s.fetchIdentity(ctx)
if err != nil {
s.log.Errorf("Error renewing identity certificate, trying again in 10 seconds: %s", err)
s.log.Errorf("Error renewing identity, trying again in 10 seconds: %s", err)
select {
case <-s.clock.After(10 * time.Second):
continue
@ -150,12 +194,16 @@ func (s *SPIFFE) runRotation(ctx context.Context) {
return
}
}
s.lock.Lock()
s.currentSVID = svid
cert = svid.Certificates[0]
s.currentX509SVID = identity.X509SVID
s.currentJWTSVID = identity.JWTSVID
cert = identity.X509SVID.Certificates[0]
jwtSVID = identity.JWTSVID
s.lock.Unlock()
renewTime = renewalTime(cert.NotBefore, cert.NotAfter)
s.log.Infof("Successfully renewed workload cert; new cert expires on: %s", cert.NotAfter.String())
renewTime = calculateRenewalTime(time.Now(), cert, jwtSVID)
s.logIdentityInfo("Successfully renewed workload identity", cert, jwtSVID, renewTime)
case <-ctx.Done():
return
@ -163,8 +211,8 @@ func (s *SPIFFE) runRotation(ctx context.Context) {
}
}
// fetchIdentityCertificate fetches a new SVID using the configured requester.
func (s *SPIFFE) fetchIdentityCertificate(ctx context.Context) (*x509svid.SVID, error) {
// Returns both X.509 SVID and JWT SVID (if available).
func (s *SPIFFE) fetchIdentity(ctx context.Context) (*Identity, error) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, fmt.Errorf("failed to generate private key: %w", err)
@ -175,27 +223,55 @@ func (s *SPIFFE) fetchIdentityCertificate(ctx context.Context) (*x509svid.SVID,
return nil, fmt.Errorf("failed to create sidecar csr: %w", err)
}
workloadcert, err := s.requestSVIDFn(ctx, csrDER)
svidResponse, err := s.requestSVIDFn(ctx, csrDER)
if err != nil {
return nil, err
}
if len(workloadcert) == 0 {
if len(svidResponse.X509Certificates) == 0 {
return nil, errors.New("no certificates received from sentry")
}
spiffeID, err := x509svid.IDFromCert(workloadcert[0])
spiffeID, err := x509svid.IDFromCert(svidResponse.X509Certificates[0])
if err != nil {
return nil, fmt.Errorf("error parsing spiffe id from newly signed certificate: %w", err)
}
identity := &Identity{
X509SVID: &x509svid.SVID{
ID: spiffeID,
Certificates: svidResponse.X509Certificates,
PrivateKey: key,
},
}
// If we have a JWT token, parse it and include it in the identity
if svidResponse.JWT != nil {
// we are using ParseInsecure here as the expectation is that the
// requestSVIDFn will have already parsed and validate the JWT SVID
// before returning it.
//
// we are parsing the token using our SPIFFE ID's trust domain
// as the audience as we expect the issuer to always include
// that as an audience since that ensures that the token is
// valid for us and our trust domain.
audiences := []string{spiffeID.TrustDomain().Name()}
jwtSvid, err := jwtsvid.ParseInsecure(*svidResponse.JWT, audiences)
if err != nil {
return nil, fmt.Errorf("failed to parse JWT SVID: %w", err)
}
identity.JWTSVID = jwtSvid
s.log.Infof("Successfully received JWT SVID with expiry: %s", jwtSvid.Expiry.String())
}
if s.dir != nil {
pkPEM, err := pem.EncodePrivateKey(key)
if err != nil {
return nil, err
}
certPEM, err := pem.EncodeX509Chain(workloadcert)
certPEM, err := pem.EncodeX509Chain(svidResponse.X509Certificates)
if err != nil {
return nil, err
}
@ -205,27 +281,72 @@ func (s *SPIFFE) fetchIdentityCertificate(ctx context.Context) (*x509svid.SVID,
return nil, err
}
if err := s.dir.Write(map[string][]byte{
files := map[string][]byte{
"key.pem": pkPEM,
"cert.pem": certPEM,
"ca.pem": td,
}); err != nil {
}
if svidResponse.JWT != nil {
files["jwt_svid.token"] = []byte(*svidResponse.JWT)
}
if err := s.dir.Write(files); err != nil {
return nil, err
}
}
return &x509svid.SVID{
ID: spiffeID,
Certificates: workloadcert,
PrivateKey: key,
}, nil
return identity, nil
}
func (s *SPIFFE) SVIDSource() x509svid.Source {
func (s *SPIFFE) X509SVIDSource() x509svid.Source {
return &svidSource{spiffe: s}
}
func (s *SPIFFE) JWTSVIDSource() jwtsvid.Source {
return &svidSource{spiffe: s}
}
// renewalTime is 50% through the certificate validity period.
func renewalTime(notBefore, notAfter time.Time) time.Time {
return notBefore.Add(notAfter.Sub(notBefore) / 2)
return notBefore.Add(notAfter.Sub(notBefore) / renewalDivisor)
}
// calculateRenewalTime returns the earlier renewal time between the X.509 certificate
// and JWT SVID (if available) to ensure timely renewal.
func calculateRenewalTime(now time.Time, cert *x509.Certificate, jwtSVID *jwtsvid.SVID) *time.Time {
certRenewal := renewalTime(cert.NotBefore, cert.NotAfter)
if jwtSVID == nil {
return &certRenewal
}
jwtRenewal := now.Add(jwtSVID.Expiry.Sub(now) / renewalDivisor)
if jwtRenewal.Before(certRenewal) {
return &jwtRenewal
}
return &certRenewal
}
// audiencesMatch checks if the SVID audiences contain all the requested audiences
func audiencesMatch(svidAudiences []string, requestedAudiences []string) bool {
if len(requestedAudiences) == 0 {
return true
}
// Create a map for faster lookup
audienceMap := make(map[string]struct{}, len(svidAudiences))
for _, audience := range svidAudiences {
audienceMap[audience] = struct{}{}
}
// Check if all requested audiences are in the SVID
for _, requested := range requestedAudiences {
if _, ok := audienceMap[requested]; !ok {
return false
}
}
return true
}

View File

@ -22,6 +22,7 @@ import (
"time"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
clocktesting "k8s.io/utils/clock/testing"
@ -39,6 +40,70 @@ func Test_renewalTime(t *testing.T) {
assert.Equal(t, in30, renewalTime(now, in1Min))
}
func Test_calculateRenewalTime(t *testing.T) {
now := time.Now()
certShort := &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(10 * time.Hour),
}
certLong := &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(24 * time.Hour),
}
// Expected renewal times for certificates (50% of validity period)
certShortRenewal := now.Add(5 * time.Hour)
// Create JWT SVIDs with different expiry times
jwtEarlier := &jwtsvid.SVID{
Expiry: now.Add(8 * time.Hour),
}
jwtLater := &jwtsvid.SVID{
Expiry: now.Add(30 * time.Hour),
}
// Expected JWT renewal time (50% of remaining time)
jwtEarlierRenewal := now.Add(4 * time.Hour)
tests := []struct {
name string
cert *x509.Certificate
jwt *jwtsvid.SVID
expected time.Time
}{
{
name: "Certificate only",
cert: certShort,
jwt: nil,
expected: certShortRenewal,
},
{
name: "Certificate and JWT, JWT earlier",
cert: certLong,
jwt: jwtEarlier,
expected: jwtEarlierRenewal,
},
{
name: "Certificate and JWT, Certificate earlier",
cert: certShort,
jwt: jwtLater,
expected: certShortRenewal,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := calculateRenewalTime(now, tt.cert, tt.jwt)
assert.WithinDuration(t, tt.expected, *actual, time.Millisecond,
"Renewal time does not match expected value")
})
}
}
func Test_Run(t *testing.T) {
t.Run("should return error multiple Runs are called", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{
@ -47,8 +112,10 @@ func Test_Run(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
s := New(Options{
Log: logger.NewLogger("test"),
RequestSVIDFn: func(context.Context, []byte) ([]*x509.Certificate, error) {
return []*x509.Certificate{pki.LeafCert}, nil
RequestSVIDFn: func(context.Context, []byte) (*SVIDResponse, error) {
return &SVIDResponse{
X509Certificates: []*x509.Certificate{pki.LeafCert},
}, nil
},
})
@ -79,7 +146,7 @@ func Test_Run(t *testing.T) {
t.Run("should return error if initial fetch errors", func(t *testing.T) {
s := New(Options{
Log: logger.NewLogger("test"),
RequestSVIDFn: func(context.Context, []byte) ([]*x509.Certificate, error) {
RequestSVIDFn: func(context.Context, []byte) (*SVIDResponse, error) {
return nil, errors.New("this is an error")
},
})
@ -95,9 +162,11 @@ func Test_Run(t *testing.T) {
var fetches atomic.Int32
s := New(Options{
Log: logger.NewLogger("test"),
RequestSVIDFn: func(context.Context, []byte) ([]*x509.Certificate, error) {
RequestSVIDFn: func(context.Context, []byte) (*SVIDResponse, error) {
fetches.Add(1)
return []*x509.Certificate{pki.LeafCert}, nil
return &SVIDResponse{
X509Certificates: []*x509.Certificate{pki.LeafCert},
}, nil
},
})
now := time.Now()
@ -107,15 +176,15 @@ func Test_Run(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
errCh := make(chan error)
go func() {
select {
case <-s.readyCh:
assert.Fail(t, "readyCh should not be closed")
default:
}
errCh <- s.Run(ctx)
}()
select {
case <-s.readyCh:
assert.Fail(t, "readyCh should not be closed")
default:
}
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
assert.Equal(t, int32(1), fetches.Load())
@ -144,9 +213,11 @@ func Test_Run(t *testing.T) {
var fetches atomic.Int32
s := New(Options{
Log: logger.NewLogger("test"),
RequestSVIDFn: func(context.Context, []byte) ([]*x509.Certificate, error) {
RequestSVIDFn: func(context.Context, []byte) (*SVIDResponse, error) {
fetches.Add(1)
return respCert, respErr
return &SVIDResponse{
X509Certificates: respCert,
}, respErr
},
})
now := time.Now()
@ -156,15 +227,15 @@ func Test_Run(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
errCh := make(chan error)
go func() {
select {
case <-s.readyCh:
assert.Fail(t, "readyCh should not be closed")
default:
}
errCh <- s.Run(ctx)
}()
select {
case <-s.readyCh:
assert.Fail(t, "readyCh should not be closed")
default:
}
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
assert.Equal(t, int32(1), fetches.Load())

View File

@ -14,27 +14,81 @@ limitations under the License.
package spiffe
import (
"context"
"errors"
"fmt"
"strings"
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
)
// svidSource is an implementation of the Go spiffe x509svid Source interface.
var (
errNoX509SVIDAvailable = errors.New("no X509 SVID available")
errNoJWTSVIDAvailable = errors.New("no JWT SVID available")
errAudienceRequired = errors.New("JWT audience is required")
)
// svidSource is an implementation of both go-spiffe x509svid.Source and jwtsvid.Source interfaces.
type svidSource struct {
spiffe *SPIFFE
}
// GetX509SVID returns the current X.509 certificate identity as a SPIFFE SVID.
// Implements the go-spiffe x509 source interface.
// Implements the go-spiffe x509svid.Source interface.
func (s *svidSource) GetX509SVID() (*x509svid.SVID, error) {
s.spiffe.lock.RLock()
defer s.spiffe.lock.RUnlock()
<-s.spiffe.readyCh
svid := s.spiffe.currentSVID
svid := s.spiffe.currentX509SVID
if svid == nil {
return nil, errors.New("no SVID available")
return nil, errNoX509SVIDAvailable
}
return svid, nil
}
// audienceMismatchError is an error that contains information about mismatched audiences
type audienceMismatchError struct {
expected []string
actual []string
}
func (e *audienceMismatchError) Error() string {
return fmt.Sprintf("JWT SVID has different audiences than requested: expected %s, got %s",
strings.Join(e.expected, ", "), strings.Join(e.actual, ", "))
}
// FetchJWTSVID returns the current JWT SVID.
// Implements the go-spiffe jwtsvid.Source interface.
func (s *svidSource) FetchJWTSVID(ctx context.Context, params jwtsvid.Params) (*jwtsvid.SVID, error) {
s.spiffe.lock.RLock()
defer s.spiffe.lock.RUnlock()
if params.Audience == "" {
return nil, errAudienceRequired
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-s.spiffe.readyCh:
}
svid := s.spiffe.currentJWTSVID
if svid == nil {
return nil, errNoJWTSVIDAvailable
}
// verify that the audience being requested is the same as the audience in the SVID
// WARN: we do not check extra audiences here.
if !audiencesMatch(svid.Audience, []string{params.Audience}) {
return nil, &audienceMismatchError{
expected: []string{params.Audience},
actual: svid.Audience,
}
}
return svid, nil

View File

@ -14,11 +14,177 @@ limitations under the License.
package spiffe
import (
"context"
"sync"
"testing"
"time"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
"github.com/stretchr/testify/require"
)
func Test_svidSource(*testing.T) {
var _ x509svid.Source = new(svidSource)
var _ jwtsvid.Source = new(svidSource)
}
// createMockJWTSVID creates a mock JWT SVID for testing
func createMockJWTSVID(audiences []string) (*jwtsvid.SVID, error) {
td, err := spiffeid.TrustDomainFromString("example.org")
if err != nil {
return nil, err
}
id, err := spiffeid.FromSegments(td, "workload")
if err != nil {
return nil, err
}
svid := &jwtsvid.SVID{
ID: id,
Audience: audiences,
Expiry: time.Now().Add(time.Hour),
}
return svid, nil
}
func TestFetchJWTSVID(t *testing.T) {
t.Run("should return error when audience is empty", func(t *testing.T) {
s := &svidSource{
spiffe: &SPIFFE{
readyCh: make(chan struct{}),
lock: sync.RWMutex{},
},
}
close(s.spiffe.readyCh) // Mark as ready
svid, err := s.FetchJWTSVID(context.Background(), jwtsvid.Params{
Audience: "",
})
require.Nil(t, svid)
require.ErrorIs(t, err, errAudienceRequired)
})
t.Run("should return error when no JWT SVID available", func(t *testing.T) {
s := &svidSource{
spiffe: &SPIFFE{
readyCh: make(chan struct{}),
lock: sync.RWMutex{},
currentJWTSVID: nil,
},
}
close(s.spiffe.readyCh) // Mark as ready
svid, err := s.FetchJWTSVID(context.Background(), jwtsvid.Params{
Audience: "test-audience",
})
require.Nil(t, svid)
require.ErrorIs(t, err, errNoJWTSVIDAvailable)
})
t.Run("should return error when audience doesn't match", func(t *testing.T) {
// Create a mock SVID with a specific audience
mockJWTSVID, err := createMockJWTSVID([]string{"actual-audience"})
require.NoError(t, err)
s := &svidSource{
spiffe: &SPIFFE{
readyCh: make(chan struct{}),
lock: sync.RWMutex{},
currentJWTSVID: mockJWTSVID,
},
}
close(s.spiffe.readyCh) // Mark as ready
svid, err := s.FetchJWTSVID(context.Background(), jwtsvid.Params{
Audience: "requested-audience",
})
require.Nil(t, svid)
require.Error(t, err)
// Verify the specific error type and contents
audienceErr, ok := err.(*audienceMismatchError)
require.True(t, ok, "Expected audienceMismatchError")
require.Equal(t, "JWT SVID has different audiences than requested: expected requested-audience, got actual-audience", audienceErr.Error())
})
t.Run("should return JWT SVID when audience matches", func(t *testing.T) {
mockJWTSVID, err := createMockJWTSVID([]string{"test-audience", "extra-audience"})
require.NoError(t, err)
s := &svidSource{
spiffe: &SPIFFE{
readyCh: make(chan struct{}),
lock: sync.RWMutex{},
currentJWTSVID: mockJWTSVID,
},
}
close(s.spiffe.readyCh) // Mark as ready
svid, err := s.FetchJWTSVID(context.Background(), jwtsvid.Params{
Audience: "test-audience",
})
require.NoError(t, err)
require.Equal(t, mockJWTSVID, svid)
})
t.Run("should wait for readyCh before checking SVID", func(t *testing.T) {
mockJWTSVID, err := createMockJWTSVID([]string{"test-audience"})
require.NoError(t, err)
readyCh := make(chan struct{})
s := &svidSource{
spiffe: &SPIFFE{
readyCh: readyCh,
lock: sync.RWMutex{},
currentJWTSVID: mockJWTSVID,
},
}
// Start goroutine to fetch SVID
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
resultCh := make(chan struct {
svid *jwtsvid.SVID
err error
})
go func() {
svid, err := s.FetchJWTSVID(ctx, jwtsvid.Params{
Audience: "test-audience",
})
resultCh <- struct {
svid *jwtsvid.SVID
err error
}{svid, err}
}()
// require that fetch is blocked
select {
case <-resultCh:
t.Fatal("FetchJWTSVID should be blocked until readyCh is closed")
case <-time.After(100 * time.Millisecond):
// Expected behavior - fetch is blocked
}
// Close readyCh to unblock fetch
close(readyCh)
// Now fetch should complete
select {
case result := <-resultCh:
require.NoError(t, result.err)
require.NotNil(t, result.svid)
case <-time.After(100 * time.Millisecond):
t.Fatal("FetchJWTSVID should have completed after readyCh was closed")
}
})
}

View File

@ -11,40 +11,52 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
package trustanchors
package file
import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"sync"
"sync/atomic"
"time"
"github.com/spiffe/go-spiffe/v2/bundle/jwtbundle"
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"k8s.io/utils/clock"
"github.com/dapr/kit/concurrency"
"github.com/dapr/kit/crypto/pem"
"github.com/dapr/kit/crypto/spiffe/trustanchors"
"github.com/dapr/kit/fswatcher"
"github.com/dapr/kit/logger"
)
type OptionsFile struct {
Log logger.Logger
Path string
var (
// ErrTrustAnchorsClosed is returned when an operation is performed on closed trust anchors.
ErrTrustAnchorsClosed = errors.New("trust anchors is closed")
// ErrFailedToReadTrustAnchorsFile is returned when the trust anchors file cannot be read.
ErrFailedToReadTrustAnchorsFile = errors.New("failed to read trust anchors file")
)
type Options struct {
Log logger.Logger
CAPath string
JwksPath *string
}
// file is a TrustAnchors implementation that uses a file as the source of trust
// anchors. The trust anchors will be updated when the file changes.
type file struct {
log logger.Logger
path string
bundle *x509bundle.Bundle
rootPEM []byte
log logger.Logger
caPath string
jwksPath *string
x509Bundle *x509bundle.Bundle
jwtBundle *jwtbundle.Bundle
rootPEM []byte
// fswatcherInterval is the interval at which the trust anchors file changes
// are batched. Used for testing only, and 500ms otherwise.
@ -65,17 +77,18 @@ type file struct {
caEvent chan struct{}
}
func FromFile(opts OptionsFile) Interface {
func From(opts Options) trustanchors.Interface {
return &file{
fsWatcherInterval: time.Millisecond * 500,
initFileWatchInterval: time.Second,
log: opts.Log,
path: opts.Path,
clock: clock.RealClock{},
readyCh: make(chan struct{}),
closeCh: make(chan struct{}),
caEvent: make(chan struct{}),
log: opts.Log,
caPath: opts.CAPath,
jwksPath: opts.JwksPath,
clock: clock.RealClock{},
readyCh: make(chan struct{}),
closeCh: make(chan struct{}),
caEvent: make(chan struct{}),
}
}
@ -87,31 +100,39 @@ func (f *file) Run(ctx context.Context) error {
defer close(f.closeCh)
for {
_, err := os.Stat(f.path)
if err == nil {
break
fs := []string{f.caPath}
if f.jwksPath != nil {
fs = append(fs, *f.jwksPath)
}
if !errors.Is(err, os.ErrNotExist) {
if found, err := filesExist(fs...); err != nil {
return err
} else if found {
break
}
// Trust anchors file not be provided yet, wait.
select {
case <-ctx.Done():
return fmt.Errorf("failed to find trust anchors file '%s': %w", f.path, ctx.Err())
return fmt.Errorf("failed to find trust anchors file '%s': %w", f.caPath, ctx.Err())
case <-f.clock.After(f.initFileWatchInterval):
f.log.Warnf("Trust anchors file '%s' not found, waiting...", f.path)
f.log.Warnf("Trust anchors file '%s' not found, waiting...", f.caPath)
}
}
f.log.Infof("Trust anchors file '%s' found", f.path)
f.log.Infof("Trust anchors file '%s' found", f.caPath)
if err := f.updateAnchors(ctx); err != nil {
return err
}
targets := []string{f.caPath}
if f.jwksPath != nil {
targets = append(targets, *f.jwksPath)
}
fs, err := fswatcher.New(fswatcher.Options{
Targets: []string{filepath.Dir(f.path)},
Targets: targets,
Interval: &f.fsWatcherInterval,
})
if err != nil {
@ -120,7 +141,11 @@ func (f *file) Run(ctx context.Context) error {
close(f.readyCh)
f.log.Infof("Watching trust anchors file '%s' for changes", f.path)
f.log.Infof("Watching trust anchors file '%s' for changes", f.caPath)
if f.jwksPath != nil {
f.log.Infof("Watching JWT bundle file '%s' for changes", f.jwksPath)
}
return concurrency.NewRunnerManager(
func(ctx context.Context) error {
return fs.Run(ctx, f.caEvent)
@ -134,7 +159,7 @@ func (f *file) Run(ctx context.Context) error {
f.log.Info("Trust anchors file changed, reloading trust anchors")
if err = f.updateAnchors(ctx); err != nil {
return fmt.Errorf("failed to read trust anchors file '%s': %v", f.path, err)
return fmt.Errorf("%w: '%s': %v", ErrFailedToReadTrustAnchorsFile, f.caPath, err)
}
}
}
@ -147,7 +172,7 @@ func (f *file) CurrentTrustAnchors(ctx context.Context) ([]byte, error) {
case <-ctx.Done():
return nil, ctx.Err()
case <-f.closeCh:
return nil, errors.New("trust anchors is closed")
return nil, ErrTrustAnchorsClosed
case <-f.readyCh:
}
@ -162,9 +187,9 @@ func (f *file) updateAnchors(ctx context.Context) error {
f.lock.Lock()
defer f.lock.Unlock()
rootPEMs, err := os.ReadFile(f.path)
rootPEMs, err := os.ReadFile(f.caPath)
if err != nil {
return fmt.Errorf("failed to read trust anchors file '%s': %w", f.path, err)
return fmt.Errorf("failed to read trust anchors file '%s': %w", f.caPath, err)
}
trustAnchorCerts, err := pem.DecodePEMCertificates(rootPEMs)
@ -173,7 +198,20 @@ func (f *file) updateAnchors(ctx context.Context) error {
}
f.rootPEM = rootPEMs
f.bundle = x509bundle.FromX509Authorities(spiffeid.TrustDomain{}, trustAnchorCerts)
f.x509Bundle = x509bundle.FromX509Authorities(spiffeid.TrustDomain{}, trustAnchorCerts)
if f.jwksPath != nil {
jwks, err := os.ReadFile(*f.jwksPath)
if err != nil {
return fmt.Errorf("failed to read JWT bundle file '%s': %w", *f.jwksPath, err)
}
jwtBundle, err := jwtbundle.Parse(spiffeid.TrustDomain{}, jwks)
if err != nil {
return fmt.Errorf("failed to parse JWT bundle: %w", err)
}
f.jwtBundle = jwtBundle
}
var wg sync.WaitGroup
defer wg.Wait()
@ -195,13 +233,26 @@ func (f *file) updateAnchors(ctx context.Context) error {
func (f *file) GetX509BundleForTrustDomain(_ spiffeid.TrustDomain) (*x509bundle.Bundle, error) {
select {
case <-f.closeCh:
return nil, errors.New("trust anchors is closed")
return nil, ErrTrustAnchorsClosed
case <-f.readyCh:
}
f.lock.RLock()
defer f.lock.RUnlock()
bundle := f.bundle
bundle := f.x509Bundle
return bundle, nil
}
func (f *file) GetJWTBundleForTrustDomain(_ spiffeid.TrustDomain) (*jwtbundle.Bundle, error) {
select {
case <-f.closeCh:
return nil, ErrTrustAnchorsClosed
case <-f.readyCh:
}
f.lock.RLock()
defer f.lock.RUnlock()
bundle := f.jwtBundle
return bundle, nil
}
@ -231,3 +282,19 @@ func (f *file) Watch(ctx context.Context, ch chan<- []byte) {
}
}
}
func filesExist(paths ...string) (bool, error) {
for _, path := range paths {
if path == "" {
continue
}
if _, err := os.Stat(path); err != nil {
if errors.Is(err, os.ErrNotExist) {
return false, nil
}
return false, fmt.Errorf("failed to stat file '%s': %w", path, err)
}
}
return true, nil
}

View File

@ -11,7 +11,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
package trustanchors
package file
import (
"context"
@ -31,9 +31,9 @@ import (
func TestFile_Run(t *testing.T) {
t.Run("if Run multiple times, expect error", func(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
ta := FromFile(OptionsFile{
Log: logger.NewLogger("test"),
Path: tmp,
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
@ -74,9 +74,9 @@ func TestFile_Run(t *testing.T) {
t.Run("if file is not found and context cancelled, should return ctx.Err", func(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
ta := FromFile(OptionsFile{
Log: logger.NewLogger("test"),
Path: tmp,
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
@ -102,9 +102,9 @@ func TestFile_Run(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, nil, 0o600))
ta := FromFile(OptionsFile{
Log: logger.NewLogger("test"),
Path: tmp,
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
@ -127,9 +127,9 @@ func TestFile_Run(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, []byte("garbage data"), 0o600))
ta := FromFile(OptionsFile{
Log: logger.NewLogger("test"),
Path: tmp,
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
@ -154,9 +154,9 @@ func TestFile_Run(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, root, 0o600))
ta := FromFile(OptionsFile{
Log: logger.NewLogger("test"),
Path: tmp,
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
@ -180,9 +180,9 @@ func TestFile_Run(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, pki.RootCertPEM, 0o600))
ta := FromFile(OptionsFile{
Log: logger.NewLogger("test"),
Path: tmp,
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
@ -211,9 +211,9 @@ func TestFile_Run(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, root, 0o600))
ta := FromFile(OptionsFile{
Log: logger.NewLogger("test"),
Path: tmp,
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
@ -242,9 +242,9 @@ func TestFile_Run(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
ta := FromFile(OptionsFile{
Log: logger.NewLogger("test"),
Path: tmp,
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
@ -273,9 +273,9 @@ func TestFile_Run(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
ta := FromFile(OptionsFile{
Log: logger.NewLogger("test"),
Path: tmp,
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
@ -311,9 +311,9 @@ func TestFile_GetX509BundleForTrustDomain(t *testing.T) {
root := append(pki.RootCertPEM, []byte("garbage data")...)
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, root, 0o600))
ta := FromFile(OptionsFile{
Log: logger.NewLogger("test"),
Path: tmp,
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
@ -337,7 +337,7 @@ func TestFile_GetX509BundleForTrustDomain(t *testing.T) {
require.NoError(t, err)
bundle, err := f.GetX509BundleForTrustDomain(trustDomain1)
require.NoError(t, err)
assert.Equal(t, f.bundle, bundle)
assert.Equal(t, f.x509Bundle, bundle)
b1, err := bundle.Marshal()
require.NoError(t, err)
assert.Equal(t, pki.RootCertPEM, b1)
@ -346,7 +346,7 @@ func TestFile_GetX509BundleForTrustDomain(t *testing.T) {
require.NoError(t, err)
bundle, err = f.GetX509BundleForTrustDomain(trustDomain2)
require.NoError(t, err)
assert.Equal(t, f.bundle, bundle)
assert.Equal(t, f.x509Bundle, bundle)
b2, err := bundle.Marshal()
require.NoError(t, err)
assert.Equal(t, pki.RootCertPEM, b2)
@ -359,9 +359,9 @@ func TestFile_Watch(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, pki.RootCertPEM, 0o600))
ta := FromFile(OptionsFile{
Log: logger.NewLogger("test"),
Path: tmp,
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
@ -400,9 +400,9 @@ func TestFile_Watch(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, pki.RootCertPEM, 0o600))
ta := FromFile(OptionsFile{
Log: logger.NewLogger("test"),
Path: tmp,
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
@ -446,9 +446,9 @@ func TestFile_Watch(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, pki1.RootCertPEM, 0o600))
ta := FromFile(OptionsFile{
Log: logger.NewLogger("test"),
Path: tmp,
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
@ -529,9 +529,9 @@ func TestFile_CurrentTrustAnchors(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, pki1.RootCertPEM, 0o600))
ta := FromFile(OptionsFile{
Log: logger.NewLogger("test"),
Path: tmp,
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
@ -547,6 +547,7 @@ func TestFile_CurrentTrustAnchors(t *testing.T) {
//nolint:gocritic
roots := append(pki1.RootCertPEM, pki2.RootCertPEM...)
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
time.Sleep(time.Millisecond * 10) // adding a small delay to ensure the file watcher has time to pick up the change
assert.EventuallyWithT(t, func(c *assert.CollectT) {
pem, err := ta.CurrentTrustAnchors(context.Background())
require.NoError(t, err)
@ -556,6 +557,7 @@ func TestFile_CurrentTrustAnchors(t *testing.T) {
//nolint:gocritic
roots = append(pki1.RootCertPEM, append(pki2.RootCertPEM, pki3.RootCertPEM...)...)
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
time.Sleep(time.Millisecond * 10) // adding a small delay to ensure the file watcher has time to pick up the change
assert.EventuallyWithT(t, func(c *assert.CollectT) {
pem, err := ta.CurrentTrustAnchors(context.Background())
require.NoError(t, err)

View File

@ -11,16 +11,18 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
package trustanchors
package multi
import (
"context"
"errors"
"github.com/spiffe/go-spiffe/v2/bundle/jwtbundle"
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/dapr/kit/concurrency"
"github.com/dapr/kit/crypto/spiffe/trustanchors"
)
var (
@ -28,17 +30,17 @@ var (
ErrTrustDomainNotFound = errors.New("trust domain not found")
)
type OptionsMulti struct {
TrustAnchors map[spiffeid.TrustDomain]Interface
type Options struct {
TrustAnchors map[spiffeid.TrustDomain]trustanchors.Interface
}
// multi is a TrustAnchors implementation which uses multiple trust anchors
// which are indexed by trust domain.
type multi struct {
trustAnchors map[spiffeid.TrustDomain]Interface
trustAnchors map[spiffeid.TrustDomain]trustanchors.Interface
}
func FromMulti(opts OptionsMulti) Interface {
func From(opts Options) trustanchors.Interface {
return &multi{
trustAnchors: opts.TrustAnchors,
}
@ -69,6 +71,16 @@ func (m *multi) GetX509BundleForTrustDomain(td spiffeid.TrustDomain) (*x509bundl
return nil, ErrTrustDomainNotFound
}
func (m *multi) GetJWTBundleForTrustDomain(td spiffeid.TrustDomain) (*jwtbundle.Bundle, error) {
for tad, ta := range m.trustAnchors {
if td.Compare(tad) == 0 {
return ta.GetJWTBundleForTrustDomain(td)
}
}
return nil, ErrTrustDomainNotFound
}
func (m *multi) Watch(context.Context, chan<- []byte) {
return
}

View File

@ -11,7 +11,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
package trustanchors
package static
import (
"context"
@ -19,31 +19,52 @@ import (
"fmt"
"sync/atomic"
"github.com/spiffe/go-spiffe/v2/bundle/jwtbundle"
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/dapr/kit/crypto/pem"
"github.com/dapr/kit/crypto/spiffe/trustanchors"
)
// static is a TrustAcnhors implementation that uses a static list of trust
// anchors.
type static struct {
bundle *x509bundle.Bundle
anchors []byte
running atomic.Bool
closeCh chan struct{}
x509Bundle *x509bundle.Bundle
jwtBundle *jwtbundle.Bundle
anchors []byte
running atomic.Bool
closeCh chan struct{}
}
func FromStatic(anchors []byte) (Interface, error) {
trustAnchorCerts, err := pem.DecodePEMCertificates(anchors)
type Options struct {
Anchors []byte
Jwks []byte
}
func From(opts Options) (trustanchors.Interface, error) {
// Create empty trust domain for now
emptyTD := spiffeid.TrustDomain{}
var jwtBundle *jwtbundle.Bundle
if opts.Jwks != nil {
var err error
jwtBundle, err = jwtbundle.Parse(emptyTD, opts.Jwks)
if err != nil {
return nil, fmt.Errorf("failed to create JWT bundle: %w", err)
}
}
trustAnchorCerts, err := pem.DecodePEMCertificates(opts.Anchors)
if err != nil {
return nil, fmt.Errorf("failed to decode trust anchors: %w", err)
}
return &static{
anchors: anchors,
bundle: x509bundle.FromX509Authorities(spiffeid.TrustDomain{}, trustAnchorCerts),
closeCh: make(chan struct{}),
anchors: opts.Anchors,
x509Bundle: x509bundle.FromX509Authorities(emptyTD, trustAnchorCerts),
jwtBundle: jwtBundle,
closeCh: make(chan struct{}),
}, nil
}
@ -63,7 +84,11 @@ func (s *static) Run(ctx context.Context) error {
}
func (s *static) GetX509BundleForTrustDomain(spiffeid.TrustDomain) (*x509bundle.Bundle, error) {
return s.bundle, nil
return s.x509Bundle, nil
}
func (s *static) GetJWTBundleForTrustDomain(_ spiffeid.TrustDomain) (*jwtbundle.Bundle, error) {
return s.jwtBundle, nil
}
func (s *static) Watch(ctx context.Context, _ chan<- []byte) {

View File

@ -11,7 +11,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
package trustanchors
package static
import (
"context"
@ -27,30 +27,30 @@ import (
func TestFromStatic(t *testing.T) {
t.Run("empty root should return error", func(t *testing.T) {
_, err := FromStatic(nil)
_, err := From(Options{})
require.Error(t, err)
})
t.Run("garbage data should return error", func(t *testing.T) {
_, err := FromStatic([]byte("garbage data"))
_, err := From(Options{Anchors: []byte("garbage data")})
require.Error(t, err)
})
t.Run("just garbage data should return error", func(t *testing.T) {
_, err := FromStatic([]byte("garbage data"))
_, err := From(Options{Anchors: []byte("garbage data")})
require.Error(t, err)
})
t.Run("garbage data in root should return error", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
root := pki.RootCertPEM[10:]
_, err := FromStatic(root)
_, err := From(Options{Anchors: root})
require.Error(t, err)
})
t.Run("single root should be correctly parsed", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
ta, err := FromStatic(pki.RootCertPEM)
ta, err := From(Options{Anchors: pki.RootCertPEM})
require.NoError(t, err)
taPEM, err := ta.CurrentTrustAnchors(context.Background())
require.NoError(t, err)
@ -61,7 +61,7 @@ func TestFromStatic(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
//nolint:gocritic
root := append(pki.RootCertPEM, []byte("garbage data")...)
ta, err := FromStatic(root)
ta, err := From(Options{Anchors: root})
require.NoError(t, err)
taPEM, err := ta.CurrentTrustAnchors(context.Background())
require.NoError(t, err)
@ -72,7 +72,7 @@ func TestFromStatic(t *testing.T) {
pki1, pki2 := test.GenPKI(t, test.PKIOptions{}), test.GenPKI(t, test.PKIOptions{})
//nolint:gocritic
roots := append(pki1.RootCertPEM, pki2.RootCertPEM...)
ta, err := FromStatic(roots)
ta, err := From(Options{Anchors: roots})
require.NoError(t, err)
taPEM, err := ta.CurrentTrustAnchors(context.Background())
require.NoError(t, err)
@ -85,7 +85,7 @@ func TestStatic_GetX509BundleForTrustDomain(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
//nolint:gocritic
root := append(pki.RootCertPEM, []byte("garbage data")...)
ta, err := FromStatic(root)
ta, err := From(Options{Anchors: root})
require.NoError(t, err)
s, ok := ta.(*static)
require.True(t, ok)
@ -94,7 +94,7 @@ func TestStatic_GetX509BundleForTrustDomain(t *testing.T) {
require.NoError(t, err)
bundle, err := s.GetX509BundleForTrustDomain(trustDomain1)
require.NoError(t, err)
assert.Equal(t, s.bundle, bundle)
assert.Equal(t, s.x509Bundle, bundle)
b1, err := bundle.Marshal()
require.NoError(t, err)
assert.Equal(t, pki.RootCertPEM, b1)
@ -103,7 +103,7 @@ func TestStatic_GetX509BundleForTrustDomain(t *testing.T) {
require.NoError(t, err)
bundle, err = s.GetX509BundleForTrustDomain(trustDomain2)
require.NoError(t, err)
assert.Equal(t, s.bundle, bundle)
assert.Equal(t, s.x509Bundle, bundle)
b2, err := bundle.Marshal()
require.NoError(t, err)
assert.Equal(t, pki.RootCertPEM, b2)
@ -113,7 +113,7 @@ func TestStatic_GetX509BundleForTrustDomain(t *testing.T) {
func TestStatic_Run(t *testing.T) {
t.Run("Run multiple times should return error", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
ta, err := FromStatic(pki.RootCertPEM)
ta, err := From(Options{Anchors: pki.RootCertPEM})
require.NoError(t, err)
s, ok := ta.(*static)
require.True(t, ok)
@ -154,7 +154,7 @@ func TestStatic_Run(t *testing.T) {
func TestStatic_Watch(t *testing.T) {
t.Run("should return when context is cancelled", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
ta, err := FromStatic(pki.RootCertPEM)
ta, err := From(Options{Anchors: pki.RootCertPEM})
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
@ -176,7 +176,7 @@ func TestStatic_Watch(t *testing.T) {
t.Run("should return when cancel is closed via closed Run", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
ta, err := FromStatic(pki.RootCertPEM)
ta, err := From(Options{Anchors: pki.RootCertPEM})
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())

View File

@ -16,6 +16,7 @@ package trustanchors
import (
"context"
"github.com/spiffe/go-spiffe/v2/bundle/jwtbundle"
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
)
@ -23,8 +24,10 @@ import (
// Allows consumers to get the current trust anchor bundle, and subscribe to
// bundle updates.
type Interface interface {
// Source implements the SPIFFE trust anchor bundle source.
// Source implements the SPIFFE trust anchor x509 bundle source.
x509bundle.Source
// Source implements the SPIFFE trust anchor jwt bundle source.
jwtbundle.Source
// CurrentTrustAnchors returns the current trust anchor PEM bundle.
CurrentTrustAnchors(ctx context.Context) ([]byte, error)

42
env/env.go vendored Normal file
View File

@ -0,0 +1,42 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package env
import (
"fmt"
"os"
"time"
)
// GetDurationWithRange returns the time.Duration value of the environment variable specified by `envVar`.
// If the environment variable is not set, it returns `defaultValue`.
// If the value is set but is not valid (not a valid time.Duration or falls outside the specified range
// [minValue, maxValue] inclusively), it returns `defaultValue` and an error.
func GetDurationWithRange(envVar string, defaultValue, min, max time.Duration) (time.Duration, error) {
v := os.Getenv(envVar)
if v == "" {
return defaultValue, nil
}
val, err := time.ParseDuration(v)
if err != nil {
return defaultValue, fmt.Errorf("invalid time.Duration value %s for the %s env variable: %w", val, envVar, err)
}
if val < min || val > max {
return defaultValue, fmt.Errorf("invalid value for the %s env variable: value should be between %s and %s, got %s", envVar, min, max, val)
}
return val, nil
}

View File

@ -1,4 +1,17 @@
package utils
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package env
import (
"testing"
@ -7,7 +20,7 @@ import (
"github.com/stretchr/testify/require"
)
func TestGetEnvIntWithRangeWrongValues(t *testing.T) {
func TestGetIntWithRangeWrongValues(t *testing.T) {
testValues := []struct {
name string
envVarVal string
@ -43,7 +56,7 @@ func TestGetEnvIntWithRangeWrongValues(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Setenv("MY_ENV", tt.envVarVal)
val, err := GetEnvDurationWithRange("MY_ENV", defaultValue, tt.min, tt.max)
val, err := GetDurationWithRange("MY_ENV", defaultValue, tt.min, tt.max)
require.Error(t, err)
require.Contains(t, err.Error(), tt.error)
require.Equal(t, defaultValue, val)
@ -75,7 +88,7 @@ func TestGetEnvDurationWithRangeValidValues(t *testing.T) {
t.Setenv("MY_ENV", tt.envVarVal)
}
val, err := GetEnvDurationWithRange("MY_ENV", 3*time.Second, time.Second, 5*time.Second)
val, err := GetDurationWithRange("MY_ENV", 3*time.Second, time.Second, 5*time.Second)
require.NoError(t, err)
require.Equal(t, tt.result, val)
})

65
events/loop/fake/fake.go Normal file
View File

@ -0,0 +1,65 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package fake
import (
"context"
"github.com/dapr/kit/events/loop"
)
type Fake[T any] struct {
runFn func(context.Context) error
enqueueFn func(T)
closeFn func(T)
}
func New[T any]() *Fake[T] {
return &Fake[T]{
runFn: func(context.Context) error { return nil },
enqueueFn: func(T) {},
closeFn: func(T) {},
}
}
func (f *Fake[T]) WithRun(fn func(context.Context) error) *Fake[T] {
f.runFn = fn
return f
}
func (f *Fake[T]) WithEnqueue(fn func(T)) *Fake[T] {
f.enqueueFn = fn
return f
}
func (f *Fake[T]) WithClose(fn func(T)) *Fake[T] {
f.closeFn = fn
return f
}
func (f *Fake[T]) Run(ctx context.Context) error {
return f.runFn(ctx)
}
func (f *Fake[T]) Enqueue(t T) {
f.enqueueFn(t)
}
func (f *Fake[T]) Close(t T) {
f.closeFn(t)
}
func (f *Fake[T]) Reset(loop.Handler[T], uint64) loop.Interface[T] {
return f
}

View File

@ -0,0 +1,24 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package fake
import (
"testing"
"github.com/dapr/kit/events/loop"
)
func Test_Fake(*testing.T) {
var _ loop.Interface[int] = New[int]()
}

View File

@ -15,58 +15,97 @@ package loop
import (
"context"
"sync"
)
type HandlerFunc[T any] func(context.Context, T) error
type Options[T any] struct {
Handler HandlerFunc[T]
BufferSize *uint64
type Handler[T any] interface {
Handle(ctx context.Context, t T) error
}
type Loop[T any] struct {
type Interface[T any] interface {
Run(ctx context.Context) error
Enqueue(t T)
Close(t T)
Reset(h Handler[T], size uint64) Interface[T]
}
type loop[T any] struct {
queue chan T
handler HandlerFunc[T]
handler Handler[T]
closed bool
closeCh chan struct{}
lock sync.RWMutex
}
func New[T any](opts Options[T]) *Loop[T] {
size := 1
if opts.BufferSize != nil {
size = int(*opts.BufferSize)
}
return &Loop[T]{
func New[T any](h Handler[T], size uint64) Interface[T] {
return &loop[T]{
queue: make(chan T, size),
handler: h,
closeCh: make(chan struct{}),
handler: opts.Handler,
}
}
func (l *Loop[T]) Run(ctx context.Context) error {
func Empty[T any]() Interface[T] {
return new(loop[T])
}
func (l *loop[T]) Run(ctx context.Context) error {
defer close(l.closeCh)
for {
var req T
select {
case req = <-l.queue:
case <-ctx.Done():
req, ok := <-l.queue
if !ok {
return nil
}
if err := ctx.Err(); err != nil {
return err
}
if err := l.handler(ctx, req); err != nil {
if err := l.handler.Handle(ctx, req); err != nil {
return err
}
}
}
func (l *Loop[T]) Enqueue(req T) {
func (l *loop[T]) Enqueue(req T) {
l.lock.RLock()
defer l.lock.RUnlock()
if l.closed {
return
}
select {
case l.queue <- req:
case <-l.closeCh:
}
}
func (l *loop[T]) Close(req T) {
l.lock.Lock()
l.closed = true
select {
case l.queue <- req:
case <-l.closeCh:
}
close(l.queue)
l.lock.Unlock()
<-l.closeCh
}
func (l *loop[T]) Reset(h Handler[T], size uint64) Interface[T] {
if l == nil {
return New[T](h, size)
}
l.lock.Lock()
defer l.lock.Unlock()
l.closed = false
l.closeCh = make(chan struct{})
l.handler = h
// TODO: @joshvanl: use a ring buffer so that we don't need to reallocate and
// improve performance.
l.queue = make(chan T, size)
return l
}

1
go.mod
View File

@ -28,6 +28,7 @@ require (
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect
github.com/go-jose/go-jose/v3 v3.0.1 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/lestrrat-go/blackmagic v1.0.2 // indirect

2
go.sum
View File

@ -19,6 +19,7 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.1.0 h1:Hsa8mG0dQ46ij8Sl2AYJDUv1oA9/d6Vk+3LG99Oe02g=
@ -72,6 +73,7 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec
github.com/zeebo/errs v1.3.0 h1:hmiaKqgYZzcVgRL1Vkc1Mn2914BbzB0IBxs+ebeutGs=
github.com/zeebo/errs v1.3.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=

View File

@ -37,9 +37,9 @@ import (
"github.com/lestrrat-go/httprc"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/dapr/kit/crypto/pem"
"github.com/dapr/kit/fswatcher"
"github.com/dapr/kit/logger"
"github.com/dapr/kit/utils"
)
const (
@ -198,7 +198,7 @@ func (c *JWKSCache) initJWKSFromURL(ctx context.Context, url string) error {
// Load CA certificates if we have one
if c.caCertificate != "" {
caCert, err := utils.GetPEM(c.caCertificate)
caCert, err := pem.GetPEM(c.caCertificate)
if err != nil {
return fmt.Errorf("failed to load CA certificate: %w", err)
}

View File

@ -23,7 +23,7 @@ import (
"github.com/mitchellh/mapstructure"
"github.com/dapr/kit/ptr"
"github.com/dapr/kit/utils"
kitstrings "github.com/dapr/kit/strings"
)
func toTruthyBoolHookFunc() mapstructure.DecodeHookFunc {
@ -37,10 +37,10 @@ func toTruthyBoolHookFunc() mapstructure.DecodeHookFunc {
data any,
) (any, error) {
if f == stringType && t == boolType {
return utils.IsTruthy(data.(string)), nil
return kitstrings.IsTruthy(data.(string)), nil
}
if f == stringType && t == boolPtrType {
return ptr.Of(utils.IsTruthy(data.(string))), nil
return ptr.Of(kitstrings.IsTruthy(data.(string))), nil
}
return data, nil
}

View File

@ -11,7 +11,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
package utils
package strings
import (
"path/filepath"

View File

@ -1,29 +0,0 @@
package utils
import (
"fmt"
"os"
"time"
)
// GetEnvDurationWithRange returns the time.Duration value of the environment variable specified by `envVar`.
// If the environment variable is not set, it returns `defaultValue`.
// If the value is set but is not valid (not a valid time.Duration or falls outside the specified range
// [minValue, maxValue] inclusively), it returns `defaultValue` and an error.
func GetEnvDurationWithRange(envVar string, defaultValue, min, max time.Duration) (time.Duration, error) {
v := os.Getenv(envVar)
if v == "" {
return defaultValue, nil
}
val, err := time.ParseDuration(v)
if err != nil {
return defaultValue, fmt.Errorf("invalid time.Duration value %s for the %s env variable: %w", val, envVar, err)
}
if val < min || val > max {
return defaultValue, fmt.Errorf("invalid value for the %s env variable: value should be between %s and %s, got %s", envVar, min, max, val)
}
return val, nil
}

View File

@ -1,41 +0,0 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package utils
import (
"encoding/pem"
"fmt"
"os"
)
// GetPEM loads a PEM-encoded file (certificate or key).
func GetPEM(val string) ([]byte, error) {
// If val is already a PEM-encoded string, return it as-is
if IsValidPEM(val) {
return []byte(val), nil
}
// Assume it's a file
pemBytes, err := os.ReadFile(val)
if err != nil {
return nil, fmt.Errorf("value is neither a valid file path or nor a valid PEM-encoded string: %w", err)
}
return pemBytes, nil
}
// IsValidPEM validates the provided input has PEM formatted block.
func IsValidPEM(val string) bool {
block, _ := pem.Decode([]byte(val))
return block != nil
}