mirror of https://github.com/dapr/kit.git
Compare commits
6 Commits
Author | SHA1 | Date |
---|---|---|
|
598b032bce | |
|
d7d50a1e1b | |
|
baea626399 | |
|
bc7dc566c4 | |
|
98fe567235 | |
|
e3d4a8f1b4 |
195
.golangci.yml
195
.golangci.yml
|
@ -4,7 +4,7 @@ run:
|
|||
concurrency: 4
|
||||
|
||||
# timeout for analysis, e.g. 30s, 5m, default is 1m
|
||||
deadline: 10m
|
||||
timeout: 15m
|
||||
|
||||
# exit code when at least one issue was found, default is 1
|
||||
issues-exit-code: 1
|
||||
|
@ -15,29 +15,34 @@ run:
|
|||
# list of build tags, all linters use it. Default is empty list.
|
||||
build-tags:
|
||||
- unit
|
||||
|
||||
# which dirs to skip: they won't be analyzed;
|
||||
# can use regexp here: generated.*, regexp is applied on full path;
|
||||
# default value is empty list, but next dirs are always skipped independently
|
||||
# from this option's value:
|
||||
# third_party$, testdata$, examples$, Godeps$, builtin$
|
||||
skip-dirs:
|
||||
- ^pkg.*client.*clientset.*versioned.*
|
||||
- ^pkg.*client.*informers.*externalversions.*
|
||||
- ^pkg.*proto.*
|
||||
- allcomponents
|
||||
- subtlecrypto
|
||||
|
||||
# which files to skip: they will be analyzed, but issues from them
|
||||
# won't be reported. Default value is empty list, but there is
|
||||
# no need to include all autogenerated files, we confidently recognize
|
||||
# autogenerated files. If it's not please let us know.
|
||||
skip-files:
|
||||
# skip-files:
|
||||
# - ".*\\.my\\.go$"
|
||||
# - lib/bad.go
|
||||
|
||||
issues:
|
||||
# which dirs to skip: they won't be analyzed;
|
||||
# can use regexp here: generated.*, regexp is applied on full path;
|
||||
# default value is empty list, but next dirs are always skipped independently
|
||||
# from this option's value:
|
||||
# third_party$, testdata$, examples$, Godeps$, builtin$
|
||||
exclude-dirs:
|
||||
- ^pkg.*client.*clientset.*versioned.*
|
||||
- ^pkg.*client.*informers.*externalversions.*
|
||||
- ^pkg.*proto.*
|
||||
- pkg/proto
|
||||
|
||||
# output configuration options
|
||||
output:
|
||||
# colored-line-number|line-number|json|tab|checkstyle, default is "colored-line-number"
|
||||
format: tab
|
||||
formats:
|
||||
- format: tab
|
||||
|
||||
# print lines of code with issue, default is true
|
||||
print-issued-lines: true
|
||||
|
@ -57,23 +62,19 @@ linters-settings:
|
|||
# default is false: such cases aren't reported by default.
|
||||
check-blank: false
|
||||
|
||||
# [deprecated] comma-separated list of pairs of the form pkg:regex
|
||||
# the regex is used to ignore names within pkg. (default "fmt:.*").
|
||||
# see https://github.com/kisielk/errcheck#the-deprecated-method for details
|
||||
ignore: fmt:.*,io/ioutil:^Read.*
|
||||
exclude-functions:
|
||||
- fmt:.*
|
||||
- io/ioutil:^Read.*
|
||||
|
||||
# path to a file containing a list of functions to exclude from checking
|
||||
# see https://github.com/kisielk/errcheck#excluding-functions for details
|
||||
exclude:
|
||||
# exclude:
|
||||
|
||||
funlen:
|
||||
lines: 60
|
||||
statements: 40
|
||||
|
||||
govet:
|
||||
# report about shadowed variables
|
||||
check-shadowing: true
|
||||
|
||||
# settings per analyzer
|
||||
settings:
|
||||
printf: # analyzer name, run `go tool vet help` to see all analyzers
|
||||
|
@ -86,13 +87,12 @@ linters-settings:
|
|||
# enable or disable analyzers by name
|
||||
enable:
|
||||
- atomicalign
|
||||
enable-all: false
|
||||
disable:
|
||||
- shadow
|
||||
enable-all: false
|
||||
disable-all: false
|
||||
golint:
|
||||
revive:
|
||||
# minimal confidence for issues, default is 0.8
|
||||
min-confidence: 0.8
|
||||
confidence: 0.8
|
||||
gofmt:
|
||||
# simplify code: gofmt with `-s` option, true by default
|
||||
simplify: true
|
||||
|
@ -106,9 +106,6 @@ linters-settings:
|
|||
gocognit:
|
||||
# minimal code complexity to report, 30 by default (but we recommend 10-20)
|
||||
min-complexity: 10
|
||||
maligned:
|
||||
# print struct with more effective memory layout or not, false by default
|
||||
suggest-new: true
|
||||
dupl:
|
||||
# tokens count to trigger issue, 150 by default
|
||||
threshold: 100
|
||||
|
@ -121,55 +118,60 @@ linters-settings:
|
|||
rules:
|
||||
main:
|
||||
deny:
|
||||
- pkg: "github.com/Sirupsen/logrus"
|
||||
desc: "must use github.com/dapr/kit/logger"
|
||||
- pkg: "github.com/agrea/ptr"
|
||||
desc: "must use github.com/dapr/kit/ptr"
|
||||
- pkg: "go.uber.org/atomic"
|
||||
desc: "must use sync/atomic"
|
||||
- pkg: "golang.org/x/net/context"
|
||||
desc: "must use context"
|
||||
- pkg: "github.com/pkg/errors"
|
||||
desc: "must use standard library (errors package and/or fmt.Errorf)"
|
||||
- pkg: "github.com/go-chi/chi$"
|
||||
desc: "must use github.com/go-chi/chi/v5"
|
||||
- pkg: "github.com/cenkalti/backoff$"
|
||||
desc: "must use github.com/cenkalti/backoff/v4"
|
||||
- pkg: "github.com/cenkalti/backoff/v2"
|
||||
desc: "must use github.com/cenkalti/backoff/v4"
|
||||
- pkg: "github.com/cenkalti/backoff/v3"
|
||||
desc: "must use github.com/cenkalti/backoff/v4"
|
||||
- pkg: "github.com/benbjohnson/clock"
|
||||
desc: "must use k8s.io/utils/clock"
|
||||
- pkg: "github.com/ghodss/yaml"
|
||||
desc: "must use sigs.k8s.io/yaml"
|
||||
- pkg: "gopkg.in/yaml.v2"
|
||||
desc: "must use gopkg.in/yaml.v3"
|
||||
- pkg: "github.com/golang-jwt/jwt"
|
||||
desc: "must use github.com/lestrrat-go/jwx/v2"
|
||||
- pkg: "github.com/golang-jwt/jwt/v2"
|
||||
desc: "must use github.com/lestrrat-go/jwx/v2"
|
||||
- pkg: "github.com/golang-jwt/jwt/v3"
|
||||
desc: "must use github.com/lestrrat-go/jwx/v2"
|
||||
- pkg: "github.com/golang-jwt/jwt/v4"
|
||||
desc: "must use github.com/lestrrat-go/jwx/v2"
|
||||
- pkg: "github.com/gogo/status"
|
||||
desc: "must use google.golang.org/grpc/status"
|
||||
- pkg: "github.com/gogo/protobuf"
|
||||
desc: "must use google.golang.org/protobuf"
|
||||
- pkg: "github.com/lestrrat-go/jwx/jwa"
|
||||
desc: "must use github.com/lestrrat-go/jwx/v2"
|
||||
- pkg: "github.com/lestrrat-go/jwx/jwt"
|
||||
desc: "must use github.com/lestrrat-go/jwx/v2"
|
||||
- pkg: "github.com/labstack/gommon/log"
|
||||
desc: "must use github.com/dapr/kit/logger"
|
||||
- pkg: "github.com/gobuffalo/logger"
|
||||
desc: "must use github.com/dapr/kit/logger"
|
||||
- pkg: "github.com/Sirupsen/logrus"
|
||||
desc: "must use github.com/dapr/kit/logger"
|
||||
- pkg: "github.com/agrea/ptr"
|
||||
desc: "must use github.com/dapr/kit/ptr"
|
||||
- pkg: "go.uber.org/atomic"
|
||||
desc: "must use sync/atomic"
|
||||
- pkg: "golang.org/x/net/context"
|
||||
desc: "must use context"
|
||||
- pkg: "github.com/pkg/errors"
|
||||
desc: "must use standard library (errors package and/or fmt.Errorf)"
|
||||
- pkg: "github.com/go-chi/chi$"
|
||||
desc: "must use github.com/go-chi/chi/v5"
|
||||
- pkg: "github.com/cenkalti/backoff$"
|
||||
desc: "must use github.com/cenkalti/backoff/v4"
|
||||
- pkg: "github.com/cenkalti/backoff/v2"
|
||||
desc: "must use github.com/cenkalti/backoff/v4"
|
||||
- pkg: "github.com/cenkalti/backoff/v3"
|
||||
desc: "must use github.com/cenkalti/backoff/v4"
|
||||
- pkg: "github.com/benbjohnson/clock"
|
||||
desc: "must use k8s.io/utils/clock"
|
||||
- pkg: "github.com/ghodss/yaml"
|
||||
desc: "must use sigs.k8s.io/yaml"
|
||||
- pkg: "gopkg.in/yaml.v2"
|
||||
desc: "must use gopkg.in/yaml.v3"
|
||||
- pkg: "github.com/golang-jwt/jwt"
|
||||
desc: "must use github.com/lestrrat-go/jwx/v2"
|
||||
- pkg: "github.com/golang-jwt/jwt/v2"
|
||||
desc: "must use github.com/lestrrat-go/jwx/v2"
|
||||
- pkg: "github.com/golang-jwt/jwt/v3"
|
||||
desc: "must use github.com/lestrrat-go/jwx/v2"
|
||||
- pkg: "github.com/golang-jwt/jwt/v4"
|
||||
desc: "must use github.com/lestrrat-go/jwx/v2"
|
||||
# pkg: Commonly auto-completed by gopls
|
||||
- pkg: "github.com/gogo/status"
|
||||
desc: "must use google.golang.org/grpc/status"
|
||||
- pkg: "github.com/gogo/protobuf"
|
||||
desc: "must use google.golang.org/protobuf"
|
||||
- pkg: "github.com/lestrrat-go/jwx/jwa"
|
||||
desc: "must use github.com/lestrrat-go/jwx/v2"
|
||||
- pkg: "github.com/lestrrat-go/jwx/jwt"
|
||||
desc: "must use github.com/lestrrat-go/jwx/v2"
|
||||
- pkg: "github.com/labstack/gommon/log"
|
||||
desc: "must use github.com/dapr/kit/logger"
|
||||
- pkg: "github.com/gobuffalo/logger"
|
||||
desc: "must use github.com/dapr/kit/logger"
|
||||
- pkg: "k8s.io/utils/pointer"
|
||||
desc: "must use github.com/dapr/kit/ptr"
|
||||
- pkg: "k8s.io/utils/ptr"
|
||||
desc: "must use github.com/dapr/kit/ptr"
|
||||
misspell:
|
||||
# Correct spellings using locale preferences for US or UK.
|
||||
# Default is to use a neutral variety of English.
|
||||
# Setting locale to US will correct the British spelling of 'colour' to 'color'.
|
||||
locale: default
|
||||
# locale: default
|
||||
ignore-words:
|
||||
- someword
|
||||
lll:
|
||||
|
@ -178,17 +180,9 @@ linters-settings:
|
|||
line-length: 120
|
||||
# tab width in spaces. Default to 1.
|
||||
tab-width: 1
|
||||
unparam:
|
||||
# Inspect exported functions, default is false. Set to true if no external program/library imports your code.
|
||||
# XXX: if you enable this setting, unparam will report a lot of false-positives in text editors:
|
||||
# if it's called for subdir of a project it can't find external interfaces. All text editor integrations
|
||||
# with golangci-lint call it on a directory with the changed file.
|
||||
check-exported: false
|
||||
nakedret:
|
||||
# make an issue if func has more lines of code than this setting and it has naked returns; default is 30
|
||||
max-func-lines: 30
|
||||
nolintlint:
|
||||
allow-unused: true
|
||||
prealloc:
|
||||
# XXX: we don't recommend using this linter before doing performance profiling.
|
||||
# For most programs usage of prealloc will be a premature optimization.
|
||||
|
@ -203,7 +197,6 @@ linters-settings:
|
|||
# See https://go-critic.github.io/overview#checks-overview
|
||||
# To check which checks are enabled run `GL_DEBUG=gocritic golangci-lint run`
|
||||
# By default list of stable checks is used.
|
||||
enabled-checks:
|
||||
|
||||
# Which checks should be disabled; can't be combined with 'enabled-checks'; default is empty
|
||||
disabled-checks:
|
||||
|
@ -251,62 +244,51 @@ linters-settings:
|
|||
allow-assign-and-call: true
|
||||
# Allow multiline assignments to be cuddled. Default is true.
|
||||
allow-multiline-assign: true
|
||||
# Allow case blocks to end with a whitespace.
|
||||
allow-case-traling-whitespace: true
|
||||
# Allow declarations (var) to be cuddled.
|
||||
allow-cuddle-declarations: false
|
||||
# If the number of lines in a case block is equal to or lager than this number,
|
||||
# the case *must* end white a newline.
|
||||
# https://github.com/bombsimon/wsl/blob/master/doc/configuration.md#force-case-trailing-whitespace
|
||||
# Default: 0
|
||||
force-case-trailing-whitespace: 1
|
||||
|
||||
linters:
|
||||
fast: false
|
||||
enable-all: true
|
||||
disable:
|
||||
# TODO Enforce the below linters later
|
||||
- musttag
|
||||
- dupl
|
||||
- nonamedreturns
|
||||
- errcheck
|
||||
- funlen
|
||||
- goconst
|
||||
- gochecknoglobals
|
||||
- gochecknoinits
|
||||
- gocyclo
|
||||
- gocognit
|
||||
- nosnakecase
|
||||
- varcheck
|
||||
- structcheck
|
||||
- deadcode
|
||||
- godox
|
||||
- interfacer
|
||||
- lll
|
||||
- maligned
|
||||
- scopelint
|
||||
- unparam
|
||||
- wsl
|
||||
- gomnd
|
||||
- testpackage
|
||||
- nestif
|
||||
- nlreturn
|
||||
- exhaustive
|
||||
- exhaustruct
|
||||
- noctx
|
||||
- gci
|
||||
- golint
|
||||
- tparallel
|
||||
- paralleltest
|
||||
- wrapcheck
|
||||
- tagliatelle
|
||||
- ireturn
|
||||
- exhaustive
|
||||
- exhaustivestruct
|
||||
- exhaustruct
|
||||
- errchkjson
|
||||
- contextcheck
|
||||
- gomoddirectives
|
||||
- godot
|
||||
- cyclop
|
||||
- varnamelen
|
||||
- gosec
|
||||
- tagalign
|
||||
- errorlint
|
||||
- forcetypeassert
|
||||
- ifshort
|
||||
- maintidx
|
||||
- nilnil
|
||||
- predeclared
|
||||
|
@ -315,4 +297,13 @@ linters:
|
|||
- wastedassign
|
||||
- containedctx
|
||||
- gosimple
|
||||
- forbidigo
|
||||
- nonamedreturns
|
||||
- asasalint
|
||||
- rowserrcheck
|
||||
- sqlclosecheck
|
||||
- inamedparam
|
||||
- tagalign
|
||||
- mnd
|
||||
- canonicalheader
|
||||
- err113
|
||||
- fatcontext
|
||||
|
|
|
@ -23,6 +23,7 @@ import (
|
|||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// DecodePEMCertificatesChain takes a PEM-encoded x509 certificates byte array
|
||||
|
@ -187,3 +188,24 @@ func PublicKeysEqual(a, b crypto.PublicKey) (bool, error) {
|
|||
return false, fmt.Errorf("unrecognised public key type: %T", a)
|
||||
}
|
||||
}
|
||||
|
||||
// GetPEM loads a PEM-encoded file (certificate or key).
|
||||
func GetPEM(val string) ([]byte, error) {
|
||||
// If val is already a PEM-encoded string, return it as-is
|
||||
if IsValidPEM(val) {
|
||||
return []byte(val), nil
|
||||
}
|
||||
|
||||
// Assume it's a file
|
||||
pemBytes, err := os.ReadFile(val)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("value is neither a valid file path or nor a valid PEM-encoded string: %w", err)
|
||||
}
|
||||
return pemBytes, nil
|
||||
}
|
||||
|
||||
// IsValidPEM validates the provided input has PEM formatted block.
|
||||
func IsValidPEM(val string) bool {
|
||||
block, _ := pem.Decode([]byte(val))
|
||||
return block != nil
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Copyright 2025 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
@ -16,6 +16,7 @@ package context
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
|
||||
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
|
||||
|
||||
"github.com/dapr/kit/crypto/spiffe"
|
||||
|
@ -23,13 +24,48 @@ import (
|
|||
|
||||
type ctxkey int
|
||||
|
||||
const svidKey ctxkey = iota
|
||||
const (
|
||||
x509SvidKey ctxkey = iota
|
||||
jwtSvidKey
|
||||
)
|
||||
|
||||
// Deprecated: use WithX509 instead.
|
||||
// With adds the x509 SVID source from the SPIFFE object to the context.
|
||||
func With(ctx context.Context, spiffe *spiffe.SPIFFE) context.Context {
|
||||
return context.WithValue(ctx, svidKey, spiffe.SVIDSource())
|
||||
return context.WithValue(ctx, x509SvidKey, spiffe.X509SVIDSource())
|
||||
}
|
||||
|
||||
// Deprecated: use X509From instead.
|
||||
// From retrieves the x509 SVID source from the context.
|
||||
func From(ctx context.Context) (x509svid.Source, bool) {
|
||||
svid, ok := ctx.Value(svidKey).(x509svid.Source)
|
||||
svid, ok := ctx.Value(x509SvidKey).(x509svid.Source)
|
||||
return svid, ok
|
||||
}
|
||||
|
||||
// WithX509 adds an x509 SVID source to the context.
|
||||
func WithX509(ctx context.Context, source x509svid.Source) context.Context {
|
||||
return context.WithValue(ctx, x509SvidKey, source)
|
||||
}
|
||||
|
||||
// WithJWT adds a JWT SVID source to the context.
|
||||
func WithJWT(ctx context.Context, source jwtsvid.Source) context.Context {
|
||||
return context.WithValue(ctx, jwtSvidKey, source)
|
||||
}
|
||||
|
||||
// X509From retrieves the x509 SVID source from the context.
|
||||
func X509From(ctx context.Context) (x509svid.Source, bool) {
|
||||
svid, ok := ctx.Value(x509SvidKey).(x509svid.Source)
|
||||
return svid, ok
|
||||
}
|
||||
|
||||
// JWTFrom retrieves the JWT SVID source from the context.
|
||||
func JWTFrom(ctx context.Context) (jwtsvid.Source, bool) {
|
||||
svid, ok := ctx.Value(jwtSvidKey).(jwtsvid.Source)
|
||||
return svid, ok
|
||||
}
|
||||
|
||||
// WithSpiffe adds both X509 and JWT SVID sources to the context.
|
||||
func WithSpiffe(ctx context.Context, spiffe *spiffe.SPIFFE) context.Context {
|
||||
ctx = WithX509(ctx, spiffe.X509SVIDSource())
|
||||
return WithJWT(ctx, spiffe.JWTSVIDSource())
|
||||
}
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
/*
|
||||
Copyright 2025 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package context
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
|
||||
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type mockX509Source struct{}
|
||||
|
||||
func (m *mockX509Source) GetX509SVID() (*x509svid.SVID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type mockJWTSource struct{}
|
||||
|
||||
func (m *mockJWTSource) FetchJWTSVID(context.Context, jwtsvid.Params) (*jwtsvid.SVID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestWithX509FromX509(t *testing.T) {
|
||||
source := &mockX509Source{}
|
||||
ctx := WithX509(context.Background(), source)
|
||||
|
||||
retrieved, ok := X509From(ctx)
|
||||
assert.True(t, ok, "Failed to retrieve X509 source from context")
|
||||
assert.Equal(t, x509svid.Source(source), retrieved, "Retrieved source does not match the original source")
|
||||
}
|
||||
|
||||
func TestWithJWTFromJWT(t *testing.T) {
|
||||
source := &mockJWTSource{}
|
||||
ctx := WithJWT(context.Background(), source)
|
||||
|
||||
retrieved, ok := JWTFrom(ctx)
|
||||
assert.True(t, ok, "Failed to retrieve JWT source from context")
|
||||
assert.Equal(t, jwtsvid.Source(source), retrieved, "Retrieved source does not match the original source")
|
||||
}
|
||||
|
||||
func TestWithFrom(t *testing.T) {
|
||||
x509Source := &mockX509Source{}
|
||||
ctx := WithX509(context.Background(), x509Source)
|
||||
|
||||
// Should be able to retrieve using the legacy From function
|
||||
retrieved, ok := From(ctx)
|
||||
assert.True(t, ok, "Failed to retrieve X509 source from context using legacy From")
|
||||
assert.Equal(t, x509svid.Source(x509Source), retrieved, "Retrieved source does not match the original source using legacy From")
|
||||
}
|
|
@ -25,6 +25,7 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
|
||||
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
|
||||
"k8s.io/utils/clock"
|
||||
|
||||
|
@ -34,8 +35,29 @@ import (
|
|||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
// renewalDivisor represents the divisor for calculating renewal time.
|
||||
// A value of 2 means renewal at 50% of the validity period.
|
||||
renewalDivisor = 2
|
||||
)
|
||||
|
||||
// SVIDResponse represents the response from the SVID request function,
|
||||
// containing both X.509 certificates and a JWT token.
|
||||
type SVIDResponse struct {
|
||||
X509Certificates []*x509.Certificate
|
||||
JWT *string
|
||||
}
|
||||
|
||||
// Identity contains both X.509 and JWT SVIDs for a workload.
|
||||
type Identity struct {
|
||||
X509SVID *x509svid.SVID
|
||||
JWTSVID *jwtsvid.SVID
|
||||
}
|
||||
|
||||
type (
|
||||
RequestSVIDFn func(context.Context, []byte) ([]*x509.Certificate, error)
|
||||
// RequestSVIDFn is the function type that requests SVIDs from a SPIFFE server,
|
||||
// returning both X.509 certificates and a JWT token.
|
||||
RequestSVIDFn func(context.Context, []byte) (*SVIDResponse, error)
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
|
@ -51,11 +73,12 @@ type Options struct {
|
|||
TrustAnchors trustanchors.Interface
|
||||
}
|
||||
|
||||
// SPIFFE is a readable/writeable store of a SPIFFE X.509 SVID.
|
||||
// Used to manage a workload SVID, and share read-only interfaces to consumers.
|
||||
// SPIFFE is a readable/writeable store of SPIFFE SVID credentials.
|
||||
// Used to manage workload SVIDs, and share read-only interfaces to consumers.
|
||||
type SPIFFE struct {
|
||||
currentSVID *x509svid.SVID
|
||||
requestSVIDFn RequestSVIDFn
|
||||
currentX509SVID *x509svid.SVID
|
||||
currentJWTSVID *jwtsvid.SVID
|
||||
requestSVIDFn RequestSVIDFn
|
||||
|
||||
dir *dir.Dir
|
||||
trustAnchors trustanchors.Interface
|
||||
|
@ -92,15 +115,16 @@ func (s *SPIFFE) Run(ctx context.Context) error {
|
|||
}
|
||||
|
||||
s.lock.Lock()
|
||||
s.log.Info("Fetching initial identity certificate")
|
||||
initialCert, err := s.fetchIdentityCertificate(ctx)
|
||||
s.log.Info("Fetching initial identity")
|
||||
initialIdentity, err := s.fetchIdentity(ctx)
|
||||
if err != nil {
|
||||
close(s.readyCh)
|
||||
s.lock.Unlock()
|
||||
return fmt.Errorf("failed to retrieve the initial identity certificate: %w", err)
|
||||
return fmt.Errorf("failed to retrieve the initial identity: %w", err)
|
||||
}
|
||||
|
||||
s.currentSVID = initialCert
|
||||
s.currentX509SVID = initialIdentity.X509SVID
|
||||
s.currentJWTSVID = initialIdentity.JWTSVID
|
||||
close(s.readyCh)
|
||||
s.lock.Unlock()
|
||||
|
||||
|
@ -121,28 +145,48 @@ func (s *SPIFFE) Ready(ctx context.Context) error {
|
|||
}
|
||||
}
|
||||
|
||||
// runRotation starts up the manager responsible for renewing the workload
|
||||
// certificate. Receives the initial certificate to calculate the next rotation
|
||||
// time.
|
||||
// logIdentityInfo creates a log message with expiry details for both X.509 and JWT SVIDs
|
||||
func (s *SPIFFE) logIdentityInfo(prefix string, cert *x509.Certificate, jwtSVID *jwtsvid.SVID, renewTime *time.Time) {
|
||||
msg := prefix + "; cert expires on: %s"
|
||||
args := []any{cert.NotAfter.String()}
|
||||
|
||||
if jwtSVID != nil {
|
||||
msg += ", jwt expires on: %s"
|
||||
args = append(args, jwtSVID.Expiry.String())
|
||||
}
|
||||
|
||||
if renewTime != nil {
|
||||
msg += ", renewal at: %s"
|
||||
args = append(args, renewTime.String())
|
||||
}
|
||||
|
||||
s.log.Infof(msg, args...)
|
||||
}
|
||||
|
||||
// runRotation starts up the manager responsible for renewing the workload identity
|
||||
func (s *SPIFFE) runRotation(ctx context.Context) {
|
||||
defer s.log.Debug("stopping workload cert expiry watcher")
|
||||
defer s.log.Debug("stopping workload identity expiry watcher")
|
||||
|
||||
s.lock.RLock()
|
||||
cert := s.currentSVID.Certificates[0]
|
||||
cert := s.currentX509SVID.Certificates[0]
|
||||
jwtSVID := s.currentJWTSVID
|
||||
s.lock.RUnlock()
|
||||
renewTime := renewalTime(cert.NotBefore, cert.NotAfter)
|
||||
s.log.Infof("Starting workload cert expiry watcher; current cert expires on: %s, renewing at %s",
|
||||
cert.NotAfter.String(), renewTime.String())
|
||||
|
||||
renewTime := calculateRenewalTime(time.Now(), cert, jwtSVID)
|
||||
s.logIdentityInfo("Starting workload identity expiry watcher", cert, jwtSVID, renewTime)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.clock.After(min(time.Minute, renewTime.Sub(s.clock.Now()))):
|
||||
if s.clock.Now().Before(renewTime) {
|
||||
if s.clock.Now().Before(*renewTime) {
|
||||
continue
|
||||
}
|
||||
s.log.Infof("Renewing workload cert; current cert expires on: %s", cert.NotAfter.String())
|
||||
svid, err := s.fetchIdentityCertificate(ctx)
|
||||
|
||||
s.logIdentityInfo("Renewing workload identity", cert, jwtSVID, nil)
|
||||
|
||||
identity, err := s.fetchIdentity(ctx)
|
||||
if err != nil {
|
||||
s.log.Errorf("Error renewing identity certificate, trying again in 10 seconds: %s", err)
|
||||
s.log.Errorf("Error renewing identity, trying again in 10 seconds: %s", err)
|
||||
select {
|
||||
case <-s.clock.After(10 * time.Second):
|
||||
continue
|
||||
|
@ -150,12 +194,16 @@ func (s *SPIFFE) runRotation(ctx context.Context) {
|
|||
return
|
||||
}
|
||||
}
|
||||
|
||||
s.lock.Lock()
|
||||
s.currentSVID = svid
|
||||
cert = svid.Certificates[0]
|
||||
s.currentX509SVID = identity.X509SVID
|
||||
s.currentJWTSVID = identity.JWTSVID
|
||||
cert = identity.X509SVID.Certificates[0]
|
||||
jwtSVID = identity.JWTSVID
|
||||
s.lock.Unlock()
|
||||
renewTime = renewalTime(cert.NotBefore, cert.NotAfter)
|
||||
s.log.Infof("Successfully renewed workload cert; new cert expires on: %s", cert.NotAfter.String())
|
||||
|
||||
renewTime = calculateRenewalTime(time.Now(), cert, jwtSVID)
|
||||
s.logIdentityInfo("Successfully renewed workload identity", cert, jwtSVID, renewTime)
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
@ -163,8 +211,8 @@ func (s *SPIFFE) runRotation(ctx context.Context) {
|
|||
}
|
||||
}
|
||||
|
||||
// fetchIdentityCertificate fetches a new SVID using the configured requester.
|
||||
func (s *SPIFFE) fetchIdentityCertificate(ctx context.Context) (*x509svid.SVID, error) {
|
||||
// Returns both X.509 SVID and JWT SVID (if available).
|
||||
func (s *SPIFFE) fetchIdentity(ctx context.Context) (*Identity, error) {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate private key: %w", err)
|
||||
|
@ -175,27 +223,55 @@ func (s *SPIFFE) fetchIdentityCertificate(ctx context.Context) (*x509svid.SVID,
|
|||
return nil, fmt.Errorf("failed to create sidecar csr: %w", err)
|
||||
}
|
||||
|
||||
workloadcert, err := s.requestSVIDFn(ctx, csrDER)
|
||||
svidResponse, err := s.requestSVIDFn(ctx, csrDER)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(workloadcert) == 0 {
|
||||
if len(svidResponse.X509Certificates) == 0 {
|
||||
return nil, errors.New("no certificates received from sentry")
|
||||
}
|
||||
|
||||
spiffeID, err := x509svid.IDFromCert(workloadcert[0])
|
||||
spiffeID, err := x509svid.IDFromCert(svidResponse.X509Certificates[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing spiffe id from newly signed certificate: %w", err)
|
||||
}
|
||||
|
||||
identity := &Identity{
|
||||
X509SVID: &x509svid.SVID{
|
||||
ID: spiffeID,
|
||||
Certificates: svidResponse.X509Certificates,
|
||||
PrivateKey: key,
|
||||
},
|
||||
}
|
||||
|
||||
// If we have a JWT token, parse it and include it in the identity
|
||||
if svidResponse.JWT != nil {
|
||||
// we are using ParseInsecure here as the expectation is that the
|
||||
// requestSVIDFn will have already parsed and validate the JWT SVID
|
||||
// before returning it.
|
||||
//
|
||||
// we are parsing the token using our SPIFFE ID's trust domain
|
||||
// as the audience as we expect the issuer to always include
|
||||
// that as an audience since that ensures that the token is
|
||||
// valid for us and our trust domain.
|
||||
audiences := []string{spiffeID.TrustDomain().Name()}
|
||||
jwtSvid, err := jwtsvid.ParseInsecure(*svidResponse.JWT, audiences)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JWT SVID: %w", err)
|
||||
}
|
||||
|
||||
identity.JWTSVID = jwtSvid
|
||||
s.log.Infof("Successfully received JWT SVID with expiry: %s", jwtSvid.Expiry.String())
|
||||
}
|
||||
|
||||
if s.dir != nil {
|
||||
pkPEM, err := pem.EncodePrivateKey(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
certPEM, err := pem.EncodeX509Chain(workloadcert)
|
||||
certPEM, err := pem.EncodeX509Chain(svidResponse.X509Certificates)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -205,27 +281,72 @@ func (s *SPIFFE) fetchIdentityCertificate(ctx context.Context) (*x509svid.SVID,
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.dir.Write(map[string][]byte{
|
||||
files := map[string][]byte{
|
||||
"key.pem": pkPEM,
|
||||
"cert.pem": certPEM,
|
||||
"ca.pem": td,
|
||||
}); err != nil {
|
||||
}
|
||||
|
||||
if svidResponse.JWT != nil {
|
||||
files["jwt_svid.token"] = []byte(*svidResponse.JWT)
|
||||
}
|
||||
|
||||
if err := s.dir.Write(files); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &x509svid.SVID{
|
||||
ID: spiffeID,
|
||||
Certificates: workloadcert,
|
||||
PrivateKey: key,
|
||||
}, nil
|
||||
return identity, nil
|
||||
}
|
||||
|
||||
func (s *SPIFFE) SVIDSource() x509svid.Source {
|
||||
func (s *SPIFFE) X509SVIDSource() x509svid.Source {
|
||||
return &svidSource{spiffe: s}
|
||||
}
|
||||
|
||||
func (s *SPIFFE) JWTSVIDSource() jwtsvid.Source {
|
||||
return &svidSource{spiffe: s}
|
||||
}
|
||||
|
||||
// renewalTime is 50% through the certificate validity period.
|
||||
func renewalTime(notBefore, notAfter time.Time) time.Time {
|
||||
return notBefore.Add(notAfter.Sub(notBefore) / 2)
|
||||
return notBefore.Add(notAfter.Sub(notBefore) / renewalDivisor)
|
||||
}
|
||||
|
||||
// calculateRenewalTime returns the earlier renewal time between the X.509 certificate
|
||||
// and JWT SVID (if available) to ensure timely renewal.
|
||||
func calculateRenewalTime(now time.Time, cert *x509.Certificate, jwtSVID *jwtsvid.SVID) *time.Time {
|
||||
certRenewal := renewalTime(cert.NotBefore, cert.NotAfter)
|
||||
|
||||
if jwtSVID == nil {
|
||||
return &certRenewal
|
||||
}
|
||||
|
||||
jwtRenewal := now.Add(jwtSVID.Expiry.Sub(now) / renewalDivisor)
|
||||
|
||||
if jwtRenewal.Before(certRenewal) {
|
||||
return &jwtRenewal
|
||||
}
|
||||
return &certRenewal
|
||||
}
|
||||
|
||||
// audiencesMatch checks if the SVID audiences contain all the requested audiences
|
||||
func audiencesMatch(svidAudiences []string, requestedAudiences []string) bool {
|
||||
if len(requestedAudiences) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Create a map for faster lookup
|
||||
audienceMap := make(map[string]struct{}, len(svidAudiences))
|
||||
for _, audience := range svidAudiences {
|
||||
audienceMap[audience] = struct{}{}
|
||||
}
|
||||
|
||||
// Check if all requested audiences are in the SVID
|
||||
for _, requested := range requestedAudiences {
|
||||
if _, ok := audienceMap[requested]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/spiffe/go-spiffe/v2/spiffeid"
|
||||
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
clocktesting "k8s.io/utils/clock/testing"
|
||||
|
@ -39,6 +40,70 @@ func Test_renewalTime(t *testing.T) {
|
|||
assert.Equal(t, in30, renewalTime(now, in1Min))
|
||||
}
|
||||
|
||||
func Test_calculateRenewalTime(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
certShort := &x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(10 * time.Hour),
|
||||
}
|
||||
|
||||
certLong := &x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
// Expected renewal times for certificates (50% of validity period)
|
||||
certShortRenewal := now.Add(5 * time.Hour)
|
||||
|
||||
// Create JWT SVIDs with different expiry times
|
||||
jwtEarlier := &jwtsvid.SVID{
|
||||
Expiry: now.Add(8 * time.Hour),
|
||||
}
|
||||
|
||||
jwtLater := &jwtsvid.SVID{
|
||||
Expiry: now.Add(30 * time.Hour),
|
||||
}
|
||||
|
||||
// Expected JWT renewal time (50% of remaining time)
|
||||
jwtEarlierRenewal := now.Add(4 * time.Hour)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cert *x509.Certificate
|
||||
jwt *jwtsvid.SVID
|
||||
expected time.Time
|
||||
}{
|
||||
{
|
||||
name: "Certificate only",
|
||||
cert: certShort,
|
||||
jwt: nil,
|
||||
expected: certShortRenewal,
|
||||
},
|
||||
{
|
||||
name: "Certificate and JWT, JWT earlier",
|
||||
cert: certLong,
|
||||
jwt: jwtEarlier,
|
||||
expected: jwtEarlierRenewal,
|
||||
},
|
||||
{
|
||||
name: "Certificate and JWT, Certificate earlier",
|
||||
cert: certShort,
|
||||
jwt: jwtLater,
|
||||
expected: certShortRenewal,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
actual := calculateRenewalTime(now, tt.cert, tt.jwt)
|
||||
|
||||
assert.WithinDuration(t, tt.expected, *actual, time.Millisecond,
|
||||
"Renewal time does not match expected value")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Run(t *testing.T) {
|
||||
t.Run("should return error multiple Runs are called", func(t *testing.T) {
|
||||
pki := test.GenPKI(t, test.PKIOptions{
|
||||
|
@ -47,8 +112,10 @@ func Test_Run(t *testing.T) {
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s := New(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
RequestSVIDFn: func(context.Context, []byte) ([]*x509.Certificate, error) {
|
||||
return []*x509.Certificate{pki.LeafCert}, nil
|
||||
RequestSVIDFn: func(context.Context, []byte) (*SVIDResponse, error) {
|
||||
return &SVIDResponse{
|
||||
X509Certificates: []*x509.Certificate{pki.LeafCert},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
|
||||
|
@ -79,7 +146,7 @@ func Test_Run(t *testing.T) {
|
|||
t.Run("should return error if initial fetch errors", func(t *testing.T) {
|
||||
s := New(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
RequestSVIDFn: func(context.Context, []byte) ([]*x509.Certificate, error) {
|
||||
RequestSVIDFn: func(context.Context, []byte) (*SVIDResponse, error) {
|
||||
return nil, errors.New("this is an error")
|
||||
},
|
||||
})
|
||||
|
@ -95,9 +162,11 @@ func Test_Run(t *testing.T) {
|
|||
var fetches atomic.Int32
|
||||
s := New(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
RequestSVIDFn: func(context.Context, []byte) ([]*x509.Certificate, error) {
|
||||
RequestSVIDFn: func(context.Context, []byte) (*SVIDResponse, error) {
|
||||
fetches.Add(1)
|
||||
return []*x509.Certificate{pki.LeafCert}, nil
|
||||
return &SVIDResponse{
|
||||
X509Certificates: []*x509.Certificate{pki.LeafCert},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
now := time.Now()
|
||||
|
@ -107,15 +176,15 @@ func Test_Run(t *testing.T) {
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
select {
|
||||
case <-s.readyCh:
|
||||
assert.Fail(t, "readyCh should not be closed")
|
||||
default:
|
||||
}
|
||||
|
||||
errCh <- s.Run(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-s.readyCh:
|
||||
assert.Fail(t, "readyCh should not be closed")
|
||||
default:
|
||||
}
|
||||
|
||||
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
|
||||
assert.Equal(t, int32(1), fetches.Load())
|
||||
|
||||
|
@ -144,9 +213,11 @@ func Test_Run(t *testing.T) {
|
|||
var fetches atomic.Int32
|
||||
s := New(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
RequestSVIDFn: func(context.Context, []byte) ([]*x509.Certificate, error) {
|
||||
RequestSVIDFn: func(context.Context, []byte) (*SVIDResponse, error) {
|
||||
fetches.Add(1)
|
||||
return respCert, respErr
|
||||
return &SVIDResponse{
|
||||
X509Certificates: respCert,
|
||||
}, respErr
|
||||
},
|
||||
})
|
||||
now := time.Now()
|
||||
|
@ -156,15 +227,15 @@ func Test_Run(t *testing.T) {
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
select {
|
||||
case <-s.readyCh:
|
||||
assert.Fail(t, "readyCh should not be closed")
|
||||
default:
|
||||
}
|
||||
|
||||
errCh <- s.Run(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-s.readyCh:
|
||||
assert.Fail(t, "readyCh should not be closed")
|
||||
default:
|
||||
}
|
||||
|
||||
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
|
||||
assert.Equal(t, int32(1), fetches.Load())
|
||||
|
||||
|
|
|
@ -14,27 +14,81 @@ limitations under the License.
|
|||
package spiffe
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
|
||||
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
|
||||
)
|
||||
|
||||
// svidSource is an implementation of the Go spiffe x509svid Source interface.
|
||||
var (
|
||||
errNoX509SVIDAvailable = errors.New("no X509 SVID available")
|
||||
errNoJWTSVIDAvailable = errors.New("no JWT SVID available")
|
||||
errAudienceRequired = errors.New("JWT audience is required")
|
||||
)
|
||||
|
||||
// svidSource is an implementation of both go-spiffe x509svid.Source and jwtsvid.Source interfaces.
|
||||
type svidSource struct {
|
||||
spiffe *SPIFFE
|
||||
}
|
||||
|
||||
// GetX509SVID returns the current X.509 certificate identity as a SPIFFE SVID.
|
||||
// Implements the go-spiffe x509 source interface.
|
||||
// Implements the go-spiffe x509svid.Source interface.
|
||||
func (s *svidSource) GetX509SVID() (*x509svid.SVID, error) {
|
||||
s.spiffe.lock.RLock()
|
||||
defer s.spiffe.lock.RUnlock()
|
||||
|
||||
<-s.spiffe.readyCh
|
||||
|
||||
svid := s.spiffe.currentSVID
|
||||
svid := s.spiffe.currentX509SVID
|
||||
if svid == nil {
|
||||
return nil, errors.New("no SVID available")
|
||||
return nil, errNoX509SVIDAvailable
|
||||
}
|
||||
|
||||
return svid, nil
|
||||
}
|
||||
|
||||
// audienceMismatchError is an error that contains information about mismatched audiences
|
||||
type audienceMismatchError struct {
|
||||
expected []string
|
||||
actual []string
|
||||
}
|
||||
|
||||
func (e *audienceMismatchError) Error() string {
|
||||
return fmt.Sprintf("JWT SVID has different audiences than requested: expected %s, got %s",
|
||||
strings.Join(e.expected, ", "), strings.Join(e.actual, ", "))
|
||||
}
|
||||
|
||||
// FetchJWTSVID returns the current JWT SVID.
|
||||
// Implements the go-spiffe jwtsvid.Source interface.
|
||||
func (s *svidSource) FetchJWTSVID(ctx context.Context, params jwtsvid.Params) (*jwtsvid.SVID, error) {
|
||||
s.spiffe.lock.RLock()
|
||||
defer s.spiffe.lock.RUnlock()
|
||||
|
||||
if params.Audience == "" {
|
||||
return nil, errAudienceRequired
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-s.spiffe.readyCh:
|
||||
}
|
||||
|
||||
svid := s.spiffe.currentJWTSVID
|
||||
if svid == nil {
|
||||
return nil, errNoJWTSVIDAvailable
|
||||
}
|
||||
|
||||
// verify that the audience being requested is the same as the audience in the SVID
|
||||
// WARN: we do not check extra audiences here.
|
||||
if !audiencesMatch(svid.Audience, []string{params.Audience}) {
|
||||
return nil, &audienceMismatchError{
|
||||
expected: []string{params.Audience},
|
||||
actual: svid.Audience,
|
||||
}
|
||||
}
|
||||
|
||||
return svid, nil
|
||||
|
|
|
@ -14,11 +14,177 @@ limitations under the License.
|
|||
package spiffe
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/spiffe/go-spiffe/v2/spiffeid"
|
||||
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
|
||||
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_svidSource(*testing.T) {
|
||||
var _ x509svid.Source = new(svidSource)
|
||||
var _ jwtsvid.Source = new(svidSource)
|
||||
}
|
||||
|
||||
// createMockJWTSVID creates a mock JWT SVID for testing
|
||||
func createMockJWTSVID(audiences []string) (*jwtsvid.SVID, error) {
|
||||
td, err := spiffeid.TrustDomainFromString("example.org")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id, err := spiffeid.FromSegments(td, "workload")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svid := &jwtsvid.SVID{
|
||||
ID: id,
|
||||
Audience: audiences,
|
||||
Expiry: time.Now().Add(time.Hour),
|
||||
}
|
||||
|
||||
return svid, nil
|
||||
}
|
||||
|
||||
func TestFetchJWTSVID(t *testing.T) {
|
||||
t.Run("should return error when audience is empty", func(t *testing.T) {
|
||||
s := &svidSource{
|
||||
spiffe: &SPIFFE{
|
||||
readyCh: make(chan struct{}),
|
||||
lock: sync.RWMutex{},
|
||||
},
|
||||
}
|
||||
close(s.spiffe.readyCh) // Mark as ready
|
||||
|
||||
svid, err := s.FetchJWTSVID(context.Background(), jwtsvid.Params{
|
||||
Audience: "",
|
||||
})
|
||||
|
||||
require.Nil(t, svid)
|
||||
require.ErrorIs(t, err, errAudienceRequired)
|
||||
})
|
||||
|
||||
t.Run("should return error when no JWT SVID available", func(t *testing.T) {
|
||||
s := &svidSource{
|
||||
spiffe: &SPIFFE{
|
||||
readyCh: make(chan struct{}),
|
||||
lock: sync.RWMutex{},
|
||||
currentJWTSVID: nil,
|
||||
},
|
||||
}
|
||||
close(s.spiffe.readyCh) // Mark as ready
|
||||
|
||||
svid, err := s.FetchJWTSVID(context.Background(), jwtsvid.Params{
|
||||
Audience: "test-audience",
|
||||
})
|
||||
|
||||
require.Nil(t, svid)
|
||||
require.ErrorIs(t, err, errNoJWTSVIDAvailable)
|
||||
})
|
||||
|
||||
t.Run("should return error when audience doesn't match", func(t *testing.T) {
|
||||
// Create a mock SVID with a specific audience
|
||||
mockJWTSVID, err := createMockJWTSVID([]string{"actual-audience"})
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &svidSource{
|
||||
spiffe: &SPIFFE{
|
||||
readyCh: make(chan struct{}),
|
||||
lock: sync.RWMutex{},
|
||||
currentJWTSVID: mockJWTSVID,
|
||||
},
|
||||
}
|
||||
close(s.spiffe.readyCh) // Mark as ready
|
||||
|
||||
svid, err := s.FetchJWTSVID(context.Background(), jwtsvid.Params{
|
||||
Audience: "requested-audience",
|
||||
})
|
||||
|
||||
require.Nil(t, svid)
|
||||
require.Error(t, err)
|
||||
|
||||
// Verify the specific error type and contents
|
||||
audienceErr, ok := err.(*audienceMismatchError)
|
||||
require.True(t, ok, "Expected audienceMismatchError")
|
||||
require.Equal(t, "JWT SVID has different audiences than requested: expected requested-audience, got actual-audience", audienceErr.Error())
|
||||
})
|
||||
|
||||
t.Run("should return JWT SVID when audience matches", func(t *testing.T) {
|
||||
mockJWTSVID, err := createMockJWTSVID([]string{"test-audience", "extra-audience"})
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &svidSource{
|
||||
spiffe: &SPIFFE{
|
||||
readyCh: make(chan struct{}),
|
||||
lock: sync.RWMutex{},
|
||||
currentJWTSVID: mockJWTSVID,
|
||||
},
|
||||
}
|
||||
close(s.spiffe.readyCh) // Mark as ready
|
||||
|
||||
svid, err := s.FetchJWTSVID(context.Background(), jwtsvid.Params{
|
||||
Audience: "test-audience",
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, mockJWTSVID, svid)
|
||||
})
|
||||
|
||||
t.Run("should wait for readyCh before checking SVID", func(t *testing.T) {
|
||||
mockJWTSVID, err := createMockJWTSVID([]string{"test-audience"})
|
||||
require.NoError(t, err)
|
||||
|
||||
readyCh := make(chan struct{})
|
||||
s := &svidSource{
|
||||
spiffe: &SPIFFE{
|
||||
readyCh: readyCh,
|
||||
lock: sync.RWMutex{},
|
||||
currentJWTSVID: mockJWTSVID,
|
||||
},
|
||||
}
|
||||
|
||||
// Start goroutine to fetch SVID
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
resultCh := make(chan struct {
|
||||
svid *jwtsvid.SVID
|
||||
err error
|
||||
})
|
||||
|
||||
go func() {
|
||||
svid, err := s.FetchJWTSVID(ctx, jwtsvid.Params{
|
||||
Audience: "test-audience",
|
||||
})
|
||||
resultCh <- struct {
|
||||
svid *jwtsvid.SVID
|
||||
err error
|
||||
}{svid, err}
|
||||
}()
|
||||
|
||||
// require that fetch is blocked
|
||||
select {
|
||||
case <-resultCh:
|
||||
t.Fatal("FetchJWTSVID should be blocked until readyCh is closed")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
// Expected behavior - fetch is blocked
|
||||
}
|
||||
|
||||
// Close readyCh to unblock fetch
|
||||
close(readyCh)
|
||||
|
||||
// Now fetch should complete
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
require.NoError(t, result.err)
|
||||
require.NotNil(t, result.svid)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("FetchJWTSVID should have completed after readyCh was closed")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -11,40 +11,52 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
*/
|
||||
|
||||
package trustanchors
|
||||
package file
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/spiffe/go-spiffe/v2/bundle/jwtbundle"
|
||||
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
|
||||
"github.com/spiffe/go-spiffe/v2/spiffeid"
|
||||
"k8s.io/utils/clock"
|
||||
|
||||
"github.com/dapr/kit/concurrency"
|
||||
"github.com/dapr/kit/crypto/pem"
|
||||
"github.com/dapr/kit/crypto/spiffe/trustanchors"
|
||||
"github.com/dapr/kit/fswatcher"
|
||||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
||||
type OptionsFile struct {
|
||||
Log logger.Logger
|
||||
Path string
|
||||
var (
|
||||
// ErrTrustAnchorsClosed is returned when an operation is performed on closed trust anchors.
|
||||
ErrTrustAnchorsClosed = errors.New("trust anchors is closed")
|
||||
|
||||
// ErrFailedToReadTrustAnchorsFile is returned when the trust anchors file cannot be read.
|
||||
ErrFailedToReadTrustAnchorsFile = errors.New("failed to read trust anchors file")
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
Log logger.Logger
|
||||
CAPath string
|
||||
JwksPath *string
|
||||
}
|
||||
|
||||
// file is a TrustAnchors implementation that uses a file as the source of trust
|
||||
// anchors. The trust anchors will be updated when the file changes.
|
||||
type file struct {
|
||||
log logger.Logger
|
||||
path string
|
||||
bundle *x509bundle.Bundle
|
||||
rootPEM []byte
|
||||
log logger.Logger
|
||||
caPath string
|
||||
jwksPath *string
|
||||
x509Bundle *x509bundle.Bundle
|
||||
jwtBundle *jwtbundle.Bundle
|
||||
rootPEM []byte
|
||||
|
||||
// fswatcherInterval is the interval at which the trust anchors file changes
|
||||
// are batched. Used for testing only, and 500ms otherwise.
|
||||
|
@ -65,17 +77,18 @@ type file struct {
|
|||
caEvent chan struct{}
|
||||
}
|
||||
|
||||
func FromFile(opts OptionsFile) Interface {
|
||||
func From(opts Options) trustanchors.Interface {
|
||||
return &file{
|
||||
fsWatcherInterval: time.Millisecond * 500,
|
||||
initFileWatchInterval: time.Second,
|
||||
|
||||
log: opts.Log,
|
||||
path: opts.Path,
|
||||
clock: clock.RealClock{},
|
||||
readyCh: make(chan struct{}),
|
||||
closeCh: make(chan struct{}),
|
||||
caEvent: make(chan struct{}),
|
||||
log: opts.Log,
|
||||
caPath: opts.CAPath,
|
||||
jwksPath: opts.JwksPath,
|
||||
clock: clock.RealClock{},
|
||||
readyCh: make(chan struct{}),
|
||||
closeCh: make(chan struct{}),
|
||||
caEvent: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -87,31 +100,39 @@ func (f *file) Run(ctx context.Context) error {
|
|||
defer close(f.closeCh)
|
||||
|
||||
for {
|
||||
_, err := os.Stat(f.path)
|
||||
if err == nil {
|
||||
break
|
||||
fs := []string{f.caPath}
|
||||
if f.jwksPath != nil {
|
||||
fs = append(fs, *f.jwksPath)
|
||||
}
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
|
||||
if found, err := filesExist(fs...); err != nil {
|
||||
return err
|
||||
} else if found {
|
||||
break
|
||||
}
|
||||
|
||||
// Trust anchors file not be provided yet, wait.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("failed to find trust anchors file '%s': %w", f.path, ctx.Err())
|
||||
return fmt.Errorf("failed to find trust anchors file '%s': %w", f.caPath, ctx.Err())
|
||||
case <-f.clock.After(f.initFileWatchInterval):
|
||||
f.log.Warnf("Trust anchors file '%s' not found, waiting...", f.path)
|
||||
f.log.Warnf("Trust anchors file '%s' not found, waiting...", f.caPath)
|
||||
}
|
||||
}
|
||||
|
||||
f.log.Infof("Trust anchors file '%s' found", f.path)
|
||||
f.log.Infof("Trust anchors file '%s' found", f.caPath)
|
||||
|
||||
if err := f.updateAnchors(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
targets := []string{f.caPath}
|
||||
if f.jwksPath != nil {
|
||||
targets = append(targets, *f.jwksPath)
|
||||
}
|
||||
|
||||
fs, err := fswatcher.New(fswatcher.Options{
|
||||
Targets: []string{filepath.Dir(f.path)},
|
||||
Targets: targets,
|
||||
Interval: &f.fsWatcherInterval,
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -120,7 +141,11 @@ func (f *file) Run(ctx context.Context) error {
|
|||
|
||||
close(f.readyCh)
|
||||
|
||||
f.log.Infof("Watching trust anchors file '%s' for changes", f.path)
|
||||
f.log.Infof("Watching trust anchors file '%s' for changes", f.caPath)
|
||||
if f.jwksPath != nil {
|
||||
f.log.Infof("Watching JWT bundle file '%s' for changes", f.jwksPath)
|
||||
}
|
||||
|
||||
return concurrency.NewRunnerManager(
|
||||
func(ctx context.Context) error {
|
||||
return fs.Run(ctx, f.caEvent)
|
||||
|
@ -134,7 +159,7 @@ func (f *file) Run(ctx context.Context) error {
|
|||
f.log.Info("Trust anchors file changed, reloading trust anchors")
|
||||
|
||||
if err = f.updateAnchors(ctx); err != nil {
|
||||
return fmt.Errorf("failed to read trust anchors file '%s': %v", f.path, err)
|
||||
return fmt.Errorf("%w: '%s': %v", ErrFailedToReadTrustAnchorsFile, f.caPath, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -147,7 +172,7 @@ func (f *file) CurrentTrustAnchors(ctx context.Context) ([]byte, error) {
|
|||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-f.closeCh:
|
||||
return nil, errors.New("trust anchors is closed")
|
||||
return nil, ErrTrustAnchorsClosed
|
||||
case <-f.readyCh:
|
||||
}
|
||||
|
||||
|
@ -162,9 +187,9 @@ func (f *file) updateAnchors(ctx context.Context) error {
|
|||
f.lock.Lock()
|
||||
defer f.lock.Unlock()
|
||||
|
||||
rootPEMs, err := os.ReadFile(f.path)
|
||||
rootPEMs, err := os.ReadFile(f.caPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read trust anchors file '%s': %w", f.path, err)
|
||||
return fmt.Errorf("failed to read trust anchors file '%s': %w", f.caPath, err)
|
||||
}
|
||||
|
||||
trustAnchorCerts, err := pem.DecodePEMCertificates(rootPEMs)
|
||||
|
@ -173,7 +198,20 @@ func (f *file) updateAnchors(ctx context.Context) error {
|
|||
}
|
||||
|
||||
f.rootPEM = rootPEMs
|
||||
f.bundle = x509bundle.FromX509Authorities(spiffeid.TrustDomain{}, trustAnchorCerts)
|
||||
f.x509Bundle = x509bundle.FromX509Authorities(spiffeid.TrustDomain{}, trustAnchorCerts)
|
||||
|
||||
if f.jwksPath != nil {
|
||||
jwks, err := os.ReadFile(*f.jwksPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read JWT bundle file '%s': %w", *f.jwksPath, err)
|
||||
}
|
||||
|
||||
jwtBundle, err := jwtbundle.Parse(spiffeid.TrustDomain{}, jwks)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse JWT bundle: %w", err)
|
||||
}
|
||||
f.jwtBundle = jwtBundle
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
defer wg.Wait()
|
||||
|
@ -195,13 +233,26 @@ func (f *file) updateAnchors(ctx context.Context) error {
|
|||
func (f *file) GetX509BundleForTrustDomain(_ spiffeid.TrustDomain) (*x509bundle.Bundle, error) {
|
||||
select {
|
||||
case <-f.closeCh:
|
||||
return nil, errors.New("trust anchors is closed")
|
||||
return nil, ErrTrustAnchorsClosed
|
||||
case <-f.readyCh:
|
||||
}
|
||||
|
||||
f.lock.RLock()
|
||||
defer f.lock.RUnlock()
|
||||
bundle := f.bundle
|
||||
bundle := f.x509Bundle
|
||||
return bundle, nil
|
||||
}
|
||||
|
||||
func (f *file) GetJWTBundleForTrustDomain(_ spiffeid.TrustDomain) (*jwtbundle.Bundle, error) {
|
||||
select {
|
||||
case <-f.closeCh:
|
||||
return nil, ErrTrustAnchorsClosed
|
||||
case <-f.readyCh:
|
||||
}
|
||||
|
||||
f.lock.RLock()
|
||||
defer f.lock.RUnlock()
|
||||
bundle := f.jwtBundle
|
||||
return bundle, nil
|
||||
}
|
||||
|
||||
|
@ -231,3 +282,19 @@ func (f *file) Watch(ctx context.Context, ch chan<- []byte) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func filesExist(paths ...string) (bool, error) {
|
||||
for _, path := range paths {
|
||||
if path == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("failed to stat file '%s': %w", path, err)
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
|
@ -11,7 +11,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
*/
|
||||
|
||||
package trustanchors
|
||||
package file
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -31,9 +31,9 @@ import (
|
|||
func TestFile_Run(t *testing.T) {
|
||||
t.Run("if Run multiple times, expect error", func(t *testing.T) {
|
||||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
ta := FromFile(OptionsFile{
|
||||
Log: logger.NewLogger("test"),
|
||||
Path: tmp,
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
f, ok := ta.(*file)
|
||||
require.True(t, ok)
|
||||
|
@ -74,9 +74,9 @@ func TestFile_Run(t *testing.T) {
|
|||
t.Run("if file is not found and context cancelled, should return ctx.Err", func(t *testing.T) {
|
||||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
|
||||
ta := FromFile(OptionsFile{
|
||||
Log: logger.NewLogger("test"),
|
||||
Path: tmp,
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
f, ok := ta.(*file)
|
||||
require.True(t, ok)
|
||||
|
@ -102,9 +102,9 @@ func TestFile_Run(t *testing.T) {
|
|||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, nil, 0o600))
|
||||
|
||||
ta := FromFile(OptionsFile{
|
||||
Log: logger.NewLogger("test"),
|
||||
Path: tmp,
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
f, ok := ta.(*file)
|
||||
require.True(t, ok)
|
||||
|
@ -127,9 +127,9 @@ func TestFile_Run(t *testing.T) {
|
|||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, []byte("garbage data"), 0o600))
|
||||
|
||||
ta := FromFile(OptionsFile{
|
||||
Log: logger.NewLogger("test"),
|
||||
Path: tmp,
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
f, ok := ta.(*file)
|
||||
require.True(t, ok)
|
||||
|
@ -154,9 +154,9 @@ func TestFile_Run(t *testing.T) {
|
|||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, root, 0o600))
|
||||
|
||||
ta := FromFile(OptionsFile{
|
||||
Log: logger.NewLogger("test"),
|
||||
Path: tmp,
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
f, ok := ta.(*file)
|
||||
require.True(t, ok)
|
||||
|
@ -180,9 +180,9 @@ func TestFile_Run(t *testing.T) {
|
|||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, pki.RootCertPEM, 0o600))
|
||||
|
||||
ta := FromFile(OptionsFile{
|
||||
Log: logger.NewLogger("test"),
|
||||
Path: tmp,
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
f, ok := ta.(*file)
|
||||
require.True(t, ok)
|
||||
|
@ -211,9 +211,9 @@ func TestFile_Run(t *testing.T) {
|
|||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, root, 0o600))
|
||||
|
||||
ta := FromFile(OptionsFile{
|
||||
Log: logger.NewLogger("test"),
|
||||
Path: tmp,
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
f, ok := ta.(*file)
|
||||
require.True(t, ok)
|
||||
|
@ -242,9 +242,9 @@ func TestFile_Run(t *testing.T) {
|
|||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
|
||||
|
||||
ta := FromFile(OptionsFile{
|
||||
Log: logger.NewLogger("test"),
|
||||
Path: tmp,
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
f, ok := ta.(*file)
|
||||
require.True(t, ok)
|
||||
|
@ -273,9 +273,9 @@ func TestFile_Run(t *testing.T) {
|
|||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
|
||||
|
||||
ta := FromFile(OptionsFile{
|
||||
Log: logger.NewLogger("test"),
|
||||
Path: tmp,
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
f, ok := ta.(*file)
|
||||
require.True(t, ok)
|
||||
|
@ -311,9 +311,9 @@ func TestFile_GetX509BundleForTrustDomain(t *testing.T) {
|
|||
root := append(pki.RootCertPEM, []byte("garbage data")...)
|
||||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, root, 0o600))
|
||||
ta := FromFile(OptionsFile{
|
||||
Log: logger.NewLogger("test"),
|
||||
Path: tmp,
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
f, ok := ta.(*file)
|
||||
require.True(t, ok)
|
||||
|
@ -337,7 +337,7 @@ func TestFile_GetX509BundleForTrustDomain(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
bundle, err := f.GetX509BundleForTrustDomain(trustDomain1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, f.bundle, bundle)
|
||||
assert.Equal(t, f.x509Bundle, bundle)
|
||||
b1, err := bundle.Marshal()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, pki.RootCertPEM, b1)
|
||||
|
@ -346,7 +346,7 @@ func TestFile_GetX509BundleForTrustDomain(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
bundle, err = f.GetX509BundleForTrustDomain(trustDomain2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, f.bundle, bundle)
|
||||
assert.Equal(t, f.x509Bundle, bundle)
|
||||
b2, err := bundle.Marshal()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, pki.RootCertPEM, b2)
|
||||
|
@ -359,9 +359,9 @@ func TestFile_Watch(t *testing.T) {
|
|||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, pki.RootCertPEM, 0o600))
|
||||
|
||||
ta := FromFile(OptionsFile{
|
||||
Log: logger.NewLogger("test"),
|
||||
Path: tmp,
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
f, ok := ta.(*file)
|
||||
require.True(t, ok)
|
||||
|
@ -400,9 +400,9 @@ func TestFile_Watch(t *testing.T) {
|
|||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, pki.RootCertPEM, 0o600))
|
||||
|
||||
ta := FromFile(OptionsFile{
|
||||
Log: logger.NewLogger("test"),
|
||||
Path: tmp,
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
f, ok := ta.(*file)
|
||||
require.True(t, ok)
|
||||
|
@ -446,9 +446,9 @@ func TestFile_Watch(t *testing.T) {
|
|||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, pki1.RootCertPEM, 0o600))
|
||||
|
||||
ta := FromFile(OptionsFile{
|
||||
Log: logger.NewLogger("test"),
|
||||
Path: tmp,
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
f, ok := ta.(*file)
|
||||
require.True(t, ok)
|
||||
|
@ -529,9 +529,9 @@ func TestFile_CurrentTrustAnchors(t *testing.T) {
|
|||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, pki1.RootCertPEM, 0o600))
|
||||
|
||||
ta := FromFile(OptionsFile{
|
||||
Log: logger.NewLogger("test"),
|
||||
Path: tmp,
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
f, ok := ta.(*file)
|
||||
require.True(t, ok)
|
||||
|
@ -547,6 +547,7 @@ func TestFile_CurrentTrustAnchors(t *testing.T) {
|
|||
//nolint:gocritic
|
||||
roots := append(pki1.RootCertPEM, pki2.RootCertPEM...)
|
||||
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
|
||||
time.Sleep(time.Millisecond * 10) // adding a small delay to ensure the file watcher has time to pick up the change
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
pem, err := ta.CurrentTrustAnchors(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
@ -556,6 +557,7 @@ func TestFile_CurrentTrustAnchors(t *testing.T) {
|
|||
//nolint:gocritic
|
||||
roots = append(pki1.RootCertPEM, append(pki2.RootCertPEM, pki3.RootCertPEM...)...)
|
||||
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
|
||||
time.Sleep(time.Millisecond * 10) // adding a small delay to ensure the file watcher has time to pick up the change
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
pem, err := ta.CurrentTrustAnchors(context.Background())
|
||||
require.NoError(t, err)
|
|
@ -11,16 +11,18 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
*/
|
||||
|
||||
package trustanchors
|
||||
package multi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/spiffe/go-spiffe/v2/bundle/jwtbundle"
|
||||
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
|
||||
"github.com/spiffe/go-spiffe/v2/spiffeid"
|
||||
|
||||
"github.com/dapr/kit/concurrency"
|
||||
"github.com/dapr/kit/crypto/spiffe/trustanchors"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -28,17 +30,17 @@ var (
|
|||
ErrTrustDomainNotFound = errors.New("trust domain not found")
|
||||
)
|
||||
|
||||
type OptionsMulti struct {
|
||||
TrustAnchors map[spiffeid.TrustDomain]Interface
|
||||
type Options struct {
|
||||
TrustAnchors map[spiffeid.TrustDomain]trustanchors.Interface
|
||||
}
|
||||
|
||||
// multi is a TrustAnchors implementation which uses multiple trust anchors
|
||||
// which are indexed by trust domain.
|
||||
type multi struct {
|
||||
trustAnchors map[spiffeid.TrustDomain]Interface
|
||||
trustAnchors map[spiffeid.TrustDomain]trustanchors.Interface
|
||||
}
|
||||
|
||||
func FromMulti(opts OptionsMulti) Interface {
|
||||
func From(opts Options) trustanchors.Interface {
|
||||
return &multi{
|
||||
trustAnchors: opts.TrustAnchors,
|
||||
}
|
||||
|
@ -69,6 +71,16 @@ func (m *multi) GetX509BundleForTrustDomain(td spiffeid.TrustDomain) (*x509bundl
|
|||
return nil, ErrTrustDomainNotFound
|
||||
}
|
||||
|
||||
func (m *multi) GetJWTBundleForTrustDomain(td spiffeid.TrustDomain) (*jwtbundle.Bundle, error) {
|
||||
for tad, ta := range m.trustAnchors {
|
||||
if td.Compare(tad) == 0 {
|
||||
return ta.GetJWTBundleForTrustDomain(td)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, ErrTrustDomainNotFound
|
||||
}
|
||||
|
||||
func (m *multi) Watch(context.Context, chan<- []byte) {
|
||||
return
|
||||
}
|
|
@ -11,7 +11,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
*/
|
||||
|
||||
package trustanchors
|
||||
package static
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -19,31 +19,52 @@ import (
|
|||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/spiffe/go-spiffe/v2/bundle/jwtbundle"
|
||||
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
|
||||
"github.com/spiffe/go-spiffe/v2/spiffeid"
|
||||
|
||||
"github.com/dapr/kit/crypto/pem"
|
||||
"github.com/dapr/kit/crypto/spiffe/trustanchors"
|
||||
)
|
||||
|
||||
// static is a TrustAcnhors implementation that uses a static list of trust
|
||||
// anchors.
|
||||
type static struct {
|
||||
bundle *x509bundle.Bundle
|
||||
anchors []byte
|
||||
running atomic.Bool
|
||||
closeCh chan struct{}
|
||||
x509Bundle *x509bundle.Bundle
|
||||
jwtBundle *jwtbundle.Bundle
|
||||
anchors []byte
|
||||
running atomic.Bool
|
||||
closeCh chan struct{}
|
||||
}
|
||||
|
||||
func FromStatic(anchors []byte) (Interface, error) {
|
||||
trustAnchorCerts, err := pem.DecodePEMCertificates(anchors)
|
||||
type Options struct {
|
||||
Anchors []byte
|
||||
Jwks []byte
|
||||
}
|
||||
|
||||
func From(opts Options) (trustanchors.Interface, error) {
|
||||
// Create empty trust domain for now
|
||||
emptyTD := spiffeid.TrustDomain{}
|
||||
|
||||
var jwtBundle *jwtbundle.Bundle
|
||||
if opts.Jwks != nil {
|
||||
var err error
|
||||
jwtBundle, err = jwtbundle.Parse(emptyTD, opts.Jwks)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create JWT bundle: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
trustAnchorCerts, err := pem.DecodePEMCertificates(opts.Anchors)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode trust anchors: %w", err)
|
||||
}
|
||||
|
||||
return &static{
|
||||
anchors: anchors,
|
||||
bundle: x509bundle.FromX509Authorities(spiffeid.TrustDomain{}, trustAnchorCerts),
|
||||
closeCh: make(chan struct{}),
|
||||
anchors: opts.Anchors,
|
||||
x509Bundle: x509bundle.FromX509Authorities(emptyTD, trustAnchorCerts),
|
||||
jwtBundle: jwtBundle,
|
||||
closeCh: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -63,7 +84,11 @@ func (s *static) Run(ctx context.Context) error {
|
|||
}
|
||||
|
||||
func (s *static) GetX509BundleForTrustDomain(spiffeid.TrustDomain) (*x509bundle.Bundle, error) {
|
||||
return s.bundle, nil
|
||||
return s.x509Bundle, nil
|
||||
}
|
||||
|
||||
func (s *static) GetJWTBundleForTrustDomain(_ spiffeid.TrustDomain) (*jwtbundle.Bundle, error) {
|
||||
return s.jwtBundle, nil
|
||||
}
|
||||
|
||||
func (s *static) Watch(ctx context.Context, _ chan<- []byte) {
|
|
@ -11,7 +11,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
*/
|
||||
|
||||
package trustanchors
|
||||
package static
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -27,30 +27,30 @@ import (
|
|||
|
||||
func TestFromStatic(t *testing.T) {
|
||||
t.Run("empty root should return error", func(t *testing.T) {
|
||||
_, err := FromStatic(nil)
|
||||
_, err := From(Options{})
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("garbage data should return error", func(t *testing.T) {
|
||||
_, err := FromStatic([]byte("garbage data"))
|
||||
_, err := From(Options{Anchors: []byte("garbage data")})
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("just garbage data should return error", func(t *testing.T) {
|
||||
_, err := FromStatic([]byte("garbage data"))
|
||||
_, err := From(Options{Anchors: []byte("garbage data")})
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("garbage data in root should return error", func(t *testing.T) {
|
||||
pki := test.GenPKI(t, test.PKIOptions{})
|
||||
root := pki.RootCertPEM[10:]
|
||||
_, err := FromStatic(root)
|
||||
_, err := From(Options{Anchors: root})
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("single root should be correctly parsed", func(t *testing.T) {
|
||||
pki := test.GenPKI(t, test.PKIOptions{})
|
||||
ta, err := FromStatic(pki.RootCertPEM)
|
||||
ta, err := From(Options{Anchors: pki.RootCertPEM})
|
||||
require.NoError(t, err)
|
||||
taPEM, err := ta.CurrentTrustAnchors(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
@ -61,7 +61,7 @@ func TestFromStatic(t *testing.T) {
|
|||
pki := test.GenPKI(t, test.PKIOptions{})
|
||||
//nolint:gocritic
|
||||
root := append(pki.RootCertPEM, []byte("garbage data")...)
|
||||
ta, err := FromStatic(root)
|
||||
ta, err := From(Options{Anchors: root})
|
||||
require.NoError(t, err)
|
||||
taPEM, err := ta.CurrentTrustAnchors(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
@ -72,7 +72,7 @@ func TestFromStatic(t *testing.T) {
|
|||
pki1, pki2 := test.GenPKI(t, test.PKIOptions{}), test.GenPKI(t, test.PKIOptions{})
|
||||
//nolint:gocritic
|
||||
roots := append(pki1.RootCertPEM, pki2.RootCertPEM...)
|
||||
ta, err := FromStatic(roots)
|
||||
ta, err := From(Options{Anchors: roots})
|
||||
require.NoError(t, err)
|
||||
taPEM, err := ta.CurrentTrustAnchors(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
@ -85,7 +85,7 @@ func TestStatic_GetX509BundleForTrustDomain(t *testing.T) {
|
|||
pki := test.GenPKI(t, test.PKIOptions{})
|
||||
//nolint:gocritic
|
||||
root := append(pki.RootCertPEM, []byte("garbage data")...)
|
||||
ta, err := FromStatic(root)
|
||||
ta, err := From(Options{Anchors: root})
|
||||
require.NoError(t, err)
|
||||
s, ok := ta.(*static)
|
||||
require.True(t, ok)
|
||||
|
@ -94,7 +94,7 @@ func TestStatic_GetX509BundleForTrustDomain(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
bundle, err := s.GetX509BundleForTrustDomain(trustDomain1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, s.bundle, bundle)
|
||||
assert.Equal(t, s.x509Bundle, bundle)
|
||||
b1, err := bundle.Marshal()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, pki.RootCertPEM, b1)
|
||||
|
@ -103,7 +103,7 @@ func TestStatic_GetX509BundleForTrustDomain(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
bundle, err = s.GetX509BundleForTrustDomain(trustDomain2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, s.bundle, bundle)
|
||||
assert.Equal(t, s.x509Bundle, bundle)
|
||||
b2, err := bundle.Marshal()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, pki.RootCertPEM, b2)
|
||||
|
@ -113,7 +113,7 @@ func TestStatic_GetX509BundleForTrustDomain(t *testing.T) {
|
|||
func TestStatic_Run(t *testing.T) {
|
||||
t.Run("Run multiple times should return error", func(t *testing.T) {
|
||||
pki := test.GenPKI(t, test.PKIOptions{})
|
||||
ta, err := FromStatic(pki.RootCertPEM)
|
||||
ta, err := From(Options{Anchors: pki.RootCertPEM})
|
||||
require.NoError(t, err)
|
||||
s, ok := ta.(*static)
|
||||
require.True(t, ok)
|
||||
|
@ -154,7 +154,7 @@ func TestStatic_Run(t *testing.T) {
|
|||
func TestStatic_Watch(t *testing.T) {
|
||||
t.Run("should return when context is cancelled", func(t *testing.T) {
|
||||
pki := test.GenPKI(t, test.PKIOptions{})
|
||||
ta, err := FromStatic(pki.RootCertPEM)
|
||||
ta, err := From(Options{Anchors: pki.RootCertPEM})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
@ -176,7 +176,7 @@ func TestStatic_Watch(t *testing.T) {
|
|||
|
||||
t.Run("should return when cancel is closed via closed Run", func(t *testing.T) {
|
||||
pki := test.GenPKI(t, test.PKIOptions{})
|
||||
ta, err := FromStatic(pki.RootCertPEM)
|
||||
ta, err := From(Options{Anchors: pki.RootCertPEM})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
|
@ -16,6 +16,7 @@ package trustanchors
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/spiffe/go-spiffe/v2/bundle/jwtbundle"
|
||||
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
|
||||
)
|
||||
|
||||
|
@ -23,8 +24,10 @@ import (
|
|||
// Allows consumers to get the current trust anchor bundle, and subscribe to
|
||||
// bundle updates.
|
||||
type Interface interface {
|
||||
// Source implements the SPIFFE trust anchor bundle source.
|
||||
// Source implements the SPIFFE trust anchor x509 bundle source.
|
||||
x509bundle.Source
|
||||
// Source implements the SPIFFE trust anchor jwt bundle source.
|
||||
jwtbundle.Source
|
||||
|
||||
// CurrentTrustAnchors returns the current trust anchor PEM bundle.
|
||||
CurrentTrustAnchors(ctx context.Context) ([]byte, error)
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package env
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GetDurationWithRange returns the time.Duration value of the environment variable specified by `envVar`.
|
||||
// If the environment variable is not set, it returns `defaultValue`.
|
||||
// If the value is set but is not valid (not a valid time.Duration or falls outside the specified range
|
||||
// [minValue, maxValue] inclusively), it returns `defaultValue` and an error.
|
||||
func GetDurationWithRange(envVar string, defaultValue, min, max time.Duration) (time.Duration, error) {
|
||||
v := os.Getenv(envVar)
|
||||
if v == "" {
|
||||
return defaultValue, nil
|
||||
}
|
||||
|
||||
val, err := time.ParseDuration(v)
|
||||
if err != nil {
|
||||
return defaultValue, fmt.Errorf("invalid time.Duration value %s for the %s env variable: %w", val, envVar, err)
|
||||
}
|
||||
|
||||
if val < min || val > max {
|
||||
return defaultValue, fmt.Errorf("invalid value for the %s env variable: value should be between %s and %s, got %s", envVar, min, max, val)
|
||||
}
|
||||
|
||||
return val, nil
|
||||
}
|
|
@ -1,4 +1,17 @@
|
|||
package utils
|
||||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package env
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
@ -7,7 +20,7 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetEnvIntWithRangeWrongValues(t *testing.T) {
|
||||
func TestGetIntWithRangeWrongValues(t *testing.T) {
|
||||
testValues := []struct {
|
||||
name string
|
||||
envVarVal string
|
||||
|
@ -43,7 +56,7 @@ func TestGetEnvIntWithRangeWrongValues(t *testing.T) {
|
|||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Setenv("MY_ENV", tt.envVarVal)
|
||||
|
||||
val, err := GetEnvDurationWithRange("MY_ENV", defaultValue, tt.min, tt.max)
|
||||
val, err := GetDurationWithRange("MY_ENV", defaultValue, tt.min, tt.max)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tt.error)
|
||||
require.Equal(t, defaultValue, val)
|
||||
|
@ -75,7 +88,7 @@ func TestGetEnvDurationWithRangeValidValues(t *testing.T) {
|
|||
t.Setenv("MY_ENV", tt.envVarVal)
|
||||
}
|
||||
|
||||
val, err := GetEnvDurationWithRange("MY_ENV", 3*time.Second, time.Second, 5*time.Second)
|
||||
val, err := GetDurationWithRange("MY_ENV", 3*time.Second, time.Second, 5*time.Second)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.result, val)
|
||||
})
|
|
@ -0,0 +1,65 @@
|
|||
/*
|
||||
Copyright 2025 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package fake
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/dapr/kit/events/loop"
|
||||
)
|
||||
|
||||
type Fake[T any] struct {
|
||||
runFn func(context.Context) error
|
||||
enqueueFn func(T)
|
||||
closeFn func(T)
|
||||
}
|
||||
|
||||
func New[T any]() *Fake[T] {
|
||||
return &Fake[T]{
|
||||
runFn: func(context.Context) error { return nil },
|
||||
enqueueFn: func(T) {},
|
||||
closeFn: func(T) {},
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Fake[T]) WithRun(fn func(context.Context) error) *Fake[T] {
|
||||
f.runFn = fn
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Fake[T]) WithEnqueue(fn func(T)) *Fake[T] {
|
||||
f.enqueueFn = fn
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Fake[T]) WithClose(fn func(T)) *Fake[T] {
|
||||
f.closeFn = fn
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Fake[T]) Run(ctx context.Context) error {
|
||||
return f.runFn(ctx)
|
||||
}
|
||||
|
||||
func (f *Fake[T]) Enqueue(t T) {
|
||||
f.enqueueFn(t)
|
||||
}
|
||||
|
||||
func (f *Fake[T]) Close(t T) {
|
||||
f.closeFn(t)
|
||||
}
|
||||
|
||||
func (f *Fake[T]) Reset(loop.Handler[T], uint64) loop.Interface[T] {
|
||||
return f
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
/*
|
||||
Copyright 2025 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package fake
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/dapr/kit/events/loop"
|
||||
)
|
||||
|
||||
func Test_Fake(*testing.T) {
|
||||
var _ loop.Interface[int] = New[int]()
|
||||
}
|
|
@ -15,58 +15,97 @@ package loop
|
|||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type HandlerFunc[T any] func(context.Context, T) error
|
||||
|
||||
type Options[T any] struct {
|
||||
Handler HandlerFunc[T]
|
||||
BufferSize *uint64
|
||||
type Handler[T any] interface {
|
||||
Handle(ctx context.Context, t T) error
|
||||
}
|
||||
|
||||
type Loop[T any] struct {
|
||||
type Interface[T any] interface {
|
||||
Run(ctx context.Context) error
|
||||
Enqueue(t T)
|
||||
Close(t T)
|
||||
Reset(h Handler[T], size uint64) Interface[T]
|
||||
}
|
||||
|
||||
type loop[T any] struct {
|
||||
queue chan T
|
||||
handler HandlerFunc[T]
|
||||
handler Handler[T]
|
||||
|
||||
closed bool
|
||||
closeCh chan struct{}
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
func New[T any](opts Options[T]) *Loop[T] {
|
||||
size := 1
|
||||
if opts.BufferSize != nil {
|
||||
size = int(*opts.BufferSize)
|
||||
}
|
||||
|
||||
return &Loop[T]{
|
||||
func New[T any](h Handler[T], size uint64) Interface[T] {
|
||||
return &loop[T]{
|
||||
queue: make(chan T, size),
|
||||
handler: h,
|
||||
closeCh: make(chan struct{}),
|
||||
handler: opts.Handler,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Loop[T]) Run(ctx context.Context) error {
|
||||
func Empty[T any]() Interface[T] {
|
||||
return new(loop[T])
|
||||
}
|
||||
|
||||
func (l *loop[T]) Run(ctx context.Context) error {
|
||||
defer close(l.closeCh)
|
||||
|
||||
for {
|
||||
var req T
|
||||
select {
|
||||
case req = <-l.queue:
|
||||
case <-ctx.Done():
|
||||
req, ok := <-l.queue
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := l.handler(ctx, req); err != nil {
|
||||
if err := l.handler.Handle(ctx, req); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Loop[T]) Enqueue(req T) {
|
||||
func (l *loop[T]) Enqueue(req T) {
|
||||
l.lock.RLock()
|
||||
defer l.lock.RUnlock()
|
||||
|
||||
if l.closed {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case l.queue <- req:
|
||||
case <-l.closeCh:
|
||||
}
|
||||
}
|
||||
|
||||
func (l *loop[T]) Close(req T) {
|
||||
l.lock.Lock()
|
||||
l.closed = true
|
||||
select {
|
||||
case l.queue <- req:
|
||||
case <-l.closeCh:
|
||||
}
|
||||
close(l.queue)
|
||||
l.lock.Unlock()
|
||||
<-l.closeCh
|
||||
}
|
||||
|
||||
func (l *loop[T]) Reset(h Handler[T], size uint64) Interface[T] {
|
||||
if l == nil {
|
||||
return New[T](h, size)
|
||||
}
|
||||
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
|
||||
l.closed = false
|
||||
l.closeCh = make(chan struct{})
|
||||
l.handler = h
|
||||
|
||||
// TODO: @joshvanl: use a ring buffer so that we don't need to reallocate and
|
||||
// improve performance.
|
||||
l.queue = make(chan T, size)
|
||||
|
||||
return l
|
||||
}
|
||||
|
|
1
go.mod
1
go.mod
|
@ -28,6 +28,7 @@ require (
|
|||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect
|
||||
github.com/go-jose/go-jose/v3 v3.0.1 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/lestrrat-go/blackmagic v1.0.2 // indirect
|
||||
|
|
2
go.sum
2
go.sum
|
@ -19,6 +19,7 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
|||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/gofuzz v1.1.0 h1:Hsa8mG0dQ46ij8Sl2AYJDUv1oA9/d6Vk+3LG99Oe02g=
|
||||
|
@ -72,6 +73,7 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec
|
|||
github.com/zeebo/errs v1.3.0 h1:hmiaKqgYZzcVgRL1Vkc1Mn2914BbzB0IBxs+ebeutGs=
|
||||
github.com/zeebo/errs v1.3.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=
|
||||
|
|
|
@ -37,9 +37,9 @@ import (
|
|||
"github.com/lestrrat-go/httprc"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
|
||||
"github.com/dapr/kit/crypto/pem"
|
||||
"github.com/dapr/kit/fswatcher"
|
||||
"github.com/dapr/kit/logger"
|
||||
"github.com/dapr/kit/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -198,7 +198,7 @@ func (c *JWKSCache) initJWKSFromURL(ctx context.Context, url string) error {
|
|||
|
||||
// Load CA certificates if we have one
|
||||
if c.caCertificate != "" {
|
||||
caCert, err := utils.GetPEM(c.caCertificate)
|
||||
caCert, err := pem.GetPEM(c.caCertificate)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load CA certificate: %w", err)
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ import (
|
|||
"github.com/mitchellh/mapstructure"
|
||||
|
||||
"github.com/dapr/kit/ptr"
|
||||
"github.com/dapr/kit/utils"
|
||||
kitstrings "github.com/dapr/kit/strings"
|
||||
)
|
||||
|
||||
func toTruthyBoolHookFunc() mapstructure.DecodeHookFunc {
|
||||
|
@ -37,10 +37,10 @@ func toTruthyBoolHookFunc() mapstructure.DecodeHookFunc {
|
|||
data any,
|
||||
) (any, error) {
|
||||
if f == stringType && t == boolType {
|
||||
return utils.IsTruthy(data.(string)), nil
|
||||
return kitstrings.IsTruthy(data.(string)), nil
|
||||
}
|
||||
if f == stringType && t == boolPtrType {
|
||||
return ptr.Of(utils.IsTruthy(data.(string))), nil
|
||||
return ptr.Of(kitstrings.IsTruthy(data.(string))), nil
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
*/
|
||||
|
||||
package utils
|
||||
package strings
|
||||
|
||||
import (
|
||||
"path/filepath"
|
29
utils/env.go
29
utils/env.go
|
@ -1,29 +0,0 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GetEnvDurationWithRange returns the time.Duration value of the environment variable specified by `envVar`.
|
||||
// If the environment variable is not set, it returns `defaultValue`.
|
||||
// If the value is set but is not valid (not a valid time.Duration or falls outside the specified range
|
||||
// [minValue, maxValue] inclusively), it returns `defaultValue` and an error.
|
||||
func GetEnvDurationWithRange(envVar string, defaultValue, min, max time.Duration) (time.Duration, error) {
|
||||
v := os.Getenv(envVar)
|
||||
if v == "" {
|
||||
return defaultValue, nil
|
||||
}
|
||||
|
||||
val, err := time.ParseDuration(v)
|
||||
if err != nil {
|
||||
return defaultValue, fmt.Errorf("invalid time.Duration value %s for the %s env variable: %w", val, envVar, err)
|
||||
}
|
||||
|
||||
if val < min || val > max {
|
||||
return defaultValue, fmt.Errorf("invalid value for the %s env variable: value should be between %s and %s, got %s", envVar, min, max, val)
|
||||
}
|
||||
|
||||
return val, nil
|
||||
}
|
41
utils/pem.go
41
utils/pem.go
|
@ -1,41 +0,0 @@
|
|||
/*
|
||||
Copyright 2023 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// GetPEM loads a PEM-encoded file (certificate or key).
|
||||
func GetPEM(val string) ([]byte, error) {
|
||||
// If val is already a PEM-encoded string, return it as-is
|
||||
if IsValidPEM(val) {
|
||||
return []byte(val), nil
|
||||
}
|
||||
|
||||
// Assume it's a file
|
||||
pemBytes, err := os.ReadFile(val)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("value is neither a valid file path or nor a valid PEM-encoded string: %w", err)
|
||||
}
|
||||
return pemBytes, nil
|
||||
}
|
||||
|
||||
// IsValidPEM validates the provided input has PEM formatted block.
|
||||
func IsValidPEM(val string) bool {
|
||||
block, _ := pem.Decode([]byte(val))
|
||||
return block != nil
|
||||
}
|
Loading…
Reference in New Issue