mirror of https://github.com/dapr/kit.git
Compare commits
38 Commits
Author | SHA1 | Date |
---|---|---|
|
8b780b4d81 | |
|
9d4f384c57 | |
|
7c4cedad37 | |
|
7409957e9e | |
|
34f8820d2a | |
|
598b032bce | |
|
d7d50a1e1b | |
|
baea626399 | |
|
bc7dc566c4 | |
|
98fe567235 | |
|
e3d4a8f1b4 | |
|
77af8ac182 | |
|
a3f06e444a | |
|
39c4bf57bd | |
|
6271c8be59 | |
|
c46009f360 | |
|
c90b807d32 | |
|
fb19570696 | |
|
65ba3783f2 | |
|
30e2c24840 | |
|
24b59a803d | |
|
d37dc603d0 | |
|
866002abe6 | |
|
bc3a4f0fb4 | |
|
2d6ff15a97 | |
|
3823663aa4 | |
|
502671bade | |
|
26b564d9d0 | |
|
58c6d9df14 | |
|
e2508d6e9e | |
|
106329e583 | |
|
ccffb60016 | |
|
a3f906d609 | |
|
0c7cfce53d | |
|
6c3b2ee1ef | |
|
e33fbab745 | |
|
050e34c9b9 | |
|
9e733a35f1 |
|
@ -24,7 +24,7 @@ jobs:
|
|||
GOOS: ${{ matrix.target_os }}
|
||||
GOARCH: ${{ matrix.target_arch }}
|
||||
GOPROXY: https://proxy.golang.org
|
||||
GOLANGCI_LINT_VER: v1.55.1
|
||||
GOLANGCI_LINT_VER: v1.64.8
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macOS-latest]
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
**/.DS_Store
|
||||
.idea
|
||||
.vscode
|
||||
.vs
|
||||
vendor
|
196
.golangci.yml
196
.golangci.yml
|
@ -4,7 +4,7 @@ run:
|
|||
concurrency: 4
|
||||
|
||||
# timeout for analysis, e.g. 30s, 5m, default is 1m
|
||||
deadline: 10m
|
||||
timeout: 15m
|
||||
|
||||
# exit code when at least one issue was found, default is 1
|
||||
issues-exit-code: 1
|
||||
|
@ -15,29 +15,34 @@ run:
|
|||
# list of build tags, all linters use it. Default is empty list.
|
||||
build-tags:
|
||||
- unit
|
||||
|
||||
# which dirs to skip: they won't be analyzed;
|
||||
# can use regexp here: generated.*, regexp is applied on full path;
|
||||
# default value is empty list, but next dirs are always skipped independently
|
||||
# from this option's value:
|
||||
# third_party$, testdata$, examples$, Godeps$, builtin$
|
||||
skip-dirs:
|
||||
- ^pkg.*client.*clientset.*versioned.*
|
||||
- ^pkg.*client.*informers.*externalversions.*
|
||||
- ^pkg.*proto.*
|
||||
- allcomponents
|
||||
- subtlecrypto
|
||||
|
||||
# which files to skip: they will be analyzed, but issues from them
|
||||
# won't be reported. Default value is empty list, but there is
|
||||
# no need to include all autogenerated files, we confidently recognize
|
||||
# autogenerated files. If it's not please let us know.
|
||||
skip-files:
|
||||
# skip-files:
|
||||
# - ".*\\.my\\.go$"
|
||||
# - lib/bad.go
|
||||
|
||||
issues:
|
||||
# which dirs to skip: they won't be analyzed;
|
||||
# can use regexp here: generated.*, regexp is applied on full path;
|
||||
# default value is empty list, but next dirs are always skipped independently
|
||||
# from this option's value:
|
||||
# third_party$, testdata$, examples$, Godeps$, builtin$
|
||||
exclude-dirs:
|
||||
- ^pkg.*client.*clientset.*versioned.*
|
||||
- ^pkg.*client.*informers.*externalversions.*
|
||||
- ^pkg.*proto.*
|
||||
- pkg/proto
|
||||
|
||||
# output configuration options
|
||||
output:
|
||||
# colored-line-number|line-number|json|tab|checkstyle, default is "colored-line-number"
|
||||
format: tab
|
||||
formats:
|
||||
- format: tab
|
||||
|
||||
# print lines of code with issue, default is true
|
||||
print-issued-lines: true
|
||||
|
@ -57,23 +62,19 @@ linters-settings:
|
|||
# default is false: such cases aren't reported by default.
|
||||
check-blank: false
|
||||
|
||||
# [deprecated] comma-separated list of pairs of the form pkg:regex
|
||||
# the regex is used to ignore names within pkg. (default "fmt:.*").
|
||||
# see https://github.com/kisielk/errcheck#the-deprecated-method for details
|
||||
ignore: fmt:.*,io/ioutil:^Read.*
|
||||
exclude-functions:
|
||||
- fmt:.*
|
||||
- io/ioutil:^Read.*
|
||||
|
||||
# path to a file containing a list of functions to exclude from checking
|
||||
# see https://github.com/kisielk/errcheck#excluding-functions for details
|
||||
exclude:
|
||||
# exclude:
|
||||
|
||||
funlen:
|
||||
lines: 60
|
||||
statements: 40
|
||||
|
||||
govet:
|
||||
# report about shadowed variables
|
||||
check-shadowing: true
|
||||
|
||||
# settings per analyzer
|
||||
settings:
|
||||
printf: # analyzer name, run `go tool vet help` to see all analyzers
|
||||
|
@ -86,13 +87,12 @@ linters-settings:
|
|||
# enable or disable analyzers by name
|
||||
enable:
|
||||
- atomicalign
|
||||
enable-all: false
|
||||
disable:
|
||||
- shadow
|
||||
enable-all: false
|
||||
disable-all: false
|
||||
golint:
|
||||
revive:
|
||||
# minimal confidence for issues, default is 0.8
|
||||
min-confidence: 0.8
|
||||
confidence: 0.8
|
||||
gofmt:
|
||||
# simplify code: gofmt with `-s` option, true by default
|
||||
simplify: true
|
||||
|
@ -106,9 +106,6 @@ linters-settings:
|
|||
gocognit:
|
||||
# minimal code complexity to report, 30 by default (but we recommend 10-20)
|
||||
min-complexity: 10
|
||||
maligned:
|
||||
# print struct with more effective memory layout or not, false by default
|
||||
suggest-new: true
|
||||
dupl:
|
||||
# tokens count to trigger issue, 150 by default
|
||||
threshold: 100
|
||||
|
@ -121,55 +118,60 @@ linters-settings:
|
|||
rules:
|
||||
main:
|
||||
deny:
|
||||
- pkg: "github.com/Sirupsen/logrus"
|
||||
desc: "must use github.com/dapr/kit/logger"
|
||||
- pkg: "github.com/agrea/ptr"
|
||||
desc: "must use github.com/dapr/kit/ptr"
|
||||
- pkg: "go.uber.org/atomic"
|
||||
desc: "must use sync/atomic"
|
||||
- pkg: "golang.org/x/net/context"
|
||||
desc: "must use context"
|
||||
- pkg: "github.com/pkg/errors"
|
||||
desc: "must use standard library (errors package and/or fmt.Errorf)"
|
||||
- pkg: "github.com/go-chi/chi$"
|
||||
desc: "must use github.com/go-chi/chi/v5"
|
||||
- pkg: "github.com/cenkalti/backoff$"
|
||||
desc: "must use github.com/cenkalti/backoff/v4"
|
||||
- pkg: "github.com/cenkalti/backoff/v2"
|
||||
desc: "must use github.com/cenkalti/backoff/v4"
|
||||
- pkg: "github.com/cenkalti/backoff/v3"
|
||||
desc: "must use github.com/cenkalti/backoff/v4"
|
||||
- pkg: "github.com/benbjohnson/clock"
|
||||
desc: "must use k8s.io/utils/clock"
|
||||
- pkg: "github.com/ghodss/yaml"
|
||||
desc: "must use sigs.k8s.io/yaml"
|
||||
- pkg: "gopkg.in/yaml.v2"
|
||||
desc: "must use gopkg.in/yaml.v3"
|
||||
- pkg: "github.com/golang-jwt/jwt"
|
||||
desc: "must use github.com/lestrrat-go/jwx/v2"
|
||||
- pkg: "github.com/golang-jwt/jwt/v2"
|
||||
desc: "must use github.com/lestrrat-go/jwx/v2"
|
||||
- pkg: "github.com/golang-jwt/jwt/v3"
|
||||
desc: "must use github.com/lestrrat-go/jwx/v2"
|
||||
- pkg: "github.com/golang-jwt/jwt/v4"
|
||||
desc: "must use github.com/lestrrat-go/jwx/v2"
|
||||
- pkg: "github.com/gogo/status"
|
||||
desc: "must use google.golang.org/grpc/status"
|
||||
- pkg: "github.com/gogo/protobuf"
|
||||
desc: "must use google.golang.org/protobuf"
|
||||
- pkg: "github.com/lestrrat-go/jwx/jwa"
|
||||
desc: "must use github.com/lestrrat-go/jwx/v2"
|
||||
- pkg: "github.com/lestrrat-go/jwx/jwt"
|
||||
desc: "must use github.com/lestrrat-go/jwx/v2"
|
||||
- pkg: "github.com/labstack/gommon/log"
|
||||
desc: "must use github.com/dapr/kit/logger"
|
||||
- pkg: "github.com/gobuffalo/logger"
|
||||
desc: "must use github.com/dapr/kit/logger"
|
||||
- pkg: "github.com/Sirupsen/logrus"
|
||||
desc: "must use github.com/dapr/kit/logger"
|
||||
- pkg: "github.com/agrea/ptr"
|
||||
desc: "must use github.com/dapr/kit/ptr"
|
||||
- pkg: "go.uber.org/atomic"
|
||||
desc: "must use sync/atomic"
|
||||
- pkg: "golang.org/x/net/context"
|
||||
desc: "must use context"
|
||||
- pkg: "github.com/pkg/errors"
|
||||
desc: "must use standard library (errors package and/or fmt.Errorf)"
|
||||
- pkg: "github.com/go-chi/chi$"
|
||||
desc: "must use github.com/go-chi/chi/v5"
|
||||
- pkg: "github.com/cenkalti/backoff$"
|
||||
desc: "must use github.com/cenkalti/backoff/v4"
|
||||
- pkg: "github.com/cenkalti/backoff/v2"
|
||||
desc: "must use github.com/cenkalti/backoff/v4"
|
||||
- pkg: "github.com/cenkalti/backoff/v3"
|
||||
desc: "must use github.com/cenkalti/backoff/v4"
|
||||
- pkg: "github.com/benbjohnson/clock"
|
||||
desc: "must use k8s.io/utils/clock"
|
||||
- pkg: "github.com/ghodss/yaml"
|
||||
desc: "must use sigs.k8s.io/yaml"
|
||||
- pkg: "gopkg.in/yaml.v2"
|
||||
desc: "must use gopkg.in/yaml.v3"
|
||||
- pkg: "github.com/golang-jwt/jwt"
|
||||
desc: "must use github.com/lestrrat-go/jwx/v2"
|
||||
- pkg: "github.com/golang-jwt/jwt/v2"
|
||||
desc: "must use github.com/lestrrat-go/jwx/v2"
|
||||
- pkg: "github.com/golang-jwt/jwt/v3"
|
||||
desc: "must use github.com/lestrrat-go/jwx/v2"
|
||||
- pkg: "github.com/golang-jwt/jwt/v4"
|
||||
desc: "must use github.com/lestrrat-go/jwx/v2"
|
||||
# pkg: Commonly auto-completed by gopls
|
||||
- pkg: "github.com/gogo/status"
|
||||
desc: "must use google.golang.org/grpc/status"
|
||||
- pkg: "github.com/gogo/protobuf"
|
||||
desc: "must use google.golang.org/protobuf"
|
||||
- pkg: "github.com/lestrrat-go/jwx/jwa"
|
||||
desc: "must use github.com/lestrrat-go/jwx/v2"
|
||||
- pkg: "github.com/lestrrat-go/jwx/jwt"
|
||||
desc: "must use github.com/lestrrat-go/jwx/v2"
|
||||
- pkg: "github.com/labstack/gommon/log"
|
||||
desc: "must use github.com/dapr/kit/logger"
|
||||
- pkg: "github.com/gobuffalo/logger"
|
||||
desc: "must use github.com/dapr/kit/logger"
|
||||
- pkg: "k8s.io/utils/pointer"
|
||||
desc: "must use github.com/dapr/kit/ptr"
|
||||
- pkg: "k8s.io/utils/ptr"
|
||||
desc: "must use github.com/dapr/kit/ptr"
|
||||
misspell:
|
||||
# Correct spellings using locale preferences for US or UK.
|
||||
# Default is to use a neutral variety of English.
|
||||
# Setting locale to US will correct the British spelling of 'colour' to 'color'.
|
||||
locale: default
|
||||
# locale: default
|
||||
ignore-words:
|
||||
- someword
|
||||
lll:
|
||||
|
@ -178,17 +180,9 @@ linters-settings:
|
|||
line-length: 120
|
||||
# tab width in spaces. Default to 1.
|
||||
tab-width: 1
|
||||
unparam:
|
||||
# Inspect exported functions, default is false. Set to true if no external program/library imports your code.
|
||||
# XXX: if you enable this setting, unparam will report a lot of false-positives in text editors:
|
||||
# if it's called for subdir of a project it can't find external interfaces. All text editor integrations
|
||||
# with golangci-lint call it on a directory with the changed file.
|
||||
check-exported: false
|
||||
nakedret:
|
||||
# make an issue if func has more lines of code than this setting and it has naked returns; default is 30
|
||||
max-func-lines: 30
|
||||
nolintlint:
|
||||
allow-unused: true
|
||||
prealloc:
|
||||
# XXX: we don't recommend using this linter before doing performance profiling.
|
||||
# For most programs usage of prealloc will be a premature optimization.
|
||||
|
@ -203,7 +197,6 @@ linters-settings:
|
|||
# See https://go-critic.github.io/overview#checks-overview
|
||||
# To check which checks are enabled run `GL_DEBUG=gocritic golangci-lint run`
|
||||
# By default list of stable checks is used.
|
||||
enabled-checks:
|
||||
|
||||
# Which checks should be disabled; can't be combined with 'enabled-checks'; default is empty
|
||||
disabled-checks:
|
||||
|
@ -251,63 +244,51 @@ linters-settings:
|
|||
allow-assign-and-call: true
|
||||
# Allow multiline assignments to be cuddled. Default is true.
|
||||
allow-multiline-assign: true
|
||||
# Allow case blocks to end with a whitespace.
|
||||
allow-case-traling-whitespace: true
|
||||
# Allow declarations (var) to be cuddled.
|
||||
allow-cuddle-declarations: false
|
||||
# If the number of lines in a case block is equal to or lager than this number,
|
||||
# the case *must* end white a newline.
|
||||
# https://github.com/bombsimon/wsl/blob/master/doc/configuration.md#force-case-trailing-whitespace
|
||||
# Default: 0
|
||||
force-case-trailing-whitespace: 1
|
||||
|
||||
linters:
|
||||
fast: false
|
||||
enable-all: true
|
||||
disable:
|
||||
# TODO Enforce the below linters later
|
||||
- musttag
|
||||
- dupl
|
||||
- nonamedreturns
|
||||
- errcheck
|
||||
- funlen
|
||||
- goconst
|
||||
- gochecknoglobals
|
||||
- gochecknoinits
|
||||
- gocyclo
|
||||
- gocognit
|
||||
- nosnakecase
|
||||
- varcheck
|
||||
- structcheck
|
||||
- deadcode
|
||||
- godox
|
||||
- interfacer
|
||||
- lll
|
||||
- maligned
|
||||
- scopelint
|
||||
- unparam
|
||||
- wsl
|
||||
- gomnd
|
||||
- testpackage
|
||||
- goerr113
|
||||
- nestif
|
||||
- nlreturn
|
||||
- exhaustive
|
||||
- exhaustruct
|
||||
- noctx
|
||||
- gci
|
||||
- golint
|
||||
- tparallel
|
||||
- paralleltest
|
||||
- wrapcheck
|
||||
- tagliatelle
|
||||
- ireturn
|
||||
- exhaustive
|
||||
- exhaustivestruct
|
||||
- exhaustruct
|
||||
- errchkjson
|
||||
- contextcheck
|
||||
- gomoddirectives
|
||||
- godot
|
||||
- cyclop
|
||||
- varnamelen
|
||||
- gosec
|
||||
- tagalign
|
||||
- errorlint
|
||||
- forcetypeassert
|
||||
- ifshort
|
||||
- maintidx
|
||||
- nilnil
|
||||
- predeclared
|
||||
|
@ -316,4 +297,13 @@ linters:
|
|||
- wastedassign
|
||||
- containedctx
|
||||
- gosimple
|
||||
- forbidigo
|
||||
- nonamedreturns
|
||||
- asasalint
|
||||
- rowserrcheck
|
||||
- sqlclosecheck
|
||||
- inamedparam
|
||||
- tagalign
|
||||
- mnd
|
||||
- canonicalheader
|
||||
- err113
|
||||
- fatcontext
|
||||
|
|
7
Makefile
7
Makefile
|
@ -78,6 +78,13 @@ test-race:
|
|||
lint:
|
||||
$(GOLANGCI_LINT) run --timeout=20m
|
||||
|
||||
################################################################################
|
||||
# Target: lint-fix #
|
||||
################################################################################
|
||||
.PHONY: lint-fix
|
||||
lint-fix:
|
||||
$(GOLANGCI_LINT) run --timeout=20m --fix
|
||||
|
||||
################################################################################
|
||||
# Target: go.mod #
|
||||
################################################################################
|
||||
|
|
|
@ -32,7 +32,7 @@ func TestByteSlicePool(t *testing.T) {
|
|||
assert.Equal(t, &bs, &bs2)
|
||||
assert.Equal(t, minCap, cap(bs2))
|
||||
|
||||
for i := 0; i < minCap; i++ {
|
||||
for range minCap {
|
||||
bs2 = append(bs2, 0)
|
||||
}
|
||||
|
||||
|
|
|
@ -26,11 +26,7 @@ import (
|
|||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrManagerAlreadyClosed = errors.New("runner manager already closed")
|
||||
|
||||
log = logger.NewLogger("dapr.kit.concurrency")
|
||||
)
|
||||
var ErrManagerAlreadyClosed = errors.New("runner manager already closed")
|
||||
|
||||
// RunnerCloserManager is a RunnerManager that also implements Closing of the
|
||||
// added closers once the main runners are done.
|
||||
|
@ -64,7 +60,7 @@ type RunnerCloserManager struct {
|
|||
// NewRunnerCloserManager creates a new RunnerCloserManager with the given
|
||||
// grace period and runners.
|
||||
// If gracePeriod is nil, the grace period is infinite.
|
||||
func NewRunnerCloserManager(gracePeriod *time.Duration, runners ...Runner) *RunnerCloserManager {
|
||||
func NewRunnerCloserManager(log logger.Logger, gracePeriod *time.Duration, runners ...Runner) *RunnerCloserManager {
|
||||
c := &RunnerCloserManager{
|
||||
mngr: NewRunnerManager(runners...),
|
||||
clock: clock.RealClock{},
|
||||
|
|
|
@ -24,8 +24,12 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
clocktesting "k8s.io/utils/clock/testing"
|
||||
|
||||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
||||
var log = logger.NewLogger("test")
|
||||
|
||||
type mockCloser func() error
|
||||
|
||||
func (m mockCloser) Close() error {
|
||||
|
@ -34,21 +38,21 @@ func (m mockCloser) Close() error {
|
|||
|
||||
func Test_RunnerClosterManager(t *testing.T) {
|
||||
t.Run("runner with no tasks or closers should return nil", func(t *testing.T) {
|
||||
require.NoError(t, NewRunnerCloserManager(nil).Run(context.Background()))
|
||||
require.NoError(t, NewRunnerCloserManager(log, nil).Run(t.Context()))
|
||||
})
|
||||
|
||||
t.Run("runner with a task that completes should return nil", func(t *testing.T) {
|
||||
var i atomic.Int32
|
||||
require.NoError(t, NewRunnerCloserManager(nil, func(context.Context) error {
|
||||
require.NoError(t, NewRunnerCloserManager(log, nil, func(context.Context) error {
|
||||
i.Add(1)
|
||||
return nil
|
||||
}).Run(context.Background()))
|
||||
}).Run(t.Context()))
|
||||
assert.Equal(t, int32(1), i.Load())
|
||||
})
|
||||
|
||||
t.Run("runner with a task and closer that completes should return nil", func(t *testing.T) {
|
||||
var i atomic.Int32
|
||||
mngr := NewRunnerCloserManager(nil, func(context.Context) error {
|
||||
mngr := NewRunnerCloserManager(log, nil, func(context.Context) error {
|
||||
i.Add(1)
|
||||
return nil
|
||||
})
|
||||
|
@ -56,13 +60,13 @@ func Test_RunnerClosterManager(t *testing.T) {
|
|||
i.Add(1)
|
||||
return nil
|
||||
}))
|
||||
require.NoError(t, mngr.Run(context.Background()))
|
||||
require.NoError(t, mngr.Run(t.Context()))
|
||||
assert.Equal(t, int32(2), i.Load())
|
||||
})
|
||||
|
||||
t.Run("runner with multiple tasks and closers that complete should return nil", func(t *testing.T) {
|
||||
var i atomic.Int32
|
||||
mngr := NewRunnerCloserManager(nil,
|
||||
mngr := NewRunnerCloserManager(log, nil,
|
||||
func(context.Context) error {
|
||||
i.Add(1)
|
||||
return nil
|
||||
|
@ -94,82 +98,82 @@ func Test_RunnerClosterManager(t *testing.T) {
|
|||
}),
|
||||
))
|
||||
|
||||
require.NoError(t, mngr.Run(context.Background()))
|
||||
require.NoError(t, mngr.Run(t.Context()))
|
||||
assert.Equal(t, int32(7), i.Load())
|
||||
})
|
||||
|
||||
t.Run("a runner that errors should error but still call the closers", func(t *testing.T) {
|
||||
var i atomic.Int32
|
||||
mngr := NewRunnerCloserManager(nil,
|
||||
func(ctx context.Context) error {
|
||||
mngr := NewRunnerCloserManager(log, nil,
|
||||
func(context.Context) error {
|
||||
i.Add(1)
|
||||
return errors.New("error")
|
||||
},
|
||||
func(ctx context.Context) error {
|
||||
func(context.Context) error {
|
||||
i.Add(1)
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context) error {
|
||||
func(context.Context) error {
|
||||
i.Add(1)
|
||||
return nil
|
||||
},
|
||||
)
|
||||
require.NoError(t, mngr.AddCloser(
|
||||
func(ctx context.Context) error {
|
||||
func(context.Context) error {
|
||||
i.Add(1)
|
||||
return nil
|
||||
},
|
||||
))
|
||||
|
||||
require.EqualError(t, mngr.Run(context.Background()), "error")
|
||||
require.EqualError(t, mngr.Run(t.Context()), "error")
|
||||
assert.Equal(t, int32(4), i.Load())
|
||||
})
|
||||
|
||||
t.Run("a runner that has closter errors should error", func(t *testing.T) {
|
||||
var i atomic.Int32
|
||||
mngr := NewRunnerCloserManager(nil,
|
||||
func(ctx context.Context) error {
|
||||
mngr := NewRunnerCloserManager(log, nil,
|
||||
func(context.Context) error {
|
||||
i.Add(1)
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context) error {
|
||||
func(context.Context) error {
|
||||
i.Add(1)
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context) error {
|
||||
func(context.Context) error {
|
||||
i.Add(1)
|
||||
return nil
|
||||
},
|
||||
)
|
||||
require.NoError(t, mngr.AddCloser(
|
||||
func(ctx context.Context) error {
|
||||
func(context.Context) error {
|
||||
i.Add(1)
|
||||
return errors.New("error")
|
||||
},
|
||||
))
|
||||
|
||||
require.EqualError(t, mngr.Run(context.Background()), "error")
|
||||
require.EqualError(t, mngr.Run(t.Context()), "error")
|
||||
assert.Equal(t, int32(4), i.Load())
|
||||
})
|
||||
|
||||
t.Run("a runner with multiple errors should collect all errors (string match)", func(t *testing.T) {
|
||||
var i atomic.Int32
|
||||
mngr := NewRunnerCloserManager(nil,
|
||||
func(ctx context.Context) error {
|
||||
mngr := NewRunnerCloserManager(log, nil,
|
||||
func(context.Context) error {
|
||||
i.Add(1)
|
||||
return errors.New("error")
|
||||
},
|
||||
func(ctx context.Context) error {
|
||||
func(context.Context) error {
|
||||
i.Add(1)
|
||||
return errors.New("error")
|
||||
},
|
||||
func(ctx context.Context) error {
|
||||
func(context.Context) error {
|
||||
i.Add(1)
|
||||
return errors.New("error")
|
||||
},
|
||||
)
|
||||
require.NoError(t, mngr.AddCloser(
|
||||
func(ctx context.Context) error {
|
||||
func(context.Context) error {
|
||||
i.Add(1)
|
||||
return errors.New("closererror")
|
||||
},
|
||||
|
@ -183,7 +187,7 @@ func Test_RunnerClosterManager(t *testing.T) {
|
|||
}),
|
||||
))
|
||||
|
||||
err := mngr.Run(context.Background())
|
||||
err := mngr.Run(t.Context())
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "error\nerror\nerror\nclosererror\nclosererror\nclosererror") //nolint:dupword
|
||||
assert.Equal(t, int32(6), i.Load())
|
||||
|
@ -191,22 +195,22 @@ func Test_RunnerClosterManager(t *testing.T) {
|
|||
|
||||
t.Run("a runner with multiple errors should collect all errors (unique)", func(t *testing.T) {
|
||||
var i atomic.Int32
|
||||
mngr := NewRunnerCloserManager(nil,
|
||||
func(ctx context.Context) error {
|
||||
mngr := NewRunnerCloserManager(log, nil,
|
||||
func(context.Context) error {
|
||||
i.Add(1)
|
||||
return errors.New("error1")
|
||||
},
|
||||
func(ctx context.Context) error {
|
||||
func(context.Context) error {
|
||||
i.Add(1)
|
||||
return errors.New("error2")
|
||||
},
|
||||
func(ctx context.Context) error {
|
||||
func(context.Context) error {
|
||||
i.Add(1)
|
||||
return errors.New("error3")
|
||||
},
|
||||
)
|
||||
require.NoError(t, mngr.AddCloser(
|
||||
func(ctx context.Context) error {
|
||||
func(context.Context) error {
|
||||
i.Add(1)
|
||||
return errors.New("closererror1")
|
||||
},
|
||||
|
@ -220,7 +224,7 @@ func Test_RunnerClosterManager(t *testing.T) {
|
|||
}),
|
||||
))
|
||||
|
||||
err := mngr.Run(context.Background())
|
||||
err := mngr.Run(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.ElementsMatch(t,
|
||||
[]string{"error1", "error2", "error3", "closererror1", "closererror2", "closererror3"},
|
||||
|
@ -231,26 +235,26 @@ func Test_RunnerClosterManager(t *testing.T) {
|
|||
|
||||
t.Run("should be able to add runner with New, Add and AddCloser", func(t *testing.T) {
|
||||
var i atomic.Int32
|
||||
mngr := NewRunnerCloserManager(nil,
|
||||
func(ctx context.Context) error {
|
||||
mngr := NewRunnerCloserManager(log, nil,
|
||||
func(context.Context) error {
|
||||
i.Add(1)
|
||||
return nil
|
||||
},
|
||||
)
|
||||
require.NoError(t, mngr.Add(
|
||||
func(ctx context.Context) error {
|
||||
func(context.Context) error {
|
||||
i.Add(1)
|
||||
return nil
|
||||
},
|
||||
))
|
||||
require.NoError(t, mngr.Add(
|
||||
func(ctx context.Context) error {
|
||||
func(context.Context) error {
|
||||
i.Add(1)
|
||||
return nil
|
||||
},
|
||||
))
|
||||
require.NoError(t, mngr.AddCloser(
|
||||
func(ctx context.Context) error {
|
||||
func(context.Context) error {
|
||||
i.Add(1)
|
||||
return nil
|
||||
},
|
||||
|
@ -261,14 +265,14 @@ func Test_RunnerClosterManager(t *testing.T) {
|
|||
},
|
||||
))
|
||||
|
||||
require.NoError(t, mngr.Run(context.Background()))
|
||||
require.NoError(t, mngr.Run(t.Context()))
|
||||
assert.Equal(t, int32(5), i.Load())
|
||||
})
|
||||
|
||||
t.Run("when a runner returns, expect context to be cancelled for other runners, but not for closers returning", func(t *testing.T) {
|
||||
var i atomic.Int32
|
||||
mngr := NewRunnerCloserManager(nil,
|
||||
func(ctx context.Context) error {
|
||||
mngr := NewRunnerCloserManager(log, nil,
|
||||
func(context.Context) error {
|
||||
i.Add(1)
|
||||
return nil
|
||||
},
|
||||
|
@ -295,7 +299,7 @@ func Test_RunnerClosterManager(t *testing.T) {
|
|||
closer1Ch := make(chan struct{})
|
||||
closer2Ch := make(chan struct{})
|
||||
require.NoError(t, mngr.AddCloser(
|
||||
func(ctx context.Context) error {
|
||||
func(context.Context) error {
|
||||
i.Add(1)
|
||||
close(closer1Ch)
|
||||
return nil
|
||||
|
@ -321,13 +325,13 @@ func Test_RunnerClosterManager(t *testing.T) {
|
|||
},
|
||||
))
|
||||
|
||||
require.NoError(t, mngr.Run(context.Background()))
|
||||
require.NoError(t, mngr.Run(t.Context()))
|
||||
assert.Equal(t, int32(6), i.Load())
|
||||
})
|
||||
|
||||
t.Run("when a runner errors, expect context to be cancelled for other runners, but closers should still run", func(t *testing.T) {
|
||||
var i atomic.Int32
|
||||
mngr := NewRunnerCloserManager(nil,
|
||||
mngr := NewRunnerCloserManager(log, nil,
|
||||
func(ctx context.Context) error {
|
||||
i.Add(1)
|
||||
select {
|
||||
|
@ -346,7 +350,7 @@ func Test_RunnerClosterManager(t *testing.T) {
|
|||
}
|
||||
return errors.New("error2")
|
||||
},
|
||||
func(ctx context.Context) error {
|
||||
func(context.Context) error {
|
||||
i.Add(1)
|
||||
return errors.New("error3")
|
||||
},
|
||||
|
@ -373,7 +377,7 @@ func Test_RunnerClosterManager(t *testing.T) {
|
|||
},
|
||||
))
|
||||
|
||||
err := mngr.Run(context.Background())
|
||||
err := mngr.Run(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.ElementsMatch(t,
|
||||
[]string{"error1", "error2", "error3", "closererror1", "closererror2"},
|
||||
|
@ -384,45 +388,45 @@ func Test_RunnerClosterManager(t *testing.T) {
|
|||
|
||||
t.Run("a manger started twice should error", func(t *testing.T) {
|
||||
var i atomic.Int32
|
||||
m := NewRunnerCloserManager(nil, func(ctx context.Context) error {
|
||||
m := NewRunnerCloserManager(log, nil, func(context.Context) error {
|
||||
i.Add(1)
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, m.Run(context.Background()))
|
||||
require.NoError(t, m.Run(t.Context()))
|
||||
assert.Equal(t, int32(1), i.Load())
|
||||
require.EqualError(t, m.Run(context.Background()), "runner manager already started")
|
||||
require.EqualError(t, m.Run(t.Context()), "runner manager already started")
|
||||
assert.Equal(t, int32(1), i.Load())
|
||||
})
|
||||
|
||||
t.Run("a manger started twice should error", func(t *testing.T) {
|
||||
var i atomic.Int32
|
||||
m := NewRunnerCloserManager(nil, func(ctx context.Context) error {
|
||||
m := NewRunnerCloserManager(log, nil, func(context.Context) error {
|
||||
i.Add(1)
|
||||
return nil
|
||||
})
|
||||
|
||||
require.NoError(t, m.AddCloser(func(ctx context.Context) error {
|
||||
require.NoError(t, m.AddCloser(func(context.Context) error {
|
||||
i.Add(1)
|
||||
return nil
|
||||
}))
|
||||
|
||||
require.NoError(t, m.Run(context.Background()))
|
||||
require.NoError(t, m.Run(t.Context()))
|
||||
assert.Equal(t, int32(2), i.Load())
|
||||
require.NoError(t, m.Close())
|
||||
require.NoError(t, m.Close())
|
||||
require.EqualError(t, m.Run(context.Background()), "runner manager already started")
|
||||
require.EqualError(t, m.Run(t.Context()), "runner manager already started")
|
||||
assert.Equal(t, int32(2), i.Load())
|
||||
})
|
||||
|
||||
t.Run("adding a task to a started manager should error", func(t *testing.T) {
|
||||
var i atomic.Int32
|
||||
m := NewRunnerCloserManager(nil, func(ctx context.Context) error {
|
||||
m := NewRunnerCloserManager(log, nil, func(context.Context) error {
|
||||
i.Add(1)
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, m.Run(context.Background()))
|
||||
require.NoError(t, m.Run(t.Context()))
|
||||
assert.Equal(t, int32(1), i.Load())
|
||||
err := m.Add(func(ctx context.Context) error {
|
||||
err := m.Add(func(context.Context) error {
|
||||
i.Add(1)
|
||||
return nil
|
||||
})
|
||||
|
@ -433,14 +437,14 @@ func Test_RunnerClosterManager(t *testing.T) {
|
|||
|
||||
t.Run("adding a closer to a closing manager should error", func(t *testing.T) {
|
||||
var i atomic.Int32
|
||||
m := NewRunnerCloserManager(nil, func(ctx context.Context) error {
|
||||
m := NewRunnerCloserManager(log, nil, func(context.Context) error {
|
||||
i.Add(1)
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, m.Run(context.Background()))
|
||||
require.NoError(t, m.Run(t.Context()))
|
||||
assert.Equal(t, int32(1), i.Load())
|
||||
require.NoError(t, m.Close())
|
||||
err := m.AddCloser(func(ctx context.Context) error {
|
||||
err := m.AddCloser(func(context.Context) error {
|
||||
i.Add(1)
|
||||
return nil
|
||||
})
|
||||
|
@ -450,19 +454,19 @@ func Test_RunnerClosterManager(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("if grace period is not given, should have no force shutdown", func(t *testing.T) {
|
||||
mngr := NewRunnerCloserManager(nil)
|
||||
mngr := NewRunnerCloserManager(log, nil)
|
||||
assert.Empty(t, mngr.closers)
|
||||
})
|
||||
|
||||
t.Run("if grace period is given, should have force shutdown", func(t *testing.T) {
|
||||
dur := time.Second
|
||||
mngr := NewRunnerCloserManager(&dur)
|
||||
mngr := NewRunnerCloserManager(log, &dur)
|
||||
assert.Len(t, mngr.closers, 1)
|
||||
})
|
||||
|
||||
t.Run("if closing but grace period not reached, should return", func(t *testing.T) {
|
||||
dur := time.Second
|
||||
mngr := NewRunnerCloserManager(&dur)
|
||||
mngr := NewRunnerCloserManager(log, &dur)
|
||||
|
||||
var i atomic.Int32
|
||||
require.NoError(t, mngr.AddCloser(func() {
|
||||
|
@ -482,7 +486,7 @@ func Test_RunnerClosterManager(t *testing.T) {
|
|||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- mngr.Run(context.Background())
|
||||
errCh <- mngr.Run(t.Context())
|
||||
}()
|
||||
|
||||
select {
|
||||
|
@ -505,13 +509,13 @@ func Test_RunnerClosterManager(t *testing.T) {
|
|||
|
||||
t.Run("if closing and grace period is reached, should force shutdown", func(t *testing.T) {
|
||||
dur := time.Second
|
||||
mngr := NewRunnerCloserManager(&dur)
|
||||
mngr := NewRunnerCloserManager(log, &dur)
|
||||
assert.Len(t, mngr.closers, 1)
|
||||
|
||||
clock := clocktesting.NewFakeClock(time.Now())
|
||||
mngr.clock = clock
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
fatalCalled := make(chan struct{})
|
||||
|
@ -533,7 +537,7 @@ func Test_RunnerClosterManager(t *testing.T) {
|
|||
}
|
||||
})
|
||||
go func() {
|
||||
errCh <- mngr.Run(context.Background())
|
||||
errCh <- mngr.Run(t.Context())
|
||||
}()
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
|
@ -555,7 +559,7 @@ func TestClose(t *testing.T) {
|
|||
t.Run("calling close should stop the main runner and call all closers", func(t *testing.T) {
|
||||
var i atomic.Int32
|
||||
runnerWaiting := make(chan struct{})
|
||||
mngr := NewRunnerCloserManager(nil, func(ctx context.Context) error {
|
||||
mngr := NewRunnerCloserManager(log, nil, func(ctx context.Context) error {
|
||||
close(runnerWaiting)
|
||||
<-ctx.Done()
|
||||
i.Add(1)
|
||||
|
@ -567,7 +571,7 @@ func TestClose(t *testing.T) {
|
|||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- mngr.Run(context.Background())
|
||||
errCh <- mngr.Run(t.Context())
|
||||
}()
|
||||
|
||||
select {
|
||||
|
@ -591,7 +595,7 @@ func TestClose(t *testing.T) {
|
|||
t.Run("calling close should wait for all closers to return", func(t *testing.T) {
|
||||
var i atomic.Int32
|
||||
runnerWaiting := make(chan struct{})
|
||||
mngr := NewRunnerCloserManager(nil, func(ctx context.Context) error {
|
||||
mngr := NewRunnerCloserManager(log, nil, func(ctx context.Context) error {
|
||||
close(runnerWaiting)
|
||||
<-ctx.Done()
|
||||
i.Add(1)
|
||||
|
@ -625,7 +629,7 @@ func TestClose(t *testing.T) {
|
|||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- mngr.Run(context.Background())
|
||||
errCh <- mngr.Run(t.Context())
|
||||
}()
|
||||
|
||||
select {
|
||||
|
@ -667,7 +671,7 @@ func TestClose(t *testing.T) {
|
|||
dur := time.Second
|
||||
var i atomic.Int32
|
||||
runnerWaiting := make(chan struct{})
|
||||
mngr := NewRunnerCloserManager(&dur, func(ctx context.Context) error {
|
||||
mngr := NewRunnerCloserManager(log, &dur, func(ctx context.Context) error {
|
||||
close(runnerWaiting)
|
||||
<-ctx.Done()
|
||||
i.Add(1)
|
||||
|
@ -710,7 +714,7 @@ func TestClose(t *testing.T) {
|
|||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- mngr.Run(context.Background())
|
||||
errCh <- mngr.Run(t.Context())
|
||||
}()
|
||||
|
||||
select {
|
||||
|
@ -754,7 +758,7 @@ func TestClose(t *testing.T) {
|
|||
dur := time.Second
|
||||
var i atomic.Int32
|
||||
runnerWaiting := make(chan struct{})
|
||||
mngr := NewRunnerCloserManager(&dur, func(ctx context.Context) error {
|
||||
mngr := NewRunnerCloserManager(log, &dur, func(ctx context.Context) error {
|
||||
close(runnerWaiting)
|
||||
<-ctx.Done()
|
||||
i.Add(1)
|
||||
|
@ -772,7 +776,7 @@ func TestClose(t *testing.T) {
|
|||
assert.Len(t, mngr.closers, 1)
|
||||
|
||||
returnClose := make(chan struct{})
|
||||
for n := 0; n < 4; n++ {
|
||||
for range 4 {
|
||||
require.NoError(t, mngr.AddCloser(func() {
|
||||
i.Add(1)
|
||||
<-returnClose
|
||||
|
@ -783,7 +787,7 @@ func TestClose(t *testing.T) {
|
|||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- mngr.Run(context.Background())
|
||||
errCh <- mngr.Run(t.Context())
|
||||
}()
|
||||
|
||||
select {
|
||||
|
@ -820,14 +824,14 @@ func TestClose(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("calling close should return the errors from the main runner and all closers", func(t *testing.T) {
|
||||
mngr := NewRunnerCloserManager(nil,
|
||||
func(ctx context.Context) error {
|
||||
mngr := NewRunnerCloserManager(log, nil,
|
||||
func(context.Context) error {
|
||||
return errors.New("error1")
|
||||
},
|
||||
func(ctx context.Context) error {
|
||||
func(context.Context) error {
|
||||
return errors.New("error2")
|
||||
},
|
||||
func(ctx context.Context) error {
|
||||
func(context.Context) error {
|
||||
return errors.New("error3")
|
||||
},
|
||||
)
|
||||
|
@ -846,7 +850,7 @@ func TestClose(t *testing.T) {
|
|||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- mngr.Run(context.Background())
|
||||
errCh <- mngr.Run(t.Context())
|
||||
}()
|
||||
|
||||
var err error
|
||||
|
@ -864,8 +868,8 @@ func TestClose(t *testing.T) {
|
|||
|
||||
t.Run("calling Close before Run should return immediately", func(t *testing.T) {
|
||||
dur := time.Second
|
||||
mngr := NewRunnerCloserManager(&dur,
|
||||
func(ctx context.Context) error {
|
||||
mngr := NewRunnerCloserManager(log, &dur,
|
||||
func(context.Context) error {
|
||||
return errors.New("error1")
|
||||
},
|
||||
)
|
||||
|
@ -875,7 +879,7 @@ func TestClose(t *testing.T) {
|
|||
|
||||
require.NoError(t, mngr.Close())
|
||||
require.NoError(t, mngr.Close())
|
||||
assert.Equal(t, mngr.Run(context.Background()), errors.New("runner manager already started"))
|
||||
assert.Equal(t, mngr.Run(t.Context()), errors.New("runner manager already started"))
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -892,7 +896,7 @@ func TestAddCloser(t *testing.T) {
|
|||
expErr: errors.Join(errors.New("unsupported closer type: int")),
|
||||
},
|
||||
"Add various supported closer types": {
|
||||
closers: []any{new(mockCloser), func(ctx context.Context) error { return nil }, func() error { return nil }, func() {}},
|
||||
closers: []any{new(mockCloser), func(context.Context) error { return nil }, func() error { return nil }, func() {}},
|
||||
expErr: nil,
|
||||
},
|
||||
"Add combination of supported and unsupported closer types": {
|
||||
|
@ -903,18 +907,18 @@ func TestAddCloser(t *testing.T) {
|
|||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
err := NewRunnerCloserManager(nil).AddCloser(test.closers...)
|
||||
err := NewRunnerCloserManager(log, nil).AddCloser(test.closers...)
|
||||
assert.Equalf(t, test.expErr, err, "%v", err)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("no error if adding a closer during main routine", func(t *testing.T) {
|
||||
mngr := NewRunnerCloserManager(nil, func(ctx context.Context) error {
|
||||
mngr := NewRunnerCloserManager(log, nil, func(ctx context.Context) error {
|
||||
<-ctx.Done()
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- mngr.Run(ctx)
|
||||
|
@ -925,9 +929,9 @@ func TestAddCloser(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("should error if closing", func(t *testing.T) {
|
||||
mngr := NewRunnerCloserManager(nil)
|
||||
mngr := NewRunnerCloserManager(log, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
closerCh := make(chan struct{})
|
||||
require.NoError(t, mngr.AddCloser(func() {
|
||||
cancel()
|
||||
|
@ -936,7 +940,7 @@ func TestAddCloser(t *testing.T) {
|
|||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- mngr.Run(context.Background())
|
||||
errCh <- mngr.Run(t.Context())
|
||||
}()
|
||||
|
||||
select {
|
||||
|
@ -968,15 +972,15 @@ func TestAddCloser(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("should error if manager already returned", func(t *testing.T) {
|
||||
mngr := NewRunnerCloserManager(nil)
|
||||
require.NoError(t, mngr.Run(context.Background()))
|
||||
mngr := NewRunnerCloserManager(log, nil)
|
||||
require.NoError(t, mngr.Run(t.Context()))
|
||||
assert.Equal(t, mngr.AddCloser(nil), errors.New("runner manager already closed"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestWaitUntilShutdown(t *testing.T) {
|
||||
dur := time.Second * 3
|
||||
mngr := NewRunnerCloserManager(&dur, func(ctx context.Context) error {
|
||||
mngr := NewRunnerCloserManager(log, &dur, func(ctx context.Context) error {
|
||||
<-ctx.Done()
|
||||
return nil
|
||||
})
|
||||
|
@ -995,7 +999,7 @@ func TestWaitUntilShutdown(t *testing.T) {
|
|||
<-returnClose
|
||||
}))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
|
|
|
@ -0,0 +1,111 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package cmap
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
type AtomicValue[T constraints.Integer] struct {
|
||||
lock sync.RWMutex
|
||||
value T
|
||||
}
|
||||
|
||||
func (a *AtomicValue[T]) Load() T {
|
||||
a.lock.RLock()
|
||||
defer a.lock.RUnlock()
|
||||
return a.value
|
||||
}
|
||||
|
||||
func (a *AtomicValue[T]) Store(v T) {
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
a.value = v
|
||||
}
|
||||
|
||||
func (a *AtomicValue[T]) Add(v T) T {
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
a.value += v
|
||||
return a.value
|
||||
}
|
||||
|
||||
type Atomic[K comparable, T constraints.Integer] interface {
|
||||
Get(key K) (*AtomicValue[T], bool)
|
||||
GetOrCreate(key K, createT T) *AtomicValue[T]
|
||||
Delete(key K)
|
||||
ForEach(fn func(key K, value *AtomicValue[T]))
|
||||
Clear()
|
||||
}
|
||||
|
||||
type atomicMap[K comparable, T constraints.Integer] struct {
|
||||
lock sync.RWMutex
|
||||
items map[K]*AtomicValue[T]
|
||||
}
|
||||
|
||||
func NewAtomic[K comparable, T constraints.Integer]() Atomic[K, T] {
|
||||
return &atomicMap[K, T]{
|
||||
items: make(map[K]*AtomicValue[T]),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *atomicMap[K, T]) Get(key K) (*AtomicValue[T], bool) {
|
||||
a.lock.RLock()
|
||||
defer a.lock.RUnlock()
|
||||
|
||||
item, ok := a.items[key]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return item, true
|
||||
}
|
||||
|
||||
func (a *atomicMap[K, T]) GetOrCreate(key K, createT T) *AtomicValue[T] {
|
||||
a.lock.RLock()
|
||||
item, ok := a.items[key]
|
||||
a.lock.RUnlock()
|
||||
if !ok {
|
||||
a.lock.Lock()
|
||||
// Double-check the key exists to avoid race condition
|
||||
item, ok = a.items[key]
|
||||
if !ok {
|
||||
item = &AtomicValue[T]{value: createT}
|
||||
a.items[key] = item
|
||||
}
|
||||
a.lock.Unlock()
|
||||
}
|
||||
return item
|
||||
}
|
||||
|
||||
func (a *atomicMap[K, T]) Delete(key K) {
|
||||
a.lock.Lock()
|
||||
delete(a.items, key)
|
||||
a.lock.Unlock()
|
||||
}
|
||||
|
||||
func (a *atomicMap[K, T]) ForEach(fn func(key K, value *AtomicValue[T])) {
|
||||
a.lock.RLock()
|
||||
defer a.lock.RUnlock()
|
||||
for k, v := range a.items {
|
||||
fn(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *atomicMap[K, T]) Clear() {
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
clear(a.items)
|
||||
}
|
|
@ -0,0 +1,79 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package cmap
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAtomicInt32_New_Get_Delete(t *testing.T) {
|
||||
m := NewAtomic[string, int32]().(*atomicMap[string, int32])
|
||||
|
||||
require.NotNil(t, m)
|
||||
require.NotNil(t, m.items)
|
||||
require.Empty(t, m.items)
|
||||
|
||||
t.Run("basic operations", func(t *testing.T) {
|
||||
key := "key1"
|
||||
value := int32(10)
|
||||
|
||||
// Initially, the key should not exist
|
||||
_, ok := m.Get(key)
|
||||
require.False(t, ok)
|
||||
|
||||
// Add a value and check it
|
||||
m.GetOrCreate(key, 0).Store(value)
|
||||
result, ok := m.Get(key)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, value, result.Load())
|
||||
|
||||
// Delete the key and check it no longer exists
|
||||
m.Delete(key)
|
||||
_, ok = m.Get(key)
|
||||
require.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("concurrent access multiple keys", func(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
keys := []string{"key1", "key2", "key3"}
|
||||
iterations := 100
|
||||
|
||||
wg.Add(len(keys) * 2)
|
||||
for _, key := range keys {
|
||||
go func(k string) {
|
||||
defer wg.Done()
|
||||
for range iterations {
|
||||
m.GetOrCreate(k, 0).Add(1)
|
||||
}
|
||||
}(key)
|
||||
go func(k string) {
|
||||
defer wg.Done()
|
||||
for range iterations {
|
||||
m.GetOrCreate(k, 0).Add(-1)
|
||||
}
|
||||
}(key)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for _, key := range keys {
|
||||
val, ok := m.Get(key)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int32(0), val.Load())
|
||||
}
|
||||
})
|
||||
}
|
|
@ -0,0 +1,99 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package cmap
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Map is a simple _typed_ map which is safe for concurrent use.
|
||||
// Favoured over sync.Map as it is typed.
|
||||
type Map[K comparable, T any] interface {
|
||||
Clear()
|
||||
Delete(key K)
|
||||
Load(key K) (T, bool)
|
||||
LoadAndDelete(key K) (T, bool)
|
||||
Range(fn func(key K, value T) bool)
|
||||
Store(key K, value T)
|
||||
Len() int
|
||||
Keys() []K
|
||||
}
|
||||
|
||||
type mapimpl[K comparable, T any] struct {
|
||||
lock sync.RWMutex
|
||||
m map[K]T
|
||||
}
|
||||
|
||||
func NewMap[K comparable, T any]() Map[K, T] {
|
||||
return &mapimpl[K, T]{m: make(map[K]T)}
|
||||
}
|
||||
|
||||
func (m *mapimpl[K, T]) Clear() {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
m.m = make(map[K]T)
|
||||
}
|
||||
|
||||
func (m *mapimpl[K, T]) Delete(k K) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
delete(m.m, k)
|
||||
}
|
||||
|
||||
func (m *mapimpl[K, T]) Load(k K) (T, bool) {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
v, ok := m.m[k]
|
||||
return v, ok
|
||||
}
|
||||
|
||||
func (m *mapimpl[K, T]) LoadAndDelete(k K) (T, bool) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
v, ok := m.m[k]
|
||||
delete(m.m, k)
|
||||
return v, ok
|
||||
}
|
||||
|
||||
func (m *mapimpl[K, T]) Range(fn func(K, T) bool) {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
for k, v := range m.m {
|
||||
if !fn(k, v) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mapimpl[K, T]) Store(k K, v T) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
m.m[k] = v
|
||||
}
|
||||
|
||||
func (m *mapimpl[K, T]) Len() int {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
return len(m.m)
|
||||
}
|
||||
|
||||
func (m *mapimpl[K, T]) Keys() []K {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
keys := make([]K, 0, len(m.m))
|
||||
for k := range m.m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}
|
|
@ -0,0 +1,150 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package cmap
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Mutex is an interface that defines a thread-safe map with keys of type T associated to
|
||||
// read-write mutexes (sync.RWMutex), allowing for granular locking on a per-key basis.
|
||||
// This can be useful for scenarios where fine-grained concurrency control is needed.
|
||||
//
|
||||
// Methods:
|
||||
// - Lock(key T): Acquires an exclusive lock on the mutex associated with the given key.
|
||||
// - Unlock(key T): Releases the exclusive lock on the mutex associated with the given key.
|
||||
// - RLock(key T): Acquires a read lock on the mutex associated with the given key.
|
||||
// - RUnlock(key T): Releases the read lock on the mutex associated with the given key.
|
||||
// - Delete(key T): Removes the mutex associated with the given key from the map.
|
||||
// - Clear(): Removes all mutexes from the map.
|
||||
// - ItemCount() int: Returns the number of items (mutexes) in the map.
|
||||
// - DeleteUnlock(key T): Removes the mutex associated with the given key from the map and releases the lock.
|
||||
// - DeleteRUnlock(key T): Removes the mutex associated with the given key from the map and releases the read lock.
|
||||
type Mutex[T comparable] interface {
|
||||
Lock(key T)
|
||||
Unlock(key T)
|
||||
RLock(key T)
|
||||
RUnlock(key T)
|
||||
Delete(key T)
|
||||
Clear()
|
||||
ItemCount() int
|
||||
DeleteUnlock(key T)
|
||||
DeleteRUnlock(key T)
|
||||
}
|
||||
|
||||
type mutex[T comparable] struct {
|
||||
lock sync.RWMutex
|
||||
items map[T]*sync.RWMutex
|
||||
}
|
||||
|
||||
func NewMutex[T comparable]() Mutex[T] {
|
||||
return &mutex[T]{
|
||||
items: make(map[T]*sync.RWMutex),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *mutex[T]) Lock(key T) {
|
||||
a.lock.RLock()
|
||||
mutex, ok := a.items[key]
|
||||
a.lock.RUnlock()
|
||||
if ok {
|
||||
mutex.Lock()
|
||||
return
|
||||
}
|
||||
|
||||
a.lock.Lock()
|
||||
mutex, ok = a.items[key]
|
||||
if !ok {
|
||||
mutex = &sync.RWMutex{}
|
||||
a.items[key] = mutex
|
||||
}
|
||||
a.lock.Unlock()
|
||||
mutex.Lock()
|
||||
}
|
||||
|
||||
func (a *mutex[T]) Unlock(key T) {
|
||||
a.lock.RLock()
|
||||
mutex, ok := a.items[key]
|
||||
if ok {
|
||||
mutex.Unlock()
|
||||
}
|
||||
a.lock.RUnlock()
|
||||
}
|
||||
|
||||
func (a *mutex[T]) RLock(key T) {
|
||||
a.lock.RLock()
|
||||
mutex, ok := a.items[key]
|
||||
a.lock.RUnlock()
|
||||
|
||||
if ok {
|
||||
mutex.RLock()
|
||||
return
|
||||
}
|
||||
|
||||
a.lock.Lock()
|
||||
mutex, ok = a.items[key]
|
||||
if !ok {
|
||||
mutex = &sync.RWMutex{}
|
||||
a.items[key] = mutex
|
||||
}
|
||||
a.lock.Unlock()
|
||||
mutex.RLock()
|
||||
}
|
||||
|
||||
func (a *mutex[T]) RUnlock(key T) {
|
||||
a.lock.RLock()
|
||||
mutex, ok := a.items[key]
|
||||
if ok {
|
||||
mutex.RUnlock()
|
||||
}
|
||||
a.lock.RUnlock()
|
||||
}
|
||||
|
||||
func (a *mutex[T]) Delete(key T) {
|
||||
a.lock.Lock()
|
||||
delete(a.items, key)
|
||||
a.lock.Unlock()
|
||||
}
|
||||
|
||||
func (a *mutex[T]) DeleteUnlock(key T) {
|
||||
a.lock.Lock()
|
||||
mutex, ok := a.items[key]
|
||||
if ok {
|
||||
mutex.Unlock()
|
||||
}
|
||||
delete(a.items, key)
|
||||
a.lock.Unlock()
|
||||
}
|
||||
|
||||
func (a *mutex[T]) DeleteRUnlock(key T) {
|
||||
a.lock.Lock()
|
||||
mutex, ok := a.items[key]
|
||||
if ok {
|
||||
mutex.RUnlock()
|
||||
}
|
||||
delete(a.items, key)
|
||||
a.lock.Unlock()
|
||||
}
|
||||
|
||||
func (a *mutex[T]) Clear() {
|
||||
a.lock.Lock()
|
||||
clear(a.items)
|
||||
a.lock.Unlock()
|
||||
}
|
||||
|
||||
func (a *mutex[T]) ItemCount() int {
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
return len(a.items)
|
||||
}
|
|
@ -0,0 +1,119 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package cmap
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewMutex_Add_Delete(t *testing.T) {
|
||||
mm := NewMutex[string]().(*mutex[string])
|
||||
|
||||
t.Run("New mutex map", func(t *testing.T) {
|
||||
require.NotNil(t, mm)
|
||||
require.NotNil(t, mm.items)
|
||||
require.Empty(t, mm.items)
|
||||
})
|
||||
|
||||
t.Run("Lock and unlock mutex", func(t *testing.T) {
|
||||
mm.Lock("key1")
|
||||
_, ok := mm.items["key1"]
|
||||
require.True(t, ok)
|
||||
mm.Unlock("key1")
|
||||
})
|
||||
|
||||
t.Run("Concurrently lock and unlock mutexes", func(t *testing.T) {
|
||||
var counter atomic.Int64
|
||||
var wg sync.WaitGroup
|
||||
|
||||
numGoroutines := 10
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
// Concurrently lock and unlock for each key
|
||||
for range numGoroutines {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
mm.Lock("key1")
|
||||
counter.Add(1)
|
||||
mm.Unlock("key1")
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
require.Equal(t, int64(10), counter.Load())
|
||||
})
|
||||
|
||||
t.Run("RLock and RUnlock mutex", func(t *testing.T) {
|
||||
mm.RLock("key1")
|
||||
_, ok := mm.items["key1"]
|
||||
require.True(t, ok)
|
||||
mm.RUnlock("key1")
|
||||
})
|
||||
|
||||
t.Run("Concurrently RLock and RUnlock mutexes", func(t *testing.T) {
|
||||
var counter atomic.Int64
|
||||
var wg sync.WaitGroup
|
||||
|
||||
numGoroutines := 10
|
||||
wg.Add(numGoroutines * 2)
|
||||
|
||||
// Concurrently RLock and RUnlock for each key
|
||||
for range numGoroutines {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
mm.RLock("key1")
|
||||
counter.Add(1)
|
||||
}()
|
||||
}
|
||||
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
assert.Equal(ct, int64(10), counter.Load())
|
||||
}, 5*time.Second, 10*time.Millisecond)
|
||||
|
||||
for range numGoroutines {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
mm.RUnlock("key1")
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
|
||||
t.Run("Delete mutex", func(t *testing.T) {
|
||||
mm.Lock("key1")
|
||||
mm.Unlock("key1")
|
||||
mm.Delete("key1")
|
||||
_, ok := mm.items["key1"]
|
||||
require.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("Clear all mutexes, and check item count", func(t *testing.T) {
|
||||
mm.Lock("key1")
|
||||
mm.Unlock("key1")
|
||||
mm.Lock("key2")
|
||||
mm.Unlock("key2")
|
||||
|
||||
require.Equal(t, 2, mm.ItemCount())
|
||||
|
||||
mm.Clear()
|
||||
require.Empty(t, mm.items)
|
||||
})
|
||||
}
|
|
@ -0,0 +1,94 @@
|
|||
/*
|
||||
Copyright 2025 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ctesting
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/dapr/kit/concurrency"
|
||||
"github.com/dapr/kit/concurrency/ctesting/internal"
|
||||
)
|
||||
|
||||
type RunnerFn func(context.Context, assert.TestingT)
|
||||
|
||||
// Assert runs the provided test functions in parallel and asserts that they
|
||||
// all pass.
|
||||
func Assert(t *testing.T, runners ...RunnerFn) {
|
||||
t.Helper()
|
||||
|
||||
if len(runners) == 0 {
|
||||
require.Fail(t, "at least one runner function is required")
|
||||
}
|
||||
|
||||
tt := internal.Assert(t)
|
||||
|
||||
ctx, cancel := context.WithCancelCause(t.Context())
|
||||
t.Cleanup(func() { cancel(nil) })
|
||||
|
||||
doneCh := make(chan struct{}, len(runners))
|
||||
for _, runner := range runners {
|
||||
go func(rfn RunnerFn) {
|
||||
rfn(ctx, tt)
|
||||
if errs := tt.Errors(); len(errs) > 0 {
|
||||
cancel(errors.Join(errs...))
|
||||
}
|
||||
doneCh <- struct{}{}
|
||||
}(runner)
|
||||
}
|
||||
|
||||
for range runners {
|
||||
select {
|
||||
case <-doneCh:
|
||||
case <-t.Context().Done():
|
||||
require.FailNow(t, "test context was cancelled before all runners completed")
|
||||
}
|
||||
}
|
||||
|
||||
for _, err := range tt.Errors() {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
// AssertCleanup runs the provided test functions in parallel and asserts that they
|
||||
// all pass, only after Cleanup,.
|
||||
func AssertCleanup(t *testing.T, runners ...concurrency.Runner) {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithCancelCause(t.Context())
|
||||
|
||||
errCh := make(chan error, len(runners))
|
||||
for _, runner := range runners {
|
||||
go func(rfn concurrency.Runner) {
|
||||
errCh <- rfn(ctx)
|
||||
}(runner)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
cancel(nil)
|
||||
for range runners {
|
||||
select {
|
||||
case err := <-errCh:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(10 * time.Second):
|
||||
assert.Fail(t, "timeout waiting for runner to stop")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
/*
|
||||
Copyright 2025 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type Interface interface {
|
||||
assert.TestingT
|
||||
Errors() []error
|
||||
}
|
||||
|
||||
type assertT struct {
|
||||
t *testing.T
|
||||
lock sync.Mutex
|
||||
errs []error
|
||||
}
|
||||
|
||||
func Assert(t *testing.T) Interface {
|
||||
return &assertT{t: t}
|
||||
}
|
||||
|
||||
func (a *assertT) Errorf(format string, args ...any) {
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
a.errs = append(a.errs, fmt.Errorf(format, args...))
|
||||
}
|
||||
|
||||
func (a *assertT) Errors() []error {
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
return a.errs
|
||||
}
|
|
@ -0,0 +1,90 @@
|
|||
/*
|
||||
Copyright 2025 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package dir
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
Log logger.Logger
|
||||
Target string
|
||||
}
|
||||
|
||||
// Dir atomically writes files to a given directory.
|
||||
type Dir struct {
|
||||
log logger.Logger
|
||||
|
||||
base string
|
||||
target string
|
||||
targetDir string
|
||||
|
||||
prev *string
|
||||
}
|
||||
|
||||
func New(opts Options) *Dir {
|
||||
return &Dir{
|
||||
log: opts.Log,
|
||||
base: filepath.Dir(opts.Target),
|
||||
target: opts.Target,
|
||||
targetDir: filepath.Base(opts.Target),
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Dir) Write(files map[string][]byte) error {
|
||||
newDir := filepath.Join(d.base, fmt.Sprintf("%d-%s", time.Now().UTC().UnixNano(), d.targetDir))
|
||||
|
||||
if err := os.MkdirAll(d.base, 0o700); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(newDir, 0o700); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for file, b := range files {
|
||||
path := filepath.Join(newDir, file)
|
||||
if err := os.WriteFile(path, b, 0o600); err != nil {
|
||||
return err
|
||||
}
|
||||
d.log.Infof("Written file %s", file)
|
||||
}
|
||||
|
||||
if err := os.Symlink(newDir, d.target+".new"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.log.Infof("Syslink %s to %s.new", newDir, d.target)
|
||||
|
||||
if err := os.Rename(d.target+".new", d.target); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.log.Infof("Atomic write to %s", d.target)
|
||||
|
||||
if d.prev != nil {
|
||||
if err := os.RemoveAll(*d.prev); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
d.prev = &newDir
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package fifo
|
||||
|
||||
// Map is a map of mutexes whose locks are acquired in a FIFO order. The map is
|
||||
// pruned automatically when all locks have been released for a key.
|
||||
type Map[T comparable] interface {
|
||||
Lock(key T)
|
||||
Unlock(key T)
|
||||
}
|
||||
|
||||
type mapItem struct {
|
||||
ilen uint64
|
||||
mutex *Mutex
|
||||
}
|
||||
|
||||
type fifoMap[T comparable] struct {
|
||||
lock *Mutex
|
||||
items map[T]*mapItem
|
||||
}
|
||||
|
||||
func NewMap[T comparable]() Map[T] {
|
||||
return &fifoMap[T]{
|
||||
lock: New(),
|
||||
items: make(map[T]*mapItem),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *fifoMap[T]) Lock(key T) {
|
||||
a.lock.Lock()
|
||||
m, ok := a.items[key]
|
||||
if !ok {
|
||||
m = &mapItem{mutex: New()}
|
||||
a.items[key] = m
|
||||
}
|
||||
m.ilen++
|
||||
a.lock.Unlock()
|
||||
|
||||
m.mutex.Lock()
|
||||
}
|
||||
|
||||
func (a *fifoMap[T]) Unlock(key T) {
|
||||
a.lock.Lock()
|
||||
m := a.items[key]
|
||||
m.ilen--
|
||||
if m.ilen == 0 {
|
||||
delete(a.items, key)
|
||||
}
|
||||
a.lock.Unlock()
|
||||
m.mutex.Unlock()
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package fifo
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_Map(t *testing.T) {
|
||||
m := NewMap[string]().(*fifoMap[string])
|
||||
|
||||
assert.Empty(t, m.items)
|
||||
|
||||
m.Lock("key1")
|
||||
assert.Len(t, m.items, 1)
|
||||
assert.Equal(t, uint64(1), m.items["key1"].ilen)
|
||||
|
||||
go func() {
|
||||
m.Lock("key1")
|
||||
}()
|
||||
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
m.lock.Lock()
|
||||
assert.Equal(c, uint64(2), m.items["key1"].ilen)
|
||||
m.lock.Unlock()
|
||||
}, time.Second*3, time.Millisecond*10)
|
||||
|
||||
m.Unlock("key1")
|
||||
assert.Equal(t, uint64(1), m.items["key1"].ilen)
|
||||
|
||||
m.Unlock("key1")
|
||||
assert.Empty(t, m.items)
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package fifo
|
||||
|
||||
// Mutex is a mutex lock whose lock and unlock operations are
|
||||
// first-in-first-out (FIFO).
|
||||
type Mutex struct {
|
||||
lock chan struct{}
|
||||
}
|
||||
|
||||
func New() *Mutex {
|
||||
return &Mutex{
|
||||
lock: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mutex) Lock() {
|
||||
m.lock <- struct{}{}
|
||||
}
|
||||
|
||||
func (m *Mutex) Unlock() {
|
||||
<-m.lock
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
/*
|
||||
Copyright 2025 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package lock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Context is a ready write mutex lock where Locking can return early with an
|
||||
// error if the context is done. No error response means the lock is acquired.
|
||||
type Context struct {
|
||||
lock sync.RWMutex
|
||||
locked chan struct{}
|
||||
}
|
||||
|
||||
func NewContext() *Context {
|
||||
return &Context{
|
||||
locked: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Context) Lock(ctx context.Context) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case c.locked <- struct{}{}:
|
||||
c.lock.Lock()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Context) Unlock() {
|
||||
c.lock.Unlock()
|
||||
<-c.locked
|
||||
}
|
||||
|
||||
func (c *Context) RLock(ctx context.Context) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case c.locked <- struct{}{}:
|
||||
c.lock.RLock()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Context) RUnlock() {
|
||||
c.lock.RUnlock()
|
||||
<-c.locked
|
||||
}
|
|
@ -0,0 +1,81 @@
|
|||
/*
|
||||
Copyright 2025 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package lock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_Context(t *testing.T) {
|
||||
tests := map[string]struct {
|
||||
name string
|
||||
action func(l *Context) error
|
||||
expectError bool
|
||||
}{
|
||||
"Successful Lock": {
|
||||
action: func(l *Context) error {
|
||||
return l.Lock(t.Context())
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
"Lock with Context Timeout": {
|
||||
action: func(l *Context) error {
|
||||
l.Lock(t.Context())
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Millisecond*50)
|
||||
defer cancel()
|
||||
return l.Lock(ctx)
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
"Successful RLock": {
|
||||
action: func(l *Context) error {
|
||||
return l.RLock(t.Context())
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
"RLock with Context Timeout": {
|
||||
action: func(l *Context) error {
|
||||
l.Lock(t.Context())
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Millisecond*50)
|
||||
defer cancel()
|
||||
return l.RLock(ctx)
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := NewContext()
|
||||
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
done <- test.action(l)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
assert.Equal(t, (err != nil), test.expectError, "unexpected error, expected error: %v, got: %v", test.expectError, err)
|
||||
case <-time.After(time.Second):
|
||||
t.Errorf("test timed out")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,199 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package lock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/dapr/kit/concurrency/fifo"
|
||||
)
|
||||
|
||||
var errLockClosed = errors.New("lock closed")
|
||||
|
||||
type hold struct {
|
||||
writeLock bool
|
||||
rctx context.Context
|
||||
respCh chan *holdresp
|
||||
}
|
||||
|
||||
type holdresp struct {
|
||||
rctx context.Context
|
||||
cancel context.CancelFunc
|
||||
err error
|
||||
}
|
||||
|
||||
type OuterCancel struct {
|
||||
ch chan *hold
|
||||
cancelErr error
|
||||
gracefulTimeout time.Duration
|
||||
|
||||
lock chan struct{}
|
||||
|
||||
wg sync.WaitGroup
|
||||
rcancelLock sync.Mutex
|
||||
rcancelx uint64
|
||||
rcancels map[uint64]context.CancelFunc
|
||||
|
||||
closeCh chan struct{}
|
||||
shutdownLock *fifo.Mutex
|
||||
}
|
||||
|
||||
func NewOuterCancel(cancelErr error, gracefulTimeout time.Duration) *OuterCancel {
|
||||
return &OuterCancel{
|
||||
lock: make(chan struct{}, 1),
|
||||
ch: make(chan *hold, 1),
|
||||
rcancels: make(map[uint64]context.CancelFunc),
|
||||
closeCh: make(chan struct{}),
|
||||
shutdownLock: fifo.New(),
|
||||
cancelErr: cancelErr,
|
||||
gracefulTimeout: gracefulTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
func (o *OuterCancel) Run(ctx context.Context) {
|
||||
defer func() {
|
||||
o.rcancelLock.Lock()
|
||||
defer o.rcancelLock.Unlock()
|
||||
|
||||
for _, cancel := range o.rcancels {
|
||||
go cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
close(o.closeCh)
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-o.closeCh:
|
||||
return
|
||||
case h := <-o.ch:
|
||||
o.handleHold(h)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (o *OuterCancel) handleHold(h *hold) {
|
||||
if h.rctx != nil {
|
||||
select {
|
||||
case o.lock <- struct{}{}:
|
||||
case <-h.rctx.Done():
|
||||
h.respCh <- &holdresp{err: h.rctx.Err()}
|
||||
return
|
||||
}
|
||||
} else {
|
||||
o.lock <- struct{}{}
|
||||
}
|
||||
|
||||
o.rcancelLock.Lock()
|
||||
|
||||
if h.writeLock {
|
||||
for _, cancel := range o.rcancels {
|
||||
go cancel()
|
||||
}
|
||||
o.rcancelx = 0
|
||||
o.rcancelLock.Unlock()
|
||||
o.wg.Wait()
|
||||
|
||||
h.respCh <- &holdresp{cancel: func() { <-o.lock }}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
o.wg.Add(1)
|
||||
var done bool
|
||||
doneCh := make(chan bool)
|
||||
rctx, cancel := context.WithCancelCause(h.rctx)
|
||||
i := o.rcancelx
|
||||
|
||||
rcancel := func() {
|
||||
o.rcancelLock.Lock()
|
||||
if !done {
|
||||
close(doneCh)
|
||||
cancel(o.cancelErr)
|
||||
delete(o.rcancels, i)
|
||||
o.wg.Done()
|
||||
done = true
|
||||
}
|
||||
o.rcancelLock.Unlock()
|
||||
}
|
||||
|
||||
rcancelGrace := func() {
|
||||
select {
|
||||
case <-time.After(o.gracefulTimeout):
|
||||
case <-o.closeCh:
|
||||
case <-doneCh:
|
||||
}
|
||||
rcancel()
|
||||
}
|
||||
|
||||
o.rcancels[i] = rcancelGrace
|
||||
o.rcancelx++
|
||||
|
||||
o.rcancelLock.Unlock()
|
||||
|
||||
h.respCh <- &holdresp{rctx: rctx, cancel: rcancel}
|
||||
|
||||
<-o.lock
|
||||
}
|
||||
|
||||
func (o *OuterCancel) Lock() context.CancelFunc {
|
||||
h := hold{
|
||||
writeLock: true,
|
||||
respCh: make(chan *holdresp, 1),
|
||||
}
|
||||
|
||||
select {
|
||||
case <-o.closeCh:
|
||||
o.shutdownLock.Lock()
|
||||
return o.shutdownLock.Unlock
|
||||
case o.ch <- &h:
|
||||
}
|
||||
|
||||
select {
|
||||
case <-o.closeCh:
|
||||
o.shutdownLock.Lock()
|
||||
return o.shutdownLock.Unlock
|
||||
case resp := <-h.respCh:
|
||||
return resp.cancel
|
||||
}
|
||||
}
|
||||
|
||||
func (o *OuterCancel) RLock(ctx context.Context) (context.Context, context.CancelFunc, error) {
|
||||
h := hold{
|
||||
writeLock: false,
|
||||
rctx: ctx,
|
||||
respCh: make(chan *holdresp, 1),
|
||||
}
|
||||
|
||||
select {
|
||||
case <-o.closeCh:
|
||||
return nil, nil, errLockClosed
|
||||
case <-ctx.Done():
|
||||
return nil, nil, ctx.Err()
|
||||
case o.ch <- &h:
|
||||
}
|
||||
|
||||
select {
|
||||
case <-o.closeCh:
|
||||
return nil, nil, errLockClosed
|
||||
case resp := <-h.respCh:
|
||||
return resp.rctx, resp.cancel, resp.err
|
||||
}
|
||||
}
|
|
@ -0,0 +1,225 @@
|
|||
/*
|
||||
Copyright 2025 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package lock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_OuterCancel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
//nolint:err113
|
||||
terr := errors.New("test")
|
||||
|
||||
t.Run("can rlock multiple times", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := NewOuterCancel(terr, time.Second)
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
go l.Run(ctx)
|
||||
|
||||
ctx1, c1, err := l.RLock(ctx)
|
||||
require.NoError(t, err)
|
||||
ctx2, c2, err := l.RLock(ctx)
|
||||
require.NoError(t, err)
|
||||
ctx3, c3, err := l.RLock(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, ctx1.Err())
|
||||
require.NoError(t, ctx2.Err())
|
||||
require.NoError(t, ctx3.Err())
|
||||
|
||||
c1()
|
||||
require.Error(t, ctx1.Err())
|
||||
require.NoError(t, ctx2.Err())
|
||||
require.NoError(t, ctx3.Err())
|
||||
|
||||
c2()
|
||||
require.Error(t, ctx1.Err())
|
||||
require.Error(t, ctx2.Err())
|
||||
require.NoError(t, ctx3.Err())
|
||||
|
||||
c3()
|
||||
require.Error(t, ctx1.Err())
|
||||
require.Error(t, ctx2.Err())
|
||||
require.Error(t, ctx3.Err())
|
||||
})
|
||||
|
||||
t.Run("rlock unlock removes cancel state", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := NewOuterCancel(terr, time.Second)
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
go l.Run(ctx)
|
||||
|
||||
_, c1, err := l.RLock(ctx)
|
||||
require.NoError(t, err)
|
||||
_, c2, err := l.RLock(ctx)
|
||||
require.NoError(t, err)
|
||||
_, c3, err := l.RLock(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, l.rcancels, 3)
|
||||
c1()
|
||||
assert.Len(t, l.rcancels, 2)
|
||||
c2()
|
||||
assert.Len(t, l.rcancels, 1)
|
||||
c3()
|
||||
assert.Empty(t, l.rcancels, 0)
|
||||
})
|
||||
|
||||
t.Run("calling lock cancels all current rlocks", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := NewOuterCancel(terr, time.Second)
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
go l.Run(ctx)
|
||||
|
||||
ctx1, _, err := l.RLock(ctx)
|
||||
require.NoError(t, err)
|
||||
ctx2, _, err := l.RLock(ctx)
|
||||
require.NoError(t, err)
|
||||
ctx3, _, err := l.RLock(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, ctx1.Err())
|
||||
require.NoError(t, ctx2.Err())
|
||||
require.NoError(t, ctx3.Err())
|
||||
|
||||
mcancel := l.Lock()
|
||||
require.Error(t, ctx1.Err())
|
||||
require.Error(t, ctx2.Err())
|
||||
require.Error(t, ctx3.Err())
|
||||
mcancel()
|
||||
|
||||
assert.Empty(t, l.rcancels)
|
||||
})
|
||||
|
||||
t.Run("rlock when closed should error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := NewOuterCancel(terr, time.Second)
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
cancel()
|
||||
|
||||
go l.Run(ctx)
|
||||
|
||||
select {
|
||||
case <-l.closeCh:
|
||||
case <-time.After(time.Second * 5):
|
||||
assert.Fail(t, "expected close")
|
||||
}
|
||||
|
||||
_, _, err := l.RLock(t.Context())
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("lock continues to work after close", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := NewOuterCancel(terr, time.Second)
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
cancel()
|
||||
|
||||
l.Run(ctx)
|
||||
|
||||
lcancel := l.Lock()
|
||||
lcancel()
|
||||
lcancel = l.Lock()
|
||||
lcancel()
|
||||
})
|
||||
|
||||
t.Run("rlock blocks until outter unlocks", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := NewOuterCancel(terr, time.Second)
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
go l.Run(ctx)
|
||||
|
||||
lcancel := l.Lock()
|
||||
|
||||
gotRLock := make(chan struct{})
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
_, c1, err := l.RLock(ctx)
|
||||
errCh <- err
|
||||
t.Cleanup(c1)
|
||||
close(gotRLock)
|
||||
}()
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, <-errCh)
|
||||
})
|
||||
|
||||
select {
|
||||
case <-time.After(time.Millisecond * 500):
|
||||
case <-gotRLock:
|
||||
require.Fail(t, "unexpected rlock")
|
||||
}
|
||||
|
||||
lcancel()
|
||||
<-gotRLock
|
||||
})
|
||||
|
||||
t.Run("lock blocks until outter unlocks", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := NewOuterCancel(terr, time.Second)
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
go l.Run(ctx)
|
||||
|
||||
lcancel := l.Lock()
|
||||
|
||||
gotLock := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
lockcancel := l.Lock()
|
||||
t.Cleanup(lockcancel)
|
||||
close(gotLock)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-time.After(time.Millisecond * 500):
|
||||
case <-gotLock:
|
||||
require.Fail(t, "unexpected rlock")
|
||||
}
|
||||
|
||||
lcancel()
|
||||
})
|
||||
}
|
|
@ -60,17 +60,12 @@ func (r *RunnerManager) Run(ctx context.Context) error {
|
|||
return ErrManagerAlreadyStarted
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
ctx, cancel := context.WithCancelCause(ctx)
|
||||
defer cancel(nil)
|
||||
|
||||
errCh := make(chan error)
|
||||
for _, runner := range r.runners {
|
||||
go func(runner Runner) {
|
||||
// Since the task returned, we need to cancel all other tasks.
|
||||
// This is a noop if the parent context is already cancelled, or another
|
||||
// task returned before this one.
|
||||
defer cancel()
|
||||
|
||||
// Ignore context cancelled errors since errors from a runner manager
|
||||
// will likely determine the exit code of the program.
|
||||
// Context cancelled errors are also not really useful to the user in
|
||||
|
@ -78,15 +73,20 @@ func (r *RunnerManager) Run(ctx context.Context) error {
|
|||
rErr := runner(ctx)
|
||||
if rErr != nil && !errors.Is(rErr, context.Canceled) {
|
||||
errCh <- rErr
|
||||
// Since the task returned, we need to cancel all other tasks.
|
||||
// This is a noop if the parent context is already cancelled, or another
|
||||
// task returned before this one.
|
||||
cancel(rErr)
|
||||
return
|
||||
}
|
||||
errCh <- nil
|
||||
cancel(nil)
|
||||
}(runner)
|
||||
}
|
||||
|
||||
// Collect all errors
|
||||
errObjs := make([]error, 0)
|
||||
for i := 0; i < len(r.runners); i++ {
|
||||
for range len(r.runners) {
|
||||
err := <-errCh
|
||||
if err != nil {
|
||||
errObjs = append(errObjs, err)
|
||||
|
|
|
@ -27,7 +27,7 @@ import (
|
|||
|
||||
func Test_RunnerManager(t *testing.T) {
|
||||
t.Run("runner with no tasks should return nil", func(t *testing.T) {
|
||||
require.NoError(t, NewRunnerManager().Run(context.Background()))
|
||||
require.NoError(t, NewRunnerManager().Run(t.Context()))
|
||||
})
|
||||
|
||||
t.Run("runner with a task that completes should return nil", func(t *testing.T) {
|
||||
|
@ -35,7 +35,7 @@ func Test_RunnerManager(t *testing.T) {
|
|||
require.NoError(t, NewRunnerManager(func(ctx context.Context) error {
|
||||
atomic.AddInt32(&i, 1)
|
||||
return nil
|
||||
}).Run(context.Background()))
|
||||
}).Run(t.Context()))
|
||||
assert.Equal(t, int32(1), i)
|
||||
})
|
||||
|
||||
|
@ -54,7 +54,7 @@ func Test_RunnerManager(t *testing.T) {
|
|||
atomic.AddInt32(&i, 1)
|
||||
return nil
|
||||
},
|
||||
).Run(context.Background()))
|
||||
).Run(t.Context()))
|
||||
assert.Equal(t, int32(3), i)
|
||||
})
|
||||
|
||||
|
@ -73,7 +73,7 @@ func Test_RunnerManager(t *testing.T) {
|
|||
atomic.AddInt32(&i, 1)
|
||||
return nil
|
||||
},
|
||||
).Run(context.Background()), "error")
|
||||
).Run(t.Context()), "error")
|
||||
assert.Equal(t, int32(3), i)
|
||||
})
|
||||
|
||||
|
@ -92,7 +92,7 @@ func Test_RunnerManager(t *testing.T) {
|
|||
atomic.AddInt32(&i, 1)
|
||||
return errors.New("error")
|
||||
},
|
||||
).Run(context.Background())
|
||||
).Run(t.Context())
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "error\nerror\nerror") //nolint:dupword
|
||||
assert.Equal(t, int32(3), i)
|
||||
|
@ -113,7 +113,7 @@ func Test_RunnerManager(t *testing.T) {
|
|||
atomic.AddInt32(&i, 1)
|
||||
return errors.New("error3")
|
||||
},
|
||||
).Run(context.Background())
|
||||
).Run(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.ElementsMatch(t, []string{"error1", "error2", "error3"}, strings.Split(err.Error(), "\n"))
|
||||
assert.Equal(t, int32(3), i)
|
||||
|
@ -139,7 +139,7 @@ func Test_RunnerManager(t *testing.T) {
|
|||
return nil
|
||||
},
|
||||
))
|
||||
require.NoError(t, mngr.Run(context.Background()))
|
||||
require.NoError(t, mngr.Run(t.Context()))
|
||||
assert.Equal(t, int32(3), i)
|
||||
})
|
||||
|
||||
|
@ -168,7 +168,7 @@ func Test_RunnerManager(t *testing.T) {
|
|||
}
|
||||
return nil
|
||||
},
|
||||
).Run(context.Background()))
|
||||
).Run(t.Context()))
|
||||
assert.Equal(t, int32(3), i)
|
||||
})
|
||||
|
||||
|
@ -197,7 +197,7 @@ func Test_RunnerManager(t *testing.T) {
|
|||
atomic.AddInt32(&i, 1)
|
||||
return errors.New("error3")
|
||||
},
|
||||
).Run(context.Background())
|
||||
).Run(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.ElementsMatch(t, []string{"error1", "error2", "error3"}, strings.Split(err.Error(), "\n"))
|
||||
assert.Equal(t, int32(3), i)
|
||||
|
@ -209,9 +209,9 @@ func Test_RunnerManager(t *testing.T) {
|
|||
atomic.AddInt32(&i, 1)
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, m.Run(context.Background()))
|
||||
require.NoError(t, m.Run(t.Context()))
|
||||
assert.Equal(t, int32(1), i)
|
||||
require.EqualError(t, m.Run(context.Background()), "runner manager already started")
|
||||
require.EqualError(t, m.Run(t.Context()), "runner manager already started")
|
||||
assert.Equal(t, int32(1), i)
|
||||
})
|
||||
|
||||
|
@ -221,7 +221,7 @@ func Test_RunnerManager(t *testing.T) {
|
|||
atomic.AddInt32(&i, 1)
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, m.Run(context.Background()))
|
||||
require.NoError(t, m.Run(t.Context()))
|
||||
assert.Equal(t, int32(1), i)
|
||||
err := m.Add(func(ctx context.Context) error {
|
||||
atomic.AddInt32(&i, 1)
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package slice
|
||||
|
||||
import "sync"
|
||||
|
||||
// Slice is a concurrent safe types slice
|
||||
type Slice[T any] interface {
|
||||
Append(items ...T) int
|
||||
Len() int
|
||||
Slice() []T
|
||||
Store(items ...T)
|
||||
}
|
||||
|
||||
type slice[T any] struct {
|
||||
lock sync.RWMutex
|
||||
data []T
|
||||
}
|
||||
|
||||
func New[T any]() Slice[T] {
|
||||
return new(slice[T])
|
||||
}
|
||||
|
||||
func (s *slice[T]) Append(items ...T) int {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.data = append(s.data, items...)
|
||||
return len(s.data)
|
||||
}
|
||||
|
||||
func (s *slice[T]) Len() int {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
return len(s.data)
|
||||
}
|
||||
|
||||
func (s *slice[T]) Slice() []T {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
return s.data
|
||||
}
|
||||
|
||||
func (s *slice[T]) Store(items ...T) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.data = items
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package slice
|
||||
|
||||
func String() Slice[string] {
|
||||
return new(slice[string])
|
||||
}
|
|
@ -81,7 +81,7 @@ func decodeString(f reflect.Type, t reflect.Type, data any) (any, error) {
|
|||
if t.Implements(typeStringDecoder) {
|
||||
result = reflect.New(t.Elem()).Interface()
|
||||
decoder = result.(StringDecoder)
|
||||
} else if reflect.PtrTo(t).Implements(typeStringDecoder) {
|
||||
} else if reflect.PointerTo(t).Implements(typeStringDecoder) {
|
||||
result = reflect.New(t).Interface()
|
||||
decoder = result.(StringDecoder)
|
||||
}
|
||||
|
|
|
@ -52,6 +52,9 @@ func NewPool(ctx ...context.Context) *Pool {
|
|||
go func() {
|
||||
defer cancel()
|
||||
defer p.lock.RUnlock()
|
||||
//nolint:intrange
|
||||
// for loops are evaluated on every loop while range are evaluated over a snapshot of the slice as it
|
||||
// existed when the loop started
|
||||
for i := 0; i < len(p.pool); i++ {
|
||||
ch := p.pool[i]
|
||||
p.lock.RUnlock()
|
||||
|
|
|
@ -37,7 +37,7 @@ func Test_Pool(t *testing.T) {
|
|||
|
||||
t.Run("a cancelled context given to pool, should have pool cancelled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
cancel()
|
||||
pool := NewPool(ctx)
|
||||
select {
|
||||
|
@ -49,10 +49,10 @@ func Test_Pool(t *testing.T) {
|
|||
|
||||
t.Run("a cancelled context given to pool, given a new context, should still have pool cancelled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
cancel()
|
||||
pool := NewPool(ctx)
|
||||
pool.Add(context.Background())
|
||||
pool.Add(t.Context())
|
||||
select {
|
||||
case <-pool.Done():
|
||||
case <-time.After(time.Second):
|
||||
|
@ -65,13 +65,13 @@ func Test_Pool(t *testing.T) {
|
|||
var ctx [50]context.Context
|
||||
var cancel [50]context.CancelFunc
|
||||
|
||||
ctx[0], cancel[0] = context.WithCancel(context.Background())
|
||||
pool := NewPool(ctx[0])
|
||||
ctxPool := make([]context.Context, 0, 50)
|
||||
|
||||
for i := 1; i < 50; i++ {
|
||||
ctx[i], cancel[i] = context.WithCancel(context.Background())
|
||||
pool.Add(ctx[i])
|
||||
for i := range 50 {
|
||||
ctx[i], cancel[i] = context.WithCancel(t.Context())
|
||||
ctxPool = append(ctxPool, ctx[i])
|
||||
}
|
||||
pool := NewPool(ctxPool...)
|
||||
|
||||
//nolint:gosec
|
||||
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
|
@ -80,7 +80,7 @@ func Test_Pool(t *testing.T) {
|
|||
cancel[i], cancel[j] = cancel[j], cancel[i]
|
||||
})
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
for i := range 50 {
|
||||
select {
|
||||
case <-pool.Done():
|
||||
t.Error("expected context to not be cancelled")
|
||||
|
@ -99,8 +99,8 @@ func Test_Pool(t *testing.T) {
|
|||
t.Run("pool size will not increase if the given contexts have been cancelled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx1, cancel1 := context.WithCancel(context.Background())
|
||||
ctx2, cancel2 := context.WithCancel(context.Background())
|
||||
ctx1, cancel1 := context.WithCancel(t.Context())
|
||||
ctx2, cancel2 := context.WithCancel(t.Context())
|
||||
pool := NewPool(ctx1, ctx2)
|
||||
assert.Equal(t, 2, pool.Size())
|
||||
|
||||
|
@ -111,19 +111,19 @@ func Test_Pool(t *testing.T) {
|
|||
case <-time.After(time.Second):
|
||||
t.Error("expected context pool to be cancelled")
|
||||
}
|
||||
pool.Add(context.Background())
|
||||
pool.Add(t.Context())
|
||||
assert.Equal(t, 2, pool.Size())
|
||||
})
|
||||
|
||||
t.Run("pool size will not increase if the pool has been closed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx1 := context.Background()
|
||||
ctx2 := context.Background()
|
||||
ctx1 := t.Context()
|
||||
ctx2 := t.Context()
|
||||
pool := NewPool(ctx1, ctx2)
|
||||
assert.Equal(t, 2, pool.Size())
|
||||
pool.Cancel()
|
||||
pool.Add(context.Background())
|
||||
pool.Add(t.Context())
|
||||
assert.Equal(t, 0, pool.Size())
|
||||
select {
|
||||
case <-pool.Done():
|
||||
|
@ -131,4 +131,24 @@ func Test_Pool(t *testing.T) {
|
|||
t.Error("expected context pool to be cancelled")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wait for added context to be closed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx1, cancel1 := context.WithCancel(t.Context())
|
||||
pool := NewPool(ctx1)
|
||||
|
||||
ctx2, cancel2 := context.WithCancel(t.Context())
|
||||
pool.Add(ctx2)
|
||||
|
||||
assert.Equal(t, 2, pool.Size())
|
||||
cancel1()
|
||||
|
||||
select {
|
||||
case <-pool.Done():
|
||||
t.Error("expected context pool to not be cancelled")
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
}
|
||||
cancel2()
|
||||
})
|
||||
}
|
||||
|
|
|
@ -170,7 +170,7 @@ func TestChainDelayIfStillRunning(t *testing.T) {
|
|||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
started, done = j.Started(), j.Done()
|
||||
if started != 2 || done != 2 {
|
||||
c.Errorf("expected both jobs done, got %v %v", started, done) //nolint:testifylint
|
||||
c.Errorf("expected both jobs done, got %v %v", started, done)
|
||||
}
|
||||
}, 100*time.Millisecond, 10*time.Millisecond)
|
||||
})
|
||||
|
@ -230,7 +230,7 @@ func TestChainSkipIfStillRunning(t *testing.T) {
|
|||
var j countJob
|
||||
j.delay = 10 * time.Millisecond
|
||||
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
|
||||
for i := 0; i < 11; i++ {
|
||||
for range 11 {
|
||||
go wrappedJob.Run()
|
||||
}
|
||||
assert.Eventually(t, j.clock.HasWaiters, 50*time.Millisecond, 10*time.Millisecond)
|
||||
|
@ -248,7 +248,7 @@ func TestChainSkipIfStillRunning(t *testing.T) {
|
|||
chain := NewChain(SkipIfStillRunning(DiscardLogger))
|
||||
wrappedJob1 := chain.Then(&j1)
|
||||
wrappedJob2 := chain.Then(&j2)
|
||||
for i := 0; i < 11; i++ {
|
||||
for range 11 {
|
||||
go wrappedJob1.Run()
|
||||
go wrappedJob2.Run()
|
||||
}
|
||||
|
|
|
@ -14,31 +14,23 @@ You can check the original license at:
|
|||
https://github.com/robfig/cron/blob/master/LICENSE
|
||||
*/
|
||||
|
||||
//nolint
|
||||
package cron
|
||||
|
||||
import "time"
|
||||
|
||||
// ConstantDelaySchedule represents a simple recurring duty cycle, e.g. "Every 5 minutes".
|
||||
// It does not support jobs more frequent than once a second.
|
||||
type ConstantDelaySchedule struct {
|
||||
Delay time.Duration
|
||||
}
|
||||
|
||||
// Every returns a crontab Schedule that activates once every duration.
|
||||
// Delays of less than a second are not supported (will round up to 1 second).
|
||||
// Any fields less than a Second are truncated.
|
||||
func Every(duration time.Duration) ConstantDelaySchedule {
|
||||
if duration < time.Second {
|
||||
duration = time.Second
|
||||
}
|
||||
return ConstantDelaySchedule{
|
||||
Delay: duration - time.Duration(duration.Nanoseconds())%time.Second,
|
||||
Delay: duration,
|
||||
}
|
||||
}
|
||||
|
||||
// Next returns the next time this should be run.
|
||||
// This rounds so that the next activation time will be on the second.
|
||||
func (schedule ConstantDelaySchedule) Next(t time.Time) time.Time {
|
||||
return t.Add(schedule.Delay - time.Duration(t.Nanosecond())*time.Nanosecond)
|
||||
return t.Add(schedule.Delay)
|
||||
}
|
||||
|
|
|
@ -14,7 +14,6 @@ You can check the original license at:
|
|||
https://github.com/robfig/cron/blob/master/LICENSE
|
||||
*/
|
||||
|
||||
//nolint
|
||||
package cron
|
||||
|
||||
import (
|
||||
|
@ -29,9 +28,12 @@ func TestConstantDelayNext(t *testing.T) {
|
|||
expected string
|
||||
}{
|
||||
// Simple cases
|
||||
{"Mon Jul 9 14:45 2012", 15*time.Minute + 50*time.Nanosecond, "Mon Jul 9 15:00 2012"},
|
||||
{"Mon Jul 9 14:45 2012", 15*time.Minute + 50*time.Nanosecond, "Mon Jul 9 15:00:00.00000005 2012"},
|
||||
{"Mon Jul 9 14:59 2012", 15 * time.Minute, "Mon Jul 9 15:14 2012"},
|
||||
{"Mon Jul 9 14:59:59 2012", 15 * time.Minute, "Mon Jul 9 15:14:59 2012"},
|
||||
{"Mon Jul 9 14:45:00 2012", 15 * time.Millisecond, "Mon Jul 9 14:45:00.015 2012"},
|
||||
{"Mon Jul 9 14:45:00.015 2012", 15 * time.Millisecond, "Mon Jul 9 14:45:00.030 2012"},
|
||||
{"Mon Jul 9 14:45:00.000000050 2012", 15 * time.Nanosecond, "Mon Jul 9 14:45:00.000000065 2012"},
|
||||
|
||||
// Wrap around hours
|
||||
{"Mon Jul 9 15:45 2012", 35 * time.Minute, "Mon Jul 9 16:20 2012"},
|
||||
|
@ -47,18 +49,6 @@ func TestConstantDelayNext(t *testing.T) {
|
|||
|
||||
// Wrap around minute, hour, day, month, and year
|
||||
{"Mon Dec 31 23:59:45 2012", 15 * time.Second, "Tue Jan 1 00:00:00 2013"},
|
||||
|
||||
// Round to nearest second on the delay
|
||||
{"Mon Jul 9 14:45 2012", 15*time.Minute + 50*time.Nanosecond, "Mon Jul 9 15:00 2012"},
|
||||
|
||||
// Round up to 1 second if the duration is less.
|
||||
{"Mon Jul 9 14:45:00 2012", 15 * time.Millisecond, "Mon Jul 9 14:45:01 2012"},
|
||||
|
||||
// Round to nearest second when calculating the next time.
|
||||
{"Mon Jul 9 14:45:00.005 2012", 15 * time.Minute, "Mon Jul 9 15:00 2012"},
|
||||
|
||||
// Round to nearest second for both.
|
||||
{"Mon Jul 9 14:45:00.005 2012", 15*time.Minute + 50*time.Nanosecond, "Mon Jul 9 15:00 2012"},
|
||||
}
|
||||
|
||||
for _, c := range tests {
|
||||
|
|
|
@ -14,7 +14,6 @@ You can check the original license at:
|
|||
https://github.com/robfig/cron/blob/master/LICENSE
|
||||
*/
|
||||
|
||||
//nolint:dupword
|
||||
package cron
|
||||
|
||||
import (
|
||||
|
@ -35,7 +34,7 @@ import (
|
|||
// for it to run. This amount is just slightly larger than 1 second to
|
||||
// compensate for a few milliseconds of runtime.
|
||||
//
|
||||
//nolint:revive
|
||||
|
||||
const OneSecond = 1*time.Second + 50*time.Millisecond
|
||||
|
||||
type syncWriter struct {
|
||||
|
@ -783,13 +782,68 @@ func TestMockClock(t *testing.T) {
|
|||
})
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
for i := 0; i <= 10; i++ {
|
||||
for range 11 {
|
||||
assert.Eventually(t, clk.HasWaiters, OneSecond, 10*time.Millisecond)
|
||||
clk.Step(1 * time.Second)
|
||||
}
|
||||
assert.Equal(t, int64(10), counter.Load())
|
||||
}
|
||||
|
||||
func TestMillisecond(t *testing.T) {
|
||||
clk := clocktesting.NewFakeClock(time.Now())
|
||||
cron := New(WithClock(clk))
|
||||
counter1ms := atomic.Int64{}
|
||||
counter15ms := atomic.Int64{}
|
||||
counter100ms := atomic.Int64{}
|
||||
|
||||
cron.AddFunc("@every 1ms", func() {
|
||||
counter1ms.Add(1)
|
||||
})
|
||||
cron.AddFunc("@every 15ms", func() {
|
||||
counter15ms.Add(1)
|
||||
})
|
||||
cron.AddFunc("@every 100ms", func() {
|
||||
counter100ms.Add(1)
|
||||
})
|
||||
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
for range 1000 {
|
||||
assert.Eventually(t, clk.HasWaiters, OneSecond, 1*time.Millisecond)
|
||||
clk.Step(1 * time.Millisecond)
|
||||
}
|
||||
ctx := cron.Stop()
|
||||
<-ctx.Done()
|
||||
|
||||
assert.Equal(t, int64(1000), counter1ms.Load())
|
||||
assert.Equal(t, int64(66), counter15ms.Load())
|
||||
assert.Equal(t, int64(10), counter100ms.Load())
|
||||
}
|
||||
|
||||
func TestNanoseconds(t *testing.T) {
|
||||
clk := clocktesting.NewFakeClock(time.Now())
|
||||
cron := New(WithClock(clk))
|
||||
|
||||
counter100ns := atomic.Int64{}
|
||||
cron.AddFunc("@every 100ns", func() {
|
||||
counter100ns.Add(1)
|
||||
})
|
||||
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
|
||||
for range 500 {
|
||||
assert.Eventually(t, clk.HasWaiters, OneSecond, 1*time.Millisecond)
|
||||
clk.Step(5 * time.Nanosecond)
|
||||
}
|
||||
ctx := cron.Stop()
|
||||
<-ctx.Done()
|
||||
|
||||
// 500 * 5 ns = 2500 ns
|
||||
// 2500 every 100ns = 25
|
||||
assert.Equal(t, int64(25), counter100ns.Load())
|
||||
}
|
||||
|
||||
func TestMultiThreadedStartAndStop(*testing.T) {
|
||||
cron := New()
|
||||
go cron.Run()
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
//nolint
|
||||
/*
|
||||
This package is a fork of "github.com/robfig/cron/v3" that implements cron spec parser and job runner with support for mocking the time.
|
||||
|
||||
|
@ -36,7 +35,9 @@ them in their own goroutines.
|
|||
# Time mocking
|
||||
|
||||
import (
|
||||
|
||||
clocktesting "k8s.io/utils/clock/testing"
|
||||
|
||||
)
|
||||
|
||||
clk := clocktesting.NewFakeClock(time.Now())
|
||||
|
|
|
@ -59,7 +59,7 @@ func TestWithVerboseLogger(t *testing.T) {
|
|||
out := buf.String()
|
||||
if !strings.Contains(out, "schedule,") ||
|
||||
!strings.Contains(out, "run,") {
|
||||
c.Errorf("expected to see some actions, got: %v", out) //nolint:testifylint
|
||||
c.Errorf("expected to see some actions, got: %v", out)
|
||||
}
|
||||
}, time.Second, time.Millisecond*10)
|
||||
}
|
||||
|
|
|
@ -14,7 +14,6 @@ You can check the original license at:
|
|||
https://github.com/robfig/cron/blob/master/LICENSE
|
||||
*/
|
||||
|
||||
//nolint
|
||||
package cron
|
||||
|
||||
import (
|
||||
|
@ -167,6 +166,8 @@ func TestParseSchedule(t *testing.T) {
|
|||
{standardParser, "CRON_TZ=UTC 5 * * * *", every5min(time.UTC)},
|
||||
{secondParser, "CRON_TZ=Asia/Tokyo 0 5 * * * *", every5min(tokyo)},
|
||||
{secondParser, "@every 5m", ConstantDelaySchedule{5 * time.Minute}},
|
||||
{secondParser, "@every 5ms", ConstantDelaySchedule{5 * time.Millisecond}},
|
||||
{secondParser, "@every 5ns", ConstantDelaySchedule{5 * time.Nanosecond}},
|
||||
{secondParser, "@midnight", midnight(time.Local)},
|
||||
{secondParser, "TZ=UTC @midnight", midnight(time.UTC)},
|
||||
{secondParser, "TZ=Asia/Tokyo @midnight", midnight(tokyo)},
|
||||
|
|
|
@ -214,9 +214,9 @@ func (aead *aesCBCAEAD) Open(dst, nonce, ciphertext, additionalData []byte) ([]b
|
|||
}
|
||||
|
||||
// Computes the HMAC tag as per specs.
|
||||
func (aead aesCBCAEAD) hmacTag(h hash.Hash, additionalData, nonce, ciphertext []byte, l int) []byte {
|
||||
func (aead *aesCBCAEAD) hmacTag(h hash.Hash, additionalData, nonce, ciphertext []byte, l int) []byte {
|
||||
al := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(al, uint64(len(additionalData)<<3)) // In bits
|
||||
binary.BigEndian.PutUint64(al, uint64(len(additionalData)<<3)) // #nosec G115 // In bits
|
||||
|
||||
h.Write(additionalData)
|
||||
h.Write(nonce)
|
||||
|
|
|
@ -48,14 +48,14 @@ func Wrap(block cipher.Block, cek []byte) ([]byte, error) {
|
|||
copy(r[i], cek[i*8:])
|
||||
}
|
||||
|
||||
for j := 0; j <= 5; j++ {
|
||||
for j := range 6 {
|
||||
for i := 1; i <= n; i++ {
|
||||
b := arrConcat(a, r[i-1])
|
||||
block.Encrypt(b, b)
|
||||
|
||||
t := (n * j) + i
|
||||
tBytes := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(tBytes, uint64(t))
|
||||
binary.BigEndian.PutUint64(tBytes, uint64(t)) // #nosec G115
|
||||
|
||||
copy(a, arrXor(b[:len(b)/2], tBytes))
|
||||
copy(r[i-1], b[len(b)/2:])
|
||||
|
@ -92,7 +92,7 @@ func Unwrap(block cipher.Block, cipherText []byte) ([]byte, error) {
|
|||
for i := n; i >= 1; i-- {
|
||||
t := (n * j) + i
|
||||
tBytes := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(tBytes, uint64(t))
|
||||
binary.BigEndian.PutUint64(tBytes, uint64(t)) // #nosec G115
|
||||
|
||||
b := arrConcat(arrXor(a, tBytes), r[i-1])
|
||||
block.Decrypt(b, b)
|
||||
|
|
|
@ -11,7 +11,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
*/
|
||||
|
||||
//nolint:nosnakecase,stylecheck,revive
|
||||
//nolint:nosnakecase,stylecheck
|
||||
package crypto
|
||||
|
||||
import (
|
||||
|
|
|
@ -0,0 +1,211 @@
|
|||
/*
|
||||
Copyright 2023 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package pem
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// DecodePEMCertificatesChain takes a PEM-encoded x509 certificates byte array
|
||||
// and returns all certificates in a slice of x509.Certificate objects.
|
||||
// Expects certificates to be a chain with leaf certificate to be first in the
|
||||
// byte array.
|
||||
func DecodePEMCertificatesChain(crtb []byte) ([]*x509.Certificate, error) {
|
||||
certs, err := DecodePEMCertificates(crtb)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range len(certs) - 1 {
|
||||
if certs[i].CheckSignatureFrom(certs[i+1]) != nil {
|
||||
return nil, errors.New("certificate chain is not valid")
|
||||
}
|
||||
}
|
||||
|
||||
return certs, nil
|
||||
}
|
||||
|
||||
// DecodePEMCertificatesChain takes a PEM-encoded x509 certificates byte array
|
||||
// and returns all certificates in a slice of x509.Certificate objects.
|
||||
func DecodePEMCertificates(crtb []byte) ([]*x509.Certificate, error) {
|
||||
certs := []*x509.Certificate{}
|
||||
for len(crtb) > 0 {
|
||||
var err error
|
||||
var cert *x509.Certificate
|
||||
|
||||
cert, crtb, err = decodeCertificatePEM(crtb)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cert != nil {
|
||||
// it's a cert, add to pool
|
||||
certs = append(certs, cert)
|
||||
}
|
||||
}
|
||||
|
||||
if len(certs) == 0 {
|
||||
return nil, errors.New("no certificates found")
|
||||
}
|
||||
|
||||
return certs, nil
|
||||
}
|
||||
|
||||
func decodeCertificatePEM(crtb []byte) (*x509.Certificate, []byte, error) {
|
||||
block, crtb := pem.Decode(crtb)
|
||||
if block == nil {
|
||||
return nil, nil, nil
|
||||
}
|
||||
if block.Type != "CERTIFICATE" {
|
||||
return nil, nil, nil
|
||||
}
|
||||
c, err := x509.ParseCertificate(block.Bytes)
|
||||
return c, crtb, err
|
||||
}
|
||||
|
||||
// DecodePEMPrivateKey takes a key PEM byte array and returns an object that
|
||||
// represents either an RSA or EC private key.
|
||||
func DecodePEMPrivateKey(key []byte) (crypto.Signer, error) {
|
||||
block, _ := pem.Decode(key)
|
||||
if block == nil {
|
||||
return nil, errors.New("key is not PEM encoded")
|
||||
}
|
||||
|
||||
switch block.Type {
|
||||
case "EC PRIVATE KEY":
|
||||
return x509.ParseECPrivateKey(block.Bytes)
|
||||
case "RSA PRIVATE KEY":
|
||||
return x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
case "PRIVATE KEY":
|
||||
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return key.(crypto.Signer), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported block type %s", block.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// EncodePrivateKey will encode a private key into PEM format.
|
||||
func EncodePrivateKey(key any) ([]byte, error) {
|
||||
var (
|
||||
keyBytes []byte
|
||||
err error
|
||||
blockType string
|
||||
)
|
||||
|
||||
switch key := key.(type) {
|
||||
case *ecdsa.PrivateKey, *ed25519.PrivateKey:
|
||||
keyBytes, err = x509.MarshalPKCS8PrivateKey(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
blockType = "PRIVATE KEY"
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported key type %T", key)
|
||||
}
|
||||
|
||||
return pem.EncodeToMemory(&pem.Block{
|
||||
Type: blockType, Bytes: keyBytes,
|
||||
}), nil
|
||||
}
|
||||
|
||||
// EncodeX509 will encode a single *x509.Certificate into PEM format.
|
||||
func EncodeX509(cert *x509.Certificate) ([]byte, error) {
|
||||
caPem := bytes.NewBuffer([]byte{})
|
||||
err := pem.Encode(caPem, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return caPem.Bytes(), nil
|
||||
}
|
||||
|
||||
// EncodeX509Chain will encode a list of *x509.Certificates into a PEM format chain.
|
||||
// Self-signed certificates are not included as per
|
||||
// https://datatracker.ietf.org/doc/html/rfc5246#section-7.4.2
|
||||
// Certificates are output in the order they're given; if the input is not ordered
|
||||
// as specified in RFC5246 section 7.4.2, the resulting chain might not be valid
|
||||
// for use in TLS.
|
||||
func EncodeX509Chain(certs []*x509.Certificate) ([]byte, error) {
|
||||
if len(certs) == 0 {
|
||||
return nil, errors.New("no certificates in chain")
|
||||
}
|
||||
|
||||
certPEM := bytes.NewBuffer([]byte{})
|
||||
for _, cert := range certs {
|
||||
if cert == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if cert.CheckSignatureFrom(cert) == nil {
|
||||
// Don't include self-signed certificate
|
||||
continue
|
||||
}
|
||||
|
||||
err := pem.Encode(certPEM, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return certPEM.Bytes(), nil
|
||||
}
|
||||
|
||||
// PublicKeysEqual compares two given public keys for equality.
|
||||
// The definition of "equality" depends on the type of the public keys.
|
||||
// Returns true if the keys are the same, false if they differ or an error if
|
||||
// the key type of `a` cannot be determined.
|
||||
func PublicKeysEqual(a, b crypto.PublicKey) (bool, error) {
|
||||
switch pub := a.(type) {
|
||||
case *rsa.PublicKey:
|
||||
return pub.Equal(b), nil
|
||||
case *ecdsa.PublicKey:
|
||||
return pub.Equal(b), nil
|
||||
case ed25519.PublicKey:
|
||||
return pub.Equal(b), nil
|
||||
default:
|
||||
return false, fmt.Errorf("unrecognised public key type: %T", a)
|
||||
}
|
||||
}
|
||||
|
||||
// GetPEM loads a PEM-encoded file (certificate or key).
|
||||
func GetPEM(val string) ([]byte, error) {
|
||||
// If val is already a PEM-encoded string, return it as-is
|
||||
if IsValidPEM(val) {
|
||||
return []byte(val), nil
|
||||
}
|
||||
|
||||
// Assume it's a file
|
||||
pemBytes, err := os.ReadFile(val)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("value is neither a valid file path or nor a valid PEM-encoded string: %w", err)
|
||||
}
|
||||
return pemBytes, nil
|
||||
}
|
||||
|
||||
// IsValidPEM validates the provided input has PEM formatted block.
|
||||
func IsValidPEM(val string) bool {
|
||||
block, _ := pem.Decode([]byte(val))
|
||||
return block != nil
|
||||
}
|
|
@ -0,0 +1,71 @@
|
|||
/*
|
||||
Copyright 2025 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package context
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
|
||||
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
|
||||
|
||||
"github.com/dapr/kit/crypto/spiffe"
|
||||
)
|
||||
|
||||
type ctxkey int
|
||||
|
||||
const (
|
||||
x509SvidKey ctxkey = iota
|
||||
jwtSvidKey
|
||||
)
|
||||
|
||||
// Deprecated: use WithX509 instead.
|
||||
// With adds the x509 SVID source from the SPIFFE object to the context.
|
||||
func With(ctx context.Context, spiffe *spiffe.SPIFFE) context.Context {
|
||||
return context.WithValue(ctx, x509SvidKey, spiffe.X509SVIDSource())
|
||||
}
|
||||
|
||||
// Deprecated: use X509From instead.
|
||||
// From retrieves the x509 SVID source from the context.
|
||||
func From(ctx context.Context) (x509svid.Source, bool) {
|
||||
svid, ok := ctx.Value(x509SvidKey).(x509svid.Source)
|
||||
return svid, ok
|
||||
}
|
||||
|
||||
// WithX509 adds an x509 SVID source to the context.
|
||||
func WithX509(ctx context.Context, source x509svid.Source) context.Context {
|
||||
return context.WithValue(ctx, x509SvidKey, source)
|
||||
}
|
||||
|
||||
// WithJWT adds a JWT SVID source to the context.
|
||||
func WithJWT(ctx context.Context, source jwtsvid.Source) context.Context {
|
||||
return context.WithValue(ctx, jwtSvidKey, source)
|
||||
}
|
||||
|
||||
// X509From retrieves the x509 SVID source from the context.
|
||||
func X509From(ctx context.Context) (x509svid.Source, bool) {
|
||||
svid, ok := ctx.Value(x509SvidKey).(x509svid.Source)
|
||||
return svid, ok
|
||||
}
|
||||
|
||||
// JWTFrom retrieves the JWT SVID source from the context.
|
||||
func JWTFrom(ctx context.Context) (jwtsvid.Source, bool) {
|
||||
svid, ok := ctx.Value(jwtSvidKey).(jwtsvid.Source)
|
||||
return svid, ok
|
||||
}
|
||||
|
||||
// WithSpiffe adds both X509 and JWT SVID sources to the context.
|
||||
func WithSpiffe(ctx context.Context, spiffe *spiffe.SPIFFE) context.Context {
|
||||
ctx = WithX509(ctx, spiffe.X509SVIDSource())
|
||||
return WithJWT(ctx, spiffe.JWTSVIDSource())
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
/*
|
||||
Copyright 2025 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package context
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
|
||||
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type mockX509Source struct{}
|
||||
|
||||
func (m *mockX509Source) GetX509SVID() (*x509svid.SVID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type mockJWTSource struct{}
|
||||
|
||||
func (m *mockJWTSource) FetchJWTSVID(context.Context, jwtsvid.Params) (*jwtsvid.SVID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestWithX509FromX509(t *testing.T) {
|
||||
source := &mockX509Source{}
|
||||
ctx := WithX509(t.Context(), source)
|
||||
|
||||
retrieved, ok := X509From(ctx)
|
||||
assert.True(t, ok, "Failed to retrieve X509 source from context")
|
||||
assert.Equal(t, x509svid.Source(source), retrieved, "Retrieved source does not match the original source")
|
||||
}
|
||||
|
||||
func TestWithJWTFromJWT(t *testing.T) {
|
||||
source := &mockJWTSource{}
|
||||
ctx := WithJWT(t.Context(), source)
|
||||
|
||||
retrieved, ok := JWTFrom(ctx)
|
||||
assert.True(t, ok, "Failed to retrieve JWT source from context")
|
||||
assert.Equal(t, jwtsvid.Source(source), retrieved, "Retrieved source does not match the original source")
|
||||
}
|
||||
|
||||
func TestWithFrom(t *testing.T) {
|
||||
x509Source := &mockX509Source{}
|
||||
ctx := WithX509(t.Context(), x509Source)
|
||||
|
||||
// Should be able to retrieve using the legacy From function
|
||||
retrieved, ok := From(ctx)
|
||||
assert.True(t, ok, "Failed to retrieve X509 source from context using legacy From")
|
||||
assert.Equal(t, x509svid.Source(x509Source), retrieved, "Retrieved source does not match the original source using legacy From")
|
||||
}
|
|
@ -0,0 +1,352 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package spiffe
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
|
||||
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
|
||||
"k8s.io/utils/clock"
|
||||
|
||||
"github.com/dapr/kit/concurrency/dir"
|
||||
"github.com/dapr/kit/crypto/pem"
|
||||
"github.com/dapr/kit/crypto/spiffe/trustanchors"
|
||||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
// renewalDivisor represents the divisor for calculating renewal time.
|
||||
// A value of 2 means renewal at 50% of the validity period.
|
||||
renewalDivisor = 2
|
||||
)
|
||||
|
||||
// SVIDResponse represents the response from the SVID request function,
|
||||
// containing both X.509 certificates and a JWT token.
|
||||
type SVIDResponse struct {
|
||||
X509Certificates []*x509.Certificate
|
||||
JWT *string
|
||||
}
|
||||
|
||||
// Identity contains both X.509 and JWT SVIDs for a workload.
|
||||
type Identity struct {
|
||||
X509SVID *x509svid.SVID
|
||||
JWTSVID *jwtsvid.SVID
|
||||
}
|
||||
|
||||
type (
|
||||
// RequestSVIDFn is the function type that requests SVIDs from a SPIFFE server,
|
||||
// returning both X.509 certificates and a JWT token.
|
||||
RequestSVIDFn func(context.Context, []byte) (*SVIDResponse, error)
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
Log logger.Logger
|
||||
RequestSVIDFn RequestSVIDFn
|
||||
|
||||
// WriteIdentityToFile is used to write the identity private key and
|
||||
// certificate chain to file. The certificate chain and private key will be
|
||||
// written to the `tls.cert` and `tls.key` files respectively in the given
|
||||
// directory.
|
||||
WriteIdentityToFile *string
|
||||
|
||||
TrustAnchors trustanchors.Interface
|
||||
}
|
||||
|
||||
// SPIFFE is a readable/writeable store of SPIFFE SVID credentials.
|
||||
// Used to manage workload SVIDs, and share read-only interfaces to consumers.
|
||||
type SPIFFE struct {
|
||||
currentX509SVID *x509svid.SVID
|
||||
currentJWTSVID *jwtsvid.SVID
|
||||
requestSVIDFn RequestSVIDFn
|
||||
|
||||
dir *dir.Dir
|
||||
trustAnchors trustanchors.Interface
|
||||
|
||||
log logger.Logger
|
||||
lock sync.RWMutex
|
||||
clock clock.Clock
|
||||
running atomic.Bool
|
||||
readyCh chan struct{}
|
||||
}
|
||||
|
||||
func New(opts Options) *SPIFFE {
|
||||
var sdir *dir.Dir
|
||||
if opts.WriteIdentityToFile != nil {
|
||||
sdir = dir.New(dir.Options{
|
||||
Log: opts.Log,
|
||||
Target: *opts.WriteIdentityToFile,
|
||||
})
|
||||
}
|
||||
|
||||
return &SPIFFE{
|
||||
requestSVIDFn: opts.RequestSVIDFn,
|
||||
dir: sdir,
|
||||
trustAnchors: opts.TrustAnchors,
|
||||
log: opts.Log,
|
||||
clock: clock.RealClock{},
|
||||
readyCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SPIFFE) Run(ctx context.Context) error {
|
||||
if !s.running.CompareAndSwap(false, true) {
|
||||
return errors.New("already running")
|
||||
}
|
||||
|
||||
s.lock.Lock()
|
||||
s.log.Info("Fetching initial identity")
|
||||
initialIdentity, err := s.fetchIdentity(ctx)
|
||||
if err != nil {
|
||||
close(s.readyCh)
|
||||
s.lock.Unlock()
|
||||
return fmt.Errorf("failed to retrieve the initial identity: %w", err)
|
||||
}
|
||||
|
||||
s.currentX509SVID = initialIdentity.X509SVID
|
||||
s.currentJWTSVID = initialIdentity.JWTSVID
|
||||
close(s.readyCh)
|
||||
s.lock.Unlock()
|
||||
|
||||
s.log.Infof("Security is initialized successfully")
|
||||
s.runRotation(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ready blocks until SPIFFE is ready or the context is done which will return
|
||||
// the context error.
|
||||
func (s *SPIFFE) Ready(ctx context.Context) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-s.readyCh:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// logIdentityInfo creates a log message with expiry details for both X.509 and JWT SVIDs
|
||||
func (s *SPIFFE) logIdentityInfo(prefix string, cert *x509.Certificate, jwtSVID *jwtsvid.SVID, renewTime *time.Time) {
|
||||
msg := prefix + "; cert expires on: %s"
|
||||
args := []any{cert.NotAfter.String()}
|
||||
|
||||
if jwtSVID != nil {
|
||||
msg += ", jwt expires on: %s"
|
||||
args = append(args, jwtSVID.Expiry.String())
|
||||
}
|
||||
|
||||
if renewTime != nil {
|
||||
msg += ", renewal at: %s"
|
||||
args = append(args, renewTime.String())
|
||||
}
|
||||
|
||||
s.log.Infof(msg, args...)
|
||||
}
|
||||
|
||||
// runRotation starts up the manager responsible for renewing the workload identity
|
||||
func (s *SPIFFE) runRotation(ctx context.Context) {
|
||||
defer s.log.Debug("stopping workload identity expiry watcher")
|
||||
|
||||
s.lock.RLock()
|
||||
cert := s.currentX509SVID.Certificates[0]
|
||||
jwtSVID := s.currentJWTSVID
|
||||
s.lock.RUnlock()
|
||||
|
||||
renewTime := calculateRenewalTime(time.Now(), cert, jwtSVID)
|
||||
s.logIdentityInfo("Starting workload identity expiry watcher", cert, jwtSVID, renewTime)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.clock.After(min(time.Minute, renewTime.Sub(s.clock.Now()))):
|
||||
if s.clock.Now().Before(*renewTime) {
|
||||
continue
|
||||
}
|
||||
|
||||
s.logIdentityInfo("Renewing workload identity", cert, jwtSVID, nil)
|
||||
|
||||
identity, err := s.fetchIdentity(ctx)
|
||||
if err != nil {
|
||||
s.log.Errorf("Error renewing identity, trying again in 10 seconds: %s", err)
|
||||
select {
|
||||
case <-s.clock.After(10 * time.Second):
|
||||
continue
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
s.lock.Lock()
|
||||
s.currentX509SVID = identity.X509SVID
|
||||
s.currentJWTSVID = identity.JWTSVID
|
||||
cert = identity.X509SVID.Certificates[0]
|
||||
jwtSVID = identity.JWTSVID
|
||||
s.lock.Unlock()
|
||||
|
||||
renewTime = calculateRenewalTime(time.Now(), cert, jwtSVID)
|
||||
s.logIdentityInfo("Successfully renewed workload identity", cert, jwtSVID, renewTime)
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Returns both X.509 SVID and JWT SVID (if available).
|
||||
func (s *SPIFFE) fetchIdentity(ctx context.Context) (*Identity, error) {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate private key: %w", err)
|
||||
}
|
||||
|
||||
csrDER, err := x509.CreateCertificateRequest(rand.Reader, new(x509.CertificateRequest), key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create sidecar csr: %w", err)
|
||||
}
|
||||
|
||||
svidResponse, err := s.requestSVIDFn(ctx, csrDER)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(svidResponse.X509Certificates) == 0 {
|
||||
return nil, errors.New("no certificates received from sentry")
|
||||
}
|
||||
|
||||
spiffeID, err := x509svid.IDFromCert(svidResponse.X509Certificates[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing spiffe id from newly signed certificate: %w", err)
|
||||
}
|
||||
|
||||
identity := &Identity{
|
||||
X509SVID: &x509svid.SVID{
|
||||
ID: spiffeID,
|
||||
Certificates: svidResponse.X509Certificates,
|
||||
PrivateKey: key,
|
||||
},
|
||||
}
|
||||
|
||||
// If we have a JWT token, parse it and include it in the identity
|
||||
if svidResponse.JWT != nil {
|
||||
// we are using ParseInsecure here as the expectation is that the
|
||||
// requestSVIDFn will have already parsed and validate the JWT SVID
|
||||
// before returning it.
|
||||
//
|
||||
// we are parsing the token using our SPIFFE ID's trust domain
|
||||
// as the audience as we expect the issuer to always include
|
||||
// that as an audience since that ensures that the token is
|
||||
// valid for us and our trust domain.
|
||||
audiences := []string{spiffeID.TrustDomain().Name()}
|
||||
jwtSvid, err := jwtsvid.ParseInsecure(*svidResponse.JWT, audiences)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JWT SVID: %w", err)
|
||||
}
|
||||
|
||||
identity.JWTSVID = jwtSvid
|
||||
s.log.Infof("Successfully received JWT SVID with expiry: %s", jwtSvid.Expiry.String())
|
||||
}
|
||||
|
||||
if s.dir != nil {
|
||||
pkPEM, err := pem.EncodePrivateKey(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
certPEM, err := pem.EncodeX509Chain(svidResponse.X509Certificates)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
td, err := s.trustAnchors.CurrentTrustAnchors(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
files := map[string][]byte{
|
||||
"key.pem": pkPEM,
|
||||
"cert.pem": certPEM,
|
||||
"ca.pem": td,
|
||||
}
|
||||
|
||||
if svidResponse.JWT != nil {
|
||||
files["jwt_svid.token"] = []byte(*svidResponse.JWT)
|
||||
}
|
||||
|
||||
if err := s.dir.Write(files); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return identity, nil
|
||||
}
|
||||
|
||||
func (s *SPIFFE) X509SVIDSource() x509svid.Source {
|
||||
return &svidSource{spiffe: s}
|
||||
}
|
||||
|
||||
func (s *SPIFFE) JWTSVIDSource() jwtsvid.Source {
|
||||
return &svidSource{spiffe: s}
|
||||
}
|
||||
|
||||
// renewalTime is 50% through the certificate validity period.
|
||||
func renewalTime(notBefore, notAfter time.Time) time.Time {
|
||||
return notBefore.Add(notAfter.Sub(notBefore) / renewalDivisor)
|
||||
}
|
||||
|
||||
// calculateRenewalTime returns the earlier renewal time between the X.509 certificate
|
||||
// and JWT SVID (if available) to ensure timely renewal.
|
||||
func calculateRenewalTime(now time.Time, cert *x509.Certificate, jwtSVID *jwtsvid.SVID) *time.Time {
|
||||
certRenewal := renewalTime(cert.NotBefore, cert.NotAfter)
|
||||
|
||||
if jwtSVID == nil {
|
||||
return &certRenewal
|
||||
}
|
||||
|
||||
jwtRenewal := now.Add(jwtSVID.Expiry.Sub(now) / renewalDivisor)
|
||||
|
||||
if jwtRenewal.Before(certRenewal) {
|
||||
return &jwtRenewal
|
||||
}
|
||||
return &certRenewal
|
||||
}
|
||||
|
||||
// audiencesMatch checks if the SVID audiences contain all the requested audiences
|
||||
func audiencesMatch(svidAudiences []string, requestedAudiences []string) bool {
|
||||
if len(requestedAudiences) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Create a map for faster lookup
|
||||
audienceMap := make(map[string]struct{}, len(svidAudiences))
|
||||
for _, audience := range svidAudiences {
|
||||
audienceMap[audience] = struct{}{}
|
||||
}
|
||||
|
||||
// Check if all requested audiences are in the SVID
|
||||
for _, requested := range requestedAudiences {
|
||||
if _, ok := audienceMap[requested]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
|
@ -0,0 +1,269 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package spiffe
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/spiffe/go-spiffe/v2/spiffeid"
|
||||
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
clocktesting "k8s.io/utils/clock/testing"
|
||||
|
||||
"github.com/dapr/kit/crypto/test"
|
||||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
||||
func Test_renewalTime(t *testing.T) {
|
||||
now := time.Now()
|
||||
assert.Equal(t, now, renewalTime(now, now))
|
||||
|
||||
in1Min := now.Add(time.Minute)
|
||||
in30 := now.Add(time.Second * 30)
|
||||
assert.Equal(t, in30, renewalTime(now, in1Min))
|
||||
}
|
||||
|
||||
func Test_calculateRenewalTime(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
certShort := &x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(10 * time.Hour),
|
||||
}
|
||||
|
||||
certLong := &x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
// Expected renewal times for certificates (50% of validity period)
|
||||
certShortRenewal := now.Add(5 * time.Hour)
|
||||
|
||||
// Create JWT SVIDs with different expiry times
|
||||
jwtEarlier := &jwtsvid.SVID{
|
||||
Expiry: now.Add(8 * time.Hour),
|
||||
}
|
||||
|
||||
jwtLater := &jwtsvid.SVID{
|
||||
Expiry: now.Add(30 * time.Hour),
|
||||
}
|
||||
|
||||
// Expected JWT renewal time (50% of remaining time)
|
||||
jwtEarlierRenewal := now.Add(4 * time.Hour)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cert *x509.Certificate
|
||||
jwt *jwtsvid.SVID
|
||||
expected time.Time
|
||||
}{
|
||||
{
|
||||
name: "Certificate only",
|
||||
cert: certShort,
|
||||
jwt: nil,
|
||||
expected: certShortRenewal,
|
||||
},
|
||||
{
|
||||
name: "Certificate and JWT, JWT earlier",
|
||||
cert: certLong,
|
||||
jwt: jwtEarlier,
|
||||
expected: jwtEarlierRenewal,
|
||||
},
|
||||
{
|
||||
name: "Certificate and JWT, Certificate earlier",
|
||||
cert: certShort,
|
||||
jwt: jwtLater,
|
||||
expected: certShortRenewal,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
actual := calculateRenewalTime(now, tt.cert, tt.jwt)
|
||||
|
||||
assert.WithinDuration(t, tt.expected, *actual, time.Millisecond,
|
||||
"Renewal time does not match expected value")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Run(t *testing.T) {
|
||||
t.Run("should return error multiple Runs are called", func(t *testing.T) {
|
||||
pki := test.GenPKI(t, test.PKIOptions{
|
||||
LeafID: spiffeid.RequireFromString("spiffe://example.com/foo/bar"),
|
||||
})
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
s := New(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
RequestSVIDFn: func(context.Context, []byte) (*SVIDResponse, error) {
|
||||
return &SVIDResponse{
|
||||
X509Certificates: []*x509.Certificate{pki.LeafCert},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- s.Run(ctx)
|
||||
}()
|
||||
go func() {
|
||||
errCh <- s.Run(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
require.Error(t, err)
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "Expected error")
|
||||
}
|
||||
|
||||
cancel()
|
||||
select {
|
||||
case err := <-errCh:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "First Run should have returned and returned no error ")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("should return error if initial fetch errors", func(t *testing.T) {
|
||||
s := New(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
RequestSVIDFn: func(context.Context, []byte) (*SVIDResponse, error) {
|
||||
return nil, errors.New("this is an error")
|
||||
},
|
||||
})
|
||||
|
||||
require.Error(t, s.Run(t.Context()))
|
||||
})
|
||||
|
||||
t.Run("should renew certificate when it has expired", func(t *testing.T) {
|
||||
pki := test.GenPKI(t, test.PKIOptions{
|
||||
LeafID: spiffeid.RequireFromString("spiffe://example.com/foo/bar"),
|
||||
})
|
||||
|
||||
var fetches atomic.Int32
|
||||
s := New(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
RequestSVIDFn: func(context.Context, []byte) (*SVIDResponse, error) {
|
||||
fetches.Add(1)
|
||||
return &SVIDResponse{
|
||||
X509Certificates: []*x509.Certificate{pki.LeafCert},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
now := time.Now()
|
||||
clock := clocktesting.NewFakeClock(now)
|
||||
s.clock = clock
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
select {
|
||||
case <-s.readyCh:
|
||||
assert.Fail(t, "readyCh should not be closed")
|
||||
default:
|
||||
}
|
||||
|
||||
errCh <- s.Run(ctx)
|
||||
}()
|
||||
|
||||
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
|
||||
assert.Equal(t, int32(1), fetches.Load())
|
||||
|
||||
clock.Step(pki.LeafCert.NotAfter.Sub(now) / 2)
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
assert.Equal(c, int32(2), fetches.Load())
|
||||
}, time.Second, time.Millisecond)
|
||||
|
||||
cancel()
|
||||
select {
|
||||
case err := <-errCh:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "First Run should have returned and returned no error ")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("if renewal failed, should try again in 10 seconds", func(t *testing.T) {
|
||||
pki := test.GenPKI(t, test.PKIOptions{
|
||||
LeafID: spiffeid.RequireFromString("spiffe://example.com/foo/bar"),
|
||||
})
|
||||
|
||||
respCert := []*x509.Certificate{pki.LeafCert}
|
||||
var respErr error
|
||||
|
||||
var fetches atomic.Int32
|
||||
s := New(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
RequestSVIDFn: func(context.Context, []byte) (*SVIDResponse, error) {
|
||||
fetches.Add(1)
|
||||
return &SVIDResponse{
|
||||
X509Certificates: respCert,
|
||||
}, respErr
|
||||
},
|
||||
})
|
||||
now := time.Now()
|
||||
clock := clocktesting.NewFakeClock(now)
|
||||
s.clock = clock
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
select {
|
||||
case <-s.readyCh:
|
||||
assert.Fail(t, "readyCh should not be closed")
|
||||
default:
|
||||
}
|
||||
|
||||
errCh <- s.Run(ctx)
|
||||
}()
|
||||
|
||||
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
|
||||
assert.Equal(t, int32(1), fetches.Load())
|
||||
|
||||
respCert = nil
|
||||
respErr = errors.New("this is an error")
|
||||
clock.Step(pki.LeafCert.NotAfter.Sub(now) / 2)
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
assert.Equal(c, int32(2), fetches.Load())
|
||||
}, time.Second, time.Millisecond)
|
||||
|
||||
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
|
||||
clock.Step(time.Second * 5)
|
||||
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
|
||||
assert.Equal(t, int32(2), fetches.Load())
|
||||
|
||||
clock.Step(time.Second * 5)
|
||||
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
|
||||
clock.Step(1)
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
assert.Equal(c, int32(3), fetches.Load())
|
||||
}, time.Second, time.Millisecond)
|
||||
|
||||
cancel()
|
||||
select {
|
||||
case err := <-errCh:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "First Run should have returned and returned no error ")
|
||||
}
|
||||
})
|
||||
}
|
|
@ -0,0 +1,95 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package spiffe
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
|
||||
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
|
||||
)
|
||||
|
||||
var (
|
||||
errNoX509SVIDAvailable = errors.New("no X509 SVID available")
|
||||
errNoJWTSVIDAvailable = errors.New("no JWT SVID available")
|
||||
errAudienceRequired = errors.New("JWT audience is required")
|
||||
)
|
||||
|
||||
// svidSource is an implementation of both go-spiffe x509svid.Source and jwtsvid.Source interfaces.
|
||||
type svidSource struct {
|
||||
spiffe *SPIFFE
|
||||
}
|
||||
|
||||
// GetX509SVID returns the current X.509 certificate identity as a SPIFFE SVID.
|
||||
// Implements the go-spiffe x509svid.Source interface.
|
||||
func (s *svidSource) GetX509SVID() (*x509svid.SVID, error) {
|
||||
s.spiffe.lock.RLock()
|
||||
defer s.spiffe.lock.RUnlock()
|
||||
|
||||
<-s.spiffe.readyCh
|
||||
|
||||
svid := s.spiffe.currentX509SVID
|
||||
if svid == nil {
|
||||
return nil, errNoX509SVIDAvailable
|
||||
}
|
||||
|
||||
return svid, nil
|
||||
}
|
||||
|
||||
// audienceMismatchError is an error that contains information about mismatched audiences
|
||||
type audienceMismatchError struct {
|
||||
expected []string
|
||||
actual []string
|
||||
}
|
||||
|
||||
func (e *audienceMismatchError) Error() string {
|
||||
return fmt.Sprintf("JWT SVID has different audiences than requested: expected %s, got %s",
|
||||
strings.Join(e.expected, ", "), strings.Join(e.actual, ", "))
|
||||
}
|
||||
|
||||
// FetchJWTSVID returns the current JWT SVID.
|
||||
// Implements the go-spiffe jwtsvid.Source interface.
|
||||
func (s *svidSource) FetchJWTSVID(ctx context.Context, params jwtsvid.Params) (*jwtsvid.SVID, error) {
|
||||
s.spiffe.lock.RLock()
|
||||
defer s.spiffe.lock.RUnlock()
|
||||
|
||||
if params.Audience == "" {
|
||||
return nil, errAudienceRequired
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-s.spiffe.readyCh:
|
||||
}
|
||||
|
||||
svid := s.spiffe.currentJWTSVID
|
||||
if svid == nil {
|
||||
return nil, errNoJWTSVIDAvailable
|
||||
}
|
||||
|
||||
// verify that the audience being requested is the same as the audience in the SVID
|
||||
// WARN: we do not check extra audiences here.
|
||||
if !audiencesMatch(svid.Audience, []string{params.Audience}) {
|
||||
return nil, &audienceMismatchError{
|
||||
expected: []string{params.Audience},
|
||||
actual: svid.Audience,
|
||||
}
|
||||
}
|
||||
|
||||
return svid, nil
|
||||
}
|
|
@ -0,0 +1,190 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package spiffe
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/spiffe/go-spiffe/v2/spiffeid"
|
||||
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
|
||||
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_svidSource(*testing.T) {
|
||||
var _ x509svid.Source = new(svidSource)
|
||||
var _ jwtsvid.Source = new(svidSource)
|
||||
}
|
||||
|
||||
// createMockJWTSVID creates a mock JWT SVID for testing
|
||||
func createMockJWTSVID(audiences []string) (*jwtsvid.SVID, error) {
|
||||
td, err := spiffeid.TrustDomainFromString("example.org")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id, err := spiffeid.FromSegments(td, "workload")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svid := &jwtsvid.SVID{
|
||||
ID: id,
|
||||
Audience: audiences,
|
||||
Expiry: time.Now().Add(time.Hour),
|
||||
}
|
||||
|
||||
return svid, nil
|
||||
}
|
||||
|
||||
func TestFetchJWTSVID(t *testing.T) {
|
||||
t.Run("should return error when audience is empty", func(t *testing.T) {
|
||||
s := &svidSource{
|
||||
spiffe: &SPIFFE{
|
||||
readyCh: make(chan struct{}),
|
||||
lock: sync.RWMutex{},
|
||||
},
|
||||
}
|
||||
close(s.spiffe.readyCh) // Mark as ready
|
||||
|
||||
svid, err := s.FetchJWTSVID(t.Context(), jwtsvid.Params{
|
||||
Audience: "",
|
||||
})
|
||||
|
||||
require.Nil(t, svid)
|
||||
require.ErrorIs(t, err, errAudienceRequired)
|
||||
})
|
||||
|
||||
t.Run("should return error when no JWT SVID available", func(t *testing.T) {
|
||||
s := &svidSource{
|
||||
spiffe: &SPIFFE{
|
||||
readyCh: make(chan struct{}),
|
||||
lock: sync.RWMutex{},
|
||||
currentJWTSVID: nil,
|
||||
},
|
||||
}
|
||||
close(s.spiffe.readyCh) // Mark as ready
|
||||
|
||||
svid, err := s.FetchJWTSVID(t.Context(), jwtsvid.Params{
|
||||
Audience: "test-audience",
|
||||
})
|
||||
|
||||
require.Nil(t, svid)
|
||||
require.ErrorIs(t, err, errNoJWTSVIDAvailable)
|
||||
})
|
||||
|
||||
t.Run("should return error when audience doesn't match", func(t *testing.T) {
|
||||
// Create a mock SVID with a specific audience
|
||||
mockJWTSVID, err := createMockJWTSVID([]string{"actual-audience"})
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &svidSource{
|
||||
spiffe: &SPIFFE{
|
||||
readyCh: make(chan struct{}),
|
||||
lock: sync.RWMutex{},
|
||||
currentJWTSVID: mockJWTSVID,
|
||||
},
|
||||
}
|
||||
close(s.spiffe.readyCh) // Mark as ready
|
||||
|
||||
svid, err := s.FetchJWTSVID(t.Context(), jwtsvid.Params{
|
||||
Audience: "requested-audience",
|
||||
})
|
||||
|
||||
require.Nil(t, svid)
|
||||
require.Error(t, err)
|
||||
|
||||
// Verify the specific error type and contents
|
||||
audienceErr, ok := err.(*audienceMismatchError)
|
||||
require.True(t, ok, "Expected audienceMismatchError")
|
||||
require.Equal(t, "JWT SVID has different audiences than requested: expected requested-audience, got actual-audience", audienceErr.Error())
|
||||
})
|
||||
|
||||
t.Run("should return JWT SVID when audience matches", func(t *testing.T) {
|
||||
mockJWTSVID, err := createMockJWTSVID([]string{"test-audience", "extra-audience"})
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &svidSource{
|
||||
spiffe: &SPIFFE{
|
||||
readyCh: make(chan struct{}),
|
||||
lock: sync.RWMutex{},
|
||||
currentJWTSVID: mockJWTSVID,
|
||||
},
|
||||
}
|
||||
close(s.spiffe.readyCh) // Mark as ready
|
||||
|
||||
svid, err := s.FetchJWTSVID(t.Context(), jwtsvid.Params{
|
||||
Audience: "test-audience",
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, mockJWTSVID, svid)
|
||||
})
|
||||
|
||||
t.Run("should wait for readyCh before checking SVID", func(t *testing.T) {
|
||||
mockJWTSVID, err := createMockJWTSVID([]string{"test-audience"})
|
||||
require.NoError(t, err)
|
||||
|
||||
readyCh := make(chan struct{})
|
||||
s := &svidSource{
|
||||
spiffe: &SPIFFE{
|
||||
readyCh: readyCh,
|
||||
lock: sync.RWMutex{},
|
||||
currentJWTSVID: mockJWTSVID,
|
||||
},
|
||||
}
|
||||
|
||||
// Start goroutine to fetch SVID
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
resultCh := make(chan struct {
|
||||
svid *jwtsvid.SVID
|
||||
err error
|
||||
})
|
||||
|
||||
go func() {
|
||||
svid, err := s.FetchJWTSVID(ctx, jwtsvid.Params{
|
||||
Audience: "test-audience",
|
||||
})
|
||||
resultCh <- struct {
|
||||
svid *jwtsvid.SVID
|
||||
err error
|
||||
}{svid, err}
|
||||
}()
|
||||
|
||||
// require that fetch is blocked
|
||||
select {
|
||||
case <-resultCh:
|
||||
t.Fatal("FetchJWTSVID should be blocked until readyCh is closed")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
// Expected behavior - fetch is blocked
|
||||
}
|
||||
|
||||
// Close readyCh to unblock fetch
|
||||
close(readyCh)
|
||||
|
||||
// Now fetch should complete
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
require.NoError(t, result.err)
|
||||
require.NotNil(t, result.svid)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("FetchJWTSVID should have completed after readyCh was closed")
|
||||
}
|
||||
})
|
||||
}
|
|
@ -0,0 +1,300 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package file
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/spiffe/go-spiffe/v2/bundle/jwtbundle"
|
||||
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
|
||||
"github.com/spiffe/go-spiffe/v2/spiffeid"
|
||||
"k8s.io/utils/clock"
|
||||
|
||||
"github.com/dapr/kit/concurrency"
|
||||
"github.com/dapr/kit/crypto/pem"
|
||||
"github.com/dapr/kit/crypto/spiffe/trustanchors"
|
||||
"github.com/dapr/kit/fswatcher"
|
||||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrTrustAnchorsClosed is returned when an operation is performed on closed trust anchors.
|
||||
ErrTrustAnchorsClosed = errors.New("trust anchors is closed")
|
||||
|
||||
// ErrFailedToReadTrustAnchorsFile is returned when the trust anchors file cannot be read.
|
||||
ErrFailedToReadTrustAnchorsFile = errors.New("failed to read trust anchors file")
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
Log logger.Logger
|
||||
CAPath string
|
||||
JwksPath *string
|
||||
}
|
||||
|
||||
// file is a TrustAnchors implementation that uses a file as the source of trust
|
||||
// anchors. The trust anchors will be updated when the file changes.
|
||||
type file struct {
|
||||
log logger.Logger
|
||||
caPath string
|
||||
jwksPath *string
|
||||
x509Bundle *x509bundle.Bundle
|
||||
jwtBundle *jwtbundle.Bundle
|
||||
rootPEM []byte
|
||||
|
||||
// fswatcherInterval is the interval at which the trust anchors file changes
|
||||
// are batched. Used for testing only, and 500ms otherwise.
|
||||
fsWatcherInterval time.Duration
|
||||
|
||||
// initFileWatchInterval is the interval at which the trust anchors file is
|
||||
// checked for the first time. Used for testing only, and 1 second otherwise.
|
||||
initFileWatchInterval time.Duration
|
||||
|
||||
// subs is a list of channels to notify when the trust anchors are updated.
|
||||
subs []chan<- struct{}
|
||||
|
||||
lock sync.RWMutex
|
||||
clock clock.Clock
|
||||
running atomic.Bool
|
||||
readyCh chan struct{}
|
||||
closeCh chan struct{}
|
||||
caEvent chan struct{}
|
||||
}
|
||||
|
||||
func From(opts Options) trustanchors.Interface {
|
||||
return &file{
|
||||
fsWatcherInterval: time.Millisecond * 500,
|
||||
initFileWatchInterval: time.Second,
|
||||
|
||||
log: opts.Log,
|
||||
caPath: opts.CAPath,
|
||||
jwksPath: opts.JwksPath,
|
||||
clock: clock.RealClock{},
|
||||
readyCh: make(chan struct{}),
|
||||
closeCh: make(chan struct{}),
|
||||
caEvent: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *file) Run(ctx context.Context) error {
|
||||
if !f.running.CompareAndSwap(false, true) {
|
||||
return errors.New("trust anchors is already running")
|
||||
}
|
||||
|
||||
defer close(f.closeCh)
|
||||
|
||||
for {
|
||||
fs := []string{f.caPath}
|
||||
if f.jwksPath != nil {
|
||||
fs = append(fs, *f.jwksPath)
|
||||
}
|
||||
|
||||
if found, err := filesExist(fs...); err != nil {
|
||||
return err
|
||||
} else if found {
|
||||
break
|
||||
}
|
||||
|
||||
// Trust anchors file not be provided yet, wait.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("failed to find trust anchors file '%s': %w", f.caPath, ctx.Err())
|
||||
case <-f.clock.After(f.initFileWatchInterval):
|
||||
f.log.Warnf("Trust anchors file '%s' not found, waiting...", f.caPath)
|
||||
}
|
||||
}
|
||||
|
||||
f.log.Infof("Trust anchors file '%s' found", f.caPath)
|
||||
|
||||
if err := f.updateAnchors(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
targets := []string{f.caPath}
|
||||
if f.jwksPath != nil {
|
||||
targets = append(targets, *f.jwksPath)
|
||||
}
|
||||
|
||||
fs, err := fswatcher.New(fswatcher.Options{
|
||||
Targets: targets,
|
||||
Interval: &f.fsWatcherInterval,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create file watcher: %w", err)
|
||||
}
|
||||
|
||||
close(f.readyCh)
|
||||
|
||||
f.log.Infof("Watching trust anchors file '%s' for changes", f.caPath)
|
||||
if f.jwksPath != nil {
|
||||
f.log.Infof("Watching JWT bundle file '%s' for changes", f.jwksPath)
|
||||
}
|
||||
|
||||
return concurrency.NewRunnerManager(
|
||||
func(ctx context.Context) error {
|
||||
return fs.Run(ctx, f.caEvent)
|
||||
},
|
||||
func(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-f.caEvent:
|
||||
f.log.Info("Trust anchors file changed, reloading trust anchors")
|
||||
|
||||
if err = f.updateAnchors(ctx); err != nil {
|
||||
return fmt.Errorf("%w: '%s': %v", ErrFailedToReadTrustAnchorsFile, f.caPath, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
).Run(ctx)
|
||||
}
|
||||
|
||||
func (f *file) CurrentTrustAnchors(ctx context.Context) ([]byte, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-f.closeCh:
|
||||
return nil, ErrTrustAnchorsClosed
|
||||
case <-f.readyCh:
|
||||
}
|
||||
|
||||
f.lock.RLock()
|
||||
defer f.lock.RUnlock()
|
||||
rootPEM := make([]byte, len(f.rootPEM))
|
||||
copy(rootPEM, f.rootPEM)
|
||||
return rootPEM, nil
|
||||
}
|
||||
|
||||
func (f *file) updateAnchors(ctx context.Context) error {
|
||||
f.lock.Lock()
|
||||
defer f.lock.Unlock()
|
||||
|
||||
rootPEMs, err := os.ReadFile(f.caPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read trust anchors file '%s': %w", f.caPath, err)
|
||||
}
|
||||
|
||||
trustAnchorCerts, err := pem.DecodePEMCertificates(rootPEMs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode trust anchors: %w", err)
|
||||
}
|
||||
|
||||
f.rootPEM = rootPEMs
|
||||
f.x509Bundle = x509bundle.FromX509Authorities(spiffeid.TrustDomain{}, trustAnchorCerts)
|
||||
|
||||
if f.jwksPath != nil {
|
||||
jwks, err := os.ReadFile(*f.jwksPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read JWT bundle file '%s': %w", *f.jwksPath, err)
|
||||
}
|
||||
|
||||
jwtBundle, err := jwtbundle.Parse(spiffeid.TrustDomain{}, jwks)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse JWT bundle: %w", err)
|
||||
}
|
||||
f.jwtBundle = jwtBundle
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
defer wg.Wait()
|
||||
|
||||
wg.Add(len(f.subs))
|
||||
for _, ch := range f.subs {
|
||||
go func(chi chan<- struct{}) {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case chi <- struct{}{}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}(ch)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *file) GetX509BundleForTrustDomain(_ spiffeid.TrustDomain) (*x509bundle.Bundle, error) {
|
||||
select {
|
||||
case <-f.closeCh:
|
||||
return nil, ErrTrustAnchorsClosed
|
||||
case <-f.readyCh:
|
||||
}
|
||||
|
||||
f.lock.RLock()
|
||||
defer f.lock.RUnlock()
|
||||
bundle := f.x509Bundle
|
||||
return bundle, nil
|
||||
}
|
||||
|
||||
func (f *file) GetJWTBundleForTrustDomain(_ spiffeid.TrustDomain) (*jwtbundle.Bundle, error) {
|
||||
select {
|
||||
case <-f.closeCh:
|
||||
return nil, ErrTrustAnchorsClosed
|
||||
case <-f.readyCh:
|
||||
}
|
||||
|
||||
f.lock.RLock()
|
||||
defer f.lock.RUnlock()
|
||||
bundle := f.jwtBundle
|
||||
return bundle, nil
|
||||
}
|
||||
|
||||
func (f *file) Watch(ctx context.Context, ch chan<- []byte) {
|
||||
f.lock.Lock()
|
||||
sub := make(chan struct{}, 5)
|
||||
f.subs = append(f.subs, sub)
|
||||
f.lock.Unlock()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-f.closeCh:
|
||||
return
|
||||
case <-sub:
|
||||
f.lock.RLock()
|
||||
rootPEM := make([]byte, len(f.rootPEM))
|
||||
copy(rootPEM, f.rootPEM)
|
||||
f.lock.RUnlock()
|
||||
|
||||
select {
|
||||
case ch <- rootPEM:
|
||||
case <-ctx.Done():
|
||||
case <-f.closeCh:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func filesExist(paths ...string) (bool, error) {
|
||||
for _, path := range paths {
|
||||
if path == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("failed to stat file '%s': %w", path, err)
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
|
@ -0,0 +1,580 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package file
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/spiffe/go-spiffe/v2/spiffeid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/dapr/kit/crypto/test"
|
||||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
||||
func TestFile_Run(t *testing.T) {
|
||||
t.Run("if Run multiple times, expect error", func(t *testing.T) {
|
||||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
f, ok := ta.(*file)
|
||||
require.True(t, ok)
|
||||
f.initFileWatchInterval = time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- f.Run(ctx)
|
||||
}()
|
||||
go func() {
|
||||
errCh <- f.Run(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
require.Error(t, err)
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "Expected error")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-f.closeCh:
|
||||
assert.Fail(t, "closeCh should not be closed")
|
||||
default:
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "First Run should have returned and returned no error ")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("if file is not found and context cancelled, should return ctx.Err", func(t *testing.T) {
|
||||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
f, ok := ta.(*file)
|
||||
require.True(t, ok)
|
||||
f.initFileWatchInterval = time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- f.Run(ctx)
|
||||
}()
|
||||
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "First Run should have returned and returned no error ")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("if file found but is empty, should return error", func(t *testing.T) {
|
||||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, nil, 0o600))
|
||||
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
f, ok := ta.(*file)
|
||||
require.True(t, ok)
|
||||
f.initFileWatchInterval = time.Millisecond
|
||||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- f.Run(t.Context())
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
require.Error(t, err)
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "expected error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("if file found but is only garbage data, expect error", func(t *testing.T) {
|
||||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, []byte("garbage data"), 0o600))
|
||||
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
f, ok := ta.(*file)
|
||||
require.True(t, ok)
|
||||
f.initFileWatchInterval = time.Millisecond
|
||||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- f.Run(t.Context())
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
require.Error(t, err)
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "expected error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("if file found but is only garbage data in root, expect error", func(t *testing.T) {
|
||||
pki := test.GenPKI(t, test.PKIOptions{})
|
||||
root := pki.RootCertPEM[10:]
|
||||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, root, 0o600))
|
||||
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
f, ok := ta.(*file)
|
||||
require.True(t, ok)
|
||||
f.initFileWatchInterval = time.Millisecond
|
||||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- f.Run(t.Context())
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
require.Error(t, err)
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "expected error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("single root should be correctly parsed from file", func(t *testing.T) {
|
||||
pki := test.GenPKI(t, test.PKIOptions{})
|
||||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, pki.RootCertPEM, 0o600))
|
||||
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
f, ok := ta.(*file)
|
||||
require.True(t, ok)
|
||||
f.initFileWatchInterval = time.Millisecond
|
||||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- f.Run(t.Context())
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-f.readyCh:
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "expected to be ready in time")
|
||||
}
|
||||
|
||||
b, err := f.CurrentTrustAnchors(t.Context())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, pki.RootCertPEM, b)
|
||||
})
|
||||
|
||||
t.Run("garbage data outside of root should be ignored", func(t *testing.T) {
|
||||
pki := test.GenPKI(t, test.PKIOptions{})
|
||||
//nolint:gocritic
|
||||
root := append(pki.RootCertPEM, []byte("garbage data")...)
|
||||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, root, 0o600))
|
||||
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
f, ok := ta.(*file)
|
||||
require.True(t, ok)
|
||||
f.initFileWatchInterval = time.Millisecond
|
||||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- f.Run(t.Context())
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-f.readyCh:
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "expected to be ready in time")
|
||||
}
|
||||
|
||||
b, err := f.CurrentTrustAnchors(t.Context())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, root, b)
|
||||
})
|
||||
|
||||
t.Run("multiple roots should be parsed", func(t *testing.T) {
|
||||
pki1, pki2 := test.GenPKI(t, test.PKIOptions{}), test.GenPKI(t, test.PKIOptions{})
|
||||
//nolint:gocritic
|
||||
roots := append(pki1.RootCertPEM, pki2.RootCertPEM...)
|
||||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
|
||||
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
f, ok := ta.(*file)
|
||||
require.True(t, ok)
|
||||
f.initFileWatchInterval = time.Millisecond
|
||||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- f.Run(t.Context())
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-f.readyCh:
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "expected to be ready in time")
|
||||
}
|
||||
|
||||
b, err := f.CurrentTrustAnchors(t.Context())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, roots, b)
|
||||
})
|
||||
|
||||
t.Run("writing a bad root PEM file should make Run return error", func(t *testing.T) {
|
||||
pki1, pki2 := test.GenPKI(t, test.PKIOptions{}), test.GenPKI(t, test.PKIOptions{})
|
||||
//nolint:gocritic
|
||||
roots := append(pki1.RootCertPEM, pki2.RootCertPEM...)
|
||||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
|
||||
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
f, ok := ta.(*file)
|
||||
require.True(t, ok)
|
||||
f.initFileWatchInterval = time.Millisecond
|
||||
f.fsWatcherInterval = time.Millisecond
|
||||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- f.Run(t.Context())
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-f.readyCh:
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "expected to be ready in time")
|
||||
}
|
||||
|
||||
require.NoError(t, os.WriteFile(tmp, []byte("garbage data"), 0o600))
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
require.Error(t, err)
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "expected error to be returned from Run")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFile_GetX509BundleForTrustDomain(t *testing.T) {
|
||||
t.Run("Should return full PEM regardless given trust domain", func(t *testing.T) {
|
||||
pki := test.GenPKI(t, test.PKIOptions{})
|
||||
//nolint:gocritic
|
||||
root := append(pki.RootCertPEM, []byte("garbage data")...)
|
||||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, root, 0o600))
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
f, ok := ta.(*file)
|
||||
require.True(t, ok)
|
||||
|
||||
errCh := make(chan error)
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
go func() {
|
||||
errCh <- ta.Run(ctx)
|
||||
}()
|
||||
t.Cleanup(func() {
|
||||
cancel()
|
||||
select {
|
||||
case err := <-errCh:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "expected Run to return")
|
||||
}
|
||||
})
|
||||
|
||||
trustDomain1, err := spiffeid.TrustDomainFromString("example.com")
|
||||
require.NoError(t, err)
|
||||
bundle, err := f.GetX509BundleForTrustDomain(trustDomain1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, f.x509Bundle, bundle)
|
||||
b1, err := bundle.Marshal()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, pki.RootCertPEM, b1)
|
||||
|
||||
trustDomain2, err := spiffeid.TrustDomainFromString("another-example.org")
|
||||
require.NoError(t, err)
|
||||
bundle, err = f.GetX509BundleForTrustDomain(trustDomain2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, f.x509Bundle, bundle)
|
||||
b2, err := bundle.Marshal()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, pki.RootCertPEM, b2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFile_Watch(t *testing.T) {
|
||||
t.Run("should return when Run context has been cancelled", func(t *testing.T) {
|
||||
pki := test.GenPKI(t, test.PKIOptions{})
|
||||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, pki.RootCertPEM, 0o600))
|
||||
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
f, ok := ta.(*file)
|
||||
require.True(t, ok)
|
||||
f.initFileWatchInterval = time.Millisecond
|
||||
|
||||
errCh := make(chan error)
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
go func() {
|
||||
errCh <- f.Run(ctx)
|
||||
}()
|
||||
time.Sleep(time.Millisecond * 10) // adding a small delay to ensure f.Run has finished and running
|
||||
|
||||
watchDone := make(chan struct{})
|
||||
go func() {
|
||||
ta.Watch(t.Context(), make(chan []byte))
|
||||
close(watchDone)
|
||||
}()
|
||||
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "expected error to be returned from Run")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-watchDone:
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "expected Watch to have returned")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("should return when given context has been cancelled", func(t *testing.T) {
|
||||
pki := test.GenPKI(t, test.PKIOptions{})
|
||||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, pki.RootCertPEM, 0o600))
|
||||
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
f, ok := ta.(*file)
|
||||
require.True(t, ok)
|
||||
f.initFileWatchInterval = time.Millisecond
|
||||
|
||||
errCh := make(chan error)
|
||||
ctx1, cancel1 := context.WithCancel(t.Context())
|
||||
go func() {
|
||||
errCh <- f.Run(ctx1)
|
||||
}()
|
||||
time.Sleep(time.Millisecond * 10) // adding a small delay to ensure f.Run has finished and running
|
||||
|
||||
watchDone := make(chan struct{})
|
||||
ctx2, cancel2 := context.WithCancel(t.Context())
|
||||
go func() {
|
||||
ta.Watch(ctx2, make(chan []byte))
|
||||
close(watchDone)
|
||||
}()
|
||||
|
||||
cancel2()
|
||||
|
||||
select {
|
||||
case <-watchDone:
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "expected Watch to have returned")
|
||||
}
|
||||
|
||||
cancel1()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "expected error to be returned from Run")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("should update Watch subscribers when root PEM has been changed", func(t *testing.T) {
|
||||
pki1 := test.GenPKI(t, test.PKIOptions{})
|
||||
pki2 := test.GenPKI(t, test.PKIOptions{})
|
||||
pki3 := test.GenPKI(t, test.PKIOptions{})
|
||||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, pki1.RootCertPEM, 0o600))
|
||||
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
f, ok := ta.(*file)
|
||||
require.True(t, ok)
|
||||
f.initFileWatchInterval = time.Millisecond
|
||||
f.fsWatcherInterval = time.Millisecond
|
||||
|
||||
errCh := make(chan error)
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
go func() {
|
||||
errCh <- f.Run(ctx)
|
||||
}()
|
||||
time.Sleep(time.Millisecond * 10) // adding a small delay to ensure f.Run has finished and running
|
||||
|
||||
select {
|
||||
case <-f.readyCh:
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "expected to be ready in time")
|
||||
}
|
||||
|
||||
watchDone1, watchDone2 := make(chan struct{}), make(chan struct{})
|
||||
tCh1, tCh2 := make(chan []byte), make(chan []byte)
|
||||
go func() {
|
||||
ta.Watch(t.Context(), tCh1)
|
||||
close(watchDone1)
|
||||
}()
|
||||
go func() {
|
||||
ta.Watch(t.Context(), tCh2)
|
||||
close(watchDone2)
|
||||
}()
|
||||
|
||||
//nolint:gocritic
|
||||
roots := append(pki1.RootCertPEM, pki2.RootCertPEM...)
|
||||
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
|
||||
|
||||
for _, ch := range []chan []byte{tCh1, tCh2} {
|
||||
select {
|
||||
case b := <-ch:
|
||||
assert.Equal(t, string(roots), string(b))
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "failed to get subscribed file watch in time")
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:gocritic
|
||||
roots = append(pki1.RootCertPEM, append(pki2.RootCertPEM, pki3.RootCertPEM...)...)
|
||||
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
|
||||
|
||||
for _, ch := range []chan []byte{tCh1, tCh2} {
|
||||
select {
|
||||
case b := <-ch:
|
||||
assert.Equal(t, string(roots), string(b))
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "failed to get subscribed file watch in time")
|
||||
}
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
for _, ch := range []chan struct{}{watchDone1, watchDone2} {
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "expected Watch to have returned")
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "expected error to be returned from Run")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFile_CurrentTrustAnchors(t *testing.T) {
|
||||
t.Run("returns trust anchors as they change", func(t *testing.T) {
|
||||
pki1, pki2, pki3 := test.GenPKI(t, test.PKIOptions{}), test.GenPKI(t, test.PKIOptions{}), test.GenPKI(t, test.PKIOptions{})
|
||||
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
||||
require.NoError(t, os.WriteFile(tmp, pki1.RootCertPEM, 0o600))
|
||||
|
||||
ta := From(Options{
|
||||
Log: logger.NewLogger("test"),
|
||||
CAPath: tmp,
|
||||
})
|
||||
f, ok := ta.(*file)
|
||||
require.True(t, ok)
|
||||
f.initFileWatchInterval = time.Millisecond
|
||||
f.fsWatcherInterval = time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- f.Run(ctx)
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond * 10) // adding a small delay to ensure f.Run has finished and running
|
||||
//nolint:gocritic
|
||||
roots := append(pki1.RootCertPEM, pki2.RootCertPEM...)
|
||||
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
|
||||
time.Sleep(time.Millisecond * 10) // adding a small delay to ensure the file watcher has time to pick up the change
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
pem, err := ta.CurrentTrustAnchors(t.Context())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(c, roots, pem)
|
||||
}, time.Second, time.Millisecond)
|
||||
|
||||
//nolint:gocritic
|
||||
roots = append(pki1.RootCertPEM, append(pki2.RootCertPEM, pki3.RootCertPEM...)...)
|
||||
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
|
||||
time.Sleep(time.Millisecond * 10) // adding a small delay to ensure the file watcher has time to pick up the change
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
pem, err := ta.CurrentTrustAnchors(t.Context())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(c, roots, pem)
|
||||
}, time.Second, time.Millisecond)
|
||||
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "expected error to be returned from Run")
|
||||
}
|
||||
})
|
||||
}
|
|
@ -0,0 +1,86 @@
|
|||
/*
|
||||
Copyright 2025 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package multi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/spiffe/go-spiffe/v2/bundle/jwtbundle"
|
||||
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
|
||||
"github.com/spiffe/go-spiffe/v2/spiffeid"
|
||||
|
||||
"github.com/dapr/kit/concurrency"
|
||||
"github.com/dapr/kit/crypto/spiffe/trustanchors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNotImplemented = errors.New("not implemented")
|
||||
ErrTrustDomainNotFound = errors.New("trust domain not found")
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
TrustAnchors map[spiffeid.TrustDomain]trustanchors.Interface
|
||||
}
|
||||
|
||||
// multi is a TrustAnchors implementation which uses multiple trust anchors
|
||||
// which are indexed by trust domain.
|
||||
type multi struct {
|
||||
trustAnchors map[spiffeid.TrustDomain]trustanchors.Interface
|
||||
}
|
||||
|
||||
func From(opts Options) trustanchors.Interface {
|
||||
return &multi{
|
||||
trustAnchors: opts.TrustAnchors,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *multi) Run(ctx context.Context) error {
|
||||
r := concurrency.NewRunnerManager()
|
||||
for _, ta := range m.trustAnchors {
|
||||
if err := r.Add(ta.Run); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return r.Run(ctx)
|
||||
}
|
||||
|
||||
func (m *multi) CurrentTrustAnchors(context.Context) ([]byte, error) {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
|
||||
func (m *multi) GetX509BundleForTrustDomain(td spiffeid.TrustDomain) (*x509bundle.Bundle, error) {
|
||||
for tad, ta := range m.trustAnchors {
|
||||
if td.Compare(tad) == 0 {
|
||||
return ta.GetX509BundleForTrustDomain(td)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, ErrTrustDomainNotFound
|
||||
}
|
||||
|
||||
func (m *multi) GetJWTBundleForTrustDomain(td spiffeid.TrustDomain) (*jwtbundle.Bundle, error) {
|
||||
for tad, ta := range m.trustAnchors {
|
||||
if td.Compare(tad) == 0 {
|
||||
return ta.GetJWTBundleForTrustDomain(td)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, ErrTrustDomainNotFound
|
||||
}
|
||||
|
||||
func (m *multi) Watch(context.Context, chan<- []byte) {
|
||||
return
|
||||
}
|
|
@ -0,0 +1,99 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package static
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/spiffe/go-spiffe/v2/bundle/jwtbundle"
|
||||
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
|
||||
"github.com/spiffe/go-spiffe/v2/spiffeid"
|
||||
|
||||
"github.com/dapr/kit/crypto/pem"
|
||||
"github.com/dapr/kit/crypto/spiffe/trustanchors"
|
||||
)
|
||||
|
||||
// static is a TrustAcnhors implementation that uses a static list of trust
|
||||
// anchors.
|
||||
type static struct {
|
||||
x509Bundle *x509bundle.Bundle
|
||||
jwtBundle *jwtbundle.Bundle
|
||||
anchors []byte
|
||||
running atomic.Bool
|
||||
closeCh chan struct{}
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
Anchors []byte
|
||||
Jwks []byte
|
||||
}
|
||||
|
||||
func From(opts Options) (trustanchors.Interface, error) {
|
||||
// Create empty trust domain for now
|
||||
emptyTD := spiffeid.TrustDomain{}
|
||||
|
||||
var jwtBundle *jwtbundle.Bundle
|
||||
if opts.Jwks != nil {
|
||||
var err error
|
||||
jwtBundle, err = jwtbundle.Parse(emptyTD, opts.Jwks)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create JWT bundle: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
trustAnchorCerts, err := pem.DecodePEMCertificates(opts.Anchors)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode trust anchors: %w", err)
|
||||
}
|
||||
|
||||
return &static{
|
||||
anchors: opts.Anchors,
|
||||
x509Bundle: x509bundle.FromX509Authorities(emptyTD, trustAnchorCerts),
|
||||
jwtBundle: jwtBundle,
|
||||
closeCh: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *static) CurrentTrustAnchors(context.Context) ([]byte, error) {
|
||||
bundle := make([]byte, len(s.anchors))
|
||||
copy(bundle, s.anchors)
|
||||
return bundle, nil
|
||||
}
|
||||
|
||||
func (s *static) Run(ctx context.Context) error {
|
||||
if !s.running.CompareAndSwap(false, true) {
|
||||
return errors.New("trust anchors source is already running")
|
||||
}
|
||||
<-ctx.Done()
|
||||
close(s.closeCh)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *static) GetX509BundleForTrustDomain(spiffeid.TrustDomain) (*x509bundle.Bundle, error) {
|
||||
return s.x509Bundle, nil
|
||||
}
|
||||
|
||||
func (s *static) GetJWTBundleForTrustDomain(_ spiffeid.TrustDomain) (*jwtbundle.Bundle, error) {
|
||||
return s.jwtBundle, nil
|
||||
}
|
||||
|
||||
func (s *static) Watch(ctx context.Context, _ chan<- []byte) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-s.closeCh:
|
||||
}
|
||||
}
|
|
@ -0,0 +1,210 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package static
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/spiffe/go-spiffe/v2/spiffeid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/dapr/kit/crypto/test"
|
||||
)
|
||||
|
||||
func TestFromStatic(t *testing.T) {
|
||||
t.Run("empty root should return error", func(t *testing.T) {
|
||||
_, err := From(Options{})
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("garbage data should return error", func(t *testing.T) {
|
||||
_, err := From(Options{Anchors: []byte("garbage data")})
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("just garbage data should return error", func(t *testing.T) {
|
||||
_, err := From(Options{Anchors: []byte("garbage data")})
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("garbage data in root should return error", func(t *testing.T) {
|
||||
pki := test.GenPKI(t, test.PKIOptions{})
|
||||
root := pki.RootCertPEM[10:]
|
||||
_, err := From(Options{Anchors: root})
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("single root should be correctly parsed", func(t *testing.T) {
|
||||
pki := test.GenPKI(t, test.PKIOptions{})
|
||||
ta, err := From(Options{Anchors: pki.RootCertPEM})
|
||||
require.NoError(t, err)
|
||||
taPEM, err := ta.CurrentTrustAnchors(t.Context())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, pki.RootCertPEM, taPEM)
|
||||
})
|
||||
|
||||
t.Run("garbage data outside of root should be ignored", func(t *testing.T) {
|
||||
pki := test.GenPKI(t, test.PKIOptions{})
|
||||
//nolint:gocritic
|
||||
root := append(pki.RootCertPEM, []byte("garbage data")...)
|
||||
ta, err := From(Options{Anchors: root})
|
||||
require.NoError(t, err)
|
||||
taPEM, err := ta.CurrentTrustAnchors(t.Context())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, root, taPEM)
|
||||
})
|
||||
|
||||
t.Run("multiple roots should be correctly parsed", func(t *testing.T) {
|
||||
pki1, pki2 := test.GenPKI(t, test.PKIOptions{}), test.GenPKI(t, test.PKIOptions{})
|
||||
//nolint:gocritic
|
||||
roots := append(pki1.RootCertPEM, pki2.RootCertPEM...)
|
||||
ta, err := From(Options{Anchors: roots})
|
||||
require.NoError(t, err)
|
||||
taPEM, err := ta.CurrentTrustAnchors(t.Context())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, roots, taPEM)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStatic_GetX509BundleForTrustDomain(t *testing.T) {
|
||||
t.Run("Should return full PEM regardless given trust domain", func(t *testing.T) {
|
||||
pki := test.GenPKI(t, test.PKIOptions{})
|
||||
//nolint:gocritic
|
||||
root := append(pki.RootCertPEM, []byte("garbage data")...)
|
||||
ta, err := From(Options{Anchors: root})
|
||||
require.NoError(t, err)
|
||||
s, ok := ta.(*static)
|
||||
require.True(t, ok)
|
||||
|
||||
trustDomain1, err := spiffeid.TrustDomainFromString("example.com")
|
||||
require.NoError(t, err)
|
||||
bundle, err := s.GetX509BundleForTrustDomain(trustDomain1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, s.x509Bundle, bundle)
|
||||
b1, err := bundle.Marshal()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, pki.RootCertPEM, b1)
|
||||
|
||||
trustDomain2, err := spiffeid.TrustDomainFromString("another-example.org")
|
||||
require.NoError(t, err)
|
||||
bundle, err = s.GetX509BundleForTrustDomain(trustDomain2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, s.x509Bundle, bundle)
|
||||
b2, err := bundle.Marshal()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, pki.RootCertPEM, b2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStatic_Run(t *testing.T) {
|
||||
t.Run("Run multiple times should return error", func(t *testing.T) {
|
||||
pki := test.GenPKI(t, test.PKIOptions{})
|
||||
ta, err := From(Options{Anchors: pki.RootCertPEM})
|
||||
require.NoError(t, err)
|
||||
s, ok := ta.(*static)
|
||||
require.True(t, ok)
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- s.Run(ctx)
|
||||
}()
|
||||
go func() {
|
||||
errCh <- s.Run(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
require.Error(t, err)
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "Expected error")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-s.closeCh:
|
||||
assert.Fail(t, "closeCh should not be closed")
|
||||
default:
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "First Run should have returned and returned no error ")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestStatic_Watch(t *testing.T) {
|
||||
t.Run("should return when context is cancelled", func(t *testing.T) {
|
||||
pki := test.GenPKI(t, test.PKIOptions{})
|
||||
ta, err := From(Options{Anchors: pki.RootCertPEM})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
doneCh := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
ta.Watch(ctx, nil)
|
||||
close(doneCh)
|
||||
}()
|
||||
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case <-doneCh:
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "Expected doneCh to be closed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("should return when cancel is closed via closed Run", func(t *testing.T) {
|
||||
pki := test.GenPKI(t, test.PKIOptions{})
|
||||
ta, err := From(Options{Anchors: pki.RootCertPEM})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
doneCh := make(chan struct{})
|
||||
errCh := make(chan error)
|
||||
|
||||
go func() {
|
||||
errCh <- ta.Run(ctx)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
ta.Watch(t.Context(), nil)
|
||||
close(doneCh)
|
||||
}()
|
||||
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case <-doneCh:
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "Expected doneCh to be closed")
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "Expected Run to return no error")
|
||||
}
|
||||
})
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package trustanchors
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/spiffe/go-spiffe/v2/bundle/jwtbundle"
|
||||
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
|
||||
)
|
||||
|
||||
// Interface exposes a SPIFFE trust anchor from a source.
|
||||
// Allows consumers to get the current trust anchor bundle, and subscribe to
|
||||
// bundle updates.
|
||||
type Interface interface {
|
||||
// Source implements the SPIFFE trust anchor x509 bundle source.
|
||||
x509bundle.Source
|
||||
// Source implements the SPIFFE trust anchor jwt bundle source.
|
||||
jwtbundle.Source
|
||||
|
||||
// CurrentTrustAnchors returns the current trust anchor PEM bundle.
|
||||
CurrentTrustAnchors(ctx context.Context) ([]byte, error)
|
||||
|
||||
// Watch watches for changes to the trust domains and returns the PEM encoded
|
||||
// trust domain roots.
|
||||
// Returns when the given context is canceled.
|
||||
Watch(ctx context.Context, ch chan<- []byte)
|
||||
|
||||
// Run starts the trust anchor source.
|
||||
Run(ctx context.Context) error
|
||||
}
|
|
@ -0,0 +1,240 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implieh.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
|
||||
"github.com/spiffe/go-spiffe/v2/spiffegrpc/grpccredentials"
|
||||
"github.com/spiffe/go-spiffe/v2/spiffeid"
|
||||
"github.com/spiffe/go-spiffe/v2/spiffetls/tlsconfig"
|
||||
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/examples/helloworld/helloworld"
|
||||
"google.golang.org/grpc/peer"
|
||||
)
|
||||
|
||||
type PKIOptions struct {
|
||||
LeafDNS string
|
||||
LeafID spiffeid.ID
|
||||
ClientDNS string
|
||||
ClientID spiffeid.ID
|
||||
}
|
||||
|
||||
type PKI struct {
|
||||
RootCertPEM []byte
|
||||
RootCert *x509.Certificate
|
||||
LeafCert *x509.Certificate
|
||||
LeafCertPEM []byte
|
||||
LeafPKPEM []byte
|
||||
LeafPK crypto.Signer
|
||||
ClientCertPEM []byte
|
||||
ClientCert *x509.Certificate
|
||||
ClientPKPEM []byte
|
||||
ClientPK crypto.Signer
|
||||
|
||||
leafID spiffeid.ID
|
||||
clientID spiffeid.ID
|
||||
}
|
||||
|
||||
func GenPKI(t *testing.T, opts PKIOptions) PKI {
|
||||
t.Helper()
|
||||
pki, err := GenPKIError(opts)
|
||||
require.NoError(t, err)
|
||||
return pki
|
||||
}
|
||||
|
||||
func GenPKIError(opts PKIOptions) (PKI, error) {
|
||||
rootPK, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return PKI{}, err
|
||||
}
|
||||
|
||||
rootCert := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "Dapr Test Root CA"},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
IsCA: true,
|
||||
KeyUsage: x509.KeyUsageCertSign,
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
rootCertBytes, err := x509.CreateCertificate(rand.Reader, rootCert, rootCert, &rootPK.PublicKey, rootPK)
|
||||
if err != nil {
|
||||
return PKI{}, err
|
||||
}
|
||||
|
||||
rootCert, err = x509.ParseCertificate(rootCertBytes)
|
||||
if err != nil {
|
||||
return PKI{}, err
|
||||
}
|
||||
|
||||
rootCertPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: rootCertBytes})
|
||||
|
||||
leafCertPEM, leafPKPEM, leafCert, leafPK, err := genLeafCert(rootPK, rootCert, opts.LeafID, opts.LeafDNS)
|
||||
if err != nil {
|
||||
return PKI{}, err
|
||||
}
|
||||
clientCertPEM, clientPKPEM, clientCert, clientPK, err := genLeafCert(rootPK, rootCert, opts.ClientID, opts.ClientDNS)
|
||||
if err != nil {
|
||||
return PKI{}, err
|
||||
}
|
||||
|
||||
return PKI{
|
||||
RootCert: rootCert,
|
||||
RootCertPEM: rootCertPEM,
|
||||
LeafCertPEM: leafCertPEM,
|
||||
LeafPKPEM: leafPKPEM,
|
||||
LeafCert: leafCert,
|
||||
LeafPK: leafPK,
|
||||
ClientCertPEM: clientCertPEM,
|
||||
ClientPKPEM: clientPKPEM,
|
||||
ClientCert: clientCert,
|
||||
ClientPK: clientPK,
|
||||
leafID: opts.LeafID,
|
||||
clientID: opts.ClientID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p PKI) ClientGRPCCtx(t *testing.T) context.Context {
|
||||
t.Helper()
|
||||
|
||||
bundle := x509bundle.New(spiffeid.RequireTrustDomainFromString("example.org"))
|
||||
bundle.AddX509Authority(p.RootCert)
|
||||
serverSVID := &mockSVID{
|
||||
bundle: bundle,
|
||||
svid: &x509svid.SVID{
|
||||
ID: p.leafID,
|
||||
Certificates: []*x509.Certificate{p.LeafCert},
|
||||
PrivateKey: p.LeafPK,
|
||||
},
|
||||
}
|
||||
|
||||
clientSVID := &mockSVID{
|
||||
bundle: bundle,
|
||||
svid: &x509svid.SVID{
|
||||
ID: p.clientID,
|
||||
Certificates: []*x509.Certificate{p.ClientCert},
|
||||
PrivateKey: p.ClientPK,
|
||||
},
|
||||
}
|
||||
|
||||
server := grpc.NewServer(grpc.Creds(grpccredentials.MTLSServerCredentials(serverSVID, serverSVID, tlsconfig.AuthorizeAny())))
|
||||
gs := new(greeterServer)
|
||||
helloworld.RegisterGreeterServer(server, gs)
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
go func() {
|
||||
server.Serve(lis)
|
||||
}()
|
||||
//nolint:staticcheck
|
||||
conn, err := grpc.DialContext(t.Context(), lis.Addr().String(),
|
||||
grpc.WithTransportCredentials(grpccredentials.MTLSClientCredentials(clientSVID, clientSVID, tlsconfig.AuthorizeAny())),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = helloworld.NewGreeterClient(conn).SayHello(t.Context(), new(helloworld.HelloRequest))
|
||||
require.NoError(t, err)
|
||||
|
||||
lis.Close()
|
||||
server.Stop()
|
||||
|
||||
return gs.ctx
|
||||
}
|
||||
|
||||
func genLeafCert(rootPK *ecdsa.PrivateKey, rootCert *x509.Certificate, id spiffeid.ID, dns string) ([]byte, []byte, *x509.Certificate, crypto.Signer, error) {
|
||||
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
pkBytes, err := x509.MarshalPKCS8PrivateKey(pk)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
cert := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageServerAuth,
|
||||
x509.ExtKeyUsageClientAuth,
|
||||
},
|
||||
}
|
||||
|
||||
if len(dns) > 0 {
|
||||
cert.DNSNames = []string{dns}
|
||||
}
|
||||
|
||||
if !id.IsZero() {
|
||||
cert.URIs = []*url.URL{id.URL()}
|
||||
}
|
||||
|
||||
certBytes, err := x509.CreateCertificate(rand.Reader, cert, rootCert, &pk.PublicKey, rootPK)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
cert, err = x509.ParseCertificate(certBytes)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
pkPEM := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: pkBytes})
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes})
|
||||
|
||||
return certPEM, pkPEM, cert, pk, nil
|
||||
}
|
||||
|
||||
type mockSVID struct {
|
||||
svid *x509svid.SVID
|
||||
bundle *x509bundle.Bundle
|
||||
}
|
||||
|
||||
func (m *mockSVID) GetX509BundleForTrustDomain(_ spiffeid.TrustDomain) (*x509bundle.Bundle, error) {
|
||||
return m.bundle, nil
|
||||
}
|
||||
|
||||
func (m *mockSVID) GetX509SVID() (*x509svid.SVID, error) {
|
||||
return m.svid, nil
|
||||
}
|
||||
|
||||
type greeterServer struct {
|
||||
helloworld.UnimplementedGreeterServer
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (s *greeterServer) SayHello(ctx context.Context, _ *helloworld.HelloRequest) (*helloworld.HelloReply, error) {
|
||||
p, _ := peer.FromContext(ctx)
|
||||
s.ctx = peer.NewContext(context.Background(), p)
|
||||
return new(helloworld.HelloReply), nil
|
||||
}
|
|
@ -41,7 +41,7 @@ pHZ3vWGFAoGAc5Um3YYkhh2QScQBy5+kumH40LhFFy2ETznWEp0tS2NwmTfTm/Nl
|
|||
Sg+Ct2nOw93cIhwDjWyoilkIapuuX2obY+sUc3kj2ugU+hONfuBStsF020IPP1sk
|
||||
A9okIZVbz8ycqcjaBiNc4+TeiXED1K7bV9Kg+A9lxDxfGRybJ1/ECWA=
|
||||
-----END RSA PRIVATE KEY-----
|
||||
`
|
||||
` // #nosec G101
|
||||
privateKeyRSAPKCS8 = `-----BEGIN PRIVATE KEY-----
|
||||
MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDcjaZ0griZFG77
|
||||
LAytiNRnMHG3Q2UBUusyEaVomxvLs9ZMyIullWKnhIEP0bCcJTRMYUPuTb7u1+zT
|
||||
|
@ -128,7 +128,7 @@ MHcCAQEEIOcFe4Q6ardS97ml2tV4+194nmlfQPh8o9ir/qsacEozoAoGCCqGSM49
|
|||
AwEHoUQDQgAEUMn1c2ioMNi2DqvC8hdBVUERFZ97eVFsNVcQIgR0Hsq5PVrQ/dQ4
|
||||
uI5u97b6k4wXHYFXMvPmsW1T6qZAE9bB3Q==
|
||||
-----END EC PRIVATE KEY-----
|
||||
`
|
||||
` // #nosec G101
|
||||
privateKeyP256PKCS8 = `-----BEGIN PRIVATE KEY-----
|
||||
MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQg5wV7hDpqt1L3uaXa
|
||||
1Xj7X3ieaV9A+Hyj2Kv+qxpwSjOhRANCAARQyfVzaKgw2LYOq8LyF0FVQREVn3t5
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package env
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GetDurationWithRange returns the time.Duration value of the environment variable specified by `envVar`.
|
||||
// If the environment variable is not set, it returns `defaultValue`.
|
||||
// If the value is set but is not valid (not a valid time.Duration or falls outside the specified range
|
||||
// [minValue, maxValue] inclusively), it returns `defaultValue` and an error.
|
||||
func GetDurationWithRange(envVar string, defaultValue, min, max time.Duration) (time.Duration, error) {
|
||||
v := os.Getenv(envVar)
|
||||
if v == "" {
|
||||
return defaultValue, nil
|
||||
}
|
||||
|
||||
val, err := time.ParseDuration(v)
|
||||
if err != nil {
|
||||
return defaultValue, fmt.Errorf("invalid time.Duration value %s for the %s env variable: %w", val, envVar, err)
|
||||
}
|
||||
|
||||
if val < min || val > max {
|
||||
return defaultValue, fmt.Errorf("invalid value for the %s env variable: value should be between %s and %s, got %s", envVar, min, max, val)
|
||||
}
|
||||
|
||||
return val, nil
|
||||
}
|
|
@ -0,0 +1,96 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package env
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetIntWithRangeWrongValues(t *testing.T) {
|
||||
testValues := []struct {
|
||||
name string
|
||||
envVarVal string
|
||||
min time.Duration
|
||||
max time.Duration
|
||||
error string
|
||||
}{
|
||||
{
|
||||
"should error if value is not a valid time.Duration",
|
||||
"0.5",
|
||||
time.Second,
|
||||
2 * time.Second,
|
||||
"invalid time.Duration value 0s for the MY_ENV env variable",
|
||||
},
|
||||
{
|
||||
"should error if value is lower than 1s",
|
||||
"0s",
|
||||
time.Second,
|
||||
10 * time.Second,
|
||||
"value should be between 1s and 10s",
|
||||
},
|
||||
{
|
||||
"should error if value is higher than 10s",
|
||||
"2m",
|
||||
time.Second,
|
||||
10 * time.Second,
|
||||
"value should be between 1s and 10s",
|
||||
},
|
||||
}
|
||||
|
||||
defaultValue := 3 * time.Second
|
||||
for _, tt := range testValues {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Setenv("MY_ENV", tt.envVarVal)
|
||||
|
||||
val, err := GetDurationWithRange("MY_ENV", defaultValue, tt.min, tt.max)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tt.error)
|
||||
require.Equal(t, defaultValue, val)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetEnvDurationWithRangeValidValues(t *testing.T) {
|
||||
testValues := []struct {
|
||||
name string
|
||||
envVarVal string
|
||||
result time.Duration
|
||||
}{
|
||||
{
|
||||
"should return default value if env variable is not set",
|
||||
"",
|
||||
3 * time.Second,
|
||||
},
|
||||
{
|
||||
"should return result is env variable value is valid",
|
||||
"4s",
|
||||
4 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range testValues {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.envVarVal != "" {
|
||||
t.Setenv("MY_ENV", tt.envVarVal)
|
||||
}
|
||||
|
||||
val, err := GetDurationWithRange("MY_ENV", 3*time.Second, time.Second, 5*time.Second)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.result, val)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
# Errors
|
||||
|
||||
The standardizing of errors to be used in Dapr based on the gRPC Richer Error Model and [accepted dapr/proposal](https://github.com/dapr/proposals/blob/main/0009-BCIRS-error-handling-codes.md).
|
||||
The standardizing of errors to be used in Dapr based on the gRPC Richer Error Model and [accepted dapr/proposal](https://github.com/dapr/proposals/blob/main/20230511-BCIRS-error-handling-codes.md).
|
||||
|
||||
## Usage
|
||||
|
||||
|
|
|
@ -24,11 +24,13 @@ const (
|
|||
CodePrefixStateStore = "DAPR_STATE_"
|
||||
CodePrefixPubSub = "DAPR_PUBSUB_"
|
||||
CodePrefixBindings = "DAPR_BINDING_"
|
||||
CodePrefixSecretStore = "DAPR_SECRET_"
|
||||
CodePrefixSecretStore = "DAPR_SECRET_" // #nosec G101
|
||||
CodePrefixConfigurationStore = "DAPR_CONFIGURATION_"
|
||||
CodePrefixLock = "DAPR_LOCK_"
|
||||
CodePrefixNameResolution = "DAPR_NAME_RESOLUTION_"
|
||||
CodePrefixMiddleware = "DAPR_MIDDLEWARE_"
|
||||
CodePrefixCryptography = "DAPR_CRYPTOGRAPHY_"
|
||||
CodePrefixPlacement = "DAPR_PLACEMENT_"
|
||||
|
||||
// State
|
||||
CodePostfixGetStateFailed = "GET_STATE_FAILED"
|
||||
|
|
|
@ -56,11 +56,14 @@ type Error struct {
|
|||
|
||||
// Tag is a string identifying the error, used with HTTP responses only.
|
||||
tag string
|
||||
|
||||
// Category is a string identifying the category of the error (i.e. "actor", "job", "pubsub), used for error code metrics only.
|
||||
category string
|
||||
}
|
||||
|
||||
// ErrorBuilder is used to build the error
|
||||
type ErrorBuilder struct {
|
||||
err Error
|
||||
err *Error
|
||||
}
|
||||
|
||||
// errorJSON is used to build the error for the HTTP Methods json output
|
||||
|
@ -84,13 +87,31 @@ func (e *Error) GrpcStatusCode() grpcCodes.Code {
|
|||
return e.grpcCode
|
||||
}
|
||||
|
||||
// ErrorCode returns the error code from the error, prioritizing the legacy Error.Tag, otherwise the ErrorInfo.Reason
|
||||
func (e *Error) ErrorCode() string {
|
||||
errorCode := e.tag
|
||||
for _, detail := range e.details {
|
||||
if _, ok := detail.(*errdetails.ErrorInfo); ok {
|
||||
if _, errInfoReason := convertErrorDetails(detail, *e); errInfoReason != "" {
|
||||
return errInfoReason
|
||||
}
|
||||
}
|
||||
}
|
||||
return errorCode
|
||||
}
|
||||
|
||||
// Category returns the error code's category
|
||||
func (e *Error) Category() string {
|
||||
return e.category
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e Error) Error() string {
|
||||
func (e *Error) Error() string {
|
||||
return e.String()
|
||||
}
|
||||
|
||||
// String returns the string representation.
|
||||
func (e Error) String() string {
|
||||
func (e *Error) String() string {
|
||||
return fmt.Sprintf(errStringFormat, e.grpcCode.String(), e.message)
|
||||
}
|
||||
|
||||
|
@ -119,9 +140,9 @@ func FromError(err error) (*Error, bool) {
|
|||
return nil, false
|
||||
}
|
||||
|
||||
var kitErr Error
|
||||
var kitErr *Error
|
||||
if errors.As(err, &kitErr) {
|
||||
return &kitErr, true
|
||||
return kitErr, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
|
@ -130,7 +151,7 @@ func FromError(err error) (*Error, bool) {
|
|||
/*** GRPC Methods ***/
|
||||
|
||||
// GRPCStatus returns the gRPC status.Status object.
|
||||
func (e Error) GRPCStatus() *status.Status {
|
||||
func (e *Error) GRPCStatus() *status.Status {
|
||||
stat := status.New(e.grpcCode, e.message)
|
||||
|
||||
// convert details from proto.Msg -> protoiface.MsgV1
|
||||
|
@ -157,7 +178,7 @@ func (e Error) GRPCStatus() *status.Status {
|
|||
/*** HTTP Methods ***/
|
||||
|
||||
// JSONErrorValue implements the errorResponseValue interface.
|
||||
func (e Error) JSONErrorValue() []byte {
|
||||
func (e *Error) JSONErrorValue() []byte {
|
||||
grpcStatus := e.GRPCStatus().Proto()
|
||||
|
||||
// Make httpCode human readable
|
||||
|
@ -179,7 +200,7 @@ func (e Error) JSONErrorValue() []byte {
|
|||
if len(details) > 0 {
|
||||
errJSON.Details = make([]any, len(details))
|
||||
for i, detail := range details {
|
||||
detailMap, errorCode := convertErrorDetails(detail, e)
|
||||
detailMap, errorCode := convertErrorDetails(detail, *e)
|
||||
errJSON.Details[i] = detailMap
|
||||
|
||||
// If there is an errorCode, update the overall ErrorCode
|
||||
|
@ -334,14 +355,15 @@ ErrorBuilder
|
|||
**************************************/
|
||||
|
||||
// NewBuilder create a new ErrorBuilder using the supplied required error fields
|
||||
func NewBuilder(grpcCode grpcCodes.Code, httpCode int, message string, tag string) *ErrorBuilder {
|
||||
func NewBuilder(grpcCode grpcCodes.Code, httpCode int, message string, tag string, category string) *ErrorBuilder {
|
||||
return &ErrorBuilder{
|
||||
err: Error{
|
||||
err: &Error{
|
||||
details: make([]proto.Message, 0),
|
||||
grpcCode: grpcCode,
|
||||
httpCode: httpCode,
|
||||
message: message,
|
||||
tag: tag,
|
||||
category: category,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -45,11 +45,12 @@ func TestError_HTTPStatusCode(t *testing.T) {
|
|||
httpStatusCode,
|
||||
"Test Msg",
|
||||
"SOME_ERROR",
|
||||
"some_category",
|
||||
).
|
||||
WithErrorInfo("fake", map[string]string{"fake": "test"}).
|
||||
Build()
|
||||
|
||||
err, ok := kitErr.(Error)
|
||||
err, ok := kitErr.(*Error)
|
||||
require.True(t, ok, httpStatusCode, err.HTTPStatusCode())
|
||||
}
|
||||
|
||||
|
@ -60,11 +61,12 @@ func TestError_GrpcStatusCode(t *testing.T) {
|
|||
http.StatusTeapot,
|
||||
"Test Msg",
|
||||
"SOME_ERROR",
|
||||
"some_category",
|
||||
).
|
||||
WithErrorInfo("fake", map[string]string{"fake": "test"}).
|
||||
Build()
|
||||
|
||||
err, ok := kitErr.(Error)
|
||||
err, ok := kitErr.(*Error)
|
||||
require.True(t, ok, grpcStatusCode, err.GrpcStatusCode())
|
||||
}
|
||||
|
||||
|
@ -125,6 +127,7 @@ func TestError_Error(t *testing.T) {
|
|||
http.StatusTeapot,
|
||||
"Msg",
|
||||
"SOME_ERROR",
|
||||
"some_category",
|
||||
).WithErrorInfo("fake", map[string]string{"fake": "test"}),
|
||||
fields: fields{
|
||||
message: "Msg",
|
||||
|
@ -139,6 +142,7 @@ func TestError_Error(t *testing.T) {
|
|||
http.StatusTeapot,
|
||||
"Msg",
|
||||
"SOME_ERROR",
|
||||
"some_category",
|
||||
).WithErrorInfo("fake", map[string]string{"fake": "test"}),
|
||||
fields: fields{
|
||||
message: "Msg",
|
||||
|
@ -152,6 +156,7 @@ func TestError_Error(t *testing.T) {
|
|||
http.StatusTeapot,
|
||||
"Msg",
|
||||
"SOME_ERROR",
|
||||
"some_category",
|
||||
).WithErrorInfo("fake", map[string]string{"fake": "test"}),
|
||||
fields: fields{
|
||||
grpcCode: grpcCodes.Canceled,
|
||||
|
@ -166,7 +171,7 @@ func TestError_Error(t *testing.T) {
|
|||
t.Errorf("got = %v, want %v", got, tt.want)
|
||||
}
|
||||
|
||||
err, ok := kitErr.(Error)
|
||||
err, ok := kitErr.(*Error)
|
||||
require.True(t, ok, err.Is(kitErr))
|
||||
})
|
||||
}
|
||||
|
@ -181,11 +186,12 @@ func TestErrorBuilder_WithErrorInfo(t *testing.T) {
|
|||
Metadata: metadata,
|
||||
}
|
||||
|
||||
expected := Error{
|
||||
expected := &Error{
|
||||
grpcCode: grpcCodes.ResourceExhausted,
|
||||
httpCode: http.StatusTeapot,
|
||||
message: "fake_message",
|
||||
tag: "DAPR_FAKE_TAG",
|
||||
category: "some_category",
|
||||
details: []proto.Message{
|
||||
details,
|
||||
},
|
||||
|
@ -196,6 +202,7 @@ func TestErrorBuilder_WithErrorInfo(t *testing.T) {
|
|||
http.StatusTeapot,
|
||||
"fake_message",
|
||||
"DAPR_FAKE_TAG",
|
||||
"some_category",
|
||||
).
|
||||
WithErrorInfo(reason, metadata)
|
||||
|
||||
|
@ -222,6 +229,7 @@ func TestErrorBuilder_WithDetails(t *testing.T) {
|
|||
httpCode int
|
||||
message string
|
||||
tag string
|
||||
category string
|
||||
}
|
||||
|
||||
type args struct {
|
||||
|
@ -232,7 +240,7 @@ func TestErrorBuilder_WithDetails(t *testing.T) {
|
|||
name string
|
||||
fields fields
|
||||
args args
|
||||
want Error
|
||||
want *Error
|
||||
}{
|
||||
{
|
||||
name: "Has_Multiple_Details",
|
||||
|
@ -255,7 +263,7 @@ func TestErrorBuilder_WithDetails(t *testing.T) {
|
|||
Description: "test_description",
|
||||
},
|
||||
}},
|
||||
want: Error{
|
||||
want: &Error{
|
||||
grpcCode: grpcCodes.ResourceExhausted,
|
||||
httpCode: http.StatusTeapot,
|
||||
message: "fake_message",
|
||||
|
@ -283,6 +291,7 @@ func TestErrorBuilder_WithDetails(t *testing.T) {
|
|||
test.fields.httpCode,
|
||||
test.fields.message,
|
||||
test.fields.tag,
|
||||
test.fields.category,
|
||||
).WithDetails(test.args.a...)
|
||||
|
||||
assert.Equal(t, test.want, kitErr.Build())
|
||||
|
@ -292,7 +301,7 @@ func TestErrorBuilder_WithDetails(t *testing.T) {
|
|||
|
||||
func TestWithErrorHelp(t *testing.T) {
|
||||
// Initialize the Error struct with some default values
|
||||
err := NewBuilder(grpcCodes.InvalidArgument, http.StatusBadRequest, "Internal error", "INTERNAL_ERROR")
|
||||
err := NewBuilder(grpcCodes.InvalidArgument, http.StatusBadRequest, "Internal error", "INTERNAL_ERROR", "some_category")
|
||||
|
||||
// Define test data for the help links
|
||||
links := []*errdetails.Help_Link{
|
||||
|
@ -319,7 +328,7 @@ func TestWithErrorHelp(t *testing.T) {
|
|||
|
||||
func TestWithErrorFieldViolation(t *testing.T) {
|
||||
// Initialize the Error struct with some default values
|
||||
err := NewBuilder(grpcCodes.InvalidArgument, http.StatusBadRequest, "Internal error", "INTERNAL_ERROR")
|
||||
err := NewBuilder(grpcCodes.InvalidArgument, http.StatusBadRequest, "Internal error", "INTERNAL_ERROR", "some_category")
|
||||
|
||||
// Define test data for the field violation
|
||||
fieldName := "testField"
|
||||
|
@ -348,6 +357,7 @@ func TestError_JSONErrorValue(t *testing.T) {
|
|||
httpCode int
|
||||
message string
|
||||
tag string
|
||||
category string
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
|
@ -657,7 +667,7 @@ func TestError_JSONErrorValue(t *testing.T) {
|
|||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
kitErr := NewBuilder(test.fields.grpcCode, test.fields.httpCode, test.fields.message, test.fields.tag).
|
||||
kitErr := NewBuilder(test.fields.grpcCode, test.fields.httpCode, test.fields.message, test.fields.tag, test.fields.category).
|
||||
WithDetails(test.fields.details...)
|
||||
|
||||
got := kitErr.err.JSONErrorValue()
|
||||
|
@ -705,6 +715,7 @@ func TestError_GRPCStatus(t *testing.T) {
|
|||
httpCode int
|
||||
message string
|
||||
tag string
|
||||
category string
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
|
@ -769,6 +780,7 @@ func TestError_GRPCStatus(t *testing.T) {
|
|||
test.fields.httpCode,
|
||||
test.fields.message,
|
||||
test.fields.tag,
|
||||
test.fields.category,
|
||||
).WithDetails(test.fields.details...)
|
||||
|
||||
got := kitErr.err.GRPCStatus()
|
||||
|
@ -787,9 +799,10 @@ func TestErrorBuilder_Build(t *testing.T) {
|
|||
http.StatusTeapot,
|
||||
"Test Msg",
|
||||
"SOME_ERROR",
|
||||
"some_category",
|
||||
).WithErrorInfo("fake", map[string]string{"fake": "test"}).Build()
|
||||
|
||||
builtErr, ok := built.(Error)
|
||||
builtErr, ok := built.(*Error)
|
||||
require.True(t, ok)
|
||||
|
||||
containsErrorInfo := false
|
||||
|
@ -803,6 +816,33 @@ func TestErrorBuilder_Build(t *testing.T) {
|
|||
}
|
||||
|
||||
assert.True(t, containsErrorInfo)
|
||||
assert.Equal(t, "SOME_ERROR", builtErr.ErrorCode())
|
||||
})
|
||||
|
||||
t.Run("With_ErrorInfo (legacy tag absent)", func(t *testing.T) {
|
||||
built := NewBuilder(
|
||||
grpcCodes.ResourceExhausted,
|
||||
http.StatusTeapot,
|
||||
"Test Msg",
|
||||
"",
|
||||
"some_category",
|
||||
).WithErrorInfo("SOME_ERROR", map[string]string{"fake": "test"}).Build()
|
||||
|
||||
builtErr, ok := built.(*Error)
|
||||
require.True(t, ok)
|
||||
|
||||
containsErrorInfo := false
|
||||
|
||||
for _, detail := range builtErr.details {
|
||||
_, isErrInfo := detail.(*errdetails.ErrorInfo)
|
||||
if isErrInfo {
|
||||
containsErrorInfo = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, containsErrorInfo)
|
||||
assert.Equal(t, "SOME_ERROR", builtErr.ErrorCode())
|
||||
})
|
||||
|
||||
t.Run("Without_ErrorInfo", func(t *testing.T) {
|
||||
|
@ -811,6 +851,7 @@ func TestErrorBuilder_Build(t *testing.T) {
|
|||
http.StatusTeapot,
|
||||
"Test Msg",
|
||||
"SOME_ERROR",
|
||||
"some_category",
|
||||
)
|
||||
|
||||
assert.PanicsWithValue(t, "Must include ErrorInfo in error details.", func() {
|
||||
|
@ -949,7 +990,7 @@ func TestFromError(t *testing.T) {
|
|||
t.Errorf("Expected result to be nil and ok to be false, got result: %v, ok: %t", result, ok)
|
||||
}
|
||||
|
||||
kitErr := Error{
|
||||
kitErr := &Error{
|
||||
grpcCode: grpcCodes.ResourceExhausted,
|
||||
httpCode: http.StatusTeapot,
|
||||
message: "fake_message",
|
||||
|
@ -958,8 +999,8 @@ func TestFromError(t *testing.T) {
|
|||
}
|
||||
|
||||
result, ok = FromError(kitErr)
|
||||
if !ok || !reflect.DeepEqual(result, &kitErr) {
|
||||
t.Errorf("Expected result to be %#v and ok to be true, got result: %#v, ok: %t", &kitErr, result, ok)
|
||||
if !ok || !reflect.DeepEqual(result, kitErr) {
|
||||
t.Errorf("Expected result to be %#v and ok to be true, got result: %#v, ok: %t", kitErr, result, ok)
|
||||
}
|
||||
|
||||
var nonKitError error
|
||||
|
@ -970,7 +1011,7 @@ func TestFromError(t *testing.T) {
|
|||
|
||||
wrapped := fmt.Errorf("wrapped: %w", kitErr)
|
||||
result, ok = FromError(wrapped)
|
||||
if !ok || !reflect.DeepEqual(result, &kitErr) {
|
||||
t.Errorf("Expected result to be %#v and ok to be true, got result: %#v, ok: %t", &kitErr, result, ok)
|
||||
if !ok || !reflect.DeepEqual(result, kitErr) {
|
||||
t.Errorf("Expected result to be %#v and ok to be true, got result: %#v, ok: %t", kitErr, result, ok)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||
package batcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
@ -23,14 +24,25 @@ import (
|
|||
"github.com/dapr/kit/events/queue"
|
||||
)
|
||||
|
||||
type eventCh[T any] struct {
|
||||
id int
|
||||
ch chan<- T
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
Interval time.Duration
|
||||
Clock clock.Clock
|
||||
}
|
||||
|
||||
// Batcher is a one to many event batcher. It batches events and sends them to
|
||||
// the added event channel subscribers. Events are sent to the channels after
|
||||
// the interval has elapsed. If events with the same key are received within
|
||||
// the interval, the timer is reset.
|
||||
type Batcher[T comparable] struct {
|
||||
interval time.Duration
|
||||
eventChs []chan<- struct{}
|
||||
queue *queue.Processor[T, *item[T]]
|
||||
type Batcher[K comparable, T any] struct {
|
||||
interval time.Duration
|
||||
eventChs []*eventCh[T]
|
||||
queue *queue.Processor[K, *item[K, T]]
|
||||
currentID int
|
||||
|
||||
clock clock.Clock
|
||||
lock sync.Mutex
|
||||
|
@ -40,85 +52,129 @@ type Batcher[T comparable] struct {
|
|||
}
|
||||
|
||||
// New creates a new Batcher with the given interval and key type.
|
||||
func New[T comparable](interval time.Duration) *Batcher[T] {
|
||||
b := &Batcher[T]{
|
||||
interval: interval,
|
||||
clock: clock.RealClock{},
|
||||
func New[K comparable, T any](opts Options) *Batcher[K, T] {
|
||||
cl := opts.Clock
|
||||
if cl == nil {
|
||||
cl = clock.RealClock{}
|
||||
}
|
||||
|
||||
b := &Batcher[K, T]{
|
||||
interval: opts.Interval,
|
||||
clock: cl,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
b.queue = queue.NewProcessor[T, *item[T]](b.execute)
|
||||
b.queue = queue.NewProcessor[K, *item[K, T]](queue.Options[K, *item[K, T]]{
|
||||
ExecuteFn: b.execute,
|
||||
Clock: opts.Clock,
|
||||
})
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithClock sets the clock used by the batcher. Used for testing.
|
||||
func (b *Batcher[T]) WithClock(clock clock.Clock) {
|
||||
b.queue.WithClock(clock)
|
||||
b.clock = clock
|
||||
}
|
||||
|
||||
// Subscribe adds a new event channel subscriber. If the batcher is closed, the
|
||||
// subscriber is silently dropped.
|
||||
func (b *Batcher[T]) Subscribe(eventCh ...chan<- struct{}) {
|
||||
func (b *Batcher[K, T]) Subscribe(ctx context.Context, ch ...chan<- T) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
if b.closed.Load() {
|
||||
return
|
||||
for _, c := range ch {
|
||||
b.subscribe(ctx, c)
|
||||
}
|
||||
b.eventChs = append(b.eventChs, eventCh...)
|
||||
}
|
||||
|
||||
func (b *Batcher[T]) execute(_ *item[T]) {
|
||||
func (b *Batcher[K, T]) subscribe(ctx context.Context, ch chan<- T) {
|
||||
if b.closed.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
id := b.currentID
|
||||
b.currentID++
|
||||
bufferedCh := make(chan T, 50)
|
||||
b.eventChs = append(b.eventChs, &eventCh[T]{
|
||||
id: id,
|
||||
ch: bufferedCh,
|
||||
})
|
||||
|
||||
b.wg.Add(1)
|
||||
go func() {
|
||||
defer func() {
|
||||
b.lock.Lock()
|
||||
close(ch)
|
||||
for i, eventCh := range b.eventChs {
|
||||
if eventCh.id == id {
|
||||
b.eventChs = append(b.eventChs[:i], b.eventChs[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
b.lock.Unlock()
|
||||
b.wg.Done()
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-b.closeCh:
|
||||
return
|
||||
case env := <-bufferedCh:
|
||||
select {
|
||||
case ch <- env:
|
||||
case <-ctx.Done():
|
||||
case <-b.closeCh:
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (b *Batcher[K, T]) execute(i *item[K, T]) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
if b.closed.Load() {
|
||||
return
|
||||
}
|
||||
b.wg.Add(len(b.eventChs))
|
||||
for _, eventCh := range b.eventChs {
|
||||
go func(eventCh chan<- struct{}) {
|
||||
defer b.wg.Done()
|
||||
select {
|
||||
case eventCh <- struct{}{}:
|
||||
case <-b.closeCh:
|
||||
}
|
||||
}(eventCh)
|
||||
for _, ev := range b.eventChs {
|
||||
select {
|
||||
case ev.ch <- i.value:
|
||||
case <-b.closeCh:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Batch adds the given key to the batcher. If an event for this key is already
|
||||
// active, the timer is reset. If the batcher is closed, the key is silently
|
||||
// dropped.
|
||||
func (b *Batcher[T]) Batch(key T) {
|
||||
b.queue.Enqueue(&item[T]{
|
||||
key: key,
|
||||
ttl: b.clock.Now().Add(b.interval),
|
||||
func (b *Batcher[K, T]) Batch(key K, value T) {
|
||||
b.queue.Enqueue(&item[K, T]{
|
||||
key: key,
|
||||
value: value,
|
||||
ttl: b.clock.Now().Add(b.interval),
|
||||
})
|
||||
}
|
||||
|
||||
// Close closes the batcher. It blocks until all events have been sent to the
|
||||
// subscribers. The batcher will be a no-op after this call.
|
||||
func (b *Batcher[T]) Close() {
|
||||
func (b *Batcher[K, T]) Close() {
|
||||
defer b.wg.Wait()
|
||||
b.queue.Close()
|
||||
b.lock.Lock()
|
||||
if b.closed.CompareAndSwap(false, true) {
|
||||
close(b.closeCh)
|
||||
}
|
||||
b.lock.Unlock()
|
||||
b.queue.Close()
|
||||
}
|
||||
|
||||
// item implements queue.queueable.
|
||||
type item[T comparable] struct {
|
||||
key T
|
||||
ttl time.Time
|
||||
type item[K comparable, T any] struct {
|
||||
key K
|
||||
value T
|
||||
ttl time.Time
|
||||
}
|
||||
|
||||
func (b *item[T]) Key() T {
|
||||
func (b *item[K, T]) Key() K {
|
||||
return b.key
|
||||
}
|
||||
|
||||
func (b *item[T]) ScheduledTime() time.Time {
|
||||
func (b *item[K, T]) ScheduledTime() time.Time {
|
||||
return b.ttl
|
||||
}
|
||||
|
|
|
@ -25,24 +25,25 @@ func TestNew(t *testing.T) {
|
|||
t.Parallel()
|
||||
|
||||
interval := time.Millisecond * 10
|
||||
b := New[string](interval)
|
||||
assert.Equal(t, interval, b.interval)
|
||||
b := New[string, struct{}](Options{Interval: interval})
|
||||
assert.False(t, b.closed.Load())
|
||||
}
|
||||
|
||||
func TestWithClock(t *testing.T) {
|
||||
b := New[string](time.Millisecond * 10)
|
||||
fakeClock := testingclock.NewFakeClock(time.Now())
|
||||
b.WithClock(fakeClock)
|
||||
b := New[string, struct{}](Options{
|
||||
Interval: time.Millisecond * 10,
|
||||
Clock: fakeClock,
|
||||
})
|
||||
assert.Equal(t, fakeClock, b.clock)
|
||||
}
|
||||
|
||||
func TestSubscribe(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
b := New[string](time.Millisecond * 10)
|
||||
b := New[string, struct{}](Options{Interval: time.Millisecond * 10})
|
||||
ch := make(chan struct{})
|
||||
b.Subscribe(ch)
|
||||
b.Subscribe(t.Context(), ch)
|
||||
assert.Len(t, b.eventChs, 1)
|
||||
}
|
||||
|
||||
|
@ -50,22 +51,24 @@ func TestBatch(t *testing.T) {
|
|||
t.Parallel()
|
||||
|
||||
fakeClock := testingclock.NewFakeClock(time.Now())
|
||||
b := New[string](time.Millisecond * 10)
|
||||
b.WithClock(fakeClock)
|
||||
b := New[string, struct{}](Options{
|
||||
Interval: time.Millisecond * 10,
|
||||
Clock: fakeClock,
|
||||
})
|
||||
ch1 := make(chan struct{})
|
||||
ch2 := make(chan struct{})
|
||||
ch3 := make(chan struct{})
|
||||
b.Subscribe(ch1, ch2)
|
||||
b.Subscribe(ch3)
|
||||
b.Subscribe(t.Context(), ch1, ch2)
|
||||
b.Subscribe(t.Context(), ch3)
|
||||
|
||||
b.Batch("key1")
|
||||
b.Batch("key1")
|
||||
b.Batch("key1")
|
||||
b.Batch("key1")
|
||||
b.Batch("key2")
|
||||
b.Batch("key2")
|
||||
b.Batch("key3")
|
||||
b.Batch("key3")
|
||||
b.Batch("key1", struct{}{})
|
||||
b.Batch("key1", struct{}{})
|
||||
b.Batch("key1", struct{}{})
|
||||
b.Batch("key1", struct{}{})
|
||||
b.Batch("key2", struct{}{})
|
||||
b.Batch("key2", struct{}{})
|
||||
b.Batch("key3", struct{}{})
|
||||
b.Batch("key3", struct{}{})
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return fakeClock.HasWaiters()
|
||||
|
@ -91,7 +94,7 @@ func TestBatch(t *testing.T) {
|
|||
|
||||
fakeClock.Step(time.Millisecond * 5)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
for range 3 {
|
||||
for _, ch := range []chan struct{}{ch1, ch2, ch3} {
|
||||
select {
|
||||
case <-ch:
|
||||
|
@ -100,16 +103,48 @@ func TestBatch(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("ensure items are received in order with latest value", func(t *testing.T) {
|
||||
fakeClock := testingclock.NewFakeClock(time.Now())
|
||||
b := New[int, int](Options{
|
||||
Interval: time.Millisecond * 10,
|
||||
Clock: fakeClock,
|
||||
})
|
||||
t.Cleanup(b.Close)
|
||||
ch1 := make(chan int, 10)
|
||||
ch2 := make(chan int, 10)
|
||||
ch3 := make(chan int, 10)
|
||||
b.Subscribe(t.Context(), ch1, ch2)
|
||||
b.Subscribe(t.Context(), ch3)
|
||||
|
||||
for i := range 10 {
|
||||
b.Batch(i, i)
|
||||
b.Batch(i, i+1)
|
||||
b.Batch(i, i+2)
|
||||
fakeClock.Step(time.Millisecond * 10)
|
||||
}
|
||||
|
||||
for _, ch := range []chan int{ch1} {
|
||||
for i := range 10 {
|
||||
select {
|
||||
case v := <-ch:
|
||||
assert.Equal(t, i+2, v)
|
||||
case <-time.After(time.Second):
|
||||
assert.Fail(t, "should be triggered")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
b := New[string](time.Millisecond * 10)
|
||||
b := New[string, struct{}](Options{Interval: time.Millisecond * 10})
|
||||
ch := make(chan struct{})
|
||||
b.Subscribe(ch)
|
||||
b.Subscribe(t.Context(), ch)
|
||||
assert.Len(t, b.eventChs, 1)
|
||||
b.Batch("key1")
|
||||
b.Batch("key1", struct{}{})
|
||||
b.Close()
|
||||
assert.True(t, b.closed.Load())
|
||||
}
|
||||
|
@ -117,9 +152,9 @@ func TestClose(t *testing.T) {
|
|||
func TestSubscribeAfterClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
b := New[string](time.Millisecond * 10)
|
||||
b := New[string, struct{}](Options{Interval: time.Millisecond * 10})
|
||||
b.Close()
|
||||
ch := make(chan struct{})
|
||||
b.Subscribe(ch)
|
||||
b.Subscribe(t.Context(), ch)
|
||||
assert.Empty(t, b.eventChs)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package broadcaster
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
const bufferSize = 10
|
||||
|
||||
type eventCh[T any] struct {
|
||||
id uint64
|
||||
ch chan<- T
|
||||
closeEventCh chan struct{}
|
||||
}
|
||||
|
||||
type Broadcaster[T any] struct {
|
||||
eventChs []*eventCh[T]
|
||||
currentID uint64
|
||||
|
||||
lock sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
closeCh chan struct{}
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
// New creates a new Broadcaster with the given interval and key type.
|
||||
func New[T any]() *Broadcaster[T] {
|
||||
return &Broadcaster[T]{
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe adds a new event channel subscriber. If the batcher is closed, the
|
||||
// subscriber is silently dropped.
|
||||
func (b *Broadcaster[T]) Subscribe(ctx context.Context, ch ...chan<- T) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
for _, c := range ch {
|
||||
b.subscribe(ctx, c)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broadcaster[T]) subscribe(ctx context.Context, ch chan<- T) {
|
||||
if b.closed.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
id := b.currentID
|
||||
b.currentID++
|
||||
bufferedCh := make(chan T, bufferSize)
|
||||
closeEventCh := make(chan struct{})
|
||||
b.eventChs = append(b.eventChs, &eventCh[T]{
|
||||
id: id,
|
||||
ch: bufferedCh,
|
||||
closeEventCh: closeEventCh,
|
||||
})
|
||||
|
||||
b.wg.Add(1)
|
||||
go func() {
|
||||
defer func() {
|
||||
close(closeEventCh)
|
||||
|
||||
b.lock.Lock()
|
||||
for i, eventCh := range b.eventChs {
|
||||
if eventCh.id == id {
|
||||
b.eventChs = append(b.eventChs[:i], b.eventChs[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
b.lock.Unlock()
|
||||
b.wg.Done()
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-b.closeCh:
|
||||
return
|
||||
case val := <-bufferedCh:
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-b.closeCh:
|
||||
return
|
||||
case ch <- val:
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Broadcast sends the given value to all subscribers.
|
||||
func (b *Broadcaster[T]) Broadcast(value T) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
if b.closed.Load() {
|
||||
return
|
||||
}
|
||||
for _, ev := range b.eventChs {
|
||||
select {
|
||||
case <-ev.closeEventCh:
|
||||
case ev.ch <- value:
|
||||
case <-b.closeCh:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the Broadcaster. It blocks until all events have been sent to
|
||||
// the subscribers. The Broadcaster will be a no-op after this call.
|
||||
func (b *Broadcaster[T]) Close() {
|
||||
defer b.wg.Wait()
|
||||
b.lock.Lock()
|
||||
if b.closed.CompareAndSwap(false, true) {
|
||||
close(b.closeCh)
|
||||
}
|
||||
b.lock.Unlock()
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
/*
|
||||
Copyright 2025 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package fake
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/dapr/kit/events/loop"
|
||||
)
|
||||
|
||||
type Fake[T any] struct {
|
||||
runFn func(context.Context) error
|
||||
enqueueFn func(T)
|
||||
closeFn func(T)
|
||||
}
|
||||
|
||||
func New[T any]() *Fake[T] {
|
||||
return &Fake[T]{
|
||||
runFn: func(context.Context) error { return nil },
|
||||
enqueueFn: func(T) {},
|
||||
closeFn: func(T) {},
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Fake[T]) WithRun(fn func(context.Context) error) *Fake[T] {
|
||||
f.runFn = fn
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Fake[T]) WithEnqueue(fn func(T)) *Fake[T] {
|
||||
f.enqueueFn = fn
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Fake[T]) WithClose(fn func(T)) *Fake[T] {
|
||||
f.closeFn = fn
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Fake[T]) Run(ctx context.Context) error {
|
||||
return f.runFn(ctx)
|
||||
}
|
||||
|
||||
func (f *Fake[T]) Enqueue(t T) {
|
||||
f.enqueueFn(t)
|
||||
}
|
||||
|
||||
func (f *Fake[T]) Close(t T) {
|
||||
f.closeFn(t)
|
||||
}
|
||||
|
||||
func (f *Fake[T]) Reset(loop.Handler[T], uint64) loop.Interface[T] {
|
||||
return f
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
/*
|
||||
Copyright 2025 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package fake
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/dapr/kit/events/loop"
|
||||
)
|
||||
|
||||
func Test_Fake(*testing.T) {
|
||||
var _ loop.Interface[int] = New[int]()
|
||||
}
|
|
@ -0,0 +1,111 @@
|
|||
/*
|
||||
Copyright 2025 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package loop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Handler[T any] interface {
|
||||
Handle(ctx context.Context, t T) error
|
||||
}
|
||||
|
||||
type Interface[T any] interface {
|
||||
Run(ctx context.Context) error
|
||||
Enqueue(t T)
|
||||
Close(t T)
|
||||
Reset(h Handler[T], size uint64) Interface[T]
|
||||
}
|
||||
|
||||
type loop[T any] struct {
|
||||
queue chan T
|
||||
handler Handler[T]
|
||||
|
||||
closed bool
|
||||
closeCh chan struct{}
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
func New[T any](h Handler[T], size uint64) Interface[T] {
|
||||
return &loop[T]{
|
||||
queue: make(chan T, size),
|
||||
handler: h,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func Empty[T any]() Interface[T] {
|
||||
return new(loop[T])
|
||||
}
|
||||
|
||||
func (l *loop[T]) Run(ctx context.Context) error {
|
||||
defer close(l.closeCh)
|
||||
|
||||
for {
|
||||
req, ok := <-l.queue
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := l.handler.Handle(ctx, req); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *loop[T]) Enqueue(req T) {
|
||||
l.lock.RLock()
|
||||
defer l.lock.RUnlock()
|
||||
|
||||
if l.closed {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case l.queue <- req:
|
||||
case <-l.closeCh:
|
||||
}
|
||||
}
|
||||
|
||||
func (l *loop[T]) Close(req T) {
|
||||
l.lock.Lock()
|
||||
l.closed = true
|
||||
select {
|
||||
case l.queue <- req:
|
||||
case <-l.closeCh:
|
||||
}
|
||||
close(l.queue)
|
||||
l.lock.Unlock()
|
||||
<-l.closeCh
|
||||
}
|
||||
|
||||
func (l *loop[T]) Reset(h Handler[T], size uint64) Interface[T] {
|
||||
if l == nil {
|
||||
return New[T](h, size)
|
||||
}
|
||||
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
|
||||
l.closed = false
|
||||
l.closeCh = make(chan struct{})
|
||||
l.handler = h
|
||||
|
||||
// TODO: @joshvanl: use a ring buffer so that we don't need to reallocate and
|
||||
// improve performance.
|
||||
l.queue = make(chan T, size)
|
||||
|
||||
return l
|
||||
}
|
|
@ -43,7 +43,9 @@ func ExampleProcessor() {
|
|||
}
|
||||
|
||||
// Create the processor
|
||||
processor := NewProcessor[string, *queueableItem](executeFn)
|
||||
processor := NewProcessor[string, *queueableItem](Options[string, *queueableItem]{
|
||||
ExecuteFn: executeFn,
|
||||
})
|
||||
|
||||
// Add items to the processor, in any order, using Enqueue
|
||||
processor.Enqueue(&queueableItem{Name: "item1", ExecutionTime: time.Now().Add(500 * time.Millisecond)})
|
||||
|
@ -57,7 +59,7 @@ func ExampleProcessor() {
|
|||
// Using Dequeue allows removing an item from the queue
|
||||
processor.Dequeue("item4")
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
for range 3 {
|
||||
fmt.Println(<-executed)
|
||||
}
|
||||
// Output:
|
||||
|
|
|
@ -14,7 +14,6 @@ limitations under the License.
|
|||
package queue
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
@ -22,11 +21,13 @@ import (
|
|||
kclock "k8s.io/utils/clock"
|
||||
)
|
||||
|
||||
// ErrProcessorStopped is returned when the processor is not running.
|
||||
var ErrProcessorStopped = errors.New("processor is stopped")
|
||||
type Options[K comparable, T Queueable[K]] struct {
|
||||
ExecuteFn func(r T)
|
||||
Clock kclock.Clock
|
||||
}
|
||||
|
||||
// Processor manages the queue of items and processes them at the correct time.
|
||||
type Processor[K comparable, T queueable[K]] struct {
|
||||
type Processor[K comparable, T Queueable[K]] struct {
|
||||
executeFn func(r T)
|
||||
queue queue[K, T]
|
||||
clock kclock.Clock
|
||||
|
@ -40,48 +41,51 @@ type Processor[K comparable, T queueable[K]] struct {
|
|||
|
||||
// NewProcessor returns a new Processor object.
|
||||
// executeFn is the callback invoked when the item is to be executed; this will be invoked in a background goroutine.
|
||||
func NewProcessor[K comparable, T queueable[K]](executeFn func(r T)) *Processor[K, T] {
|
||||
func NewProcessor[K comparable, T Queueable[K]](opts Options[K, T]) *Processor[K, T] {
|
||||
cl := opts.Clock
|
||||
if cl == nil {
|
||||
cl = kclock.RealClock{}
|
||||
}
|
||||
return &Processor[K, T]{
|
||||
executeFn: executeFn,
|
||||
executeFn: opts.ExecuteFn,
|
||||
queue: newQueue[K, T](),
|
||||
processorRunningCh: make(chan struct{}, 1),
|
||||
stopCh: make(chan struct{}),
|
||||
resetCh: make(chan struct{}, 1),
|
||||
clock: kclock.RealClock{},
|
||||
clock: cl,
|
||||
}
|
||||
}
|
||||
|
||||
// WithClock sets the clock used by the processor. Used for testing.
|
||||
func (p *Processor[K, T]) WithClock(clock kclock.Clock) *Processor[K, T] {
|
||||
p.clock = clock
|
||||
return p
|
||||
}
|
||||
|
||||
// Enqueue adds a new item to the queue.
|
||||
// Enqueue adds a new items to the queue.
|
||||
// If a item with the same ID already exists, it'll be replaced.
|
||||
func (p *Processor[K, T]) Enqueue(r T) error {
|
||||
func (p *Processor[K, T]) Enqueue(rs ...T) {
|
||||
if p.stopped.Load() {
|
||||
return ErrProcessorStopped
|
||||
return
|
||||
}
|
||||
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
for _, r := range rs {
|
||||
p.enqueue(r)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Processor[K, T]) enqueue(r T) {
|
||||
// Insert or replace the item in the queue
|
||||
// If the item added or replaced is the first one in the queue, we need to know that
|
||||
p.lock.Lock()
|
||||
peek, ok := p.queue.Peek()
|
||||
isFirst := (ok && peek.Key() == r.Key()) // This is going to be true if the item being replaced is the first one in the queue
|
||||
p.queue.Insert(r, true)
|
||||
peek, _ = p.queue.Peek() // No need to check for "ok" here because we know this will return an item
|
||||
isFirst = isFirst || (peek == r) // This is also going to be true if the item just added landed at the front of the queue
|
||||
p.process(isFirst)
|
||||
p.lock.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Dequeue removes a item from the queue.
|
||||
func (p *Processor[K, T]) Dequeue(key K) error {
|
||||
func (p *Processor[K, T]) Dequeue(key K) {
|
||||
if p.stopped.Load() {
|
||||
return ErrProcessorStopped
|
||||
return
|
||||
}
|
||||
|
||||
// We need to check if this is the next item in the queue, as that requires stopping the processor
|
||||
|
@ -93,8 +97,6 @@ func (p *Processor[K, T]) Dequeue(key K) error {
|
|||
p.process(true)
|
||||
}
|
||||
p.lock.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close stops the processor.
|
||||
|
@ -226,5 +228,5 @@ func (p *Processor[K, T]) execute(r T) {
|
|||
return
|
||||
}
|
||||
|
||||
go p.executeFn(r)
|
||||
p.executeFn(r)
|
||||
}
|
||||
|
|
|
@ -31,10 +31,12 @@ func TestProcessor(t *testing.T) {
|
|||
// Create the processor
|
||||
clock := clocktesting.NewFakeClock(time.Now())
|
||||
executeCh := make(chan *queueableItem)
|
||||
processor := NewProcessor[string](func(r *queueableItem) {
|
||||
executeCh <- r
|
||||
processor := NewProcessor[string, *queueableItem](Options[string, *queueableItem]{
|
||||
ExecuteFn: func(r *queueableItem) {
|
||||
executeCh <- r
|
||||
},
|
||||
Clock: clock,
|
||||
})
|
||||
processor.clock = clock
|
||||
|
||||
assertExecutedItem := func(t *testing.T) *queueableItem {
|
||||
t.Helper()
|
||||
|
@ -63,10 +65,9 @@ func TestProcessor(t *testing.T) {
|
|||
|
||||
t.Run("enqueue items", func(t *testing.T) {
|
||||
for i := 1; i <= 5; i++ {
|
||||
err := processor.Enqueue(
|
||||
processor.Enqueue(
|
||||
newTestItem(i, clock.Now().Add(time.Second*time.Duration(i))),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Advance tickers by 500ms to start
|
||||
|
@ -83,8 +84,7 @@ func TestProcessor(t *testing.T) {
|
|||
|
||||
t.Run("enqueue item to be executed right away", func(t *testing.T) {
|
||||
r := newTestItem(1, clock.Now())
|
||||
err := processor.Enqueue(r)
|
||||
require.NoError(t, err)
|
||||
processor.Enqueue(r)
|
||||
|
||||
clock.Step(500 * time.Millisecond)
|
||||
|
||||
|
@ -95,10 +95,9 @@ func TestProcessor(t *testing.T) {
|
|||
t.Run("enqueue item at the front of the queue", func(t *testing.T) {
|
||||
// Enqueue 4 items
|
||||
for i := 1; i <= 4; i++ {
|
||||
err := processor.Enqueue(
|
||||
processor.Enqueue(
|
||||
newTestItem(i, clock.Now().Add(time.Second*time.Duration(i))),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.Eventually(t, clock.HasWaiters, time.Second, 100*time.Millisecond)
|
||||
|
@ -111,10 +110,9 @@ func TestProcessor(t *testing.T) {
|
|||
assert.Equal(t, "1", received.Name)
|
||||
|
||||
// Add a new item at the front of the queue
|
||||
err := processor.Enqueue(
|
||||
processor.Enqueue(
|
||||
newTestItem(99, clock.Now()),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Advance tickers and assert messages are coming in order
|
||||
for i := 1; i <= 4; i++ {
|
||||
|
@ -136,19 +134,16 @@ func TestProcessor(t *testing.T) {
|
|||
|
||||
// Enqueue 5 items
|
||||
for i := 1; i <= 5; i++ {
|
||||
err := processor.Enqueue(
|
||||
processor.Enqueue(
|
||||
newTestItem(i, clock.Now().Add(time.Second*time.Duration(i))),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, 5, processor.queue.Len())
|
||||
|
||||
// Dequeue items 2 and 4
|
||||
// Note that this is a string because it's the key
|
||||
err := processor.Dequeue("2")
|
||||
require.NoError(t, err)
|
||||
err = processor.Dequeue("4")
|
||||
require.NoError(t, err)
|
||||
processor.Dequeue("2")
|
||||
processor.Dequeue("4")
|
||||
|
||||
assert.Equal(t, 3, processor.queue.Len())
|
||||
|
||||
|
@ -173,10 +168,9 @@ func TestProcessor(t *testing.T) {
|
|||
t.Run("dequeue item from the front of the queue", func(t *testing.T) {
|
||||
// Enqueue 6 items
|
||||
for i := 1; i <= 6; i++ {
|
||||
err := processor.Enqueue(
|
||||
processor.Enqueue(
|
||||
newTestItem(i, clock.Now().Add(time.Second*time.Duration(i))),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Advance tickers and assert messages are coming in order
|
||||
|
@ -187,8 +181,7 @@ func TestProcessor(t *testing.T) {
|
|||
if i == 2 || i == 5 {
|
||||
// Dequeue the item at the front of the queue
|
||||
// Note that this is a string because it's the key
|
||||
err := processor.Dequeue(strconv.Itoa(i))
|
||||
require.NoError(t, err)
|
||||
processor.Dequeue(strconv.Itoa(i))
|
||||
|
||||
// Skip items that have been removed
|
||||
t.Logf("Should not receive signal %d", i)
|
||||
|
@ -206,15 +199,13 @@ func TestProcessor(t *testing.T) {
|
|||
t.Run("replace item", func(t *testing.T) {
|
||||
// Enqueue 5 items
|
||||
for i := 1; i <= 5; i++ {
|
||||
err := processor.Enqueue(
|
||||
processor.Enqueue(
|
||||
newTestItem(i, clock.Now().Add(time.Second*time.Duration(i))),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Replace item 4, bumping its priority down
|
||||
err := processor.Enqueue(newTestItem(4, clock.Now().Add(6*time.Second)))
|
||||
require.NoError(t, err)
|
||||
processor.Enqueue(newTestItem(4, clock.Now().Add(6*time.Second)))
|
||||
|
||||
// Advance tickers and assert messages are coming in order
|
||||
for i := 1; i <= 6; i++ {
|
||||
|
@ -241,10 +232,9 @@ func TestProcessor(t *testing.T) {
|
|||
t.Run("replace item at the front of the queue", func(t *testing.T) {
|
||||
// Enqueue 5 items
|
||||
for i := 1; i <= 5; i++ {
|
||||
err := processor.Enqueue(
|
||||
processor.Enqueue(
|
||||
newTestItem(i, clock.Now().Add(time.Second*time.Duration(i))),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Advance tickers and assert messages are coming in order
|
||||
|
@ -253,8 +243,7 @@ func TestProcessor(t *testing.T) {
|
|||
|
||||
if i == 2 {
|
||||
// Replace item 2, bumping its priority down, while it's at the front of the queue
|
||||
err := processor.Enqueue(newTestItem(2, clock.Now().Add(5*time.Second)))
|
||||
require.NoError(t, err)
|
||||
processor.Enqueue(newTestItem(2, clock.Now().Add(5*time.Second)))
|
||||
|
||||
// This item has been pushed down
|
||||
t.Logf("Should not receive signal %d now", i)
|
||||
|
@ -282,13 +271,12 @@ func TestProcessor(t *testing.T) {
|
|||
)
|
||||
now := clock.Now()
|
||||
wg := sync.WaitGroup{}
|
||||
for i := 0; i < count; i++ {
|
||||
for i := range count {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
execTime := now.Add(time.Second * time.Duration(rand.Intn(maxDelay))) //nolint:gosec
|
||||
err := processor.Enqueue(newTestItem(i, execTime))
|
||||
require.NoError(t, err)
|
||||
processor.Enqueue(newTestItem(i, execTime))
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
@ -324,7 +312,7 @@ func TestProcessor(t *testing.T) {
|
|||
close(doneCh)
|
||||
|
||||
// Ensure all items are true
|
||||
for i := 0; i < count; i++ {
|
||||
for i := range count {
|
||||
assert.Truef(t, collected[i], "item %d not received", i)
|
||||
}
|
||||
})
|
||||
|
@ -332,10 +320,9 @@ func TestProcessor(t *testing.T) {
|
|||
t.Run("stop processor", func(t *testing.T) {
|
||||
// Enqueue 5 items
|
||||
for i := 1; i <= 5; i++ {
|
||||
err := processor.Enqueue(
|
||||
processor.Enqueue(
|
||||
newTestItem(i, clock.Now().Add(time.Second*time.Duration(i))),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.Eventually(t, clock.HasWaiters, time.Second, 100*time.Millisecond)
|
||||
|
@ -348,10 +335,8 @@ func TestProcessor(t *testing.T) {
|
|||
assertNoExecutedItem(t)
|
||||
|
||||
// Enqueuing and dequeueing should fail
|
||||
err := processor.Enqueue(newTestItem(99, clock.Now()))
|
||||
require.ErrorIs(t, err, ErrProcessorStopped)
|
||||
err = processor.Dequeue("99")
|
||||
require.ErrorIs(t, err, ErrProcessorStopped)
|
||||
processor.Enqueue(newTestItem(99, clock.Now()))
|
||||
processor.Dequeue("99")
|
||||
|
||||
// Stopping again is a nop (should not crash)
|
||||
require.NoError(t, processor.Close())
|
||||
|
@ -364,10 +349,12 @@ func TestClose(t *testing.T) {
|
|||
// Create the processor
|
||||
clock := clocktesting.NewFakeClock(time.Now())
|
||||
executeCh := make(chan *queueableItem)
|
||||
processor := NewProcessor[string](func(r *queueableItem) {
|
||||
executeCh <- r
|
||||
processor := NewProcessor[string, *queueableItem](Options[string, *queueableItem]{
|
||||
ExecuteFn: func(r *queueableItem) {
|
||||
executeCh <- r
|
||||
},
|
||||
Clock: clock,
|
||||
})
|
||||
processor.clock = clock
|
||||
|
||||
processor.Enqueue(newTestItem(1, clock.Now().Add(time.Second)))
|
||||
processor.Enqueue(newTestItem(2, clock.Now().Add(time.Second*2)))
|
||||
|
@ -415,7 +402,7 @@ func TestClose(t *testing.T) {
|
|||
default:
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
for range 3 {
|
||||
select {
|
||||
case err := <-closeCh:
|
||||
require.NoError(t, err)
|
||||
|
|
|
@ -18,8 +18,8 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
// queueable is the interface for items that can be added to the queue.
|
||||
type queueable[T comparable] interface {
|
||||
// Queueable is the interface for items that can be added to the queue.
|
||||
type Queueable[T comparable] interface {
|
||||
comparable
|
||||
Key() T
|
||||
ScheduledTime() time.Time
|
||||
|
@ -29,13 +29,13 @@ type queueable[T comparable] interface {
|
|||
// It acts as a "priority queue", in which items are added in order of when they're scheduled.
|
||||
// Internally, it uses a heap (from container/heap) that allows Insert and Pop operations to be completed in O(log N) time (where N is the queue's length).
|
||||
// Note: methods in this struct are not safe for concurrent use. Callers should use locks to ensure consistency.
|
||||
type queue[K comparable, T queueable[K]] struct {
|
||||
type queue[K comparable, T Queueable[K]] struct {
|
||||
heap *queueHeap[K, T]
|
||||
items map[K]*queueItem[K, T]
|
||||
}
|
||||
|
||||
// newQueue creates a new queue.
|
||||
func newQueue[K comparable, T queueable[K]]() queue[K, T] {
|
||||
func newQueue[K comparable, T Queueable[K]]() queue[K, T] {
|
||||
return queue[K, T]{
|
||||
heap: new(queueHeap[K, T]),
|
||||
items: make(map[K]*queueItem[K, T]),
|
||||
|
@ -122,14 +122,14 @@ func (p *queue[K, T]) Update(r T) {
|
|||
heap.Fix(p.heap, item.index)
|
||||
}
|
||||
|
||||
type queueItem[K comparable, T queueable[K]] struct {
|
||||
type queueItem[K comparable, T Queueable[K]] struct {
|
||||
value T
|
||||
|
||||
// The index of the item in the heap. This is maintained by the heap.Interface methods.
|
||||
index int
|
||||
}
|
||||
|
||||
type queueHeap[K comparable, T queueable[K]] []*queueItem[K, T]
|
||||
type queueHeap[K comparable, T Queueable[K]] []*queueItem[K, T]
|
||||
|
||||
func (pq queueHeap[K, T]) Len() int {
|
||||
return len(pq)
|
||||
|
|
|
@ -39,7 +39,7 @@ func TestCoalescing(t *testing.T) {
|
|||
ch := make(chan struct{})
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- c.Run(context.Background(), ch)
|
||||
errCh <- c.Run(t.Context(), ch)
|
||||
}()
|
||||
|
||||
t.Cleanup(func() {
|
||||
|
@ -78,7 +78,7 @@ func TestCoalescing(t *testing.T) {
|
|||
c, err := NewCoalescing(OptionsCoalescing{})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- c.Run(ctx, make(chan struct{}))
|
||||
|
@ -100,7 +100,7 @@ func TestCoalescing(t *testing.T) {
|
|||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- c.Run(context.Background(), make(chan struct{}))
|
||||
errCh <- c.Run(t.Context(), make(chan struct{}))
|
||||
}()
|
||||
|
||||
c.Close()
|
||||
|
@ -119,7 +119,7 @@ func TestCoalescing(t *testing.T) {
|
|||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- c.Run(context.Background(), make(chan struct{}))
|
||||
errCh <- c.Run(t.Context(), make(chan struct{}))
|
||||
}()
|
||||
|
||||
c.Close()
|
||||
|
@ -132,7 +132,7 @@ func TestCoalescing(t *testing.T) {
|
|||
}
|
||||
|
||||
go func() {
|
||||
errCh <- c.Run(context.Background(), make(chan struct{}))
|
||||
errCh <- c.Run(t.Context(), make(chan struct{}))
|
||||
}()
|
||||
|
||||
select {
|
||||
|
@ -277,7 +277,7 @@ func TestCoalescing(t *testing.T) {
|
|||
c.Add()
|
||||
assertNoChannel(t, ch)
|
||||
|
||||
for i := 0; i < 4; i++ {
|
||||
for range 4 {
|
||||
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
|
||||
clock.Step(time.Second * 4)
|
||||
c.Add()
|
||||
|
@ -345,7 +345,7 @@ func TestCoalescing(t *testing.T) {
|
|||
assertChannel(t, ch)
|
||||
|
||||
assert.Eventually(t, c.hasTimer.Load, time.Second, time.Millisecond)
|
||||
for i := 0; i < 10; i++ {
|
||||
for range 10 {
|
||||
c.Add()
|
||||
}
|
||||
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
|
||||
|
|
|
@ -44,7 +44,7 @@ type Options struct {
|
|||
type FSWatcher struct {
|
||||
w *fsnotify.Watcher
|
||||
running atomic.Bool
|
||||
batcher *batcher.Batcher[string]
|
||||
batcher *batcher.Batcher[string, struct{}]
|
||||
}
|
||||
|
||||
func New(opts Options) (*FSWatcher, error) {
|
||||
|
@ -71,7 +71,9 @@ func New(opts Options) (*FSWatcher, error) {
|
|||
w: w,
|
||||
// Often the case, writes to files are not atomic and involve multiple file system events.
|
||||
// We want to hold off on sending events until we are sure that the file has been written to completion. We do this by waiting for a period of time after the last event has been received for a file name.
|
||||
batcher: batcher.New[string](interval),
|
||||
batcher: batcher.New[string, struct{}](batcher.Options{
|
||||
Interval: interval,
|
||||
}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -81,7 +83,7 @@ func (f *FSWatcher) Run(ctx context.Context, eventCh chan<- struct{}) error {
|
|||
}
|
||||
defer f.batcher.Close()
|
||||
|
||||
f.batcher.Subscribe(eventCh)
|
||||
f.batcher.Subscribe(ctx, eventCh)
|
||||
|
||||
for {
|
||||
select {
|
||||
|
@ -90,7 +92,7 @@ func (f *FSWatcher) Run(ctx context.Context, eventCh chan<- struct{}) error {
|
|||
case err := <-f.w.Errors:
|
||||
return errors.Join(fmt.Errorf("watcher error: %w", err), f.w.Close())
|
||||
case event := <-f.w.Events:
|
||||
f.batcher.Batch(event.Name)
|
||||
f.batcher.Batch(event.Name, struct{}{})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -32,7 +32,7 @@ import (
|
|||
)
|
||||
|
||||
func TestFSWatcher(t *testing.T) {
|
||||
runWatcher := func(t *testing.T, opts Options, bacher *batcher.Batcher[string]) <-chan struct{} {
|
||||
runWatcher := func(t *testing.T, opts Options, bacher *batcher.Batcher[string, struct{}]) <-chan struct{} {
|
||||
t.Helper()
|
||||
|
||||
f, err := New(opts)
|
||||
|
@ -43,7 +43,7 @@ func TestFSWatcher(t *testing.T) {
|
|||
}
|
||||
|
||||
errCh := make(chan error)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
eventsCh := make(chan struct{})
|
||||
go func() {
|
||||
errCh <- f.Run(ctx, eventsCh)
|
||||
|
@ -84,7 +84,7 @@ func TestFSWatcher(t *testing.T) {
|
|||
t.Run("running Run twice should error", func(t *testing.T) {
|
||||
fs, err := New(Options{})
|
||||
require.NoError(t, err)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
cancel()
|
||||
require.NoError(t, fs.Run(ctx, make(chan struct{})))
|
||||
require.Error(t, fs.Run(ctx, make(chan struct{})))
|
||||
|
@ -101,7 +101,7 @@ func TestFSWatcher(t *testing.T) {
|
|||
|
||||
t.Run("should fire event when event occurs on target file", func(t *testing.T) {
|
||||
fp := filepath.Join(t.TempDir(), "test.txt")
|
||||
require.NoError(t, os.WriteFile(fp, []byte{}, 0o644))
|
||||
require.NoError(t, os.WriteFile(fp, []byte{}, 0o600))
|
||||
eventsCh := runWatcher(t, Options{
|
||||
Targets: []string{fp},
|
||||
Interval: ptr.Of(time.Duration(1)),
|
||||
|
@ -112,7 +112,7 @@ func TestFSWatcher(t *testing.T) {
|
|||
// If running in windows, wait for notify to be ready.
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
require.NoError(t, os.WriteFile(fp, []byte{}, 0o644))
|
||||
require.NoError(t, os.WriteFile(fp, []byte{}, 0o600))
|
||||
|
||||
select {
|
||||
case <-eventsCh:
|
||||
|
@ -124,16 +124,16 @@ func TestFSWatcher(t *testing.T) {
|
|||
t.Run("should fire 2 events when event occurs on 2 file target", func(t *testing.T) {
|
||||
fp1 := filepath.Join(t.TempDir(), "test.txt")
|
||||
fp2 := filepath.Join(t.TempDir(), "test.txt")
|
||||
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o644))
|
||||
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o644))
|
||||
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o600))
|
||||
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o600))
|
||||
eventsCh := runWatcher(t, Options{
|
||||
Targets: []string{fp1, fp2},
|
||||
Interval: ptr.Of(time.Duration(1)),
|
||||
}, nil)
|
||||
assert.Empty(t, eventsCh)
|
||||
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o644))
|
||||
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o644))
|
||||
for i := 0; i < 2; i++ {
|
||||
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o600))
|
||||
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o600))
|
||||
for range 2 {
|
||||
select {
|
||||
case <-eventsCh:
|
||||
case <-time.After(time.Second):
|
||||
|
@ -146,8 +146,8 @@ func TestFSWatcher(t *testing.T) {
|
|||
dir := t.TempDir()
|
||||
fp1 := filepath.Join(dir, "test1.txt")
|
||||
fp2 := filepath.Join(dir, "test2.txt")
|
||||
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o644))
|
||||
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o644))
|
||||
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o600))
|
||||
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o600))
|
||||
eventsCh := runWatcher(t, Options{
|
||||
Targets: []string{fp1, fp2},
|
||||
Interval: ptr.Of(time.Duration(1)),
|
||||
|
@ -157,9 +157,9 @@ func TestFSWatcher(t *testing.T) {
|
|||
time.Sleep(time.Second)
|
||||
}
|
||||
assert.Empty(t, eventsCh)
|
||||
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o644))
|
||||
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o644))
|
||||
for i := 0; i < 2; i++ {
|
||||
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o600))
|
||||
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o600))
|
||||
for range 2 {
|
||||
select {
|
||||
case <-eventsCh:
|
||||
case <-time.After(time.Second):
|
||||
|
@ -178,9 +178,9 @@ func TestFSWatcher(t *testing.T) {
|
|||
Interval: ptr.Of(time.Duration(1)),
|
||||
}, nil)
|
||||
assert.Empty(t, eventsCh)
|
||||
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o644))
|
||||
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o644))
|
||||
for i := 0; i < 2; i++ {
|
||||
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o600))
|
||||
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o600))
|
||||
for range 2 {
|
||||
select {
|
||||
case <-eventsCh:
|
||||
case <-time.After(time.Second):
|
||||
|
@ -191,8 +191,10 @@ func TestFSWatcher(t *testing.T) {
|
|||
|
||||
t.Run("should batch events of the same file for multiple events", func(t *testing.T) {
|
||||
clock := clocktesting.NewFakeClock(time.Time{})
|
||||
batcher := batcher.New[string](time.Millisecond * 500)
|
||||
batcher.WithClock(clock)
|
||||
batcher := batcher.New[string, struct{}](batcher.Options{
|
||||
Interval: time.Millisecond * 500,
|
||||
Clock: clock,
|
||||
})
|
||||
dir1 := t.TempDir()
|
||||
dir2 := t.TempDir()
|
||||
fp1 := filepath.Join(dir1, "test1.txt")
|
||||
|
@ -205,9 +207,9 @@ func TestFSWatcher(t *testing.T) {
|
|||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o644))
|
||||
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o644))
|
||||
for range 10 {
|
||||
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o600))
|
||||
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o600))
|
||||
}
|
||||
|
||||
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond*10)
|
||||
|
@ -220,9 +222,9 @@ func TestFSWatcher(t *testing.T) {
|
|||
|
||||
clock.Step(time.Millisecond * 250)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o644))
|
||||
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o644))
|
||||
for range 10 {
|
||||
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o600))
|
||||
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o600))
|
||||
}
|
||||
|
||||
select {
|
||||
|
@ -234,7 +236,7 @@ func TestFSWatcher(t *testing.T) {
|
|||
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond*10)
|
||||
clock.Step(time.Millisecond * 500)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
for range 2 {
|
||||
select {
|
||||
case <-eventsCh:
|
||||
case <-time.After(time.Second):
|
||||
|
|
|
@ -22,7 +22,7 @@ import (
|
|||
"github.com/dapr/kit/events/batcher"
|
||||
)
|
||||
|
||||
func (f *FSWatcher) WithBatcher(b *batcher.Batcher[string]) *FSWatcher {
|
||||
func (f *FSWatcher) WithBatcher(b *batcher.Batcher[string, struct{}]) *FSWatcher {
|
||||
f.batcher = b
|
||||
return f
|
||||
}
|
||||
|
|
|
@ -27,7 +27,9 @@ import (
|
|||
)
|
||||
|
||||
func TestWithBatcher(t *testing.T) {
|
||||
b := batcher.New[string](time.Millisecond * 10)
|
||||
b := batcher.New[string, struct{}](batcher.Options{
|
||||
Interval: time.Millisecond * 10,
|
||||
})
|
||||
f, err := New(Options{})
|
||||
require.NoError(t, err)
|
||||
f.WithBatcher(b)
|
||||
|
|
30
go.mod
30
go.mod
|
@ -1,24 +1,26 @@
|
|||
module github.com/dapr/kit
|
||||
|
||||
go 1.20
|
||||
go 1.24.3
|
||||
|
||||
require (
|
||||
github.com/alphadose/haxmap v1.3.1
|
||||
github.com/cenkalti/backoff/v4 v4.2.1
|
||||
github.com/fsnotify/fsnotify v1.7.0
|
||||
github.com/lestrrat-go/httprc v1.0.4
|
||||
github.com/lestrrat-go/jwx/v2 v2.0.15
|
||||
github.com/lestrrat-go/httprc v1.0.5
|
||||
github.com/lestrrat-go/jwx/v2 v2.0.21
|
||||
github.com/mitchellh/mapstructure v1.5.1-0.20220423185008-bf980b35cac4
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/spf13/cast v1.5.1
|
||||
github.com/stretchr/testify v1.8.4
|
||||
github.com/spiffe/go-spiffe/v2 v2.5.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/tidwall/transform v0.0.0-20201103190739-32f242e2dbde
|
||||
golang.org/x/crypto v0.14.0
|
||||
golang.org/x/crypto v0.39.0
|
||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d
|
||||
golang.org/x/tools v0.14.0
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d
|
||||
google.golang.org/grpc v1.57.0
|
||||
google.golang.org/protobuf v1.31.0
|
||||
golang.org/x/tools v0.33.0
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822
|
||||
google.golang.org/grpc v1.73.0
|
||||
google.golang.org/grpc/examples v0.0.0-20230224211313-3775f633ce20
|
||||
google.golang.org/protobuf v1.36.6
|
||||
k8s.io/apimachinery v0.26.9
|
||||
k8s.io/utils v0.0.0-20230726121419-3b25d923346b
|
||||
)
|
||||
|
@ -26,17 +28,21 @@ require (
|
|||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect
|
||||
github.com/go-jose/go-jose/v4 v4.0.5 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/golang/protobuf v1.5.3 // indirect
|
||||
github.com/lestrrat-go/blackmagic v1.0.2 // indirect
|
||||
github.com/lestrrat-go/httpcc v1.0.1 // indirect
|
||||
github.com/lestrrat-go/iter v1.0.2 // indirect
|
||||
github.com/lestrrat-go/option v1.0.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/segmentio/asm v1.2.0 // indirect
|
||||
golang.org/x/mod v0.13.0 // indirect
|
||||
golang.org/x/sys v0.13.0 // indirect
|
||||
github.com/zeebo/errs v1.4.0 // indirect
|
||||
golang.org/x/mod v0.25.0 // indirect
|
||||
golang.org/x/net v0.41.0 // indirect
|
||||
golang.org/x/sync v0.15.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
golang.org/x/text v0.26.0 // indirect
|
||||
gopkg.in/inf.v0 v0.9.1 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
|
129
go.sum
129
go.sum
|
@ -1,3 +1,5 @@
|
|||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/alphadose/haxmap v1.3.1 h1:KmZh75duO1tC8pt3LmUwoTYiZ9sh4K52FX8p7/yrlqU=
|
||||
github.com/alphadose/haxmap v1.3.1/go.mod h1:rjHw1IAqbxm0S3U5tD16GoKsiAd8FWx5BJ2IYqXwgmM=
|
||||
github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM=
|
||||
|
@ -5,45 +7,56 @@ github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyY
|
|||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/decred/dcrd/crypto/blake256 v1.0.1/go.mod h1:2OfgNZ5wDpcsFmHmCK5gZTPcCXqlm2ArzUIkw9czNJo=
|
||||
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 h1:8UrgZ3GkP4i/CLijOJx79Yu+etlyjdBU4sfcs2WYQMs=
|
||||
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0=
|
||||
github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0XL9UY=
|
||||
github.com/frankban/quicktest v1.14.4/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
|
||||
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
|
||||
github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE=
|
||||
github.com/go-jose/go-jose/v4 v4.0.5/go.mod h1:s3P1lRrkT8igV8D9OjyL4WRyHvjB6a4JSllnOrmmBOA=
|
||||
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
||||
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
|
||||
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.1.0 h1:Hsa8mG0dQ46ij8Sl2AYJDUv1oA9/d6Vk+3LG99Oe02g=
|
||||
github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k=
|
||||
github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU=
|
||||
github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE=
|
||||
github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E=
|
||||
github.com/lestrrat-go/httprc v1.0.4 h1:bAZymwoZQb+Oq8MEbyipag7iSq6YIga8Wj6GOiJGdI8=
|
||||
github.com/lestrrat-go/httprc v1.0.4/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo=
|
||||
github.com/lestrrat-go/httprc v1.0.5 h1:bsTfiH8xaKOJPrg1R+E3iE/AWZr/x0Phj9PBTG/OLUk=
|
||||
github.com/lestrrat-go/httprc v1.0.5/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo=
|
||||
github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI=
|
||||
github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4=
|
||||
github.com/lestrrat-go/jwx/v2 v2.0.15 h1:XvR2lQdX+mZechmqWxqQb2foU3hgAn5+Rj0ICa0I6sU=
|
||||
github.com/lestrrat-go/jwx/v2 v2.0.15/go.mod h1:jBHyESp4e7QxfERM0UKkQ80/94paqNIEcdEfiUYz5zE=
|
||||
github.com/lestrrat-go/option v1.0.0/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
|
||||
github.com/lestrrat-go/jwx/v2 v2.0.21 h1:jAPKupy4uHgrHFEdjVjNkUgoBKtVDgrQPB/h55FHrR0=
|
||||
github.com/lestrrat-go/jwx/v2 v2.0.21/go.mod h1:09mLW8zto6bWL9GbwnqAli+ArLf+5M33QLQPDggkUWM=
|
||||
github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU=
|
||||
github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
|
||||
github.com/mitchellh/mapstructure v1.5.1-0.20220423185008-bf980b35cac4 h1:BpfhmLKZf+SjVanKKhCgf3bg+511DmU9eDQTen7LLbY=
|
||||
github.com/mitchellh/mapstructure v1.5.1-0.20220423185008-bf980b35cac4/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys=
|
||||
github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
|
@ -51,96 +64,86 @@ github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVs
|
|||
github.com/spf13/cast v1.5.1 h1:R+kOtfhWQE6TVQzY+4D7wJLBgkdVasCEFxSUBYBYIlA=
|
||||
github.com/spf13/cast v1.5.1/go.mod h1:b9PdjNptOpzXr7Rq1q9gJML/2cdGQAo69NKzQ10KN48=
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spiffe/go-spiffe/v2 v2.5.0 h1:N2I01KCUkv1FAjZXJMwh95KK1ZIQLYbPfhaxw8WS0hE=
|
||||
github.com/spiffe/go-spiffe/v2 v2.5.0/go.mod h1:P+NxobPc6wXhVtINNtFjNWGBTreew1GBUCwT2wPmb7g=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tidwall/transform v0.0.0-20201103190739-32f242e2dbde h1:AMNpJRc7P+GTwVbl8DkK2I9I8BBUzNiHuH/tlxrpan0=
|
||||
github.com/tidwall/transform v0.0.0-20201103190739-32f242e2dbde/go.mod h1:MvrEmduDUz4ST5pGZ7CABCnOU5f3ZiOAZzT6b1A6nX8=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
github.com/zeebo/errs v1.4.0 h1:XNdoD/RRMKP7HD0UhJnIzUy74ISdGGxURlYG8HSWSfM=
|
||||
github.com/zeebo/errs v1.4.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||
go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ=
|
||||
go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y=
|
||||
go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M=
|
||||
go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE=
|
||||
go.opentelemetry.io/otel/sdk v1.35.0 h1:iPctf8iprVySXSKJffSS79eOjl9pvxV9ZqOWT0QejKY=
|
||||
go.opentelemetry.io/otel/sdk v1.35.0/go.mod h1:+ga1bZliga3DxJ3CQGg3updiaAJoNECOgJREo9KHGQg=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.35.0 h1:1RriWBmCKgkeHEhM7a2uMjMUfP7MsOF5JpUCaEqEI9o=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.35.0/go.mod h1:is6XYCUMpcKi+ZsOvfluY5YstFnhW0BidkR+gL+qN+w=
|
||||
go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs=
|
||||
go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
|
||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
|
||||
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
|
||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
|
||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
|
||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.13.0 h1:I/DsJXRlw/8l/0c24sM9yb0T4z9liZTduXvdAWYiysY=
|
||||
golang.org/x/mod v0.13.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
|
||||
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.16.0 h1:7eBu7KsSvFDtSXUIDbh3aqlK4DPsZ1rByC8PFfBThos=
|
||||
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
|
||||
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ=
|
||||
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
|
||||
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||
golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
|
||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
|
||||
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.14.0 h1:jvNa2pY0M4r62jkRQ6RwEZZyPcymeL9XZMLBbV7U2nc=
|
||||
golang.org/x/tools v0.14.0/go.mod h1:uYBEerGOWcJyEORxN+Ek8+TT266gXkNlHdJBwexUsBg=
|
||||
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
|
||||
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d h1:uvYuEyMHKNt+lT4K3bN6fGswmK8qSvcreM3BwjDh+y4=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d/go.mod h1:+Bk1OCOj40wS2hwAMA+aCW9ypzm63QTBBHp6lQ3p+9M=
|
||||
google.golang.org/grpc v1.57.0 h1:kfzNeI/klCGD2YPMUlaGNT3pxvYfga7smW3Vth8Zsiw=
|
||||
google.golang.org/grpc v1.57.0/go.mod h1:Sd+9RMTACXwmub0zcNY2c4arhtrbBYD1AUHI/dt16Mo=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
|
||||
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
|
||||
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 h1:fc6jSaCT0vBduLYZHYrBBNY4dsWuvgyff9noRNDdBeE=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
|
||||
google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok=
|
||||
google.golang.org/grpc v1.73.0/go.mod h1:50sbHOUqWoCQGI8V2HQLJM0B+LMlIUjNSZmow7EVBQc=
|
||||
google.golang.org/grpc/examples v0.0.0-20230224211313-3775f633ce20 h1:MLBCGN1O7GzIx+cBiwfYPwtmZ41U3Mn/cotLJciaArI=
|
||||
google.golang.org/grpc/examples v0.0.0-20230224211313-3775f633ce20/go.mod h1:Nr5H8+MlGWr5+xX/STzdoEqJrO+YteqFbMyCsrb6mH0=
|
||||
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
|
||||
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
|
||||
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
|
||||
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
|
|
@ -37,9 +37,9 @@ import (
|
|||
"github.com/lestrrat-go/httprc"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
|
||||
"github.com/dapr/kit/crypto/pem"
|
||||
"github.com/dapr/kit/fswatcher"
|
||||
"github.com/dapr/kit/logger"
|
||||
"github.com/dapr/kit/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -198,7 +198,7 @@ func (c *JWKSCache) initJWKSFromURL(ctx context.Context, url string) error {
|
|||
|
||||
// Load CA certificates if we have one
|
||||
if c.caCertificate != "" {
|
||||
caCert, err := utils.GetPEM(c.caCertificate)
|
||||
caCert, err := pem.GetPEM(c.caCertificate)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load CA certificate: %w", err)
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ func TestJWKSCache(t *testing.T) {
|
|||
|
||||
t.Run("init with value", func(t *testing.T) {
|
||||
cache := NewJWKSCache(testJWKS1, log)
|
||||
err := cache.initCache(context.Background())
|
||||
err := cache.initCache(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
set := cache.KeySet()
|
||||
|
@ -53,7 +53,7 @@ func TestJWKSCache(t *testing.T) {
|
|||
|
||||
t.Run("init with base64-encoded value", func(t *testing.T) {
|
||||
cache := NewJWKSCache(base64.StdEncoding.EncodeToString([]byte(testJWKS1)), log)
|
||||
err := cache.initCache(context.Background())
|
||||
err := cache.initCache(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
set := cache.KeySet()
|
||||
|
@ -68,12 +68,12 @@ func TestJWKSCache(t *testing.T) {
|
|||
// Create a temporary directory and put the JWKS in there
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "jwks.json")
|
||||
err := os.WriteFile(path, []byte(testJWKS1), 0o666)
|
||||
err := os.WriteFile(path, []byte(testJWKS1), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should wait for first file to be loaded before initialization is reported as completed
|
||||
cache := NewJWKSCache(path, log)
|
||||
err = cache.initCache(context.Background())
|
||||
err = cache.initCache(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
set := cache.KeySet()
|
||||
|
@ -87,7 +87,7 @@ func TestJWKSCache(t *testing.T) {
|
|||
time.Sleep(time.Second)
|
||||
|
||||
// Update the file and verify it's picked up
|
||||
err = os.WriteFile(path, []byte(testJWKS2), 0o666)
|
||||
err = os.WriteFile(path, []byte(testJWKS2), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
|
@ -127,7 +127,7 @@ func TestJWKSCache(t *testing.T) {
|
|||
cache := NewJWKSCache("http://localhost/jwks.json", log)
|
||||
cache.SetHTTPClient(client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
|
||||
defer cancel()
|
||||
err := cache.initCache(ctx)
|
||||
require.NoError(t, err)
|
||||
|
@ -142,7 +142,7 @@ func TestJWKSCache(t *testing.T) {
|
|||
|
||||
t.Run("start and wait for init", func(t *testing.T) {
|
||||
cache := NewJWKSCache(testJWKS1, log)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Start in background
|
||||
|
@ -174,7 +174,7 @@ func TestJWKSCache(t *testing.T) {
|
|||
cache := NewJWKSCache("https://localhost/jwks.json", log)
|
||||
cache.SetHTTPClient(client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Start in background
|
||||
|
@ -194,7 +194,7 @@ func TestJWKSCache(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("start and init times out", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1500*time.Millisecond)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 1500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
// Create a custom HTTP client with a RoundTripper that doesn't require starting a TCP listener
|
||||
|
@ -223,7 +223,7 @@ func TestJWKSCache(t *testing.T) {
|
|||
}()
|
||||
|
||||
// Wait for initialization
|
||||
err := cache.WaitForCacheReady(context.Background())
|
||||
err := cache.WaitForCacheReady(t.Context())
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "failed to fetch JWKS")
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
|
|
|
@ -25,7 +25,6 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
const fakeLoggerName = "fakeLogger"
|
||||
|
@ -286,7 +285,7 @@ func TestWithTypeFields(t *testing.T) {
|
|||
testLogger.Info("testLogger with log LogType")
|
||||
|
||||
b, _ = buf.ReadBytes('\n')
|
||||
maps.Clear(o)
|
||||
clear(o)
|
||||
require.NoError(t, json.Unmarshal(b, &o))
|
||||
|
||||
assert.Equalf(t, LogTypeLog, o[logFieldType], "testLogger must be %s type", LogTypeLog)
|
||||
|
@ -309,12 +308,12 @@ func TestWithFields(t *testing.T) {
|
|||
}).Info("🙃")
|
||||
|
||||
b, _ := buf.ReadBytes('\n')
|
||||
maps.Clear(o)
|
||||
clear(o)
|
||||
require.NoError(t, json.Unmarshal(b, &o))
|
||||
|
||||
assert.Equal(t, "🙃", o["msg"])
|
||||
assert.Equal(t, "world", o["hello"])
|
||||
assert.Equal(t, float64(42), o["answer"])
|
||||
assert.InDelta(t, float64(42), o["answer"], 000.1)
|
||||
|
||||
// Test with other fields
|
||||
testLogger.WithFields(map[string]any{
|
||||
|
@ -322,7 +321,7 @@ func TestWithFields(t *testing.T) {
|
|||
}).Info("🐶")
|
||||
|
||||
b, _ = buf.ReadBytes('\n')
|
||||
maps.Clear(o)
|
||||
clear(o)
|
||||
require.NoError(t, json.Unmarshal(b, &o))
|
||||
|
||||
assert.Equal(t, "🐶", o["msg"])
|
||||
|
@ -336,7 +335,7 @@ func TestWithFields(t *testing.T) {
|
|||
testLogger.Info("🤔")
|
||||
|
||||
b, _ = buf.ReadBytes('\n')
|
||||
maps.Clear(o)
|
||||
clear(o)
|
||||
require.NoError(t, json.Unmarshal(b, &o))
|
||||
|
||||
assert.Equal(t, "🤔", o["msg"])
|
||||
|
|
|
@ -14,7 +14,6 @@ limitations under the License.
|
|||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -78,7 +77,7 @@ func TestToLogLevel(t *testing.T) {
|
|||
|
||||
func TestNewContext(t *testing.T) {
|
||||
t.Run("input nil logger", func(t *testing.T) {
|
||||
ctx := NewContext(context.Background(), nil)
|
||||
ctx := NewContext(t.Context(), nil)
|
||||
assert.NotNil(t, ctx, "ctx is not nil")
|
||||
|
||||
logger := FromContextOrDefault(ctx)
|
||||
|
@ -91,7 +90,7 @@ func TestNewContext(t *testing.T) {
|
|||
logger := NewLogger(testLoggerName)
|
||||
assert.NotNil(t, logger)
|
||||
|
||||
ctx := NewContext(context.Background(), logger)
|
||||
ctx := NewContext(t.Context(), logger)
|
||||
assert.NotNil(t, ctx, "ctx is not nil")
|
||||
logger2 := FromContextOrDefault(ctx)
|
||||
assert.NotNil(t, logger2)
|
||||
|
|
|
@ -29,7 +29,7 @@ type Duration struct {
|
|||
time.Duration
|
||||
}
|
||||
|
||||
func (d Duration) MarshalJSON() ([]byte, error) {
|
||||
func (d *Duration) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(d.String())
|
||||
}
|
||||
|
||||
|
@ -114,7 +114,7 @@ func toTimeDurationHookFunc() mapstructure.DecodeHookFunc {
|
|||
// This methods supports days, hours, minutes, and seconds. It assumes all durations are in UTC time and are not impacted by DST (so all days are 24-hours long).
|
||||
// This method does not support fractions of seconds, and durations are truncated to seconds.
|
||||
// See https://en.wikipedia.org/wiki/ISO_8601#Durations for referece.
|
||||
func (d Duration) ToISOString() string {
|
||||
func (d *Duration) ToISOString() string {
|
||||
// Truncate to seconds, removing fractional seconds
|
||||
trunc := d.Truncate(time.Second)
|
||||
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package metadata
|
||||
|
||||
// Properties contains metadata properties, as a key-value dictionary
|
||||
type Properties map[string]string
|
||||
|
||||
// GetProperty returns a property from the metadata, with support for case-insensitive keys and aliases.
|
||||
func (p Properties) GetProperty(keys ...string) (val string, ok bool) {
|
||||
return GetMetadataProperty(p, keys...)
|
||||
}
|
||||
|
||||
// GetPropertyWithMatchedKey returns a property from the metadata, with support for case-insensitive keys and aliases,
|
||||
// while returning the original matching metadata field key.
|
||||
func (p Properties) GetPropertyWithMatchedKey(keys ...string) (key string, val string, ok bool) {
|
||||
return GetMetadataPropertyWithMatchedKey(p, keys...)
|
||||
}
|
||||
|
||||
// Decode decodes metadata into a struct.
|
||||
// This is an extension of mitchellh/mapstructure which also supports decoding durations.
|
||||
func (p Properties) Decode(result any) error {
|
||||
return decodeMetadataMap(p, result)
|
||||
}
|
|
@ -23,7 +23,7 @@ import (
|
|||
"github.com/mitchellh/mapstructure"
|
||||
|
||||
"github.com/dapr/kit/ptr"
|
||||
"github.com/dapr/kit/utils"
|
||||
kitstrings "github.com/dapr/kit/strings"
|
||||
)
|
||||
|
||||
func toTruthyBoolHookFunc() mapstructure.DecodeHookFunc {
|
||||
|
@ -37,10 +37,10 @@ func toTruthyBoolHookFunc() mapstructure.DecodeHookFunc {
|
|||
data any,
|
||||
) (any, error) {
|
||||
if f == stringType && t == boolType {
|
||||
return utils.IsTruthy(data.(string)), nil
|
||||
return kitstrings.IsTruthy(data.(string)), nil
|
||||
}
|
||||
if f == stringType && t == boolPtrType {
|
||||
return ptr.Of(utils.IsTruthy(data.(string))), nil
|
||||
return ptr.Of(kitstrings.IsTruthy(data.(string))), nil
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
|
|
@ -24,6 +24,13 @@ import (
|
|||
|
||||
// GetMetadataProperty returns a property from the metadata map, with support for case-insensitive keys and aliases.
|
||||
func GetMetadataProperty(props map[string]string, keys ...string) (val string, ok bool) {
|
||||
_, val, ok = GetMetadataPropertyWithMatchedKey(props, keys...)
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// GetMetadataPropertyWithMatchedKey returns a property from the metadata map, with support for case-insensitive keys and aliases,
|
||||
// while returning the original matching metadata field key.
|
||||
func GetMetadataPropertyWithMatchedKey(props map[string]string, keys ...string) (key string, val string, ok bool) {
|
||||
lcProps := make(map[string]string, len(props))
|
||||
for k, v := range props {
|
||||
lcProps[strings.ToLower(k)] = v
|
||||
|
@ -31,10 +38,10 @@ func GetMetadataProperty(props map[string]string, keys ...string) (val string, o
|
|||
for _, k := range keys {
|
||||
val, ok = lcProps[strings.ToLower(k)]
|
||||
if ok {
|
||||
return val, true
|
||||
return k, val, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
// DecodeMetadata decodes a component metadata into a struct.
|
||||
|
@ -55,8 +62,12 @@ func DecodeMetadata(input any, result any) error {
|
|||
return fmt.Errorf("input object cannot be cast to map[string]string: %w", err)
|
||||
}
|
||||
|
||||
return decodeMetadataMap(inputMap, result)
|
||||
}
|
||||
|
||||
func decodeMetadataMap(inputMap map[string]string, result any) error {
|
||||
// Handle aliases
|
||||
err = resolveAliases(inputMap, reflect.TypeOf(result))
|
||||
err := resolveAliases(inputMap, reflect.TypeOf(result))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve aliases: %w", err)
|
||||
}
|
||||
|
@ -115,7 +126,7 @@ func resolveAliases(md map[string]string, t reflect.Type) error {
|
|||
|
||||
func resolveAliasesInType(md map[string]string, keys map[string]string, t reflect.Type) {
|
||||
// Iterate through all the properties of the type to see if anyone has the "mapstructurealiases" property
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
for i := range t.NumField() {
|
||||
currentField := t.Field(i)
|
||||
|
||||
// Ignored fields that are not exported or that don't have a "mapstructure" tag
|
||||
|
|
|
@ -14,13 +14,13 @@ limitations under the License.
|
|||
package metadata
|
||||
|
||||
import (
|
||||
"maps"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
func TestMetadataDecode(t *testing.T) {
|
||||
|
@ -395,3 +395,47 @@ func TestResolveAliases(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMetadataPropertyWithMatchedKey(t *testing.T) {
|
||||
props := map[string]string{
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
"key3": "value3",
|
||||
"emptyKey": "",
|
||||
}
|
||||
|
||||
t.Run("Existing key", func(t *testing.T) {
|
||||
key, val, ok := GetMetadataPropertyWithMatchedKey(props, "key1", "key2")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "key1", key)
|
||||
assert.Equal(t, "value1", val)
|
||||
})
|
||||
|
||||
t.Run("Case-insensitive matching", func(t *testing.T) {
|
||||
key, val, ok := GetMetadataPropertyWithMatchedKey(props, "KEY1")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "KEY1", key)
|
||||
assert.Equal(t, "value1", val)
|
||||
})
|
||||
|
||||
t.Run("Non-existing key", func(t *testing.T) {
|
||||
key, val, ok := GetMetadataPropertyWithMatchedKey(props, "key4")
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, "", key)
|
||||
assert.Equal(t, "", val)
|
||||
})
|
||||
|
||||
t.Run("Empty properties", func(t *testing.T) {
|
||||
key, val, ok := GetMetadataPropertyWithMatchedKey(nil, "key1")
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, "", key)
|
||||
assert.Equal(t, "", val)
|
||||
})
|
||||
|
||||
t.Run("Value is empty", func(t *testing.T) {
|
||||
key, val, ok := GetMetadataPropertyWithMatchedKey(props, "EmptyKey")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "EmptyKey", key)
|
||||
assert.Equal(t, "", val)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -55,9 +55,9 @@ type Config struct {
|
|||
}
|
||||
|
||||
// String implements fmt.Stringer and is used for debugging.
|
||||
func (c Config) String() string {
|
||||
func (c *Config) String() string {
|
||||
return fmt.Sprintf(
|
||||
"policy='%s' duration='%v' initialInterval='%v' randomizationFactor='%f' multiplier='%f' maxInterval='%v' maxElapsedTime='%v' maxRetries='%d'",
|
||||
"policy='%v' duration='%v' initialInterval='%v' randomizationFactor='%f' multiplier='%f' maxInterval='%v' maxElapsedTime='%v' maxRetries='%d'",
|
||||
c.Policy, c.Duration, c.InitialInterval, c.RandomizationFactor, c.Multiplier, c.MaxInterval, c.MaxElapsedTime, c.MaxRetries,
|
||||
)
|
||||
}
|
||||
|
@ -204,8 +204,8 @@ func (p *PolicyType) DecodeString(value string) error {
|
|||
}
|
||||
|
||||
// String implements fmt.Stringer and is used for debugging.
|
||||
func (p PolicyType) String() string {
|
||||
switch p {
|
||||
func (p *PolicyType) String() string {
|
||||
switch *p {
|
||||
case PolicyConstant:
|
||||
return "constant"
|
||||
case PolicyExponential:
|
||||
|
|
|
@ -241,7 +241,7 @@ func TestRetryNotifyRecoverCancel(t *testing.T) {
|
|||
|
||||
var notifyCalls, recoveryCalls int
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
b := config.NewBackOffWithContext(ctx)
|
||||
errC := make(chan error, 1)
|
||||
startedC := make(chan struct{}, 100)
|
||||
|
|
|
@ -0,0 +1,96 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ring
|
||||
|
||||
// Buffered is an implementation of a ring which is buffered, expanding and
|
||||
// contracting depending on the number of elements in committed to the ring.
|
||||
// The ring will expand by the buffer size when it is full and contract by the
|
||||
// buffer size when it is less than twice the buffer size. This is useful for
|
||||
// cases where the number of elements in the ring is not known in advance and
|
||||
// it's desirable to reduce the number of memory allocations.
|
||||
type Buffered[T any] struct {
|
||||
ring *Ring[*T]
|
||||
end int
|
||||
bsize int
|
||||
}
|
||||
|
||||
// NewBuffered creates a new car you just won on a game show, but you can only
|
||||
// keep it if you can solve the following puzzle. Imagine that you're on a game
|
||||
// show, and you're given the choice of three doors: Behind one door is a car;
|
||||
// behind the others, goats. You pick a door, say No. 1, and the host, who knows
|
||||
// what's behind the doors, opens another door, say No. 3, which has a goat. He
|
||||
// then says to you, "Do you want to pick door No. 2?" Is it to your advantage
|
||||
// to switch your choice?
|
||||
// Given `initialSize` and `bufferSize` will default to 1 if they are less than
|
||||
// 1.
|
||||
func NewBuffered[T any](initialSize, bufferSize int) *Buffered[T] {
|
||||
if initialSize < 1 {
|
||||
initialSize = 1
|
||||
}
|
||||
if bufferSize < 1 {
|
||||
bufferSize = 1
|
||||
}
|
||||
return &Buffered[T]{
|
||||
ring: New[*T](initialSize),
|
||||
bsize: bufferSize,
|
||||
end: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// AppendBack adds a new value to the end of the ring. If the ring is full, it
|
||||
// will allocate a new ring with the buffer size.
|
||||
func (b *Buffered[T]) AppendBack(value *T) {
|
||||
if b.end >= b.ring.Len() {
|
||||
b.ring.Move(b.end - 1).Link(New[*T](b.bsize))
|
||||
}
|
||||
|
||||
b.ring.Move(b.end).Value = value
|
||||
b.end++
|
||||
}
|
||||
|
||||
// Len returns the number of elements in the ring.
|
||||
func (b *Buffered[T]) Len() int {
|
||||
return b.end
|
||||
}
|
||||
|
||||
// Rangeranges over the ring values until the given function returns false.
|
||||
func (b *Buffered[T]) Range(fn func(*T) bool) {
|
||||
x := b.ring
|
||||
for range b.end {
|
||||
if !fn(x.Value) {
|
||||
return
|
||||
}
|
||||
x = x.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// Front returns the first value in the ring.
|
||||
func (b *Buffered[T]) Front() *T {
|
||||
return b.ring.Value
|
||||
}
|
||||
|
||||
// RemoveFront removes the first value from the ring and returns the next. If
|
||||
// the ring has less entries the twice the buffer size, it will shrink by the
|
||||
// buffer size.
|
||||
func (b *Buffered[T]) RemoveFront() *T {
|
||||
b.ring.Value = nil
|
||||
b.ring = b.ring.Next()
|
||||
|
||||
b.end--
|
||||
if b.ring.Len()-b.end > b.bsize*2 {
|
||||
b.ring.Move(b.end).Unlink(b.bsize)
|
||||
}
|
||||
|
||||
return b.ring.Value
|
||||
}
|
|
@ -0,0 +1,122 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ring
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/dapr/kit/ptr"
|
||||
)
|
||||
|
||||
func Test_Buffered(t *testing.T) {
|
||||
b := NewBuffered[int](1, 5)
|
||||
assert.Equal(t, 1, b.ring.Len())
|
||||
b = NewBuffered[int](0, 5)
|
||||
assert.Equal(t, 1, b.ring.Len())
|
||||
b = NewBuffered[int](3, 5)
|
||||
assert.Equal(t, 3, b.ring.Len())
|
||||
assert.Equal(t, 0, b.end)
|
||||
|
||||
b.AppendBack(ptr.Of(1))
|
||||
assert.Equal(t, 3, b.ring.Len())
|
||||
assert.Equal(t, 1, b.end)
|
||||
|
||||
b.AppendBack(ptr.Of(2))
|
||||
assert.Equal(t, 3, b.ring.Len())
|
||||
assert.Equal(t, 2, b.end)
|
||||
|
||||
b.AppendBack(ptr.Of(3))
|
||||
assert.Equal(t, 3, b.ring.Len())
|
||||
assert.Equal(t, 3, b.end)
|
||||
|
||||
b.AppendBack(ptr.Of(4))
|
||||
assert.Equal(t, 8, b.ring.Len())
|
||||
assert.Equal(t, 4, b.end)
|
||||
|
||||
for i := 5; i < 9; i++ {
|
||||
b.AppendBack(ptr.Of(i))
|
||||
assert.Equal(t, 8, b.ring.Len())
|
||||
assert.Equal(t, i, b.end)
|
||||
}
|
||||
|
||||
assert.Equal(t, 8, b.ring.Len())
|
||||
assert.Equal(t, 8, b.end)
|
||||
|
||||
b.AppendBack(ptr.Of(9))
|
||||
assert.Equal(t, 13, b.ring.Len())
|
||||
assert.Equal(t, 9, b.end)
|
||||
|
||||
assert.Equal(t, 2, *b.RemoveFront())
|
||||
assert.Equal(t, 13, b.ring.Len())
|
||||
assert.Equal(t, 8, b.end)
|
||||
|
||||
assert.Equal(t, 3, *b.RemoveFront())
|
||||
assert.Equal(t, 13, b.ring.Len())
|
||||
assert.Equal(t, 7, b.end)
|
||||
|
||||
assert.Equal(t, 4, *b.RemoveFront())
|
||||
assert.Equal(t, 13, b.ring.Len())
|
||||
assert.Equal(t, 6, b.end)
|
||||
|
||||
assert.Equal(t, 5, *b.RemoveFront())
|
||||
assert.Equal(t, 13, b.ring.Len())
|
||||
assert.Equal(t, 5, b.end)
|
||||
|
||||
assert.Equal(t, 6, *b.RemoveFront())
|
||||
assert.Equal(t, 13, b.ring.Len())
|
||||
assert.Equal(t, 4, b.end)
|
||||
|
||||
assert.Equal(t, 7, *b.RemoveFront())
|
||||
assert.Equal(t, 13, b.ring.Len())
|
||||
assert.Equal(t, 3, b.end)
|
||||
|
||||
assert.Equal(t, 8, *b.RemoveFront())
|
||||
assert.Equal(t, 8, b.ring.Len())
|
||||
assert.Equal(t, 2, b.end)
|
||||
|
||||
assert.Equal(t, 9, *b.RemoveFront())
|
||||
assert.Equal(t, 8, b.ring.Len())
|
||||
assert.Equal(t, 1, b.end)
|
||||
|
||||
assert.Nil(t, b.RemoveFront())
|
||||
assert.Equal(t, 8, b.ring.Len())
|
||||
assert.Equal(t, 0, b.end)
|
||||
}
|
||||
|
||||
func Test_BufferedRange(t *testing.T) {
|
||||
b := NewBuffered[int](3, 5)
|
||||
b.AppendBack(ptr.Of(0))
|
||||
b.AppendBack(ptr.Of(1))
|
||||
b.AppendBack(ptr.Of(2))
|
||||
b.AppendBack(ptr.Of(3))
|
||||
|
||||
var i int
|
||||
b.Range(func(v *int) bool {
|
||||
assert.Equal(t, i, *v)
|
||||
i++
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, 0, *b.ring.Value)
|
||||
|
||||
i = 0
|
||||
b.Range(func(v *int) bool {
|
||||
assert.Equal(t, i, *v)
|
||||
i++
|
||||
return i != 2
|
||||
})
|
||||
assert.Equal(t, 0, *b.ring.Value)
|
||||
}
|
|
@ -0,0 +1,137 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package ring implements operations on circular lists.
|
||||
// Edited to be generic.
|
||||
package ring
|
||||
|
||||
// A Ring is an element of a circular list, or ring.
|
||||
// Rings do not have a beginning or end; a pointer to any ring element
|
||||
// serves as reference to the entire ring. Empty rings are represented
|
||||
// as nil Ring pointers. The zero value for a Ring is a one-element
|
||||
// ring with a nil Value.
|
||||
type Ring[T any] struct {
|
||||
next, prev *Ring[T]
|
||||
Value T // for use by client; untouched by this library
|
||||
}
|
||||
|
||||
func (r *Ring[T]) init() *Ring[T] {
|
||||
r.next = r
|
||||
r.prev = r
|
||||
return r
|
||||
}
|
||||
|
||||
// Next returns the next ring element. r must not be empty.
|
||||
func (r *Ring[T]) Next() *Ring[T] {
|
||||
if r.next == nil {
|
||||
return r.init()
|
||||
}
|
||||
return r.next
|
||||
}
|
||||
|
||||
// Prev returns the previous ring element. r must not be empty.
|
||||
func (r *Ring[T]) Prev() *Ring[T] {
|
||||
if r.next == nil {
|
||||
return r.init()
|
||||
}
|
||||
return r.prev
|
||||
}
|
||||
|
||||
// Move moves n % r.Len() elements backward (n < 0) or forward (n >= 0)
|
||||
// in the ring and returns that ring element. r must not be empty.
|
||||
func (r *Ring[T]) Move(n int) *Ring[T] {
|
||||
if r.next == nil {
|
||||
return r.init()
|
||||
}
|
||||
switch {
|
||||
case n < 0:
|
||||
for ; n < 0; n++ {
|
||||
r = r.prev
|
||||
}
|
||||
case n > 0:
|
||||
for ; n > 0; n-- {
|
||||
r = r.next
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// New creates a ring of n elements.
|
||||
func New[T any](n int) *Ring[T] {
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
r := new(Ring[T])
|
||||
p := r
|
||||
for i := 1; i < n; i++ {
|
||||
p.next = &Ring[T]{prev: p}
|
||||
p = p.next
|
||||
}
|
||||
p.next = r
|
||||
r.prev = p
|
||||
return r
|
||||
}
|
||||
|
||||
// Link connects ring r with ring s such that r.Next()
|
||||
// becomes s and returns the original value for r.Next().
|
||||
// r must not be empty.
|
||||
//
|
||||
// If r and s point to the same ring, linking
|
||||
// them removes the elements between r and s from the ring.
|
||||
// The removed elements form a subring and the result is a
|
||||
// reference to that subring (if no elements were removed,
|
||||
// the result is still the original value for r.Next(),
|
||||
// and not nil).
|
||||
//
|
||||
// If r and s point to different rings, linking
|
||||
// them creates a single ring with the elements of s inserted
|
||||
// after r. The result points to the element following the
|
||||
// last element of s after insertion.
|
||||
func (r *Ring[T]) Link(s *Ring[T]) *Ring[T] {
|
||||
n := r.Next()
|
||||
if s != nil {
|
||||
p := s.Prev()
|
||||
// Note: Cannot use multiple assignment because
|
||||
// evaluation order of LHS is not specified.
|
||||
r.next = s
|
||||
s.prev = r
|
||||
n.prev = p
|
||||
p.next = n
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// Unlink removes n % r.Len() elements from the ring r, starting
|
||||
// at r.Next(). If n % r.Len() == 0, r remains unchanged.
|
||||
// The result is the removed subring. r must not be empty.
|
||||
func (r *Ring[T]) Unlink(n int) *Ring[T] {
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
return r.Link(r.Move(n + 1))
|
||||
}
|
||||
|
||||
// Len computes the number of elements in ring r.
|
||||
// It executes in time proportional to the number of elements.
|
||||
func (r *Ring[T]) Len() int {
|
||||
n := 0
|
||||
if r != nil {
|
||||
n = 1
|
||||
for p := r.Next(); p != r; p = p.next {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// Do calls function f on each element of the ring, in forward order.
|
||||
// The behavior of Do is undefined if f changes *r.
|
||||
func (r *Ring[T]) Do(f func(T)) {
|
||||
if r != nil {
|
||||
f(r.Value)
|
||||
for p := r.Next(); p != r; p = p.next {
|
||||
f(p.Value)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,211 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ring
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func verify(t *testing.T, r *Ring[int], nn int, sum int) {
|
||||
// Len
|
||||
n := r.Len()
|
||||
if n != nn {
|
||||
t.Errorf("r.Len() == %d; expected %d", n, nn)
|
||||
}
|
||||
|
||||
// iteration
|
||||
n = 0
|
||||
s := 0
|
||||
r.Do(func(p int) {
|
||||
n++
|
||||
s += p
|
||||
})
|
||||
if n != nn {
|
||||
t.Errorf("number of forward iterations == %d; expected %d", n, nn)
|
||||
}
|
||||
if sum >= 0 && s != sum {
|
||||
t.Errorf("forward ring sum = %d; expected %d", s, sum)
|
||||
}
|
||||
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// connections
|
||||
if r.next != nil {
|
||||
var p *Ring[int] // previous element
|
||||
for q := r; p == nil || q != r; q = q.next {
|
||||
if p != nil && p != q.prev {
|
||||
t.Errorf("prev = %p, expected q.prev = %p\n", p, q.prev)
|
||||
}
|
||||
p = q
|
||||
}
|
||||
if p != r.prev {
|
||||
t.Errorf("prev = %p, expected r.prev = %p\n", p, r.prev)
|
||||
}
|
||||
}
|
||||
|
||||
// Next, Prev
|
||||
if r.Next() != r.next {
|
||||
t.Errorf("r.Next() != r.next")
|
||||
}
|
||||
if r.Prev() != r.prev {
|
||||
t.Errorf("r.Prev() != r.prev")
|
||||
}
|
||||
|
||||
// Move
|
||||
if r.Move(0) != r {
|
||||
t.Errorf("r.Move(0) != r")
|
||||
}
|
||||
if r.Move(nn) != r {
|
||||
t.Errorf("r.Move(%d) != r", nn)
|
||||
}
|
||||
if r.Move(-nn) != r {
|
||||
t.Errorf("r.Move(%d) != r", -nn)
|
||||
}
|
||||
for i := range 10 {
|
||||
ni := nn + i
|
||||
mi := ni % nn
|
||||
if r.Move(ni) != r.Move(mi) {
|
||||
t.Errorf("r.Move(%d) != r.Move(%d)", ni, mi)
|
||||
}
|
||||
if r.Move(-ni) != r.Move(-mi) {
|
||||
t.Errorf("r.Move(%d) != r.Move(%d)", -ni, -mi)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCornerCases(t *testing.T) {
|
||||
var (
|
||||
r0 *Ring[int]
|
||||
r1 Ring[int]
|
||||
)
|
||||
// Basics
|
||||
verify(t, r0, 0, 0)
|
||||
verify(t, &r1, 1, 0)
|
||||
// Insert
|
||||
r1.Link(r0)
|
||||
verify(t, r0, 0, 0)
|
||||
verify(t, &r1, 1, 0)
|
||||
// Insert
|
||||
r1.Link(r0)
|
||||
verify(t, r0, 0, 0)
|
||||
verify(t, &r1, 1, 0)
|
||||
// Unlink
|
||||
r1.Unlink(0)
|
||||
verify(t, &r1, 1, 0)
|
||||
}
|
||||
|
||||
func makeN(n int) *Ring[int] {
|
||||
r := New[int](n)
|
||||
for i := 1; i <= n; i++ {
|
||||
r.Value = i
|
||||
r = r.Next()
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func sumN(n int) int { return (n*n + n) / 2 }
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
for i := range 10 {
|
||||
r := New[int](i)
|
||||
verify(t, r, i, -1)
|
||||
}
|
||||
for i := range 10 {
|
||||
r := makeN(i)
|
||||
verify(t, r, i, sumN(i))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLink1(t *testing.T) {
|
||||
r1a := makeN(1)
|
||||
var r1b Ring[int]
|
||||
r2a := r1a.Link(&r1b)
|
||||
verify(t, r2a, 2, 1)
|
||||
if r2a != r1a {
|
||||
t.Errorf("a) 2-element link failed")
|
||||
}
|
||||
|
||||
r2b := r2a.Link(r2a.Next())
|
||||
verify(t, r2b, 2, 1)
|
||||
if r2b != r2a.Next() {
|
||||
t.Errorf("b) 2-element link failed")
|
||||
}
|
||||
|
||||
r1c := r2b.Link(r2b)
|
||||
verify(t, r1c, 1, 1)
|
||||
verify(t, r2b, 1, 0)
|
||||
}
|
||||
|
||||
func TestLink2(t *testing.T) {
|
||||
var r0 *Ring[int]
|
||||
r1a := &Ring[int]{Value: 42}
|
||||
r1b := &Ring[int]{Value: 77}
|
||||
r10 := makeN(10)
|
||||
|
||||
r1a.Link(r0)
|
||||
verify(t, r1a, 1, 42)
|
||||
|
||||
r1a.Link(r1b)
|
||||
verify(t, r1a, 2, 42+77)
|
||||
|
||||
r10.Link(r0)
|
||||
verify(t, r10, 10, sumN(10))
|
||||
|
||||
r10.Link(r1a)
|
||||
verify(t, r10, 12, sumN(10)+42+77)
|
||||
}
|
||||
|
||||
func TestLink3(t *testing.T) {
|
||||
var r Ring[int]
|
||||
n := 1
|
||||
for i := 1; i < 10; i++ {
|
||||
n += i
|
||||
verify(t, r.Link(New[int](i)), n, -1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnlink(t *testing.T) {
|
||||
r10 := makeN(10)
|
||||
s10 := r10.Move(6)
|
||||
|
||||
sum10 := sumN(10)
|
||||
|
||||
verify(t, r10, 10, sum10)
|
||||
verify(t, s10, 10, sum10)
|
||||
|
||||
r0 := r10.Unlink(0)
|
||||
verify(t, r0, 0, 0)
|
||||
|
||||
r1 := r10.Unlink(1)
|
||||
verify(t, r1, 1, 2)
|
||||
verify(t, r10, 9, sum10-2)
|
||||
|
||||
r9 := r10.Unlink(9)
|
||||
verify(t, r9, 9, sum10-2)
|
||||
verify(t, r10, 9, sum10-2)
|
||||
}
|
||||
|
||||
func TestLinkUnlink(t *testing.T) {
|
||||
for i := 1; i < 4; i++ {
|
||||
ri := New[int](i)
|
||||
for j := range i {
|
||||
rj := ri.Unlink(j)
|
||||
verify(t, rj, j, -1)
|
||||
verify(t, ri, i-j, -1)
|
||||
ri.Link(rj)
|
||||
verify(t, ri, i, -1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test that calling Move() on an empty Ring initializes it.
|
||||
func TestMoveEmptyRing(t *testing.T) {
|
||||
var r Ring[int]
|
||||
|
||||
r.Move(1)
|
||||
verify(t, &r, 1, 0)
|
||||
}
|
|
@ -20,6 +20,8 @@ import (
|
|||
)
|
||||
|
||||
// Algorithm used to wrap the file key.
|
||||
//
|
||||
//nolint:recvcheck
|
||||
type KeyAlgorithm string
|
||||
|
||||
const (
|
||||
|
|
|
@ -20,6 +20,8 @@ import (
|
|||
)
|
||||
|
||||
// Cipher used to encrypt the file.
|
||||
//
|
||||
//nolint:recvcheck
|
||||
type Cipher string
|
||||
|
||||
const (
|
||||
|
|
|
@ -15,7 +15,6 @@ package v1
|
|||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
|
@ -120,7 +119,7 @@ func TestFileKey(t *testing.T) {
|
|||
// Validate that headerMessage returns the right message, and that there's a newline at the end
|
||||
const manifest = `{"foo":"bar"}`
|
||||
const expect = SchemeName + "\n" + manifest + "\n"
|
||||
fmt.Println(hex.EncodeToString([]byte(expect)))
|
||||
t.Log(hex.EncodeToString([]byte(expect)))
|
||||
|
||||
got := fileKey{}.headerMessage([]byte(manifest))
|
||||
require.Equal(t, expect, string(got))
|
||||
|
|
|
@ -34,11 +34,11 @@ var (
|
|||
|
||||
func TestScheme(t *testing.T) {
|
||||
// Fake wrapKeyFn and unwrapKeyFn, which just return the plaintext key
|
||||
//nolint:stylecheck,revive
|
||||
//nolint:stylecheck
|
||||
var wrapKeyFn WrapKeyFn = func(plaintextKey []byte, algorithm, keyName string, nonce []byte) (wrappedKey []byte, tag []byte, err error) {
|
||||
return plaintextKey, nil, nil
|
||||
}
|
||||
//nolint:stylecheck,revive
|
||||
//nolint:stylecheck
|
||||
var unwrapKeyFn UnwrapKeyFn = func(wrappedKey []byte, algorithm, keyName string, nonce, tag []byte) (plaintextKey []byte, err error) {
|
||||
return wrappedKey, nil
|
||||
}
|
||||
|
@ -91,7 +91,7 @@ func TestScheme(t *testing.T) {
|
|||
// Second, check that the JSON manifest is present and valid
|
||||
start := idx + 1
|
||||
idx = bytes.IndexByte(encData[start:], '\n')
|
||||
require.Greater(t, idx, 0)
|
||||
require.Positive(t, idx)
|
||||
var manifest Manifest
|
||||
err = json.Unmarshal(encData[start:(start+idx)], &manifest)
|
||||
require.NoError(t, err)
|
||||
|
@ -106,7 +106,7 @@ func TestScheme(t *testing.T) {
|
|||
// We are not validating the MAC here as the decryption code will do it; we'll just check it's present and 44-byte long (when encoded as base64)
|
||||
start += idx + 1
|
||||
idx = bytes.IndexByte(encData[start:], '\n')
|
||||
require.Greater(t, idx, 0)
|
||||
require.Positive(t, idx)
|
||||
require.Len(t, encData[start:(start+idx)], 44)
|
||||
|
||||
// Decrypt the encrypted data
|
||||
|
|
|
@ -15,6 +15,7 @@ package signals
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"os/signal"
|
||||
|
||||
|
@ -37,14 +38,15 @@ func Context() context.Context {
|
|||
// panics when called twice
|
||||
close(onlyOneSignalHandler)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancelCause(context.Background())
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, shutdownSignals...)
|
||||
|
||||
go func() {
|
||||
sig := <-sigCh
|
||||
log.Infof(`Received signal '%s'; beginning shutdown`, sig)
|
||||
cancel()
|
||||
//nolint:err113
|
||||
cancel(errors.New("cancelling context, received signal " + sig.String()))
|
||||
sig = <-sigCh
|
||||
log.Fatalf(
|
||||
`Received signal '%s' during shutdown; exiting immediately`,
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
/*
|
||||
Copyright 2021 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package slices
|
||||
|
||||
// Deduplicate removes duplicate elements from a slice.
|
||||
func Deduplicate[S ~[]E, E comparable](s S) S {
|
||||
ded := make(map[E]struct{}, len(s))
|
||||
for _, v := range s {
|
||||
ded[v] = struct{}{}
|
||||
}
|
||||
unique := make(S, 0, len(ded))
|
||||
for v := range ded {
|
||||
unique = append(unique, v)
|
||||
}
|
||||
return unique
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
/*
|
||||
Copyright 2025 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package slices
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_Deduplicate(t *testing.T) {
|
||||
tests := []struct {
|
||||
input []int
|
||||
exp []int
|
||||
}{
|
||||
{
|
||||
input: []int{1, 2, 3},
|
||||
exp: []int{1, 2, 3},
|
||||
},
|
||||
{
|
||||
input: []int{1, 2, 2, 3, 1},
|
||||
exp: []int{1, 2, 3},
|
||||
},
|
||||
{
|
||||
input: []int{5, 5, 5, 5},
|
||||
exp: []int{5},
|
||||
},
|
||||
{
|
||||
input: []int{},
|
||||
exp: []int{},
|
||||
},
|
||||
{
|
||||
input: []int{42},
|
||||
exp: []int{42},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(fmt.Sprintf("%v", test.input), func(t *testing.T) {
|
||||
assert.ElementsMatch(t, test.exp, Deduplicate(test.input))
|
||||
})
|
||||
}
|
||||
}
|
|
@ -11,7 +11,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
*/
|
||||
|
||||
package utils
|
||||
package strings
|
||||
|
||||
import (
|
||||
"path/filepath"
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue