Compare commits

...

38 Commits

Author SHA1 Message Date
Fabian Martinez 8b780b4d81
Concurrency ctesting (#133)
* Adds concurrency/ctesting

Adds concurrency/ctesting package, used for concurrently running a set
of runners, and collect the results via a testing assertion.

Signed-off-by: joshvanl <me@joshvanl.dev>

* lint

Signed-off-by: Fabian Martinez <46371672+famarting@users.noreply.github.com>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
Signed-off-by: Fabian Martinez <46371672+famarting@users.noreply.github.com>
Co-authored-by: joshvanl <me@joshvanl.dev>
2025-07-17 11:07:48 -03:00
Javier Aliaga 9d4f384c57
feat: Add subsecond precision to jobs (#129)
* feat: Add subsecond precision to jobs

Signed-off-by: Javier Aliaga <javier@diagrid.io>

* chore: linting fixes

Signed-off-by: Javier Aliaga <javier@diagrid.io>

* chore: add tests

Signed-off-by: Javier Aliaga <javier@diagrid.io>

* chore: remove all precision to cron constant schedule

Signed-off-by: Javier Aliaga <javier@diagrid.io>

* chore: Time parse supports RFC3339nano

Signed-off-by: Javier Aliaga <javier@diagrid.io>

* chore: Fix TestFile_CurrentTrustAnchors flaky test

Signed-off-by: Javier Aliaga <javier@diagrid.io>

* chore: No need to fallback to RFC3339

Signed-off-by: Javier Aliaga <javier@diagrid.io>

* chore: Wait until file trustanchors is running

Signed-off-by: Javier Aliaga <javier@diagrid.io>

---------

Signed-off-by: Javier Aliaga <javier@diagrid.io>
2025-07-10 11:03:56 -03:00
Josh van Leeuwen 7c4cedad37
concurrency/dir: fix mkdir permissions (#131)
Fix Mkdir permissions by using 0o700 over `ModePerm`.

Signed-off-by: joshvanl <me@joshvanl.dev>
2025-07-08 21:41:07 -05:00
Javier Aliaga 7409957e9e
Update deps (#130)
* chore: update go and golanci-lint versions

Signed-off-by: Javier Aliaga <javier@diagrid.io>

* chore: Fix linting issues

Signed-off-by: Javier Aliaga <javier@diagrid.io>

* chore: Remove nolint exclusion

Signed-off-by: Javier Aliaga <javier@diagrid.io>

* chore: Revert range to for loop

Signed-off-by: Javier Aliaga <javier@diagrid.io>

---------

Signed-off-by: Javier Aliaga <javier@diagrid.io>
2025-07-08 17:55:56 -03:00
Josh van Leeuwen 34f8820d2a
concurrency/runner: make cancel with cause (#128)
Useful for making concurrency runner closer errors more useful for
shutdowns.

Signed-off-by: joshvanl <me@joshvanl.dev>
2025-07-03 09:51:31 -05:00
Joni Collinge 598b032bce
Fix deprecation comment to reference correct function (#125)
Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>
2025-06-16 17:06:11 +01:00
Josh van Leeuwen d7d50a1e1b
events/loop: drain queue on close (#124)
* events/loop: drain queue on close

Update looper to drain the Enqueue loop in the event that an error has
occurred when Enqueuing. Handle an error'd `Run` on `Close` by
respecting the `closedCh` channel.

Signed-off-by: joshvanl <me@joshvanl.dev>

* Update loop.go

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2025-06-16 08:26:36 -05:00
Josh van Leeuwen baea626399
Update .golangci.yml to remove deprecations (#122)
* Update .golangci.yml  to remove deprecations

Signed-off-by: joshvanl <me@joshvanl.dev>

* Update .golangci.yml

Co-authored-by: Cassie Coyle <cassie.i.coyle@gmail.com>
Signed-off-by: Josh van Leeuwen <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
Signed-off-by: Josh van Leeuwen <me@joshvanl.dev>
Co-authored-by: Cassie Coyle <cassie.i.coyle@gmail.com>
2025-05-22 08:58:18 -05:00
Joni Collinge bc7dc566c4
Add JWT handling to spiffe package (#118)
* Adds JWT handling to spiffe

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Update log

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Add jwtbundle

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Update file watcher

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Clean up jwt spiffe

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* lint

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Update renewal behavior

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Updates based on joshvanl feedback

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* go mod tidy

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* lint

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* lint

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Move ready chan check to avoid race

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Add small delay after fs write to allow watcher to pick up change

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Resolve feedback

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Resolve feedback

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* lint

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

---------

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>
2025-05-16 13:15:56 +01:00
Josh van Leeuwen 98fe567235
events/loop: add reset (#120)
* events/loop: add reset

Update loop implementation is include functionality for Reset which is
useful when caching the loop struct for future use to reduce
allocations.

Signed-off-by: joshvanl <me@joshvanl.dev>

* lint

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2025-05-15 23:23:38 +01:00
Josh van Leeuwen e3d4a8f1b4
Add Copyright headers to env, and remove utils. (#103)
Chore to add Copyright headers to `env` files.

Move containing funcs and deletes `utils` package. `utils` packages are
generally a code smell, and better placed in a more descriptive package
name that gives context.

Signed-off-by: joshvanl <me@joshvanl.dev>
2025-04-23 11:29:15 -03:00
Josh van Leeuwen 77af8ac182
events/loop & slices (#119)
* events/loop & slices

Adds a generic control loop implementation to `event/loop`.

Adds a new `slices` package that provides a generic slice de-duplication
func.

Makes events batcher and queue processer taker in Options.

Allows enqueuing multiple processor items in same func call.

Signed-off-by: joshvanl <me@joshvanl.dev>

* Lint

Signed-off-by: joshvanl <me@joshvanl.dev>

* lint

Signed-off-by: joshvanl <me@joshvanl.dev>

* lint

Signed-off-by: joshvanl <me@joshvanl.dev>

* Elements match

Signed-off-by: joshvanl <me@joshvanl.dev>

* Adds buffer size option to events loop

Signed-off-by: joshvanl <me@joshvanl.dev>

* nit

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2025-04-23 11:28:21 -03:00
Josh van Leeuwen a3f06e444a
concurrency: Make runner closer take a logger (#116)
* concurrency: Make runner closer take a logger

Signed-off-by: joshvanl <me@joshvanl.dev>

* Linting

Signed-off-by: joshvanl <me@joshvanl.dev>

* Linter

Signed-off-by: joshvanl <me@joshvanl.dev>

* Linter

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2025-04-10 15:34:54 -03:00
Josh van Leeuwen 39c4bf57bd
concurrency/lock: Adds context and outercancel locks (#115)
* concurrency/lock: Adds context and outercancel locks

Adds context lock which will cancel and return an error to Lock & RLock
if the given context cancels before the lock is achieved.

Adds outercancel lock which will cancel all RLocks in progress if the
outer Lock is called.

Signed-off-by: joshvanl <me@joshvanl.dev>

* lint

Signed-off-by: joshvanl <me@joshvanl.dev>

* Lint

Signed-off-by: joshvanl <me@joshvanl.dev>

* Fix tests

Signed-off-by: joshvanl <me@joshvanl.dev>

* lint

Signed-off-by: joshvanl <me@joshvanl.dev>

* lint

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2025-02-24 11:15:07 -08:00
Josh van Leeuwen 6271c8be59
Signals: Add cause to signal context cancel (#114)
* Signals: Add cause to signal context cancel

Signed-off-by: joshvanl <me@joshvanl.dev>

* lint

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2025-02-04 17:31:44 -08:00
Luis Rascão c46009f360
crypto/spiffe: adds a multi trust anchor selector (#113)
* crypto/spiffe: adds a multi trust anchor selector

Adds a new trust anchors provider that returns different trust
anchors depending on the requested trust domain out of a pre-loaded
set.

Signed-off-by: Luis Rascao <luis.rascao@gmail.com>

* fixup! crypto/spiffe: adds a multi trust anchor selector

Signed-off-by: Luis Rascao <luis.rascao@gmail.com>

---------

Signed-off-by: Luis Rascao <luis.rascao@gmail.com>
2025-01-28 21:07:41 -08:00
Josh van Leeuwen c90b807d32
concurrency/dir & WriteIdentityToFile (#112)
* concurrency/dir & WriteIdentityToFile

Adds concurrency/dir package to handle atomically writing files to a
directory.

Adds support for spiffe to optionally write the identity certificate,
private key and trust bundle to a given directory.

Signed-off-by: joshvanl <me@joshvanl.dev>

* Adds Dir comment

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2025-01-21 06:16:49 -08:00
Josh van Leeuwen fb19570696
events/broadcaster Buffer val when writing to subscriber channel (#111)
Signed-off-by: joshvanl <me@joshvanl.dev>
2025-01-10 11:22:55 -08:00
Josh van Leeuwen 65ba3783f2
Adds events/broadcaster (#110)
* Adds events/broadcaster

Adds a generic buffered dmessage broadcaster which will relay a typed
message to a dynamic set of subscribers.

Signed-off-by: joshvanl <me@joshvanl.dev>

* Lint

Signed-off-by: joshvanl <me@joshvanl.dev>

* Review comments

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2025-01-06 15:46:09 +00:00
Jake Engelberg 30e2c24840
Add error code and category funcs (#109) 2024-11-27 08:52:51 -08:00
Josh van Leeuwen 24b59a803d
Adds generic ring (#108)
* Adds generic ring

Adds a generic implementation of the stdblib ring buffer so that each
ring `Value` can be a concrete type.
https://pkg.go.dev/container/ring

Adds `Len() int` and `Keys() []K` func to the generic Map cmap.

Changes `events/queue` Processor `Queueable` to be an exported type. No
functional change, but consumed types should be exported.

Signed-off-by: joshvanl <me@joshvanl.dev>

* Adds ring_test.go

Signed-off-by: joshvanl <me@joshvanl.dev>

* Linting

Signed-off-by: joshvanl <me@joshvanl.dev>

* Linting

Signed-off-by: joshvanl <me@joshvanl.dev>

* Update Do func to be typed

Signed-off-by: joshvanl <me@joshvanl.dev>

* Adds ring/buffered

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2024-11-20 13:48:57 -08:00
Josh van Leeuwen d37dc603d0
Update go to 1.23.1, golangci-lint 1.61.0 (#105)
* Update go to 1.23.1, golangci-lint 1.61.0

Signed-off-by: joshvanl <me@joshvanl.dev>

* Adds only new issues

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2024-10-21 13:52:18 -07:00
Josh van Leeuwen 866002abe6
Adds FIFO concurrent lock & generic concurrent Slice (#107)
* Adds FIFO concurrent lock & generic concurrent Slice

Adds a new concurrency/fifo package which implements a fifo mutex, as
well as a concurrently safe comparable indexed map of fifo mutexes.

Adds a simple generic concurrently safe slice implementation, which can
currently only grow.

Moves the map generic implementations in `/concurrency` to
`/concurrency/cmap`.

Signed-off-by: joshvanl <me@joshvanl.dev>

* Linting

Signed-off-by: joshvanl <me@joshvanl.dev>

* Move concurency/slice.go to concurency/slice/slice.go and add concurency/slice/string.go

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2024-10-15 06:03:26 -07:00
Josh van Leeuwen bc3a4f0fb4
events/queue: Don't return queue error when closed (#106)
`Enqueue` and `Dequeue` returned an error which was cumbersome for
consumers to need to check and deal with. The only time this error would
occur would be when the queue was closed, so instead of returning an
error we simply ignore the queue event.

Signed-off-by: joshvanl <me@joshvanl.dev>
2024-10-07 07:39:32 -07:00
Josh van Leeuwen 2d6ff15a97
Map: Adds concurrency/map (#104) 2024-09-23 21:10:40 -07:00
Elena Kolevska 3823663aa4
Adds a GetEnvIntWithRange utility function (#102)
* Adds a GetEnvIntWithRange utility function

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Fixes linter

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Update to using time duration instead of int for seconds

Signed-off-by: Elena Kolevska <elena@kolevska.com>

---------

Signed-off-by: Elena Kolevska <elena@kolevska.com>
2024-09-09 14:50:17 -07:00
Josh van Leeuwen 502671bade
concurrency/mutexmap Move Unlock to after operation (#101)
Signed-off-by: joshvanl <me@joshvanl.dev>
2024-09-09 14:49:03 -07:00
Josh van Leeuwen 26b564d9d0
MutexMap: Adds DeleteRUnlock and fixes RLock/RUnlock (#100)
Signed-off-by: joshvanl <me@joshvanl.dev>
2024-07-23 17:01:21 -07:00
Josh van Leeuwen 58c6d9df14
concurrency/mutexmap: Adds DeleteUnlock (#99)
Signed-off-by: joshvanl <me@joshvanl.dev>
2024-07-22 09:34:53 -07:00
Sam e2508d6e9e
fix(security): update vulnerabilities (#96)
* fix(security): update vulnerabilities

Signed-off-by: Samantha Coyle <sam@diagrid.io>

* style: make linter happy

Signed-off-by: Samantha Coyle <sam@diagrid.io>

* fix: add another fix for a sec vul

Signed-off-by: Samantha Coyle <sam@diagrid.io>

---------

Signed-off-by: Samantha Coyle <sam@diagrid.io>
2024-06-24 14:24:34 -07:00
Elena Kolevska 106329e583
Mutexmap (#95)
* Adds Mutex Map

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Adds an atomic map

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* More work on atomic map and mutex map

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Fixes, improvements and more tests

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Updates interface

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Linter

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Refactors atomic map to use generics

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* cleanups

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Apply suggestions from code review

Co-authored-by: Cassie Coyle <cassie.i.coyle@gmail.com>
Signed-off-by: Elena Kolevska <elena-kolevska@users.noreply.github.com>

* small reorg

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Adds ItemCount()

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Removes atomicmap in favour of haxmap

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* formats fix and adds comment

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Update concurrency/mutexmap.go

Co-authored-by: Josh van Leeuwen <me@joshvanl.dev>
Signed-off-by: Elena Kolevska <elena-kolevska@users.noreply.github.com>

* Uses built in `clear`

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Revert "Removes atomicmap in favour of haxmap"

This reverts commit 20ca9ad197.

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Uses clear() for atomic map too

Signed-off-by: Elena Kolevska <elena@kolevska.com>

---------

Signed-off-by: Elena Kolevska <elena@kolevska.com>
Signed-off-by: Elena Kolevska <elena-kolevska@users.noreply.github.com>
Co-authored-by: Cassie Coyle <cassie.i.coyle@gmail.com>
Co-authored-by: Josh van Leeuwen <me@joshvanl.dev>
2024-05-23 15:57:05 -07:00
Annu Singh ccffb60016
fixing a dead hyperlink (#94)
Signed-off-by: Annu Singh <annu.singh@terramate.io>
2024-04-16 13:28:40 -07:00
Josh van Leeuwen a3f906d609
Adds crypto/spiffe (#92)
* Adds crypto/spiffe

Adds spiffe package to crypto. This is a refactored version of the
existing `pkg/security` package. This new package is more modulated and
fuller test coverage.

This package has been moved so that it can be both imported by dapr &
components-contrib, as well as making the package more suitable for
further development to support X.509 Component auth.
https://github.com/dapr/proposals/pull/51

Also moves in `test/utils` from dapr to `crypto/test` for shared usage.

Signed-off-by: joshvanl <me@joshvanl.dev>

* Adds crypto/spiffe/context

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2024-04-15 10:19:26 -07:00
lee 0c7cfce53d
Add placement error code (#91)
Signed-off-by: Lee Fowler <fowler.lee8@gmail.com>
2024-04-02 13:38:09 +03:00
Josh van Leeuwen 6c3b2ee1ef
Events: type Batcher value & ensure queue order (#89)
* Events: type Batcher value & ensure queue order

Update Batcher to allow for typed value types.

Update Batcher and Queue to execute values in order they were added.

Signed-off-by: joshvanl <me@joshvanl.dev>

* Delay batcher to ensure key is sent in order

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2024-03-26 13:31:51 +02:00
Sam e33fbab745
feat: enable original key to be returned for metadata property fields (#88)
Signed-off-by: Samantha Coyle <sam@diagrid.io>
Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
Co-authored-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
2024-03-06 07:26:01 -08:00
Mike Nguyen 050e34c9b9
chore: bump lestrrat-go/jwx/v2 from v2.0.15 to v2.0.20 (#86)
Signed-off-by: mikeee <hey@mike.ee>
2024-02-26 09:24:31 -08:00
Chaitanya Bhangale 9e733a35f1
Add cryptography error code (#84)
Co-authored-by: Chaitanya Bhangale <chaitanyabhangale@Chaitanyas-MacBook-Pro-2.local>
2024-02-19 13:42:03 -08:00
104 changed files with 6439 additions and 650 deletions

View File

@ -24,7 +24,7 @@ jobs:
GOOS: ${{ matrix.target_os }}
GOARCH: ${{ matrix.target_arch }}
GOPROXY: https://proxy.golang.org
GOLANGCI_LINT_VER: v1.55.1
GOLANGCI_LINT_VER: v1.64.8
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macOS-latest]

5
.gitignore vendored Normal file
View File

@ -0,0 +1,5 @@
**/.DS_Store
.idea
.vscode
.vs
vendor

View File

@ -4,7 +4,7 @@ run:
concurrency: 4
# timeout for analysis, e.g. 30s, 5m, default is 1m
deadline: 10m
timeout: 15m
# exit code when at least one issue was found, default is 1
issues-exit-code: 1
@ -15,29 +15,34 @@ run:
# list of build tags, all linters use it. Default is empty list.
build-tags:
- unit
# which dirs to skip: they won't be analyzed;
# can use regexp here: generated.*, regexp is applied on full path;
# default value is empty list, but next dirs are always skipped independently
# from this option's value:
# third_party$, testdata$, examples$, Godeps$, builtin$
skip-dirs:
- ^pkg.*client.*clientset.*versioned.*
- ^pkg.*client.*informers.*externalversions.*
- ^pkg.*proto.*
- allcomponents
- subtlecrypto
# which files to skip: they will be analyzed, but issues from them
# won't be reported. Default value is empty list, but there is
# no need to include all autogenerated files, we confidently recognize
# autogenerated files. If it's not please let us know.
skip-files:
# skip-files:
# - ".*\\.my\\.go$"
# - lib/bad.go
issues:
# which dirs to skip: they won't be analyzed;
# can use regexp here: generated.*, regexp is applied on full path;
# default value is empty list, but next dirs are always skipped independently
# from this option's value:
# third_party$, testdata$, examples$, Godeps$, builtin$
exclude-dirs:
- ^pkg.*client.*clientset.*versioned.*
- ^pkg.*client.*informers.*externalversions.*
- ^pkg.*proto.*
- pkg/proto
# output configuration options
output:
# colored-line-number|line-number|json|tab|checkstyle, default is "colored-line-number"
format: tab
formats:
- format: tab
# print lines of code with issue, default is true
print-issued-lines: true
@ -57,23 +62,19 @@ linters-settings:
# default is false: such cases aren't reported by default.
check-blank: false
# [deprecated] comma-separated list of pairs of the form pkg:regex
# the regex is used to ignore names within pkg. (default "fmt:.*").
# see https://github.com/kisielk/errcheck#the-deprecated-method for details
ignore: fmt:.*,io/ioutil:^Read.*
exclude-functions:
- fmt:.*
- io/ioutil:^Read.*
# path to a file containing a list of functions to exclude from checking
# see https://github.com/kisielk/errcheck#excluding-functions for details
exclude:
# exclude:
funlen:
lines: 60
statements: 40
govet:
# report about shadowed variables
check-shadowing: true
# settings per analyzer
settings:
printf: # analyzer name, run `go tool vet help` to see all analyzers
@ -86,13 +87,12 @@ linters-settings:
# enable or disable analyzers by name
enable:
- atomicalign
enable-all: false
disable:
- shadow
enable-all: false
disable-all: false
golint:
revive:
# minimal confidence for issues, default is 0.8
min-confidence: 0.8
confidence: 0.8
gofmt:
# simplify code: gofmt with `-s` option, true by default
simplify: true
@ -106,9 +106,6 @@ linters-settings:
gocognit:
# minimal code complexity to report, 30 by default (but we recommend 10-20)
min-complexity: 10
maligned:
# print struct with more effective memory layout or not, false by default
suggest-new: true
dupl:
# tokens count to trigger issue, 150 by default
threshold: 100
@ -121,55 +118,60 @@ linters-settings:
rules:
main:
deny:
- pkg: "github.com/Sirupsen/logrus"
desc: "must use github.com/dapr/kit/logger"
- pkg: "github.com/agrea/ptr"
desc: "must use github.com/dapr/kit/ptr"
- pkg: "go.uber.org/atomic"
desc: "must use sync/atomic"
- pkg: "golang.org/x/net/context"
desc: "must use context"
- pkg: "github.com/pkg/errors"
desc: "must use standard library (errors package and/or fmt.Errorf)"
- pkg: "github.com/go-chi/chi$"
desc: "must use github.com/go-chi/chi/v5"
- pkg: "github.com/cenkalti/backoff$"
desc: "must use github.com/cenkalti/backoff/v4"
- pkg: "github.com/cenkalti/backoff/v2"
desc: "must use github.com/cenkalti/backoff/v4"
- pkg: "github.com/cenkalti/backoff/v3"
desc: "must use github.com/cenkalti/backoff/v4"
- pkg: "github.com/benbjohnson/clock"
desc: "must use k8s.io/utils/clock"
- pkg: "github.com/ghodss/yaml"
desc: "must use sigs.k8s.io/yaml"
- pkg: "gopkg.in/yaml.v2"
desc: "must use gopkg.in/yaml.v3"
- pkg: "github.com/golang-jwt/jwt"
desc: "must use github.com/lestrrat-go/jwx/v2"
- pkg: "github.com/golang-jwt/jwt/v2"
desc: "must use github.com/lestrrat-go/jwx/v2"
- pkg: "github.com/golang-jwt/jwt/v3"
desc: "must use github.com/lestrrat-go/jwx/v2"
- pkg: "github.com/golang-jwt/jwt/v4"
desc: "must use github.com/lestrrat-go/jwx/v2"
- pkg: "github.com/gogo/status"
desc: "must use google.golang.org/grpc/status"
- pkg: "github.com/gogo/protobuf"
desc: "must use google.golang.org/protobuf"
- pkg: "github.com/lestrrat-go/jwx/jwa"
desc: "must use github.com/lestrrat-go/jwx/v2"
- pkg: "github.com/lestrrat-go/jwx/jwt"
desc: "must use github.com/lestrrat-go/jwx/v2"
- pkg: "github.com/labstack/gommon/log"
desc: "must use github.com/dapr/kit/logger"
- pkg: "github.com/gobuffalo/logger"
desc: "must use github.com/dapr/kit/logger"
- pkg: "github.com/Sirupsen/logrus"
desc: "must use github.com/dapr/kit/logger"
- pkg: "github.com/agrea/ptr"
desc: "must use github.com/dapr/kit/ptr"
- pkg: "go.uber.org/atomic"
desc: "must use sync/atomic"
- pkg: "golang.org/x/net/context"
desc: "must use context"
- pkg: "github.com/pkg/errors"
desc: "must use standard library (errors package and/or fmt.Errorf)"
- pkg: "github.com/go-chi/chi$"
desc: "must use github.com/go-chi/chi/v5"
- pkg: "github.com/cenkalti/backoff$"
desc: "must use github.com/cenkalti/backoff/v4"
- pkg: "github.com/cenkalti/backoff/v2"
desc: "must use github.com/cenkalti/backoff/v4"
- pkg: "github.com/cenkalti/backoff/v3"
desc: "must use github.com/cenkalti/backoff/v4"
- pkg: "github.com/benbjohnson/clock"
desc: "must use k8s.io/utils/clock"
- pkg: "github.com/ghodss/yaml"
desc: "must use sigs.k8s.io/yaml"
- pkg: "gopkg.in/yaml.v2"
desc: "must use gopkg.in/yaml.v3"
- pkg: "github.com/golang-jwt/jwt"
desc: "must use github.com/lestrrat-go/jwx/v2"
- pkg: "github.com/golang-jwt/jwt/v2"
desc: "must use github.com/lestrrat-go/jwx/v2"
- pkg: "github.com/golang-jwt/jwt/v3"
desc: "must use github.com/lestrrat-go/jwx/v2"
- pkg: "github.com/golang-jwt/jwt/v4"
desc: "must use github.com/lestrrat-go/jwx/v2"
# pkg: Commonly auto-completed by gopls
- pkg: "github.com/gogo/status"
desc: "must use google.golang.org/grpc/status"
- pkg: "github.com/gogo/protobuf"
desc: "must use google.golang.org/protobuf"
- pkg: "github.com/lestrrat-go/jwx/jwa"
desc: "must use github.com/lestrrat-go/jwx/v2"
- pkg: "github.com/lestrrat-go/jwx/jwt"
desc: "must use github.com/lestrrat-go/jwx/v2"
- pkg: "github.com/labstack/gommon/log"
desc: "must use github.com/dapr/kit/logger"
- pkg: "github.com/gobuffalo/logger"
desc: "must use github.com/dapr/kit/logger"
- pkg: "k8s.io/utils/pointer"
desc: "must use github.com/dapr/kit/ptr"
- pkg: "k8s.io/utils/ptr"
desc: "must use github.com/dapr/kit/ptr"
misspell:
# Correct spellings using locale preferences for US or UK.
# Default is to use a neutral variety of English.
# Setting locale to US will correct the British spelling of 'colour' to 'color'.
locale: default
# locale: default
ignore-words:
- someword
lll:
@ -178,17 +180,9 @@ linters-settings:
line-length: 120
# tab width in spaces. Default to 1.
tab-width: 1
unparam:
# Inspect exported functions, default is false. Set to true if no external program/library imports your code.
# XXX: if you enable this setting, unparam will report a lot of false-positives in text editors:
# if it's called for subdir of a project it can't find external interfaces. All text editor integrations
# with golangci-lint call it on a directory with the changed file.
check-exported: false
nakedret:
# make an issue if func has more lines of code than this setting and it has naked returns; default is 30
max-func-lines: 30
nolintlint:
allow-unused: true
prealloc:
# XXX: we don't recommend using this linter before doing performance profiling.
# For most programs usage of prealloc will be a premature optimization.
@ -203,7 +197,6 @@ linters-settings:
# See https://go-critic.github.io/overview#checks-overview
# To check which checks are enabled run `GL_DEBUG=gocritic golangci-lint run`
# By default list of stable checks is used.
enabled-checks:
# Which checks should be disabled; can't be combined with 'enabled-checks'; default is empty
disabled-checks:
@ -251,63 +244,51 @@ linters-settings:
allow-assign-and-call: true
# Allow multiline assignments to be cuddled. Default is true.
allow-multiline-assign: true
# Allow case blocks to end with a whitespace.
allow-case-traling-whitespace: true
# Allow declarations (var) to be cuddled.
allow-cuddle-declarations: false
# If the number of lines in a case block is equal to or lager than this number,
# the case *must* end white a newline.
# https://github.com/bombsimon/wsl/blob/master/doc/configuration.md#force-case-trailing-whitespace
# Default: 0
force-case-trailing-whitespace: 1
linters:
fast: false
enable-all: true
disable:
# TODO Enforce the below linters later
- musttag
- dupl
- nonamedreturns
- errcheck
- funlen
- goconst
- gochecknoglobals
- gochecknoinits
- gocyclo
- gocognit
- nosnakecase
- varcheck
- structcheck
- deadcode
- godox
- interfacer
- lll
- maligned
- scopelint
- unparam
- wsl
- gomnd
- testpackage
- goerr113
- nestif
- nlreturn
- exhaustive
- exhaustruct
- noctx
- gci
- golint
- tparallel
- paralleltest
- wrapcheck
- tagliatelle
- ireturn
- exhaustive
- exhaustivestruct
- exhaustruct
- errchkjson
- contextcheck
- gomoddirectives
- godot
- cyclop
- varnamelen
- gosec
- tagalign
- errorlint
- forcetypeassert
- ifshort
- maintidx
- nilnil
- predeclared
@ -316,4 +297,13 @@ linters:
- wastedassign
- containedctx
- gosimple
- forbidigo
- nonamedreturns
- asasalint
- rowserrcheck
- sqlclosecheck
- inamedparam
- tagalign
- mnd
- canonicalheader
- err113
- fatcontext

View File

@ -78,6 +78,13 @@ test-race:
lint:
$(GOLANGCI_LINT) run --timeout=20m
################################################################################
# Target: lint-fix #
################################################################################
.PHONY: lint-fix
lint-fix:
$(GOLANGCI_LINT) run --timeout=20m --fix
################################################################################
# Target: go.mod #
################################################################################

View File

@ -32,7 +32,7 @@ func TestByteSlicePool(t *testing.T) {
assert.Equal(t, &bs, &bs2)
assert.Equal(t, minCap, cap(bs2))
for i := 0; i < minCap; i++ {
for range minCap {
bs2 = append(bs2, 0)
}

View File

@ -26,11 +26,7 @@ import (
"github.com/dapr/kit/logger"
)
var (
ErrManagerAlreadyClosed = errors.New("runner manager already closed")
log = logger.NewLogger("dapr.kit.concurrency")
)
var ErrManagerAlreadyClosed = errors.New("runner manager already closed")
// RunnerCloserManager is a RunnerManager that also implements Closing of the
// added closers once the main runners are done.
@ -64,7 +60,7 @@ type RunnerCloserManager struct {
// NewRunnerCloserManager creates a new RunnerCloserManager with the given
// grace period and runners.
// If gracePeriod is nil, the grace period is infinite.
func NewRunnerCloserManager(gracePeriod *time.Duration, runners ...Runner) *RunnerCloserManager {
func NewRunnerCloserManager(log logger.Logger, gracePeriod *time.Duration, runners ...Runner) *RunnerCloserManager {
c := &RunnerCloserManager{
mngr: NewRunnerManager(runners...),
clock: clock.RealClock{},

View File

@ -24,8 +24,12 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
clocktesting "k8s.io/utils/clock/testing"
"github.com/dapr/kit/logger"
)
var log = logger.NewLogger("test")
type mockCloser func() error
func (m mockCloser) Close() error {
@ -34,21 +38,21 @@ func (m mockCloser) Close() error {
func Test_RunnerClosterManager(t *testing.T) {
t.Run("runner with no tasks or closers should return nil", func(t *testing.T) {
require.NoError(t, NewRunnerCloserManager(nil).Run(context.Background()))
require.NoError(t, NewRunnerCloserManager(log, nil).Run(t.Context()))
})
t.Run("runner with a task that completes should return nil", func(t *testing.T) {
var i atomic.Int32
require.NoError(t, NewRunnerCloserManager(nil, func(context.Context) error {
require.NoError(t, NewRunnerCloserManager(log, nil, func(context.Context) error {
i.Add(1)
return nil
}).Run(context.Background()))
}).Run(t.Context()))
assert.Equal(t, int32(1), i.Load())
})
t.Run("runner with a task and closer that completes should return nil", func(t *testing.T) {
var i atomic.Int32
mngr := NewRunnerCloserManager(nil, func(context.Context) error {
mngr := NewRunnerCloserManager(log, nil, func(context.Context) error {
i.Add(1)
return nil
})
@ -56,13 +60,13 @@ func Test_RunnerClosterManager(t *testing.T) {
i.Add(1)
return nil
}))
require.NoError(t, mngr.Run(context.Background()))
require.NoError(t, mngr.Run(t.Context()))
assert.Equal(t, int32(2), i.Load())
})
t.Run("runner with multiple tasks and closers that complete should return nil", func(t *testing.T) {
var i atomic.Int32
mngr := NewRunnerCloserManager(nil,
mngr := NewRunnerCloserManager(log, nil,
func(context.Context) error {
i.Add(1)
return nil
@ -94,82 +98,82 @@ func Test_RunnerClosterManager(t *testing.T) {
}),
))
require.NoError(t, mngr.Run(context.Background()))
require.NoError(t, mngr.Run(t.Context()))
assert.Equal(t, int32(7), i.Load())
})
t.Run("a runner that errors should error but still call the closers", func(t *testing.T) {
var i atomic.Int32
mngr := NewRunnerCloserManager(nil,
func(ctx context.Context) error {
mngr := NewRunnerCloserManager(log, nil,
func(context.Context) error {
i.Add(1)
return errors.New("error")
},
func(ctx context.Context) error {
func(context.Context) error {
i.Add(1)
return nil
},
func(ctx context.Context) error {
func(context.Context) error {
i.Add(1)
return nil
},
)
require.NoError(t, mngr.AddCloser(
func(ctx context.Context) error {
func(context.Context) error {
i.Add(1)
return nil
},
))
require.EqualError(t, mngr.Run(context.Background()), "error")
require.EqualError(t, mngr.Run(t.Context()), "error")
assert.Equal(t, int32(4), i.Load())
})
t.Run("a runner that has closter errors should error", func(t *testing.T) {
var i atomic.Int32
mngr := NewRunnerCloserManager(nil,
func(ctx context.Context) error {
mngr := NewRunnerCloserManager(log, nil,
func(context.Context) error {
i.Add(1)
return nil
},
func(ctx context.Context) error {
func(context.Context) error {
i.Add(1)
return nil
},
func(ctx context.Context) error {
func(context.Context) error {
i.Add(1)
return nil
},
)
require.NoError(t, mngr.AddCloser(
func(ctx context.Context) error {
func(context.Context) error {
i.Add(1)
return errors.New("error")
},
))
require.EqualError(t, mngr.Run(context.Background()), "error")
require.EqualError(t, mngr.Run(t.Context()), "error")
assert.Equal(t, int32(4), i.Load())
})
t.Run("a runner with multiple errors should collect all errors (string match)", func(t *testing.T) {
var i atomic.Int32
mngr := NewRunnerCloserManager(nil,
func(ctx context.Context) error {
mngr := NewRunnerCloserManager(log, nil,
func(context.Context) error {
i.Add(1)
return errors.New("error")
},
func(ctx context.Context) error {
func(context.Context) error {
i.Add(1)
return errors.New("error")
},
func(ctx context.Context) error {
func(context.Context) error {
i.Add(1)
return errors.New("error")
},
)
require.NoError(t, mngr.AddCloser(
func(ctx context.Context) error {
func(context.Context) error {
i.Add(1)
return errors.New("closererror")
},
@ -183,7 +187,7 @@ func Test_RunnerClosterManager(t *testing.T) {
}),
))
err := mngr.Run(context.Background())
err := mngr.Run(t.Context())
require.Error(t, err)
require.ErrorContains(t, err, "error\nerror\nerror\nclosererror\nclosererror\nclosererror") //nolint:dupword
assert.Equal(t, int32(6), i.Load())
@ -191,22 +195,22 @@ func Test_RunnerClosterManager(t *testing.T) {
t.Run("a runner with multiple errors should collect all errors (unique)", func(t *testing.T) {
var i atomic.Int32
mngr := NewRunnerCloserManager(nil,
func(ctx context.Context) error {
mngr := NewRunnerCloserManager(log, nil,
func(context.Context) error {
i.Add(1)
return errors.New("error1")
},
func(ctx context.Context) error {
func(context.Context) error {
i.Add(1)
return errors.New("error2")
},
func(ctx context.Context) error {
func(context.Context) error {
i.Add(1)
return errors.New("error3")
},
)
require.NoError(t, mngr.AddCloser(
func(ctx context.Context) error {
func(context.Context) error {
i.Add(1)
return errors.New("closererror1")
},
@ -220,7 +224,7 @@ func Test_RunnerClosterManager(t *testing.T) {
}),
))
err := mngr.Run(context.Background())
err := mngr.Run(t.Context())
require.Error(t, err)
assert.ElementsMatch(t,
[]string{"error1", "error2", "error3", "closererror1", "closererror2", "closererror3"},
@ -231,26 +235,26 @@ func Test_RunnerClosterManager(t *testing.T) {
t.Run("should be able to add runner with New, Add and AddCloser", func(t *testing.T) {
var i atomic.Int32
mngr := NewRunnerCloserManager(nil,
func(ctx context.Context) error {
mngr := NewRunnerCloserManager(log, nil,
func(context.Context) error {
i.Add(1)
return nil
},
)
require.NoError(t, mngr.Add(
func(ctx context.Context) error {
func(context.Context) error {
i.Add(1)
return nil
},
))
require.NoError(t, mngr.Add(
func(ctx context.Context) error {
func(context.Context) error {
i.Add(1)
return nil
},
))
require.NoError(t, mngr.AddCloser(
func(ctx context.Context) error {
func(context.Context) error {
i.Add(1)
return nil
},
@ -261,14 +265,14 @@ func Test_RunnerClosterManager(t *testing.T) {
},
))
require.NoError(t, mngr.Run(context.Background()))
require.NoError(t, mngr.Run(t.Context()))
assert.Equal(t, int32(5), i.Load())
})
t.Run("when a runner returns, expect context to be cancelled for other runners, but not for closers returning", func(t *testing.T) {
var i atomic.Int32
mngr := NewRunnerCloserManager(nil,
func(ctx context.Context) error {
mngr := NewRunnerCloserManager(log, nil,
func(context.Context) error {
i.Add(1)
return nil
},
@ -295,7 +299,7 @@ func Test_RunnerClosterManager(t *testing.T) {
closer1Ch := make(chan struct{})
closer2Ch := make(chan struct{})
require.NoError(t, mngr.AddCloser(
func(ctx context.Context) error {
func(context.Context) error {
i.Add(1)
close(closer1Ch)
return nil
@ -321,13 +325,13 @@ func Test_RunnerClosterManager(t *testing.T) {
},
))
require.NoError(t, mngr.Run(context.Background()))
require.NoError(t, mngr.Run(t.Context()))
assert.Equal(t, int32(6), i.Load())
})
t.Run("when a runner errors, expect context to be cancelled for other runners, but closers should still run", func(t *testing.T) {
var i atomic.Int32
mngr := NewRunnerCloserManager(nil,
mngr := NewRunnerCloserManager(log, nil,
func(ctx context.Context) error {
i.Add(1)
select {
@ -346,7 +350,7 @@ func Test_RunnerClosterManager(t *testing.T) {
}
return errors.New("error2")
},
func(ctx context.Context) error {
func(context.Context) error {
i.Add(1)
return errors.New("error3")
},
@ -373,7 +377,7 @@ func Test_RunnerClosterManager(t *testing.T) {
},
))
err := mngr.Run(context.Background())
err := mngr.Run(t.Context())
require.Error(t, err)
assert.ElementsMatch(t,
[]string{"error1", "error2", "error3", "closererror1", "closererror2"},
@ -384,45 +388,45 @@ func Test_RunnerClosterManager(t *testing.T) {
t.Run("a manger started twice should error", func(t *testing.T) {
var i atomic.Int32
m := NewRunnerCloserManager(nil, func(ctx context.Context) error {
m := NewRunnerCloserManager(log, nil, func(context.Context) error {
i.Add(1)
return nil
})
require.NoError(t, m.Run(context.Background()))
require.NoError(t, m.Run(t.Context()))
assert.Equal(t, int32(1), i.Load())
require.EqualError(t, m.Run(context.Background()), "runner manager already started")
require.EqualError(t, m.Run(t.Context()), "runner manager already started")
assert.Equal(t, int32(1), i.Load())
})
t.Run("a manger started twice should error", func(t *testing.T) {
var i atomic.Int32
m := NewRunnerCloserManager(nil, func(ctx context.Context) error {
m := NewRunnerCloserManager(log, nil, func(context.Context) error {
i.Add(1)
return nil
})
require.NoError(t, m.AddCloser(func(ctx context.Context) error {
require.NoError(t, m.AddCloser(func(context.Context) error {
i.Add(1)
return nil
}))
require.NoError(t, m.Run(context.Background()))
require.NoError(t, m.Run(t.Context()))
assert.Equal(t, int32(2), i.Load())
require.NoError(t, m.Close())
require.NoError(t, m.Close())
require.EqualError(t, m.Run(context.Background()), "runner manager already started")
require.EqualError(t, m.Run(t.Context()), "runner manager already started")
assert.Equal(t, int32(2), i.Load())
})
t.Run("adding a task to a started manager should error", func(t *testing.T) {
var i atomic.Int32
m := NewRunnerCloserManager(nil, func(ctx context.Context) error {
m := NewRunnerCloserManager(log, nil, func(context.Context) error {
i.Add(1)
return nil
})
require.NoError(t, m.Run(context.Background()))
require.NoError(t, m.Run(t.Context()))
assert.Equal(t, int32(1), i.Load())
err := m.Add(func(ctx context.Context) error {
err := m.Add(func(context.Context) error {
i.Add(1)
return nil
})
@ -433,14 +437,14 @@ func Test_RunnerClosterManager(t *testing.T) {
t.Run("adding a closer to a closing manager should error", func(t *testing.T) {
var i atomic.Int32
m := NewRunnerCloserManager(nil, func(ctx context.Context) error {
m := NewRunnerCloserManager(log, nil, func(context.Context) error {
i.Add(1)
return nil
})
require.NoError(t, m.Run(context.Background()))
require.NoError(t, m.Run(t.Context()))
assert.Equal(t, int32(1), i.Load())
require.NoError(t, m.Close())
err := m.AddCloser(func(ctx context.Context) error {
err := m.AddCloser(func(context.Context) error {
i.Add(1)
return nil
})
@ -450,19 +454,19 @@ func Test_RunnerClosterManager(t *testing.T) {
})
t.Run("if grace period is not given, should have no force shutdown", func(t *testing.T) {
mngr := NewRunnerCloserManager(nil)
mngr := NewRunnerCloserManager(log, nil)
assert.Empty(t, mngr.closers)
})
t.Run("if grace period is given, should have force shutdown", func(t *testing.T) {
dur := time.Second
mngr := NewRunnerCloserManager(&dur)
mngr := NewRunnerCloserManager(log, &dur)
assert.Len(t, mngr.closers, 1)
})
t.Run("if closing but grace period not reached, should return", func(t *testing.T) {
dur := time.Second
mngr := NewRunnerCloserManager(&dur)
mngr := NewRunnerCloserManager(log, &dur)
var i atomic.Int32
require.NoError(t, mngr.AddCloser(func() {
@ -482,7 +486,7 @@ func Test_RunnerClosterManager(t *testing.T) {
errCh := make(chan error)
go func() {
errCh <- mngr.Run(context.Background())
errCh <- mngr.Run(t.Context())
}()
select {
@ -505,13 +509,13 @@ func Test_RunnerClosterManager(t *testing.T) {
t.Run("if closing and grace period is reached, should force shutdown", func(t *testing.T) {
dur := time.Second
mngr := NewRunnerCloserManager(&dur)
mngr := NewRunnerCloserManager(log, &dur)
assert.Len(t, mngr.closers, 1)
clock := clocktesting.NewFakeClock(time.Now())
mngr.clock = clock
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(t.Context())
t.Cleanup(cancel)
fatalCalled := make(chan struct{})
@ -533,7 +537,7 @@ func Test_RunnerClosterManager(t *testing.T) {
}
})
go func() {
errCh <- mngr.Run(context.Background())
errCh <- mngr.Run(t.Context())
}()
assert.Eventually(t, func() bool {
@ -555,7 +559,7 @@ func TestClose(t *testing.T) {
t.Run("calling close should stop the main runner and call all closers", func(t *testing.T) {
var i atomic.Int32
runnerWaiting := make(chan struct{})
mngr := NewRunnerCloserManager(nil, func(ctx context.Context) error {
mngr := NewRunnerCloserManager(log, nil, func(ctx context.Context) error {
close(runnerWaiting)
<-ctx.Done()
i.Add(1)
@ -567,7 +571,7 @@ func TestClose(t *testing.T) {
errCh := make(chan error)
go func() {
errCh <- mngr.Run(context.Background())
errCh <- mngr.Run(t.Context())
}()
select {
@ -591,7 +595,7 @@ func TestClose(t *testing.T) {
t.Run("calling close should wait for all closers to return", func(t *testing.T) {
var i atomic.Int32
runnerWaiting := make(chan struct{})
mngr := NewRunnerCloserManager(nil, func(ctx context.Context) error {
mngr := NewRunnerCloserManager(log, nil, func(ctx context.Context) error {
close(runnerWaiting)
<-ctx.Done()
i.Add(1)
@ -625,7 +629,7 @@ func TestClose(t *testing.T) {
errCh := make(chan error)
go func() {
errCh <- mngr.Run(context.Background())
errCh <- mngr.Run(t.Context())
}()
select {
@ -667,7 +671,7 @@ func TestClose(t *testing.T) {
dur := time.Second
var i atomic.Int32
runnerWaiting := make(chan struct{})
mngr := NewRunnerCloserManager(&dur, func(ctx context.Context) error {
mngr := NewRunnerCloserManager(log, &dur, func(ctx context.Context) error {
close(runnerWaiting)
<-ctx.Done()
i.Add(1)
@ -710,7 +714,7 @@ func TestClose(t *testing.T) {
errCh := make(chan error)
go func() {
errCh <- mngr.Run(context.Background())
errCh <- mngr.Run(t.Context())
}()
select {
@ -754,7 +758,7 @@ func TestClose(t *testing.T) {
dur := time.Second
var i atomic.Int32
runnerWaiting := make(chan struct{})
mngr := NewRunnerCloserManager(&dur, func(ctx context.Context) error {
mngr := NewRunnerCloserManager(log, &dur, func(ctx context.Context) error {
close(runnerWaiting)
<-ctx.Done()
i.Add(1)
@ -772,7 +776,7 @@ func TestClose(t *testing.T) {
assert.Len(t, mngr.closers, 1)
returnClose := make(chan struct{})
for n := 0; n < 4; n++ {
for range 4 {
require.NoError(t, mngr.AddCloser(func() {
i.Add(1)
<-returnClose
@ -783,7 +787,7 @@ func TestClose(t *testing.T) {
errCh := make(chan error)
go func() {
errCh <- mngr.Run(context.Background())
errCh <- mngr.Run(t.Context())
}()
select {
@ -820,14 +824,14 @@ func TestClose(t *testing.T) {
})
t.Run("calling close should return the errors from the main runner and all closers", func(t *testing.T) {
mngr := NewRunnerCloserManager(nil,
func(ctx context.Context) error {
mngr := NewRunnerCloserManager(log, nil,
func(context.Context) error {
return errors.New("error1")
},
func(ctx context.Context) error {
func(context.Context) error {
return errors.New("error2")
},
func(ctx context.Context) error {
func(context.Context) error {
return errors.New("error3")
},
)
@ -846,7 +850,7 @@ func TestClose(t *testing.T) {
errCh := make(chan error)
go func() {
errCh <- mngr.Run(context.Background())
errCh <- mngr.Run(t.Context())
}()
var err error
@ -864,8 +868,8 @@ func TestClose(t *testing.T) {
t.Run("calling Close before Run should return immediately", func(t *testing.T) {
dur := time.Second
mngr := NewRunnerCloserManager(&dur,
func(ctx context.Context) error {
mngr := NewRunnerCloserManager(log, &dur,
func(context.Context) error {
return errors.New("error1")
},
)
@ -875,7 +879,7 @@ func TestClose(t *testing.T) {
require.NoError(t, mngr.Close())
require.NoError(t, mngr.Close())
assert.Equal(t, mngr.Run(context.Background()), errors.New("runner manager already started"))
assert.Equal(t, mngr.Run(t.Context()), errors.New("runner manager already started"))
})
}
@ -892,7 +896,7 @@ func TestAddCloser(t *testing.T) {
expErr: errors.Join(errors.New("unsupported closer type: int")),
},
"Add various supported closer types": {
closers: []any{new(mockCloser), func(ctx context.Context) error { return nil }, func() error { return nil }, func() {}},
closers: []any{new(mockCloser), func(context.Context) error { return nil }, func() error { return nil }, func() {}},
expErr: nil,
},
"Add combination of supported and unsupported closer types": {
@ -903,18 +907,18 @@ func TestAddCloser(t *testing.T) {
for name, test := range tests {
t.Run(name, func(t *testing.T) {
err := NewRunnerCloserManager(nil).AddCloser(test.closers...)
err := NewRunnerCloserManager(log, nil).AddCloser(test.closers...)
assert.Equalf(t, test.expErr, err, "%v", err)
})
}
t.Run("no error if adding a closer during main routine", func(t *testing.T) {
mngr := NewRunnerCloserManager(nil, func(ctx context.Context) error {
mngr := NewRunnerCloserManager(log, nil, func(ctx context.Context) error {
<-ctx.Done()
return nil
})
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(t.Context())
errCh := make(chan error)
go func() {
errCh <- mngr.Run(ctx)
@ -925,9 +929,9 @@ func TestAddCloser(t *testing.T) {
})
t.Run("should error if closing", func(t *testing.T) {
mngr := NewRunnerCloserManager(nil)
mngr := NewRunnerCloserManager(log, nil)
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(t.Context())
closerCh := make(chan struct{})
require.NoError(t, mngr.AddCloser(func() {
cancel()
@ -936,7 +940,7 @@ func TestAddCloser(t *testing.T) {
errCh := make(chan error)
go func() {
errCh <- mngr.Run(context.Background())
errCh <- mngr.Run(t.Context())
}()
select {
@ -968,15 +972,15 @@ func TestAddCloser(t *testing.T) {
})
t.Run("should error if manager already returned", func(t *testing.T) {
mngr := NewRunnerCloserManager(nil)
require.NoError(t, mngr.Run(context.Background()))
mngr := NewRunnerCloserManager(log, nil)
require.NoError(t, mngr.Run(t.Context()))
assert.Equal(t, mngr.AddCloser(nil), errors.New("runner manager already closed"))
})
}
func TestWaitUntilShutdown(t *testing.T) {
dur := time.Second * 3
mngr := NewRunnerCloserManager(&dur, func(ctx context.Context) error {
mngr := NewRunnerCloserManager(log, &dur, func(ctx context.Context) error {
<-ctx.Done()
return nil
})
@ -995,7 +999,7 @@ func TestWaitUntilShutdown(t *testing.T) {
<-returnClose
}))
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(t.Context())
errCh := make(chan error)
go func() {

111
concurrency/cmap/atomic.go Normal file
View File

@ -0,0 +1,111 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cmap
import (
"sync"
"golang.org/x/exp/constraints"
)
type AtomicValue[T constraints.Integer] struct {
lock sync.RWMutex
value T
}
func (a *AtomicValue[T]) Load() T {
a.lock.RLock()
defer a.lock.RUnlock()
return a.value
}
func (a *AtomicValue[T]) Store(v T) {
a.lock.Lock()
defer a.lock.Unlock()
a.value = v
}
func (a *AtomicValue[T]) Add(v T) T {
a.lock.Lock()
defer a.lock.Unlock()
a.value += v
return a.value
}
type Atomic[K comparable, T constraints.Integer] interface {
Get(key K) (*AtomicValue[T], bool)
GetOrCreate(key K, createT T) *AtomicValue[T]
Delete(key K)
ForEach(fn func(key K, value *AtomicValue[T]))
Clear()
}
type atomicMap[K comparable, T constraints.Integer] struct {
lock sync.RWMutex
items map[K]*AtomicValue[T]
}
func NewAtomic[K comparable, T constraints.Integer]() Atomic[K, T] {
return &atomicMap[K, T]{
items: make(map[K]*AtomicValue[T]),
}
}
func (a *atomicMap[K, T]) Get(key K) (*AtomicValue[T], bool) {
a.lock.RLock()
defer a.lock.RUnlock()
item, ok := a.items[key]
if !ok {
return nil, false
}
return item, true
}
func (a *atomicMap[K, T]) GetOrCreate(key K, createT T) *AtomicValue[T] {
a.lock.RLock()
item, ok := a.items[key]
a.lock.RUnlock()
if !ok {
a.lock.Lock()
// Double-check the key exists to avoid race condition
item, ok = a.items[key]
if !ok {
item = &AtomicValue[T]{value: createT}
a.items[key] = item
}
a.lock.Unlock()
}
return item
}
func (a *atomicMap[K, T]) Delete(key K) {
a.lock.Lock()
delete(a.items, key)
a.lock.Unlock()
}
func (a *atomicMap[K, T]) ForEach(fn func(key K, value *AtomicValue[T])) {
a.lock.RLock()
defer a.lock.RUnlock()
for k, v := range a.items {
fn(k, v)
}
}
func (a *atomicMap[K, T]) Clear() {
a.lock.Lock()
defer a.lock.Unlock()
clear(a.items)
}

View File

@ -0,0 +1,79 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cmap
import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAtomicInt32_New_Get_Delete(t *testing.T) {
m := NewAtomic[string, int32]().(*atomicMap[string, int32])
require.NotNil(t, m)
require.NotNil(t, m.items)
require.Empty(t, m.items)
t.Run("basic operations", func(t *testing.T) {
key := "key1"
value := int32(10)
// Initially, the key should not exist
_, ok := m.Get(key)
require.False(t, ok)
// Add a value and check it
m.GetOrCreate(key, 0).Store(value)
result, ok := m.Get(key)
require.True(t, ok)
assert.Equal(t, value, result.Load())
// Delete the key and check it no longer exists
m.Delete(key)
_, ok = m.Get(key)
require.False(t, ok)
})
t.Run("concurrent access multiple keys", func(t *testing.T) {
var wg sync.WaitGroup
keys := []string{"key1", "key2", "key3"}
iterations := 100
wg.Add(len(keys) * 2)
for _, key := range keys {
go func(k string) {
defer wg.Done()
for range iterations {
m.GetOrCreate(k, 0).Add(1)
}
}(key)
go func(k string) {
defer wg.Done()
for range iterations {
m.GetOrCreate(k, 0).Add(-1)
}
}(key)
}
wg.Wait()
for _, key := range keys {
val, ok := m.Get(key)
require.True(t, ok)
require.Equal(t, int32(0), val.Load())
}
})
}

99
concurrency/cmap/map.go Normal file
View File

@ -0,0 +1,99 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cmap
import (
"sync"
)
// Map is a simple _typed_ map which is safe for concurrent use.
// Favoured over sync.Map as it is typed.
type Map[K comparable, T any] interface {
Clear()
Delete(key K)
Load(key K) (T, bool)
LoadAndDelete(key K) (T, bool)
Range(fn func(key K, value T) bool)
Store(key K, value T)
Len() int
Keys() []K
}
type mapimpl[K comparable, T any] struct {
lock sync.RWMutex
m map[K]T
}
func NewMap[K comparable, T any]() Map[K, T] {
return &mapimpl[K, T]{m: make(map[K]T)}
}
func (m *mapimpl[K, T]) Clear() {
m.lock.Lock()
defer m.lock.Unlock()
m.m = make(map[K]T)
}
func (m *mapimpl[K, T]) Delete(k K) {
m.lock.Lock()
defer m.lock.Unlock()
delete(m.m, k)
}
func (m *mapimpl[K, T]) Load(k K) (T, bool) {
m.lock.RLock()
defer m.lock.RUnlock()
v, ok := m.m[k]
return v, ok
}
func (m *mapimpl[K, T]) LoadAndDelete(k K) (T, bool) {
m.lock.Lock()
defer m.lock.Unlock()
v, ok := m.m[k]
delete(m.m, k)
return v, ok
}
func (m *mapimpl[K, T]) Range(fn func(K, T) bool) {
m.lock.RLock()
defer m.lock.RUnlock()
for k, v := range m.m {
if !fn(k, v) {
break
}
}
}
func (m *mapimpl[K, T]) Store(k K, v T) {
m.lock.Lock()
defer m.lock.Unlock()
m.m[k] = v
}
func (m *mapimpl[K, T]) Len() int {
m.lock.RLock()
defer m.lock.RUnlock()
return len(m.m)
}
func (m *mapimpl[K, T]) Keys() []K {
m.lock.Lock()
defer m.lock.Unlock()
keys := make([]K, 0, len(m.m))
for k := range m.m {
keys = append(keys, k)
}
return keys
}

150
concurrency/cmap/mutex.go Normal file
View File

@ -0,0 +1,150 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cmap
import (
"sync"
)
// Mutex is an interface that defines a thread-safe map with keys of type T associated to
// read-write mutexes (sync.RWMutex), allowing for granular locking on a per-key basis.
// This can be useful for scenarios where fine-grained concurrency control is needed.
//
// Methods:
// - Lock(key T): Acquires an exclusive lock on the mutex associated with the given key.
// - Unlock(key T): Releases the exclusive lock on the mutex associated with the given key.
// - RLock(key T): Acquires a read lock on the mutex associated with the given key.
// - RUnlock(key T): Releases the read lock on the mutex associated with the given key.
// - Delete(key T): Removes the mutex associated with the given key from the map.
// - Clear(): Removes all mutexes from the map.
// - ItemCount() int: Returns the number of items (mutexes) in the map.
// - DeleteUnlock(key T): Removes the mutex associated with the given key from the map and releases the lock.
// - DeleteRUnlock(key T): Removes the mutex associated with the given key from the map and releases the read lock.
type Mutex[T comparable] interface {
Lock(key T)
Unlock(key T)
RLock(key T)
RUnlock(key T)
Delete(key T)
Clear()
ItemCount() int
DeleteUnlock(key T)
DeleteRUnlock(key T)
}
type mutex[T comparable] struct {
lock sync.RWMutex
items map[T]*sync.RWMutex
}
func NewMutex[T comparable]() Mutex[T] {
return &mutex[T]{
items: make(map[T]*sync.RWMutex),
}
}
func (a *mutex[T]) Lock(key T) {
a.lock.RLock()
mutex, ok := a.items[key]
a.lock.RUnlock()
if ok {
mutex.Lock()
return
}
a.lock.Lock()
mutex, ok = a.items[key]
if !ok {
mutex = &sync.RWMutex{}
a.items[key] = mutex
}
a.lock.Unlock()
mutex.Lock()
}
func (a *mutex[T]) Unlock(key T) {
a.lock.RLock()
mutex, ok := a.items[key]
if ok {
mutex.Unlock()
}
a.lock.RUnlock()
}
func (a *mutex[T]) RLock(key T) {
a.lock.RLock()
mutex, ok := a.items[key]
a.lock.RUnlock()
if ok {
mutex.RLock()
return
}
a.lock.Lock()
mutex, ok = a.items[key]
if !ok {
mutex = &sync.RWMutex{}
a.items[key] = mutex
}
a.lock.Unlock()
mutex.RLock()
}
func (a *mutex[T]) RUnlock(key T) {
a.lock.RLock()
mutex, ok := a.items[key]
if ok {
mutex.RUnlock()
}
a.lock.RUnlock()
}
func (a *mutex[T]) Delete(key T) {
a.lock.Lock()
delete(a.items, key)
a.lock.Unlock()
}
func (a *mutex[T]) DeleteUnlock(key T) {
a.lock.Lock()
mutex, ok := a.items[key]
if ok {
mutex.Unlock()
}
delete(a.items, key)
a.lock.Unlock()
}
func (a *mutex[T]) DeleteRUnlock(key T) {
a.lock.Lock()
mutex, ok := a.items[key]
if ok {
mutex.RUnlock()
}
delete(a.items, key)
a.lock.Unlock()
}
func (a *mutex[T]) Clear() {
a.lock.Lock()
clear(a.items)
a.lock.Unlock()
}
func (a *mutex[T]) ItemCount() int {
a.lock.Lock()
defer a.lock.Unlock()
return len(a.items)
}

View File

@ -0,0 +1,119 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cmap
import (
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewMutex_Add_Delete(t *testing.T) {
mm := NewMutex[string]().(*mutex[string])
t.Run("New mutex map", func(t *testing.T) {
require.NotNil(t, mm)
require.NotNil(t, mm.items)
require.Empty(t, mm.items)
})
t.Run("Lock and unlock mutex", func(t *testing.T) {
mm.Lock("key1")
_, ok := mm.items["key1"]
require.True(t, ok)
mm.Unlock("key1")
})
t.Run("Concurrently lock and unlock mutexes", func(t *testing.T) {
var counter atomic.Int64
var wg sync.WaitGroup
numGoroutines := 10
wg.Add(numGoroutines)
// Concurrently lock and unlock for each key
for range numGoroutines {
go func() {
defer wg.Done()
mm.Lock("key1")
counter.Add(1)
mm.Unlock("key1")
}()
}
wg.Wait()
require.Equal(t, int64(10), counter.Load())
})
t.Run("RLock and RUnlock mutex", func(t *testing.T) {
mm.RLock("key1")
_, ok := mm.items["key1"]
require.True(t, ok)
mm.RUnlock("key1")
})
t.Run("Concurrently RLock and RUnlock mutexes", func(t *testing.T) {
var counter atomic.Int64
var wg sync.WaitGroup
numGoroutines := 10
wg.Add(numGoroutines * 2)
// Concurrently RLock and RUnlock for each key
for range numGoroutines {
go func() {
defer wg.Done()
mm.RLock("key1")
counter.Add(1)
}()
}
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
assert.Equal(ct, int64(10), counter.Load())
}, 5*time.Second, 10*time.Millisecond)
for range numGoroutines {
go func() {
defer wg.Done()
mm.RUnlock("key1")
}()
}
wg.Wait()
})
t.Run("Delete mutex", func(t *testing.T) {
mm.Lock("key1")
mm.Unlock("key1")
mm.Delete("key1")
_, ok := mm.items["key1"]
require.False(t, ok)
})
t.Run("Clear all mutexes, and check item count", func(t *testing.T) {
mm.Lock("key1")
mm.Unlock("key1")
mm.Lock("key2")
mm.Unlock("key2")
require.Equal(t, 2, mm.ItemCount())
mm.Clear()
require.Empty(t, mm.items)
})
}

View File

@ -0,0 +1,94 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ctesting
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dapr/kit/concurrency"
"github.com/dapr/kit/concurrency/ctesting/internal"
)
type RunnerFn func(context.Context, assert.TestingT)
// Assert runs the provided test functions in parallel and asserts that they
// all pass.
func Assert(t *testing.T, runners ...RunnerFn) {
t.Helper()
if len(runners) == 0 {
require.Fail(t, "at least one runner function is required")
}
tt := internal.Assert(t)
ctx, cancel := context.WithCancelCause(t.Context())
t.Cleanup(func() { cancel(nil) })
doneCh := make(chan struct{}, len(runners))
for _, runner := range runners {
go func(rfn RunnerFn) {
rfn(ctx, tt)
if errs := tt.Errors(); len(errs) > 0 {
cancel(errors.Join(errs...))
}
doneCh <- struct{}{}
}(runner)
}
for range runners {
select {
case <-doneCh:
case <-t.Context().Done():
require.FailNow(t, "test context was cancelled before all runners completed")
}
}
for _, err := range tt.Errors() {
assert.NoError(t, err)
}
}
// AssertCleanup runs the provided test functions in parallel and asserts that they
// all pass, only after Cleanup,.
func AssertCleanup(t *testing.T, runners ...concurrency.Runner) {
t.Helper()
ctx, cancel := context.WithCancelCause(t.Context())
errCh := make(chan error, len(runners))
for _, runner := range runners {
go func(rfn concurrency.Runner) {
errCh <- rfn(ctx)
}(runner)
}
t.Cleanup(func() {
cancel(nil)
for range runners {
select {
case err := <-errCh:
require.NoError(t, err)
case <-time.After(10 * time.Second):
assert.Fail(t, "timeout waiting for runner to stop")
}
}
})
}

View File

@ -0,0 +1,49 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package internal
import (
"fmt"
"sync"
"testing"
"github.com/stretchr/testify/assert"
)
type Interface interface {
assert.TestingT
Errors() []error
}
type assertT struct {
t *testing.T
lock sync.Mutex
errs []error
}
func Assert(t *testing.T) Interface {
return &assertT{t: t}
}
func (a *assertT) Errorf(format string, args ...any) {
a.lock.Lock()
defer a.lock.Unlock()
a.errs = append(a.errs, fmt.Errorf(format, args...))
}
func (a *assertT) Errors() []error {
a.lock.Lock()
defer a.lock.Unlock()
return a.errs
}

90
concurrency/dir/dir.go Normal file
View File

@ -0,0 +1,90 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package dir
import (
"fmt"
"os"
"path/filepath"
"time"
"github.com/dapr/kit/logger"
)
type Options struct {
Log logger.Logger
Target string
}
// Dir atomically writes files to a given directory.
type Dir struct {
log logger.Logger
base string
target string
targetDir string
prev *string
}
func New(opts Options) *Dir {
return &Dir{
log: opts.Log,
base: filepath.Dir(opts.Target),
target: opts.Target,
targetDir: filepath.Base(opts.Target),
}
}
func (d *Dir) Write(files map[string][]byte) error {
newDir := filepath.Join(d.base, fmt.Sprintf("%d-%s", time.Now().UTC().UnixNano(), d.targetDir))
if err := os.MkdirAll(d.base, 0o700); err != nil {
return err
}
if err := os.MkdirAll(newDir, 0o700); err != nil {
return err
}
for file, b := range files {
path := filepath.Join(newDir, file)
if err := os.WriteFile(path, b, 0o600); err != nil {
return err
}
d.log.Infof("Written file %s", file)
}
if err := os.Symlink(newDir, d.target+".new"); err != nil {
return err
}
d.log.Infof("Syslink %s to %s.new", newDir, d.target)
if err := os.Rename(d.target+".new", d.target); err != nil {
return err
}
d.log.Infof("Atomic write to %s", d.target)
if d.prev != nil {
if err := os.RemoveAll(*d.prev); err != nil {
return err
}
}
d.prev = &newDir
return nil
}

62
concurrency/fifo/map.go Normal file
View File

@ -0,0 +1,62 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package fifo
// Map is a map of mutexes whose locks are acquired in a FIFO order. The map is
// pruned automatically when all locks have been released for a key.
type Map[T comparable] interface {
Lock(key T)
Unlock(key T)
}
type mapItem struct {
ilen uint64
mutex *Mutex
}
type fifoMap[T comparable] struct {
lock *Mutex
items map[T]*mapItem
}
func NewMap[T comparable]() Map[T] {
return &fifoMap[T]{
lock: New(),
items: make(map[T]*mapItem),
}
}
func (a *fifoMap[T]) Lock(key T) {
a.lock.Lock()
m, ok := a.items[key]
if !ok {
m = &mapItem{mutex: New()}
a.items[key] = m
}
m.ilen++
a.lock.Unlock()
m.mutex.Lock()
}
func (a *fifoMap[T]) Unlock(key T) {
a.lock.Lock()
m := a.items[key]
m.ilen--
if m.ilen == 0 {
delete(a.items, key)
}
a.lock.Unlock()
m.mutex.Unlock()
}

View File

@ -0,0 +1,47 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package fifo
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func Test_Map(t *testing.T) {
m := NewMap[string]().(*fifoMap[string])
assert.Empty(t, m.items)
m.Lock("key1")
assert.Len(t, m.items, 1)
assert.Equal(t, uint64(1), m.items["key1"].ilen)
go func() {
m.Lock("key1")
}()
assert.EventuallyWithT(t, func(c *assert.CollectT) {
m.lock.Lock()
assert.Equal(c, uint64(2), m.items["key1"].ilen)
m.lock.Unlock()
}, time.Second*3, time.Millisecond*10)
m.Unlock("key1")
assert.Equal(t, uint64(1), m.items["key1"].ilen)
m.Unlock("key1")
assert.Empty(t, m.items)
}

34
concurrency/fifo/mutex.go Normal file
View File

@ -0,0 +1,34 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package fifo
// Mutex is a mutex lock whose lock and unlock operations are
// first-in-first-out (FIFO).
type Mutex struct {
lock chan struct{}
}
func New() *Mutex {
return &Mutex{
lock: make(chan struct{}, 1),
}
}
func (m *Mutex) Lock() {
m.lock <- struct{}{}
}
func (m *Mutex) Unlock() {
<-m.lock
}

View File

@ -0,0 +1,62 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package lock
import (
"context"
"sync"
)
// Context is a ready write mutex lock where Locking can return early with an
// error if the context is done. No error response means the lock is acquired.
type Context struct {
lock sync.RWMutex
locked chan struct{}
}
func NewContext() *Context {
return &Context{
locked: make(chan struct{}, 1),
}
}
func (c *Context) Lock(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case c.locked <- struct{}{}:
c.lock.Lock()
return nil
}
}
func (c *Context) Unlock() {
c.lock.Unlock()
<-c.locked
}
func (c *Context) RLock(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case c.locked <- struct{}{}:
c.lock.RLock()
return nil
}
}
func (c *Context) RUnlock() {
c.lock.RUnlock()
<-c.locked
}

View File

@ -0,0 +1,81 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package lock
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func Test_Context(t *testing.T) {
tests := map[string]struct {
name string
action func(l *Context) error
expectError bool
}{
"Successful Lock": {
action: func(l *Context) error {
return l.Lock(t.Context())
},
expectError: false,
},
"Lock with Context Timeout": {
action: func(l *Context) error {
l.Lock(t.Context())
ctx, cancel := context.WithTimeout(t.Context(), time.Millisecond*50)
defer cancel()
return l.Lock(ctx)
},
expectError: true,
},
"Successful RLock": {
action: func(l *Context) error {
return l.RLock(t.Context())
},
expectError: false,
},
"RLock with Context Timeout": {
action: func(l *Context) error {
l.Lock(t.Context())
ctx, cancel := context.WithTimeout(t.Context(), time.Millisecond*50)
defer cancel()
return l.RLock(ctx)
},
expectError: true,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
t.Parallel()
l := NewContext()
done := make(chan error)
go func() {
done <- test.action(l)
}()
select {
case err := <-done:
assert.Equal(t, (err != nil), test.expectError, "unexpected error, expected error: %v, got: %v", test.expectError, err)
case <-time.After(time.Second):
t.Errorf("test timed out")
}
})
}
}

View File

@ -0,0 +1,199 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package lock
import (
"context"
"errors"
"sync"
"time"
"github.com/dapr/kit/concurrency/fifo"
)
var errLockClosed = errors.New("lock closed")
type hold struct {
writeLock bool
rctx context.Context
respCh chan *holdresp
}
type holdresp struct {
rctx context.Context
cancel context.CancelFunc
err error
}
type OuterCancel struct {
ch chan *hold
cancelErr error
gracefulTimeout time.Duration
lock chan struct{}
wg sync.WaitGroup
rcancelLock sync.Mutex
rcancelx uint64
rcancels map[uint64]context.CancelFunc
closeCh chan struct{}
shutdownLock *fifo.Mutex
}
func NewOuterCancel(cancelErr error, gracefulTimeout time.Duration) *OuterCancel {
return &OuterCancel{
lock: make(chan struct{}, 1),
ch: make(chan *hold, 1),
rcancels: make(map[uint64]context.CancelFunc),
closeCh: make(chan struct{}),
shutdownLock: fifo.New(),
cancelErr: cancelErr,
gracefulTimeout: gracefulTimeout,
}
}
func (o *OuterCancel) Run(ctx context.Context) {
defer func() {
o.rcancelLock.Lock()
defer o.rcancelLock.Unlock()
for _, cancel := range o.rcancels {
go cancel()
}
}()
go func() {
<-ctx.Done()
close(o.closeCh)
}()
for {
select {
case <-o.closeCh:
return
case h := <-o.ch:
o.handleHold(h)
}
}
}
func (o *OuterCancel) handleHold(h *hold) {
if h.rctx != nil {
select {
case o.lock <- struct{}{}:
case <-h.rctx.Done():
h.respCh <- &holdresp{err: h.rctx.Err()}
return
}
} else {
o.lock <- struct{}{}
}
o.rcancelLock.Lock()
if h.writeLock {
for _, cancel := range o.rcancels {
go cancel()
}
o.rcancelx = 0
o.rcancelLock.Unlock()
o.wg.Wait()
h.respCh <- &holdresp{cancel: func() { <-o.lock }}
return
}
o.wg.Add(1)
var done bool
doneCh := make(chan bool)
rctx, cancel := context.WithCancelCause(h.rctx)
i := o.rcancelx
rcancel := func() {
o.rcancelLock.Lock()
if !done {
close(doneCh)
cancel(o.cancelErr)
delete(o.rcancels, i)
o.wg.Done()
done = true
}
o.rcancelLock.Unlock()
}
rcancelGrace := func() {
select {
case <-time.After(o.gracefulTimeout):
case <-o.closeCh:
case <-doneCh:
}
rcancel()
}
o.rcancels[i] = rcancelGrace
o.rcancelx++
o.rcancelLock.Unlock()
h.respCh <- &holdresp{rctx: rctx, cancel: rcancel}
<-o.lock
}
func (o *OuterCancel) Lock() context.CancelFunc {
h := hold{
writeLock: true,
respCh: make(chan *holdresp, 1),
}
select {
case <-o.closeCh:
o.shutdownLock.Lock()
return o.shutdownLock.Unlock
case o.ch <- &h:
}
select {
case <-o.closeCh:
o.shutdownLock.Lock()
return o.shutdownLock.Unlock
case resp := <-h.respCh:
return resp.cancel
}
}
func (o *OuterCancel) RLock(ctx context.Context) (context.Context, context.CancelFunc, error) {
h := hold{
writeLock: false,
rctx: ctx,
respCh: make(chan *holdresp, 1),
}
select {
case <-o.closeCh:
return nil, nil, errLockClosed
case <-ctx.Done():
return nil, nil, ctx.Err()
case o.ch <- &h:
}
select {
case <-o.closeCh:
return nil, nil, errLockClosed
case resp := <-h.respCh:
return resp.rctx, resp.cancel, resp.err
}
}

View File

@ -0,0 +1,225 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package lock
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_OuterCancel(t *testing.T) {
t.Parallel()
//nolint:err113
terr := errors.New("test")
t.Run("can rlock multiple times", func(t *testing.T) {
t.Parallel()
l := NewOuterCancel(terr, time.Second)
ctx, cancel := context.WithCancel(t.Context())
t.Cleanup(cancel)
go l.Run(ctx)
ctx1, c1, err := l.RLock(ctx)
require.NoError(t, err)
ctx2, c2, err := l.RLock(ctx)
require.NoError(t, err)
ctx3, c3, err := l.RLock(ctx)
require.NoError(t, err)
require.NoError(t, ctx1.Err())
require.NoError(t, ctx2.Err())
require.NoError(t, ctx3.Err())
c1()
require.Error(t, ctx1.Err())
require.NoError(t, ctx2.Err())
require.NoError(t, ctx3.Err())
c2()
require.Error(t, ctx1.Err())
require.Error(t, ctx2.Err())
require.NoError(t, ctx3.Err())
c3()
require.Error(t, ctx1.Err())
require.Error(t, ctx2.Err())
require.Error(t, ctx3.Err())
})
t.Run("rlock unlock removes cancel state", func(t *testing.T) {
t.Parallel()
l := NewOuterCancel(terr, time.Second)
ctx, cancel := context.WithCancel(t.Context())
t.Cleanup(cancel)
go l.Run(ctx)
_, c1, err := l.RLock(ctx)
require.NoError(t, err)
_, c2, err := l.RLock(ctx)
require.NoError(t, err)
_, c3, err := l.RLock(ctx)
require.NoError(t, err)
assert.Len(t, l.rcancels, 3)
c1()
assert.Len(t, l.rcancels, 2)
c2()
assert.Len(t, l.rcancels, 1)
c3()
assert.Empty(t, l.rcancels, 0)
})
t.Run("calling lock cancels all current rlocks", func(t *testing.T) {
t.Parallel()
l := NewOuterCancel(terr, time.Second)
ctx, cancel := context.WithCancel(t.Context())
t.Cleanup(cancel)
go l.Run(ctx)
ctx1, _, err := l.RLock(ctx)
require.NoError(t, err)
ctx2, _, err := l.RLock(ctx)
require.NoError(t, err)
ctx3, _, err := l.RLock(ctx)
require.NoError(t, err)
require.NoError(t, ctx1.Err())
require.NoError(t, ctx2.Err())
require.NoError(t, ctx3.Err())
mcancel := l.Lock()
require.Error(t, ctx1.Err())
require.Error(t, ctx2.Err())
require.Error(t, ctx3.Err())
mcancel()
assert.Empty(t, l.rcancels)
})
t.Run("rlock when closed should error", func(t *testing.T) {
t.Parallel()
l := NewOuterCancel(terr, time.Second)
ctx, cancel := context.WithCancel(t.Context())
cancel()
go l.Run(ctx)
select {
case <-l.closeCh:
case <-time.After(time.Second * 5):
assert.Fail(t, "expected close")
}
_, _, err := l.RLock(t.Context())
require.Error(t, err)
})
t.Run("lock continues to work after close", func(t *testing.T) {
t.Parallel()
l := NewOuterCancel(terr, time.Second)
ctx, cancel := context.WithCancel(t.Context())
cancel()
l.Run(ctx)
lcancel := l.Lock()
lcancel()
lcancel = l.Lock()
lcancel()
})
t.Run("rlock blocks until outter unlocks", func(t *testing.T) {
t.Parallel()
l := NewOuterCancel(terr, time.Second)
ctx, cancel := context.WithCancel(t.Context())
t.Cleanup(cancel)
go l.Run(ctx)
lcancel := l.Lock()
gotRLock := make(chan struct{})
errCh := make(chan error, 1)
go func() {
_, c1, err := l.RLock(ctx)
errCh <- err
t.Cleanup(c1)
close(gotRLock)
}()
t.Cleanup(func() {
require.NoError(t, <-errCh)
})
select {
case <-time.After(time.Millisecond * 500):
case <-gotRLock:
require.Fail(t, "unexpected rlock")
}
lcancel()
<-gotRLock
})
t.Run("lock blocks until outter unlocks", func(t *testing.T) {
t.Parallel()
l := NewOuterCancel(terr, time.Second)
ctx, cancel := context.WithCancel(t.Context())
t.Cleanup(cancel)
go l.Run(ctx)
lcancel := l.Lock()
gotLock := make(chan struct{})
go func() {
lockcancel := l.Lock()
t.Cleanup(lockcancel)
close(gotLock)
}()
select {
case <-time.After(time.Millisecond * 500):
case <-gotLock:
require.Fail(t, "unexpected rlock")
}
lcancel()
})
}

View File

@ -60,17 +60,12 @@ func (r *RunnerManager) Run(ctx context.Context) error {
return ErrManagerAlreadyStarted
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
errCh := make(chan error)
for _, runner := range r.runners {
go func(runner Runner) {
// Since the task returned, we need to cancel all other tasks.
// This is a noop if the parent context is already cancelled, or another
// task returned before this one.
defer cancel()
// Ignore context cancelled errors since errors from a runner manager
// will likely determine the exit code of the program.
// Context cancelled errors are also not really useful to the user in
@ -78,15 +73,20 @@ func (r *RunnerManager) Run(ctx context.Context) error {
rErr := runner(ctx)
if rErr != nil && !errors.Is(rErr, context.Canceled) {
errCh <- rErr
// Since the task returned, we need to cancel all other tasks.
// This is a noop if the parent context is already cancelled, or another
// task returned before this one.
cancel(rErr)
return
}
errCh <- nil
cancel(nil)
}(runner)
}
// Collect all errors
errObjs := make([]error, 0)
for i := 0; i < len(r.runners); i++ {
for range len(r.runners) {
err := <-errCh
if err != nil {
errObjs = append(errObjs, err)

View File

@ -27,7 +27,7 @@ import (
func Test_RunnerManager(t *testing.T) {
t.Run("runner with no tasks should return nil", func(t *testing.T) {
require.NoError(t, NewRunnerManager().Run(context.Background()))
require.NoError(t, NewRunnerManager().Run(t.Context()))
})
t.Run("runner with a task that completes should return nil", func(t *testing.T) {
@ -35,7 +35,7 @@ func Test_RunnerManager(t *testing.T) {
require.NoError(t, NewRunnerManager(func(ctx context.Context) error {
atomic.AddInt32(&i, 1)
return nil
}).Run(context.Background()))
}).Run(t.Context()))
assert.Equal(t, int32(1), i)
})
@ -54,7 +54,7 @@ func Test_RunnerManager(t *testing.T) {
atomic.AddInt32(&i, 1)
return nil
},
).Run(context.Background()))
).Run(t.Context()))
assert.Equal(t, int32(3), i)
})
@ -73,7 +73,7 @@ func Test_RunnerManager(t *testing.T) {
atomic.AddInt32(&i, 1)
return nil
},
).Run(context.Background()), "error")
).Run(t.Context()), "error")
assert.Equal(t, int32(3), i)
})
@ -92,7 +92,7 @@ func Test_RunnerManager(t *testing.T) {
atomic.AddInt32(&i, 1)
return errors.New("error")
},
).Run(context.Background())
).Run(t.Context())
require.Error(t, err)
require.ErrorContains(t, err, "error\nerror\nerror") //nolint:dupword
assert.Equal(t, int32(3), i)
@ -113,7 +113,7 @@ func Test_RunnerManager(t *testing.T) {
atomic.AddInt32(&i, 1)
return errors.New("error3")
},
).Run(context.Background())
).Run(t.Context())
require.Error(t, err)
assert.ElementsMatch(t, []string{"error1", "error2", "error3"}, strings.Split(err.Error(), "\n"))
assert.Equal(t, int32(3), i)
@ -139,7 +139,7 @@ func Test_RunnerManager(t *testing.T) {
return nil
},
))
require.NoError(t, mngr.Run(context.Background()))
require.NoError(t, mngr.Run(t.Context()))
assert.Equal(t, int32(3), i)
})
@ -168,7 +168,7 @@ func Test_RunnerManager(t *testing.T) {
}
return nil
},
).Run(context.Background()))
).Run(t.Context()))
assert.Equal(t, int32(3), i)
})
@ -197,7 +197,7 @@ func Test_RunnerManager(t *testing.T) {
atomic.AddInt32(&i, 1)
return errors.New("error3")
},
).Run(context.Background())
).Run(t.Context())
require.Error(t, err)
assert.ElementsMatch(t, []string{"error1", "error2", "error3"}, strings.Split(err.Error(), "\n"))
assert.Equal(t, int32(3), i)
@ -209,9 +209,9 @@ func Test_RunnerManager(t *testing.T) {
atomic.AddInt32(&i, 1)
return nil
})
require.NoError(t, m.Run(context.Background()))
require.NoError(t, m.Run(t.Context()))
assert.Equal(t, int32(1), i)
require.EqualError(t, m.Run(context.Background()), "runner manager already started")
require.EqualError(t, m.Run(t.Context()), "runner manager already started")
assert.Equal(t, int32(1), i)
})
@ -221,7 +221,7 @@ func Test_RunnerManager(t *testing.T) {
atomic.AddInt32(&i, 1)
return nil
})
require.NoError(t, m.Run(context.Background()))
require.NoError(t, m.Run(t.Context()))
assert.Equal(t, int32(1), i)
err := m.Add(func(ctx context.Context) error {
atomic.AddInt32(&i, 1)

View File

@ -0,0 +1,58 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package slice
import "sync"
// Slice is a concurrent safe types slice
type Slice[T any] interface {
Append(items ...T) int
Len() int
Slice() []T
Store(items ...T)
}
type slice[T any] struct {
lock sync.RWMutex
data []T
}
func New[T any]() Slice[T] {
return new(slice[T])
}
func (s *slice[T]) Append(items ...T) int {
s.lock.Lock()
defer s.lock.Unlock()
s.data = append(s.data, items...)
return len(s.data)
}
func (s *slice[T]) Len() int {
s.lock.RLock()
defer s.lock.RUnlock()
return len(s.data)
}
func (s *slice[T]) Slice() []T {
s.lock.RLock()
defer s.lock.RUnlock()
return s.data
}
func (s *slice[T]) Store(items ...T) {
s.lock.Lock()
defer s.lock.Unlock()
s.data = items
}

View File

@ -0,0 +1,18 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package slice
func String() Slice[string] {
return new(slice[string])
}

View File

@ -81,7 +81,7 @@ func decodeString(f reflect.Type, t reflect.Type, data any) (any, error) {
if t.Implements(typeStringDecoder) {
result = reflect.New(t.Elem()).Interface()
decoder = result.(StringDecoder)
} else if reflect.PtrTo(t).Implements(typeStringDecoder) {
} else if reflect.PointerTo(t).Implements(typeStringDecoder) {
result = reflect.New(t).Interface()
decoder = result.(StringDecoder)
}

View File

@ -52,6 +52,9 @@ func NewPool(ctx ...context.Context) *Pool {
go func() {
defer cancel()
defer p.lock.RUnlock()
//nolint:intrange
// for loops are evaluated on every loop while range are evaluated over a snapshot of the slice as it
// existed when the loop started
for i := 0; i < len(p.pool); i++ {
ch := p.pool[i]
p.lock.RUnlock()

View File

@ -37,7 +37,7 @@ func Test_Pool(t *testing.T) {
t.Run("a cancelled context given to pool, should have pool cancelled", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(t.Context())
cancel()
pool := NewPool(ctx)
select {
@ -49,10 +49,10 @@ func Test_Pool(t *testing.T) {
t.Run("a cancelled context given to pool, given a new context, should still have pool cancelled", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(t.Context())
cancel()
pool := NewPool(ctx)
pool.Add(context.Background())
pool.Add(t.Context())
select {
case <-pool.Done():
case <-time.After(time.Second):
@ -65,13 +65,13 @@ func Test_Pool(t *testing.T) {
var ctx [50]context.Context
var cancel [50]context.CancelFunc
ctx[0], cancel[0] = context.WithCancel(context.Background())
pool := NewPool(ctx[0])
ctxPool := make([]context.Context, 0, 50)
for i := 1; i < 50; i++ {
ctx[i], cancel[i] = context.WithCancel(context.Background())
pool.Add(ctx[i])
for i := range 50 {
ctx[i], cancel[i] = context.WithCancel(t.Context())
ctxPool = append(ctxPool, ctx[i])
}
pool := NewPool(ctxPool...)
//nolint:gosec
r := rand.New(rand.NewSource(time.Now().UnixNano()))
@ -80,7 +80,7 @@ func Test_Pool(t *testing.T) {
cancel[i], cancel[j] = cancel[j], cancel[i]
})
for i := 0; i < 50; i++ {
for i := range 50 {
select {
case <-pool.Done():
t.Error("expected context to not be cancelled")
@ -99,8 +99,8 @@ func Test_Pool(t *testing.T) {
t.Run("pool size will not increase if the given contexts have been cancelled", func(t *testing.T) {
t.Parallel()
ctx1, cancel1 := context.WithCancel(context.Background())
ctx2, cancel2 := context.WithCancel(context.Background())
ctx1, cancel1 := context.WithCancel(t.Context())
ctx2, cancel2 := context.WithCancel(t.Context())
pool := NewPool(ctx1, ctx2)
assert.Equal(t, 2, pool.Size())
@ -111,19 +111,19 @@ func Test_Pool(t *testing.T) {
case <-time.After(time.Second):
t.Error("expected context pool to be cancelled")
}
pool.Add(context.Background())
pool.Add(t.Context())
assert.Equal(t, 2, pool.Size())
})
t.Run("pool size will not increase if the pool has been closed", func(t *testing.T) {
t.Parallel()
ctx1 := context.Background()
ctx2 := context.Background()
ctx1 := t.Context()
ctx2 := t.Context()
pool := NewPool(ctx1, ctx2)
assert.Equal(t, 2, pool.Size())
pool.Cancel()
pool.Add(context.Background())
pool.Add(t.Context())
assert.Equal(t, 0, pool.Size())
select {
case <-pool.Done():
@ -131,4 +131,24 @@ func Test_Pool(t *testing.T) {
t.Error("expected context pool to be cancelled")
}
})
t.Run("wait for added context to be closed", func(t *testing.T) {
t.Parallel()
ctx1, cancel1 := context.WithCancel(t.Context())
pool := NewPool(ctx1)
ctx2, cancel2 := context.WithCancel(t.Context())
pool.Add(ctx2)
assert.Equal(t, 2, pool.Size())
cancel1()
select {
case <-pool.Done():
t.Error("expected context pool to not be cancelled")
case <-time.After(10 * time.Millisecond):
}
cancel2()
})
}

View File

@ -170,7 +170,7 @@ func TestChainDelayIfStillRunning(t *testing.T) {
assert.EventuallyWithT(t, func(c *assert.CollectT) {
started, done = j.Started(), j.Done()
if started != 2 || done != 2 {
c.Errorf("expected both jobs done, got %v %v", started, done) //nolint:testifylint
c.Errorf("expected both jobs done, got %v %v", started, done)
}
}, 100*time.Millisecond, 10*time.Millisecond)
})
@ -230,7 +230,7 @@ func TestChainSkipIfStillRunning(t *testing.T) {
var j countJob
j.delay = 10 * time.Millisecond
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
for i := 0; i < 11; i++ {
for range 11 {
go wrappedJob.Run()
}
assert.Eventually(t, j.clock.HasWaiters, 50*time.Millisecond, 10*time.Millisecond)
@ -248,7 +248,7 @@ func TestChainSkipIfStillRunning(t *testing.T) {
chain := NewChain(SkipIfStillRunning(DiscardLogger))
wrappedJob1 := chain.Then(&j1)
wrappedJob2 := chain.Then(&j2)
for i := 0; i < 11; i++ {
for range 11 {
go wrappedJob1.Run()
go wrappedJob2.Run()
}

View File

@ -14,31 +14,23 @@ You can check the original license at:
https://github.com/robfig/cron/blob/master/LICENSE
*/
//nolint
package cron
import "time"
// ConstantDelaySchedule represents a simple recurring duty cycle, e.g. "Every 5 minutes".
// It does not support jobs more frequent than once a second.
type ConstantDelaySchedule struct {
Delay time.Duration
}
// Every returns a crontab Schedule that activates once every duration.
// Delays of less than a second are not supported (will round up to 1 second).
// Any fields less than a Second are truncated.
func Every(duration time.Duration) ConstantDelaySchedule {
if duration < time.Second {
duration = time.Second
}
return ConstantDelaySchedule{
Delay: duration - time.Duration(duration.Nanoseconds())%time.Second,
Delay: duration,
}
}
// Next returns the next time this should be run.
// This rounds so that the next activation time will be on the second.
func (schedule ConstantDelaySchedule) Next(t time.Time) time.Time {
return t.Add(schedule.Delay - time.Duration(t.Nanosecond())*time.Nanosecond)
return t.Add(schedule.Delay)
}

View File

@ -14,7 +14,6 @@ You can check the original license at:
https://github.com/robfig/cron/blob/master/LICENSE
*/
//nolint
package cron
import (
@ -29,9 +28,12 @@ func TestConstantDelayNext(t *testing.T) {
expected string
}{
// Simple cases
{"Mon Jul 9 14:45 2012", 15*time.Minute + 50*time.Nanosecond, "Mon Jul 9 15:00 2012"},
{"Mon Jul 9 14:45 2012", 15*time.Minute + 50*time.Nanosecond, "Mon Jul 9 15:00:00.00000005 2012"},
{"Mon Jul 9 14:59 2012", 15 * time.Minute, "Mon Jul 9 15:14 2012"},
{"Mon Jul 9 14:59:59 2012", 15 * time.Minute, "Mon Jul 9 15:14:59 2012"},
{"Mon Jul 9 14:45:00 2012", 15 * time.Millisecond, "Mon Jul 9 14:45:00.015 2012"},
{"Mon Jul 9 14:45:00.015 2012", 15 * time.Millisecond, "Mon Jul 9 14:45:00.030 2012"},
{"Mon Jul 9 14:45:00.000000050 2012", 15 * time.Nanosecond, "Mon Jul 9 14:45:00.000000065 2012"},
// Wrap around hours
{"Mon Jul 9 15:45 2012", 35 * time.Minute, "Mon Jul 9 16:20 2012"},
@ -47,18 +49,6 @@ func TestConstantDelayNext(t *testing.T) {
// Wrap around minute, hour, day, month, and year
{"Mon Dec 31 23:59:45 2012", 15 * time.Second, "Tue Jan 1 00:00:00 2013"},
// Round to nearest second on the delay
{"Mon Jul 9 14:45 2012", 15*time.Minute + 50*time.Nanosecond, "Mon Jul 9 15:00 2012"},
// Round up to 1 second if the duration is less.
{"Mon Jul 9 14:45:00 2012", 15 * time.Millisecond, "Mon Jul 9 14:45:01 2012"},
// Round to nearest second when calculating the next time.
{"Mon Jul 9 14:45:00.005 2012", 15 * time.Minute, "Mon Jul 9 15:00 2012"},
// Round to nearest second for both.
{"Mon Jul 9 14:45:00.005 2012", 15*time.Minute + 50*time.Nanosecond, "Mon Jul 9 15:00 2012"},
}
for _, c := range tests {

View File

@ -14,7 +14,6 @@ You can check the original license at:
https://github.com/robfig/cron/blob/master/LICENSE
*/
//nolint:dupword
package cron
import (
@ -35,7 +34,7 @@ import (
// for it to run. This amount is just slightly larger than 1 second to
// compensate for a few milliseconds of runtime.
//
//nolint:revive
const OneSecond = 1*time.Second + 50*time.Millisecond
type syncWriter struct {
@ -783,13 +782,68 @@ func TestMockClock(t *testing.T) {
})
cron.Start()
defer cron.Stop()
for i := 0; i <= 10; i++ {
for range 11 {
assert.Eventually(t, clk.HasWaiters, OneSecond, 10*time.Millisecond)
clk.Step(1 * time.Second)
}
assert.Equal(t, int64(10), counter.Load())
}
func TestMillisecond(t *testing.T) {
clk := clocktesting.NewFakeClock(time.Now())
cron := New(WithClock(clk))
counter1ms := atomic.Int64{}
counter15ms := atomic.Int64{}
counter100ms := atomic.Int64{}
cron.AddFunc("@every 1ms", func() {
counter1ms.Add(1)
})
cron.AddFunc("@every 15ms", func() {
counter15ms.Add(1)
})
cron.AddFunc("@every 100ms", func() {
counter100ms.Add(1)
})
cron.Start()
defer cron.Stop()
for range 1000 {
assert.Eventually(t, clk.HasWaiters, OneSecond, 1*time.Millisecond)
clk.Step(1 * time.Millisecond)
}
ctx := cron.Stop()
<-ctx.Done()
assert.Equal(t, int64(1000), counter1ms.Load())
assert.Equal(t, int64(66), counter15ms.Load())
assert.Equal(t, int64(10), counter100ms.Load())
}
func TestNanoseconds(t *testing.T) {
clk := clocktesting.NewFakeClock(time.Now())
cron := New(WithClock(clk))
counter100ns := atomic.Int64{}
cron.AddFunc("@every 100ns", func() {
counter100ns.Add(1)
})
cron.Start()
defer cron.Stop()
for range 500 {
assert.Eventually(t, clk.HasWaiters, OneSecond, 1*time.Millisecond)
clk.Step(5 * time.Nanosecond)
}
ctx := cron.Stop()
<-ctx.Done()
// 500 * 5 ns = 2500 ns
// 2500 every 100ns = 25
assert.Equal(t, int64(25), counter100ns.Load())
}
func TestMultiThreadedStartAndStop(*testing.T) {
cron := New()
go cron.Run()

View File

@ -1,4 +1,3 @@
//nolint
/*
This package is a fork of "github.com/robfig/cron/v3" that implements cron spec parser and job runner with support for mocking the time.
@ -36,7 +35,9 @@ them in their own goroutines.
# Time mocking
import (
clocktesting "k8s.io/utils/clock/testing"
)
clk := clocktesting.NewFakeClock(time.Now())

View File

@ -59,7 +59,7 @@ func TestWithVerboseLogger(t *testing.T) {
out := buf.String()
if !strings.Contains(out, "schedule,") ||
!strings.Contains(out, "run,") {
c.Errorf("expected to see some actions, got: %v", out) //nolint:testifylint
c.Errorf("expected to see some actions, got: %v", out)
}
}, time.Second, time.Millisecond*10)
}

View File

@ -14,7 +14,6 @@ You can check the original license at:
https://github.com/robfig/cron/blob/master/LICENSE
*/
//nolint
package cron
import (
@ -167,6 +166,8 @@ func TestParseSchedule(t *testing.T) {
{standardParser, "CRON_TZ=UTC 5 * * * *", every5min(time.UTC)},
{secondParser, "CRON_TZ=Asia/Tokyo 0 5 * * * *", every5min(tokyo)},
{secondParser, "@every 5m", ConstantDelaySchedule{5 * time.Minute}},
{secondParser, "@every 5ms", ConstantDelaySchedule{5 * time.Millisecond}},
{secondParser, "@every 5ns", ConstantDelaySchedule{5 * time.Nanosecond}},
{secondParser, "@midnight", midnight(time.Local)},
{secondParser, "TZ=UTC @midnight", midnight(time.UTC)},
{secondParser, "TZ=Asia/Tokyo @midnight", midnight(tokyo)},

View File

@ -214,9 +214,9 @@ func (aead *aesCBCAEAD) Open(dst, nonce, ciphertext, additionalData []byte) ([]b
}
// Computes the HMAC tag as per specs.
func (aead aesCBCAEAD) hmacTag(h hash.Hash, additionalData, nonce, ciphertext []byte, l int) []byte {
func (aead *aesCBCAEAD) hmacTag(h hash.Hash, additionalData, nonce, ciphertext []byte, l int) []byte {
al := make([]byte, 8)
binary.BigEndian.PutUint64(al, uint64(len(additionalData)<<3)) // In bits
binary.BigEndian.PutUint64(al, uint64(len(additionalData)<<3)) // #nosec G115 // In bits
h.Write(additionalData)
h.Write(nonce)

View File

@ -48,14 +48,14 @@ func Wrap(block cipher.Block, cek []byte) ([]byte, error) {
copy(r[i], cek[i*8:])
}
for j := 0; j <= 5; j++ {
for j := range 6 {
for i := 1; i <= n; i++ {
b := arrConcat(a, r[i-1])
block.Encrypt(b, b)
t := (n * j) + i
tBytes := make([]byte, 8)
binary.BigEndian.PutUint64(tBytes, uint64(t))
binary.BigEndian.PutUint64(tBytes, uint64(t)) // #nosec G115
copy(a, arrXor(b[:len(b)/2], tBytes))
copy(r[i-1], b[len(b)/2:])
@ -92,7 +92,7 @@ func Unwrap(block cipher.Block, cipherText []byte) ([]byte, error) {
for i := n; i >= 1; i-- {
t := (n * j) + i
tBytes := make([]byte, 8)
binary.BigEndian.PutUint64(tBytes, uint64(t))
binary.BigEndian.PutUint64(tBytes, uint64(t)) // #nosec G115
b := arrConcat(arrXor(a, tBytes), r[i-1])
block.Decrypt(b, b)

View File

@ -11,7 +11,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
//nolint:nosnakecase,stylecheck,revive
//nolint:nosnakecase,stylecheck
package crypto
import (

211
crypto/pem/pem.go Normal file
View File

@ -0,0 +1,211 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package pem
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"os"
)
// DecodePEMCertificatesChain takes a PEM-encoded x509 certificates byte array
// and returns all certificates in a slice of x509.Certificate objects.
// Expects certificates to be a chain with leaf certificate to be first in the
// byte array.
func DecodePEMCertificatesChain(crtb []byte) ([]*x509.Certificate, error) {
certs, err := DecodePEMCertificates(crtb)
if err != nil {
return nil, err
}
for i := range len(certs) - 1 {
if certs[i].CheckSignatureFrom(certs[i+1]) != nil {
return nil, errors.New("certificate chain is not valid")
}
}
return certs, nil
}
// DecodePEMCertificatesChain takes a PEM-encoded x509 certificates byte array
// and returns all certificates in a slice of x509.Certificate objects.
func DecodePEMCertificates(crtb []byte) ([]*x509.Certificate, error) {
certs := []*x509.Certificate{}
for len(crtb) > 0 {
var err error
var cert *x509.Certificate
cert, crtb, err = decodeCertificatePEM(crtb)
if err != nil {
return nil, err
}
if cert != nil {
// it's a cert, add to pool
certs = append(certs, cert)
}
}
if len(certs) == 0 {
return nil, errors.New("no certificates found")
}
return certs, nil
}
func decodeCertificatePEM(crtb []byte) (*x509.Certificate, []byte, error) {
block, crtb := pem.Decode(crtb)
if block == nil {
return nil, nil, nil
}
if block.Type != "CERTIFICATE" {
return nil, nil, nil
}
c, err := x509.ParseCertificate(block.Bytes)
return c, crtb, err
}
// DecodePEMPrivateKey takes a key PEM byte array and returns an object that
// represents either an RSA or EC private key.
func DecodePEMPrivateKey(key []byte) (crypto.Signer, error) {
block, _ := pem.Decode(key)
if block == nil {
return nil, errors.New("key is not PEM encoded")
}
switch block.Type {
case "EC PRIVATE KEY":
return x509.ParseECPrivateKey(block.Bytes)
case "RSA PRIVATE KEY":
return x509.ParsePKCS1PrivateKey(block.Bytes)
case "PRIVATE KEY":
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, err
}
return key.(crypto.Signer), nil
default:
return nil, fmt.Errorf("unsupported block type %s", block.Type)
}
}
// EncodePrivateKey will encode a private key into PEM format.
func EncodePrivateKey(key any) ([]byte, error) {
var (
keyBytes []byte
err error
blockType string
)
switch key := key.(type) {
case *ecdsa.PrivateKey, *ed25519.PrivateKey:
keyBytes, err = x509.MarshalPKCS8PrivateKey(key)
if err != nil {
return nil, err
}
blockType = "PRIVATE KEY"
default:
return nil, fmt.Errorf("unsupported key type %T", key)
}
return pem.EncodeToMemory(&pem.Block{
Type: blockType, Bytes: keyBytes,
}), nil
}
// EncodeX509 will encode a single *x509.Certificate into PEM format.
func EncodeX509(cert *x509.Certificate) ([]byte, error) {
caPem := bytes.NewBuffer([]byte{})
err := pem.Encode(caPem, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
if err != nil {
return nil, err
}
return caPem.Bytes(), nil
}
// EncodeX509Chain will encode a list of *x509.Certificates into a PEM format chain.
// Self-signed certificates are not included as per
// https://datatracker.ietf.org/doc/html/rfc5246#section-7.4.2
// Certificates are output in the order they're given; if the input is not ordered
// as specified in RFC5246 section 7.4.2, the resulting chain might not be valid
// for use in TLS.
func EncodeX509Chain(certs []*x509.Certificate) ([]byte, error) {
if len(certs) == 0 {
return nil, errors.New("no certificates in chain")
}
certPEM := bytes.NewBuffer([]byte{})
for _, cert := range certs {
if cert == nil {
continue
}
if cert.CheckSignatureFrom(cert) == nil {
// Don't include self-signed certificate
continue
}
err := pem.Encode(certPEM, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
if err != nil {
return nil, err
}
}
return certPEM.Bytes(), nil
}
// PublicKeysEqual compares two given public keys for equality.
// The definition of "equality" depends on the type of the public keys.
// Returns true if the keys are the same, false if they differ or an error if
// the key type of `a` cannot be determined.
func PublicKeysEqual(a, b crypto.PublicKey) (bool, error) {
switch pub := a.(type) {
case *rsa.PublicKey:
return pub.Equal(b), nil
case *ecdsa.PublicKey:
return pub.Equal(b), nil
case ed25519.PublicKey:
return pub.Equal(b), nil
default:
return false, fmt.Errorf("unrecognised public key type: %T", a)
}
}
// GetPEM loads a PEM-encoded file (certificate or key).
func GetPEM(val string) ([]byte, error) {
// If val is already a PEM-encoded string, return it as-is
if IsValidPEM(val) {
return []byte(val), nil
}
// Assume it's a file
pemBytes, err := os.ReadFile(val)
if err != nil {
return nil, fmt.Errorf("value is neither a valid file path or nor a valid PEM-encoded string: %w", err)
}
return pemBytes, nil
}
// IsValidPEM validates the provided input has PEM formatted block.
func IsValidPEM(val string) bool {
block, _ := pem.Decode([]byte(val))
return block != nil
}

View File

@ -0,0 +1,71 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package context
import (
"context"
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
"github.com/dapr/kit/crypto/spiffe"
)
type ctxkey int
const (
x509SvidKey ctxkey = iota
jwtSvidKey
)
// Deprecated: use WithX509 instead.
// With adds the x509 SVID source from the SPIFFE object to the context.
func With(ctx context.Context, spiffe *spiffe.SPIFFE) context.Context {
return context.WithValue(ctx, x509SvidKey, spiffe.X509SVIDSource())
}
// Deprecated: use X509From instead.
// From retrieves the x509 SVID source from the context.
func From(ctx context.Context) (x509svid.Source, bool) {
svid, ok := ctx.Value(x509SvidKey).(x509svid.Source)
return svid, ok
}
// WithX509 adds an x509 SVID source to the context.
func WithX509(ctx context.Context, source x509svid.Source) context.Context {
return context.WithValue(ctx, x509SvidKey, source)
}
// WithJWT adds a JWT SVID source to the context.
func WithJWT(ctx context.Context, source jwtsvid.Source) context.Context {
return context.WithValue(ctx, jwtSvidKey, source)
}
// X509From retrieves the x509 SVID source from the context.
func X509From(ctx context.Context) (x509svid.Source, bool) {
svid, ok := ctx.Value(x509SvidKey).(x509svid.Source)
return svid, ok
}
// JWTFrom retrieves the JWT SVID source from the context.
func JWTFrom(ctx context.Context) (jwtsvid.Source, bool) {
svid, ok := ctx.Value(jwtSvidKey).(jwtsvid.Source)
return svid, ok
}
// WithSpiffe adds both X509 and JWT SVID sources to the context.
func WithSpiffe(ctx context.Context, spiffe *spiffe.SPIFFE) context.Context {
ctx = WithX509(ctx, spiffe.X509SVIDSource())
return WithJWT(ctx, spiffe.JWTSVIDSource())
}

View File

@ -0,0 +1,64 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package context
import (
"context"
"testing"
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
"github.com/stretchr/testify/assert"
)
type mockX509Source struct{}
func (m *mockX509Source) GetX509SVID() (*x509svid.SVID, error) {
return nil, nil
}
type mockJWTSource struct{}
func (m *mockJWTSource) FetchJWTSVID(context.Context, jwtsvid.Params) (*jwtsvid.SVID, error) {
return nil, nil
}
func TestWithX509FromX509(t *testing.T) {
source := &mockX509Source{}
ctx := WithX509(t.Context(), source)
retrieved, ok := X509From(ctx)
assert.True(t, ok, "Failed to retrieve X509 source from context")
assert.Equal(t, x509svid.Source(source), retrieved, "Retrieved source does not match the original source")
}
func TestWithJWTFromJWT(t *testing.T) {
source := &mockJWTSource{}
ctx := WithJWT(t.Context(), source)
retrieved, ok := JWTFrom(ctx)
assert.True(t, ok, "Failed to retrieve JWT source from context")
assert.Equal(t, jwtsvid.Source(source), retrieved, "Retrieved source does not match the original source")
}
func TestWithFrom(t *testing.T) {
x509Source := &mockX509Source{}
ctx := WithX509(t.Context(), x509Source)
// Should be able to retrieve using the legacy From function
retrieved, ok := From(ctx)
assert.True(t, ok, "Failed to retrieve X509 source from context using legacy From")
assert.Equal(t, x509svid.Source(x509Source), retrieved, "Retrieved source does not match the original source using legacy From")
}

352
crypto/spiffe/spiffe.go Normal file
View File

@ -0,0 +1,352 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package spiffe
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
"k8s.io/utils/clock"
"github.com/dapr/kit/concurrency/dir"
"github.com/dapr/kit/crypto/pem"
"github.com/dapr/kit/crypto/spiffe/trustanchors"
"github.com/dapr/kit/logger"
)
const (
// renewalDivisor represents the divisor for calculating renewal time.
// A value of 2 means renewal at 50% of the validity period.
renewalDivisor = 2
)
// SVIDResponse represents the response from the SVID request function,
// containing both X.509 certificates and a JWT token.
type SVIDResponse struct {
X509Certificates []*x509.Certificate
JWT *string
}
// Identity contains both X.509 and JWT SVIDs for a workload.
type Identity struct {
X509SVID *x509svid.SVID
JWTSVID *jwtsvid.SVID
}
type (
// RequestSVIDFn is the function type that requests SVIDs from a SPIFFE server,
// returning both X.509 certificates and a JWT token.
RequestSVIDFn func(context.Context, []byte) (*SVIDResponse, error)
)
type Options struct {
Log logger.Logger
RequestSVIDFn RequestSVIDFn
// WriteIdentityToFile is used to write the identity private key and
// certificate chain to file. The certificate chain and private key will be
// written to the `tls.cert` and `tls.key` files respectively in the given
// directory.
WriteIdentityToFile *string
TrustAnchors trustanchors.Interface
}
// SPIFFE is a readable/writeable store of SPIFFE SVID credentials.
// Used to manage workload SVIDs, and share read-only interfaces to consumers.
type SPIFFE struct {
currentX509SVID *x509svid.SVID
currentJWTSVID *jwtsvid.SVID
requestSVIDFn RequestSVIDFn
dir *dir.Dir
trustAnchors trustanchors.Interface
log logger.Logger
lock sync.RWMutex
clock clock.Clock
running atomic.Bool
readyCh chan struct{}
}
func New(opts Options) *SPIFFE {
var sdir *dir.Dir
if opts.WriteIdentityToFile != nil {
sdir = dir.New(dir.Options{
Log: opts.Log,
Target: *opts.WriteIdentityToFile,
})
}
return &SPIFFE{
requestSVIDFn: opts.RequestSVIDFn,
dir: sdir,
trustAnchors: opts.TrustAnchors,
log: opts.Log,
clock: clock.RealClock{},
readyCh: make(chan struct{}),
}
}
func (s *SPIFFE) Run(ctx context.Context) error {
if !s.running.CompareAndSwap(false, true) {
return errors.New("already running")
}
s.lock.Lock()
s.log.Info("Fetching initial identity")
initialIdentity, err := s.fetchIdentity(ctx)
if err != nil {
close(s.readyCh)
s.lock.Unlock()
return fmt.Errorf("failed to retrieve the initial identity: %w", err)
}
s.currentX509SVID = initialIdentity.X509SVID
s.currentJWTSVID = initialIdentity.JWTSVID
close(s.readyCh)
s.lock.Unlock()
s.log.Infof("Security is initialized successfully")
s.runRotation(ctx)
return nil
}
// Ready blocks until SPIFFE is ready or the context is done which will return
// the context error.
func (s *SPIFFE) Ready(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-s.readyCh:
return nil
}
}
// logIdentityInfo creates a log message with expiry details for both X.509 and JWT SVIDs
func (s *SPIFFE) logIdentityInfo(prefix string, cert *x509.Certificate, jwtSVID *jwtsvid.SVID, renewTime *time.Time) {
msg := prefix + "; cert expires on: %s"
args := []any{cert.NotAfter.String()}
if jwtSVID != nil {
msg += ", jwt expires on: %s"
args = append(args, jwtSVID.Expiry.String())
}
if renewTime != nil {
msg += ", renewal at: %s"
args = append(args, renewTime.String())
}
s.log.Infof(msg, args...)
}
// runRotation starts up the manager responsible for renewing the workload identity
func (s *SPIFFE) runRotation(ctx context.Context) {
defer s.log.Debug("stopping workload identity expiry watcher")
s.lock.RLock()
cert := s.currentX509SVID.Certificates[0]
jwtSVID := s.currentJWTSVID
s.lock.RUnlock()
renewTime := calculateRenewalTime(time.Now(), cert, jwtSVID)
s.logIdentityInfo("Starting workload identity expiry watcher", cert, jwtSVID, renewTime)
for {
select {
case <-s.clock.After(min(time.Minute, renewTime.Sub(s.clock.Now()))):
if s.clock.Now().Before(*renewTime) {
continue
}
s.logIdentityInfo("Renewing workload identity", cert, jwtSVID, nil)
identity, err := s.fetchIdentity(ctx)
if err != nil {
s.log.Errorf("Error renewing identity, trying again in 10 seconds: %s", err)
select {
case <-s.clock.After(10 * time.Second):
continue
case <-ctx.Done():
return
}
}
s.lock.Lock()
s.currentX509SVID = identity.X509SVID
s.currentJWTSVID = identity.JWTSVID
cert = identity.X509SVID.Certificates[0]
jwtSVID = identity.JWTSVID
s.lock.Unlock()
renewTime = calculateRenewalTime(time.Now(), cert, jwtSVID)
s.logIdentityInfo("Successfully renewed workload identity", cert, jwtSVID, renewTime)
case <-ctx.Done():
return
}
}
}
// Returns both X.509 SVID and JWT SVID (if available).
func (s *SPIFFE) fetchIdentity(ctx context.Context) (*Identity, error) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, fmt.Errorf("failed to generate private key: %w", err)
}
csrDER, err := x509.CreateCertificateRequest(rand.Reader, new(x509.CertificateRequest), key)
if err != nil {
return nil, fmt.Errorf("failed to create sidecar csr: %w", err)
}
svidResponse, err := s.requestSVIDFn(ctx, csrDER)
if err != nil {
return nil, err
}
if len(svidResponse.X509Certificates) == 0 {
return nil, errors.New("no certificates received from sentry")
}
spiffeID, err := x509svid.IDFromCert(svidResponse.X509Certificates[0])
if err != nil {
return nil, fmt.Errorf("error parsing spiffe id from newly signed certificate: %w", err)
}
identity := &Identity{
X509SVID: &x509svid.SVID{
ID: spiffeID,
Certificates: svidResponse.X509Certificates,
PrivateKey: key,
},
}
// If we have a JWT token, parse it and include it in the identity
if svidResponse.JWT != nil {
// we are using ParseInsecure here as the expectation is that the
// requestSVIDFn will have already parsed and validate the JWT SVID
// before returning it.
//
// we are parsing the token using our SPIFFE ID's trust domain
// as the audience as we expect the issuer to always include
// that as an audience since that ensures that the token is
// valid for us and our trust domain.
audiences := []string{spiffeID.TrustDomain().Name()}
jwtSvid, err := jwtsvid.ParseInsecure(*svidResponse.JWT, audiences)
if err != nil {
return nil, fmt.Errorf("failed to parse JWT SVID: %w", err)
}
identity.JWTSVID = jwtSvid
s.log.Infof("Successfully received JWT SVID with expiry: %s", jwtSvid.Expiry.String())
}
if s.dir != nil {
pkPEM, err := pem.EncodePrivateKey(key)
if err != nil {
return nil, err
}
certPEM, err := pem.EncodeX509Chain(svidResponse.X509Certificates)
if err != nil {
return nil, err
}
td, err := s.trustAnchors.CurrentTrustAnchors(ctx)
if err != nil {
return nil, err
}
files := map[string][]byte{
"key.pem": pkPEM,
"cert.pem": certPEM,
"ca.pem": td,
}
if svidResponse.JWT != nil {
files["jwt_svid.token"] = []byte(*svidResponse.JWT)
}
if err := s.dir.Write(files); err != nil {
return nil, err
}
}
return identity, nil
}
func (s *SPIFFE) X509SVIDSource() x509svid.Source {
return &svidSource{spiffe: s}
}
func (s *SPIFFE) JWTSVIDSource() jwtsvid.Source {
return &svidSource{spiffe: s}
}
// renewalTime is 50% through the certificate validity period.
func renewalTime(notBefore, notAfter time.Time) time.Time {
return notBefore.Add(notAfter.Sub(notBefore) / renewalDivisor)
}
// calculateRenewalTime returns the earlier renewal time between the X.509 certificate
// and JWT SVID (if available) to ensure timely renewal.
func calculateRenewalTime(now time.Time, cert *x509.Certificate, jwtSVID *jwtsvid.SVID) *time.Time {
certRenewal := renewalTime(cert.NotBefore, cert.NotAfter)
if jwtSVID == nil {
return &certRenewal
}
jwtRenewal := now.Add(jwtSVID.Expiry.Sub(now) / renewalDivisor)
if jwtRenewal.Before(certRenewal) {
return &jwtRenewal
}
return &certRenewal
}
// audiencesMatch checks if the SVID audiences contain all the requested audiences
func audiencesMatch(svidAudiences []string, requestedAudiences []string) bool {
if len(requestedAudiences) == 0 {
return true
}
// Create a map for faster lookup
audienceMap := make(map[string]struct{}, len(svidAudiences))
for _, audience := range svidAudiences {
audienceMap[audience] = struct{}{}
}
// Check if all requested audiences are in the SVID
for _, requested := range requestedAudiences {
if _, ok := audienceMap[requested]; !ok {
return false
}
}
return true
}

View File

@ -0,0 +1,269 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package spiffe
import (
"context"
"crypto/x509"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
clocktesting "k8s.io/utils/clock/testing"
"github.com/dapr/kit/crypto/test"
"github.com/dapr/kit/logger"
)
func Test_renewalTime(t *testing.T) {
now := time.Now()
assert.Equal(t, now, renewalTime(now, now))
in1Min := now.Add(time.Minute)
in30 := now.Add(time.Second * 30)
assert.Equal(t, in30, renewalTime(now, in1Min))
}
func Test_calculateRenewalTime(t *testing.T) {
now := time.Now()
certShort := &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(10 * time.Hour),
}
certLong := &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(24 * time.Hour),
}
// Expected renewal times for certificates (50% of validity period)
certShortRenewal := now.Add(5 * time.Hour)
// Create JWT SVIDs with different expiry times
jwtEarlier := &jwtsvid.SVID{
Expiry: now.Add(8 * time.Hour),
}
jwtLater := &jwtsvid.SVID{
Expiry: now.Add(30 * time.Hour),
}
// Expected JWT renewal time (50% of remaining time)
jwtEarlierRenewal := now.Add(4 * time.Hour)
tests := []struct {
name string
cert *x509.Certificate
jwt *jwtsvid.SVID
expected time.Time
}{
{
name: "Certificate only",
cert: certShort,
jwt: nil,
expected: certShortRenewal,
},
{
name: "Certificate and JWT, JWT earlier",
cert: certLong,
jwt: jwtEarlier,
expected: jwtEarlierRenewal,
},
{
name: "Certificate and JWT, Certificate earlier",
cert: certShort,
jwt: jwtLater,
expected: certShortRenewal,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := calculateRenewalTime(now, tt.cert, tt.jwt)
assert.WithinDuration(t, tt.expected, *actual, time.Millisecond,
"Renewal time does not match expected value")
})
}
}
func Test_Run(t *testing.T) {
t.Run("should return error multiple Runs are called", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{
LeafID: spiffeid.RequireFromString("spiffe://example.com/foo/bar"),
})
ctx, cancel := context.WithCancel(t.Context())
s := New(Options{
Log: logger.NewLogger("test"),
RequestSVIDFn: func(context.Context, []byte) (*SVIDResponse, error) {
return &SVIDResponse{
X509Certificates: []*x509.Certificate{pki.LeafCert},
}, nil
},
})
errCh := make(chan error)
go func() {
errCh <- s.Run(ctx)
}()
go func() {
errCh <- s.Run(ctx)
}()
select {
case err := <-errCh:
require.Error(t, err)
case <-time.After(time.Second):
assert.Fail(t, "Expected error")
}
cancel()
select {
case err := <-errCh:
require.NoError(t, err)
case <-time.After(time.Second):
assert.Fail(t, "First Run should have returned and returned no error ")
}
})
t.Run("should return error if initial fetch errors", func(t *testing.T) {
s := New(Options{
Log: logger.NewLogger("test"),
RequestSVIDFn: func(context.Context, []byte) (*SVIDResponse, error) {
return nil, errors.New("this is an error")
},
})
require.Error(t, s.Run(t.Context()))
})
t.Run("should renew certificate when it has expired", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{
LeafID: spiffeid.RequireFromString("spiffe://example.com/foo/bar"),
})
var fetches atomic.Int32
s := New(Options{
Log: logger.NewLogger("test"),
RequestSVIDFn: func(context.Context, []byte) (*SVIDResponse, error) {
fetches.Add(1)
return &SVIDResponse{
X509Certificates: []*x509.Certificate{pki.LeafCert},
}, nil
},
})
now := time.Now()
clock := clocktesting.NewFakeClock(now)
s.clock = clock
ctx, cancel := context.WithCancel(t.Context())
errCh := make(chan error)
go func() {
select {
case <-s.readyCh:
assert.Fail(t, "readyCh should not be closed")
default:
}
errCh <- s.Run(ctx)
}()
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
assert.Equal(t, int32(1), fetches.Load())
clock.Step(pki.LeafCert.NotAfter.Sub(now) / 2)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
assert.Equal(c, int32(2), fetches.Load())
}, time.Second, time.Millisecond)
cancel()
select {
case err := <-errCh:
require.NoError(t, err)
case <-time.After(time.Second):
assert.Fail(t, "First Run should have returned and returned no error ")
}
})
t.Run("if renewal failed, should try again in 10 seconds", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{
LeafID: spiffeid.RequireFromString("spiffe://example.com/foo/bar"),
})
respCert := []*x509.Certificate{pki.LeafCert}
var respErr error
var fetches atomic.Int32
s := New(Options{
Log: logger.NewLogger("test"),
RequestSVIDFn: func(context.Context, []byte) (*SVIDResponse, error) {
fetches.Add(1)
return &SVIDResponse{
X509Certificates: respCert,
}, respErr
},
})
now := time.Now()
clock := clocktesting.NewFakeClock(now)
s.clock = clock
ctx, cancel := context.WithCancel(t.Context())
errCh := make(chan error)
go func() {
select {
case <-s.readyCh:
assert.Fail(t, "readyCh should not be closed")
default:
}
errCh <- s.Run(ctx)
}()
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
assert.Equal(t, int32(1), fetches.Load())
respCert = nil
respErr = errors.New("this is an error")
clock.Step(pki.LeafCert.NotAfter.Sub(now) / 2)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
assert.Equal(c, int32(2), fetches.Load())
}, time.Second, time.Millisecond)
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
clock.Step(time.Second * 5)
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
assert.Equal(t, int32(2), fetches.Load())
clock.Step(time.Second * 5)
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
clock.Step(1)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
assert.Equal(c, int32(3), fetches.Load())
}, time.Second, time.Millisecond)
cancel()
select {
case err := <-errCh:
require.NoError(t, err)
case <-time.After(time.Second):
assert.Fail(t, "First Run should have returned and returned no error ")
}
})
}

View File

@ -0,0 +1,95 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package spiffe
import (
"context"
"errors"
"fmt"
"strings"
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
)
var (
errNoX509SVIDAvailable = errors.New("no X509 SVID available")
errNoJWTSVIDAvailable = errors.New("no JWT SVID available")
errAudienceRequired = errors.New("JWT audience is required")
)
// svidSource is an implementation of both go-spiffe x509svid.Source and jwtsvid.Source interfaces.
type svidSource struct {
spiffe *SPIFFE
}
// GetX509SVID returns the current X.509 certificate identity as a SPIFFE SVID.
// Implements the go-spiffe x509svid.Source interface.
func (s *svidSource) GetX509SVID() (*x509svid.SVID, error) {
s.spiffe.lock.RLock()
defer s.spiffe.lock.RUnlock()
<-s.spiffe.readyCh
svid := s.spiffe.currentX509SVID
if svid == nil {
return nil, errNoX509SVIDAvailable
}
return svid, nil
}
// audienceMismatchError is an error that contains information about mismatched audiences
type audienceMismatchError struct {
expected []string
actual []string
}
func (e *audienceMismatchError) Error() string {
return fmt.Sprintf("JWT SVID has different audiences than requested: expected %s, got %s",
strings.Join(e.expected, ", "), strings.Join(e.actual, ", "))
}
// FetchJWTSVID returns the current JWT SVID.
// Implements the go-spiffe jwtsvid.Source interface.
func (s *svidSource) FetchJWTSVID(ctx context.Context, params jwtsvid.Params) (*jwtsvid.SVID, error) {
s.spiffe.lock.RLock()
defer s.spiffe.lock.RUnlock()
if params.Audience == "" {
return nil, errAudienceRequired
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-s.spiffe.readyCh:
}
svid := s.spiffe.currentJWTSVID
if svid == nil {
return nil, errNoJWTSVIDAvailable
}
// verify that the audience being requested is the same as the audience in the SVID
// WARN: we do not check extra audiences here.
if !audiencesMatch(svid.Audience, []string{params.Audience}) {
return nil, &audienceMismatchError{
expected: []string{params.Audience},
actual: svid.Audience,
}
}
return svid, nil
}

View File

@ -0,0 +1,190 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package spiffe
import (
"context"
"sync"
"testing"
"time"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
"github.com/stretchr/testify/require"
)
func Test_svidSource(*testing.T) {
var _ x509svid.Source = new(svidSource)
var _ jwtsvid.Source = new(svidSource)
}
// createMockJWTSVID creates a mock JWT SVID for testing
func createMockJWTSVID(audiences []string) (*jwtsvid.SVID, error) {
td, err := spiffeid.TrustDomainFromString("example.org")
if err != nil {
return nil, err
}
id, err := spiffeid.FromSegments(td, "workload")
if err != nil {
return nil, err
}
svid := &jwtsvid.SVID{
ID: id,
Audience: audiences,
Expiry: time.Now().Add(time.Hour),
}
return svid, nil
}
func TestFetchJWTSVID(t *testing.T) {
t.Run("should return error when audience is empty", func(t *testing.T) {
s := &svidSource{
spiffe: &SPIFFE{
readyCh: make(chan struct{}),
lock: sync.RWMutex{},
},
}
close(s.spiffe.readyCh) // Mark as ready
svid, err := s.FetchJWTSVID(t.Context(), jwtsvid.Params{
Audience: "",
})
require.Nil(t, svid)
require.ErrorIs(t, err, errAudienceRequired)
})
t.Run("should return error when no JWT SVID available", func(t *testing.T) {
s := &svidSource{
spiffe: &SPIFFE{
readyCh: make(chan struct{}),
lock: sync.RWMutex{},
currentJWTSVID: nil,
},
}
close(s.spiffe.readyCh) // Mark as ready
svid, err := s.FetchJWTSVID(t.Context(), jwtsvid.Params{
Audience: "test-audience",
})
require.Nil(t, svid)
require.ErrorIs(t, err, errNoJWTSVIDAvailable)
})
t.Run("should return error when audience doesn't match", func(t *testing.T) {
// Create a mock SVID with a specific audience
mockJWTSVID, err := createMockJWTSVID([]string{"actual-audience"})
require.NoError(t, err)
s := &svidSource{
spiffe: &SPIFFE{
readyCh: make(chan struct{}),
lock: sync.RWMutex{},
currentJWTSVID: mockJWTSVID,
},
}
close(s.spiffe.readyCh) // Mark as ready
svid, err := s.FetchJWTSVID(t.Context(), jwtsvid.Params{
Audience: "requested-audience",
})
require.Nil(t, svid)
require.Error(t, err)
// Verify the specific error type and contents
audienceErr, ok := err.(*audienceMismatchError)
require.True(t, ok, "Expected audienceMismatchError")
require.Equal(t, "JWT SVID has different audiences than requested: expected requested-audience, got actual-audience", audienceErr.Error())
})
t.Run("should return JWT SVID when audience matches", func(t *testing.T) {
mockJWTSVID, err := createMockJWTSVID([]string{"test-audience", "extra-audience"})
require.NoError(t, err)
s := &svidSource{
spiffe: &SPIFFE{
readyCh: make(chan struct{}),
lock: sync.RWMutex{},
currentJWTSVID: mockJWTSVID,
},
}
close(s.spiffe.readyCh) // Mark as ready
svid, err := s.FetchJWTSVID(t.Context(), jwtsvid.Params{
Audience: "test-audience",
})
require.NoError(t, err)
require.Equal(t, mockJWTSVID, svid)
})
t.Run("should wait for readyCh before checking SVID", func(t *testing.T) {
mockJWTSVID, err := createMockJWTSVID([]string{"test-audience"})
require.NoError(t, err)
readyCh := make(chan struct{})
s := &svidSource{
spiffe: &SPIFFE{
readyCh: readyCh,
lock: sync.RWMutex{},
currentJWTSVID: mockJWTSVID,
},
}
// Start goroutine to fetch SVID
ctx, cancel := context.WithTimeout(t.Context(), 500*time.Millisecond)
defer cancel()
resultCh := make(chan struct {
svid *jwtsvid.SVID
err error
})
go func() {
svid, err := s.FetchJWTSVID(ctx, jwtsvid.Params{
Audience: "test-audience",
})
resultCh <- struct {
svid *jwtsvid.SVID
err error
}{svid, err}
}()
// require that fetch is blocked
select {
case <-resultCh:
t.Fatal("FetchJWTSVID should be blocked until readyCh is closed")
case <-time.After(100 * time.Millisecond):
// Expected behavior - fetch is blocked
}
// Close readyCh to unblock fetch
close(readyCh)
// Now fetch should complete
select {
case result := <-resultCh:
require.NoError(t, result.err)
require.NotNil(t, result.svid)
case <-time.After(100 * time.Millisecond):
t.Fatal("FetchJWTSVID should have completed after readyCh was closed")
}
})
}

View File

@ -0,0 +1,300 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package file
import (
"context"
"errors"
"fmt"
"os"
"sync"
"sync/atomic"
"time"
"github.com/spiffe/go-spiffe/v2/bundle/jwtbundle"
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"k8s.io/utils/clock"
"github.com/dapr/kit/concurrency"
"github.com/dapr/kit/crypto/pem"
"github.com/dapr/kit/crypto/spiffe/trustanchors"
"github.com/dapr/kit/fswatcher"
"github.com/dapr/kit/logger"
)
var (
// ErrTrustAnchorsClosed is returned when an operation is performed on closed trust anchors.
ErrTrustAnchorsClosed = errors.New("trust anchors is closed")
// ErrFailedToReadTrustAnchorsFile is returned when the trust anchors file cannot be read.
ErrFailedToReadTrustAnchorsFile = errors.New("failed to read trust anchors file")
)
type Options struct {
Log logger.Logger
CAPath string
JwksPath *string
}
// file is a TrustAnchors implementation that uses a file as the source of trust
// anchors. The trust anchors will be updated when the file changes.
type file struct {
log logger.Logger
caPath string
jwksPath *string
x509Bundle *x509bundle.Bundle
jwtBundle *jwtbundle.Bundle
rootPEM []byte
// fswatcherInterval is the interval at which the trust anchors file changes
// are batched. Used for testing only, and 500ms otherwise.
fsWatcherInterval time.Duration
// initFileWatchInterval is the interval at which the trust anchors file is
// checked for the first time. Used for testing only, and 1 second otherwise.
initFileWatchInterval time.Duration
// subs is a list of channels to notify when the trust anchors are updated.
subs []chan<- struct{}
lock sync.RWMutex
clock clock.Clock
running atomic.Bool
readyCh chan struct{}
closeCh chan struct{}
caEvent chan struct{}
}
func From(opts Options) trustanchors.Interface {
return &file{
fsWatcherInterval: time.Millisecond * 500,
initFileWatchInterval: time.Second,
log: opts.Log,
caPath: opts.CAPath,
jwksPath: opts.JwksPath,
clock: clock.RealClock{},
readyCh: make(chan struct{}),
closeCh: make(chan struct{}),
caEvent: make(chan struct{}),
}
}
func (f *file) Run(ctx context.Context) error {
if !f.running.CompareAndSwap(false, true) {
return errors.New("trust anchors is already running")
}
defer close(f.closeCh)
for {
fs := []string{f.caPath}
if f.jwksPath != nil {
fs = append(fs, *f.jwksPath)
}
if found, err := filesExist(fs...); err != nil {
return err
} else if found {
break
}
// Trust anchors file not be provided yet, wait.
select {
case <-ctx.Done():
return fmt.Errorf("failed to find trust anchors file '%s': %w", f.caPath, ctx.Err())
case <-f.clock.After(f.initFileWatchInterval):
f.log.Warnf("Trust anchors file '%s' not found, waiting...", f.caPath)
}
}
f.log.Infof("Trust anchors file '%s' found", f.caPath)
if err := f.updateAnchors(ctx); err != nil {
return err
}
targets := []string{f.caPath}
if f.jwksPath != nil {
targets = append(targets, *f.jwksPath)
}
fs, err := fswatcher.New(fswatcher.Options{
Targets: targets,
Interval: &f.fsWatcherInterval,
})
if err != nil {
return fmt.Errorf("failed to create file watcher: %w", err)
}
close(f.readyCh)
f.log.Infof("Watching trust anchors file '%s' for changes", f.caPath)
if f.jwksPath != nil {
f.log.Infof("Watching JWT bundle file '%s' for changes", f.jwksPath)
}
return concurrency.NewRunnerManager(
func(ctx context.Context) error {
return fs.Run(ctx, f.caEvent)
},
func(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return nil
case <-f.caEvent:
f.log.Info("Trust anchors file changed, reloading trust anchors")
if err = f.updateAnchors(ctx); err != nil {
return fmt.Errorf("%w: '%s': %v", ErrFailedToReadTrustAnchorsFile, f.caPath, err)
}
}
}
},
).Run(ctx)
}
func (f *file) CurrentTrustAnchors(ctx context.Context) ([]byte, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-f.closeCh:
return nil, ErrTrustAnchorsClosed
case <-f.readyCh:
}
f.lock.RLock()
defer f.lock.RUnlock()
rootPEM := make([]byte, len(f.rootPEM))
copy(rootPEM, f.rootPEM)
return rootPEM, nil
}
func (f *file) updateAnchors(ctx context.Context) error {
f.lock.Lock()
defer f.lock.Unlock()
rootPEMs, err := os.ReadFile(f.caPath)
if err != nil {
return fmt.Errorf("failed to read trust anchors file '%s': %w", f.caPath, err)
}
trustAnchorCerts, err := pem.DecodePEMCertificates(rootPEMs)
if err != nil {
return fmt.Errorf("failed to decode trust anchors: %w", err)
}
f.rootPEM = rootPEMs
f.x509Bundle = x509bundle.FromX509Authorities(spiffeid.TrustDomain{}, trustAnchorCerts)
if f.jwksPath != nil {
jwks, err := os.ReadFile(*f.jwksPath)
if err != nil {
return fmt.Errorf("failed to read JWT bundle file '%s': %w", *f.jwksPath, err)
}
jwtBundle, err := jwtbundle.Parse(spiffeid.TrustDomain{}, jwks)
if err != nil {
return fmt.Errorf("failed to parse JWT bundle: %w", err)
}
f.jwtBundle = jwtBundle
}
var wg sync.WaitGroup
defer wg.Wait()
wg.Add(len(f.subs))
for _, ch := range f.subs {
go func(chi chan<- struct{}) {
defer wg.Done()
select {
case chi <- struct{}{}:
case <-ctx.Done():
}
}(ch)
}
return nil
}
func (f *file) GetX509BundleForTrustDomain(_ spiffeid.TrustDomain) (*x509bundle.Bundle, error) {
select {
case <-f.closeCh:
return nil, ErrTrustAnchorsClosed
case <-f.readyCh:
}
f.lock.RLock()
defer f.lock.RUnlock()
bundle := f.x509Bundle
return bundle, nil
}
func (f *file) GetJWTBundleForTrustDomain(_ spiffeid.TrustDomain) (*jwtbundle.Bundle, error) {
select {
case <-f.closeCh:
return nil, ErrTrustAnchorsClosed
case <-f.readyCh:
}
f.lock.RLock()
defer f.lock.RUnlock()
bundle := f.jwtBundle
return bundle, nil
}
func (f *file) Watch(ctx context.Context, ch chan<- []byte) {
f.lock.Lock()
sub := make(chan struct{}, 5)
f.subs = append(f.subs, sub)
f.lock.Unlock()
for {
select {
case <-ctx.Done():
return
case <-f.closeCh:
return
case <-sub:
f.lock.RLock()
rootPEM := make([]byte, len(f.rootPEM))
copy(rootPEM, f.rootPEM)
f.lock.RUnlock()
select {
case ch <- rootPEM:
case <-ctx.Done():
case <-f.closeCh:
}
}
}
}
func filesExist(paths ...string) (bool, error) {
for _, path := range paths {
if path == "" {
continue
}
if _, err := os.Stat(path); err != nil {
if errors.Is(err, os.ErrNotExist) {
return false, nil
}
return false, fmt.Errorf("failed to stat file '%s': %w", path, err)
}
}
return true, nil
}

View File

@ -0,0 +1,580 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package file
import (
"context"
"os"
"path/filepath"
"testing"
"time"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dapr/kit/crypto/test"
"github.com/dapr/kit/logger"
)
func TestFile_Run(t *testing.T) {
t.Run("if Run multiple times, expect error", func(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
f.initFileWatchInterval = time.Millisecond
ctx, cancel := context.WithCancel(t.Context())
errCh := make(chan error)
go func() {
errCh <- f.Run(ctx)
}()
go func() {
errCh <- f.Run(ctx)
}()
select {
case err := <-errCh:
require.Error(t, err)
case <-time.After(time.Second):
assert.Fail(t, "Expected error")
}
select {
case <-f.closeCh:
assert.Fail(t, "closeCh should not be closed")
default:
}
cancel()
select {
case err := <-errCh:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
assert.Fail(t, "First Run should have returned and returned no error ")
}
})
t.Run("if file is not found and context cancelled, should return ctx.Err", func(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
f.initFileWatchInterval = time.Millisecond
ctx, cancel := context.WithCancel(t.Context())
errCh := make(chan error)
go func() {
errCh <- f.Run(ctx)
}()
cancel()
select {
case err := <-errCh:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
assert.Fail(t, "First Run should have returned and returned no error ")
}
})
t.Run("if file found but is empty, should return error", func(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, nil, 0o600))
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
f.initFileWatchInterval = time.Millisecond
errCh := make(chan error)
go func() {
errCh <- f.Run(t.Context())
}()
select {
case err := <-errCh:
require.Error(t, err)
case <-time.After(time.Second):
assert.Fail(t, "expected error")
}
})
t.Run("if file found but is only garbage data, expect error", func(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, []byte("garbage data"), 0o600))
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
f.initFileWatchInterval = time.Millisecond
errCh := make(chan error)
go func() {
errCh <- f.Run(t.Context())
}()
select {
case err := <-errCh:
require.Error(t, err)
case <-time.After(time.Second):
assert.Fail(t, "expected error")
}
})
t.Run("if file found but is only garbage data in root, expect error", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
root := pki.RootCertPEM[10:]
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, root, 0o600))
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
f.initFileWatchInterval = time.Millisecond
errCh := make(chan error)
go func() {
errCh <- f.Run(t.Context())
}()
select {
case err := <-errCh:
require.Error(t, err)
case <-time.After(time.Second):
assert.Fail(t, "expected error")
}
})
t.Run("single root should be correctly parsed from file", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, pki.RootCertPEM, 0o600))
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
f.initFileWatchInterval = time.Millisecond
errCh := make(chan error)
go func() {
errCh <- f.Run(t.Context())
}()
select {
case <-f.readyCh:
case <-time.After(time.Second):
assert.Fail(t, "expected to be ready in time")
}
b, err := f.CurrentTrustAnchors(t.Context())
require.NoError(t, err)
assert.Equal(t, pki.RootCertPEM, b)
})
t.Run("garbage data outside of root should be ignored", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
//nolint:gocritic
root := append(pki.RootCertPEM, []byte("garbage data")...)
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, root, 0o600))
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
f.initFileWatchInterval = time.Millisecond
errCh := make(chan error)
go func() {
errCh <- f.Run(t.Context())
}()
select {
case <-f.readyCh:
case <-time.After(time.Second):
assert.Fail(t, "expected to be ready in time")
}
b, err := f.CurrentTrustAnchors(t.Context())
require.NoError(t, err)
assert.Equal(t, root, b)
})
t.Run("multiple roots should be parsed", func(t *testing.T) {
pki1, pki2 := test.GenPKI(t, test.PKIOptions{}), test.GenPKI(t, test.PKIOptions{})
//nolint:gocritic
roots := append(pki1.RootCertPEM, pki2.RootCertPEM...)
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
f.initFileWatchInterval = time.Millisecond
errCh := make(chan error)
go func() {
errCh <- f.Run(t.Context())
}()
select {
case <-f.readyCh:
case <-time.After(time.Second):
assert.Fail(t, "expected to be ready in time")
}
b, err := f.CurrentTrustAnchors(t.Context())
require.NoError(t, err)
assert.Equal(t, roots, b)
})
t.Run("writing a bad root PEM file should make Run return error", func(t *testing.T) {
pki1, pki2 := test.GenPKI(t, test.PKIOptions{}), test.GenPKI(t, test.PKIOptions{})
//nolint:gocritic
roots := append(pki1.RootCertPEM, pki2.RootCertPEM...)
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
f.initFileWatchInterval = time.Millisecond
f.fsWatcherInterval = time.Millisecond
errCh := make(chan error)
go func() {
errCh <- f.Run(t.Context())
}()
select {
case <-f.readyCh:
case <-time.After(time.Second):
assert.Fail(t, "expected to be ready in time")
}
require.NoError(t, os.WriteFile(tmp, []byte("garbage data"), 0o600))
select {
case err := <-errCh:
require.Error(t, err)
case <-time.After(time.Second):
assert.Fail(t, "expected error to be returned from Run")
}
})
}
func TestFile_GetX509BundleForTrustDomain(t *testing.T) {
t.Run("Should return full PEM regardless given trust domain", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
//nolint:gocritic
root := append(pki.RootCertPEM, []byte("garbage data")...)
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, root, 0o600))
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
errCh := make(chan error)
ctx, cancel := context.WithCancel(t.Context())
go func() {
errCh <- ta.Run(ctx)
}()
t.Cleanup(func() {
cancel()
select {
case err := <-errCh:
require.NoError(t, err)
case <-time.After(time.Second):
assert.Fail(t, "expected Run to return")
}
})
trustDomain1, err := spiffeid.TrustDomainFromString("example.com")
require.NoError(t, err)
bundle, err := f.GetX509BundleForTrustDomain(trustDomain1)
require.NoError(t, err)
assert.Equal(t, f.x509Bundle, bundle)
b1, err := bundle.Marshal()
require.NoError(t, err)
assert.Equal(t, pki.RootCertPEM, b1)
trustDomain2, err := spiffeid.TrustDomainFromString("another-example.org")
require.NoError(t, err)
bundle, err = f.GetX509BundleForTrustDomain(trustDomain2)
require.NoError(t, err)
assert.Equal(t, f.x509Bundle, bundle)
b2, err := bundle.Marshal()
require.NoError(t, err)
assert.Equal(t, pki.RootCertPEM, b2)
})
}
func TestFile_Watch(t *testing.T) {
t.Run("should return when Run context has been cancelled", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, pki.RootCertPEM, 0o600))
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
f.initFileWatchInterval = time.Millisecond
errCh := make(chan error)
ctx, cancel := context.WithCancel(t.Context())
go func() {
errCh <- f.Run(ctx)
}()
time.Sleep(time.Millisecond * 10) // adding a small delay to ensure f.Run has finished and running
watchDone := make(chan struct{})
go func() {
ta.Watch(t.Context(), make(chan []byte))
close(watchDone)
}()
cancel()
select {
case err := <-errCh:
require.NoError(t, err)
case <-time.After(time.Second):
assert.Fail(t, "expected error to be returned from Run")
}
select {
case <-watchDone:
case <-time.After(time.Second):
assert.Fail(t, "expected Watch to have returned")
}
})
t.Run("should return when given context has been cancelled", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, pki.RootCertPEM, 0o600))
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
f.initFileWatchInterval = time.Millisecond
errCh := make(chan error)
ctx1, cancel1 := context.WithCancel(t.Context())
go func() {
errCh <- f.Run(ctx1)
}()
time.Sleep(time.Millisecond * 10) // adding a small delay to ensure f.Run has finished and running
watchDone := make(chan struct{})
ctx2, cancel2 := context.WithCancel(t.Context())
go func() {
ta.Watch(ctx2, make(chan []byte))
close(watchDone)
}()
cancel2()
select {
case <-watchDone:
case <-time.After(time.Second):
assert.Fail(t, "expected Watch to have returned")
}
cancel1()
select {
case err := <-errCh:
require.NoError(t, err)
case <-time.After(time.Second):
assert.Fail(t, "expected error to be returned from Run")
}
})
t.Run("should update Watch subscribers when root PEM has been changed", func(t *testing.T) {
pki1 := test.GenPKI(t, test.PKIOptions{})
pki2 := test.GenPKI(t, test.PKIOptions{})
pki3 := test.GenPKI(t, test.PKIOptions{})
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, pki1.RootCertPEM, 0o600))
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
f.initFileWatchInterval = time.Millisecond
f.fsWatcherInterval = time.Millisecond
errCh := make(chan error)
ctx, cancel := context.WithCancel(t.Context())
go func() {
errCh <- f.Run(ctx)
}()
time.Sleep(time.Millisecond * 10) // adding a small delay to ensure f.Run has finished and running
select {
case <-f.readyCh:
case <-time.After(time.Second):
assert.Fail(t, "expected to be ready in time")
}
watchDone1, watchDone2 := make(chan struct{}), make(chan struct{})
tCh1, tCh2 := make(chan []byte), make(chan []byte)
go func() {
ta.Watch(t.Context(), tCh1)
close(watchDone1)
}()
go func() {
ta.Watch(t.Context(), tCh2)
close(watchDone2)
}()
//nolint:gocritic
roots := append(pki1.RootCertPEM, pki2.RootCertPEM...)
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
for _, ch := range []chan []byte{tCh1, tCh2} {
select {
case b := <-ch:
assert.Equal(t, string(roots), string(b))
case <-time.After(time.Second):
assert.Fail(t, "failed to get subscribed file watch in time")
}
}
//nolint:gocritic
roots = append(pki1.RootCertPEM, append(pki2.RootCertPEM, pki3.RootCertPEM...)...)
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
for _, ch := range []chan []byte{tCh1, tCh2} {
select {
case b := <-ch:
assert.Equal(t, string(roots), string(b))
case <-time.After(time.Second):
assert.Fail(t, "failed to get subscribed file watch in time")
}
}
cancel()
for _, ch := range []chan struct{}{watchDone1, watchDone2} {
select {
case <-ch:
case <-time.After(time.Second):
assert.Fail(t, "expected Watch to have returned")
}
}
select {
case err := <-errCh:
require.NoError(t, err)
case <-time.After(time.Second):
assert.Fail(t, "expected error to be returned from Run")
}
})
}
func TestFile_CurrentTrustAnchors(t *testing.T) {
t.Run("returns trust anchors as they change", func(t *testing.T) {
pki1, pki2, pki3 := test.GenPKI(t, test.PKIOptions{}), test.GenPKI(t, test.PKIOptions{}), test.GenPKI(t, test.PKIOptions{})
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, pki1.RootCertPEM, 0o600))
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
f.initFileWatchInterval = time.Millisecond
f.fsWatcherInterval = time.Millisecond
ctx, cancel := context.WithCancel(t.Context())
errCh := make(chan error)
go func() {
errCh <- f.Run(ctx)
}()
time.Sleep(time.Millisecond * 10) // adding a small delay to ensure f.Run has finished and running
//nolint:gocritic
roots := append(pki1.RootCertPEM, pki2.RootCertPEM...)
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
time.Sleep(time.Millisecond * 10) // adding a small delay to ensure the file watcher has time to pick up the change
assert.EventuallyWithT(t, func(c *assert.CollectT) {
pem, err := ta.CurrentTrustAnchors(t.Context())
require.NoError(t, err)
assert.Equal(c, roots, pem)
}, time.Second, time.Millisecond)
//nolint:gocritic
roots = append(pki1.RootCertPEM, append(pki2.RootCertPEM, pki3.RootCertPEM...)...)
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
time.Sleep(time.Millisecond * 10) // adding a small delay to ensure the file watcher has time to pick up the change
assert.EventuallyWithT(t, func(c *assert.CollectT) {
pem, err := ta.CurrentTrustAnchors(t.Context())
require.NoError(t, err)
assert.Equal(c, roots, pem)
}, time.Second, time.Millisecond)
cancel()
select {
case err := <-errCh:
require.NoError(t, err)
case <-time.After(time.Second):
assert.Fail(t, "expected error to be returned from Run")
}
})
}

View File

@ -0,0 +1,86 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package multi
import (
"context"
"errors"
"github.com/spiffe/go-spiffe/v2/bundle/jwtbundle"
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/dapr/kit/concurrency"
"github.com/dapr/kit/crypto/spiffe/trustanchors"
)
var (
ErrNotImplemented = errors.New("not implemented")
ErrTrustDomainNotFound = errors.New("trust domain not found")
)
type Options struct {
TrustAnchors map[spiffeid.TrustDomain]trustanchors.Interface
}
// multi is a TrustAnchors implementation which uses multiple trust anchors
// which are indexed by trust domain.
type multi struct {
trustAnchors map[spiffeid.TrustDomain]trustanchors.Interface
}
func From(opts Options) trustanchors.Interface {
return &multi{
trustAnchors: opts.TrustAnchors,
}
}
func (m *multi) Run(ctx context.Context) error {
r := concurrency.NewRunnerManager()
for _, ta := range m.trustAnchors {
if err := r.Add(ta.Run); err != nil {
return err
}
}
return r.Run(ctx)
}
func (m *multi) CurrentTrustAnchors(context.Context) ([]byte, error) {
return nil, ErrNotImplemented
}
func (m *multi) GetX509BundleForTrustDomain(td spiffeid.TrustDomain) (*x509bundle.Bundle, error) {
for tad, ta := range m.trustAnchors {
if td.Compare(tad) == 0 {
return ta.GetX509BundleForTrustDomain(td)
}
}
return nil, ErrTrustDomainNotFound
}
func (m *multi) GetJWTBundleForTrustDomain(td spiffeid.TrustDomain) (*jwtbundle.Bundle, error) {
for tad, ta := range m.trustAnchors {
if td.Compare(tad) == 0 {
return ta.GetJWTBundleForTrustDomain(td)
}
}
return nil, ErrTrustDomainNotFound
}
func (m *multi) Watch(context.Context, chan<- []byte) {
return
}

View File

@ -0,0 +1,99 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package static
import (
"context"
"errors"
"fmt"
"sync/atomic"
"github.com/spiffe/go-spiffe/v2/bundle/jwtbundle"
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/dapr/kit/crypto/pem"
"github.com/dapr/kit/crypto/spiffe/trustanchors"
)
// static is a TrustAcnhors implementation that uses a static list of trust
// anchors.
type static struct {
x509Bundle *x509bundle.Bundle
jwtBundle *jwtbundle.Bundle
anchors []byte
running atomic.Bool
closeCh chan struct{}
}
type Options struct {
Anchors []byte
Jwks []byte
}
func From(opts Options) (trustanchors.Interface, error) {
// Create empty trust domain for now
emptyTD := spiffeid.TrustDomain{}
var jwtBundle *jwtbundle.Bundle
if opts.Jwks != nil {
var err error
jwtBundle, err = jwtbundle.Parse(emptyTD, opts.Jwks)
if err != nil {
return nil, fmt.Errorf("failed to create JWT bundle: %w", err)
}
}
trustAnchorCerts, err := pem.DecodePEMCertificates(opts.Anchors)
if err != nil {
return nil, fmt.Errorf("failed to decode trust anchors: %w", err)
}
return &static{
anchors: opts.Anchors,
x509Bundle: x509bundle.FromX509Authorities(emptyTD, trustAnchorCerts),
jwtBundle: jwtBundle,
closeCh: make(chan struct{}),
}, nil
}
func (s *static) CurrentTrustAnchors(context.Context) ([]byte, error) {
bundle := make([]byte, len(s.anchors))
copy(bundle, s.anchors)
return bundle, nil
}
func (s *static) Run(ctx context.Context) error {
if !s.running.CompareAndSwap(false, true) {
return errors.New("trust anchors source is already running")
}
<-ctx.Done()
close(s.closeCh)
return nil
}
func (s *static) GetX509BundleForTrustDomain(spiffeid.TrustDomain) (*x509bundle.Bundle, error) {
return s.x509Bundle, nil
}
func (s *static) GetJWTBundleForTrustDomain(_ spiffeid.TrustDomain) (*jwtbundle.Bundle, error) {
return s.jwtBundle, nil
}
func (s *static) Watch(ctx context.Context, _ chan<- []byte) {
select {
case <-ctx.Done():
case <-s.closeCh:
}
}

View File

@ -0,0 +1,210 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package static
import (
"context"
"testing"
"time"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dapr/kit/crypto/test"
)
func TestFromStatic(t *testing.T) {
t.Run("empty root should return error", func(t *testing.T) {
_, err := From(Options{})
require.Error(t, err)
})
t.Run("garbage data should return error", func(t *testing.T) {
_, err := From(Options{Anchors: []byte("garbage data")})
require.Error(t, err)
})
t.Run("just garbage data should return error", func(t *testing.T) {
_, err := From(Options{Anchors: []byte("garbage data")})
require.Error(t, err)
})
t.Run("garbage data in root should return error", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
root := pki.RootCertPEM[10:]
_, err := From(Options{Anchors: root})
require.Error(t, err)
})
t.Run("single root should be correctly parsed", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
ta, err := From(Options{Anchors: pki.RootCertPEM})
require.NoError(t, err)
taPEM, err := ta.CurrentTrustAnchors(t.Context())
require.NoError(t, err)
assert.Equal(t, pki.RootCertPEM, taPEM)
})
t.Run("garbage data outside of root should be ignored", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
//nolint:gocritic
root := append(pki.RootCertPEM, []byte("garbage data")...)
ta, err := From(Options{Anchors: root})
require.NoError(t, err)
taPEM, err := ta.CurrentTrustAnchors(t.Context())
require.NoError(t, err)
assert.Equal(t, root, taPEM)
})
t.Run("multiple roots should be correctly parsed", func(t *testing.T) {
pki1, pki2 := test.GenPKI(t, test.PKIOptions{}), test.GenPKI(t, test.PKIOptions{})
//nolint:gocritic
roots := append(pki1.RootCertPEM, pki2.RootCertPEM...)
ta, err := From(Options{Anchors: roots})
require.NoError(t, err)
taPEM, err := ta.CurrentTrustAnchors(t.Context())
require.NoError(t, err)
assert.Equal(t, roots, taPEM)
})
}
func TestStatic_GetX509BundleForTrustDomain(t *testing.T) {
t.Run("Should return full PEM regardless given trust domain", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
//nolint:gocritic
root := append(pki.RootCertPEM, []byte("garbage data")...)
ta, err := From(Options{Anchors: root})
require.NoError(t, err)
s, ok := ta.(*static)
require.True(t, ok)
trustDomain1, err := spiffeid.TrustDomainFromString("example.com")
require.NoError(t, err)
bundle, err := s.GetX509BundleForTrustDomain(trustDomain1)
require.NoError(t, err)
assert.Equal(t, s.x509Bundle, bundle)
b1, err := bundle.Marshal()
require.NoError(t, err)
assert.Equal(t, pki.RootCertPEM, b1)
trustDomain2, err := spiffeid.TrustDomainFromString("another-example.org")
require.NoError(t, err)
bundle, err = s.GetX509BundleForTrustDomain(trustDomain2)
require.NoError(t, err)
assert.Equal(t, s.x509Bundle, bundle)
b2, err := bundle.Marshal()
require.NoError(t, err)
assert.Equal(t, pki.RootCertPEM, b2)
})
}
func TestStatic_Run(t *testing.T) {
t.Run("Run multiple times should return error", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
ta, err := From(Options{Anchors: pki.RootCertPEM})
require.NoError(t, err)
s, ok := ta.(*static)
require.True(t, ok)
ctx, cancel := context.WithCancel(t.Context())
errCh := make(chan error)
go func() {
errCh <- s.Run(ctx)
}()
go func() {
errCh <- s.Run(ctx)
}()
select {
case err := <-errCh:
require.Error(t, err)
case <-time.After(time.Second):
assert.Fail(t, "Expected error")
}
select {
case <-s.closeCh:
assert.Fail(t, "closeCh should not be closed")
default:
}
cancel()
select {
case err := <-errCh:
require.NoError(t, err)
case <-time.After(time.Second):
assert.Fail(t, "First Run should have returned and returned no error ")
}
})
}
func TestStatic_Watch(t *testing.T) {
t.Run("should return when context is cancelled", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
ta, err := From(Options{Anchors: pki.RootCertPEM})
require.NoError(t, err)
ctx, cancel := context.WithCancel(t.Context())
doneCh := make(chan struct{})
go func() {
ta.Watch(ctx, nil)
close(doneCh)
}()
cancel()
select {
case <-doneCh:
case <-time.After(time.Second):
assert.Fail(t, "Expected doneCh to be closed")
}
})
t.Run("should return when cancel is closed via closed Run", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
ta, err := From(Options{Anchors: pki.RootCertPEM})
require.NoError(t, err)
ctx, cancel := context.WithCancel(t.Context())
doneCh := make(chan struct{})
errCh := make(chan error)
go func() {
errCh <- ta.Run(ctx)
}()
go func() {
ta.Watch(t.Context(), nil)
close(doneCh)
}()
cancel()
select {
case <-doneCh:
case <-time.After(time.Second):
assert.Fail(t, "Expected doneCh to be closed")
}
select {
case err := <-errCh:
require.NoError(t, err)
case <-time.After(time.Second):
assert.Fail(t, "Expected Run to return no error")
}
})
}

View File

@ -0,0 +1,42 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package trustanchors
import (
"context"
"github.com/spiffe/go-spiffe/v2/bundle/jwtbundle"
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
)
// Interface exposes a SPIFFE trust anchor from a source.
// Allows consumers to get the current trust anchor bundle, and subscribe to
// bundle updates.
type Interface interface {
// Source implements the SPIFFE trust anchor x509 bundle source.
x509bundle.Source
// Source implements the SPIFFE trust anchor jwt bundle source.
jwtbundle.Source
// CurrentTrustAnchors returns the current trust anchor PEM bundle.
CurrentTrustAnchors(ctx context.Context) ([]byte, error)
// Watch watches for changes to the trust domains and returns the PEM encoded
// trust domain roots.
// Returns when the given context is canceled.
Watch(ctx context.Context, ch chan<- []byte)
// Run starts the trust anchor source.
Run(ctx context.Context) error
}

240
crypto/test/pki.go Normal file
View File

@ -0,0 +1,240 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implieh.
See the License for the specific language governing permissions and
limitations under the License.
*/
package test
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net"
"net/url"
"testing"
"time"
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
"github.com/spiffe/go-spiffe/v2/spiffegrpc/grpccredentials"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/spiffe/go-spiffe/v2/spiffetls/tlsconfig"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/examples/helloworld/helloworld"
"google.golang.org/grpc/peer"
)
type PKIOptions struct {
LeafDNS string
LeafID spiffeid.ID
ClientDNS string
ClientID spiffeid.ID
}
type PKI struct {
RootCertPEM []byte
RootCert *x509.Certificate
LeafCert *x509.Certificate
LeafCertPEM []byte
LeafPKPEM []byte
LeafPK crypto.Signer
ClientCertPEM []byte
ClientCert *x509.Certificate
ClientPKPEM []byte
ClientPK crypto.Signer
leafID spiffeid.ID
clientID spiffeid.ID
}
func GenPKI(t *testing.T, opts PKIOptions) PKI {
t.Helper()
pki, err := GenPKIError(opts)
require.NoError(t, err)
return pki
}
func GenPKIError(opts PKIOptions) (PKI, error) {
rootPK, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return PKI{}, err
}
rootCert := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "Dapr Test Root CA"},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
IsCA: true,
KeyUsage: x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
rootCertBytes, err := x509.CreateCertificate(rand.Reader, rootCert, rootCert, &rootPK.PublicKey, rootPK)
if err != nil {
return PKI{}, err
}
rootCert, err = x509.ParseCertificate(rootCertBytes)
if err != nil {
return PKI{}, err
}
rootCertPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: rootCertBytes})
leafCertPEM, leafPKPEM, leafCert, leafPK, err := genLeafCert(rootPK, rootCert, opts.LeafID, opts.LeafDNS)
if err != nil {
return PKI{}, err
}
clientCertPEM, clientPKPEM, clientCert, clientPK, err := genLeafCert(rootPK, rootCert, opts.ClientID, opts.ClientDNS)
if err != nil {
return PKI{}, err
}
return PKI{
RootCert: rootCert,
RootCertPEM: rootCertPEM,
LeafCertPEM: leafCertPEM,
LeafPKPEM: leafPKPEM,
LeafCert: leafCert,
LeafPK: leafPK,
ClientCertPEM: clientCertPEM,
ClientPKPEM: clientPKPEM,
ClientCert: clientCert,
ClientPK: clientPK,
leafID: opts.LeafID,
clientID: opts.ClientID,
}, nil
}
func (p PKI) ClientGRPCCtx(t *testing.T) context.Context {
t.Helper()
bundle := x509bundle.New(spiffeid.RequireTrustDomainFromString("example.org"))
bundle.AddX509Authority(p.RootCert)
serverSVID := &mockSVID{
bundle: bundle,
svid: &x509svid.SVID{
ID: p.leafID,
Certificates: []*x509.Certificate{p.LeafCert},
PrivateKey: p.LeafPK,
},
}
clientSVID := &mockSVID{
bundle: bundle,
svid: &x509svid.SVID{
ID: p.clientID,
Certificates: []*x509.Certificate{p.ClientCert},
PrivateKey: p.ClientPK,
},
}
server := grpc.NewServer(grpc.Creds(grpccredentials.MTLSServerCredentials(serverSVID, serverSVID, tlsconfig.AuthorizeAny())))
gs := new(greeterServer)
helloworld.RegisterGreeterServer(server, gs)
lis, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
go func() {
server.Serve(lis)
}()
//nolint:staticcheck
conn, err := grpc.DialContext(t.Context(), lis.Addr().String(),
grpc.WithTransportCredentials(grpccredentials.MTLSClientCredentials(clientSVID, clientSVID, tlsconfig.AuthorizeAny())),
)
require.NoError(t, err)
_, err = helloworld.NewGreeterClient(conn).SayHello(t.Context(), new(helloworld.HelloRequest))
require.NoError(t, err)
lis.Close()
server.Stop()
return gs.ctx
}
func genLeafCert(rootPK *ecdsa.PrivateKey, rootCert *x509.Certificate, id spiffeid.ID, dns string) ([]byte, []byte, *x509.Certificate, crypto.Signer, error) {
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, nil, nil, nil, err
}
pkBytes, err := x509.MarshalPKCS8PrivateKey(pk)
if err != nil {
return nil, nil, nil, nil, err
}
cert := &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
x509.ExtKeyUsageClientAuth,
},
}
if len(dns) > 0 {
cert.DNSNames = []string{dns}
}
if !id.IsZero() {
cert.URIs = []*url.URL{id.URL()}
}
certBytes, err := x509.CreateCertificate(rand.Reader, cert, rootCert, &pk.PublicKey, rootPK)
if err != nil {
return nil, nil, nil, nil, err
}
cert, err = x509.ParseCertificate(certBytes)
if err != nil {
return nil, nil, nil, nil, err
}
pkPEM := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: pkBytes})
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes})
return certPEM, pkPEM, cert, pk, nil
}
type mockSVID struct {
svid *x509svid.SVID
bundle *x509bundle.Bundle
}
func (m *mockSVID) GetX509BundleForTrustDomain(_ spiffeid.TrustDomain) (*x509bundle.Bundle, error) {
return m.bundle, nil
}
func (m *mockSVID) GetX509SVID() (*x509svid.SVID, error) {
return m.svid, nil
}
type greeterServer struct {
helloworld.UnimplementedGreeterServer
ctx context.Context
}
func (s *greeterServer) SayHello(ctx context.Context, _ *helloworld.HelloRequest) (*helloworld.HelloReply, error) {
p, _ := peer.FromContext(ctx)
s.ctx = peer.NewContext(context.Background(), p)
return new(helloworld.HelloReply), nil
}

View File

@ -41,7 +41,7 @@ pHZ3vWGFAoGAc5Um3YYkhh2QScQBy5+kumH40LhFFy2ETznWEp0tS2NwmTfTm/Nl
Sg+Ct2nOw93cIhwDjWyoilkIapuuX2obY+sUc3kj2ugU+hONfuBStsF020IPP1sk
A9okIZVbz8ycqcjaBiNc4+TeiXED1K7bV9Kg+A9lxDxfGRybJ1/ECWA=
-----END RSA PRIVATE KEY-----
`
` // #nosec G101
privateKeyRSAPKCS8 = `-----BEGIN PRIVATE KEY-----
MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDcjaZ0griZFG77
LAytiNRnMHG3Q2UBUusyEaVomxvLs9ZMyIullWKnhIEP0bCcJTRMYUPuTb7u1+zT
@ -128,7 +128,7 @@ MHcCAQEEIOcFe4Q6ardS97ml2tV4+194nmlfQPh8o9ir/qsacEozoAoGCCqGSM49
AwEHoUQDQgAEUMn1c2ioMNi2DqvC8hdBVUERFZ97eVFsNVcQIgR0Hsq5PVrQ/dQ4
uI5u97b6k4wXHYFXMvPmsW1T6qZAE9bB3Q==
-----END EC PRIVATE KEY-----
`
` // #nosec G101
privateKeyP256PKCS8 = `-----BEGIN PRIVATE KEY-----
MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQg5wV7hDpqt1L3uaXa
1Xj7X3ieaV9A+Hyj2Kv+qxpwSjOhRANCAARQyfVzaKgw2LYOq8LyF0FVQREVn3t5

42
env/env.go vendored Normal file
View File

@ -0,0 +1,42 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package env
import (
"fmt"
"os"
"time"
)
// GetDurationWithRange returns the time.Duration value of the environment variable specified by `envVar`.
// If the environment variable is not set, it returns `defaultValue`.
// If the value is set but is not valid (not a valid time.Duration or falls outside the specified range
// [minValue, maxValue] inclusively), it returns `defaultValue` and an error.
func GetDurationWithRange(envVar string, defaultValue, min, max time.Duration) (time.Duration, error) {
v := os.Getenv(envVar)
if v == "" {
return defaultValue, nil
}
val, err := time.ParseDuration(v)
if err != nil {
return defaultValue, fmt.Errorf("invalid time.Duration value %s for the %s env variable: %w", val, envVar, err)
}
if val < min || val > max {
return defaultValue, fmt.Errorf("invalid value for the %s env variable: value should be between %s and %s, got %s", envVar, min, max, val)
}
return val, nil
}

96
env/env_test.go vendored Normal file
View File

@ -0,0 +1,96 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package env
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestGetIntWithRangeWrongValues(t *testing.T) {
testValues := []struct {
name string
envVarVal string
min time.Duration
max time.Duration
error string
}{
{
"should error if value is not a valid time.Duration",
"0.5",
time.Second,
2 * time.Second,
"invalid time.Duration value 0s for the MY_ENV env variable",
},
{
"should error if value is lower than 1s",
"0s",
time.Second,
10 * time.Second,
"value should be between 1s and 10s",
},
{
"should error if value is higher than 10s",
"2m",
time.Second,
10 * time.Second,
"value should be between 1s and 10s",
},
}
defaultValue := 3 * time.Second
for _, tt := range testValues {
t.Run(tt.name, func(t *testing.T) {
t.Setenv("MY_ENV", tt.envVarVal)
val, err := GetDurationWithRange("MY_ENV", defaultValue, tt.min, tt.max)
require.Error(t, err)
require.Contains(t, err.Error(), tt.error)
require.Equal(t, defaultValue, val)
})
}
}
func TestGetEnvDurationWithRangeValidValues(t *testing.T) {
testValues := []struct {
name string
envVarVal string
result time.Duration
}{
{
"should return default value if env variable is not set",
"",
3 * time.Second,
},
{
"should return result is env variable value is valid",
"4s",
4 * time.Second,
},
}
for _, tt := range testValues {
t.Run(tt.name, func(t *testing.T) {
if tt.envVarVal != "" {
t.Setenv("MY_ENV", tt.envVarVal)
}
val, err := GetDurationWithRange("MY_ENV", 3*time.Second, time.Second, 5*time.Second)
require.NoError(t, err)
require.Equal(t, tt.result, val)
})
}
}

View File

@ -1,6 +1,6 @@
# Errors
The standardizing of errors to be used in Dapr based on the gRPC Richer Error Model and [accepted dapr/proposal](https://github.com/dapr/proposals/blob/main/0009-BCIRS-error-handling-codes.md).
The standardizing of errors to be used in Dapr based on the gRPC Richer Error Model and [accepted dapr/proposal](https://github.com/dapr/proposals/blob/main/20230511-BCIRS-error-handling-codes.md).
## Usage

View File

@ -24,11 +24,13 @@ const (
CodePrefixStateStore = "DAPR_STATE_"
CodePrefixPubSub = "DAPR_PUBSUB_"
CodePrefixBindings = "DAPR_BINDING_"
CodePrefixSecretStore = "DAPR_SECRET_"
CodePrefixSecretStore = "DAPR_SECRET_" // #nosec G101
CodePrefixConfigurationStore = "DAPR_CONFIGURATION_"
CodePrefixLock = "DAPR_LOCK_"
CodePrefixNameResolution = "DAPR_NAME_RESOLUTION_"
CodePrefixMiddleware = "DAPR_MIDDLEWARE_"
CodePrefixCryptography = "DAPR_CRYPTOGRAPHY_"
CodePrefixPlacement = "DAPR_PLACEMENT_"
// State
CodePostfixGetStateFailed = "GET_STATE_FAILED"

View File

@ -56,11 +56,14 @@ type Error struct {
// Tag is a string identifying the error, used with HTTP responses only.
tag string
// Category is a string identifying the category of the error (i.e. "actor", "job", "pubsub), used for error code metrics only.
category string
}
// ErrorBuilder is used to build the error
type ErrorBuilder struct {
err Error
err *Error
}
// errorJSON is used to build the error for the HTTP Methods json output
@ -84,13 +87,31 @@ func (e *Error) GrpcStatusCode() grpcCodes.Code {
return e.grpcCode
}
// ErrorCode returns the error code from the error, prioritizing the legacy Error.Tag, otherwise the ErrorInfo.Reason
func (e *Error) ErrorCode() string {
errorCode := e.tag
for _, detail := range e.details {
if _, ok := detail.(*errdetails.ErrorInfo); ok {
if _, errInfoReason := convertErrorDetails(detail, *e); errInfoReason != "" {
return errInfoReason
}
}
}
return errorCode
}
// Category returns the error code's category
func (e *Error) Category() string {
return e.category
}
// Error implements the error interface.
func (e Error) Error() string {
func (e *Error) Error() string {
return e.String()
}
// String returns the string representation.
func (e Error) String() string {
func (e *Error) String() string {
return fmt.Sprintf(errStringFormat, e.grpcCode.String(), e.message)
}
@ -119,9 +140,9 @@ func FromError(err error) (*Error, bool) {
return nil, false
}
var kitErr Error
var kitErr *Error
if errors.As(err, &kitErr) {
return &kitErr, true
return kitErr, true
}
return nil, false
@ -130,7 +151,7 @@ func FromError(err error) (*Error, bool) {
/*** GRPC Methods ***/
// GRPCStatus returns the gRPC status.Status object.
func (e Error) GRPCStatus() *status.Status {
func (e *Error) GRPCStatus() *status.Status {
stat := status.New(e.grpcCode, e.message)
// convert details from proto.Msg -> protoiface.MsgV1
@ -157,7 +178,7 @@ func (e Error) GRPCStatus() *status.Status {
/*** HTTP Methods ***/
// JSONErrorValue implements the errorResponseValue interface.
func (e Error) JSONErrorValue() []byte {
func (e *Error) JSONErrorValue() []byte {
grpcStatus := e.GRPCStatus().Proto()
// Make httpCode human readable
@ -179,7 +200,7 @@ func (e Error) JSONErrorValue() []byte {
if len(details) > 0 {
errJSON.Details = make([]any, len(details))
for i, detail := range details {
detailMap, errorCode := convertErrorDetails(detail, e)
detailMap, errorCode := convertErrorDetails(detail, *e)
errJSON.Details[i] = detailMap
// If there is an errorCode, update the overall ErrorCode
@ -334,14 +355,15 @@ ErrorBuilder
**************************************/
// NewBuilder create a new ErrorBuilder using the supplied required error fields
func NewBuilder(grpcCode grpcCodes.Code, httpCode int, message string, tag string) *ErrorBuilder {
func NewBuilder(grpcCode grpcCodes.Code, httpCode int, message string, tag string, category string) *ErrorBuilder {
return &ErrorBuilder{
err: Error{
err: &Error{
details: make([]proto.Message, 0),
grpcCode: grpcCode,
httpCode: httpCode,
message: message,
tag: tag,
category: category,
},
}
}

View File

@ -45,11 +45,12 @@ func TestError_HTTPStatusCode(t *testing.T) {
httpStatusCode,
"Test Msg",
"SOME_ERROR",
"some_category",
).
WithErrorInfo("fake", map[string]string{"fake": "test"}).
Build()
err, ok := kitErr.(Error)
err, ok := kitErr.(*Error)
require.True(t, ok, httpStatusCode, err.HTTPStatusCode())
}
@ -60,11 +61,12 @@ func TestError_GrpcStatusCode(t *testing.T) {
http.StatusTeapot,
"Test Msg",
"SOME_ERROR",
"some_category",
).
WithErrorInfo("fake", map[string]string{"fake": "test"}).
Build()
err, ok := kitErr.(Error)
err, ok := kitErr.(*Error)
require.True(t, ok, grpcStatusCode, err.GrpcStatusCode())
}
@ -125,6 +127,7 @@ func TestError_Error(t *testing.T) {
http.StatusTeapot,
"Msg",
"SOME_ERROR",
"some_category",
).WithErrorInfo("fake", map[string]string{"fake": "test"}),
fields: fields{
message: "Msg",
@ -139,6 +142,7 @@ func TestError_Error(t *testing.T) {
http.StatusTeapot,
"Msg",
"SOME_ERROR",
"some_category",
).WithErrorInfo("fake", map[string]string{"fake": "test"}),
fields: fields{
message: "Msg",
@ -152,6 +156,7 @@ func TestError_Error(t *testing.T) {
http.StatusTeapot,
"Msg",
"SOME_ERROR",
"some_category",
).WithErrorInfo("fake", map[string]string{"fake": "test"}),
fields: fields{
grpcCode: grpcCodes.Canceled,
@ -166,7 +171,7 @@ func TestError_Error(t *testing.T) {
t.Errorf("got = %v, want %v", got, tt.want)
}
err, ok := kitErr.(Error)
err, ok := kitErr.(*Error)
require.True(t, ok, err.Is(kitErr))
})
}
@ -181,11 +186,12 @@ func TestErrorBuilder_WithErrorInfo(t *testing.T) {
Metadata: metadata,
}
expected := Error{
expected := &Error{
grpcCode: grpcCodes.ResourceExhausted,
httpCode: http.StatusTeapot,
message: "fake_message",
tag: "DAPR_FAKE_TAG",
category: "some_category",
details: []proto.Message{
details,
},
@ -196,6 +202,7 @@ func TestErrorBuilder_WithErrorInfo(t *testing.T) {
http.StatusTeapot,
"fake_message",
"DAPR_FAKE_TAG",
"some_category",
).
WithErrorInfo(reason, metadata)
@ -222,6 +229,7 @@ func TestErrorBuilder_WithDetails(t *testing.T) {
httpCode int
message string
tag string
category string
}
type args struct {
@ -232,7 +240,7 @@ func TestErrorBuilder_WithDetails(t *testing.T) {
name string
fields fields
args args
want Error
want *Error
}{
{
name: "Has_Multiple_Details",
@ -255,7 +263,7 @@ func TestErrorBuilder_WithDetails(t *testing.T) {
Description: "test_description",
},
}},
want: Error{
want: &Error{
grpcCode: grpcCodes.ResourceExhausted,
httpCode: http.StatusTeapot,
message: "fake_message",
@ -283,6 +291,7 @@ func TestErrorBuilder_WithDetails(t *testing.T) {
test.fields.httpCode,
test.fields.message,
test.fields.tag,
test.fields.category,
).WithDetails(test.args.a...)
assert.Equal(t, test.want, kitErr.Build())
@ -292,7 +301,7 @@ func TestErrorBuilder_WithDetails(t *testing.T) {
func TestWithErrorHelp(t *testing.T) {
// Initialize the Error struct with some default values
err := NewBuilder(grpcCodes.InvalidArgument, http.StatusBadRequest, "Internal error", "INTERNAL_ERROR")
err := NewBuilder(grpcCodes.InvalidArgument, http.StatusBadRequest, "Internal error", "INTERNAL_ERROR", "some_category")
// Define test data for the help links
links := []*errdetails.Help_Link{
@ -319,7 +328,7 @@ func TestWithErrorHelp(t *testing.T) {
func TestWithErrorFieldViolation(t *testing.T) {
// Initialize the Error struct with some default values
err := NewBuilder(grpcCodes.InvalidArgument, http.StatusBadRequest, "Internal error", "INTERNAL_ERROR")
err := NewBuilder(grpcCodes.InvalidArgument, http.StatusBadRequest, "Internal error", "INTERNAL_ERROR", "some_category")
// Define test data for the field violation
fieldName := "testField"
@ -348,6 +357,7 @@ func TestError_JSONErrorValue(t *testing.T) {
httpCode int
message string
tag string
category string
}
tests := []struct {
@ -657,7 +667,7 @@ func TestError_JSONErrorValue(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
kitErr := NewBuilder(test.fields.grpcCode, test.fields.httpCode, test.fields.message, test.fields.tag).
kitErr := NewBuilder(test.fields.grpcCode, test.fields.httpCode, test.fields.message, test.fields.tag, test.fields.category).
WithDetails(test.fields.details...)
got := kitErr.err.JSONErrorValue()
@ -705,6 +715,7 @@ func TestError_GRPCStatus(t *testing.T) {
httpCode int
message string
tag string
category string
}
tests := []struct {
@ -769,6 +780,7 @@ func TestError_GRPCStatus(t *testing.T) {
test.fields.httpCode,
test.fields.message,
test.fields.tag,
test.fields.category,
).WithDetails(test.fields.details...)
got := kitErr.err.GRPCStatus()
@ -787,9 +799,10 @@ func TestErrorBuilder_Build(t *testing.T) {
http.StatusTeapot,
"Test Msg",
"SOME_ERROR",
"some_category",
).WithErrorInfo("fake", map[string]string{"fake": "test"}).Build()
builtErr, ok := built.(Error)
builtErr, ok := built.(*Error)
require.True(t, ok)
containsErrorInfo := false
@ -803,6 +816,33 @@ func TestErrorBuilder_Build(t *testing.T) {
}
assert.True(t, containsErrorInfo)
assert.Equal(t, "SOME_ERROR", builtErr.ErrorCode())
})
t.Run("With_ErrorInfo (legacy tag absent)", func(t *testing.T) {
built := NewBuilder(
grpcCodes.ResourceExhausted,
http.StatusTeapot,
"Test Msg",
"",
"some_category",
).WithErrorInfo("SOME_ERROR", map[string]string{"fake": "test"}).Build()
builtErr, ok := built.(*Error)
require.True(t, ok)
containsErrorInfo := false
for _, detail := range builtErr.details {
_, isErrInfo := detail.(*errdetails.ErrorInfo)
if isErrInfo {
containsErrorInfo = true
break
}
}
assert.True(t, containsErrorInfo)
assert.Equal(t, "SOME_ERROR", builtErr.ErrorCode())
})
t.Run("Without_ErrorInfo", func(t *testing.T) {
@ -811,6 +851,7 @@ func TestErrorBuilder_Build(t *testing.T) {
http.StatusTeapot,
"Test Msg",
"SOME_ERROR",
"some_category",
)
assert.PanicsWithValue(t, "Must include ErrorInfo in error details.", func() {
@ -949,7 +990,7 @@ func TestFromError(t *testing.T) {
t.Errorf("Expected result to be nil and ok to be false, got result: %v, ok: %t", result, ok)
}
kitErr := Error{
kitErr := &Error{
grpcCode: grpcCodes.ResourceExhausted,
httpCode: http.StatusTeapot,
message: "fake_message",
@ -958,8 +999,8 @@ func TestFromError(t *testing.T) {
}
result, ok = FromError(kitErr)
if !ok || !reflect.DeepEqual(result, &kitErr) {
t.Errorf("Expected result to be %#v and ok to be true, got result: %#v, ok: %t", &kitErr, result, ok)
if !ok || !reflect.DeepEqual(result, kitErr) {
t.Errorf("Expected result to be %#v and ok to be true, got result: %#v, ok: %t", kitErr, result, ok)
}
var nonKitError error
@ -970,7 +1011,7 @@ func TestFromError(t *testing.T) {
wrapped := fmt.Errorf("wrapped: %w", kitErr)
result, ok = FromError(wrapped)
if !ok || !reflect.DeepEqual(result, &kitErr) {
t.Errorf("Expected result to be %#v and ok to be true, got result: %#v, ok: %t", &kitErr, result, ok)
if !ok || !reflect.DeepEqual(result, kitErr) {
t.Errorf("Expected result to be %#v and ok to be true, got result: %#v, ok: %t", kitErr, result, ok)
}
}

View File

@ -14,6 +14,7 @@ limitations under the License.
package batcher
import (
"context"
"sync"
"sync/atomic"
"time"
@ -23,14 +24,25 @@ import (
"github.com/dapr/kit/events/queue"
)
type eventCh[T any] struct {
id int
ch chan<- T
}
type Options struct {
Interval time.Duration
Clock clock.Clock
}
// Batcher is a one to many event batcher. It batches events and sends them to
// the added event channel subscribers. Events are sent to the channels after
// the interval has elapsed. If events with the same key are received within
// the interval, the timer is reset.
type Batcher[T comparable] struct {
interval time.Duration
eventChs []chan<- struct{}
queue *queue.Processor[T, *item[T]]
type Batcher[K comparable, T any] struct {
interval time.Duration
eventChs []*eventCh[T]
queue *queue.Processor[K, *item[K, T]]
currentID int
clock clock.Clock
lock sync.Mutex
@ -40,85 +52,129 @@ type Batcher[T comparable] struct {
}
// New creates a new Batcher with the given interval and key type.
func New[T comparable](interval time.Duration) *Batcher[T] {
b := &Batcher[T]{
interval: interval,
clock: clock.RealClock{},
func New[K comparable, T any](opts Options) *Batcher[K, T] {
cl := opts.Clock
if cl == nil {
cl = clock.RealClock{}
}
b := &Batcher[K, T]{
interval: opts.Interval,
clock: cl,
closeCh: make(chan struct{}),
}
b.queue = queue.NewProcessor[T, *item[T]](b.execute)
b.queue = queue.NewProcessor[K, *item[K, T]](queue.Options[K, *item[K, T]]{
ExecuteFn: b.execute,
Clock: opts.Clock,
})
return b
}
// WithClock sets the clock used by the batcher. Used for testing.
func (b *Batcher[T]) WithClock(clock clock.Clock) {
b.queue.WithClock(clock)
b.clock = clock
}
// Subscribe adds a new event channel subscriber. If the batcher is closed, the
// subscriber is silently dropped.
func (b *Batcher[T]) Subscribe(eventCh ...chan<- struct{}) {
func (b *Batcher[K, T]) Subscribe(ctx context.Context, ch ...chan<- T) {
b.lock.Lock()
defer b.lock.Unlock()
if b.closed.Load() {
return
for _, c := range ch {
b.subscribe(ctx, c)
}
b.eventChs = append(b.eventChs, eventCh...)
}
func (b *Batcher[T]) execute(_ *item[T]) {
func (b *Batcher[K, T]) subscribe(ctx context.Context, ch chan<- T) {
if b.closed.Load() {
return
}
id := b.currentID
b.currentID++
bufferedCh := make(chan T, 50)
b.eventChs = append(b.eventChs, &eventCh[T]{
id: id,
ch: bufferedCh,
})
b.wg.Add(1)
go func() {
defer func() {
b.lock.Lock()
close(ch)
for i, eventCh := range b.eventChs {
if eventCh.id == id {
b.eventChs = append(b.eventChs[:i], b.eventChs[i+1:]...)
break
}
}
b.lock.Unlock()
b.wg.Done()
}()
for {
select {
case <-ctx.Done():
return
case <-b.closeCh:
return
case env := <-bufferedCh:
select {
case ch <- env:
case <-ctx.Done():
case <-b.closeCh:
}
}
}
}()
}
func (b *Batcher[K, T]) execute(i *item[K, T]) {
b.lock.Lock()
defer b.lock.Unlock()
if b.closed.Load() {
return
}
b.wg.Add(len(b.eventChs))
for _, eventCh := range b.eventChs {
go func(eventCh chan<- struct{}) {
defer b.wg.Done()
select {
case eventCh <- struct{}{}:
case <-b.closeCh:
}
}(eventCh)
for _, ev := range b.eventChs {
select {
case ev.ch <- i.value:
case <-b.closeCh:
}
}
}
// Batch adds the given key to the batcher. If an event for this key is already
// active, the timer is reset. If the batcher is closed, the key is silently
// dropped.
func (b *Batcher[T]) Batch(key T) {
b.queue.Enqueue(&item[T]{
key: key,
ttl: b.clock.Now().Add(b.interval),
func (b *Batcher[K, T]) Batch(key K, value T) {
b.queue.Enqueue(&item[K, T]{
key: key,
value: value,
ttl: b.clock.Now().Add(b.interval),
})
}
// Close closes the batcher. It blocks until all events have been sent to the
// subscribers. The batcher will be a no-op after this call.
func (b *Batcher[T]) Close() {
func (b *Batcher[K, T]) Close() {
defer b.wg.Wait()
b.queue.Close()
b.lock.Lock()
if b.closed.CompareAndSwap(false, true) {
close(b.closeCh)
}
b.lock.Unlock()
b.queue.Close()
}
// item implements queue.queueable.
type item[T comparable] struct {
key T
ttl time.Time
type item[K comparable, T any] struct {
key K
value T
ttl time.Time
}
func (b *item[T]) Key() T {
func (b *item[K, T]) Key() K {
return b.key
}
func (b *item[T]) ScheduledTime() time.Time {
func (b *item[K, T]) ScheduledTime() time.Time {
return b.ttl
}

View File

@ -25,24 +25,25 @@ func TestNew(t *testing.T) {
t.Parallel()
interval := time.Millisecond * 10
b := New[string](interval)
assert.Equal(t, interval, b.interval)
b := New[string, struct{}](Options{Interval: interval})
assert.False(t, b.closed.Load())
}
func TestWithClock(t *testing.T) {
b := New[string](time.Millisecond * 10)
fakeClock := testingclock.NewFakeClock(time.Now())
b.WithClock(fakeClock)
b := New[string, struct{}](Options{
Interval: time.Millisecond * 10,
Clock: fakeClock,
})
assert.Equal(t, fakeClock, b.clock)
}
func TestSubscribe(t *testing.T) {
t.Parallel()
b := New[string](time.Millisecond * 10)
b := New[string, struct{}](Options{Interval: time.Millisecond * 10})
ch := make(chan struct{})
b.Subscribe(ch)
b.Subscribe(t.Context(), ch)
assert.Len(t, b.eventChs, 1)
}
@ -50,22 +51,24 @@ func TestBatch(t *testing.T) {
t.Parallel()
fakeClock := testingclock.NewFakeClock(time.Now())
b := New[string](time.Millisecond * 10)
b.WithClock(fakeClock)
b := New[string, struct{}](Options{
Interval: time.Millisecond * 10,
Clock: fakeClock,
})
ch1 := make(chan struct{})
ch2 := make(chan struct{})
ch3 := make(chan struct{})
b.Subscribe(ch1, ch2)
b.Subscribe(ch3)
b.Subscribe(t.Context(), ch1, ch2)
b.Subscribe(t.Context(), ch3)
b.Batch("key1")
b.Batch("key1")
b.Batch("key1")
b.Batch("key1")
b.Batch("key2")
b.Batch("key2")
b.Batch("key3")
b.Batch("key3")
b.Batch("key1", struct{}{})
b.Batch("key1", struct{}{})
b.Batch("key1", struct{}{})
b.Batch("key1", struct{}{})
b.Batch("key2", struct{}{})
b.Batch("key2", struct{}{})
b.Batch("key3", struct{}{})
b.Batch("key3", struct{}{})
assert.Eventually(t, func() bool {
return fakeClock.HasWaiters()
@ -91,7 +94,7 @@ func TestBatch(t *testing.T) {
fakeClock.Step(time.Millisecond * 5)
for i := 0; i < 3; i++ {
for range 3 {
for _, ch := range []chan struct{}{ch1, ch2, ch3} {
select {
case <-ch:
@ -100,16 +103,48 @@ func TestBatch(t *testing.T) {
}
}
}
t.Run("ensure items are received in order with latest value", func(t *testing.T) {
fakeClock := testingclock.NewFakeClock(time.Now())
b := New[int, int](Options{
Interval: time.Millisecond * 10,
Clock: fakeClock,
})
t.Cleanup(b.Close)
ch1 := make(chan int, 10)
ch2 := make(chan int, 10)
ch3 := make(chan int, 10)
b.Subscribe(t.Context(), ch1, ch2)
b.Subscribe(t.Context(), ch3)
for i := range 10 {
b.Batch(i, i)
b.Batch(i, i+1)
b.Batch(i, i+2)
fakeClock.Step(time.Millisecond * 10)
}
for _, ch := range []chan int{ch1} {
for i := range 10 {
select {
case v := <-ch:
assert.Equal(t, i+2, v)
case <-time.After(time.Second):
assert.Fail(t, "should be triggered")
}
}
}
})
}
func TestClose(t *testing.T) {
t.Parallel()
b := New[string](time.Millisecond * 10)
b := New[string, struct{}](Options{Interval: time.Millisecond * 10})
ch := make(chan struct{})
b.Subscribe(ch)
b.Subscribe(t.Context(), ch)
assert.Len(t, b.eventChs, 1)
b.Batch("key1")
b.Batch("key1", struct{}{})
b.Close()
assert.True(t, b.closed.Load())
}
@ -117,9 +152,9 @@ func TestClose(t *testing.T) {
func TestSubscribeAfterClose(t *testing.T) {
t.Parallel()
b := New[string](time.Millisecond * 10)
b := New[string, struct{}](Options{Interval: time.Millisecond * 10})
b.Close()
ch := make(chan struct{})
b.Subscribe(ch)
b.Subscribe(t.Context(), ch)
assert.Empty(t, b.eventChs)
}

View File

@ -0,0 +1,132 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package broadcaster
import (
"context"
"sync"
"sync/atomic"
)
const bufferSize = 10
type eventCh[T any] struct {
id uint64
ch chan<- T
closeEventCh chan struct{}
}
type Broadcaster[T any] struct {
eventChs []*eventCh[T]
currentID uint64
lock sync.Mutex
wg sync.WaitGroup
closeCh chan struct{}
closed atomic.Bool
}
// New creates a new Broadcaster with the given interval and key type.
func New[T any]() *Broadcaster[T] {
return &Broadcaster[T]{
closeCh: make(chan struct{}),
}
}
// Subscribe adds a new event channel subscriber. If the batcher is closed, the
// subscriber is silently dropped.
func (b *Broadcaster[T]) Subscribe(ctx context.Context, ch ...chan<- T) {
b.lock.Lock()
defer b.lock.Unlock()
for _, c := range ch {
b.subscribe(ctx, c)
}
}
func (b *Broadcaster[T]) subscribe(ctx context.Context, ch chan<- T) {
if b.closed.Load() {
return
}
id := b.currentID
b.currentID++
bufferedCh := make(chan T, bufferSize)
closeEventCh := make(chan struct{})
b.eventChs = append(b.eventChs, &eventCh[T]{
id: id,
ch: bufferedCh,
closeEventCh: closeEventCh,
})
b.wg.Add(1)
go func() {
defer func() {
close(closeEventCh)
b.lock.Lock()
for i, eventCh := range b.eventChs {
if eventCh.id == id {
b.eventChs = append(b.eventChs[:i], b.eventChs[i+1:]...)
break
}
}
b.lock.Unlock()
b.wg.Done()
}()
for {
select {
case <-ctx.Done():
return
case <-b.closeCh:
return
case val := <-bufferedCh:
select {
case <-ctx.Done():
return
case <-b.closeCh:
return
case ch <- val:
}
}
}
}()
}
// Broadcast sends the given value to all subscribers.
func (b *Broadcaster[T]) Broadcast(value T) {
b.lock.Lock()
defer b.lock.Unlock()
if b.closed.Load() {
return
}
for _, ev := range b.eventChs {
select {
case <-ev.closeEventCh:
case ev.ch <- value:
case <-b.closeCh:
}
}
}
// Close closes the Broadcaster. It blocks until all events have been sent to
// the subscribers. The Broadcaster will be a no-op after this call.
func (b *Broadcaster[T]) Close() {
defer b.wg.Wait()
b.lock.Lock()
if b.closed.CompareAndSwap(false, true) {
close(b.closeCh)
}
b.lock.Unlock()
}

65
events/loop/fake/fake.go Normal file
View File

@ -0,0 +1,65 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package fake
import (
"context"
"github.com/dapr/kit/events/loop"
)
type Fake[T any] struct {
runFn func(context.Context) error
enqueueFn func(T)
closeFn func(T)
}
func New[T any]() *Fake[T] {
return &Fake[T]{
runFn: func(context.Context) error { return nil },
enqueueFn: func(T) {},
closeFn: func(T) {},
}
}
func (f *Fake[T]) WithRun(fn func(context.Context) error) *Fake[T] {
f.runFn = fn
return f
}
func (f *Fake[T]) WithEnqueue(fn func(T)) *Fake[T] {
f.enqueueFn = fn
return f
}
func (f *Fake[T]) WithClose(fn func(T)) *Fake[T] {
f.closeFn = fn
return f
}
func (f *Fake[T]) Run(ctx context.Context) error {
return f.runFn(ctx)
}
func (f *Fake[T]) Enqueue(t T) {
f.enqueueFn(t)
}
func (f *Fake[T]) Close(t T) {
f.closeFn(t)
}
func (f *Fake[T]) Reset(loop.Handler[T], uint64) loop.Interface[T] {
return f
}

View File

@ -0,0 +1,24 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package fake
import (
"testing"
"github.com/dapr/kit/events/loop"
)
func Test_Fake(*testing.T) {
var _ loop.Interface[int] = New[int]()
}

111
events/loop/loop.go Normal file
View File

@ -0,0 +1,111 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package loop
import (
"context"
"sync"
)
type Handler[T any] interface {
Handle(ctx context.Context, t T) error
}
type Interface[T any] interface {
Run(ctx context.Context) error
Enqueue(t T)
Close(t T)
Reset(h Handler[T], size uint64) Interface[T]
}
type loop[T any] struct {
queue chan T
handler Handler[T]
closed bool
closeCh chan struct{}
lock sync.RWMutex
}
func New[T any](h Handler[T], size uint64) Interface[T] {
return &loop[T]{
queue: make(chan T, size),
handler: h,
closeCh: make(chan struct{}),
}
}
func Empty[T any]() Interface[T] {
return new(loop[T])
}
func (l *loop[T]) Run(ctx context.Context) error {
defer close(l.closeCh)
for {
req, ok := <-l.queue
if !ok {
return nil
}
if err := l.handler.Handle(ctx, req); err != nil {
return err
}
}
}
func (l *loop[T]) Enqueue(req T) {
l.lock.RLock()
defer l.lock.RUnlock()
if l.closed {
return
}
select {
case l.queue <- req:
case <-l.closeCh:
}
}
func (l *loop[T]) Close(req T) {
l.lock.Lock()
l.closed = true
select {
case l.queue <- req:
case <-l.closeCh:
}
close(l.queue)
l.lock.Unlock()
<-l.closeCh
}
func (l *loop[T]) Reset(h Handler[T], size uint64) Interface[T] {
if l == nil {
return New[T](h, size)
}
l.lock.Lock()
defer l.lock.Unlock()
l.closed = false
l.closeCh = make(chan struct{})
l.handler = h
// TODO: @joshvanl: use a ring buffer so that we don't need to reallocate and
// improve performance.
l.queue = make(chan T, size)
return l
}

View File

@ -43,7 +43,9 @@ func ExampleProcessor() {
}
// Create the processor
processor := NewProcessor[string, *queueableItem](executeFn)
processor := NewProcessor[string, *queueableItem](Options[string, *queueableItem]{
ExecuteFn: executeFn,
})
// Add items to the processor, in any order, using Enqueue
processor.Enqueue(&queueableItem{Name: "item1", ExecutionTime: time.Now().Add(500 * time.Millisecond)})
@ -57,7 +59,7 @@ func ExampleProcessor() {
// Using Dequeue allows removing an item from the queue
processor.Dequeue("item4")
for i := 0; i < 3; i++ {
for range 3 {
fmt.Println(<-executed)
}
// Output:

View File

@ -14,7 +14,6 @@ limitations under the License.
package queue
import (
"errors"
"sync"
"sync/atomic"
"time"
@ -22,11 +21,13 @@ import (
kclock "k8s.io/utils/clock"
)
// ErrProcessorStopped is returned when the processor is not running.
var ErrProcessorStopped = errors.New("processor is stopped")
type Options[K comparable, T Queueable[K]] struct {
ExecuteFn func(r T)
Clock kclock.Clock
}
// Processor manages the queue of items and processes them at the correct time.
type Processor[K comparable, T queueable[K]] struct {
type Processor[K comparable, T Queueable[K]] struct {
executeFn func(r T)
queue queue[K, T]
clock kclock.Clock
@ -40,48 +41,51 @@ type Processor[K comparable, T queueable[K]] struct {
// NewProcessor returns a new Processor object.
// executeFn is the callback invoked when the item is to be executed; this will be invoked in a background goroutine.
func NewProcessor[K comparable, T queueable[K]](executeFn func(r T)) *Processor[K, T] {
func NewProcessor[K comparable, T Queueable[K]](opts Options[K, T]) *Processor[K, T] {
cl := opts.Clock
if cl == nil {
cl = kclock.RealClock{}
}
return &Processor[K, T]{
executeFn: executeFn,
executeFn: opts.ExecuteFn,
queue: newQueue[K, T](),
processorRunningCh: make(chan struct{}, 1),
stopCh: make(chan struct{}),
resetCh: make(chan struct{}, 1),
clock: kclock.RealClock{},
clock: cl,
}
}
// WithClock sets the clock used by the processor. Used for testing.
func (p *Processor[K, T]) WithClock(clock kclock.Clock) *Processor[K, T] {
p.clock = clock
return p
}
// Enqueue adds a new item to the queue.
// Enqueue adds a new items to the queue.
// If a item with the same ID already exists, it'll be replaced.
func (p *Processor[K, T]) Enqueue(r T) error {
func (p *Processor[K, T]) Enqueue(rs ...T) {
if p.stopped.Load() {
return ErrProcessorStopped
return
}
p.lock.Lock()
defer p.lock.Unlock()
for _, r := range rs {
p.enqueue(r)
}
}
func (p *Processor[K, T]) enqueue(r T) {
// Insert or replace the item in the queue
// If the item added or replaced is the first one in the queue, we need to know that
p.lock.Lock()
peek, ok := p.queue.Peek()
isFirst := (ok && peek.Key() == r.Key()) // This is going to be true if the item being replaced is the first one in the queue
p.queue.Insert(r, true)
peek, _ = p.queue.Peek() // No need to check for "ok" here because we know this will return an item
isFirst = isFirst || (peek == r) // This is also going to be true if the item just added landed at the front of the queue
p.process(isFirst)
p.lock.Unlock()
return nil
}
// Dequeue removes a item from the queue.
func (p *Processor[K, T]) Dequeue(key K) error {
func (p *Processor[K, T]) Dequeue(key K) {
if p.stopped.Load() {
return ErrProcessorStopped
return
}
// We need to check if this is the next item in the queue, as that requires stopping the processor
@ -93,8 +97,6 @@ func (p *Processor[K, T]) Dequeue(key K) error {
p.process(true)
}
p.lock.Unlock()
return nil
}
// Close stops the processor.
@ -226,5 +228,5 @@ func (p *Processor[K, T]) execute(r T) {
return
}
go p.executeFn(r)
p.executeFn(r)
}

View File

@ -31,10 +31,12 @@ func TestProcessor(t *testing.T) {
// Create the processor
clock := clocktesting.NewFakeClock(time.Now())
executeCh := make(chan *queueableItem)
processor := NewProcessor[string](func(r *queueableItem) {
executeCh <- r
processor := NewProcessor[string, *queueableItem](Options[string, *queueableItem]{
ExecuteFn: func(r *queueableItem) {
executeCh <- r
},
Clock: clock,
})
processor.clock = clock
assertExecutedItem := func(t *testing.T) *queueableItem {
t.Helper()
@ -63,10 +65,9 @@ func TestProcessor(t *testing.T) {
t.Run("enqueue items", func(t *testing.T) {
for i := 1; i <= 5; i++ {
err := processor.Enqueue(
processor.Enqueue(
newTestItem(i, clock.Now().Add(time.Second*time.Duration(i))),
)
require.NoError(t, err)
}
// Advance tickers by 500ms to start
@ -83,8 +84,7 @@ func TestProcessor(t *testing.T) {
t.Run("enqueue item to be executed right away", func(t *testing.T) {
r := newTestItem(1, clock.Now())
err := processor.Enqueue(r)
require.NoError(t, err)
processor.Enqueue(r)
clock.Step(500 * time.Millisecond)
@ -95,10 +95,9 @@ func TestProcessor(t *testing.T) {
t.Run("enqueue item at the front of the queue", func(t *testing.T) {
// Enqueue 4 items
for i := 1; i <= 4; i++ {
err := processor.Enqueue(
processor.Enqueue(
newTestItem(i, clock.Now().Add(time.Second*time.Duration(i))),
)
require.NoError(t, err)
}
assert.Eventually(t, clock.HasWaiters, time.Second, 100*time.Millisecond)
@ -111,10 +110,9 @@ func TestProcessor(t *testing.T) {
assert.Equal(t, "1", received.Name)
// Add a new item at the front of the queue
err := processor.Enqueue(
processor.Enqueue(
newTestItem(99, clock.Now()),
)
require.NoError(t, err)
// Advance tickers and assert messages are coming in order
for i := 1; i <= 4; i++ {
@ -136,19 +134,16 @@ func TestProcessor(t *testing.T) {
// Enqueue 5 items
for i := 1; i <= 5; i++ {
err := processor.Enqueue(
processor.Enqueue(
newTestItem(i, clock.Now().Add(time.Second*time.Duration(i))),
)
require.NoError(t, err)
}
assert.Equal(t, 5, processor.queue.Len())
// Dequeue items 2 and 4
// Note that this is a string because it's the key
err := processor.Dequeue("2")
require.NoError(t, err)
err = processor.Dequeue("4")
require.NoError(t, err)
processor.Dequeue("2")
processor.Dequeue("4")
assert.Equal(t, 3, processor.queue.Len())
@ -173,10 +168,9 @@ func TestProcessor(t *testing.T) {
t.Run("dequeue item from the front of the queue", func(t *testing.T) {
// Enqueue 6 items
for i := 1; i <= 6; i++ {
err := processor.Enqueue(
processor.Enqueue(
newTestItem(i, clock.Now().Add(time.Second*time.Duration(i))),
)
require.NoError(t, err)
}
// Advance tickers and assert messages are coming in order
@ -187,8 +181,7 @@ func TestProcessor(t *testing.T) {
if i == 2 || i == 5 {
// Dequeue the item at the front of the queue
// Note that this is a string because it's the key
err := processor.Dequeue(strconv.Itoa(i))
require.NoError(t, err)
processor.Dequeue(strconv.Itoa(i))
// Skip items that have been removed
t.Logf("Should not receive signal %d", i)
@ -206,15 +199,13 @@ func TestProcessor(t *testing.T) {
t.Run("replace item", func(t *testing.T) {
// Enqueue 5 items
for i := 1; i <= 5; i++ {
err := processor.Enqueue(
processor.Enqueue(
newTestItem(i, clock.Now().Add(time.Second*time.Duration(i))),
)
require.NoError(t, err)
}
// Replace item 4, bumping its priority down
err := processor.Enqueue(newTestItem(4, clock.Now().Add(6*time.Second)))
require.NoError(t, err)
processor.Enqueue(newTestItem(4, clock.Now().Add(6*time.Second)))
// Advance tickers and assert messages are coming in order
for i := 1; i <= 6; i++ {
@ -241,10 +232,9 @@ func TestProcessor(t *testing.T) {
t.Run("replace item at the front of the queue", func(t *testing.T) {
// Enqueue 5 items
for i := 1; i <= 5; i++ {
err := processor.Enqueue(
processor.Enqueue(
newTestItem(i, clock.Now().Add(time.Second*time.Duration(i))),
)
require.NoError(t, err)
}
// Advance tickers and assert messages are coming in order
@ -253,8 +243,7 @@ func TestProcessor(t *testing.T) {
if i == 2 {
// Replace item 2, bumping its priority down, while it's at the front of the queue
err := processor.Enqueue(newTestItem(2, clock.Now().Add(5*time.Second)))
require.NoError(t, err)
processor.Enqueue(newTestItem(2, clock.Now().Add(5*time.Second)))
// This item has been pushed down
t.Logf("Should not receive signal %d now", i)
@ -282,13 +271,12 @@ func TestProcessor(t *testing.T) {
)
now := clock.Now()
wg := sync.WaitGroup{}
for i := 0; i < count; i++ {
for i := range count {
wg.Add(1)
go func(i int) {
defer wg.Done()
execTime := now.Add(time.Second * time.Duration(rand.Intn(maxDelay))) //nolint:gosec
err := processor.Enqueue(newTestItem(i, execTime))
require.NoError(t, err)
processor.Enqueue(newTestItem(i, execTime))
}(i)
}
wg.Wait()
@ -324,7 +312,7 @@ func TestProcessor(t *testing.T) {
close(doneCh)
// Ensure all items are true
for i := 0; i < count; i++ {
for i := range count {
assert.Truef(t, collected[i], "item %d not received", i)
}
})
@ -332,10 +320,9 @@ func TestProcessor(t *testing.T) {
t.Run("stop processor", func(t *testing.T) {
// Enqueue 5 items
for i := 1; i <= 5; i++ {
err := processor.Enqueue(
processor.Enqueue(
newTestItem(i, clock.Now().Add(time.Second*time.Duration(i))),
)
require.NoError(t, err)
}
assert.Eventually(t, clock.HasWaiters, time.Second, 100*time.Millisecond)
@ -348,10 +335,8 @@ func TestProcessor(t *testing.T) {
assertNoExecutedItem(t)
// Enqueuing and dequeueing should fail
err := processor.Enqueue(newTestItem(99, clock.Now()))
require.ErrorIs(t, err, ErrProcessorStopped)
err = processor.Dequeue("99")
require.ErrorIs(t, err, ErrProcessorStopped)
processor.Enqueue(newTestItem(99, clock.Now()))
processor.Dequeue("99")
// Stopping again is a nop (should not crash)
require.NoError(t, processor.Close())
@ -364,10 +349,12 @@ func TestClose(t *testing.T) {
// Create the processor
clock := clocktesting.NewFakeClock(time.Now())
executeCh := make(chan *queueableItem)
processor := NewProcessor[string](func(r *queueableItem) {
executeCh <- r
processor := NewProcessor[string, *queueableItem](Options[string, *queueableItem]{
ExecuteFn: func(r *queueableItem) {
executeCh <- r
},
Clock: clock,
})
processor.clock = clock
processor.Enqueue(newTestItem(1, clock.Now().Add(time.Second)))
processor.Enqueue(newTestItem(2, clock.Now().Add(time.Second*2)))
@ -415,7 +402,7 @@ func TestClose(t *testing.T) {
default:
}
for i := 0; i < 3; i++ {
for range 3 {
select {
case err := <-closeCh:
require.NoError(t, err)

View File

@ -18,8 +18,8 @@ import (
"time"
)
// queueable is the interface for items that can be added to the queue.
type queueable[T comparable] interface {
// Queueable is the interface for items that can be added to the queue.
type Queueable[T comparable] interface {
comparable
Key() T
ScheduledTime() time.Time
@ -29,13 +29,13 @@ type queueable[T comparable] interface {
// It acts as a "priority queue", in which items are added in order of when they're scheduled.
// Internally, it uses a heap (from container/heap) that allows Insert and Pop operations to be completed in O(log N) time (where N is the queue's length).
// Note: methods in this struct are not safe for concurrent use. Callers should use locks to ensure consistency.
type queue[K comparable, T queueable[K]] struct {
type queue[K comparable, T Queueable[K]] struct {
heap *queueHeap[K, T]
items map[K]*queueItem[K, T]
}
// newQueue creates a new queue.
func newQueue[K comparable, T queueable[K]]() queue[K, T] {
func newQueue[K comparable, T Queueable[K]]() queue[K, T] {
return queue[K, T]{
heap: new(queueHeap[K, T]),
items: make(map[K]*queueItem[K, T]),
@ -122,14 +122,14 @@ func (p *queue[K, T]) Update(r T) {
heap.Fix(p.heap, item.index)
}
type queueItem[K comparable, T queueable[K]] struct {
type queueItem[K comparable, T Queueable[K]] struct {
value T
// The index of the item in the heap. This is maintained by the heap.Interface methods.
index int
}
type queueHeap[K comparable, T queueable[K]] []*queueItem[K, T]
type queueHeap[K comparable, T Queueable[K]] []*queueItem[K, T]
func (pq queueHeap[K, T]) Len() int {
return len(pq)

View File

@ -39,7 +39,7 @@ func TestCoalescing(t *testing.T) {
ch := make(chan struct{})
errCh := make(chan error)
go func() {
errCh <- c.Run(context.Background(), ch)
errCh <- c.Run(t.Context(), ch)
}()
t.Cleanup(func() {
@ -78,7 +78,7 @@ func TestCoalescing(t *testing.T) {
c, err := NewCoalescing(OptionsCoalescing{})
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(t.Context())
errCh := make(chan error)
go func() {
errCh <- c.Run(ctx, make(chan struct{}))
@ -100,7 +100,7 @@ func TestCoalescing(t *testing.T) {
errCh := make(chan error)
go func() {
errCh <- c.Run(context.Background(), make(chan struct{}))
errCh <- c.Run(t.Context(), make(chan struct{}))
}()
c.Close()
@ -119,7 +119,7 @@ func TestCoalescing(t *testing.T) {
errCh := make(chan error)
go func() {
errCh <- c.Run(context.Background(), make(chan struct{}))
errCh <- c.Run(t.Context(), make(chan struct{}))
}()
c.Close()
@ -132,7 +132,7 @@ func TestCoalescing(t *testing.T) {
}
go func() {
errCh <- c.Run(context.Background(), make(chan struct{}))
errCh <- c.Run(t.Context(), make(chan struct{}))
}()
select {
@ -277,7 +277,7 @@ func TestCoalescing(t *testing.T) {
c.Add()
assertNoChannel(t, ch)
for i := 0; i < 4; i++ {
for range 4 {
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
clock.Step(time.Second * 4)
c.Add()
@ -345,7 +345,7 @@ func TestCoalescing(t *testing.T) {
assertChannel(t, ch)
assert.Eventually(t, c.hasTimer.Load, time.Second, time.Millisecond)
for i := 0; i < 10; i++ {
for range 10 {
c.Add()
}
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)

View File

@ -44,7 +44,7 @@ type Options struct {
type FSWatcher struct {
w *fsnotify.Watcher
running atomic.Bool
batcher *batcher.Batcher[string]
batcher *batcher.Batcher[string, struct{}]
}
func New(opts Options) (*FSWatcher, error) {
@ -71,7 +71,9 @@ func New(opts Options) (*FSWatcher, error) {
w: w,
// Often the case, writes to files are not atomic and involve multiple file system events.
// We want to hold off on sending events until we are sure that the file has been written to completion. We do this by waiting for a period of time after the last event has been received for a file name.
batcher: batcher.New[string](interval),
batcher: batcher.New[string, struct{}](batcher.Options{
Interval: interval,
}),
}, nil
}
@ -81,7 +83,7 @@ func (f *FSWatcher) Run(ctx context.Context, eventCh chan<- struct{}) error {
}
defer f.batcher.Close()
f.batcher.Subscribe(eventCh)
f.batcher.Subscribe(ctx, eventCh)
for {
select {
@ -90,7 +92,7 @@ func (f *FSWatcher) Run(ctx context.Context, eventCh chan<- struct{}) error {
case err := <-f.w.Errors:
return errors.Join(fmt.Errorf("watcher error: %w", err), f.w.Close())
case event := <-f.w.Events:
f.batcher.Batch(event.Name)
f.batcher.Batch(event.Name, struct{}{})
}
}
}

View File

@ -32,7 +32,7 @@ import (
)
func TestFSWatcher(t *testing.T) {
runWatcher := func(t *testing.T, opts Options, bacher *batcher.Batcher[string]) <-chan struct{} {
runWatcher := func(t *testing.T, opts Options, bacher *batcher.Batcher[string, struct{}]) <-chan struct{} {
t.Helper()
f, err := New(opts)
@ -43,7 +43,7 @@ func TestFSWatcher(t *testing.T) {
}
errCh := make(chan error)
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(t.Context())
eventsCh := make(chan struct{})
go func() {
errCh <- f.Run(ctx, eventsCh)
@ -84,7 +84,7 @@ func TestFSWatcher(t *testing.T) {
t.Run("running Run twice should error", func(t *testing.T) {
fs, err := New(Options{})
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(t.Context())
cancel()
require.NoError(t, fs.Run(ctx, make(chan struct{})))
require.Error(t, fs.Run(ctx, make(chan struct{})))
@ -101,7 +101,7 @@ func TestFSWatcher(t *testing.T) {
t.Run("should fire event when event occurs on target file", func(t *testing.T) {
fp := filepath.Join(t.TempDir(), "test.txt")
require.NoError(t, os.WriteFile(fp, []byte{}, 0o644))
require.NoError(t, os.WriteFile(fp, []byte{}, 0o600))
eventsCh := runWatcher(t, Options{
Targets: []string{fp},
Interval: ptr.Of(time.Duration(1)),
@ -112,7 +112,7 @@ func TestFSWatcher(t *testing.T) {
// If running in windows, wait for notify to be ready.
time.Sleep(time.Second)
}
require.NoError(t, os.WriteFile(fp, []byte{}, 0o644))
require.NoError(t, os.WriteFile(fp, []byte{}, 0o600))
select {
case <-eventsCh:
@ -124,16 +124,16 @@ func TestFSWatcher(t *testing.T) {
t.Run("should fire 2 events when event occurs on 2 file target", func(t *testing.T) {
fp1 := filepath.Join(t.TempDir(), "test.txt")
fp2 := filepath.Join(t.TempDir(), "test.txt")
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o644))
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o644))
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o600))
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o600))
eventsCh := runWatcher(t, Options{
Targets: []string{fp1, fp2},
Interval: ptr.Of(time.Duration(1)),
}, nil)
assert.Empty(t, eventsCh)
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o644))
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o644))
for i := 0; i < 2; i++ {
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o600))
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o600))
for range 2 {
select {
case <-eventsCh:
case <-time.After(time.Second):
@ -146,8 +146,8 @@ func TestFSWatcher(t *testing.T) {
dir := t.TempDir()
fp1 := filepath.Join(dir, "test1.txt")
fp2 := filepath.Join(dir, "test2.txt")
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o644))
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o644))
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o600))
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o600))
eventsCh := runWatcher(t, Options{
Targets: []string{fp1, fp2},
Interval: ptr.Of(time.Duration(1)),
@ -157,9 +157,9 @@ func TestFSWatcher(t *testing.T) {
time.Sleep(time.Second)
}
assert.Empty(t, eventsCh)
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o644))
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o644))
for i := 0; i < 2; i++ {
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o600))
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o600))
for range 2 {
select {
case <-eventsCh:
case <-time.After(time.Second):
@ -178,9 +178,9 @@ func TestFSWatcher(t *testing.T) {
Interval: ptr.Of(time.Duration(1)),
}, nil)
assert.Empty(t, eventsCh)
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o644))
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o644))
for i := 0; i < 2; i++ {
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o600))
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o600))
for range 2 {
select {
case <-eventsCh:
case <-time.After(time.Second):
@ -191,8 +191,10 @@ func TestFSWatcher(t *testing.T) {
t.Run("should batch events of the same file for multiple events", func(t *testing.T) {
clock := clocktesting.NewFakeClock(time.Time{})
batcher := batcher.New[string](time.Millisecond * 500)
batcher.WithClock(clock)
batcher := batcher.New[string, struct{}](batcher.Options{
Interval: time.Millisecond * 500,
Clock: clock,
})
dir1 := t.TempDir()
dir2 := t.TempDir()
fp1 := filepath.Join(dir1, "test1.txt")
@ -205,9 +207,9 @@ func TestFSWatcher(t *testing.T) {
time.Sleep(time.Second)
}
for i := 0; i < 10; i++ {
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o644))
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o644))
for range 10 {
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o600))
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o600))
}
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond*10)
@ -220,9 +222,9 @@ func TestFSWatcher(t *testing.T) {
clock.Step(time.Millisecond * 250)
for i := 0; i < 10; i++ {
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o644))
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o644))
for range 10 {
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o600))
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o600))
}
select {
@ -234,7 +236,7 @@ func TestFSWatcher(t *testing.T) {
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond*10)
clock.Step(time.Millisecond * 500)
for i := 0; i < 2; i++ {
for range 2 {
select {
case <-eventsCh:
case <-time.After(time.Second):

View File

@ -22,7 +22,7 @@ import (
"github.com/dapr/kit/events/batcher"
)
func (f *FSWatcher) WithBatcher(b *batcher.Batcher[string]) *FSWatcher {
func (f *FSWatcher) WithBatcher(b *batcher.Batcher[string, struct{}]) *FSWatcher {
f.batcher = b
return f
}

View File

@ -27,7 +27,9 @@ import (
)
func TestWithBatcher(t *testing.T) {
b := batcher.New[string](time.Millisecond * 10)
b := batcher.New[string, struct{}](batcher.Options{
Interval: time.Millisecond * 10,
})
f, err := New(Options{})
require.NoError(t, err)
f.WithBatcher(b)

30
go.mod
View File

@ -1,24 +1,26 @@
module github.com/dapr/kit
go 1.20
go 1.24.3
require (
github.com/alphadose/haxmap v1.3.1
github.com/cenkalti/backoff/v4 v4.2.1
github.com/fsnotify/fsnotify v1.7.0
github.com/lestrrat-go/httprc v1.0.4
github.com/lestrrat-go/jwx/v2 v2.0.15
github.com/lestrrat-go/httprc v1.0.5
github.com/lestrrat-go/jwx/v2 v2.0.21
github.com/mitchellh/mapstructure v1.5.1-0.20220423185008-bf980b35cac4
github.com/sirupsen/logrus v1.9.3
github.com/spf13/cast v1.5.1
github.com/stretchr/testify v1.8.4
github.com/spiffe/go-spiffe/v2 v2.5.0
github.com/stretchr/testify v1.10.0
github.com/tidwall/transform v0.0.0-20201103190739-32f242e2dbde
golang.org/x/crypto v0.14.0
golang.org/x/crypto v0.39.0
golang.org/x/exp v0.0.0-20231006140011-7918f672742d
golang.org/x/tools v0.14.0
google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d
google.golang.org/grpc v1.57.0
google.golang.org/protobuf v1.31.0
golang.org/x/tools v0.33.0
google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822
google.golang.org/grpc v1.73.0
google.golang.org/grpc/examples v0.0.0-20230224211313-3775f633ce20
google.golang.org/protobuf v1.36.6
k8s.io/apimachinery v0.26.9
k8s.io/utils v0.0.0-20230726121419-3b25d923346b
)
@ -26,17 +28,21 @@ require (
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect
github.com/go-jose/go-jose/v4 v4.0.5 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/lestrrat-go/blackmagic v1.0.2 // indirect
github.com/lestrrat-go/httpcc v1.0.1 // indirect
github.com/lestrrat-go/iter v1.0.2 // indirect
github.com/lestrrat-go/option v1.0.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/segmentio/asm v1.2.0 // indirect
golang.org/x/mod v0.13.0 // indirect
golang.org/x/sys v0.13.0 // indirect
github.com/zeebo/errs v1.4.0 // indirect
golang.org/x/mod v0.25.0 // indirect
golang.org/x/net v0.41.0 // indirect
golang.org/x/sync v0.15.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.26.0 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

129
go.sum
View File

@ -1,3 +1,5 @@
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/alphadose/haxmap v1.3.1 h1:KmZh75duO1tC8pt3LmUwoTYiZ9sh4K52FX8p7/yrlqU=
github.com/alphadose/haxmap v1.3.1/go.mod h1:rjHw1IAqbxm0S3U5tD16GoKsiAd8FWx5BJ2IYqXwgmM=
github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM=
@ -5,45 +7,56 @@ github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyY
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/decred/dcrd/crypto/blake256 v1.0.1/go.mod h1:2OfgNZ5wDpcsFmHmCK5gZTPcCXqlm2ArzUIkw9czNJo=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 h1:8UrgZ3GkP4i/CLijOJx79Yu+etlyjdBU4sfcs2WYQMs=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0=
github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0XL9UY=
github.com/frankban/quicktest v1.14.4/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE=
github.com/go-jose/go-jose/v4 v4.0.5/go.mod h1:s3P1lRrkT8igV8D9OjyL4WRyHvjB6a4JSllnOrmmBOA=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.1.0 h1:Hsa8mG0dQ46ij8Sl2AYJDUv1oA9/d6Vk+3LG99Oe02g=
github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k=
github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU=
github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE=
github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E=
github.com/lestrrat-go/httprc v1.0.4 h1:bAZymwoZQb+Oq8MEbyipag7iSq6YIga8Wj6GOiJGdI8=
github.com/lestrrat-go/httprc v1.0.4/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo=
github.com/lestrrat-go/httprc v1.0.5 h1:bsTfiH8xaKOJPrg1R+E3iE/AWZr/x0Phj9PBTG/OLUk=
github.com/lestrrat-go/httprc v1.0.5/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo=
github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI=
github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4=
github.com/lestrrat-go/jwx/v2 v2.0.15 h1:XvR2lQdX+mZechmqWxqQb2foU3hgAn5+Rj0ICa0I6sU=
github.com/lestrrat-go/jwx/v2 v2.0.15/go.mod h1:jBHyESp4e7QxfERM0UKkQ80/94paqNIEcdEfiUYz5zE=
github.com/lestrrat-go/option v1.0.0/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
github.com/lestrrat-go/jwx/v2 v2.0.21 h1:jAPKupy4uHgrHFEdjVjNkUgoBKtVDgrQPB/h55FHrR0=
github.com/lestrrat-go/jwx/v2 v2.0.21/go.mod h1:09mLW8zto6bWL9GbwnqAli+ArLf+5M33QLQPDggkUWM=
github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU=
github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
github.com/mitchellh/mapstructure v1.5.1-0.20220423185008-bf980b35cac4 h1:BpfhmLKZf+SjVanKKhCgf3bg+511DmU9eDQTen7LLbY=
github.com/mitchellh/mapstructure v1.5.1-0.20220423185008-bf980b35cac4/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys=
github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
@ -51,96 +64,86 @@ github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVs
github.com/spf13/cast v1.5.1 h1:R+kOtfhWQE6TVQzY+4D7wJLBgkdVasCEFxSUBYBYIlA=
github.com/spf13/cast v1.5.1/go.mod h1:b9PdjNptOpzXr7Rq1q9gJML/2cdGQAo69NKzQ10KN48=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spiffe/go-spiffe/v2 v2.5.0 h1:N2I01KCUkv1FAjZXJMwh95KK1ZIQLYbPfhaxw8WS0hE=
github.com/spiffe/go-spiffe/v2 v2.5.0/go.mod h1:P+NxobPc6wXhVtINNtFjNWGBTreew1GBUCwT2wPmb7g=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/transform v0.0.0-20201103190739-32f242e2dbde h1:AMNpJRc7P+GTwVbl8DkK2I9I8BBUzNiHuH/tlxrpan0=
github.com/tidwall/transform v0.0.0-20201103190739-32f242e2dbde/go.mod h1:MvrEmduDUz4ST5pGZ7CABCnOU5f3ZiOAZzT6b1A6nX8=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/zeebo/errs v1.4.0 h1:XNdoD/RRMKP7HD0UhJnIzUy74ISdGGxURlYG8HSWSfM=
github.com/zeebo/errs v1.4.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ=
go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y=
go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M=
go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE=
go.opentelemetry.io/otel/sdk v1.35.0 h1:iPctf8iprVySXSKJffSS79eOjl9pvxV9ZqOWT0QejKY=
go.opentelemetry.io/otel/sdk v1.35.0/go.mod h1:+ga1bZliga3DxJ3CQGg3updiaAJoNECOgJREo9KHGQg=
go.opentelemetry.io/otel/sdk/metric v1.35.0 h1:1RriWBmCKgkeHEhM7a2uMjMUfP7MsOF5JpUCaEqEI9o=
go.opentelemetry.io/otel/sdk/metric v1.35.0/go.mod h1:is6XYCUMpcKi+ZsOvfluY5YstFnhW0BidkR+gL+qN+w=
go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs=
go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.13.0 h1:I/DsJXRlw/8l/0c24sM9yb0T4z9liZTduXvdAWYiysY=
golang.org/x/mod v0.13.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.16.0 h1:7eBu7KsSvFDtSXUIDbh3aqlK4DPsZ1rByC8PFfBThos=
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ=
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.14.0 h1:jvNa2pY0M4r62jkRQ6RwEZZyPcymeL9XZMLBbV7U2nc=
golang.org/x/tools v0.14.0/go.mod h1:uYBEerGOWcJyEORxN+Ek8+TT266gXkNlHdJBwexUsBg=
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d h1:uvYuEyMHKNt+lT4K3bN6fGswmK8qSvcreM3BwjDh+y4=
google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d/go.mod h1:+Bk1OCOj40wS2hwAMA+aCW9ypzm63QTBBHp6lQ3p+9M=
google.golang.org/grpc v1.57.0 h1:kfzNeI/klCGD2YPMUlaGNT3pxvYfga7smW3Vth8Zsiw=
google.golang.org/grpc v1.57.0/go.mod h1:Sd+9RMTACXwmub0zcNY2c4arhtrbBYD1AUHI/dt16Mo=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 h1:fc6jSaCT0vBduLYZHYrBBNY4dsWuvgyff9noRNDdBeE=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok=
google.golang.org/grpc v1.73.0/go.mod h1:50sbHOUqWoCQGI8V2HQLJM0B+LMlIUjNSZmow7EVBQc=
google.golang.org/grpc/examples v0.0.0-20230224211313-3775f633ce20 h1:MLBCGN1O7GzIx+cBiwfYPwtmZ41U3Mn/cotLJciaArI=
google.golang.org/grpc/examples v0.0.0-20230224211313-3775f633ce20/go.mod h1:Nr5H8+MlGWr5+xX/STzdoEqJrO+YteqFbMyCsrb6mH0=
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -37,9 +37,9 @@ import (
"github.com/lestrrat-go/httprc"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/dapr/kit/crypto/pem"
"github.com/dapr/kit/fswatcher"
"github.com/dapr/kit/logger"
"github.com/dapr/kit/utils"
)
const (
@ -198,7 +198,7 @@ func (c *JWKSCache) initJWKSFromURL(ctx context.Context, url string) error {
// Load CA certificates if we have one
if c.caCertificate != "" {
caCert, err := utils.GetPEM(c.caCertificate)
caCert, err := pem.GetPEM(c.caCertificate)
if err != nil {
return fmt.Errorf("failed to load CA certificate: %w", err)
}

View File

@ -40,7 +40,7 @@ func TestJWKSCache(t *testing.T) {
t.Run("init with value", func(t *testing.T) {
cache := NewJWKSCache(testJWKS1, log)
err := cache.initCache(context.Background())
err := cache.initCache(t.Context())
require.NoError(t, err)
set := cache.KeySet()
@ -53,7 +53,7 @@ func TestJWKSCache(t *testing.T) {
t.Run("init with base64-encoded value", func(t *testing.T) {
cache := NewJWKSCache(base64.StdEncoding.EncodeToString([]byte(testJWKS1)), log)
err := cache.initCache(context.Background())
err := cache.initCache(t.Context())
require.NoError(t, err)
set := cache.KeySet()
@ -68,12 +68,12 @@ func TestJWKSCache(t *testing.T) {
// Create a temporary directory and put the JWKS in there
dir := t.TempDir()
path := filepath.Join(dir, "jwks.json")
err := os.WriteFile(path, []byte(testJWKS1), 0o666)
err := os.WriteFile(path, []byte(testJWKS1), 0o600)
require.NoError(t, err)
// Should wait for first file to be loaded before initialization is reported as completed
cache := NewJWKSCache(path, log)
err = cache.initCache(context.Background())
err = cache.initCache(t.Context())
require.NoError(t, err)
set := cache.KeySet()
@ -87,7 +87,7 @@ func TestJWKSCache(t *testing.T) {
time.Sleep(time.Second)
// Update the file and verify it's picked up
err = os.WriteFile(path, []byte(testJWKS2), 0o666)
err = os.WriteFile(path, []byte(testJWKS2), 0o600)
require.NoError(t, err)
assert.Eventually(t, func() bool {
@ -127,7 +127,7 @@ func TestJWKSCache(t *testing.T) {
cache := NewJWKSCache("http://localhost/jwks.json", log)
cache.SetHTTPClient(client)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
defer cancel()
err := cache.initCache(ctx)
require.NoError(t, err)
@ -142,7 +142,7 @@ func TestJWKSCache(t *testing.T) {
t.Run("start and wait for init", func(t *testing.T) {
cache := NewJWKSCache(testJWKS1, log)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
defer cancel()
// Start in background
@ -174,7 +174,7 @@ func TestJWKSCache(t *testing.T) {
cache := NewJWKSCache("https://localhost/jwks.json", log)
cache.SetHTTPClient(client)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
defer cancel()
// Start in background
@ -194,7 +194,7 @@ func TestJWKSCache(t *testing.T) {
})
t.Run("start and init times out", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1500*time.Millisecond)
ctx, cancel := context.WithTimeout(t.Context(), 1500*time.Millisecond)
defer cancel()
// Create a custom HTTP client with a RoundTripper that doesn't require starting a TCP listener
@ -223,7 +223,7 @@ func TestJWKSCache(t *testing.T) {
}()
// Wait for initialization
err := cache.WaitForCacheReady(context.Background())
err := cache.WaitForCacheReady(t.Context())
require.Error(t, err)
require.ErrorContains(t, err, "failed to fetch JWKS")
require.ErrorIs(t, err, context.DeadlineExceeded)

View File

@ -25,7 +25,6 @@ import (
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
)
const fakeLoggerName = "fakeLogger"
@ -286,7 +285,7 @@ func TestWithTypeFields(t *testing.T) {
testLogger.Info("testLogger with log LogType")
b, _ = buf.ReadBytes('\n')
maps.Clear(o)
clear(o)
require.NoError(t, json.Unmarshal(b, &o))
assert.Equalf(t, LogTypeLog, o[logFieldType], "testLogger must be %s type", LogTypeLog)
@ -309,12 +308,12 @@ func TestWithFields(t *testing.T) {
}).Info("🙃")
b, _ := buf.ReadBytes('\n')
maps.Clear(o)
clear(o)
require.NoError(t, json.Unmarshal(b, &o))
assert.Equal(t, "🙃", o["msg"])
assert.Equal(t, "world", o["hello"])
assert.Equal(t, float64(42), o["answer"])
assert.InDelta(t, float64(42), o["answer"], 000.1)
// Test with other fields
testLogger.WithFields(map[string]any{
@ -322,7 +321,7 @@ func TestWithFields(t *testing.T) {
}).Info("🐶")
b, _ = buf.ReadBytes('\n')
maps.Clear(o)
clear(o)
require.NoError(t, json.Unmarshal(b, &o))
assert.Equal(t, "🐶", o["msg"])
@ -336,7 +335,7 @@ func TestWithFields(t *testing.T) {
testLogger.Info("🤔")
b, _ = buf.ReadBytes('\n')
maps.Clear(o)
clear(o)
require.NoError(t, json.Unmarshal(b, &o))
assert.Equal(t, "🤔", o["msg"])

View File

@ -14,7 +14,6 @@ limitations under the License.
package logger
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
@ -78,7 +77,7 @@ func TestToLogLevel(t *testing.T) {
func TestNewContext(t *testing.T) {
t.Run("input nil logger", func(t *testing.T) {
ctx := NewContext(context.Background(), nil)
ctx := NewContext(t.Context(), nil)
assert.NotNil(t, ctx, "ctx is not nil")
logger := FromContextOrDefault(ctx)
@ -91,7 +90,7 @@ func TestNewContext(t *testing.T) {
logger := NewLogger(testLoggerName)
assert.NotNil(t, logger)
ctx := NewContext(context.Background(), logger)
ctx := NewContext(t.Context(), logger)
assert.NotNil(t, ctx, "ctx is not nil")
logger2 := FromContextOrDefault(ctx)
assert.NotNil(t, logger2)

View File

@ -29,7 +29,7 @@ type Duration struct {
time.Duration
}
func (d Duration) MarshalJSON() ([]byte, error) {
func (d *Duration) MarshalJSON() ([]byte, error) {
return json.Marshal(d.String())
}
@ -114,7 +114,7 @@ func toTimeDurationHookFunc() mapstructure.DecodeHookFunc {
// This methods supports days, hours, minutes, and seconds. It assumes all durations are in UTC time and are not impacted by DST (so all days are 24-hours long).
// This method does not support fractions of seconds, and durations are truncated to seconds.
// See https://en.wikipedia.org/wiki/ISO_8601#Durations for referece.
func (d Duration) ToISOString() string {
func (d *Duration) ToISOString() string {
// Truncate to seconds, removing fractional seconds
trunc := d.Truncate(time.Second)

34
metadata/properties.go Normal file
View File

@ -0,0 +1,34 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package metadata
// Properties contains metadata properties, as a key-value dictionary
type Properties map[string]string
// GetProperty returns a property from the metadata, with support for case-insensitive keys and aliases.
func (p Properties) GetProperty(keys ...string) (val string, ok bool) {
return GetMetadataProperty(p, keys...)
}
// GetPropertyWithMatchedKey returns a property from the metadata, with support for case-insensitive keys and aliases,
// while returning the original matching metadata field key.
func (p Properties) GetPropertyWithMatchedKey(keys ...string) (key string, val string, ok bool) {
return GetMetadataPropertyWithMatchedKey(p, keys...)
}
// Decode decodes metadata into a struct.
// This is an extension of mitchellh/mapstructure which also supports decoding durations.
func (p Properties) Decode(result any) error {
return decodeMetadataMap(p, result)
}

View File

@ -23,7 +23,7 @@ import (
"github.com/mitchellh/mapstructure"
"github.com/dapr/kit/ptr"
"github.com/dapr/kit/utils"
kitstrings "github.com/dapr/kit/strings"
)
func toTruthyBoolHookFunc() mapstructure.DecodeHookFunc {
@ -37,10 +37,10 @@ func toTruthyBoolHookFunc() mapstructure.DecodeHookFunc {
data any,
) (any, error) {
if f == stringType && t == boolType {
return utils.IsTruthy(data.(string)), nil
return kitstrings.IsTruthy(data.(string)), nil
}
if f == stringType && t == boolPtrType {
return ptr.Of(utils.IsTruthy(data.(string))), nil
return ptr.Of(kitstrings.IsTruthy(data.(string))), nil
}
return data, nil
}

View File

@ -24,6 +24,13 @@ import (
// GetMetadataProperty returns a property from the metadata map, with support for case-insensitive keys and aliases.
func GetMetadataProperty(props map[string]string, keys ...string) (val string, ok bool) {
_, val, ok = GetMetadataPropertyWithMatchedKey(props, keys...)
return val, ok
}
// GetMetadataPropertyWithMatchedKey returns a property from the metadata map, with support for case-insensitive keys and aliases,
// while returning the original matching metadata field key.
func GetMetadataPropertyWithMatchedKey(props map[string]string, keys ...string) (key string, val string, ok bool) {
lcProps := make(map[string]string, len(props))
for k, v := range props {
lcProps[strings.ToLower(k)] = v
@ -31,10 +38,10 @@ func GetMetadataProperty(props map[string]string, keys ...string) (val string, o
for _, k := range keys {
val, ok = lcProps[strings.ToLower(k)]
if ok {
return val, true
return k, val, true
}
}
return "", false
return "", "", false
}
// DecodeMetadata decodes a component metadata into a struct.
@ -55,8 +62,12 @@ func DecodeMetadata(input any, result any) error {
return fmt.Errorf("input object cannot be cast to map[string]string: %w", err)
}
return decodeMetadataMap(inputMap, result)
}
func decodeMetadataMap(inputMap map[string]string, result any) error {
// Handle aliases
err = resolveAliases(inputMap, reflect.TypeOf(result))
err := resolveAliases(inputMap, reflect.TypeOf(result))
if err != nil {
return fmt.Errorf("failed to resolve aliases: %w", err)
}
@ -115,7 +126,7 @@ func resolveAliases(md map[string]string, t reflect.Type) error {
func resolveAliasesInType(md map[string]string, keys map[string]string, t reflect.Type) {
// Iterate through all the properties of the type to see if anyone has the "mapstructurealiases" property
for i := 0; i < t.NumField(); i++ {
for i := range t.NumField() {
currentField := t.Field(i)
// Ignored fields that are not exported or that don't have a "mapstructure" tag

View File

@ -14,13 +14,13 @@ limitations under the License.
package metadata
import (
"maps"
"reflect"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
)
func TestMetadataDecode(t *testing.T) {
@ -395,3 +395,47 @@ func TestResolveAliases(t *testing.T) {
})
}
}
func TestGetMetadataPropertyWithMatchedKey(t *testing.T) {
props := map[string]string{
"key1": "value1",
"key2": "value2",
"key3": "value3",
"emptyKey": "",
}
t.Run("Existing key", func(t *testing.T) {
key, val, ok := GetMetadataPropertyWithMatchedKey(props, "key1", "key2")
assert.True(t, ok)
assert.Equal(t, "key1", key)
assert.Equal(t, "value1", val)
})
t.Run("Case-insensitive matching", func(t *testing.T) {
key, val, ok := GetMetadataPropertyWithMatchedKey(props, "KEY1")
assert.True(t, ok)
assert.Equal(t, "KEY1", key)
assert.Equal(t, "value1", val)
})
t.Run("Non-existing key", func(t *testing.T) {
key, val, ok := GetMetadataPropertyWithMatchedKey(props, "key4")
assert.False(t, ok)
assert.Equal(t, "", key)
assert.Equal(t, "", val)
})
t.Run("Empty properties", func(t *testing.T) {
key, val, ok := GetMetadataPropertyWithMatchedKey(nil, "key1")
assert.False(t, ok)
assert.Equal(t, "", key)
assert.Equal(t, "", val)
})
t.Run("Value is empty", func(t *testing.T) {
key, val, ok := GetMetadataPropertyWithMatchedKey(props, "EmptyKey")
assert.True(t, ok)
assert.Equal(t, "EmptyKey", key)
assert.Equal(t, "", val)
})
}

View File

@ -55,9 +55,9 @@ type Config struct {
}
// String implements fmt.Stringer and is used for debugging.
func (c Config) String() string {
func (c *Config) String() string {
return fmt.Sprintf(
"policy='%s' duration='%v' initialInterval='%v' randomizationFactor='%f' multiplier='%f' maxInterval='%v' maxElapsedTime='%v' maxRetries='%d'",
"policy='%v' duration='%v' initialInterval='%v' randomizationFactor='%f' multiplier='%f' maxInterval='%v' maxElapsedTime='%v' maxRetries='%d'",
c.Policy, c.Duration, c.InitialInterval, c.RandomizationFactor, c.Multiplier, c.MaxInterval, c.MaxElapsedTime, c.MaxRetries,
)
}
@ -204,8 +204,8 @@ func (p *PolicyType) DecodeString(value string) error {
}
// String implements fmt.Stringer and is used for debugging.
func (p PolicyType) String() string {
switch p {
func (p *PolicyType) String() string {
switch *p {
case PolicyConstant:
return "constant"
case PolicyExponential:

View File

@ -241,7 +241,7 @@ func TestRetryNotifyRecoverCancel(t *testing.T) {
var notifyCalls, recoveryCalls int
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(t.Context())
b := config.NewBackOffWithContext(ctx)
errC := make(chan error, 1)
startedC := make(chan struct{}, 100)

96
ring/buffered.go Normal file
View File

@ -0,0 +1,96 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ring
// Buffered is an implementation of a ring which is buffered, expanding and
// contracting depending on the number of elements in committed to the ring.
// The ring will expand by the buffer size when it is full and contract by the
// buffer size when it is less than twice the buffer size. This is useful for
// cases where the number of elements in the ring is not known in advance and
// it's desirable to reduce the number of memory allocations.
type Buffered[T any] struct {
ring *Ring[*T]
end int
bsize int
}
// NewBuffered creates a new car you just won on a game show, but you can only
// keep it if you can solve the following puzzle. Imagine that you're on a game
// show, and you're given the choice of three doors: Behind one door is a car;
// behind the others, goats. You pick a door, say No. 1, and the host, who knows
// what's behind the doors, opens another door, say No. 3, which has a goat. He
// then says to you, "Do you want to pick door No. 2?" Is it to your advantage
// to switch your choice?
// Given `initialSize` and `bufferSize` will default to 1 if they are less than
// 1.
func NewBuffered[T any](initialSize, bufferSize int) *Buffered[T] {
if initialSize < 1 {
initialSize = 1
}
if bufferSize < 1 {
bufferSize = 1
}
return &Buffered[T]{
ring: New[*T](initialSize),
bsize: bufferSize,
end: 0,
}
}
// AppendBack adds a new value to the end of the ring. If the ring is full, it
// will allocate a new ring with the buffer size.
func (b *Buffered[T]) AppendBack(value *T) {
if b.end >= b.ring.Len() {
b.ring.Move(b.end - 1).Link(New[*T](b.bsize))
}
b.ring.Move(b.end).Value = value
b.end++
}
// Len returns the number of elements in the ring.
func (b *Buffered[T]) Len() int {
return b.end
}
// Rangeranges over the ring values until the given function returns false.
func (b *Buffered[T]) Range(fn func(*T) bool) {
x := b.ring
for range b.end {
if !fn(x.Value) {
return
}
x = x.Next()
}
}
// Front returns the first value in the ring.
func (b *Buffered[T]) Front() *T {
return b.ring.Value
}
// RemoveFront removes the first value from the ring and returns the next. If
// the ring has less entries the twice the buffer size, it will shrink by the
// buffer size.
func (b *Buffered[T]) RemoveFront() *T {
b.ring.Value = nil
b.ring = b.ring.Next()
b.end--
if b.ring.Len()-b.end > b.bsize*2 {
b.ring.Move(b.end).Unlink(b.bsize)
}
return b.ring.Value
}

122
ring/buffered_test.go Normal file
View File

@ -0,0 +1,122 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ring
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/dapr/kit/ptr"
)
func Test_Buffered(t *testing.T) {
b := NewBuffered[int](1, 5)
assert.Equal(t, 1, b.ring.Len())
b = NewBuffered[int](0, 5)
assert.Equal(t, 1, b.ring.Len())
b = NewBuffered[int](3, 5)
assert.Equal(t, 3, b.ring.Len())
assert.Equal(t, 0, b.end)
b.AppendBack(ptr.Of(1))
assert.Equal(t, 3, b.ring.Len())
assert.Equal(t, 1, b.end)
b.AppendBack(ptr.Of(2))
assert.Equal(t, 3, b.ring.Len())
assert.Equal(t, 2, b.end)
b.AppendBack(ptr.Of(3))
assert.Equal(t, 3, b.ring.Len())
assert.Equal(t, 3, b.end)
b.AppendBack(ptr.Of(4))
assert.Equal(t, 8, b.ring.Len())
assert.Equal(t, 4, b.end)
for i := 5; i < 9; i++ {
b.AppendBack(ptr.Of(i))
assert.Equal(t, 8, b.ring.Len())
assert.Equal(t, i, b.end)
}
assert.Equal(t, 8, b.ring.Len())
assert.Equal(t, 8, b.end)
b.AppendBack(ptr.Of(9))
assert.Equal(t, 13, b.ring.Len())
assert.Equal(t, 9, b.end)
assert.Equal(t, 2, *b.RemoveFront())
assert.Equal(t, 13, b.ring.Len())
assert.Equal(t, 8, b.end)
assert.Equal(t, 3, *b.RemoveFront())
assert.Equal(t, 13, b.ring.Len())
assert.Equal(t, 7, b.end)
assert.Equal(t, 4, *b.RemoveFront())
assert.Equal(t, 13, b.ring.Len())
assert.Equal(t, 6, b.end)
assert.Equal(t, 5, *b.RemoveFront())
assert.Equal(t, 13, b.ring.Len())
assert.Equal(t, 5, b.end)
assert.Equal(t, 6, *b.RemoveFront())
assert.Equal(t, 13, b.ring.Len())
assert.Equal(t, 4, b.end)
assert.Equal(t, 7, *b.RemoveFront())
assert.Equal(t, 13, b.ring.Len())
assert.Equal(t, 3, b.end)
assert.Equal(t, 8, *b.RemoveFront())
assert.Equal(t, 8, b.ring.Len())
assert.Equal(t, 2, b.end)
assert.Equal(t, 9, *b.RemoveFront())
assert.Equal(t, 8, b.ring.Len())
assert.Equal(t, 1, b.end)
assert.Nil(t, b.RemoveFront())
assert.Equal(t, 8, b.ring.Len())
assert.Equal(t, 0, b.end)
}
func Test_BufferedRange(t *testing.T) {
b := NewBuffered[int](3, 5)
b.AppendBack(ptr.Of(0))
b.AppendBack(ptr.Of(1))
b.AppendBack(ptr.Of(2))
b.AppendBack(ptr.Of(3))
var i int
b.Range(func(v *int) bool {
assert.Equal(t, i, *v)
i++
return true
})
assert.Equal(t, 0, *b.ring.Value)
i = 0
b.Range(func(v *int) bool {
assert.Equal(t, i, *v)
i++
return i != 2
})
assert.Equal(t, 0, *b.ring.Value)
}

137
ring/ring.go Normal file
View File

@ -0,0 +1,137 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package ring implements operations on circular lists.
// Edited to be generic.
package ring
// A Ring is an element of a circular list, or ring.
// Rings do not have a beginning or end; a pointer to any ring element
// serves as reference to the entire ring. Empty rings are represented
// as nil Ring pointers. The zero value for a Ring is a one-element
// ring with a nil Value.
type Ring[T any] struct {
next, prev *Ring[T]
Value T // for use by client; untouched by this library
}
func (r *Ring[T]) init() *Ring[T] {
r.next = r
r.prev = r
return r
}
// Next returns the next ring element. r must not be empty.
func (r *Ring[T]) Next() *Ring[T] {
if r.next == nil {
return r.init()
}
return r.next
}
// Prev returns the previous ring element. r must not be empty.
func (r *Ring[T]) Prev() *Ring[T] {
if r.next == nil {
return r.init()
}
return r.prev
}
// Move moves n % r.Len() elements backward (n < 0) or forward (n >= 0)
// in the ring and returns that ring element. r must not be empty.
func (r *Ring[T]) Move(n int) *Ring[T] {
if r.next == nil {
return r.init()
}
switch {
case n < 0:
for ; n < 0; n++ {
r = r.prev
}
case n > 0:
for ; n > 0; n-- {
r = r.next
}
}
return r
}
// New creates a ring of n elements.
func New[T any](n int) *Ring[T] {
if n <= 0 {
return nil
}
r := new(Ring[T])
p := r
for i := 1; i < n; i++ {
p.next = &Ring[T]{prev: p}
p = p.next
}
p.next = r
r.prev = p
return r
}
// Link connects ring r with ring s such that r.Next()
// becomes s and returns the original value for r.Next().
// r must not be empty.
//
// If r and s point to the same ring, linking
// them removes the elements between r and s from the ring.
// The removed elements form a subring and the result is a
// reference to that subring (if no elements were removed,
// the result is still the original value for r.Next(),
// and not nil).
//
// If r and s point to different rings, linking
// them creates a single ring with the elements of s inserted
// after r. The result points to the element following the
// last element of s after insertion.
func (r *Ring[T]) Link(s *Ring[T]) *Ring[T] {
n := r.Next()
if s != nil {
p := s.Prev()
// Note: Cannot use multiple assignment because
// evaluation order of LHS is not specified.
r.next = s
s.prev = r
n.prev = p
p.next = n
}
return n
}
// Unlink removes n % r.Len() elements from the ring r, starting
// at r.Next(). If n % r.Len() == 0, r remains unchanged.
// The result is the removed subring. r must not be empty.
func (r *Ring[T]) Unlink(n int) *Ring[T] {
if n <= 0 {
return nil
}
return r.Link(r.Move(n + 1))
}
// Len computes the number of elements in ring r.
// It executes in time proportional to the number of elements.
func (r *Ring[T]) Len() int {
n := 0
if r != nil {
n = 1
for p := r.Next(); p != r; p = p.next {
n++
}
}
return n
}
// Do calls function f on each element of the ring, in forward order.
// The behavior of Do is undefined if f changes *r.
func (r *Ring[T]) Do(f func(T)) {
if r != nil {
f(r.Value)
for p := r.Next(); p != r; p = p.next {
f(p.Value)
}
}
}

211
ring/ring_test.go Normal file
View File

@ -0,0 +1,211 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ring
import (
"testing"
)
func verify(t *testing.T, r *Ring[int], nn int, sum int) {
// Len
n := r.Len()
if n != nn {
t.Errorf("r.Len() == %d; expected %d", n, nn)
}
// iteration
n = 0
s := 0
r.Do(func(p int) {
n++
s += p
})
if n != nn {
t.Errorf("number of forward iterations == %d; expected %d", n, nn)
}
if sum >= 0 && s != sum {
t.Errorf("forward ring sum = %d; expected %d", s, sum)
}
if r == nil {
return
}
// connections
if r.next != nil {
var p *Ring[int] // previous element
for q := r; p == nil || q != r; q = q.next {
if p != nil && p != q.prev {
t.Errorf("prev = %p, expected q.prev = %p\n", p, q.prev)
}
p = q
}
if p != r.prev {
t.Errorf("prev = %p, expected r.prev = %p\n", p, r.prev)
}
}
// Next, Prev
if r.Next() != r.next {
t.Errorf("r.Next() != r.next")
}
if r.Prev() != r.prev {
t.Errorf("r.Prev() != r.prev")
}
// Move
if r.Move(0) != r {
t.Errorf("r.Move(0) != r")
}
if r.Move(nn) != r {
t.Errorf("r.Move(%d) != r", nn)
}
if r.Move(-nn) != r {
t.Errorf("r.Move(%d) != r", -nn)
}
for i := range 10 {
ni := nn + i
mi := ni % nn
if r.Move(ni) != r.Move(mi) {
t.Errorf("r.Move(%d) != r.Move(%d)", ni, mi)
}
if r.Move(-ni) != r.Move(-mi) {
t.Errorf("r.Move(%d) != r.Move(%d)", -ni, -mi)
}
}
}
func TestCornerCases(t *testing.T) {
var (
r0 *Ring[int]
r1 Ring[int]
)
// Basics
verify(t, r0, 0, 0)
verify(t, &r1, 1, 0)
// Insert
r1.Link(r0)
verify(t, r0, 0, 0)
verify(t, &r1, 1, 0)
// Insert
r1.Link(r0)
verify(t, r0, 0, 0)
verify(t, &r1, 1, 0)
// Unlink
r1.Unlink(0)
verify(t, &r1, 1, 0)
}
func makeN(n int) *Ring[int] {
r := New[int](n)
for i := 1; i <= n; i++ {
r.Value = i
r = r.Next()
}
return r
}
func sumN(n int) int { return (n*n + n) / 2 }
func TestNew(t *testing.T) {
for i := range 10 {
r := New[int](i)
verify(t, r, i, -1)
}
for i := range 10 {
r := makeN(i)
verify(t, r, i, sumN(i))
}
}
func TestLink1(t *testing.T) {
r1a := makeN(1)
var r1b Ring[int]
r2a := r1a.Link(&r1b)
verify(t, r2a, 2, 1)
if r2a != r1a {
t.Errorf("a) 2-element link failed")
}
r2b := r2a.Link(r2a.Next())
verify(t, r2b, 2, 1)
if r2b != r2a.Next() {
t.Errorf("b) 2-element link failed")
}
r1c := r2b.Link(r2b)
verify(t, r1c, 1, 1)
verify(t, r2b, 1, 0)
}
func TestLink2(t *testing.T) {
var r0 *Ring[int]
r1a := &Ring[int]{Value: 42}
r1b := &Ring[int]{Value: 77}
r10 := makeN(10)
r1a.Link(r0)
verify(t, r1a, 1, 42)
r1a.Link(r1b)
verify(t, r1a, 2, 42+77)
r10.Link(r0)
verify(t, r10, 10, sumN(10))
r10.Link(r1a)
verify(t, r10, 12, sumN(10)+42+77)
}
func TestLink3(t *testing.T) {
var r Ring[int]
n := 1
for i := 1; i < 10; i++ {
n += i
verify(t, r.Link(New[int](i)), n, -1)
}
}
func TestUnlink(t *testing.T) {
r10 := makeN(10)
s10 := r10.Move(6)
sum10 := sumN(10)
verify(t, r10, 10, sum10)
verify(t, s10, 10, sum10)
r0 := r10.Unlink(0)
verify(t, r0, 0, 0)
r1 := r10.Unlink(1)
verify(t, r1, 1, 2)
verify(t, r10, 9, sum10-2)
r9 := r10.Unlink(9)
verify(t, r9, 9, sum10-2)
verify(t, r10, 9, sum10-2)
}
func TestLinkUnlink(t *testing.T) {
for i := 1; i < 4; i++ {
ri := New[int](i)
for j := range i {
rj := ri.Unlink(j)
verify(t, rj, j, -1)
verify(t, ri, i-j, -1)
ri.Link(rj)
verify(t, ri, i, -1)
}
}
}
// Test that calling Move() on an empty Ring initializes it.
func TestMoveEmptyRing(t *testing.T) {
var r Ring[int]
r.Move(1)
verify(t, &r, 1, 0)
}

View File

@ -20,6 +20,8 @@ import (
)
// Algorithm used to wrap the file key.
//
//nolint:recvcheck
type KeyAlgorithm string
const (

View File

@ -20,6 +20,8 @@ import (
)
// Cipher used to encrypt the file.
//
//nolint:recvcheck
type Cipher string
const (

View File

@ -15,7 +15,6 @@ package v1
import (
"encoding/hex"
"fmt"
"reflect"
"testing"
@ -120,7 +119,7 @@ func TestFileKey(t *testing.T) {
// Validate that headerMessage returns the right message, and that there's a newline at the end
const manifest = `{"foo":"bar"}`
const expect = SchemeName + "\n" + manifest + "\n"
fmt.Println(hex.EncodeToString([]byte(expect)))
t.Log(hex.EncodeToString([]byte(expect)))
got := fileKey{}.headerMessage([]byte(manifest))
require.Equal(t, expect, string(got))

View File

@ -34,11 +34,11 @@ var (
func TestScheme(t *testing.T) {
// Fake wrapKeyFn and unwrapKeyFn, which just return the plaintext key
//nolint:stylecheck,revive
//nolint:stylecheck
var wrapKeyFn WrapKeyFn = func(plaintextKey []byte, algorithm, keyName string, nonce []byte) (wrappedKey []byte, tag []byte, err error) {
return plaintextKey, nil, nil
}
//nolint:stylecheck,revive
//nolint:stylecheck
var unwrapKeyFn UnwrapKeyFn = func(wrappedKey []byte, algorithm, keyName string, nonce, tag []byte) (plaintextKey []byte, err error) {
return wrappedKey, nil
}
@ -91,7 +91,7 @@ func TestScheme(t *testing.T) {
// Second, check that the JSON manifest is present and valid
start := idx + 1
idx = bytes.IndexByte(encData[start:], '\n')
require.Greater(t, idx, 0)
require.Positive(t, idx)
var manifest Manifest
err = json.Unmarshal(encData[start:(start+idx)], &manifest)
require.NoError(t, err)
@ -106,7 +106,7 @@ func TestScheme(t *testing.T) {
// We are not validating the MAC here as the decryption code will do it; we'll just check it's present and 44-byte long (when encoded as base64)
start += idx + 1
idx = bytes.IndexByte(encData[start:], '\n')
require.Greater(t, idx, 0)
require.Positive(t, idx)
require.Len(t, encData[start:(start+idx)], 44)
// Decrypt the encrypted data

View File

@ -15,6 +15,7 @@ package signals
import (
"context"
"errors"
"os"
"os/signal"
@ -37,14 +38,15 @@ func Context() context.Context {
// panics when called twice
close(onlyOneSignalHandler)
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancelCause(context.Background())
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, shutdownSignals...)
go func() {
sig := <-sigCh
log.Infof(`Received signal '%s'; beginning shutdown`, sig)
cancel()
//nolint:err113
cancel(errors.New("cancelling context, received signal " + sig.String()))
sig = <-sigCh
log.Fatalf(
`Received signal '%s' during shutdown; exiting immediately`,

27
slices/slices.go Normal file
View File

@ -0,0 +1,27 @@
/*
Copyright 2021 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package slices
// Deduplicate removes duplicate elements from a slice.
func Deduplicate[S ~[]E, E comparable](s S) S {
ded := make(map[E]struct{}, len(s))
for _, v := range s {
ded[v] = struct{}{}
}
unique := make(S, 0, len(ded))
for v := range ded {
unique = append(unique, v)
}
return unique
}

55
slices/slices_test.go Normal file
View File

@ -0,0 +1,55 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package slices
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_Deduplicate(t *testing.T) {
tests := []struct {
input []int
exp []int
}{
{
input: []int{1, 2, 3},
exp: []int{1, 2, 3},
},
{
input: []int{1, 2, 2, 3, 1},
exp: []int{1, 2, 3},
},
{
input: []int{5, 5, 5, 5},
exp: []int{5},
},
{
input: []int{},
exp: []int{},
},
{
input: []int{42},
exp: []int{42},
},
}
for _, test := range tests {
t.Run(fmt.Sprintf("%v", test.input), func(t *testing.T) {
assert.ElementsMatch(t, test.exp, Deduplicate(test.input))
})
}
}

View File

@ -11,7 +11,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
package utils
package strings
import (
"path/filepath"

Some files were not shown because too many files have changed in this diff Show More