Compare commits

...

56 Commits

Author SHA1 Message Date
Fabian Martinez 8b780b4d81
Concurrency ctesting (#133)
* Adds concurrency/ctesting

Adds concurrency/ctesting package, used for concurrently running a set
of runners, and collect the results via a testing assertion.

Signed-off-by: joshvanl <me@joshvanl.dev>

* lint

Signed-off-by: Fabian Martinez <46371672+famarting@users.noreply.github.com>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
Signed-off-by: Fabian Martinez <46371672+famarting@users.noreply.github.com>
Co-authored-by: joshvanl <me@joshvanl.dev>
2025-07-17 11:07:48 -03:00
Javier Aliaga 9d4f384c57
feat: Add subsecond precision to jobs (#129)
* feat: Add subsecond precision to jobs

Signed-off-by: Javier Aliaga <javier@diagrid.io>

* chore: linting fixes

Signed-off-by: Javier Aliaga <javier@diagrid.io>

* chore: add tests

Signed-off-by: Javier Aliaga <javier@diagrid.io>

* chore: remove all precision to cron constant schedule

Signed-off-by: Javier Aliaga <javier@diagrid.io>

* chore: Time parse supports RFC3339nano

Signed-off-by: Javier Aliaga <javier@diagrid.io>

* chore: Fix TestFile_CurrentTrustAnchors flaky test

Signed-off-by: Javier Aliaga <javier@diagrid.io>

* chore: No need to fallback to RFC3339

Signed-off-by: Javier Aliaga <javier@diagrid.io>

* chore: Wait until file trustanchors is running

Signed-off-by: Javier Aliaga <javier@diagrid.io>

---------

Signed-off-by: Javier Aliaga <javier@diagrid.io>
2025-07-10 11:03:56 -03:00
Josh van Leeuwen 7c4cedad37
concurrency/dir: fix mkdir permissions (#131)
Fix Mkdir permissions by using 0o700 over `ModePerm`.

Signed-off-by: joshvanl <me@joshvanl.dev>
2025-07-08 21:41:07 -05:00
Javier Aliaga 7409957e9e
Update deps (#130)
* chore: update go and golanci-lint versions

Signed-off-by: Javier Aliaga <javier@diagrid.io>

* chore: Fix linting issues

Signed-off-by: Javier Aliaga <javier@diagrid.io>

* chore: Remove nolint exclusion

Signed-off-by: Javier Aliaga <javier@diagrid.io>

* chore: Revert range to for loop

Signed-off-by: Javier Aliaga <javier@diagrid.io>

---------

Signed-off-by: Javier Aliaga <javier@diagrid.io>
2025-07-08 17:55:56 -03:00
Josh van Leeuwen 34f8820d2a
concurrency/runner: make cancel with cause (#128)
Useful for making concurrency runner closer errors more useful for
shutdowns.

Signed-off-by: joshvanl <me@joshvanl.dev>
2025-07-03 09:51:31 -05:00
Joni Collinge 598b032bce
Fix deprecation comment to reference correct function (#125)
Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>
2025-06-16 17:06:11 +01:00
Josh van Leeuwen d7d50a1e1b
events/loop: drain queue on close (#124)
* events/loop: drain queue on close

Update looper to drain the Enqueue loop in the event that an error has
occurred when Enqueuing. Handle an error'd `Run` on `Close` by
respecting the `closedCh` channel.

Signed-off-by: joshvanl <me@joshvanl.dev>

* Update loop.go

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2025-06-16 08:26:36 -05:00
Josh van Leeuwen baea626399
Update .golangci.yml to remove deprecations (#122)
* Update .golangci.yml  to remove deprecations

Signed-off-by: joshvanl <me@joshvanl.dev>

* Update .golangci.yml

Co-authored-by: Cassie Coyle <cassie.i.coyle@gmail.com>
Signed-off-by: Josh van Leeuwen <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
Signed-off-by: Josh van Leeuwen <me@joshvanl.dev>
Co-authored-by: Cassie Coyle <cassie.i.coyle@gmail.com>
2025-05-22 08:58:18 -05:00
Joni Collinge bc7dc566c4
Add JWT handling to spiffe package (#118)
* Adds JWT handling to spiffe

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Update log

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Add jwtbundle

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Update file watcher

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Clean up jwt spiffe

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* lint

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Update renewal behavior

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Updates based on joshvanl feedback

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* go mod tidy

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* lint

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* lint

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Move ready chan check to avoid race

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Add small delay after fs write to allow watcher to pick up change

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Resolve feedback

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* Resolve feedback

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

* lint

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>

---------

Signed-off-by: Jonathan Collinge <jonathancollinge@live.com>
2025-05-16 13:15:56 +01:00
Josh van Leeuwen 98fe567235
events/loop: add reset (#120)
* events/loop: add reset

Update loop implementation is include functionality for Reset which is
useful when caching the loop struct for future use to reduce
allocations.

Signed-off-by: joshvanl <me@joshvanl.dev>

* lint

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2025-05-15 23:23:38 +01:00
Josh van Leeuwen e3d4a8f1b4
Add Copyright headers to env, and remove utils. (#103)
Chore to add Copyright headers to `env` files.

Move containing funcs and deletes `utils` package. `utils` packages are
generally a code smell, and better placed in a more descriptive package
name that gives context.

Signed-off-by: joshvanl <me@joshvanl.dev>
2025-04-23 11:29:15 -03:00
Josh van Leeuwen 77af8ac182
events/loop & slices (#119)
* events/loop & slices

Adds a generic control loop implementation to `event/loop`.

Adds a new `slices` package that provides a generic slice de-duplication
func.

Makes events batcher and queue processer taker in Options.

Allows enqueuing multiple processor items in same func call.

Signed-off-by: joshvanl <me@joshvanl.dev>

* Lint

Signed-off-by: joshvanl <me@joshvanl.dev>

* lint

Signed-off-by: joshvanl <me@joshvanl.dev>

* lint

Signed-off-by: joshvanl <me@joshvanl.dev>

* Elements match

Signed-off-by: joshvanl <me@joshvanl.dev>

* Adds buffer size option to events loop

Signed-off-by: joshvanl <me@joshvanl.dev>

* nit

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2025-04-23 11:28:21 -03:00
Josh van Leeuwen a3f06e444a
concurrency: Make runner closer take a logger (#116)
* concurrency: Make runner closer take a logger

Signed-off-by: joshvanl <me@joshvanl.dev>

* Linting

Signed-off-by: joshvanl <me@joshvanl.dev>

* Linter

Signed-off-by: joshvanl <me@joshvanl.dev>

* Linter

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2025-04-10 15:34:54 -03:00
Josh van Leeuwen 39c4bf57bd
concurrency/lock: Adds context and outercancel locks (#115)
* concurrency/lock: Adds context and outercancel locks

Adds context lock which will cancel and return an error to Lock & RLock
if the given context cancels before the lock is achieved.

Adds outercancel lock which will cancel all RLocks in progress if the
outer Lock is called.

Signed-off-by: joshvanl <me@joshvanl.dev>

* lint

Signed-off-by: joshvanl <me@joshvanl.dev>

* Lint

Signed-off-by: joshvanl <me@joshvanl.dev>

* Fix tests

Signed-off-by: joshvanl <me@joshvanl.dev>

* lint

Signed-off-by: joshvanl <me@joshvanl.dev>

* lint

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2025-02-24 11:15:07 -08:00
Josh van Leeuwen 6271c8be59
Signals: Add cause to signal context cancel (#114)
* Signals: Add cause to signal context cancel

Signed-off-by: joshvanl <me@joshvanl.dev>

* lint

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2025-02-04 17:31:44 -08:00
Luis Rascão c46009f360
crypto/spiffe: adds a multi trust anchor selector (#113)
* crypto/spiffe: adds a multi trust anchor selector

Adds a new trust anchors provider that returns different trust
anchors depending on the requested trust domain out of a pre-loaded
set.

Signed-off-by: Luis Rascao <luis.rascao@gmail.com>

* fixup! crypto/spiffe: adds a multi trust anchor selector

Signed-off-by: Luis Rascao <luis.rascao@gmail.com>

---------

Signed-off-by: Luis Rascao <luis.rascao@gmail.com>
2025-01-28 21:07:41 -08:00
Josh van Leeuwen c90b807d32
concurrency/dir & WriteIdentityToFile (#112)
* concurrency/dir & WriteIdentityToFile

Adds concurrency/dir package to handle atomically writing files to a
directory.

Adds support for spiffe to optionally write the identity certificate,
private key and trust bundle to a given directory.

Signed-off-by: joshvanl <me@joshvanl.dev>

* Adds Dir comment

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2025-01-21 06:16:49 -08:00
Josh van Leeuwen fb19570696
events/broadcaster Buffer val when writing to subscriber channel (#111)
Signed-off-by: joshvanl <me@joshvanl.dev>
2025-01-10 11:22:55 -08:00
Josh van Leeuwen 65ba3783f2
Adds events/broadcaster (#110)
* Adds events/broadcaster

Adds a generic buffered dmessage broadcaster which will relay a typed
message to a dynamic set of subscribers.

Signed-off-by: joshvanl <me@joshvanl.dev>

* Lint

Signed-off-by: joshvanl <me@joshvanl.dev>

* Review comments

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2025-01-06 15:46:09 +00:00
Jake Engelberg 30e2c24840
Add error code and category funcs (#109) 2024-11-27 08:52:51 -08:00
Josh van Leeuwen 24b59a803d
Adds generic ring (#108)
* Adds generic ring

Adds a generic implementation of the stdblib ring buffer so that each
ring `Value` can be a concrete type.
https://pkg.go.dev/container/ring

Adds `Len() int` and `Keys() []K` func to the generic Map cmap.

Changes `events/queue` Processor `Queueable` to be an exported type. No
functional change, but consumed types should be exported.

Signed-off-by: joshvanl <me@joshvanl.dev>

* Adds ring_test.go

Signed-off-by: joshvanl <me@joshvanl.dev>

* Linting

Signed-off-by: joshvanl <me@joshvanl.dev>

* Linting

Signed-off-by: joshvanl <me@joshvanl.dev>

* Update Do func to be typed

Signed-off-by: joshvanl <me@joshvanl.dev>

* Adds ring/buffered

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2024-11-20 13:48:57 -08:00
Josh van Leeuwen d37dc603d0
Update go to 1.23.1, golangci-lint 1.61.0 (#105)
* Update go to 1.23.1, golangci-lint 1.61.0

Signed-off-by: joshvanl <me@joshvanl.dev>

* Adds only new issues

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2024-10-21 13:52:18 -07:00
Josh van Leeuwen 866002abe6
Adds FIFO concurrent lock & generic concurrent Slice (#107)
* Adds FIFO concurrent lock & generic concurrent Slice

Adds a new concurrency/fifo package which implements a fifo mutex, as
well as a concurrently safe comparable indexed map of fifo mutexes.

Adds a simple generic concurrently safe slice implementation, which can
currently only grow.

Moves the map generic implementations in `/concurrency` to
`/concurrency/cmap`.

Signed-off-by: joshvanl <me@joshvanl.dev>

* Linting

Signed-off-by: joshvanl <me@joshvanl.dev>

* Move concurency/slice.go to concurency/slice/slice.go and add concurency/slice/string.go

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2024-10-15 06:03:26 -07:00
Josh van Leeuwen bc3a4f0fb4
events/queue: Don't return queue error when closed (#106)
`Enqueue` and `Dequeue` returned an error which was cumbersome for
consumers to need to check and deal with. The only time this error would
occur would be when the queue was closed, so instead of returning an
error we simply ignore the queue event.

Signed-off-by: joshvanl <me@joshvanl.dev>
2024-10-07 07:39:32 -07:00
Josh van Leeuwen 2d6ff15a97
Map: Adds concurrency/map (#104) 2024-09-23 21:10:40 -07:00
Elena Kolevska 3823663aa4
Adds a GetEnvIntWithRange utility function (#102)
* Adds a GetEnvIntWithRange utility function

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Fixes linter

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Update to using time duration instead of int for seconds

Signed-off-by: Elena Kolevska <elena@kolevska.com>

---------

Signed-off-by: Elena Kolevska <elena@kolevska.com>
2024-09-09 14:50:17 -07:00
Josh van Leeuwen 502671bade
concurrency/mutexmap Move Unlock to after operation (#101)
Signed-off-by: joshvanl <me@joshvanl.dev>
2024-09-09 14:49:03 -07:00
Josh van Leeuwen 26b564d9d0
MutexMap: Adds DeleteRUnlock and fixes RLock/RUnlock (#100)
Signed-off-by: joshvanl <me@joshvanl.dev>
2024-07-23 17:01:21 -07:00
Josh van Leeuwen 58c6d9df14
concurrency/mutexmap: Adds DeleteUnlock (#99)
Signed-off-by: joshvanl <me@joshvanl.dev>
2024-07-22 09:34:53 -07:00
Sam e2508d6e9e
fix(security): update vulnerabilities (#96)
* fix(security): update vulnerabilities

Signed-off-by: Samantha Coyle <sam@diagrid.io>

* style: make linter happy

Signed-off-by: Samantha Coyle <sam@diagrid.io>

* fix: add another fix for a sec vul

Signed-off-by: Samantha Coyle <sam@diagrid.io>

---------

Signed-off-by: Samantha Coyle <sam@diagrid.io>
2024-06-24 14:24:34 -07:00
Elena Kolevska 106329e583
Mutexmap (#95)
* Adds Mutex Map

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Adds an atomic map

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* More work on atomic map and mutex map

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Fixes, improvements and more tests

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Updates interface

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Linter

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Refactors atomic map to use generics

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* cleanups

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Apply suggestions from code review

Co-authored-by: Cassie Coyle <cassie.i.coyle@gmail.com>
Signed-off-by: Elena Kolevska <elena-kolevska@users.noreply.github.com>

* small reorg

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Adds ItemCount()

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Removes atomicmap in favour of haxmap

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* formats fix and adds comment

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Update concurrency/mutexmap.go

Co-authored-by: Josh van Leeuwen <me@joshvanl.dev>
Signed-off-by: Elena Kolevska <elena-kolevska@users.noreply.github.com>

* Uses built in `clear`

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Revert "Removes atomicmap in favour of haxmap"

This reverts commit 20ca9ad197.

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Uses clear() for atomic map too

Signed-off-by: Elena Kolevska <elena@kolevska.com>

---------

Signed-off-by: Elena Kolevska <elena@kolevska.com>
Signed-off-by: Elena Kolevska <elena-kolevska@users.noreply.github.com>
Co-authored-by: Cassie Coyle <cassie.i.coyle@gmail.com>
Co-authored-by: Josh van Leeuwen <me@joshvanl.dev>
2024-05-23 15:57:05 -07:00
Annu Singh ccffb60016
fixing a dead hyperlink (#94)
Signed-off-by: Annu Singh <annu.singh@terramate.io>
2024-04-16 13:28:40 -07:00
Josh van Leeuwen a3f906d609
Adds crypto/spiffe (#92)
* Adds crypto/spiffe

Adds spiffe package to crypto. This is a refactored version of the
existing `pkg/security` package. This new package is more modulated and
fuller test coverage.

This package has been moved so that it can be both imported by dapr &
components-contrib, as well as making the package more suitable for
further development to support X.509 Component auth.
https://github.com/dapr/proposals/pull/51

Also moves in `test/utils` from dapr to `crypto/test` for shared usage.

Signed-off-by: joshvanl <me@joshvanl.dev>

* Adds crypto/spiffe/context

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2024-04-15 10:19:26 -07:00
lee 0c7cfce53d
Add placement error code (#91)
Signed-off-by: Lee Fowler <fowler.lee8@gmail.com>
2024-04-02 13:38:09 +03:00
Josh van Leeuwen 6c3b2ee1ef
Events: type Batcher value & ensure queue order (#89)
* Events: type Batcher value & ensure queue order

Update Batcher to allow for typed value types.

Update Batcher and Queue to execute values in order they were added.

Signed-off-by: joshvanl <me@joshvanl.dev>

* Delay batcher to ensure key is sent in order

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2024-03-26 13:31:51 +02:00
Sam e33fbab745
feat: enable original key to be returned for metadata property fields (#88)
Signed-off-by: Samantha Coyle <sam@diagrid.io>
Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
Co-authored-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
2024-03-06 07:26:01 -08:00
Mike Nguyen 050e34c9b9
chore: bump lestrrat-go/jwx/v2 from v2.0.15 to v2.0.20 (#86)
Signed-off-by: mikeee <hey@mike.ee>
2024-02-26 09:24:31 -08:00
Chaitanya Bhangale 9e733a35f1
Add cryptography error code (#84)
Co-authored-by: Chaitanya Bhangale <chaitanyabhangale@Chaitanyas-MacBook-Pro-2.local>
2024-02-19 13:42:03 -08:00
Josh van Leeuwen 858719eb78
Change `events/batcher` to use `events/queue` as backend. (#82)
* events/batcher: use events/queue as queue backend

Signed-off-by: joshvanl <me@joshvanl.dev>

* Make events/queue/queue key type comparable

Signed-off-by: joshvanl <me@joshvanl.dev>

* Explicitly define NewProcessor generic type in test

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2024-01-15 09:08:33 -08:00
Alessandro (Ale) Segala c24d1d28cf
JWKSCache: add option to set CA certificate to trust (#81)
This is helpful when the JWKS is located on a HTTPS endpoint and the certificate is signed by a custom CA.

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
2024-01-11 10:59:16 -08:00
Alessandro (Ale) Segala 77f7f031c9
Add `ttlcache` package (#80)
* Add `ttlcache` package

This implements an in-memory cache with a TTL for automatically expiring records

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>

* Add delete method

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>

---------

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
2023-12-27 07:25:56 -08:00
Elena Kolevska abe711ef62
Updates path in README.md (#79)
Signed-off-by: Elena Kolevska <elena-kolevska@users.noreply.github.com>
2023-12-20 08:21:41 -08:00
Cassie Coyle fd317d255e
Add `errors` package (#77)
* Move dapr/concurrency to kit (#72)

* Move dapr/concurrency to kit

Does not include any code change

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>

* Fixed copyright year

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>

* Improved memory usage in error collection

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>

---------

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* Move `pkg/signals` from dapr/dapr to kit (#70)

No code changes

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* Move dapr/utils/streams to kit (#68)

* Move dapr/utils/streams to kit

No code changes

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>

* 💄

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>

* Lint

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>

---------

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* Migrate metadata decoder from components-contrib to kit (#74)

* Migrate metadata decoder from components-contrib to kit

Required creating the `utils` package for utils.IsTruthy too (ported from runtime)

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>

* Lint

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>

---------

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* initial standardized err pkg: errfmt

Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* tweaks to error pkg and update tests. need to confirm reason

Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* cleanup test

Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* add new func for err. change to protojson for http. need to figure out grpc status tho

Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* update status name

Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* wip: update JSONErrorValue

Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* Updates err to json conversion. Organises error messages and codes

Signed-off-by: Elena Kolevska <elena@kolevska.com>
Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* add type to http json output. tests are a WIP

Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* add all details, update tests, prefixes/postfixes

Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* add README

Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* use strongly-typed struct for errJSON

Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* update README

Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* Adds the option to add a help link detail and a field violation detail

Signed-off-by: Elena Kolevska <elena@kolevska.com>
Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* Update fswatcher to use /events/batcher (#75)

* Update fswatcher to use /events/batcher

Signed-off-by: joshvanl <me@joshvanl.dev>

* Linting

Signed-off-by: joshvanl <me@joshvanl.dev>

* Linting

Signed-off-by: joshvanl <me@joshvanl.dev>

* Add sleep to wait for windows fsnotify to become ready

Signed-off-by: joshvanl <me@joshvanl.dev>

* Increase time for event to be received to 1 second

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* Adds tests for WithHelp and err.WithFieldViolation

Signed-off-by: Elena Kolevska <elena@kolevska.com>
Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* rebase and update proto field access to rebased code

Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* gofumpt

Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* errJson -> errJSON and update proto field access

Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* HttpCode -> HTTPCode per CI warnings

Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* rm reason since its not used

Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* update return type in README example

Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* use builder, add errorInfo check to Build(), update and add tests for new funcs, add getters for grpc/http codes

Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* appease CI

Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* update README

Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* make GRPCStatus val receiver

Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* Update messages.go

Signed-off-by: Artur Souza <asouza.pro@gmail.com>

* add test to ensure we have a switch for all google err_detail types

Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* rebase and update log

Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* re-export ErrorBuilder

Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* Update errors/errors.go

Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
Signed-off-by: Cassie Coyle <cassie.i.coyle@gmail.com>

* use ast pkg to dynamically grab our errTypes in the switch case instead of hard coding it

Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* appease CI

Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* add FromError func

Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

* account for error wrapping

Signed-off-by: Cassandra Coyle <cassie@diagrid.io>

---------

Signed-off-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
Signed-off-by: Cassandra Coyle <cassie@diagrid.io>
Signed-off-by: Elena Kolevska <elena@kolevska.com>
Signed-off-by: joshvanl <me@joshvanl.dev>
Signed-off-by: Cassie Coyle <cassie.i.coyle@gmail.com>
Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
Co-authored-by: Elena Kolevska <elena@kolevska.com>
Co-authored-by: Josh van Leeuwen <me@joshvanl.dev>
2023-12-19 08:42:36 -08:00
Josh van Leeuwen df64d3a144
Update fswatcher to use /events/batcher (#75)
* Update fswatcher to use /events/batcher

Signed-off-by: joshvanl <me@joshvanl.dev>

* Linting

Signed-off-by: joshvanl <me@joshvanl.dev>

* Linting

Signed-off-by: joshvanl <me@joshvanl.dev>

* Add sleep to wait for windows fsnotify to become ready

Signed-off-by: joshvanl <me@joshvanl.dev>

* Increase time for event to be received to 1 second

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2023-11-15 16:36:20 -08:00
Alessandro (Ale) Segala 0e1fd37fc4
Migrate metadata decoder from components-contrib to kit (#74)
* Migrate metadata decoder from components-contrib to kit

Required creating the `utils` package for utils.IsTruthy too (ported from runtime)

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>

* Lint

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>

---------

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
2023-10-31 14:15:30 -07:00
Alessandro (Ale) Segala 2e939bc273
Move dapr/utils/streams to kit (#68)
* Move dapr/utils/streams to kit

No code changes

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>

* 💄

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>

* Lint

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>

---------

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
2023-10-31 13:12:15 -07:00
Alessandro (Ale) Segala 76c6281dda
Move `pkg/signals` from dapr/dapr to kit (#70)
No code changes

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
2023-10-30 11:56:03 -07:00
Alessandro (Ale) Segala 49532df126
Move dapr/concurrency to kit (#72)
* Move dapr/concurrency to kit

Does not include any code change

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>

* Fixed copyright year

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>

* Improved memory usage in error collection

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>

---------

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
2023-10-30 11:55:52 -07:00
Alessandro (Ale) Segala c0ebd07f3a
Migrate dapr/utils/byteslicepool to kit (#73)
No code changes (aside from those needed to appease the linter)

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
2023-10-30 11:55:45 -07:00
Alessandro (Ale) Segala 2d30434d91
Update linter to 1.55.1 and fix linter errors (#71)
* Update linter to 1.55.1 and fix linter errors

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>

* Also disable tagalign linter

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>

---------

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
2023-10-30 11:55:29 -07:00
Alessandro (Ale) Segala a0df11f512
Tiny: use `Load` instead of `CompareAndSwap` for detecting recoveries (#69)
Should be _slightly_ faster

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
2023-10-30 16:18:57 +00:00
Josh van Leeuwen 4f434ebc3f
Adds events/ratelimiting/coalescing (#64)
* Adds events/ratelimiting/coalecing

events/ratelimiting is a new utility package to provides helpers which
can be used to rate limit events.

Coalecing is a new helper that will exponentially rate limit events. It
will coalesce events into a single event if they occur within the same
rate limiting window. Coalesce also has the option to forcibly fire an
event when the number of events reaches a certain threshold. Added to
prevent events from never being fired in a high throughput scenario.

Signed-off-by: joshvanl <me@joshvanl.dev>
Signed-off-by: Josh van Leeuwen <me@joshvanl.dev>
Signed-off-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
2023-10-25 13:15:16 -07:00
Josh van Leeuwen 55bfe3b570
Adds `--build-tags=unit` to golang lint config file (#67)
* Adds `--build-tags=unit` to golang lint GitHub Action

Signed-off-by: joshvanl <me@joshvanl.dev>

* Move golang-list workflow build-tags CLI args to .golangci.yml

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2023-10-24 08:07:56 -07:00
Josh van Leeuwen 099b0404d1
Use mocked clock in cron tests, rather than wall clock (#66)
PR updates the cron tests to use a mocked clock instead of the wall
in order to speed up the tests from ~40s to ~1.5s. This also has the
benefit of controlling time so that tests are more deterministic.

Signed-off-by: joshvanl <me@joshvanl.dev>
2023-10-23 09:53:28 -07:00
Josh van Leeuwen 549b95799f
Fixing running tests on aarch64 arm machines (#65)
* Fixing running tests on aarch64 arm machines

Signed-off-by: joshvanl <me@joshvanl.dev>

* Use different Makefile target for doing go test with race enabled

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
2023-10-22 06:51:12 -07:00
Roberto Rojas a043330f5d
Dapr Error Handling/Codes - Add Internal Shared Functionality (#57)
Signed-off-by: robertojrojas <robertojrojas@gmail.com>
Signed-off-by: Roberto Rojas <robertojrojas@gmail.com>
Signed-off-by: Josh van Leeuwen <me@joshvanl.dev>
Signed-off-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
Co-authored-by: Josh van Leeuwen <me@joshvanl.dev>
Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
2023-09-14 12:04:04 -07:00
137 changed files with 13455 additions and 1002 deletions

View File

@ -21,11 +21,10 @@ jobs:
name: Build ${{ matrix.target_os }}_${{ matrix.target_arch }} binaries
runs-on: ${{ matrix.os }}
env:
GOVER: "1.20"
GOOS: ${{ matrix.target_os }}
GOARCH: ${{ matrix.target_arch }}
GOPROXY: https://proxy.golang.org
GOLANGCI_LINT_VER: v1.51.2
GOLANGCI_LINT_VER: v1.64.8
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macOS-latest]
@ -43,15 +42,15 @@ jobs:
- os: macOS-latest
target_arch: arm
steps:
- name: Set up Go ${{ env.GOVER }}
uses: actions/setup-go@v1
with:
go-version: ${{ env.GOVER }}
- name: Check out code into the Go module directory
uses: actions/checkout@v2
uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v4
with:
go-version-file: 'go.mod'
- name: Run golangci-lint
if: matrix.target_arch == 'amd64' && matrix.target_os == 'linux'
uses: golangci/golangci-lint-action@v3.2.0
uses: golangci/golangci-lint-action@v3
with:
version: ${{ env.GOLANGCI_LINT_VER }}
skip-cache: true
@ -65,4 +64,4 @@ jobs:
run: make test
- name: Codecov
if: matrix.target_arch == 'amd64' && matrix.target_os == 'linux'
uses: codecov/codecov-action@v1
uses: codecov/codecov-action@v3

5
.gitignore vendored Normal file
View File

@ -0,0 +1,5 @@
**/.DS_Store
.idea
.vscode
.vs
vendor

View File

@ -4,7 +4,7 @@ run:
concurrency: 4
# timeout for analysis, e.g. 30s, 5m, default is 1m
deadline: 10m
timeout: 15m
# exit code when at least one issue was found, default is 1
issues-exit-code: 1
@ -13,31 +13,36 @@ run:
tests: true
# list of build tags, all linters use it. Default is empty list.
#build-tags:
# - mytag
# which dirs to skip: they won't be analyzed;
# can use regexp here: generated.*, regexp is applied on full path;
# default value is empty list, but next dirs are always skipped independently
# from this option's value:
# third_party$, testdata$, examples$, Godeps$, builtin$
skip-dirs:
- ^pkg.*client.*clientset.*versioned.*
- ^pkg.*client.*informers.*externalversions.*
- ^pkg.*proto.*
build-tags:
- unit
- allcomponents
- subtlecrypto
# which files to skip: they will be analyzed, but issues from them
# won't be reported. Default value is empty list, but there is
# no need to include all autogenerated files, we confidently recognize
# autogenerated files. If it's not please let us know.
skip-files:
# skip-files:
# - ".*\\.my\\.go$"
# - lib/bad.go
issues:
# which dirs to skip: they won't be analyzed;
# can use regexp here: generated.*, regexp is applied on full path;
# default value is empty list, but next dirs are always skipped independently
# from this option's value:
# third_party$, testdata$, examples$, Godeps$, builtin$
exclude-dirs:
- ^pkg.*client.*clientset.*versioned.*
- ^pkg.*client.*informers.*externalversions.*
- ^pkg.*proto.*
- pkg/proto
# output configuration options
output:
# colored-line-number|line-number|json|tab|checkstyle, default is "colored-line-number"
format: tab
formats:
- format: tab
# print lines of code with issue, default is true
print-issued-lines: true
@ -57,23 +62,19 @@ linters-settings:
# default is false: such cases aren't reported by default.
check-blank: false
# [deprecated] comma-separated list of pairs of the form pkg:regex
# the regex is used to ignore names within pkg. (default "fmt:.*").
# see https://github.com/kisielk/errcheck#the-deprecated-method for details
ignore: fmt:.*,io/ioutil:^Read.*
exclude-functions:
- fmt:.*
- io/ioutil:^Read.*
# path to a file containing a list of functions to exclude from checking
# see https://github.com/kisielk/errcheck#excluding-functions for details
exclude:
# exclude:
funlen:
lines: 60
statements: 40
govet:
# report about shadowed variables
check-shadowing: true
# settings per analyzer
settings:
printf: # analyzer name, run `go tool vet help` to see all analyzers
@ -86,13 +87,12 @@ linters-settings:
# enable or disable analyzers by name
enable:
- atomicalign
enable-all: false
disable:
- shadow
enable-all: false
disable-all: false
golint:
revive:
# minimal confidence for issues, default is 0.8
min-confidence: 0.8
confidence: 0.8
gofmt:
# simplify code: gofmt with `-s` option, true by default
simplify: true
@ -106,9 +106,6 @@ linters-settings:
gocognit:
# minimal code complexity to report, 30 by default (but we recommend 10-20)
min-complexity: 10
maligned:
# print struct with more effective memory layout or not, false by default
suggest-new: true
dupl:
# tokens count to trigger issue, 150 by default
threshold: 100
@ -118,36 +115,63 @@ linters-settings:
# minimal occurrences count to trigger, 3 by default
min-occurrences: 5
depguard:
list-type: denylist
include-go-root: false
packages-with-error-message:
- "github.com/Sirupsen/logrus": "must use github.com/dapr/kit/logger"
- "github.com/agrea/ptr": "must use github.com/dapr/kit/ptr"
- "go.uber.org/atomic": "must use sync/atomic"
- "golang.org/x/net/context": "must use context"
- "github.com/pkg/errors": "must use standard library (errors package and/or fmt.Errorf)"
- "github.com/go-chi/chi$": "must use github.com/go-chi/chi/v5"
- "github.com/cenkalti/backoff$": "must use github.com/cenkalti/backoff/v4"
- "github.com/cenkalti/backoff/v2": "must use github.com/cenkalti/backoff/v4"
- "github.com/cenkalti/backoff/v3": "must use github.com/cenkalti/backoff/v4"
- "github.com/benbjohnson/clock": "must use k8s.io/utils/clock"
- "github.com/ghodss/yaml": "must use sigs.k8s.io/yaml"
- "gopkg.in/yaml.v2": "must use gopkg.in/yaml.v3"
- "github.com/golang-jwt/jwt": "must use github.com/lestrrat-go/jwx/v2"
- "github.com/golang-jwt/jwt/v2": "must use github.com/lestrrat-go/jwx/v2"
- "github.com/golang-jwt/jwt/v3": "must use github.com/lestrrat-go/jwx/v2"
- "github.com/golang-jwt/jwt/v4": "must use github.com/lestrrat-go/jwx/v2"
- "github.com/gogo/status": "must use google.golang.org/grpc/status"
- "github.com/gogo/protobuf": "must use google.golang.org/protobuf"
- "github.com/lestrrat-go/jwx/jwa": "must use github.com/lestrrat-go/jwx/v2"
- "github.com/lestrrat-go/jwx/jwt": "must use github.com/lestrrat-go/jwx/v2"
- "github.com/labstack/gommon/log": "must use github.com/dapr/kit/logger"
- "github.com/gobuffalo/logger": "must use github.com/dapr/kit/logger"
rules:
main:
deny:
- pkg: "github.com/Sirupsen/logrus"
desc: "must use github.com/dapr/kit/logger"
- pkg: "github.com/agrea/ptr"
desc: "must use github.com/dapr/kit/ptr"
- pkg: "go.uber.org/atomic"
desc: "must use sync/atomic"
- pkg: "golang.org/x/net/context"
desc: "must use context"
- pkg: "github.com/pkg/errors"
desc: "must use standard library (errors package and/or fmt.Errorf)"
- pkg: "github.com/go-chi/chi$"
desc: "must use github.com/go-chi/chi/v5"
- pkg: "github.com/cenkalti/backoff$"
desc: "must use github.com/cenkalti/backoff/v4"
- pkg: "github.com/cenkalti/backoff/v2"
desc: "must use github.com/cenkalti/backoff/v4"
- pkg: "github.com/cenkalti/backoff/v3"
desc: "must use github.com/cenkalti/backoff/v4"
- pkg: "github.com/benbjohnson/clock"
desc: "must use k8s.io/utils/clock"
- pkg: "github.com/ghodss/yaml"
desc: "must use sigs.k8s.io/yaml"
- pkg: "gopkg.in/yaml.v2"
desc: "must use gopkg.in/yaml.v3"
- pkg: "github.com/golang-jwt/jwt"
desc: "must use github.com/lestrrat-go/jwx/v2"
- pkg: "github.com/golang-jwt/jwt/v2"
desc: "must use github.com/lestrrat-go/jwx/v2"
- pkg: "github.com/golang-jwt/jwt/v3"
desc: "must use github.com/lestrrat-go/jwx/v2"
- pkg: "github.com/golang-jwt/jwt/v4"
desc: "must use github.com/lestrrat-go/jwx/v2"
# pkg: Commonly auto-completed by gopls
- pkg: "github.com/gogo/status"
desc: "must use google.golang.org/grpc/status"
- pkg: "github.com/gogo/protobuf"
desc: "must use google.golang.org/protobuf"
- pkg: "github.com/lestrrat-go/jwx/jwa"
desc: "must use github.com/lestrrat-go/jwx/v2"
- pkg: "github.com/lestrrat-go/jwx/jwt"
desc: "must use github.com/lestrrat-go/jwx/v2"
- pkg: "github.com/labstack/gommon/log"
desc: "must use github.com/dapr/kit/logger"
- pkg: "github.com/gobuffalo/logger"
desc: "must use github.com/dapr/kit/logger"
- pkg: "k8s.io/utils/pointer"
desc: "must use github.com/dapr/kit/ptr"
- pkg: "k8s.io/utils/ptr"
desc: "must use github.com/dapr/kit/ptr"
misspell:
# Correct spellings using locale preferences for US or UK.
# Default is to use a neutral variety of English.
# Setting locale to US will correct the British spelling of 'colour' to 'color'.
locale: default
# locale: default
ignore-words:
- someword
lll:
@ -156,23 +180,9 @@ linters-settings:
line-length: 120
# tab width in spaces. Default to 1.
tab-width: 1
unused:
# treat code as a program (not a library) and report unused exported identifiers; default is false.
# XXX: if you enable this setting, unused will report a lot of false-positives in text editors:
# if it's called for subdir of a project it can't find funcs usages. All text editor integrations
# with golangci-lint call it on a directory with the changed file.
check-exported: false
unparam:
# Inspect exported functions, default is false. Set to true if no external program/library imports your code.
# XXX: if you enable this setting, unparam will report a lot of false-positives in text editors:
# if it's called for subdir of a project it can't find external interfaces. All text editor integrations
# with golangci-lint call it on a directory with the changed file.
check-exported: false
nakedret:
# make an issue if func has more lines of code than this setting and it has naked returns; default is 30
max-func-lines: 30
nolintlint:
allow-unused: true
prealloc:
# XXX: we don't recommend using this linter before doing performance profiling.
# For most programs usage of prealloc will be a premature optimization.
@ -187,7 +197,6 @@ linters-settings:
# See https://go-critic.github.io/overview#checks-overview
# To check which checks are enabled run `GL_DEBUG=gocritic golangci-lint run`
# By default list of stable checks is used.
enabled-checks:
# Which checks should be disabled; can't be combined with 'enabled-checks'; default is empty
disabled-checks:
@ -235,61 +244,51 @@ linters-settings:
allow-assign-and-call: true
# Allow multiline assignments to be cuddled. Default is true.
allow-multiline-assign: true
# Allow case blocks to end with a whitespace.
allow-case-traling-whitespace: true
# Allow declarations (var) to be cuddled.
allow-cuddle-declarations: false
# If the number of lines in a case block is equal to or lager than this number,
# the case *must* end white a newline.
# https://github.com/bombsimon/wsl/blob/master/doc/configuration.md#force-case-trailing-whitespace
# Default: 0
force-case-trailing-whitespace: 1
linters:
fast: false
enable-all: true
disable:
# TODO Enforce the below linters later
- musttag
- dupl
- nonamedreturns
- errcheck
- funlen
- gochecknoglobals
- gochecknoinits
- gocyclo
- gocognit
- nosnakecase
- varcheck
- structcheck
- deadcode
- godox
- interfacer
- lll
- maligned
- scopelint
- unparam
- wsl
- gomnd
- testpackage
- goerr113
- nestif
- nlreturn
- exhaustive
- exhaustruct
- noctx
- gci
- golint
- tparallel
- paralleltest
- wrapcheck
- tagliatelle
- ireturn
- exhaustive
- exhaustivestruct
- exhaustruct
- errchkjson
- contextcheck
- gomoddirectives
- godot
- cyclop
- varnamelen
- gosec
- errorlint
- forcetypeassert
- ifshort
- maintidx
- nilnil
- predeclared
@ -298,4 +297,13 @@ linters:
- wastedassign
- containedctx
- gosimple
- forbidigo
- nonamedreturns
- asasalint
- rowserrcheck
- sqlclosecheck
- inamedparam
- tagalign
- mnd
- canonicalheader
- err113
- fatcontext

View File

@ -31,6 +31,10 @@ else ifeq ($(shell echo $(LOCAL_ARCH) | head -c 5),armv8)
TARGET_ARCH_LOCAL=arm64
else ifeq ($(shell echo $(LOCAL_ARCH) | head -c 4),armv)
TARGET_ARCH_LOCAL=arm
else ifeq ($(shell echo $(LOCAL_ARCH) | head -c 5),arm64)
TARGET_ARCH_LOCAL=arm64
else ifeq ($(shell echo $(LOCAL_ARCH) | head -c 7),aarch64)
TARGET_ARCH_LOCAL=arm64
else
TARGET_ARCH_LOCAL=amd64
endif
@ -61,16 +65,26 @@ endif
################################################################################
.PHONY: test
test:
go test ./... $(COVERAGE_OPTS) $(BUILDMODE)
go test -tags unit ./... $(COVERAGE_OPTS) $(BUILDMODE)
.PHONY: test-race
test-race:
CGO_ENABLED=1 go test -race -tags unit ./... $(COVERAGE_OPTS) $(BUILDMODE)
################################################################################
# Target: lint #
################################################################################
.PHONY: lint
lint:
# Due to https://github.com/golangci/golangci-lint/issues/580, we need to add --fix for windows
$(GOLANGCI_LINT) run --timeout=20m
################################################################################
# Target: lint-fix #
################################################################################
.PHONY: lint-fix
lint-fix:
$(GOLANGCI_LINT) run --timeout=20m --fix
################################################################################
# Target: go.mod #
################################################################################

View File

@ -0,0 +1,85 @@
/*
Copyright 2021 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package byteslicepool
import (
"sync"
)
/*
Originally based on https://github.com/xdg-go/zzz-slice-recycling
Copyright (C) 2019 by David A. Golden
License (Apache2): https://github.com/xdg-go/zzz-slice-recycling/blob/master/LICENSE
*/
// ByteSlicePool is a wrapper around sync.Pool to get []byte objects with a given capacity.
type ByteSlicePool struct {
MinCap int
pool *sync.Pool
}
// NewByteSlicePool returns a new ByteSlicePool object.
func NewByteSlicePool(minCap int) *ByteSlicePool {
return &ByteSlicePool{
MinCap: minCap,
pool: &sync.Pool{},
}
}
// Get a slice from the pool.
// The capacity parameter is used only if we need to allocate a new byte slice; there's no guarantee a slice retrieved from the pool will have enough capacity for that.
func (sp ByteSlicePool) Get(capacity int) []byte {
bp := sp.pool.Get()
if bp == nil {
if capacity < sp.MinCap {
capacity = sp.MinCap
}
return make([]byte, 0, capacity)
}
buf := bp.([]byte)
// This will be optimized by the compiler
for i := range buf {
buf[i] = 0
}
return buf[:0]
}
// Put a slice back in the pool.
func (sp ByteSlicePool) Put(bs []byte) {
// The linter here complains because we're putting a slice rather than a pointer in the pool.
// The complain is valid, because doing so does cause an allocation for the local copy of the slice header.
// However, this is ok for us because given how we use ByteSlicePool, we can't keep around the pointer we took out.
// See this thread for some discussion: https://github.com/dominikh/go-tools/issues/1336
//nolint:staticcheck
sp.pool.Put(bs)
}
// Resize a byte slice, making sure that it has enough capacity for a given size.
func (sp ByteSlicePool) Resize(orig []byte, size int) []byte {
if size < cap(orig) {
return orig[0:size]
}
// Allocate a new byte slice and then discard the old one, too small, so it can be garbage collected
temp := make([]byte, size, max(size, cap(orig)*2))
copy(temp, orig)
return temp
}
func max(x, y int) int {
if x < y {
return y
}
return x
}

View File

@ -0,0 +1,59 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package byteslicepool
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestByteSlicePool(t *testing.T) {
minCap := 32
pool := NewByteSlicePool(minCap)
bs := pool.Get(minCap)
assert.Empty(t, bs)
assert.Equal(t, minCap, cap(bs))
pool.Put(bs)
bs2 := pool.Get(minCap)
assert.Equal(t, &bs, &bs2)
assert.Equal(t, minCap, cap(bs2))
for range minCap {
bs2 = append(bs2, 0)
}
// Less than minCap
// Capacity will not change after resize
size2 := 16
bs2 = pool.Resize(bs2, size2)
assert.Equal(t, size2, len(bs2)) //nolint:testifylint
assert.Equal(t, minCap, cap(bs2))
// Less than twice the minCap
// Will automatically expand to twice the original capacity
size3 := 48
bs2 = pool.Resize(bs2, size3)
assert.Equal(t, size3, len(bs2)) //nolint:testifylint
assert.Equal(t, minCap*2, cap(bs2))
// More than twice the minCap
// Will automatically expand to the specified size
size4 := 128
bs2 = pool.Resize(bs2, size4)
assert.Equal(t, size4, len(bs2)) //nolint:testifylint
assert.Equal(t, size4, cap(bs2))
}

214
concurrency/closer.go Normal file
View File

@ -0,0 +1,214 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package concurrency
import (
"context"
"errors"
"fmt"
"io"
"sync/atomic"
"time"
"k8s.io/utils/clock"
"github.com/dapr/kit/logger"
)
var ErrManagerAlreadyClosed = errors.New("runner manager already closed")
// RunnerCloserManager is a RunnerManager that also implements Closing of the
// added closers once the main runners are done.
type RunnerCloserManager struct {
// mngr implements the main RunnerManager.
mngr *RunnerManager
// closers are the closers to be closed once the main runners are done.
closers []func() error
// retErr is the error returned by the main runners and closers. Used to
// return the resulting error from Close().
retErr error
// fatalShutdownFn is called if the grace period is exceeded.
// Defined if the grace period is not nil.
fatalShutdownFn func()
// closeFatalShutdown closes the fatal shutdown goroutine. Closing is a no-op
// if fatalShutdownFn is nil.
closeFatalShutdown chan struct{}
clock clock.Clock
running atomic.Bool
closing atomic.Bool
closed atomic.Bool
closeCh chan struct{}
stopped chan struct{}
}
// NewRunnerCloserManager creates a new RunnerCloserManager with the given
// grace period and runners.
// If gracePeriod is nil, the grace period is infinite.
func NewRunnerCloserManager(log logger.Logger, gracePeriod *time.Duration, runners ...Runner) *RunnerCloserManager {
c := &RunnerCloserManager{
mngr: NewRunnerManager(runners...),
clock: clock.RealClock{},
stopped: make(chan struct{}),
closeCh: make(chan struct{}),
closeFatalShutdown: make(chan struct{}),
}
if gracePeriod == nil {
log.Warn("Graceful shutdown timeout is infinite, will wait indefinitely to shutdown")
return c
}
c.fatalShutdownFn = func() {
log.Fatal("Graceful shutdown timeout exceeded, forcing shutdown")
}
c.AddCloser(func() {
log.Debugf("Graceful shutdown timeout: %s", *gracePeriod)
t := c.clock.NewTimer(*gracePeriod)
defer t.Stop()
select {
case <-t.C():
c.fatalShutdownFn()
case <-c.closeFatalShutdown:
}
})
return c
}
// Add implements RunnerManager.Add.
func (c *RunnerCloserManager) Add(runner ...Runner) error {
if c.running.Load() {
return ErrManagerAlreadyStarted
}
return c.mngr.Add(runner...)
}
// AddCloser adds a closer to the list of closers to be closed once the main
// runners are done.
func (c *RunnerCloserManager) AddCloser(closers ...any) error {
if c.closing.Load() {
return ErrManagerAlreadyClosed
}
c.mngr.lock.Lock()
defer c.mngr.lock.Unlock()
var errs []error
for _, cl := range closers {
switch v := cl.(type) {
case io.Closer:
c.closers = append(c.closers, v.Close)
case func(context.Context) error:
c.closers = append(c.closers, func() error {
// We use a background context here since the fatalShutdownFn will kill
// the program if the grace period is exceeded.
return v(context.Background())
})
case func() error:
c.closers = append(c.closers, v)
case func():
c.closers = append(c.closers, func() error {
v()
return nil
})
default:
errs = append(errs, fmt.Errorf("unsupported closer type: %T", v))
}
}
return errors.Join(errs...)
}
// Add implements RunnerManager.Run.
func (c *RunnerCloserManager) Run(ctx context.Context) error {
if !c.running.CompareAndSwap(false, true) {
return ErrManagerAlreadyStarted
}
// Signal the manager is stopped.
defer close(c.stopped)
// If the main runner has at least one runner, add a closer that will
// close the context once Close() is called.
if len(c.mngr.runners) > 0 {
c.mngr.Add(func(ctx context.Context) error {
select {
case <-ctx.Done():
case <-c.closeCh:
}
return nil
})
}
errCh := make(chan error, len(c.closers))
go func() {
errCh <- c.mngr.Run(ctx)
}()
rErr := <-errCh
c.mngr.lock.Lock()
defer c.mngr.lock.Unlock()
c.closing.Store(true)
errs := make([]error, len(c.closers)+1)
errs[0] = rErr
for _, closer := range c.closers {
go func(closer func() error) {
errCh <- closer()
}(closer)
}
// Wait for all closers to be done.
for i := 1; i < len(c.closers)+1; i++ {
// Close the fatal shutdown goroutine if all closers are done. This is a
// no-op if the fatal go routine is not defined.
if i == len(c.closers) {
close(c.closeFatalShutdown)
}
errs[i] = <-errCh
}
c.retErr = errors.Join(errs...)
return c.retErr
}
// Close will close the main runners and then the closers.
func (c *RunnerCloserManager) Close() error {
if c.closed.CompareAndSwap(false, true) {
close(c.closeCh)
}
// If the manager is not running yet, we stop immediately.
if c.running.CompareAndSwap(false, true) {
close(c.stopped)
}
c.WaitUntilShutdown()
return c.retErr
}
// WaitUntilShutdown will block until the main runners and closers are done.
func (c *RunnerCloserManager) WaitUntilShutdown() {
<-c.stopped
}

1045
concurrency/closer_test.go Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,23 @@
//go:build unit
// +build unit
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package concurrency
// WithFatalShutdown sets the fatal shutdown function for the closer manager.
// Used for testing.
func (c *RunnerCloserManager) WithFatalShutdown(fn func()) {
c.fatalShutdownFn = fn
}

111
concurrency/cmap/atomic.go Normal file
View File

@ -0,0 +1,111 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cmap
import (
"sync"
"golang.org/x/exp/constraints"
)
type AtomicValue[T constraints.Integer] struct {
lock sync.RWMutex
value T
}
func (a *AtomicValue[T]) Load() T {
a.lock.RLock()
defer a.lock.RUnlock()
return a.value
}
func (a *AtomicValue[T]) Store(v T) {
a.lock.Lock()
defer a.lock.Unlock()
a.value = v
}
func (a *AtomicValue[T]) Add(v T) T {
a.lock.Lock()
defer a.lock.Unlock()
a.value += v
return a.value
}
type Atomic[K comparable, T constraints.Integer] interface {
Get(key K) (*AtomicValue[T], bool)
GetOrCreate(key K, createT T) *AtomicValue[T]
Delete(key K)
ForEach(fn func(key K, value *AtomicValue[T]))
Clear()
}
type atomicMap[K comparable, T constraints.Integer] struct {
lock sync.RWMutex
items map[K]*AtomicValue[T]
}
func NewAtomic[K comparable, T constraints.Integer]() Atomic[K, T] {
return &atomicMap[K, T]{
items: make(map[K]*AtomicValue[T]),
}
}
func (a *atomicMap[K, T]) Get(key K) (*AtomicValue[T], bool) {
a.lock.RLock()
defer a.lock.RUnlock()
item, ok := a.items[key]
if !ok {
return nil, false
}
return item, true
}
func (a *atomicMap[K, T]) GetOrCreate(key K, createT T) *AtomicValue[T] {
a.lock.RLock()
item, ok := a.items[key]
a.lock.RUnlock()
if !ok {
a.lock.Lock()
// Double-check the key exists to avoid race condition
item, ok = a.items[key]
if !ok {
item = &AtomicValue[T]{value: createT}
a.items[key] = item
}
a.lock.Unlock()
}
return item
}
func (a *atomicMap[K, T]) Delete(key K) {
a.lock.Lock()
delete(a.items, key)
a.lock.Unlock()
}
func (a *atomicMap[K, T]) ForEach(fn func(key K, value *AtomicValue[T])) {
a.lock.RLock()
defer a.lock.RUnlock()
for k, v := range a.items {
fn(k, v)
}
}
func (a *atomicMap[K, T]) Clear() {
a.lock.Lock()
defer a.lock.Unlock()
clear(a.items)
}

View File

@ -0,0 +1,79 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cmap
import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAtomicInt32_New_Get_Delete(t *testing.T) {
m := NewAtomic[string, int32]().(*atomicMap[string, int32])
require.NotNil(t, m)
require.NotNil(t, m.items)
require.Empty(t, m.items)
t.Run("basic operations", func(t *testing.T) {
key := "key1"
value := int32(10)
// Initially, the key should not exist
_, ok := m.Get(key)
require.False(t, ok)
// Add a value and check it
m.GetOrCreate(key, 0).Store(value)
result, ok := m.Get(key)
require.True(t, ok)
assert.Equal(t, value, result.Load())
// Delete the key and check it no longer exists
m.Delete(key)
_, ok = m.Get(key)
require.False(t, ok)
})
t.Run("concurrent access multiple keys", func(t *testing.T) {
var wg sync.WaitGroup
keys := []string{"key1", "key2", "key3"}
iterations := 100
wg.Add(len(keys) * 2)
for _, key := range keys {
go func(k string) {
defer wg.Done()
for range iterations {
m.GetOrCreate(k, 0).Add(1)
}
}(key)
go func(k string) {
defer wg.Done()
for range iterations {
m.GetOrCreate(k, 0).Add(-1)
}
}(key)
}
wg.Wait()
for _, key := range keys {
val, ok := m.Get(key)
require.True(t, ok)
require.Equal(t, int32(0), val.Load())
}
})
}

99
concurrency/cmap/map.go Normal file
View File

@ -0,0 +1,99 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cmap
import (
"sync"
)
// Map is a simple _typed_ map which is safe for concurrent use.
// Favoured over sync.Map as it is typed.
type Map[K comparable, T any] interface {
Clear()
Delete(key K)
Load(key K) (T, bool)
LoadAndDelete(key K) (T, bool)
Range(fn func(key K, value T) bool)
Store(key K, value T)
Len() int
Keys() []K
}
type mapimpl[K comparable, T any] struct {
lock sync.RWMutex
m map[K]T
}
func NewMap[K comparable, T any]() Map[K, T] {
return &mapimpl[K, T]{m: make(map[K]T)}
}
func (m *mapimpl[K, T]) Clear() {
m.lock.Lock()
defer m.lock.Unlock()
m.m = make(map[K]T)
}
func (m *mapimpl[K, T]) Delete(k K) {
m.lock.Lock()
defer m.lock.Unlock()
delete(m.m, k)
}
func (m *mapimpl[K, T]) Load(k K) (T, bool) {
m.lock.RLock()
defer m.lock.RUnlock()
v, ok := m.m[k]
return v, ok
}
func (m *mapimpl[K, T]) LoadAndDelete(k K) (T, bool) {
m.lock.Lock()
defer m.lock.Unlock()
v, ok := m.m[k]
delete(m.m, k)
return v, ok
}
func (m *mapimpl[K, T]) Range(fn func(K, T) bool) {
m.lock.RLock()
defer m.lock.RUnlock()
for k, v := range m.m {
if !fn(k, v) {
break
}
}
}
func (m *mapimpl[K, T]) Store(k K, v T) {
m.lock.Lock()
defer m.lock.Unlock()
m.m[k] = v
}
func (m *mapimpl[K, T]) Len() int {
m.lock.RLock()
defer m.lock.RUnlock()
return len(m.m)
}
func (m *mapimpl[K, T]) Keys() []K {
m.lock.Lock()
defer m.lock.Unlock()
keys := make([]K, 0, len(m.m))
for k := range m.m {
keys = append(keys, k)
}
return keys
}

150
concurrency/cmap/mutex.go Normal file
View File

@ -0,0 +1,150 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cmap
import (
"sync"
)
// Mutex is an interface that defines a thread-safe map with keys of type T associated to
// read-write mutexes (sync.RWMutex), allowing for granular locking on a per-key basis.
// This can be useful for scenarios where fine-grained concurrency control is needed.
//
// Methods:
// - Lock(key T): Acquires an exclusive lock on the mutex associated with the given key.
// - Unlock(key T): Releases the exclusive lock on the mutex associated with the given key.
// - RLock(key T): Acquires a read lock on the mutex associated with the given key.
// - RUnlock(key T): Releases the read lock on the mutex associated with the given key.
// - Delete(key T): Removes the mutex associated with the given key from the map.
// - Clear(): Removes all mutexes from the map.
// - ItemCount() int: Returns the number of items (mutexes) in the map.
// - DeleteUnlock(key T): Removes the mutex associated with the given key from the map and releases the lock.
// - DeleteRUnlock(key T): Removes the mutex associated with the given key from the map and releases the read lock.
type Mutex[T comparable] interface {
Lock(key T)
Unlock(key T)
RLock(key T)
RUnlock(key T)
Delete(key T)
Clear()
ItemCount() int
DeleteUnlock(key T)
DeleteRUnlock(key T)
}
type mutex[T comparable] struct {
lock sync.RWMutex
items map[T]*sync.RWMutex
}
func NewMutex[T comparable]() Mutex[T] {
return &mutex[T]{
items: make(map[T]*sync.RWMutex),
}
}
func (a *mutex[T]) Lock(key T) {
a.lock.RLock()
mutex, ok := a.items[key]
a.lock.RUnlock()
if ok {
mutex.Lock()
return
}
a.lock.Lock()
mutex, ok = a.items[key]
if !ok {
mutex = &sync.RWMutex{}
a.items[key] = mutex
}
a.lock.Unlock()
mutex.Lock()
}
func (a *mutex[T]) Unlock(key T) {
a.lock.RLock()
mutex, ok := a.items[key]
if ok {
mutex.Unlock()
}
a.lock.RUnlock()
}
func (a *mutex[T]) RLock(key T) {
a.lock.RLock()
mutex, ok := a.items[key]
a.lock.RUnlock()
if ok {
mutex.RLock()
return
}
a.lock.Lock()
mutex, ok = a.items[key]
if !ok {
mutex = &sync.RWMutex{}
a.items[key] = mutex
}
a.lock.Unlock()
mutex.RLock()
}
func (a *mutex[T]) RUnlock(key T) {
a.lock.RLock()
mutex, ok := a.items[key]
if ok {
mutex.RUnlock()
}
a.lock.RUnlock()
}
func (a *mutex[T]) Delete(key T) {
a.lock.Lock()
delete(a.items, key)
a.lock.Unlock()
}
func (a *mutex[T]) DeleteUnlock(key T) {
a.lock.Lock()
mutex, ok := a.items[key]
if ok {
mutex.Unlock()
}
delete(a.items, key)
a.lock.Unlock()
}
func (a *mutex[T]) DeleteRUnlock(key T) {
a.lock.Lock()
mutex, ok := a.items[key]
if ok {
mutex.RUnlock()
}
delete(a.items, key)
a.lock.Unlock()
}
func (a *mutex[T]) Clear() {
a.lock.Lock()
clear(a.items)
a.lock.Unlock()
}
func (a *mutex[T]) ItemCount() int {
a.lock.Lock()
defer a.lock.Unlock()
return len(a.items)
}

View File

@ -0,0 +1,119 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cmap
import (
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewMutex_Add_Delete(t *testing.T) {
mm := NewMutex[string]().(*mutex[string])
t.Run("New mutex map", func(t *testing.T) {
require.NotNil(t, mm)
require.NotNil(t, mm.items)
require.Empty(t, mm.items)
})
t.Run("Lock and unlock mutex", func(t *testing.T) {
mm.Lock("key1")
_, ok := mm.items["key1"]
require.True(t, ok)
mm.Unlock("key1")
})
t.Run("Concurrently lock and unlock mutexes", func(t *testing.T) {
var counter atomic.Int64
var wg sync.WaitGroup
numGoroutines := 10
wg.Add(numGoroutines)
// Concurrently lock and unlock for each key
for range numGoroutines {
go func() {
defer wg.Done()
mm.Lock("key1")
counter.Add(1)
mm.Unlock("key1")
}()
}
wg.Wait()
require.Equal(t, int64(10), counter.Load())
})
t.Run("RLock and RUnlock mutex", func(t *testing.T) {
mm.RLock("key1")
_, ok := mm.items["key1"]
require.True(t, ok)
mm.RUnlock("key1")
})
t.Run("Concurrently RLock and RUnlock mutexes", func(t *testing.T) {
var counter atomic.Int64
var wg sync.WaitGroup
numGoroutines := 10
wg.Add(numGoroutines * 2)
// Concurrently RLock and RUnlock for each key
for range numGoroutines {
go func() {
defer wg.Done()
mm.RLock("key1")
counter.Add(1)
}()
}
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
assert.Equal(ct, int64(10), counter.Load())
}, 5*time.Second, 10*time.Millisecond)
for range numGoroutines {
go func() {
defer wg.Done()
mm.RUnlock("key1")
}()
}
wg.Wait()
})
t.Run("Delete mutex", func(t *testing.T) {
mm.Lock("key1")
mm.Unlock("key1")
mm.Delete("key1")
_, ok := mm.items["key1"]
require.False(t, ok)
})
t.Run("Clear all mutexes, and check item count", func(t *testing.T) {
mm.Lock("key1")
mm.Unlock("key1")
mm.Lock("key2")
mm.Unlock("key2")
require.Equal(t, 2, mm.ItemCount())
mm.Clear()
require.Empty(t, mm.items)
})
}

View File

@ -0,0 +1,94 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ctesting
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dapr/kit/concurrency"
"github.com/dapr/kit/concurrency/ctesting/internal"
)
type RunnerFn func(context.Context, assert.TestingT)
// Assert runs the provided test functions in parallel and asserts that they
// all pass.
func Assert(t *testing.T, runners ...RunnerFn) {
t.Helper()
if len(runners) == 0 {
require.Fail(t, "at least one runner function is required")
}
tt := internal.Assert(t)
ctx, cancel := context.WithCancelCause(t.Context())
t.Cleanup(func() { cancel(nil) })
doneCh := make(chan struct{}, len(runners))
for _, runner := range runners {
go func(rfn RunnerFn) {
rfn(ctx, tt)
if errs := tt.Errors(); len(errs) > 0 {
cancel(errors.Join(errs...))
}
doneCh <- struct{}{}
}(runner)
}
for range runners {
select {
case <-doneCh:
case <-t.Context().Done():
require.FailNow(t, "test context was cancelled before all runners completed")
}
}
for _, err := range tt.Errors() {
assert.NoError(t, err)
}
}
// AssertCleanup runs the provided test functions in parallel and asserts that they
// all pass, only after Cleanup,.
func AssertCleanup(t *testing.T, runners ...concurrency.Runner) {
t.Helper()
ctx, cancel := context.WithCancelCause(t.Context())
errCh := make(chan error, len(runners))
for _, runner := range runners {
go func(rfn concurrency.Runner) {
errCh <- rfn(ctx)
}(runner)
}
t.Cleanup(func() {
cancel(nil)
for range runners {
select {
case err := <-errCh:
require.NoError(t, err)
case <-time.After(10 * time.Second):
assert.Fail(t, "timeout waiting for runner to stop")
}
}
})
}

View File

@ -0,0 +1,49 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package internal
import (
"fmt"
"sync"
"testing"
"github.com/stretchr/testify/assert"
)
type Interface interface {
assert.TestingT
Errors() []error
}
type assertT struct {
t *testing.T
lock sync.Mutex
errs []error
}
func Assert(t *testing.T) Interface {
return &assertT{t: t}
}
func (a *assertT) Errorf(format string, args ...any) {
a.lock.Lock()
defer a.lock.Unlock()
a.errs = append(a.errs, fmt.Errorf(format, args...))
}
func (a *assertT) Errors() []error {
a.lock.Lock()
defer a.lock.Unlock()
return a.errs
}

90
concurrency/dir/dir.go Normal file
View File

@ -0,0 +1,90 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package dir
import (
"fmt"
"os"
"path/filepath"
"time"
"github.com/dapr/kit/logger"
)
type Options struct {
Log logger.Logger
Target string
}
// Dir atomically writes files to a given directory.
type Dir struct {
log logger.Logger
base string
target string
targetDir string
prev *string
}
func New(opts Options) *Dir {
return &Dir{
log: opts.Log,
base: filepath.Dir(opts.Target),
target: opts.Target,
targetDir: filepath.Base(opts.Target),
}
}
func (d *Dir) Write(files map[string][]byte) error {
newDir := filepath.Join(d.base, fmt.Sprintf("%d-%s", time.Now().UTC().UnixNano(), d.targetDir))
if err := os.MkdirAll(d.base, 0o700); err != nil {
return err
}
if err := os.MkdirAll(newDir, 0o700); err != nil {
return err
}
for file, b := range files {
path := filepath.Join(newDir, file)
if err := os.WriteFile(path, b, 0o600); err != nil {
return err
}
d.log.Infof("Written file %s", file)
}
if err := os.Symlink(newDir, d.target+".new"); err != nil {
return err
}
d.log.Infof("Syslink %s to %s.new", newDir, d.target)
if err := os.Rename(d.target+".new", d.target); err != nil {
return err
}
d.log.Infof("Atomic write to %s", d.target)
if d.prev != nil {
if err := os.RemoveAll(*d.prev); err != nil {
return err
}
}
d.prev = &newDir
return nil
}

62
concurrency/fifo/map.go Normal file
View File

@ -0,0 +1,62 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package fifo
// Map is a map of mutexes whose locks are acquired in a FIFO order. The map is
// pruned automatically when all locks have been released for a key.
type Map[T comparable] interface {
Lock(key T)
Unlock(key T)
}
type mapItem struct {
ilen uint64
mutex *Mutex
}
type fifoMap[T comparable] struct {
lock *Mutex
items map[T]*mapItem
}
func NewMap[T comparable]() Map[T] {
return &fifoMap[T]{
lock: New(),
items: make(map[T]*mapItem),
}
}
func (a *fifoMap[T]) Lock(key T) {
a.lock.Lock()
m, ok := a.items[key]
if !ok {
m = &mapItem{mutex: New()}
a.items[key] = m
}
m.ilen++
a.lock.Unlock()
m.mutex.Lock()
}
func (a *fifoMap[T]) Unlock(key T) {
a.lock.Lock()
m := a.items[key]
m.ilen--
if m.ilen == 0 {
delete(a.items, key)
}
a.lock.Unlock()
m.mutex.Unlock()
}

View File

@ -0,0 +1,47 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package fifo
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func Test_Map(t *testing.T) {
m := NewMap[string]().(*fifoMap[string])
assert.Empty(t, m.items)
m.Lock("key1")
assert.Len(t, m.items, 1)
assert.Equal(t, uint64(1), m.items["key1"].ilen)
go func() {
m.Lock("key1")
}()
assert.EventuallyWithT(t, func(c *assert.CollectT) {
m.lock.Lock()
assert.Equal(c, uint64(2), m.items["key1"].ilen)
m.lock.Unlock()
}, time.Second*3, time.Millisecond*10)
m.Unlock("key1")
assert.Equal(t, uint64(1), m.items["key1"].ilen)
m.Unlock("key1")
assert.Empty(t, m.items)
}

34
concurrency/fifo/mutex.go Normal file
View File

@ -0,0 +1,34 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package fifo
// Mutex is a mutex lock whose lock and unlock operations are
// first-in-first-out (FIFO).
type Mutex struct {
lock chan struct{}
}
func New() *Mutex {
return &Mutex{
lock: make(chan struct{}, 1),
}
}
func (m *Mutex) Lock() {
m.lock <- struct{}{}
}
func (m *Mutex) Unlock() {
<-m.lock
}

View File

@ -0,0 +1,62 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package lock
import (
"context"
"sync"
)
// Context is a ready write mutex lock where Locking can return early with an
// error if the context is done. No error response means the lock is acquired.
type Context struct {
lock sync.RWMutex
locked chan struct{}
}
func NewContext() *Context {
return &Context{
locked: make(chan struct{}, 1),
}
}
func (c *Context) Lock(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case c.locked <- struct{}{}:
c.lock.Lock()
return nil
}
}
func (c *Context) Unlock() {
c.lock.Unlock()
<-c.locked
}
func (c *Context) RLock(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case c.locked <- struct{}{}:
c.lock.RLock()
return nil
}
}
func (c *Context) RUnlock() {
c.lock.RUnlock()
<-c.locked
}

View File

@ -0,0 +1,81 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package lock
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func Test_Context(t *testing.T) {
tests := map[string]struct {
name string
action func(l *Context) error
expectError bool
}{
"Successful Lock": {
action: func(l *Context) error {
return l.Lock(t.Context())
},
expectError: false,
},
"Lock with Context Timeout": {
action: func(l *Context) error {
l.Lock(t.Context())
ctx, cancel := context.WithTimeout(t.Context(), time.Millisecond*50)
defer cancel()
return l.Lock(ctx)
},
expectError: true,
},
"Successful RLock": {
action: func(l *Context) error {
return l.RLock(t.Context())
},
expectError: false,
},
"RLock with Context Timeout": {
action: func(l *Context) error {
l.Lock(t.Context())
ctx, cancel := context.WithTimeout(t.Context(), time.Millisecond*50)
defer cancel()
return l.RLock(ctx)
},
expectError: true,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
t.Parallel()
l := NewContext()
done := make(chan error)
go func() {
done <- test.action(l)
}()
select {
case err := <-done:
assert.Equal(t, (err != nil), test.expectError, "unexpected error, expected error: %v, got: %v", test.expectError, err)
case <-time.After(time.Second):
t.Errorf("test timed out")
}
})
}
}

View File

@ -0,0 +1,199 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package lock
import (
"context"
"errors"
"sync"
"time"
"github.com/dapr/kit/concurrency/fifo"
)
var errLockClosed = errors.New("lock closed")
type hold struct {
writeLock bool
rctx context.Context
respCh chan *holdresp
}
type holdresp struct {
rctx context.Context
cancel context.CancelFunc
err error
}
type OuterCancel struct {
ch chan *hold
cancelErr error
gracefulTimeout time.Duration
lock chan struct{}
wg sync.WaitGroup
rcancelLock sync.Mutex
rcancelx uint64
rcancels map[uint64]context.CancelFunc
closeCh chan struct{}
shutdownLock *fifo.Mutex
}
func NewOuterCancel(cancelErr error, gracefulTimeout time.Duration) *OuterCancel {
return &OuterCancel{
lock: make(chan struct{}, 1),
ch: make(chan *hold, 1),
rcancels: make(map[uint64]context.CancelFunc),
closeCh: make(chan struct{}),
shutdownLock: fifo.New(),
cancelErr: cancelErr,
gracefulTimeout: gracefulTimeout,
}
}
func (o *OuterCancel) Run(ctx context.Context) {
defer func() {
o.rcancelLock.Lock()
defer o.rcancelLock.Unlock()
for _, cancel := range o.rcancels {
go cancel()
}
}()
go func() {
<-ctx.Done()
close(o.closeCh)
}()
for {
select {
case <-o.closeCh:
return
case h := <-o.ch:
o.handleHold(h)
}
}
}
func (o *OuterCancel) handleHold(h *hold) {
if h.rctx != nil {
select {
case o.lock <- struct{}{}:
case <-h.rctx.Done():
h.respCh <- &holdresp{err: h.rctx.Err()}
return
}
} else {
o.lock <- struct{}{}
}
o.rcancelLock.Lock()
if h.writeLock {
for _, cancel := range o.rcancels {
go cancel()
}
o.rcancelx = 0
o.rcancelLock.Unlock()
o.wg.Wait()
h.respCh <- &holdresp{cancel: func() { <-o.lock }}
return
}
o.wg.Add(1)
var done bool
doneCh := make(chan bool)
rctx, cancel := context.WithCancelCause(h.rctx)
i := o.rcancelx
rcancel := func() {
o.rcancelLock.Lock()
if !done {
close(doneCh)
cancel(o.cancelErr)
delete(o.rcancels, i)
o.wg.Done()
done = true
}
o.rcancelLock.Unlock()
}
rcancelGrace := func() {
select {
case <-time.After(o.gracefulTimeout):
case <-o.closeCh:
case <-doneCh:
}
rcancel()
}
o.rcancels[i] = rcancelGrace
o.rcancelx++
o.rcancelLock.Unlock()
h.respCh <- &holdresp{rctx: rctx, cancel: rcancel}
<-o.lock
}
func (o *OuterCancel) Lock() context.CancelFunc {
h := hold{
writeLock: true,
respCh: make(chan *holdresp, 1),
}
select {
case <-o.closeCh:
o.shutdownLock.Lock()
return o.shutdownLock.Unlock
case o.ch <- &h:
}
select {
case <-o.closeCh:
o.shutdownLock.Lock()
return o.shutdownLock.Unlock
case resp := <-h.respCh:
return resp.cancel
}
}
func (o *OuterCancel) RLock(ctx context.Context) (context.Context, context.CancelFunc, error) {
h := hold{
writeLock: false,
rctx: ctx,
respCh: make(chan *holdresp, 1),
}
select {
case <-o.closeCh:
return nil, nil, errLockClosed
case <-ctx.Done():
return nil, nil, ctx.Err()
case o.ch <- &h:
}
select {
case <-o.closeCh:
return nil, nil, errLockClosed
case resp := <-h.respCh:
return resp.rctx, resp.cancel, resp.err
}
}

View File

@ -0,0 +1,225 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package lock
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_OuterCancel(t *testing.T) {
t.Parallel()
//nolint:err113
terr := errors.New("test")
t.Run("can rlock multiple times", func(t *testing.T) {
t.Parallel()
l := NewOuterCancel(terr, time.Second)
ctx, cancel := context.WithCancel(t.Context())
t.Cleanup(cancel)
go l.Run(ctx)
ctx1, c1, err := l.RLock(ctx)
require.NoError(t, err)
ctx2, c2, err := l.RLock(ctx)
require.NoError(t, err)
ctx3, c3, err := l.RLock(ctx)
require.NoError(t, err)
require.NoError(t, ctx1.Err())
require.NoError(t, ctx2.Err())
require.NoError(t, ctx3.Err())
c1()
require.Error(t, ctx1.Err())
require.NoError(t, ctx2.Err())
require.NoError(t, ctx3.Err())
c2()
require.Error(t, ctx1.Err())
require.Error(t, ctx2.Err())
require.NoError(t, ctx3.Err())
c3()
require.Error(t, ctx1.Err())
require.Error(t, ctx2.Err())
require.Error(t, ctx3.Err())
})
t.Run("rlock unlock removes cancel state", func(t *testing.T) {
t.Parallel()
l := NewOuterCancel(terr, time.Second)
ctx, cancel := context.WithCancel(t.Context())
t.Cleanup(cancel)
go l.Run(ctx)
_, c1, err := l.RLock(ctx)
require.NoError(t, err)
_, c2, err := l.RLock(ctx)
require.NoError(t, err)
_, c3, err := l.RLock(ctx)
require.NoError(t, err)
assert.Len(t, l.rcancels, 3)
c1()
assert.Len(t, l.rcancels, 2)
c2()
assert.Len(t, l.rcancels, 1)
c3()
assert.Empty(t, l.rcancels, 0)
})
t.Run("calling lock cancels all current rlocks", func(t *testing.T) {
t.Parallel()
l := NewOuterCancel(terr, time.Second)
ctx, cancel := context.WithCancel(t.Context())
t.Cleanup(cancel)
go l.Run(ctx)
ctx1, _, err := l.RLock(ctx)
require.NoError(t, err)
ctx2, _, err := l.RLock(ctx)
require.NoError(t, err)
ctx3, _, err := l.RLock(ctx)
require.NoError(t, err)
require.NoError(t, ctx1.Err())
require.NoError(t, ctx2.Err())
require.NoError(t, ctx3.Err())
mcancel := l.Lock()
require.Error(t, ctx1.Err())
require.Error(t, ctx2.Err())
require.Error(t, ctx3.Err())
mcancel()
assert.Empty(t, l.rcancels)
})
t.Run("rlock when closed should error", func(t *testing.T) {
t.Parallel()
l := NewOuterCancel(terr, time.Second)
ctx, cancel := context.WithCancel(t.Context())
cancel()
go l.Run(ctx)
select {
case <-l.closeCh:
case <-time.After(time.Second * 5):
assert.Fail(t, "expected close")
}
_, _, err := l.RLock(t.Context())
require.Error(t, err)
})
t.Run("lock continues to work after close", func(t *testing.T) {
t.Parallel()
l := NewOuterCancel(terr, time.Second)
ctx, cancel := context.WithCancel(t.Context())
cancel()
l.Run(ctx)
lcancel := l.Lock()
lcancel()
lcancel = l.Lock()
lcancel()
})
t.Run("rlock blocks until outter unlocks", func(t *testing.T) {
t.Parallel()
l := NewOuterCancel(terr, time.Second)
ctx, cancel := context.WithCancel(t.Context())
t.Cleanup(cancel)
go l.Run(ctx)
lcancel := l.Lock()
gotRLock := make(chan struct{})
errCh := make(chan error, 1)
go func() {
_, c1, err := l.RLock(ctx)
errCh <- err
t.Cleanup(c1)
close(gotRLock)
}()
t.Cleanup(func() {
require.NoError(t, <-errCh)
})
select {
case <-time.After(time.Millisecond * 500):
case <-gotRLock:
require.Fail(t, "unexpected rlock")
}
lcancel()
<-gotRLock
})
t.Run("lock blocks until outter unlocks", func(t *testing.T) {
t.Parallel()
l := NewOuterCancel(terr, time.Second)
ctx, cancel := context.WithCancel(t.Context())
t.Cleanup(cancel)
go l.Run(ctx)
lcancel := l.Lock()
gotLock := make(chan struct{})
go func() {
lockcancel := l.Lock()
t.Cleanup(lockcancel)
close(gotLock)
}()
select {
case <-time.After(time.Millisecond * 500):
case <-gotLock:
require.Fail(t, "unexpected rlock")
}
lcancel()
})
}

97
concurrency/runner.go Normal file
View File

@ -0,0 +1,97 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package concurrency
import (
"context"
"errors"
"sync"
"sync/atomic"
)
var ErrManagerAlreadyStarted = errors.New("runner manager already started")
// Runner is a function that runs a task.
type Runner func(ctx context.Context) error
// RunnerManager is a manager for runners. It runs all runners in parallel and
// waits for all runners to finish. If any runner returns, the RunnerManager
// will stop all other runners and return any error.
type RunnerManager struct {
lock sync.Mutex
runners []Runner
running atomic.Bool
}
// NewRunnerManager creates a new RunnerManager.
func NewRunnerManager(runners ...Runner) *RunnerManager {
return &RunnerManager{
runners: runners,
}
}
// Add adds a new runner to the RunnerManager.
func (r *RunnerManager) Add(runner ...Runner) error {
if r.running.Load() {
return ErrManagerAlreadyStarted
}
r.lock.Lock()
defer r.lock.Unlock()
r.runners = append(r.runners, runner...)
return nil
}
// Run runs all runners in parallel and waits for all runners to finish. If any
// runner returns, the RunnerManager will stop all other runners and return any
// error.
func (r *RunnerManager) Run(ctx context.Context) error {
if !r.running.CompareAndSwap(false, true) {
return ErrManagerAlreadyStarted
}
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
errCh := make(chan error)
for _, runner := range r.runners {
go func(runner Runner) {
// Ignore context cancelled errors since errors from a runner manager
// will likely determine the exit code of the program.
// Context cancelled errors are also not really useful to the user in
// this situation.
rErr := runner(ctx)
if rErr != nil && !errors.Is(rErr, context.Canceled) {
errCh <- rErr
// Since the task returned, we need to cancel all other tasks.
// This is a noop if the parent context is already cancelled, or another
// task returned before this one.
cancel(rErr)
return
}
errCh <- nil
cancel(nil)
}(runner)
}
// Collect all errors
errObjs := make([]error, 0)
for range len(r.runners) {
err := <-errCh
if err != nil {
errObjs = append(errObjs, err)
}
}
return errors.Join(errObjs...)
}

234
concurrency/runner_test.go Normal file
View File

@ -0,0 +1,234 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package concurrency
import (
"context"
"errors"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_RunnerManager(t *testing.T) {
t.Run("runner with no tasks should return nil", func(t *testing.T) {
require.NoError(t, NewRunnerManager().Run(t.Context()))
})
t.Run("runner with a task that completes should return nil", func(t *testing.T) {
var i int32
require.NoError(t, NewRunnerManager(func(ctx context.Context) error {
atomic.AddInt32(&i, 1)
return nil
}).Run(t.Context()))
assert.Equal(t, int32(1), i)
})
t.Run("runner with multiple tasks that complete should return nil", func(t *testing.T) {
var i int32
require.NoError(t, NewRunnerManager(
func(ctx context.Context) error {
atomic.AddInt32(&i, 1)
return nil
},
func(ctx context.Context) error {
atomic.AddInt32(&i, 1)
return nil
},
func(ctx context.Context) error {
atomic.AddInt32(&i, 1)
return nil
},
).Run(t.Context()))
assert.Equal(t, int32(3), i)
})
t.Run("a runner that errors should error", func(t *testing.T) {
var i int32
require.EqualError(t, NewRunnerManager(
func(ctx context.Context) error {
atomic.AddInt32(&i, 1)
return errors.New("error")
},
func(ctx context.Context) error {
atomic.AddInt32(&i, 1)
return nil
},
func(ctx context.Context) error {
atomic.AddInt32(&i, 1)
return nil
},
).Run(t.Context()), "error")
assert.Equal(t, int32(3), i)
})
t.Run("a runner with multiple errors should collect all errors (string match)", func(t *testing.T) {
var i int32
err := NewRunnerManager(
func(ctx context.Context) error {
atomic.AddInt32(&i, 1)
return errors.New("error")
},
func(ctx context.Context) error {
atomic.AddInt32(&i, 1)
return errors.New("error")
},
func(ctx context.Context) error {
atomic.AddInt32(&i, 1)
return errors.New("error")
},
).Run(t.Context())
require.Error(t, err)
require.ErrorContains(t, err, "error\nerror\nerror") //nolint:dupword
assert.Equal(t, int32(3), i)
})
t.Run("a runner with multiple errors should collect all errors (unique)", func(t *testing.T) {
var i int32
err := NewRunnerManager(
func(ctx context.Context) error {
atomic.AddInt32(&i, 1)
return errors.New("error1")
},
func(ctx context.Context) error {
atomic.AddInt32(&i, 1)
return errors.New("error2")
},
func(ctx context.Context) error {
atomic.AddInt32(&i, 1)
return errors.New("error3")
},
).Run(t.Context())
require.Error(t, err)
assert.ElementsMatch(t, []string{"error1", "error2", "error3"}, strings.Split(err.Error(), "\n"))
assert.Equal(t, int32(3), i)
})
t.Run("should be able to add runner with both New and Add", func(t *testing.T) {
var i int32
mngr := NewRunnerManager(
func(ctx context.Context) error {
atomic.AddInt32(&i, 1)
return nil
},
)
require.NoError(t, mngr.Add(
func(ctx context.Context) error {
atomic.AddInt32(&i, 1)
return nil
},
))
require.NoError(t, mngr.Add(
func(ctx context.Context) error {
atomic.AddInt32(&i, 1)
return nil
},
))
require.NoError(t, mngr.Run(t.Context()))
assert.Equal(t, int32(3), i)
})
t.Run("when a runner returns, expect context to be cancelled for other runners", func(t *testing.T) {
var i int32
require.NoError(t, NewRunnerManager(
func(ctx context.Context) error {
atomic.AddInt32(&i, 1)
return nil
},
func(ctx context.Context) error {
atomic.AddInt32(&i, 1)
select {
case <-ctx.Done():
case <-time.After(time.Second):
t.Error("context should have been cancelled in time")
}
return nil
},
func(ctx context.Context) error {
atomic.AddInt32(&i, 1)
select {
case <-ctx.Done():
case <-time.After(time.Second):
t.Error("context should have been cancelled in time")
}
return nil
},
).Run(t.Context()))
assert.Equal(t, int32(3), i)
})
t.Run("when a runner errors, expect context to be cancelled for other runners", func(t *testing.T) {
var i int32
err := NewRunnerManager(
func(ctx context.Context) error {
atomic.AddInt32(&i, 1)
select {
case <-ctx.Done():
case <-time.After(time.Second):
t.Error("context should have been cancelled in time")
}
return errors.New("error1")
},
func(ctx context.Context) error {
atomic.AddInt32(&i, 1)
select {
case <-ctx.Done():
case <-time.After(time.Second):
t.Error("context should have been cancelled in time")
}
return errors.New("error2")
},
func(ctx context.Context) error {
atomic.AddInt32(&i, 1)
return errors.New("error3")
},
).Run(t.Context())
require.Error(t, err)
assert.ElementsMatch(t, []string{"error1", "error2", "error3"}, strings.Split(err.Error(), "\n"))
assert.Equal(t, int32(3), i)
})
t.Run("a manger started twice should error", func(t *testing.T) {
var i int32
m := NewRunnerManager(func(ctx context.Context) error {
atomic.AddInt32(&i, 1)
return nil
})
require.NoError(t, m.Run(t.Context()))
assert.Equal(t, int32(1), i)
require.EqualError(t, m.Run(t.Context()), "runner manager already started")
assert.Equal(t, int32(1), i)
})
t.Run("adding a task to a started manager should error", func(t *testing.T) {
var i int32
m := NewRunnerManager(func(ctx context.Context) error {
atomic.AddInt32(&i, 1)
return nil
})
require.NoError(t, m.Run(t.Context()))
assert.Equal(t, int32(1), i)
err := m.Add(func(ctx context.Context) error {
atomic.AddInt32(&i, 1)
return nil
})
require.Error(t, err)
require.ErrorIs(t, err, ErrManagerAlreadyStarted)
assert.Equal(t, int32(1), i)
})
}

View File

@ -0,0 +1,58 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package slice
import "sync"
// Slice is a concurrent safe types slice
type Slice[T any] interface {
Append(items ...T) int
Len() int
Slice() []T
Store(items ...T)
}
type slice[T any] struct {
lock sync.RWMutex
data []T
}
func New[T any]() Slice[T] {
return new(slice[T])
}
func (s *slice[T]) Append(items ...T) int {
s.lock.Lock()
defer s.lock.Unlock()
s.data = append(s.data, items...)
return len(s.data)
}
func (s *slice[T]) Len() int {
s.lock.RLock()
defer s.lock.RUnlock()
return len(s.data)
}
func (s *slice[T]) Slice() []T {
s.lock.RLock()
defer s.lock.RUnlock()
return s.data
}
func (s *slice[T]) Store(items ...T) {
s.lock.Lock()
defer s.lock.Unlock()
s.data = items
}

View File

@ -0,0 +1,18 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package slice
func String() Slice[string] {
return new(slice[string])
}

View File

@ -81,7 +81,7 @@ func decodeString(f reflect.Type, t reflect.Type, data any) (any, error) {
if t.Implements(typeStringDecoder) {
result = reflect.New(t.Elem()).Interface()
decoder = result.(StringDecoder)
} else if reflect.PtrTo(t).Implements(typeStringDecoder) {
} else if reflect.PointerTo(t).Implements(typeStringDecoder) {
result = reflect.New(t).Interface()
decoder = result.(StringDecoder)
}

View File

@ -21,6 +21,7 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dapr/kit/config"
"github.com/dapr/kit/ptr"
@ -187,7 +188,7 @@ func TestDecode(t *testing.T) {
t.Run(name, func(t *testing.T) {
var actual testConfig
err := config.Decode(tc, &actual)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, expected, actual)
})
}
@ -231,61 +232,60 @@ func TestDecodeErrors(t *testing.T) {
"string": 1234,
"stringPtr": 1234,
}, &actual)
if assert.Error(t, err) {
errMsg := err.Error()
expectedNumErrors := 32
expectedPrefix := " error(s) decoding:"
assert.True(t, strings.HasPrefix(errMsg, fmt.Sprintf("%d%s", expectedNumErrors, expectedPrefix)), errMsg)
prefixIndex := strings.Index(errMsg, expectedPrefix)
if assert.True(t, prefixIndex != -1) {
errMsg = errMsg[prefixIndex+len(expectedPrefix):]
errMsg = strings.TrimSpace(errMsg)
errors := strings.Split(errMsg, "\n")
errorSet := make(map[string]struct{}, len(errors))
for _, e := range errors {
errorSet[e] = struct{}{}
}
expectedErrors := []string{
"* error decoding 'int': invalid int \"-badval\"",
"* error decoding 'intPtr': invalid int \"-badval\"",
"* error decoding 'int16': invalid int16 \"-badval\"",
"* error decoding 'int16Ptr': invalid int16 \"-badval\"",
"* error decoding 'int32': invalid int32 \"-badval\"",
"* error decoding 'int32Ptr': invalid int32 \"-badval\"",
"* error decoding 'int64': invalid int64 \"-badval\"",
"* error decoding 'int64Ptr': invalid int64 \"-badval\"",
"* error decoding 'int8': invalid int8 \"-badval\"",
"* error decoding 'int8Ptr': invalid int8 \"-badval\"",
"* error decoding 'uint': invalid uint \"-9999\"",
"* error decoding 'uintPtr': invalid uint \"-9999\"",
"* error decoding 'uint64': invalid uint64 \"-1234\"",
"* error decoding 'uint64Ptr': invalid uint64 \"-1234\"",
"* error decoding 'uint32': invalid uint32 \"-5678\"",
"* error decoding 'uint32Ptr': invalid uint32 \"-5678\"",
"* error decoding 'uint16': invalid uint16 \"-9012\"",
"* error decoding 'uint16Ptr': invalid uint16 \"-9012\"",
"* error decoding 'byte': invalid uint8 \"-1\"",
"* error decoding 'bytePtr': invalid uint8 \"-1\"",
"* error decoding 'float32': invalid float32 \"badval.5\"",
"* error decoding 'float32Ptr': invalid float32 \"badval.5\"",
"* error decoding 'float64': invalid float64 \"badval.5\"",
"* error decoding 'float64Ptr': invalid float64 \"badval.5\"",
"* error decoding 'duration': invalid duration \"badval\"",
"* error decoding 'durationPtr': invalid duration \"badval\"",
"* error decoding 'time': invalid time \"badval\"",
"* error decoding 'timePtr': invalid time \"badval\"",
"* error decoding 'decoded': invalid Decoded \"badval\": strconv.Atoi: parsing \"badval\": invalid syntax",
"* error decoding 'decodedPtr': invalid Decoded \"badval\": strconv.Atoi: parsing \"badval\": invalid syntax",
"* error decoding 'bool': invalid bool \"badval\"",
"* error decoding 'boolPtr': invalid bool \"badval\"",
}
for _, expectedError := range expectedErrors {
assert.Contains(t, errors, expectedError)
delete(errorSet, expectedError)
}
assert.Empty(t, errorSet)
}
require.Error(t, err)
errMsg := err.Error()
expectedNumErrors := 32
expectedPrefix := " error(s) decoding:"
assert.True(t, strings.HasPrefix(errMsg, fmt.Sprintf("%d%s", expectedNumErrors, expectedPrefix)), errMsg)
prefixIndex := strings.Index(errMsg, expectedPrefix)
require.NotEqual(t, -1, prefixIndex)
errMsg = errMsg[prefixIndex+len(expectedPrefix):]
errMsg = strings.TrimSpace(errMsg)
errors := strings.Split(errMsg, "\n")
errorSet := make(map[string]struct{}, len(errors))
for _, e := range errors {
errorSet[e] = struct{}{}
}
expectedErrors := []string{
"* error decoding 'int': invalid int \"-badval\"",
"* error decoding 'intPtr': invalid int \"-badval\"",
"* error decoding 'int16': invalid int16 \"-badval\"",
"* error decoding 'int16Ptr': invalid int16 \"-badval\"",
"* error decoding 'int32': invalid int32 \"-badval\"",
"* error decoding 'int32Ptr': invalid int32 \"-badval\"",
"* error decoding 'int64': invalid int64 \"-badval\"",
"* error decoding 'int64Ptr': invalid int64 \"-badval\"",
"* error decoding 'int8': invalid int8 \"-badval\"",
"* error decoding 'int8Ptr': invalid int8 \"-badval\"",
"* error decoding 'uint': invalid uint \"-9999\"",
"* error decoding 'uintPtr': invalid uint \"-9999\"",
"* error decoding 'uint64': invalid uint64 \"-1234\"",
"* error decoding 'uint64Ptr': invalid uint64 \"-1234\"",
"* error decoding 'uint32': invalid uint32 \"-5678\"",
"* error decoding 'uint32Ptr': invalid uint32 \"-5678\"",
"* error decoding 'uint16': invalid uint16 \"-9012\"",
"* error decoding 'uint16Ptr': invalid uint16 \"-9012\"",
"* error decoding 'byte': invalid uint8 \"-1\"",
"* error decoding 'bytePtr': invalid uint8 \"-1\"",
"* error decoding 'float32': invalid float32 \"badval.5\"",
"* error decoding 'float32Ptr': invalid float32 \"badval.5\"",
"* error decoding 'float64': invalid float64 \"badval.5\"",
"* error decoding 'float64Ptr': invalid float64 \"badval.5\"",
"* error decoding 'duration': invalid duration \"badval\"",
"* error decoding 'durationPtr': invalid duration \"badval\"",
"* error decoding 'time': invalid time \"badval\"",
"* error decoding 'timePtr': invalid time \"badval\"",
"* error decoding 'decoded': invalid Decoded \"badval\": strconv.Atoi: parsing \"badval\": invalid syntax",
"* error decoding 'decodedPtr': invalid Decoded \"badval\": strconv.Atoi: parsing \"badval\": invalid syntax",
"* error decoding 'bool': invalid bool \"badval\"",
"* error decoding 'boolPtr': invalid bool \"badval\"",
}
for _, expectedError := range expectedErrors {
assert.Contains(t, errors, expectedError)
delete(errorSet, expectedError)
}
assert.Empty(t, errorSet)
}
func getTimeVal() time.Time {

View File

@ -96,7 +96,7 @@ func TestNormalize(t *testing.T) {
actual, err := config.Normalize(tc.input)
if tc.err != "" {
require.Error(t, err)
assert.EqualError(t, err, tc.err)
require.EqualError(t, err, tc.err)
} else {
require.NoError(t, err)
}

View File

@ -17,6 +17,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dapr/kit/config"
)
@ -60,9 +61,8 @@ func TestPrefixedBy(t *testing.T) {
t.Run(name, func(t *testing.T) {
actual, err := config.PrefixedBy(tc.input, tc.prefix)
if tc.err != "" {
if assert.Error(t, err) {
assert.Equal(t, tc.err, err.Error())
}
require.Error(t, err)
assert.Equal(t, tc.err, err.Error())
} else {
assert.Equal(t, tc.expected, actual, "unexpected output")
}

View File

@ -52,6 +52,9 @@ func NewPool(ctx ...context.Context) *Pool {
go func() {
defer cancel()
defer p.lock.RUnlock()
//nolint:intrange
// for loops are evaluated on every loop while range are evaluated over a snapshot of the slice as it
// existed when the loop started
for i := 0; i < len(p.pool); i++ {
ch := p.pool[i]
p.lock.RUnlock()

View File

@ -37,7 +37,7 @@ func Test_Pool(t *testing.T) {
t.Run("a cancelled context given to pool, should have pool cancelled", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(t.Context())
cancel()
pool := NewPool(ctx)
select {
@ -49,10 +49,10 @@ func Test_Pool(t *testing.T) {
t.Run("a cancelled context given to pool, given a new context, should still have pool cancelled", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(t.Context())
cancel()
pool := NewPool(ctx)
pool.Add(context.Background())
pool.Add(t.Context())
select {
case <-pool.Done():
case <-time.After(time.Second):
@ -65,13 +65,13 @@ func Test_Pool(t *testing.T) {
var ctx [50]context.Context
var cancel [50]context.CancelFunc
ctx[0], cancel[0] = context.WithCancel(context.Background())
pool := NewPool(ctx[0])
ctxPool := make([]context.Context, 0, 50)
for i := 1; i < 50; i++ {
ctx[i], cancel[i] = context.WithCancel(context.Background())
pool.Add(ctx[i])
for i := range 50 {
ctx[i], cancel[i] = context.WithCancel(t.Context())
ctxPool = append(ctxPool, ctx[i])
}
pool := NewPool(ctxPool...)
//nolint:gosec
r := rand.New(rand.NewSource(time.Now().UnixNano()))
@ -80,7 +80,7 @@ func Test_Pool(t *testing.T) {
cancel[i], cancel[j] = cancel[j], cancel[i]
})
for i := 0; i < 50; i++ {
for i := range 50 {
select {
case <-pool.Done():
t.Error("expected context to not be cancelled")
@ -99,8 +99,8 @@ func Test_Pool(t *testing.T) {
t.Run("pool size will not increase if the given contexts have been cancelled", func(t *testing.T) {
t.Parallel()
ctx1, cancel1 := context.WithCancel(context.Background())
ctx2, cancel2 := context.WithCancel(context.Background())
ctx1, cancel1 := context.WithCancel(t.Context())
ctx2, cancel2 := context.WithCancel(t.Context())
pool := NewPool(ctx1, ctx2)
assert.Equal(t, 2, pool.Size())
@ -111,19 +111,19 @@ func Test_Pool(t *testing.T) {
case <-time.After(time.Second):
t.Error("expected context pool to be cancelled")
}
pool.Add(context.Background())
pool.Add(t.Context())
assert.Equal(t, 2, pool.Size())
})
t.Run("pool size will not increase if the pool has been closed", func(t *testing.T) {
t.Parallel()
ctx1 := context.Background()
ctx2 := context.Background()
ctx1 := t.Context()
ctx2 := t.Context()
pool := NewPool(ctx1, ctx2)
assert.Equal(t, 2, pool.Size())
pool.Cancel()
pool.Add(context.Background())
pool.Add(t.Context())
assert.Equal(t, 0, pool.Size())
select {
case <-pool.Done():
@ -131,4 +131,24 @@ func Test_Pool(t *testing.T) {
t.Error("expected context pool to be cancelled")
}
})
t.Run("wait for added context to be closed", func(t *testing.T) {
t.Parallel()
ctx1, cancel1 := context.WithCancel(t.Context())
pool := NewPool(ctx1)
ctx2, cancel2 := context.WithCancel(t.Context())
pool.Add(ctx2)
assert.Equal(t, 2, pool.Size())
cancel1()
select {
case <-pool.Done():
t.Error("expected context pool to not be cancelled")
case <-time.After(10 * time.Millisecond):
}
cancel2()
})
}

View File

@ -23,6 +23,9 @@ import (
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
clocktesting "k8s.io/utils/clock/testing"
)
func appendingJob(slice *[]int, value int) Job {
@ -88,6 +91,7 @@ func TestChainRecover(t *testing.T) {
type countJob struct {
m sync.Mutex
started int
clock clocktesting.FakeClock
done int
delay time.Duration
}
@ -96,7 +100,7 @@ func (j *countJob) Run() {
j.m.Lock()
j.started++
j.m.Unlock()
time.Sleep(j.delay)
<-j.clock.After(j.delay)
j.m.Lock()
j.done++
j.m.Unlock()
@ -119,10 +123,11 @@ func TestChainDelayIfStillRunning(t *testing.T) {
var j countJob
wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j)
go wrappedJob.Run()
time.Sleep(50 * time.Millisecond) // Give the job 50 ms to complete.
if c := j.Done(); c != 1 {
t.Errorf("expected job run once, immediately, got %d", c)
}
assert.Eventually(t, j.clock.HasWaiters, 100*time.Millisecond, 10*time.Millisecond)
j.clock.Step(1)
assert.Eventually(t, func() bool {
return j.Done() == 1
}, 100*time.Millisecond, 10*time.Millisecond)
})
t.Run("second run immediate if first done", func(t *testing.T) {
@ -130,13 +135,13 @@ func TestChainDelayIfStillRunning(t *testing.T) {
wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j)
go func() {
go wrappedJob.Run()
time.Sleep(10 * time.Millisecond)
go wrappedJob.Run()
}()
time.Sleep(100 * time.Millisecond) // Give both jobs 100 ms to complete.
if c := j.Done(); c != 2 {
t.Errorf("expected job run twice, immediately, got %d", c)
}
assert.Eventually(t, j.clock.HasWaiters, 100*time.Millisecond, 10*time.Millisecond)
assert.Eventually(t, func() bool {
j.clock.Step(1)
return j.Done() == 2
}, 100*time.Millisecond, 10*time.Millisecond)
})
t.Run("second run delayed if first not done", func(t *testing.T) {
@ -145,24 +150,29 @@ func TestChainDelayIfStillRunning(t *testing.T) {
wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j)
go func() {
go wrappedJob.Run()
time.Sleep(10 * time.Millisecond)
go wrappedJob.Run()
}()
// After 50 ms, the first job is still in progress, and the second job was
// run but should be waiting for it to finish.
time.Sleep(50 * time.Millisecond)
assert.Eventually(t, j.clock.HasWaiters, 50*time.Millisecond, 10*time.Millisecond)
j.clock.Step(50 * time.Millisecond)
started, done := j.Started(), j.Done()
if started != 1 || done != 0 {
t.Error("expected first job started, but not finished, got", started, done)
}
// Verify that the second job completes.
time.Sleep(200 * time.Millisecond)
started, done = j.Started(), j.Done()
if started != 2 || done != 2 {
t.Error("expected both jobs done, got", started, done)
}
assert.Eventually(t, j.clock.HasWaiters, 50*time.Millisecond, 10*time.Millisecond)
j.clock.Step(50 * time.Millisecond)
assert.Eventually(t, j.clock.HasWaiters, 50*time.Millisecond, 10*time.Millisecond)
j.clock.Step(200 * time.Millisecond)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
started, done = j.Started(), j.Done()
if started != 2 || done != 2 {
c.Errorf("expected both jobs done, got %v %v", started, done)
}
}, 100*time.Millisecond, 10*time.Millisecond)
})
}
@ -171,24 +181,25 @@ func TestChainSkipIfStillRunning(t *testing.T) {
var j countJob
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
go wrappedJob.Run()
time.Sleep(50 * time.Millisecond) // Give the job 50ms to complete.
if c := j.Done(); c != 1 {
t.Errorf("expected job run once, immediately, got %d", c)
}
assert.Eventually(t, j.clock.HasWaiters, 50*time.Millisecond, 10*time.Millisecond)
j.clock.Step(1)
assert.Eventually(t, func() bool { return j.Done() == 1 }, 50*time.Millisecond, 10*time.Millisecond)
})
t.Run("second run immediate if first done", func(t *testing.T) {
var j countJob
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
go func() {
go wrappedJob.Run()
time.Sleep(10 * time.Millisecond)
go wrappedJob.Run()
}()
time.Sleep(100 * time.Millisecond) // Give both jobs 100ms to complete.
if c := j.Done(); c != 2 {
t.Errorf("expected job run twice, immediately, got %d", c)
}
go wrappedJob.Run()
assert.Eventually(t, j.clock.HasWaiters, 50*time.Millisecond, 10*time.Millisecond)
j.clock.Step(1)
assert.Eventually(t, func() bool { return j.Done() == 1 }, 100*time.Millisecond, 10*time.Millisecond)
go wrappedJob.Run()
assert.Eventually(t, j.clock.HasWaiters, 50*time.Millisecond, 10*time.Millisecond)
j.clock.Step(1)
assert.Eventually(t, func() bool { return j.Done() == 2 }, 100*time.Millisecond, 10*time.Millisecond)
})
t.Run("second run skipped if first not done", func(t *testing.T) {
@ -197,38 +208,37 @@ func TestChainSkipIfStillRunning(t *testing.T) {
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
go func() {
go wrappedJob.Run()
time.Sleep(10 * time.Millisecond)
go wrappedJob.Run()
}()
// After 50ms, the first job is still in progress, and the second job was
// aleady skipped.
time.Sleep(50 * time.Millisecond)
started, done := j.Started(), j.Done()
if started != 1 || done != 0 {
t.Error("expected first job started, but not finished, got", started, done)
}
assert.Eventually(t, j.clock.HasWaiters, 50*time.Millisecond, 10*time.Millisecond)
j.clock.Step(50 * time.Millisecond)
assert.Eventually(t, func() bool {
return j.Started() == 1 && j.Done() == 0
}, 50*time.Millisecond, 10*time.Millisecond)
// Verify that the first job completes and second does not run.
time.Sleep(200 * time.Millisecond)
started, done = j.Started(), j.Done()
if started != 1 || done != 1 {
t.Error("expected second job skipped, got", started, done)
}
j.clock.Step(200 * time.Millisecond)
assert.Eventually(t, func() bool {
return j.Started() == 1 && j.Done() == 1
}, 50*time.Millisecond, 10*time.Millisecond)
})
t.Run("skip 10 jobs on rapid fire", func(t *testing.T) {
var j countJob
j.delay = 10 * time.Millisecond
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
for i := 0; i < 11; i++ {
for range 11 {
go wrappedJob.Run()
}
time.Sleep(200 * time.Millisecond)
done := j.Done()
if done != 1 {
t.Error("expected 1 jobs executed, 10 jobs dropped, got", done)
}
assert.Eventually(t, j.clock.HasWaiters, 50*time.Millisecond, 10*time.Millisecond)
j.clock.Step(200 * time.Millisecond)
assert.False(t, j.clock.HasWaiters())
assert.Eventually(t, func() bool {
return j.Started() == 1 && j.Done() == 1
}, 50*time.Millisecond, 10*time.Millisecond)
})
t.Run("different jobs independent", func(t *testing.T) {
@ -238,17 +248,16 @@ func TestChainSkipIfStillRunning(t *testing.T) {
chain := NewChain(SkipIfStillRunning(DiscardLogger))
wrappedJob1 := chain.Then(&j1)
wrappedJob2 := chain.Then(&j2)
for i := 0; i < 11; i++ {
for range 11 {
go wrappedJob1.Run()
go wrappedJob2.Run()
}
time.Sleep(100 * time.Millisecond)
var (
done1 = j1.Done()
done2 = j2.Done()
)
if done1 != 1 || done2 != 1 {
t.Error("expected both jobs executed once, got", done1, "and", done2)
}
assert.Eventually(t, j1.clock.HasWaiters, 50*time.Millisecond, 10*time.Millisecond)
assert.Eventually(t, j2.clock.HasWaiters, 50*time.Millisecond, 10*time.Millisecond)
j1.clock.Step(10 * time.Millisecond)
j2.clock.Step(10 * time.Millisecond)
assert.Eventually(t, func() bool {
return j1.Started() == 1 && j1.Done() == 1
}, 50*time.Millisecond, 10*time.Millisecond)
})
}

View File

@ -14,31 +14,23 @@ You can check the original license at:
https://github.com/robfig/cron/blob/master/LICENSE
*/
//nolint
package cron
import "time"
// ConstantDelaySchedule represents a simple recurring duty cycle, e.g. "Every 5 minutes".
// It does not support jobs more frequent than once a second.
type ConstantDelaySchedule struct {
Delay time.Duration
}
// Every returns a crontab Schedule that activates once every duration.
// Delays of less than a second are not supported (will round up to 1 second).
// Any fields less than a Second are truncated.
func Every(duration time.Duration) ConstantDelaySchedule {
if duration < time.Second {
duration = time.Second
}
return ConstantDelaySchedule{
Delay: duration - time.Duration(duration.Nanoseconds())%time.Second,
Delay: duration,
}
}
// Next returns the next time this should be run.
// This rounds so that the next activation time will be on the second.
func (schedule ConstantDelaySchedule) Next(t time.Time) time.Time {
return t.Add(schedule.Delay - time.Duration(t.Nanosecond())*time.Nanosecond)
return t.Add(schedule.Delay)
}

View File

@ -14,7 +14,6 @@ You can check the original license at:
https://github.com/robfig/cron/blob/master/LICENSE
*/
//nolint
package cron
import (
@ -29,9 +28,12 @@ func TestConstantDelayNext(t *testing.T) {
expected string
}{
// Simple cases
{"Mon Jul 9 14:45 2012", 15*time.Minute + 50*time.Nanosecond, "Mon Jul 9 15:00 2012"},
{"Mon Jul 9 14:45 2012", 15*time.Minute + 50*time.Nanosecond, "Mon Jul 9 15:00:00.00000005 2012"},
{"Mon Jul 9 14:59 2012", 15 * time.Minute, "Mon Jul 9 15:14 2012"},
{"Mon Jul 9 14:59:59 2012", 15 * time.Minute, "Mon Jul 9 15:14:59 2012"},
{"Mon Jul 9 14:45:00 2012", 15 * time.Millisecond, "Mon Jul 9 14:45:00.015 2012"},
{"Mon Jul 9 14:45:00.015 2012", 15 * time.Millisecond, "Mon Jul 9 14:45:00.030 2012"},
{"Mon Jul 9 14:45:00.000000050 2012", 15 * time.Nanosecond, "Mon Jul 9 14:45:00.000000065 2012"},
// Wrap around hours
{"Mon Jul 9 15:45 2012", 35 * time.Minute, "Mon Jul 9 16:20 2012"},
@ -47,18 +49,6 @@ func TestConstantDelayNext(t *testing.T) {
// Wrap around minute, hour, day, month, and year
{"Mon Dec 31 23:59:45 2012", 15 * time.Second, "Tue Jan 1 00:00:00 2013"},
// Round to nearest second on the delay
{"Mon Jul 9 14:45 2012", 15*time.Minute + 50*time.Nanosecond, "Mon Jul 9 15:00 2012"},
// Round up to 1 second if the duration is less.
{"Mon Jul 9 14:45:00 2012", 15 * time.Millisecond, "Mon Jul 9 14:45:01 2012"},
// Round to nearest second when calculating the next time.
{"Mon Jul 9 14:45:00.005 2012", 15 * time.Minute, "Mon Jul 9 15:00 2012"},
// Round to nearest second for both.
{"Mon Jul 9 14:45:00.005 2012", 15*time.Minute + 50*time.Nanosecond, "Mon Jul 9 15:00 2012"},
}
for _, c := range tests {

View File

@ -59,7 +59,7 @@ type Job interface {
type Schedule interface {
// Next returns the next activation time, later than the given time.
// Next is invoked initially, and then each time the job is run.
Next(time.Time) time.Time
Next(t time.Time) time.Time
}
// EntryID identifies an entry within a Cron instance
@ -140,7 +140,7 @@ func New(opts ...Option) *Cron {
running: false,
runningMu: sync.Mutex{},
logger: DefaultLogger,
location: time.Local,
location: time.Local, //nolint:gosmopolitan
parser: standardParser,
clk: clock.RealClock{},
}

View File

@ -14,20 +14,19 @@ You can check the original license at:
https://github.com/robfig/cron/blob/master/LICENSE
*/
//nolint:dupword
package cron
import (
"bytes"
"fmt"
"log"
"runtime"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
clocktesting "k8s.io/utils/clock/testing"
)
@ -35,7 +34,7 @@ import (
// for it to run. This amount is just slightly larger than 1 second to
// compensate for a few milliseconds of runtime.
//
//nolint:revive
const OneSecond = 1*time.Second + 50*time.Millisecond
type syncWriter struct {
@ -62,22 +61,23 @@ func newBufLogger(sw *syncWriter) Logger {
}
func TestFuncPanicRecovery(t *testing.T) {
clock := clocktesting.NewFakeClock(time.Now())
var buf syncWriter
cron := New(WithParser(secondParser),
WithChain(Recover(newBufLogger(&buf))))
WithChain(Recover(newBufLogger(&buf))),
WithClock(clock),
)
cron.Start()
defer cron.Stop()
cron.AddFunc("* * * * * ?", func() {
panic("YOLO")
})
select {
case <-time.After(OneSecond):
if !strings.Contains(buf.String(), "YOLO") {
t.Error("expected a panic to be logged, got none")
}
return
}
assert.Eventually(t, clock.HasWaiters, OneSecond, 10*time.Millisecond)
clock.Step(OneSecond)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
assert.Contains(c, buf.String(), "YOLO")
}, OneSecond, 10*time.Millisecond)
}
type DummyJob struct{}
@ -88,26 +88,27 @@ func (d DummyJob) Run() {
func TestJobPanicRecovery(t *testing.T) {
var job DummyJob
var buf syncWriter
clock := clocktesting.NewFakeClock(time.Now())
cron := New(WithParser(secondParser),
WithChain(Recover(newBufLogger(&buf))))
WithChain(Recover(newBufLogger(&buf))),
WithClock(clock),
)
cron.Start()
defer cron.Stop()
cron.AddJob("* * * * * ?", job)
select {
case <-time.After(OneSecond):
if !strings.Contains(buf.String(), "YOLO") {
t.Error("expected a panic to be logged, got none")
}
return
}
assert.Eventually(t, clock.HasWaiters, OneSecond, 10*time.Millisecond)
clock.Step(OneSecond)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
assert.Contains(c, buf.String(), "YOLO")
}, OneSecond, 10*time.Millisecond)
}
// Start and stop cron with no entries.
func TestNoEntries(t *testing.T) {
cron := newWithSeconds()
cron, _ := newWithSeconds()
cron.Start()
select {
@ -122,13 +123,14 @@ func TestStopCausesJobsToNotRun(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
cron := newWithSeconds()
cron, clock := newWithSeconds()
cron.Start()
cron.Stop()
cron.AddFunc("* * * * * ?", func() { wg.Done() })
select {
case <-time.After(OneSecond):
case <-time.After(time.Millisecond * 100):
assert.False(t, clock.HasWaiters())
// No job ran!
case <-wait(wg):
t.Fatal("expected stopped cron does not run any job")
@ -140,11 +142,14 @@ func TestAddBeforeRunning(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
cron := newWithSeconds()
cron, clock := newWithSeconds()
cron.AddFunc("* * * * * ?", func() { wg.Done() })
cron.Start()
defer cron.Stop()
assert.Eventually(t, clock.HasWaiters, OneSecond, 10*time.Millisecond)
clock.Step(OneSecond)
// Give cron 2 seconds to run our job (which is always activated).
select {
case <-time.After(OneSecond):
@ -153,16 +158,19 @@ func TestAddBeforeRunning(t *testing.T) {
}
}
// Start cron, add a job, expect it runs.
// // Start cron, add a job, expect it runs.
func TestAddWhileRunning(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
cron := newWithSeconds()
cron, clock := newWithSeconds()
cron.Start()
defer cron.Stop()
cron.AddFunc("* * * * * ?", func() { wg.Done() })
assert.Eventually(t, clock.HasWaiters, OneSecond, 10*time.Millisecond)
clock.Step(OneSecond)
select {
case <-time.After(OneSecond):
t.Fatal("expected job runs")
@ -170,19 +178,21 @@ func TestAddWhileRunning(t *testing.T) {
}
}
// Test for #34. Adding a job after calling start results in multiple job invocations
// // Test for #34. Adding a job after calling start results in multiple job invocations
func TestAddWhileRunningWithDelay(t *testing.T) {
cron := newWithSeconds()
cron, clock := newWithSeconds()
cron.Start()
defer cron.Stop()
time.Sleep(5 * time.Second)
clock.Step(OneSecond * 5)
var calls int64
cron.AddFunc("* * * * * *", func() { atomic.AddInt64(&calls, 1) })
<-time.After(OneSecond)
if atomic.LoadInt64(&calls) != 1 {
t.Errorf("called %d times, expected 1\n", calls)
}
assert.Eventually(t, clock.HasWaiters, OneSecond, 10*time.Millisecond)
clock.Step(OneSecond)
assert.Eventually(t, func() bool {
return atomic.LoadInt64(&calls) == 1
}, OneSecond, 10*time.Millisecond)
}
// Add a job, remove a job, start cron, expect nothing runs.
@ -190,55 +200,72 @@ func TestRemoveBeforeRunning(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
cron := newWithSeconds()
cron, clock := newWithSeconds()
id, _ := cron.AddFunc("* * * * * ?", func() { wg.Done() })
cron.Remove(id)
cron.Start()
defer cron.Stop()
clock.Step(OneSecond)
select {
case <-time.After(OneSecond):
case <-time.After(time.Millisecond * 100):
// Success, shouldn't run
assert.False(t, clock.HasWaiters())
case <-wait(wg):
t.FailNow()
}
}
// Start cron, add a job, remove it, expect it doesn't run.
// // Start cron, add a job, remove it, expect it doesn't run.
func TestRemoveWhileRunning(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
cron := newWithSeconds()
cron, clock := newWithSeconds()
cron.Start()
defer cron.Stop()
id, _ := cron.AddFunc("* * * * * ?", func() { wg.Done() })
id, err := cron.AddFunc("* * * * * ?", func() { wg.Done() })
require.NoError(t, err)
assert.Eventually(t, clock.HasWaiters, OneSecond, 10*time.Millisecond)
cron.Remove(id)
assert.Eventually(t, func() bool {
return !clock.HasWaiters()
}, OneSecond, 10*time.Millisecond)
select {
case <-time.After(OneSecond):
case <-time.After(time.Millisecond * 100):
case <-wait(wg):
t.FailNow()
}
}
// Test timing with Entries.
// // Test timing with Entries.
func TestSnapshotEntries(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
cron := New()
clock := clocktesting.NewFakeClock(time.Now())
cron := New(WithClock(clock))
cron.AddFunc("@every 2s", func() { wg.Done() })
cron.Start()
defer cron.Stop()
// Cron should fire in 2 seconds. After 1 second, call Entries.
select {
case <-time.After(OneSecond):
cron.Entries()
}
assert.Eventually(t, clock.HasWaiters, OneSecond, 10*time.Millisecond)
clock.Step(OneSecond)
cron.Entries()
// Even though Entries was called, the cron should fire at the 2 second mark.
select {
case <-time.After(time.Millisecond * 100):
case <-wait(wg):
}
clock.Step(OneSecond)
select {
case <-time.After(OneSecond):
t.Error("expected job runs at 2 second mark")
@ -254,7 +281,7 @@ func TestMultipleEntries(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(2)
cron := newWithSeconds()
cron, clock := newWithSeconds()
cron.AddFunc("0 0 0 1 1 ?", func() {})
cron.AddFunc("* * * * * ?", func() { wg.Done() })
id1, _ := cron.AddFunc("* * * * * ?", func() { t.Fatal() })
@ -267,6 +294,9 @@ func TestMultipleEntries(t *testing.T) {
cron.Remove(id2)
defer cron.Stop()
assert.Eventually(t, clock.HasWaiters, OneSecond, 10*time.Millisecond)
clock.Step(OneSecond)
select {
case <-time.After(OneSecond):
t.Error("expected job run in proper order")
@ -274,12 +304,12 @@ func TestMultipleEntries(t *testing.T) {
}
}
// Test running the same job twice.
// // Test running the same job twice.
func TestRunningJobTwice(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(2)
cron := newWithSeconds()
cron, clock := newWithSeconds()
cron.AddFunc("0 0 0 1 1 ?", func() {})
cron.AddFunc("0 0 0 31 12 ?", func() {})
cron.AddFunc("* * * * * ?", func() { wg.Done() })
@ -287,6 +317,11 @@ func TestRunningJobTwice(t *testing.T) {
cron.Start()
defer cron.Stop()
assert.Eventually(t, clock.HasWaiters, OneSecond, 10*time.Millisecond)
clock.Step(OneSecond)
assert.Eventually(t, clock.HasWaiters, OneSecond, 10*time.Millisecond)
clock.Step(OneSecond)
select {
case <-time.After(2 * OneSecond):
t.Error("expected job fires 2 times")
@ -298,7 +333,7 @@ func TestRunningMultipleSchedules(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(2)
cron := newWithSeconds()
cron, clock := newWithSeconds()
cron.AddFunc("0 0 0 1 1 ?", func() {})
cron.AddFunc("0 0 0 31 12 ?", func() {})
cron.AddFunc("* * * * * ?", func() { wg.Done() })
@ -309,6 +344,9 @@ func TestRunningMultipleSchedules(t *testing.T) {
cron.Start()
defer cron.Stop()
assert.Eventually(t, clock.HasWaiters, OneSecond, 10*time.Millisecond)
clock.Step(OneSecond)
select {
case <-time.After(2 * OneSecond):
t.Error("expected job fires 2 times")
@ -321,22 +359,21 @@ func TestLocalTimezone(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(2)
now := time.Now()
// FIX: Issue #205
// This calculation doesn't work in seconds 58 or 59.
// Take the easy way out and sleep.
if now.Second() >= 58 {
time.Sleep(2 * time.Second)
now = time.Now()
}
now := time.Date(2016, 11, 8, 12, 0, 0, 0, time.Local)
spec := fmt.Sprintf("%d,%d %d %d %d %d ?",
now.Second()+1, now.Second()+2, now.Minute(), now.Hour(), now.Day(), now.Month())
cron := newWithSeconds()
cron, clock := newWithSeconds()
clock.SetTime(now)
cron.AddFunc(spec, func() { wg.Done() })
cron.Start()
defer cron.Stop()
assert.Eventually(t, clock.HasWaiters, OneSecond, 10*time.Millisecond)
clock.Step(OneSecond)
assert.Eventually(t, clock.HasWaiters, OneSecond, 10*time.Millisecond)
clock.Step(OneSecond)
select {
case <-time.After(OneSecond * 2):
t.Error("expected job fires 2 times")
@ -350,27 +387,28 @@ func TestNonLocalTimezone(t *testing.T) {
wg.Add(2)
loc, err := time.LoadLocation("Atlantic/Cape_Verde")
if err != nil {
fmt.Printf("Failed to load time zone Atlantic/Cape_Verde: %+v", err)
t.Fail()
require.NoError(t, err)
if loc == time.Local {
loc, err = time.LoadLocation("America/New_York")
require.NoError(t, err)
}
now := time.Now().In(loc)
// FIX: Issue #205
// This calculation doesn't work in seconds 58 or 59.
// Take the easy way out and sleep.
if now.Second() >= 58 {
time.Sleep(2 * time.Second)
now = time.Now().In(loc)
}
now := time.Date(2016, 11, 8, 12, 0, 0, 0, loc)
spec := fmt.Sprintf("%d,%d %d %d %d %d ?",
now.Second()+1, now.Second()+2, now.Minute(), now.Hour(), now.Day(), now.Month())
cron := New(WithLocation(loc), WithParser(secondParser))
clock := clocktesting.NewFakeClock(now)
cron := New(WithLocation(loc), WithParser(secondParser), WithClock(clock))
cron.AddFunc(spec, func() { wg.Done() })
cron.Start()
defer cron.Stop()
assert.Eventually(t, clock.HasWaiters, OneSecond, 10*time.Millisecond)
clock.Step(OneSecond)
assert.Eventually(t, clock.HasWaiters, OneSecond, 10*time.Millisecond)
clock.Step(OneSecond)
select {
case <-time.After(OneSecond * 2):
t.Error("expected job fires 2 times")
@ -380,7 +418,7 @@ func TestNonLocalTimezone(t *testing.T) {
// Test that calling stop before start silently returns without
// blocking the stop channel.
func TestStopWithoutStart(t *testing.T) {
func TestStopWithoutStart(*testing.T) {
cron := New()
cron.Stop()
}
@ -408,7 +446,7 @@ func TestBlockingRun(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
cron := newWithSeconds()
cron, clock := newWithSeconds()
cron.AddFunc("* * * * * ?", func() { wg.Done() })
unblockChan := make(chan struct{})
@ -419,6 +457,9 @@ func TestBlockingRun(t *testing.T) {
}()
defer cron.Stop()
assert.Eventually(t, clock.HasWaiters, OneSecond, 10*time.Millisecond)
clock.Step(OneSecond)
select {
case <-time.After(OneSecond):
t.Error("expected job fires")
@ -432,7 +473,7 @@ func TestBlockingRun(t *testing.T) {
func TestStartNoop(t *testing.T) {
tickChan := make(chan struct{}, 2)
cron := newWithSeconds()
cron, clock := newWithSeconds()
cron.AddFunc("* * * * * ?", func() {
tickChan <- struct{}{}
})
@ -440,11 +481,17 @@ func TestStartNoop(t *testing.T) {
cron.Start()
defer cron.Stop()
assert.Eventually(t, clock.HasWaiters, OneSecond, 10*time.Millisecond)
clock.Step(OneSecond)
// Wait for the first firing to ensure the runner is going
<-tickChan
cron.Start()
assert.Eventually(t, clock.HasWaiters, OneSecond, 10*time.Millisecond)
clock.Step(OneSecond)
<-tickChan
// Fail if this job fires again in a short period, indicating a double-run
@ -460,7 +507,7 @@ func TestJob(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
cron := newWithSeconds()
cron, clock := newWithSeconds()
cron.AddJob("0 0 0 30 Feb ?", testJob{wg, "job0"})
cron.AddJob("0 0 0 1 1 ?", testJob{wg, "job1"})
job2, _ := cron.AddJob("* * * * * ?", testJob{wg, "job2"})
@ -479,6 +526,9 @@ func TestJob(t *testing.T) {
cron.Start()
defer cron.Stop()
assert.Eventually(t, clock.HasWaiters, OneSecond, 10*time.Millisecond)
clock.Step(OneSecond)
select {
case <-time.After(OneSecond):
t.FailNow()
@ -524,7 +574,7 @@ func TestScheduleAfterRemoval(t *testing.T) {
var calls int
var mu sync.Mutex
cron := newWithSeconds()
cron, clock := newWithSeconds()
hourJob := cron.Schedule(Every(time.Hour), FuncJob(func() {}))
cron.Schedule(Every(time.Second), FuncJob(func() {
mu.Lock()
@ -534,7 +584,7 @@ func TestScheduleAfterRemoval(t *testing.T) {
wg1.Done()
calls++
case 1:
time.Sleep(750 * time.Millisecond)
<-clock.After(100 * time.Millisecond)
cron.Remove(hourJob)
calls++
case 2:
@ -548,10 +598,18 @@ func TestScheduleAfterRemoval(t *testing.T) {
cron.Start()
defer cron.Stop()
assert.Eventually(t, clock.HasWaiters, OneSecond, 10*time.Millisecond)
clock.Step(OneSecond)
// the first run might be any length of time 0 - 1s, since the schedule
// rounds to the second. wait for the first run to true up.
wg1.Wait()
assert.Eventually(t, clock.HasWaiters, OneSecond, 10*time.Millisecond)
clock.Step(OneSecond)
assert.Eventually(t, clock.HasWaiters, OneSecond, 10*time.Millisecond)
clock.Step(OneSecond)
select {
case <-time.After(2 * OneSecond):
t.Error("expected job fires 2 times")
@ -567,21 +625,22 @@ func (*ZeroSchedule) Next(time.Time) time.Time {
// Tests that job without time does not run
func TestJobWithZeroTimeDoesNotRun(t *testing.T) {
cron := newWithSeconds()
cron, clock := newWithSeconds()
var calls int64
cron.AddFunc("* * * * * *", func() { atomic.AddInt64(&calls, 1) })
cron.Schedule(new(ZeroSchedule), FuncJob(func() { t.Error("expected zero task will not run") }))
cron.Start()
defer cron.Stop()
<-time.After(OneSecond)
if atomic.LoadInt64(&calls) != 1 {
t.Errorf("called %d times, expected 1\n", calls)
}
assert.Eventually(t, clock.HasWaiters, OneSecond, 10*time.Millisecond)
clock.Step(OneSecond)
assert.Eventually(t, func() bool {
return atomic.LoadInt64(&calls) == 1
}, OneSecond, 10*time.Millisecond)
}
func TestStopAndWait(t *testing.T) {
t.Run("nothing running, returns immediately", func(t *testing.T) {
cron := newWithSeconds()
cron, _ := newWithSeconds()
cron.Start()
ctx := cron.Stop()
select {
@ -592,10 +651,9 @@ func TestStopAndWait(t *testing.T) {
})
t.Run("repeated calls to Stop", func(t *testing.T) {
cron := newWithSeconds()
cron, _ := newWithSeconds()
cron.Start()
_ = cron.Stop()
time.Sleep(time.Millisecond)
ctx := cron.Stop()
select {
case <-ctx.Done():
@ -605,13 +663,14 @@ func TestStopAndWait(t *testing.T) {
})
t.Run("a couple fast jobs added, still returns immediately", func(t *testing.T) {
cron := newWithSeconds()
cron, clock := newWithSeconds()
cron.AddFunc("* * * * * *", func() {})
cron.Start()
cron.AddFunc("* * * * * *", func() {})
cron.AddFunc("* * * * * *", func() {})
cron.AddFunc("* * * * * *", func() {})
time.Sleep(time.Second)
assert.Eventually(t, clock.HasWaiters, OneSecond, 10*time.Millisecond)
clock.Step(OneSecond)
ctx := cron.Stop()
select {
case <-ctx.Done():
@ -621,23 +680,32 @@ func TestStopAndWait(t *testing.T) {
})
t.Run("a couple fast jobs and a slow job added, waits for slow job", func(t *testing.T) {
cron := newWithSeconds()
funcClock := clocktesting.NewFakeClock(time.Now())
cron, clock := newWithSeconds()
cron.AddFunc("* * * * * *", func() {})
cron.Start()
cron.AddFunc("* * * * * *", func() { time.Sleep(2 * time.Second) })
cron.AddFunc("* * * * * *", func() { <-funcClock.After(OneSecond * 2) })
cron.AddFunc("* * * * * *", func() {})
time.Sleep(time.Second)
assert.Eventually(t, clock.HasWaiters, OneSecond, 10*time.Millisecond)
assert.False(t, funcClock.HasWaiters())
clock.Step(OneSecond)
assert.Eventually(t, funcClock.HasWaiters, OneSecond, 10*time.Millisecond)
funcClock.Step(OneSecond)
ctx := cron.Stop()
// Verify that it is not done for at least 750ms
// Verify that it is not done.
select {
case <-ctx.Done():
t.Error("context was done too quickly immediately")
case <-time.After(750 * time.Millisecond):
case <-time.After(10 * time.Millisecond):
// expected, because the job sleeping for 1 second is still running
}
assert.False(t, clock.HasWaiters())
funcClock.Step(OneSecond)
// Verify that it IS done in the next 500ms (giving 250ms buffer)
select {
case <-ctx.Done():
@ -648,12 +716,19 @@ func TestStopAndWait(t *testing.T) {
})
t.Run("repeated calls to stop, waiting for completion and after", func(t *testing.T) {
cron := newWithSeconds()
cron, clock := newWithSeconds()
funcClock := clocktesting.NewFakeClock(clock.Now())
cron.AddFunc("* * * * * *", func() {})
cron.AddFunc("* * * * * *", func() { time.Sleep(2 * time.Second) })
cron.AddFunc("* * * * * *", func() { <-funcClock.After(OneSecond * 2) })
cron.Start()
cron.AddFunc("* * * * * *", func() {})
time.Sleep(time.Second)
assert.Eventually(t, clock.HasWaiters, OneSecond, 10*time.Millisecond)
assert.False(t, funcClock.HasWaiters())
clock.Step(OneSecond)
assert.Eventually(t, funcClock.HasWaiters, OneSecond, 10*time.Millisecond)
funcClock.Step(time.Millisecond * 1500)
ctx := cron.Stop()
ctx2 := cron.Stop()
@ -663,10 +738,14 @@ func TestStopAndWait(t *testing.T) {
t.Error("context was done too quickly immediately")
case <-ctx2.Done():
t.Error("context2 was done too quickly immediately")
case <-time.After(1500 * time.Millisecond):
case <-time.After(100 * time.Millisecond):
// expected, because the job sleeping for 2 seconds is still running
}
assert.False(t, clock.HasWaiters())
assert.True(t, funcClock.HasWaiters())
funcClock.Step(time.Millisecond * 600)
// Verify that it IS done in the next 1s (giving 500ms buffer)
select {
case <-ctx.Done():
@ -697,26 +776,77 @@ func TestStopAndWait(t *testing.T) {
func TestMockClock(t *testing.T) {
clk := clocktesting.NewFakeClock(time.Now())
cron := New(WithClock(clk))
counter := atomic.Uint64{}
counter := atomic.Int64{}
cron.AddFunc("@every 1s", func() {
counter.Add(1)
})
cron.Start()
defer cron.Stop()
for i := 0; i <= 10; i++ {
for range 11 {
assert.Eventually(t, clk.HasWaiters, OneSecond, 10*time.Millisecond)
clk.Step(1 * time.Second)
runtime.Gosched()
time.Sleep(100 * time.Millisecond)
}
if counter.Load() != 10 {
t.Errorf("expected 10 calls, got %d", counter.Load())
}
assert.Equal(t, int64(10), counter.Load())
}
func TestMultiThreadedStartAndStop(t *testing.T) {
func TestMillisecond(t *testing.T) {
clk := clocktesting.NewFakeClock(time.Now())
cron := New(WithClock(clk))
counter1ms := atomic.Int64{}
counter15ms := atomic.Int64{}
counter100ms := atomic.Int64{}
cron.AddFunc("@every 1ms", func() {
counter1ms.Add(1)
})
cron.AddFunc("@every 15ms", func() {
counter15ms.Add(1)
})
cron.AddFunc("@every 100ms", func() {
counter100ms.Add(1)
})
cron.Start()
defer cron.Stop()
for range 1000 {
assert.Eventually(t, clk.HasWaiters, OneSecond, 1*time.Millisecond)
clk.Step(1 * time.Millisecond)
}
ctx := cron.Stop()
<-ctx.Done()
assert.Equal(t, int64(1000), counter1ms.Load())
assert.Equal(t, int64(66), counter15ms.Load())
assert.Equal(t, int64(10), counter100ms.Load())
}
func TestNanoseconds(t *testing.T) {
clk := clocktesting.NewFakeClock(time.Now())
cron := New(WithClock(clk))
counter100ns := atomic.Int64{}
cron.AddFunc("@every 100ns", func() {
counter100ns.Add(1)
})
cron.Start()
defer cron.Stop()
for range 500 {
assert.Eventually(t, clk.HasWaiters, OneSecond, 1*time.Millisecond)
clk.Step(5 * time.Nanosecond)
}
ctx := cron.Stop()
<-ctx.Done()
// 500 * 5 ns = 2500 ns
// 2500 every 100ns = 25
assert.Equal(t, int64(25), counter100ns.Load())
}
func TestMultiThreadedStartAndStop(*testing.T) {
cron := New()
go cron.Run()
time.Sleep(2 * time.Millisecond)
cron.Stop()
}
@ -739,6 +869,7 @@ func stop(cron *Cron) chan bool {
}
// newWithSeconds returns a Cron with the seconds field enabled.
func newWithSeconds() *Cron {
return New(WithParser(secondParser), WithChain())
func newWithSeconds() (*Cron, *clocktesting.FakeClock) {
clock := clocktesting.NewFakeClock(time.Now())
return New(WithParser(secondParser), WithChain(), WithClock(clock)), clock
}

View File

@ -1,4 +1,3 @@
//nolint
/*
This package is a fork of "github.com/robfig/cron/v3" that implements cron spec parser and job runner with support for mocking the time.
@ -36,7 +35,9 @@ them in their own goroutines.
# Time mocking
import (
clocktesting "k8s.io/utils/clock/testing"
)
clk := clocktesting.NewFakeClock(time.Now())

View File

@ -14,7 +14,6 @@ You can check the original license at:
https://github.com/robfig/cron/blob/master/LICENSE
*/
//nolint
package cron
import (
@ -22,6 +21,9 @@ import (
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
clocktesting "k8s.io/utils/clock/testing"
)
func TestWithLocation(t *testing.T) {
@ -42,18 +44,22 @@ func TestWithParser(t *testing.T) {
func TestWithVerboseLogger(t *testing.T) {
var buf syncWriter
logger := log.New(&buf, "", log.LstdFlags)
c := New(WithLogger(VerbosePrintfLogger(logger)))
clock := clocktesting.NewFakeClock(time.Now())
c := New(WithLogger(VerbosePrintfLogger(logger)), WithClock(clock))
if c.logger.(printfLogger).logger != logger {
t.Error("expected provided logger")
}
c.AddFunc("@every 1s", func() {})
c.Start()
time.Sleep(OneSecond)
assert.Eventually(t, clock.HasWaiters, OneSecond, time.Millisecond*10)
clock.Step(OneSecond)
c.Stop()
out := buf.String()
if !strings.Contains(out, "schedule,") ||
!strings.Contains(out, "run,") {
t.Error("expected to see some actions, got:", out)
}
assert.EventuallyWithT(t, func(c *assert.CollectT) {
out := buf.String()
if !strings.Contains(out, "schedule,") ||
!strings.Contains(out, "run,") {
c.Errorf("expected to see some actions, got: %v", out)
}
}, time.Second, time.Millisecond*10)
}

View File

@ -14,7 +14,6 @@ You can check the original license at:
https://github.com/robfig/cron/blob/master/LICENSE
*/
//nolint
package cron
import (
@ -167,6 +166,8 @@ func TestParseSchedule(t *testing.T) {
{standardParser, "CRON_TZ=UTC 5 * * * *", every5min(time.UTC)},
{secondParser, "CRON_TZ=Asia/Tokyo 0 5 * * * *", every5min(tokyo)},
{secondParser, "@every 5m", ConstantDelaySchedule{5 * time.Minute}},
{secondParser, "@every 5ms", ConstantDelaySchedule{5 * time.Millisecond}},
{secondParser, "@every 5ns", ConstantDelaySchedule{5 * time.Nanosecond}},
{secondParser, "@midnight", midnight(time.Local)},
{secondParser, "TZ=UTC @midnight", midnight(time.UTC)},
{secondParser, "TZ=Asia/Tokyo @midnight", midnight(tokyo)},

View File

@ -14,7 +14,6 @@ You can check the original license at:
https://github.com/robfig/cron/blob/master/LICENSE
*/
//nolint
package cron
import (

View File

@ -214,9 +214,9 @@ func (aead *aesCBCAEAD) Open(dst, nonce, ciphertext, additionalData []byte) ([]b
}
// Computes the HMAC tag as per specs.
func (aead aesCBCAEAD) hmacTag(h hash.Hash, additionalData, nonce, ciphertext []byte, l int) []byte {
func (aead *aesCBCAEAD) hmacTag(h hash.Hash, additionalData, nonce, ciphertext []byte, l int) []byte {
al := make([]byte, 8)
binary.BigEndian.PutUint64(al, uint64(len(additionalData)<<3)) // In bits
binary.BigEndian.PutUint64(al, uint64(len(additionalData)<<3)) // #nosec G115 // In bits
h.Write(additionalData)
h.Write(nonce)

View File

@ -32,7 +32,7 @@ func TestNewAESCBCAEAD(t *testing.T) {
aead, err := NewAESCBC128SHA256(key)
require.NoError(t, err)
require.Equal(t, len(nonce), aead.NonceSize())
require.Len(t, nonce, aead.NonceSize())
require.Equal(t, 16, aead.Overhead())
gotCiphertext := aead.Seal(nil, nonce, plaintext, aad)
@ -52,7 +52,7 @@ func TestNewAESCBCAEAD(t *testing.T) {
aead, err := NewAESCBC192SHA384(key)
require.NoError(t, err)
require.Equal(t, len(nonce), aead.NonceSize())
require.Len(t, nonce, aead.NonceSize())
require.Equal(t, 24, aead.Overhead())
gotCiphertext := aead.Seal(nil, nonce, plaintext, aad)
@ -72,7 +72,7 @@ func TestNewAESCBCAEAD(t *testing.T) {
aead, err := NewAESCBC256SHA384(key)
require.NoError(t, err)
require.Equal(t, len(nonce), aead.NonceSize())
require.Len(t, nonce, aead.NonceSize())
require.Equal(t, 24, aead.Overhead())
gotCiphertext := aead.Seal(nil, nonce, plaintext, aad)
@ -92,7 +92,7 @@ func TestNewAESCBCAEAD(t *testing.T) {
aead, err := NewAESCBC256SHA512(key)
require.NoError(t, err)
require.Equal(t, len(nonce), aead.NonceSize())
require.Len(t, nonce, aead.NonceSize())
require.Equal(t, 32, aead.Overhead())
gotCiphertext := aead.Seal(nil, nonce, plaintext, aad)

View File

@ -48,14 +48,14 @@ func Wrap(block cipher.Block, cek []byte) ([]byte, error) {
copy(r[i], cek[i*8:])
}
for j := 0; j <= 5; j++ {
for j := range 6 {
for i := 1; i <= n; i++ {
b := arrConcat(a, r[i-1])
block.Encrypt(b, b)
t := (n * j) + i
tBytes := make([]byte, 8)
binary.BigEndian.PutUint64(tBytes, uint64(t))
binary.BigEndian.PutUint64(tBytes, uint64(t)) // #nosec G115
copy(a, arrXor(b[:len(b)/2], tBytes))
copy(r[i-1], b[len(b)/2:])
@ -92,7 +92,7 @@ func Unwrap(block cipher.Block, cipherText []byte) ([]byte, error) {
for i := n; i >= 1; i-- {
t := (n * j) + i
tBytes := make([]byte, 8)
binary.BigEndian.PutUint64(tBytes, uint64(t))
binary.BigEndian.PutUint64(tBytes, uint64(t)) // #nosec G115
b := arrConcat(arrXor(a, tBytes), r[i-1])
block.Decrypt(b, b)

View File

@ -85,12 +85,12 @@ func TestWrapRfc3394Vectors(t *testing.T) {
exp := mustHexDecode(v.Expected)
cipher, err := aes.NewCipher(kek)
if !assert.NoError(t, err, "NewCipher should not fail!") {
if !assert.NoError(t, err, "NewCipher should not fail!") { //nolint:testifylint
continue
}
actual, err := Wrap(cipher, data)
if !assert.NoError(t, err, "Wrap should not throw error with valid input") {
if !assert.NoError(t, err, "Wrap should not throw error with valid input") { //nolint:testifylint
continue
}
if !assert.Equal(t, exp, actual, "Wrap Mismatch: Actual wrapped ciphertext should equal expected for test case '%s'", v.Case) {
@ -98,7 +98,7 @@ func TestWrapRfc3394Vectors(t *testing.T) {
}
actualUnwrapped, err := Unwrap(cipher, actual)
if !assert.NoError(t, err, "Unwrap should not throw error with valid input") {
if !assert.NoError(t, err, "Unwrap should not throw error with valid input") { //nolint:testifylint
continue
}
if !assert.Equal(t, data, actualUnwrapped, "Unwrap Mismatch: Actual unwrapped ciphertext should equal the original data for test case '%s'", v.Case) {

View File

@ -11,7 +11,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
//nolint:nosnakecase,stylecheck,revive
//nolint:nosnakecase,stylecheck
package crypto
import (

View File

@ -97,7 +97,7 @@ func TestPkcs7(t *testing.T) {
t.Run("Invalid length while unpadding", func(t *testing.T) {
unpadded, err := UnpadPKCS7([]byte("1234567890\x06\x06\x06\x06"), blockSize)
require.Error(t, err)
assert.ErrorIs(t, err, ErrInvalidPKCS7Padding)
require.ErrorIs(t, err, ErrInvalidPKCS7Padding)
assert.Nil(t, unpadded)
})
@ -112,7 +112,7 @@ func TestPkcs7(t *testing.T) {
for _, tt := range tests {
unpadded, err := UnpadPKCS7(tt, blockSize)
require.Error(t, err)
assert.ErrorIs(t, err, ErrInvalidPKCS7Padding)
require.ErrorIs(t, err, ErrInvalidPKCS7Padding)
assert.Nil(t, unpadded)
}
})
@ -120,12 +120,12 @@ func TestPkcs7(t *testing.T) {
t.Run("Invalid block size", func(t *testing.T) {
res, err := PadPKCS7([]byte("1234567890ABCDEF"), 260)
require.Error(t, err)
assert.ErrorIs(t, err, ErrInvalidPKCS7BlockSize)
require.ErrorIs(t, err, ErrInvalidPKCS7BlockSize)
assert.Nil(t, res)
res, err = UnpadPKCS7([]byte("1234567890ABCDEF"), 260)
require.Error(t, err)
assert.ErrorIs(t, err, ErrInvalidPKCS7BlockSize)
require.ErrorIs(t, err, ErrInvalidPKCS7BlockSize)
assert.Nil(t, res)
})

211
crypto/pem/pem.go Normal file
View File

@ -0,0 +1,211 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package pem
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"os"
)
// DecodePEMCertificatesChain takes a PEM-encoded x509 certificates byte array
// and returns all certificates in a slice of x509.Certificate objects.
// Expects certificates to be a chain with leaf certificate to be first in the
// byte array.
func DecodePEMCertificatesChain(crtb []byte) ([]*x509.Certificate, error) {
certs, err := DecodePEMCertificates(crtb)
if err != nil {
return nil, err
}
for i := range len(certs) - 1 {
if certs[i].CheckSignatureFrom(certs[i+1]) != nil {
return nil, errors.New("certificate chain is not valid")
}
}
return certs, nil
}
// DecodePEMCertificatesChain takes a PEM-encoded x509 certificates byte array
// and returns all certificates in a slice of x509.Certificate objects.
func DecodePEMCertificates(crtb []byte) ([]*x509.Certificate, error) {
certs := []*x509.Certificate{}
for len(crtb) > 0 {
var err error
var cert *x509.Certificate
cert, crtb, err = decodeCertificatePEM(crtb)
if err != nil {
return nil, err
}
if cert != nil {
// it's a cert, add to pool
certs = append(certs, cert)
}
}
if len(certs) == 0 {
return nil, errors.New("no certificates found")
}
return certs, nil
}
func decodeCertificatePEM(crtb []byte) (*x509.Certificate, []byte, error) {
block, crtb := pem.Decode(crtb)
if block == nil {
return nil, nil, nil
}
if block.Type != "CERTIFICATE" {
return nil, nil, nil
}
c, err := x509.ParseCertificate(block.Bytes)
return c, crtb, err
}
// DecodePEMPrivateKey takes a key PEM byte array and returns an object that
// represents either an RSA or EC private key.
func DecodePEMPrivateKey(key []byte) (crypto.Signer, error) {
block, _ := pem.Decode(key)
if block == nil {
return nil, errors.New("key is not PEM encoded")
}
switch block.Type {
case "EC PRIVATE KEY":
return x509.ParseECPrivateKey(block.Bytes)
case "RSA PRIVATE KEY":
return x509.ParsePKCS1PrivateKey(block.Bytes)
case "PRIVATE KEY":
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, err
}
return key.(crypto.Signer), nil
default:
return nil, fmt.Errorf("unsupported block type %s", block.Type)
}
}
// EncodePrivateKey will encode a private key into PEM format.
func EncodePrivateKey(key any) ([]byte, error) {
var (
keyBytes []byte
err error
blockType string
)
switch key := key.(type) {
case *ecdsa.PrivateKey, *ed25519.PrivateKey:
keyBytes, err = x509.MarshalPKCS8PrivateKey(key)
if err != nil {
return nil, err
}
blockType = "PRIVATE KEY"
default:
return nil, fmt.Errorf("unsupported key type %T", key)
}
return pem.EncodeToMemory(&pem.Block{
Type: blockType, Bytes: keyBytes,
}), nil
}
// EncodeX509 will encode a single *x509.Certificate into PEM format.
func EncodeX509(cert *x509.Certificate) ([]byte, error) {
caPem := bytes.NewBuffer([]byte{})
err := pem.Encode(caPem, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
if err != nil {
return nil, err
}
return caPem.Bytes(), nil
}
// EncodeX509Chain will encode a list of *x509.Certificates into a PEM format chain.
// Self-signed certificates are not included as per
// https://datatracker.ietf.org/doc/html/rfc5246#section-7.4.2
// Certificates are output in the order they're given; if the input is not ordered
// as specified in RFC5246 section 7.4.2, the resulting chain might not be valid
// for use in TLS.
func EncodeX509Chain(certs []*x509.Certificate) ([]byte, error) {
if len(certs) == 0 {
return nil, errors.New("no certificates in chain")
}
certPEM := bytes.NewBuffer([]byte{})
for _, cert := range certs {
if cert == nil {
continue
}
if cert.CheckSignatureFrom(cert) == nil {
// Don't include self-signed certificate
continue
}
err := pem.Encode(certPEM, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
if err != nil {
return nil, err
}
}
return certPEM.Bytes(), nil
}
// PublicKeysEqual compares two given public keys for equality.
// The definition of "equality" depends on the type of the public keys.
// Returns true if the keys are the same, false if they differ or an error if
// the key type of `a` cannot be determined.
func PublicKeysEqual(a, b crypto.PublicKey) (bool, error) {
switch pub := a.(type) {
case *rsa.PublicKey:
return pub.Equal(b), nil
case *ecdsa.PublicKey:
return pub.Equal(b), nil
case ed25519.PublicKey:
return pub.Equal(b), nil
default:
return false, fmt.Errorf("unrecognised public key type: %T", a)
}
}
// GetPEM loads a PEM-encoded file (certificate or key).
func GetPEM(val string) ([]byte, error) {
// If val is already a PEM-encoded string, return it as-is
if IsValidPEM(val) {
return []byte(val), nil
}
// Assume it's a file
pemBytes, err := os.ReadFile(val)
if err != nil {
return nil, fmt.Errorf("value is neither a valid file path or nor a valid PEM-encoded string: %w", err)
}
return pemBytes, nil
}
// IsValidPEM validates the provided input has PEM formatted block.
func IsValidPEM(val string) bool {
block, _ := pem.Decode([]byte(val))
return block != nil
}

View File

@ -0,0 +1,71 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package context
import (
"context"
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
"github.com/dapr/kit/crypto/spiffe"
)
type ctxkey int
const (
x509SvidKey ctxkey = iota
jwtSvidKey
)
// Deprecated: use WithX509 instead.
// With adds the x509 SVID source from the SPIFFE object to the context.
func With(ctx context.Context, spiffe *spiffe.SPIFFE) context.Context {
return context.WithValue(ctx, x509SvidKey, spiffe.X509SVIDSource())
}
// Deprecated: use X509From instead.
// From retrieves the x509 SVID source from the context.
func From(ctx context.Context) (x509svid.Source, bool) {
svid, ok := ctx.Value(x509SvidKey).(x509svid.Source)
return svid, ok
}
// WithX509 adds an x509 SVID source to the context.
func WithX509(ctx context.Context, source x509svid.Source) context.Context {
return context.WithValue(ctx, x509SvidKey, source)
}
// WithJWT adds a JWT SVID source to the context.
func WithJWT(ctx context.Context, source jwtsvid.Source) context.Context {
return context.WithValue(ctx, jwtSvidKey, source)
}
// X509From retrieves the x509 SVID source from the context.
func X509From(ctx context.Context) (x509svid.Source, bool) {
svid, ok := ctx.Value(x509SvidKey).(x509svid.Source)
return svid, ok
}
// JWTFrom retrieves the JWT SVID source from the context.
func JWTFrom(ctx context.Context) (jwtsvid.Source, bool) {
svid, ok := ctx.Value(jwtSvidKey).(jwtsvid.Source)
return svid, ok
}
// WithSpiffe adds both X509 and JWT SVID sources to the context.
func WithSpiffe(ctx context.Context, spiffe *spiffe.SPIFFE) context.Context {
ctx = WithX509(ctx, spiffe.X509SVIDSource())
return WithJWT(ctx, spiffe.JWTSVIDSource())
}

View File

@ -0,0 +1,64 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package context
import (
"context"
"testing"
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
"github.com/stretchr/testify/assert"
)
type mockX509Source struct{}
func (m *mockX509Source) GetX509SVID() (*x509svid.SVID, error) {
return nil, nil
}
type mockJWTSource struct{}
func (m *mockJWTSource) FetchJWTSVID(context.Context, jwtsvid.Params) (*jwtsvid.SVID, error) {
return nil, nil
}
func TestWithX509FromX509(t *testing.T) {
source := &mockX509Source{}
ctx := WithX509(t.Context(), source)
retrieved, ok := X509From(ctx)
assert.True(t, ok, "Failed to retrieve X509 source from context")
assert.Equal(t, x509svid.Source(source), retrieved, "Retrieved source does not match the original source")
}
func TestWithJWTFromJWT(t *testing.T) {
source := &mockJWTSource{}
ctx := WithJWT(t.Context(), source)
retrieved, ok := JWTFrom(ctx)
assert.True(t, ok, "Failed to retrieve JWT source from context")
assert.Equal(t, jwtsvid.Source(source), retrieved, "Retrieved source does not match the original source")
}
func TestWithFrom(t *testing.T) {
x509Source := &mockX509Source{}
ctx := WithX509(t.Context(), x509Source)
// Should be able to retrieve using the legacy From function
retrieved, ok := From(ctx)
assert.True(t, ok, "Failed to retrieve X509 source from context using legacy From")
assert.Equal(t, x509svid.Source(x509Source), retrieved, "Retrieved source does not match the original source using legacy From")
}

352
crypto/spiffe/spiffe.go Normal file
View File

@ -0,0 +1,352 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package spiffe
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
"k8s.io/utils/clock"
"github.com/dapr/kit/concurrency/dir"
"github.com/dapr/kit/crypto/pem"
"github.com/dapr/kit/crypto/spiffe/trustanchors"
"github.com/dapr/kit/logger"
)
const (
// renewalDivisor represents the divisor for calculating renewal time.
// A value of 2 means renewal at 50% of the validity period.
renewalDivisor = 2
)
// SVIDResponse represents the response from the SVID request function,
// containing both X.509 certificates and a JWT token.
type SVIDResponse struct {
X509Certificates []*x509.Certificate
JWT *string
}
// Identity contains both X.509 and JWT SVIDs for a workload.
type Identity struct {
X509SVID *x509svid.SVID
JWTSVID *jwtsvid.SVID
}
type (
// RequestSVIDFn is the function type that requests SVIDs from a SPIFFE server,
// returning both X.509 certificates and a JWT token.
RequestSVIDFn func(context.Context, []byte) (*SVIDResponse, error)
)
type Options struct {
Log logger.Logger
RequestSVIDFn RequestSVIDFn
// WriteIdentityToFile is used to write the identity private key and
// certificate chain to file. The certificate chain and private key will be
// written to the `tls.cert` and `tls.key` files respectively in the given
// directory.
WriteIdentityToFile *string
TrustAnchors trustanchors.Interface
}
// SPIFFE is a readable/writeable store of SPIFFE SVID credentials.
// Used to manage workload SVIDs, and share read-only interfaces to consumers.
type SPIFFE struct {
currentX509SVID *x509svid.SVID
currentJWTSVID *jwtsvid.SVID
requestSVIDFn RequestSVIDFn
dir *dir.Dir
trustAnchors trustanchors.Interface
log logger.Logger
lock sync.RWMutex
clock clock.Clock
running atomic.Bool
readyCh chan struct{}
}
func New(opts Options) *SPIFFE {
var sdir *dir.Dir
if opts.WriteIdentityToFile != nil {
sdir = dir.New(dir.Options{
Log: opts.Log,
Target: *opts.WriteIdentityToFile,
})
}
return &SPIFFE{
requestSVIDFn: opts.RequestSVIDFn,
dir: sdir,
trustAnchors: opts.TrustAnchors,
log: opts.Log,
clock: clock.RealClock{},
readyCh: make(chan struct{}),
}
}
func (s *SPIFFE) Run(ctx context.Context) error {
if !s.running.CompareAndSwap(false, true) {
return errors.New("already running")
}
s.lock.Lock()
s.log.Info("Fetching initial identity")
initialIdentity, err := s.fetchIdentity(ctx)
if err != nil {
close(s.readyCh)
s.lock.Unlock()
return fmt.Errorf("failed to retrieve the initial identity: %w", err)
}
s.currentX509SVID = initialIdentity.X509SVID
s.currentJWTSVID = initialIdentity.JWTSVID
close(s.readyCh)
s.lock.Unlock()
s.log.Infof("Security is initialized successfully")
s.runRotation(ctx)
return nil
}
// Ready blocks until SPIFFE is ready or the context is done which will return
// the context error.
func (s *SPIFFE) Ready(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-s.readyCh:
return nil
}
}
// logIdentityInfo creates a log message with expiry details for both X.509 and JWT SVIDs
func (s *SPIFFE) logIdentityInfo(prefix string, cert *x509.Certificate, jwtSVID *jwtsvid.SVID, renewTime *time.Time) {
msg := prefix + "; cert expires on: %s"
args := []any{cert.NotAfter.String()}
if jwtSVID != nil {
msg += ", jwt expires on: %s"
args = append(args, jwtSVID.Expiry.String())
}
if renewTime != nil {
msg += ", renewal at: %s"
args = append(args, renewTime.String())
}
s.log.Infof(msg, args...)
}
// runRotation starts up the manager responsible for renewing the workload identity
func (s *SPIFFE) runRotation(ctx context.Context) {
defer s.log.Debug("stopping workload identity expiry watcher")
s.lock.RLock()
cert := s.currentX509SVID.Certificates[0]
jwtSVID := s.currentJWTSVID
s.lock.RUnlock()
renewTime := calculateRenewalTime(time.Now(), cert, jwtSVID)
s.logIdentityInfo("Starting workload identity expiry watcher", cert, jwtSVID, renewTime)
for {
select {
case <-s.clock.After(min(time.Minute, renewTime.Sub(s.clock.Now()))):
if s.clock.Now().Before(*renewTime) {
continue
}
s.logIdentityInfo("Renewing workload identity", cert, jwtSVID, nil)
identity, err := s.fetchIdentity(ctx)
if err != nil {
s.log.Errorf("Error renewing identity, trying again in 10 seconds: %s", err)
select {
case <-s.clock.After(10 * time.Second):
continue
case <-ctx.Done():
return
}
}
s.lock.Lock()
s.currentX509SVID = identity.X509SVID
s.currentJWTSVID = identity.JWTSVID
cert = identity.X509SVID.Certificates[0]
jwtSVID = identity.JWTSVID
s.lock.Unlock()
renewTime = calculateRenewalTime(time.Now(), cert, jwtSVID)
s.logIdentityInfo("Successfully renewed workload identity", cert, jwtSVID, renewTime)
case <-ctx.Done():
return
}
}
}
// Returns both X.509 SVID and JWT SVID (if available).
func (s *SPIFFE) fetchIdentity(ctx context.Context) (*Identity, error) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, fmt.Errorf("failed to generate private key: %w", err)
}
csrDER, err := x509.CreateCertificateRequest(rand.Reader, new(x509.CertificateRequest), key)
if err != nil {
return nil, fmt.Errorf("failed to create sidecar csr: %w", err)
}
svidResponse, err := s.requestSVIDFn(ctx, csrDER)
if err != nil {
return nil, err
}
if len(svidResponse.X509Certificates) == 0 {
return nil, errors.New("no certificates received from sentry")
}
spiffeID, err := x509svid.IDFromCert(svidResponse.X509Certificates[0])
if err != nil {
return nil, fmt.Errorf("error parsing spiffe id from newly signed certificate: %w", err)
}
identity := &Identity{
X509SVID: &x509svid.SVID{
ID: spiffeID,
Certificates: svidResponse.X509Certificates,
PrivateKey: key,
},
}
// If we have a JWT token, parse it and include it in the identity
if svidResponse.JWT != nil {
// we are using ParseInsecure here as the expectation is that the
// requestSVIDFn will have already parsed and validate the JWT SVID
// before returning it.
//
// we are parsing the token using our SPIFFE ID's trust domain
// as the audience as we expect the issuer to always include
// that as an audience since that ensures that the token is
// valid for us and our trust domain.
audiences := []string{spiffeID.TrustDomain().Name()}
jwtSvid, err := jwtsvid.ParseInsecure(*svidResponse.JWT, audiences)
if err != nil {
return nil, fmt.Errorf("failed to parse JWT SVID: %w", err)
}
identity.JWTSVID = jwtSvid
s.log.Infof("Successfully received JWT SVID with expiry: %s", jwtSvid.Expiry.String())
}
if s.dir != nil {
pkPEM, err := pem.EncodePrivateKey(key)
if err != nil {
return nil, err
}
certPEM, err := pem.EncodeX509Chain(svidResponse.X509Certificates)
if err != nil {
return nil, err
}
td, err := s.trustAnchors.CurrentTrustAnchors(ctx)
if err != nil {
return nil, err
}
files := map[string][]byte{
"key.pem": pkPEM,
"cert.pem": certPEM,
"ca.pem": td,
}
if svidResponse.JWT != nil {
files["jwt_svid.token"] = []byte(*svidResponse.JWT)
}
if err := s.dir.Write(files); err != nil {
return nil, err
}
}
return identity, nil
}
func (s *SPIFFE) X509SVIDSource() x509svid.Source {
return &svidSource{spiffe: s}
}
func (s *SPIFFE) JWTSVIDSource() jwtsvid.Source {
return &svidSource{spiffe: s}
}
// renewalTime is 50% through the certificate validity period.
func renewalTime(notBefore, notAfter time.Time) time.Time {
return notBefore.Add(notAfter.Sub(notBefore) / renewalDivisor)
}
// calculateRenewalTime returns the earlier renewal time between the X.509 certificate
// and JWT SVID (if available) to ensure timely renewal.
func calculateRenewalTime(now time.Time, cert *x509.Certificate, jwtSVID *jwtsvid.SVID) *time.Time {
certRenewal := renewalTime(cert.NotBefore, cert.NotAfter)
if jwtSVID == nil {
return &certRenewal
}
jwtRenewal := now.Add(jwtSVID.Expiry.Sub(now) / renewalDivisor)
if jwtRenewal.Before(certRenewal) {
return &jwtRenewal
}
return &certRenewal
}
// audiencesMatch checks if the SVID audiences contain all the requested audiences
func audiencesMatch(svidAudiences []string, requestedAudiences []string) bool {
if len(requestedAudiences) == 0 {
return true
}
// Create a map for faster lookup
audienceMap := make(map[string]struct{}, len(svidAudiences))
for _, audience := range svidAudiences {
audienceMap[audience] = struct{}{}
}
// Check if all requested audiences are in the SVID
for _, requested := range requestedAudiences {
if _, ok := audienceMap[requested]; !ok {
return false
}
}
return true
}

View File

@ -0,0 +1,269 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package spiffe
import (
"context"
"crypto/x509"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
clocktesting "k8s.io/utils/clock/testing"
"github.com/dapr/kit/crypto/test"
"github.com/dapr/kit/logger"
)
func Test_renewalTime(t *testing.T) {
now := time.Now()
assert.Equal(t, now, renewalTime(now, now))
in1Min := now.Add(time.Minute)
in30 := now.Add(time.Second * 30)
assert.Equal(t, in30, renewalTime(now, in1Min))
}
func Test_calculateRenewalTime(t *testing.T) {
now := time.Now()
certShort := &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(10 * time.Hour),
}
certLong := &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(24 * time.Hour),
}
// Expected renewal times for certificates (50% of validity period)
certShortRenewal := now.Add(5 * time.Hour)
// Create JWT SVIDs with different expiry times
jwtEarlier := &jwtsvid.SVID{
Expiry: now.Add(8 * time.Hour),
}
jwtLater := &jwtsvid.SVID{
Expiry: now.Add(30 * time.Hour),
}
// Expected JWT renewal time (50% of remaining time)
jwtEarlierRenewal := now.Add(4 * time.Hour)
tests := []struct {
name string
cert *x509.Certificate
jwt *jwtsvid.SVID
expected time.Time
}{
{
name: "Certificate only",
cert: certShort,
jwt: nil,
expected: certShortRenewal,
},
{
name: "Certificate and JWT, JWT earlier",
cert: certLong,
jwt: jwtEarlier,
expected: jwtEarlierRenewal,
},
{
name: "Certificate and JWT, Certificate earlier",
cert: certShort,
jwt: jwtLater,
expected: certShortRenewal,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := calculateRenewalTime(now, tt.cert, tt.jwt)
assert.WithinDuration(t, tt.expected, *actual, time.Millisecond,
"Renewal time does not match expected value")
})
}
}
func Test_Run(t *testing.T) {
t.Run("should return error multiple Runs are called", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{
LeafID: spiffeid.RequireFromString("spiffe://example.com/foo/bar"),
})
ctx, cancel := context.WithCancel(t.Context())
s := New(Options{
Log: logger.NewLogger("test"),
RequestSVIDFn: func(context.Context, []byte) (*SVIDResponse, error) {
return &SVIDResponse{
X509Certificates: []*x509.Certificate{pki.LeafCert},
}, nil
},
})
errCh := make(chan error)
go func() {
errCh <- s.Run(ctx)
}()
go func() {
errCh <- s.Run(ctx)
}()
select {
case err := <-errCh:
require.Error(t, err)
case <-time.After(time.Second):
assert.Fail(t, "Expected error")
}
cancel()
select {
case err := <-errCh:
require.NoError(t, err)
case <-time.After(time.Second):
assert.Fail(t, "First Run should have returned and returned no error ")
}
})
t.Run("should return error if initial fetch errors", func(t *testing.T) {
s := New(Options{
Log: logger.NewLogger("test"),
RequestSVIDFn: func(context.Context, []byte) (*SVIDResponse, error) {
return nil, errors.New("this is an error")
},
})
require.Error(t, s.Run(t.Context()))
})
t.Run("should renew certificate when it has expired", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{
LeafID: spiffeid.RequireFromString("spiffe://example.com/foo/bar"),
})
var fetches atomic.Int32
s := New(Options{
Log: logger.NewLogger("test"),
RequestSVIDFn: func(context.Context, []byte) (*SVIDResponse, error) {
fetches.Add(1)
return &SVIDResponse{
X509Certificates: []*x509.Certificate{pki.LeafCert},
}, nil
},
})
now := time.Now()
clock := clocktesting.NewFakeClock(now)
s.clock = clock
ctx, cancel := context.WithCancel(t.Context())
errCh := make(chan error)
go func() {
select {
case <-s.readyCh:
assert.Fail(t, "readyCh should not be closed")
default:
}
errCh <- s.Run(ctx)
}()
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
assert.Equal(t, int32(1), fetches.Load())
clock.Step(pki.LeafCert.NotAfter.Sub(now) / 2)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
assert.Equal(c, int32(2), fetches.Load())
}, time.Second, time.Millisecond)
cancel()
select {
case err := <-errCh:
require.NoError(t, err)
case <-time.After(time.Second):
assert.Fail(t, "First Run should have returned and returned no error ")
}
})
t.Run("if renewal failed, should try again in 10 seconds", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{
LeafID: spiffeid.RequireFromString("spiffe://example.com/foo/bar"),
})
respCert := []*x509.Certificate{pki.LeafCert}
var respErr error
var fetches atomic.Int32
s := New(Options{
Log: logger.NewLogger("test"),
RequestSVIDFn: func(context.Context, []byte) (*SVIDResponse, error) {
fetches.Add(1)
return &SVIDResponse{
X509Certificates: respCert,
}, respErr
},
})
now := time.Now()
clock := clocktesting.NewFakeClock(now)
s.clock = clock
ctx, cancel := context.WithCancel(t.Context())
errCh := make(chan error)
go func() {
select {
case <-s.readyCh:
assert.Fail(t, "readyCh should not be closed")
default:
}
errCh <- s.Run(ctx)
}()
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
assert.Equal(t, int32(1), fetches.Load())
respCert = nil
respErr = errors.New("this is an error")
clock.Step(pki.LeafCert.NotAfter.Sub(now) / 2)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
assert.Equal(c, int32(2), fetches.Load())
}, time.Second, time.Millisecond)
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
clock.Step(time.Second * 5)
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
assert.Equal(t, int32(2), fetches.Load())
clock.Step(time.Second * 5)
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
clock.Step(1)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
assert.Equal(c, int32(3), fetches.Load())
}, time.Second, time.Millisecond)
cancel()
select {
case err := <-errCh:
require.NoError(t, err)
case <-time.After(time.Second):
assert.Fail(t, "First Run should have returned and returned no error ")
}
})
}

View File

@ -0,0 +1,95 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package spiffe
import (
"context"
"errors"
"fmt"
"strings"
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
)
var (
errNoX509SVIDAvailable = errors.New("no X509 SVID available")
errNoJWTSVIDAvailable = errors.New("no JWT SVID available")
errAudienceRequired = errors.New("JWT audience is required")
)
// svidSource is an implementation of both go-spiffe x509svid.Source and jwtsvid.Source interfaces.
type svidSource struct {
spiffe *SPIFFE
}
// GetX509SVID returns the current X.509 certificate identity as a SPIFFE SVID.
// Implements the go-spiffe x509svid.Source interface.
func (s *svidSource) GetX509SVID() (*x509svid.SVID, error) {
s.spiffe.lock.RLock()
defer s.spiffe.lock.RUnlock()
<-s.spiffe.readyCh
svid := s.spiffe.currentX509SVID
if svid == nil {
return nil, errNoX509SVIDAvailable
}
return svid, nil
}
// audienceMismatchError is an error that contains information about mismatched audiences
type audienceMismatchError struct {
expected []string
actual []string
}
func (e *audienceMismatchError) Error() string {
return fmt.Sprintf("JWT SVID has different audiences than requested: expected %s, got %s",
strings.Join(e.expected, ", "), strings.Join(e.actual, ", "))
}
// FetchJWTSVID returns the current JWT SVID.
// Implements the go-spiffe jwtsvid.Source interface.
func (s *svidSource) FetchJWTSVID(ctx context.Context, params jwtsvid.Params) (*jwtsvid.SVID, error) {
s.spiffe.lock.RLock()
defer s.spiffe.lock.RUnlock()
if params.Audience == "" {
return nil, errAudienceRequired
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-s.spiffe.readyCh:
}
svid := s.spiffe.currentJWTSVID
if svid == nil {
return nil, errNoJWTSVIDAvailable
}
// verify that the audience being requested is the same as the audience in the SVID
// WARN: we do not check extra audiences here.
if !audiencesMatch(svid.Audience, []string{params.Audience}) {
return nil, &audienceMismatchError{
expected: []string{params.Audience},
actual: svid.Audience,
}
}
return svid, nil
}

View File

@ -0,0 +1,190 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package spiffe
import (
"context"
"sync"
"testing"
"time"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
"github.com/stretchr/testify/require"
)
func Test_svidSource(*testing.T) {
var _ x509svid.Source = new(svidSource)
var _ jwtsvid.Source = new(svidSource)
}
// createMockJWTSVID creates a mock JWT SVID for testing
func createMockJWTSVID(audiences []string) (*jwtsvid.SVID, error) {
td, err := spiffeid.TrustDomainFromString("example.org")
if err != nil {
return nil, err
}
id, err := spiffeid.FromSegments(td, "workload")
if err != nil {
return nil, err
}
svid := &jwtsvid.SVID{
ID: id,
Audience: audiences,
Expiry: time.Now().Add(time.Hour),
}
return svid, nil
}
func TestFetchJWTSVID(t *testing.T) {
t.Run("should return error when audience is empty", func(t *testing.T) {
s := &svidSource{
spiffe: &SPIFFE{
readyCh: make(chan struct{}),
lock: sync.RWMutex{},
},
}
close(s.spiffe.readyCh) // Mark as ready
svid, err := s.FetchJWTSVID(t.Context(), jwtsvid.Params{
Audience: "",
})
require.Nil(t, svid)
require.ErrorIs(t, err, errAudienceRequired)
})
t.Run("should return error when no JWT SVID available", func(t *testing.T) {
s := &svidSource{
spiffe: &SPIFFE{
readyCh: make(chan struct{}),
lock: sync.RWMutex{},
currentJWTSVID: nil,
},
}
close(s.spiffe.readyCh) // Mark as ready
svid, err := s.FetchJWTSVID(t.Context(), jwtsvid.Params{
Audience: "test-audience",
})
require.Nil(t, svid)
require.ErrorIs(t, err, errNoJWTSVIDAvailable)
})
t.Run("should return error when audience doesn't match", func(t *testing.T) {
// Create a mock SVID with a specific audience
mockJWTSVID, err := createMockJWTSVID([]string{"actual-audience"})
require.NoError(t, err)
s := &svidSource{
spiffe: &SPIFFE{
readyCh: make(chan struct{}),
lock: sync.RWMutex{},
currentJWTSVID: mockJWTSVID,
},
}
close(s.spiffe.readyCh) // Mark as ready
svid, err := s.FetchJWTSVID(t.Context(), jwtsvid.Params{
Audience: "requested-audience",
})
require.Nil(t, svid)
require.Error(t, err)
// Verify the specific error type and contents
audienceErr, ok := err.(*audienceMismatchError)
require.True(t, ok, "Expected audienceMismatchError")
require.Equal(t, "JWT SVID has different audiences than requested: expected requested-audience, got actual-audience", audienceErr.Error())
})
t.Run("should return JWT SVID when audience matches", func(t *testing.T) {
mockJWTSVID, err := createMockJWTSVID([]string{"test-audience", "extra-audience"})
require.NoError(t, err)
s := &svidSource{
spiffe: &SPIFFE{
readyCh: make(chan struct{}),
lock: sync.RWMutex{},
currentJWTSVID: mockJWTSVID,
},
}
close(s.spiffe.readyCh) // Mark as ready
svid, err := s.FetchJWTSVID(t.Context(), jwtsvid.Params{
Audience: "test-audience",
})
require.NoError(t, err)
require.Equal(t, mockJWTSVID, svid)
})
t.Run("should wait for readyCh before checking SVID", func(t *testing.T) {
mockJWTSVID, err := createMockJWTSVID([]string{"test-audience"})
require.NoError(t, err)
readyCh := make(chan struct{})
s := &svidSource{
spiffe: &SPIFFE{
readyCh: readyCh,
lock: sync.RWMutex{},
currentJWTSVID: mockJWTSVID,
},
}
// Start goroutine to fetch SVID
ctx, cancel := context.WithTimeout(t.Context(), 500*time.Millisecond)
defer cancel()
resultCh := make(chan struct {
svid *jwtsvid.SVID
err error
})
go func() {
svid, err := s.FetchJWTSVID(ctx, jwtsvid.Params{
Audience: "test-audience",
})
resultCh <- struct {
svid *jwtsvid.SVID
err error
}{svid, err}
}()
// require that fetch is blocked
select {
case <-resultCh:
t.Fatal("FetchJWTSVID should be blocked until readyCh is closed")
case <-time.After(100 * time.Millisecond):
// Expected behavior - fetch is blocked
}
// Close readyCh to unblock fetch
close(readyCh)
// Now fetch should complete
select {
case result := <-resultCh:
require.NoError(t, result.err)
require.NotNil(t, result.svid)
case <-time.After(100 * time.Millisecond):
t.Fatal("FetchJWTSVID should have completed after readyCh was closed")
}
})
}

View File

@ -0,0 +1,300 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package file
import (
"context"
"errors"
"fmt"
"os"
"sync"
"sync/atomic"
"time"
"github.com/spiffe/go-spiffe/v2/bundle/jwtbundle"
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"k8s.io/utils/clock"
"github.com/dapr/kit/concurrency"
"github.com/dapr/kit/crypto/pem"
"github.com/dapr/kit/crypto/spiffe/trustanchors"
"github.com/dapr/kit/fswatcher"
"github.com/dapr/kit/logger"
)
var (
// ErrTrustAnchorsClosed is returned when an operation is performed on closed trust anchors.
ErrTrustAnchorsClosed = errors.New("trust anchors is closed")
// ErrFailedToReadTrustAnchorsFile is returned when the trust anchors file cannot be read.
ErrFailedToReadTrustAnchorsFile = errors.New("failed to read trust anchors file")
)
type Options struct {
Log logger.Logger
CAPath string
JwksPath *string
}
// file is a TrustAnchors implementation that uses a file as the source of trust
// anchors. The trust anchors will be updated when the file changes.
type file struct {
log logger.Logger
caPath string
jwksPath *string
x509Bundle *x509bundle.Bundle
jwtBundle *jwtbundle.Bundle
rootPEM []byte
// fswatcherInterval is the interval at which the trust anchors file changes
// are batched. Used for testing only, and 500ms otherwise.
fsWatcherInterval time.Duration
// initFileWatchInterval is the interval at which the trust anchors file is
// checked for the first time. Used for testing only, and 1 second otherwise.
initFileWatchInterval time.Duration
// subs is a list of channels to notify when the trust anchors are updated.
subs []chan<- struct{}
lock sync.RWMutex
clock clock.Clock
running atomic.Bool
readyCh chan struct{}
closeCh chan struct{}
caEvent chan struct{}
}
func From(opts Options) trustanchors.Interface {
return &file{
fsWatcherInterval: time.Millisecond * 500,
initFileWatchInterval: time.Second,
log: opts.Log,
caPath: opts.CAPath,
jwksPath: opts.JwksPath,
clock: clock.RealClock{},
readyCh: make(chan struct{}),
closeCh: make(chan struct{}),
caEvent: make(chan struct{}),
}
}
func (f *file) Run(ctx context.Context) error {
if !f.running.CompareAndSwap(false, true) {
return errors.New("trust anchors is already running")
}
defer close(f.closeCh)
for {
fs := []string{f.caPath}
if f.jwksPath != nil {
fs = append(fs, *f.jwksPath)
}
if found, err := filesExist(fs...); err != nil {
return err
} else if found {
break
}
// Trust anchors file not be provided yet, wait.
select {
case <-ctx.Done():
return fmt.Errorf("failed to find trust anchors file '%s': %w", f.caPath, ctx.Err())
case <-f.clock.After(f.initFileWatchInterval):
f.log.Warnf("Trust anchors file '%s' not found, waiting...", f.caPath)
}
}
f.log.Infof("Trust anchors file '%s' found", f.caPath)
if err := f.updateAnchors(ctx); err != nil {
return err
}
targets := []string{f.caPath}
if f.jwksPath != nil {
targets = append(targets, *f.jwksPath)
}
fs, err := fswatcher.New(fswatcher.Options{
Targets: targets,
Interval: &f.fsWatcherInterval,
})
if err != nil {
return fmt.Errorf("failed to create file watcher: %w", err)
}
close(f.readyCh)
f.log.Infof("Watching trust anchors file '%s' for changes", f.caPath)
if f.jwksPath != nil {
f.log.Infof("Watching JWT bundle file '%s' for changes", f.jwksPath)
}
return concurrency.NewRunnerManager(
func(ctx context.Context) error {
return fs.Run(ctx, f.caEvent)
},
func(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return nil
case <-f.caEvent:
f.log.Info("Trust anchors file changed, reloading trust anchors")
if err = f.updateAnchors(ctx); err != nil {
return fmt.Errorf("%w: '%s': %v", ErrFailedToReadTrustAnchorsFile, f.caPath, err)
}
}
}
},
).Run(ctx)
}
func (f *file) CurrentTrustAnchors(ctx context.Context) ([]byte, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-f.closeCh:
return nil, ErrTrustAnchorsClosed
case <-f.readyCh:
}
f.lock.RLock()
defer f.lock.RUnlock()
rootPEM := make([]byte, len(f.rootPEM))
copy(rootPEM, f.rootPEM)
return rootPEM, nil
}
func (f *file) updateAnchors(ctx context.Context) error {
f.lock.Lock()
defer f.lock.Unlock()
rootPEMs, err := os.ReadFile(f.caPath)
if err != nil {
return fmt.Errorf("failed to read trust anchors file '%s': %w", f.caPath, err)
}
trustAnchorCerts, err := pem.DecodePEMCertificates(rootPEMs)
if err != nil {
return fmt.Errorf("failed to decode trust anchors: %w", err)
}
f.rootPEM = rootPEMs
f.x509Bundle = x509bundle.FromX509Authorities(spiffeid.TrustDomain{}, trustAnchorCerts)
if f.jwksPath != nil {
jwks, err := os.ReadFile(*f.jwksPath)
if err != nil {
return fmt.Errorf("failed to read JWT bundle file '%s': %w", *f.jwksPath, err)
}
jwtBundle, err := jwtbundle.Parse(spiffeid.TrustDomain{}, jwks)
if err != nil {
return fmt.Errorf("failed to parse JWT bundle: %w", err)
}
f.jwtBundle = jwtBundle
}
var wg sync.WaitGroup
defer wg.Wait()
wg.Add(len(f.subs))
for _, ch := range f.subs {
go func(chi chan<- struct{}) {
defer wg.Done()
select {
case chi <- struct{}{}:
case <-ctx.Done():
}
}(ch)
}
return nil
}
func (f *file) GetX509BundleForTrustDomain(_ spiffeid.TrustDomain) (*x509bundle.Bundle, error) {
select {
case <-f.closeCh:
return nil, ErrTrustAnchorsClosed
case <-f.readyCh:
}
f.lock.RLock()
defer f.lock.RUnlock()
bundle := f.x509Bundle
return bundle, nil
}
func (f *file) GetJWTBundleForTrustDomain(_ spiffeid.TrustDomain) (*jwtbundle.Bundle, error) {
select {
case <-f.closeCh:
return nil, ErrTrustAnchorsClosed
case <-f.readyCh:
}
f.lock.RLock()
defer f.lock.RUnlock()
bundle := f.jwtBundle
return bundle, nil
}
func (f *file) Watch(ctx context.Context, ch chan<- []byte) {
f.lock.Lock()
sub := make(chan struct{}, 5)
f.subs = append(f.subs, sub)
f.lock.Unlock()
for {
select {
case <-ctx.Done():
return
case <-f.closeCh:
return
case <-sub:
f.lock.RLock()
rootPEM := make([]byte, len(f.rootPEM))
copy(rootPEM, f.rootPEM)
f.lock.RUnlock()
select {
case ch <- rootPEM:
case <-ctx.Done():
case <-f.closeCh:
}
}
}
}
func filesExist(paths ...string) (bool, error) {
for _, path := range paths {
if path == "" {
continue
}
if _, err := os.Stat(path); err != nil {
if errors.Is(err, os.ErrNotExist) {
return false, nil
}
return false, fmt.Errorf("failed to stat file '%s': %w", path, err)
}
}
return true, nil
}

View File

@ -0,0 +1,580 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package file
import (
"context"
"os"
"path/filepath"
"testing"
"time"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dapr/kit/crypto/test"
"github.com/dapr/kit/logger"
)
func TestFile_Run(t *testing.T) {
t.Run("if Run multiple times, expect error", func(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
f.initFileWatchInterval = time.Millisecond
ctx, cancel := context.WithCancel(t.Context())
errCh := make(chan error)
go func() {
errCh <- f.Run(ctx)
}()
go func() {
errCh <- f.Run(ctx)
}()
select {
case err := <-errCh:
require.Error(t, err)
case <-time.After(time.Second):
assert.Fail(t, "Expected error")
}
select {
case <-f.closeCh:
assert.Fail(t, "closeCh should not be closed")
default:
}
cancel()
select {
case err := <-errCh:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
assert.Fail(t, "First Run should have returned and returned no error ")
}
})
t.Run("if file is not found and context cancelled, should return ctx.Err", func(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
f.initFileWatchInterval = time.Millisecond
ctx, cancel := context.WithCancel(t.Context())
errCh := make(chan error)
go func() {
errCh <- f.Run(ctx)
}()
cancel()
select {
case err := <-errCh:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
assert.Fail(t, "First Run should have returned and returned no error ")
}
})
t.Run("if file found but is empty, should return error", func(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, nil, 0o600))
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
f.initFileWatchInterval = time.Millisecond
errCh := make(chan error)
go func() {
errCh <- f.Run(t.Context())
}()
select {
case err := <-errCh:
require.Error(t, err)
case <-time.After(time.Second):
assert.Fail(t, "expected error")
}
})
t.Run("if file found but is only garbage data, expect error", func(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, []byte("garbage data"), 0o600))
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
f.initFileWatchInterval = time.Millisecond
errCh := make(chan error)
go func() {
errCh <- f.Run(t.Context())
}()
select {
case err := <-errCh:
require.Error(t, err)
case <-time.After(time.Second):
assert.Fail(t, "expected error")
}
})
t.Run("if file found but is only garbage data in root, expect error", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
root := pki.RootCertPEM[10:]
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, root, 0o600))
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
f.initFileWatchInterval = time.Millisecond
errCh := make(chan error)
go func() {
errCh <- f.Run(t.Context())
}()
select {
case err := <-errCh:
require.Error(t, err)
case <-time.After(time.Second):
assert.Fail(t, "expected error")
}
})
t.Run("single root should be correctly parsed from file", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, pki.RootCertPEM, 0o600))
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
f.initFileWatchInterval = time.Millisecond
errCh := make(chan error)
go func() {
errCh <- f.Run(t.Context())
}()
select {
case <-f.readyCh:
case <-time.After(time.Second):
assert.Fail(t, "expected to be ready in time")
}
b, err := f.CurrentTrustAnchors(t.Context())
require.NoError(t, err)
assert.Equal(t, pki.RootCertPEM, b)
})
t.Run("garbage data outside of root should be ignored", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
//nolint:gocritic
root := append(pki.RootCertPEM, []byte("garbage data")...)
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, root, 0o600))
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
f.initFileWatchInterval = time.Millisecond
errCh := make(chan error)
go func() {
errCh <- f.Run(t.Context())
}()
select {
case <-f.readyCh:
case <-time.After(time.Second):
assert.Fail(t, "expected to be ready in time")
}
b, err := f.CurrentTrustAnchors(t.Context())
require.NoError(t, err)
assert.Equal(t, root, b)
})
t.Run("multiple roots should be parsed", func(t *testing.T) {
pki1, pki2 := test.GenPKI(t, test.PKIOptions{}), test.GenPKI(t, test.PKIOptions{})
//nolint:gocritic
roots := append(pki1.RootCertPEM, pki2.RootCertPEM...)
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
f.initFileWatchInterval = time.Millisecond
errCh := make(chan error)
go func() {
errCh <- f.Run(t.Context())
}()
select {
case <-f.readyCh:
case <-time.After(time.Second):
assert.Fail(t, "expected to be ready in time")
}
b, err := f.CurrentTrustAnchors(t.Context())
require.NoError(t, err)
assert.Equal(t, roots, b)
})
t.Run("writing a bad root PEM file should make Run return error", func(t *testing.T) {
pki1, pki2 := test.GenPKI(t, test.PKIOptions{}), test.GenPKI(t, test.PKIOptions{})
//nolint:gocritic
roots := append(pki1.RootCertPEM, pki2.RootCertPEM...)
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
f.initFileWatchInterval = time.Millisecond
f.fsWatcherInterval = time.Millisecond
errCh := make(chan error)
go func() {
errCh <- f.Run(t.Context())
}()
select {
case <-f.readyCh:
case <-time.After(time.Second):
assert.Fail(t, "expected to be ready in time")
}
require.NoError(t, os.WriteFile(tmp, []byte("garbage data"), 0o600))
select {
case err := <-errCh:
require.Error(t, err)
case <-time.After(time.Second):
assert.Fail(t, "expected error to be returned from Run")
}
})
}
func TestFile_GetX509BundleForTrustDomain(t *testing.T) {
t.Run("Should return full PEM regardless given trust domain", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
//nolint:gocritic
root := append(pki.RootCertPEM, []byte("garbage data")...)
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, root, 0o600))
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
errCh := make(chan error)
ctx, cancel := context.WithCancel(t.Context())
go func() {
errCh <- ta.Run(ctx)
}()
t.Cleanup(func() {
cancel()
select {
case err := <-errCh:
require.NoError(t, err)
case <-time.After(time.Second):
assert.Fail(t, "expected Run to return")
}
})
trustDomain1, err := spiffeid.TrustDomainFromString("example.com")
require.NoError(t, err)
bundle, err := f.GetX509BundleForTrustDomain(trustDomain1)
require.NoError(t, err)
assert.Equal(t, f.x509Bundle, bundle)
b1, err := bundle.Marshal()
require.NoError(t, err)
assert.Equal(t, pki.RootCertPEM, b1)
trustDomain2, err := spiffeid.TrustDomainFromString("another-example.org")
require.NoError(t, err)
bundle, err = f.GetX509BundleForTrustDomain(trustDomain2)
require.NoError(t, err)
assert.Equal(t, f.x509Bundle, bundle)
b2, err := bundle.Marshal()
require.NoError(t, err)
assert.Equal(t, pki.RootCertPEM, b2)
})
}
func TestFile_Watch(t *testing.T) {
t.Run("should return when Run context has been cancelled", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, pki.RootCertPEM, 0o600))
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
f.initFileWatchInterval = time.Millisecond
errCh := make(chan error)
ctx, cancel := context.WithCancel(t.Context())
go func() {
errCh <- f.Run(ctx)
}()
time.Sleep(time.Millisecond * 10) // adding a small delay to ensure f.Run has finished and running
watchDone := make(chan struct{})
go func() {
ta.Watch(t.Context(), make(chan []byte))
close(watchDone)
}()
cancel()
select {
case err := <-errCh:
require.NoError(t, err)
case <-time.After(time.Second):
assert.Fail(t, "expected error to be returned from Run")
}
select {
case <-watchDone:
case <-time.After(time.Second):
assert.Fail(t, "expected Watch to have returned")
}
})
t.Run("should return when given context has been cancelled", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, pki.RootCertPEM, 0o600))
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
f.initFileWatchInterval = time.Millisecond
errCh := make(chan error)
ctx1, cancel1 := context.WithCancel(t.Context())
go func() {
errCh <- f.Run(ctx1)
}()
time.Sleep(time.Millisecond * 10) // adding a small delay to ensure f.Run has finished and running
watchDone := make(chan struct{})
ctx2, cancel2 := context.WithCancel(t.Context())
go func() {
ta.Watch(ctx2, make(chan []byte))
close(watchDone)
}()
cancel2()
select {
case <-watchDone:
case <-time.After(time.Second):
assert.Fail(t, "expected Watch to have returned")
}
cancel1()
select {
case err := <-errCh:
require.NoError(t, err)
case <-time.After(time.Second):
assert.Fail(t, "expected error to be returned from Run")
}
})
t.Run("should update Watch subscribers when root PEM has been changed", func(t *testing.T) {
pki1 := test.GenPKI(t, test.PKIOptions{})
pki2 := test.GenPKI(t, test.PKIOptions{})
pki3 := test.GenPKI(t, test.PKIOptions{})
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, pki1.RootCertPEM, 0o600))
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
f.initFileWatchInterval = time.Millisecond
f.fsWatcherInterval = time.Millisecond
errCh := make(chan error)
ctx, cancel := context.WithCancel(t.Context())
go func() {
errCh <- f.Run(ctx)
}()
time.Sleep(time.Millisecond * 10) // adding a small delay to ensure f.Run has finished and running
select {
case <-f.readyCh:
case <-time.After(time.Second):
assert.Fail(t, "expected to be ready in time")
}
watchDone1, watchDone2 := make(chan struct{}), make(chan struct{})
tCh1, tCh2 := make(chan []byte), make(chan []byte)
go func() {
ta.Watch(t.Context(), tCh1)
close(watchDone1)
}()
go func() {
ta.Watch(t.Context(), tCh2)
close(watchDone2)
}()
//nolint:gocritic
roots := append(pki1.RootCertPEM, pki2.RootCertPEM...)
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
for _, ch := range []chan []byte{tCh1, tCh2} {
select {
case b := <-ch:
assert.Equal(t, string(roots), string(b))
case <-time.After(time.Second):
assert.Fail(t, "failed to get subscribed file watch in time")
}
}
//nolint:gocritic
roots = append(pki1.RootCertPEM, append(pki2.RootCertPEM, pki3.RootCertPEM...)...)
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
for _, ch := range []chan []byte{tCh1, tCh2} {
select {
case b := <-ch:
assert.Equal(t, string(roots), string(b))
case <-time.After(time.Second):
assert.Fail(t, "failed to get subscribed file watch in time")
}
}
cancel()
for _, ch := range []chan struct{}{watchDone1, watchDone2} {
select {
case <-ch:
case <-time.After(time.Second):
assert.Fail(t, "expected Watch to have returned")
}
}
select {
case err := <-errCh:
require.NoError(t, err)
case <-time.After(time.Second):
assert.Fail(t, "expected error to be returned from Run")
}
})
}
func TestFile_CurrentTrustAnchors(t *testing.T) {
t.Run("returns trust anchors as they change", func(t *testing.T) {
pki1, pki2, pki3 := test.GenPKI(t, test.PKIOptions{}), test.GenPKI(t, test.PKIOptions{}), test.GenPKI(t, test.PKIOptions{})
tmp := filepath.Join(t.TempDir(), "ca.crt")
require.NoError(t, os.WriteFile(tmp, pki1.RootCertPEM, 0o600))
ta := From(Options{
Log: logger.NewLogger("test"),
CAPath: tmp,
})
f, ok := ta.(*file)
require.True(t, ok)
f.initFileWatchInterval = time.Millisecond
f.fsWatcherInterval = time.Millisecond
ctx, cancel := context.WithCancel(t.Context())
errCh := make(chan error)
go func() {
errCh <- f.Run(ctx)
}()
time.Sleep(time.Millisecond * 10) // adding a small delay to ensure f.Run has finished and running
//nolint:gocritic
roots := append(pki1.RootCertPEM, pki2.RootCertPEM...)
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
time.Sleep(time.Millisecond * 10) // adding a small delay to ensure the file watcher has time to pick up the change
assert.EventuallyWithT(t, func(c *assert.CollectT) {
pem, err := ta.CurrentTrustAnchors(t.Context())
require.NoError(t, err)
assert.Equal(c, roots, pem)
}, time.Second, time.Millisecond)
//nolint:gocritic
roots = append(pki1.RootCertPEM, append(pki2.RootCertPEM, pki3.RootCertPEM...)...)
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
time.Sleep(time.Millisecond * 10) // adding a small delay to ensure the file watcher has time to pick up the change
assert.EventuallyWithT(t, func(c *assert.CollectT) {
pem, err := ta.CurrentTrustAnchors(t.Context())
require.NoError(t, err)
assert.Equal(c, roots, pem)
}, time.Second, time.Millisecond)
cancel()
select {
case err := <-errCh:
require.NoError(t, err)
case <-time.After(time.Second):
assert.Fail(t, "expected error to be returned from Run")
}
})
}

View File

@ -0,0 +1,86 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package multi
import (
"context"
"errors"
"github.com/spiffe/go-spiffe/v2/bundle/jwtbundle"
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/dapr/kit/concurrency"
"github.com/dapr/kit/crypto/spiffe/trustanchors"
)
var (
ErrNotImplemented = errors.New("not implemented")
ErrTrustDomainNotFound = errors.New("trust domain not found")
)
type Options struct {
TrustAnchors map[spiffeid.TrustDomain]trustanchors.Interface
}
// multi is a TrustAnchors implementation which uses multiple trust anchors
// which are indexed by trust domain.
type multi struct {
trustAnchors map[spiffeid.TrustDomain]trustanchors.Interface
}
func From(opts Options) trustanchors.Interface {
return &multi{
trustAnchors: opts.TrustAnchors,
}
}
func (m *multi) Run(ctx context.Context) error {
r := concurrency.NewRunnerManager()
for _, ta := range m.trustAnchors {
if err := r.Add(ta.Run); err != nil {
return err
}
}
return r.Run(ctx)
}
func (m *multi) CurrentTrustAnchors(context.Context) ([]byte, error) {
return nil, ErrNotImplemented
}
func (m *multi) GetX509BundleForTrustDomain(td spiffeid.TrustDomain) (*x509bundle.Bundle, error) {
for tad, ta := range m.trustAnchors {
if td.Compare(tad) == 0 {
return ta.GetX509BundleForTrustDomain(td)
}
}
return nil, ErrTrustDomainNotFound
}
func (m *multi) GetJWTBundleForTrustDomain(td spiffeid.TrustDomain) (*jwtbundle.Bundle, error) {
for tad, ta := range m.trustAnchors {
if td.Compare(tad) == 0 {
return ta.GetJWTBundleForTrustDomain(td)
}
}
return nil, ErrTrustDomainNotFound
}
func (m *multi) Watch(context.Context, chan<- []byte) {
return
}

View File

@ -0,0 +1,99 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package static
import (
"context"
"errors"
"fmt"
"sync/atomic"
"github.com/spiffe/go-spiffe/v2/bundle/jwtbundle"
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/dapr/kit/crypto/pem"
"github.com/dapr/kit/crypto/spiffe/trustanchors"
)
// static is a TrustAcnhors implementation that uses a static list of trust
// anchors.
type static struct {
x509Bundle *x509bundle.Bundle
jwtBundle *jwtbundle.Bundle
anchors []byte
running atomic.Bool
closeCh chan struct{}
}
type Options struct {
Anchors []byte
Jwks []byte
}
func From(opts Options) (trustanchors.Interface, error) {
// Create empty trust domain for now
emptyTD := spiffeid.TrustDomain{}
var jwtBundle *jwtbundle.Bundle
if opts.Jwks != nil {
var err error
jwtBundle, err = jwtbundle.Parse(emptyTD, opts.Jwks)
if err != nil {
return nil, fmt.Errorf("failed to create JWT bundle: %w", err)
}
}
trustAnchorCerts, err := pem.DecodePEMCertificates(opts.Anchors)
if err != nil {
return nil, fmt.Errorf("failed to decode trust anchors: %w", err)
}
return &static{
anchors: opts.Anchors,
x509Bundle: x509bundle.FromX509Authorities(emptyTD, trustAnchorCerts),
jwtBundle: jwtBundle,
closeCh: make(chan struct{}),
}, nil
}
func (s *static) CurrentTrustAnchors(context.Context) ([]byte, error) {
bundle := make([]byte, len(s.anchors))
copy(bundle, s.anchors)
return bundle, nil
}
func (s *static) Run(ctx context.Context) error {
if !s.running.CompareAndSwap(false, true) {
return errors.New("trust anchors source is already running")
}
<-ctx.Done()
close(s.closeCh)
return nil
}
func (s *static) GetX509BundleForTrustDomain(spiffeid.TrustDomain) (*x509bundle.Bundle, error) {
return s.x509Bundle, nil
}
func (s *static) GetJWTBundleForTrustDomain(_ spiffeid.TrustDomain) (*jwtbundle.Bundle, error) {
return s.jwtBundle, nil
}
func (s *static) Watch(ctx context.Context, _ chan<- []byte) {
select {
case <-ctx.Done():
case <-s.closeCh:
}
}

View File

@ -0,0 +1,210 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package static
import (
"context"
"testing"
"time"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dapr/kit/crypto/test"
)
func TestFromStatic(t *testing.T) {
t.Run("empty root should return error", func(t *testing.T) {
_, err := From(Options{})
require.Error(t, err)
})
t.Run("garbage data should return error", func(t *testing.T) {
_, err := From(Options{Anchors: []byte("garbage data")})
require.Error(t, err)
})
t.Run("just garbage data should return error", func(t *testing.T) {
_, err := From(Options{Anchors: []byte("garbage data")})
require.Error(t, err)
})
t.Run("garbage data in root should return error", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
root := pki.RootCertPEM[10:]
_, err := From(Options{Anchors: root})
require.Error(t, err)
})
t.Run("single root should be correctly parsed", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
ta, err := From(Options{Anchors: pki.RootCertPEM})
require.NoError(t, err)
taPEM, err := ta.CurrentTrustAnchors(t.Context())
require.NoError(t, err)
assert.Equal(t, pki.RootCertPEM, taPEM)
})
t.Run("garbage data outside of root should be ignored", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
//nolint:gocritic
root := append(pki.RootCertPEM, []byte("garbage data")...)
ta, err := From(Options{Anchors: root})
require.NoError(t, err)
taPEM, err := ta.CurrentTrustAnchors(t.Context())
require.NoError(t, err)
assert.Equal(t, root, taPEM)
})
t.Run("multiple roots should be correctly parsed", func(t *testing.T) {
pki1, pki2 := test.GenPKI(t, test.PKIOptions{}), test.GenPKI(t, test.PKIOptions{})
//nolint:gocritic
roots := append(pki1.RootCertPEM, pki2.RootCertPEM...)
ta, err := From(Options{Anchors: roots})
require.NoError(t, err)
taPEM, err := ta.CurrentTrustAnchors(t.Context())
require.NoError(t, err)
assert.Equal(t, roots, taPEM)
})
}
func TestStatic_GetX509BundleForTrustDomain(t *testing.T) {
t.Run("Should return full PEM regardless given trust domain", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
//nolint:gocritic
root := append(pki.RootCertPEM, []byte("garbage data")...)
ta, err := From(Options{Anchors: root})
require.NoError(t, err)
s, ok := ta.(*static)
require.True(t, ok)
trustDomain1, err := spiffeid.TrustDomainFromString("example.com")
require.NoError(t, err)
bundle, err := s.GetX509BundleForTrustDomain(trustDomain1)
require.NoError(t, err)
assert.Equal(t, s.x509Bundle, bundle)
b1, err := bundle.Marshal()
require.NoError(t, err)
assert.Equal(t, pki.RootCertPEM, b1)
trustDomain2, err := spiffeid.TrustDomainFromString("another-example.org")
require.NoError(t, err)
bundle, err = s.GetX509BundleForTrustDomain(trustDomain2)
require.NoError(t, err)
assert.Equal(t, s.x509Bundle, bundle)
b2, err := bundle.Marshal()
require.NoError(t, err)
assert.Equal(t, pki.RootCertPEM, b2)
})
}
func TestStatic_Run(t *testing.T) {
t.Run("Run multiple times should return error", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
ta, err := From(Options{Anchors: pki.RootCertPEM})
require.NoError(t, err)
s, ok := ta.(*static)
require.True(t, ok)
ctx, cancel := context.WithCancel(t.Context())
errCh := make(chan error)
go func() {
errCh <- s.Run(ctx)
}()
go func() {
errCh <- s.Run(ctx)
}()
select {
case err := <-errCh:
require.Error(t, err)
case <-time.After(time.Second):
assert.Fail(t, "Expected error")
}
select {
case <-s.closeCh:
assert.Fail(t, "closeCh should not be closed")
default:
}
cancel()
select {
case err := <-errCh:
require.NoError(t, err)
case <-time.After(time.Second):
assert.Fail(t, "First Run should have returned and returned no error ")
}
})
}
func TestStatic_Watch(t *testing.T) {
t.Run("should return when context is cancelled", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
ta, err := From(Options{Anchors: pki.RootCertPEM})
require.NoError(t, err)
ctx, cancel := context.WithCancel(t.Context())
doneCh := make(chan struct{})
go func() {
ta.Watch(ctx, nil)
close(doneCh)
}()
cancel()
select {
case <-doneCh:
case <-time.After(time.Second):
assert.Fail(t, "Expected doneCh to be closed")
}
})
t.Run("should return when cancel is closed via closed Run", func(t *testing.T) {
pki := test.GenPKI(t, test.PKIOptions{})
ta, err := From(Options{Anchors: pki.RootCertPEM})
require.NoError(t, err)
ctx, cancel := context.WithCancel(t.Context())
doneCh := make(chan struct{})
errCh := make(chan error)
go func() {
errCh <- ta.Run(ctx)
}()
go func() {
ta.Watch(t.Context(), nil)
close(doneCh)
}()
cancel()
select {
case <-doneCh:
case <-time.After(time.Second):
assert.Fail(t, "Expected doneCh to be closed")
}
select {
case err := <-errCh:
require.NoError(t, err)
case <-time.After(time.Second):
assert.Fail(t, "Expected Run to return no error")
}
})
}

View File

@ -0,0 +1,42 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package trustanchors
import (
"context"
"github.com/spiffe/go-spiffe/v2/bundle/jwtbundle"
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
)
// Interface exposes a SPIFFE trust anchor from a source.
// Allows consumers to get the current trust anchor bundle, and subscribe to
// bundle updates.
type Interface interface {
// Source implements the SPIFFE trust anchor x509 bundle source.
x509bundle.Source
// Source implements the SPIFFE trust anchor jwt bundle source.
jwtbundle.Source
// CurrentTrustAnchors returns the current trust anchor PEM bundle.
CurrentTrustAnchors(ctx context.Context) ([]byte, error)
// Watch watches for changes to the trust domains and returns the PEM encoded
// trust domain roots.
// Returns when the given context is canceled.
Watch(ctx context.Context, ch chan<- []byte)
// Run starts the trust anchor source.
Run(ctx context.Context) error
}

View File

@ -191,7 +191,7 @@ func encryptSymmetricAESGCM(plaintext []byte, algorithm string, key []byte, nonc
return nil, nil, ErrKeyTypeMismatch
}
return encryptSymmetricAEAD(aead, plaintext, algorithm, key, nonce, associatedData)
return encryptSymmetricAEAD(aead, plaintext, nonce, associatedData)
}
func encryptSymmetricAESCBCHMAC(plaintext []byte, algorithm string, key []byte, nonce []byte, associatedData []byte) (ciphertext []byte, tag []byte, err error) {
@ -200,10 +200,10 @@ func encryptSymmetricAESCBCHMAC(plaintext []byte, algorithm string, key []byte,
return nil, nil, err
}
return encryptSymmetricAEAD(aead, plaintext, algorithm, key, nonce, associatedData)
return encryptSymmetricAEAD(aead, plaintext, nonce, associatedData)
}
func encryptSymmetricAEAD(aead cipher.AEAD, plaintext []byte, algorithm string, key []byte, nonce []byte, associatedData []byte) (ciphertext []byte, tag []byte, err error) {
func encryptSymmetricAEAD(aead cipher.AEAD, plaintext []byte, nonce []byte, associatedData []byte) (ciphertext []byte, tag []byte, err error) {
if len(nonce) != aead.NonceSize() {
return nil, nil, ErrInvalidNonce
}
@ -229,7 +229,7 @@ func decryptSymmetricAESGCM(ciphertext []byte, algorithm string, key []byte, non
return nil, ErrKeyTypeMismatch
}
return decryptSymmetricAEAD(aead, ciphertext, algorithm, key, nonce, tag, associatedData)
return decryptSymmetricAEAD(aead, ciphertext, nonce, tag, associatedData)
}
func decryptSymmetricAESCBCHMAC(ciphertext []byte, algorithm string, key []byte, nonce []byte, tag []byte, associatedData []byte) (plaintext []byte, err error) {
@ -238,10 +238,10 @@ func decryptSymmetricAESCBCHMAC(ciphertext []byte, algorithm string, key []byte,
return nil, err
}
return decryptSymmetricAEAD(aead, ciphertext, algorithm, key, nonce, tag, associatedData)
return decryptSymmetricAEAD(aead, ciphertext, nonce, tag, associatedData)
}
func decryptSymmetricAEAD(aead cipher.AEAD, ciphertext []byte, algorithm string, key []byte, nonce []byte, tag []byte, associatedData []byte) (plaintext []byte, err error) {
func decryptSymmetricAEAD(aead cipher.AEAD, ciphertext []byte, nonce []byte, tag []byte, associatedData []byte) (plaintext []byte, err error) {
if len(nonce) != aead.NonceSize() {
return nil, ErrInvalidNonce
}

240
crypto/test/pki.go Normal file
View File

@ -0,0 +1,240 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implieh.
See the License for the specific language governing permissions and
limitations under the License.
*/
package test
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net"
"net/url"
"testing"
"time"
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
"github.com/spiffe/go-spiffe/v2/spiffegrpc/grpccredentials"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/spiffe/go-spiffe/v2/spiffetls/tlsconfig"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/examples/helloworld/helloworld"
"google.golang.org/grpc/peer"
)
type PKIOptions struct {
LeafDNS string
LeafID spiffeid.ID
ClientDNS string
ClientID spiffeid.ID
}
type PKI struct {
RootCertPEM []byte
RootCert *x509.Certificate
LeafCert *x509.Certificate
LeafCertPEM []byte
LeafPKPEM []byte
LeafPK crypto.Signer
ClientCertPEM []byte
ClientCert *x509.Certificate
ClientPKPEM []byte
ClientPK crypto.Signer
leafID spiffeid.ID
clientID spiffeid.ID
}
func GenPKI(t *testing.T, opts PKIOptions) PKI {
t.Helper()
pki, err := GenPKIError(opts)
require.NoError(t, err)
return pki
}
func GenPKIError(opts PKIOptions) (PKI, error) {
rootPK, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return PKI{}, err
}
rootCert := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "Dapr Test Root CA"},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
IsCA: true,
KeyUsage: x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
rootCertBytes, err := x509.CreateCertificate(rand.Reader, rootCert, rootCert, &rootPK.PublicKey, rootPK)
if err != nil {
return PKI{}, err
}
rootCert, err = x509.ParseCertificate(rootCertBytes)
if err != nil {
return PKI{}, err
}
rootCertPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: rootCertBytes})
leafCertPEM, leafPKPEM, leafCert, leafPK, err := genLeafCert(rootPK, rootCert, opts.LeafID, opts.LeafDNS)
if err != nil {
return PKI{}, err
}
clientCertPEM, clientPKPEM, clientCert, clientPK, err := genLeafCert(rootPK, rootCert, opts.ClientID, opts.ClientDNS)
if err != nil {
return PKI{}, err
}
return PKI{
RootCert: rootCert,
RootCertPEM: rootCertPEM,
LeafCertPEM: leafCertPEM,
LeafPKPEM: leafPKPEM,
LeafCert: leafCert,
LeafPK: leafPK,
ClientCertPEM: clientCertPEM,
ClientPKPEM: clientPKPEM,
ClientCert: clientCert,
ClientPK: clientPK,
leafID: opts.LeafID,
clientID: opts.ClientID,
}, nil
}
func (p PKI) ClientGRPCCtx(t *testing.T) context.Context {
t.Helper()
bundle := x509bundle.New(spiffeid.RequireTrustDomainFromString("example.org"))
bundle.AddX509Authority(p.RootCert)
serverSVID := &mockSVID{
bundle: bundle,
svid: &x509svid.SVID{
ID: p.leafID,
Certificates: []*x509.Certificate{p.LeafCert},
PrivateKey: p.LeafPK,
},
}
clientSVID := &mockSVID{
bundle: bundle,
svid: &x509svid.SVID{
ID: p.clientID,
Certificates: []*x509.Certificate{p.ClientCert},
PrivateKey: p.ClientPK,
},
}
server := grpc.NewServer(grpc.Creds(grpccredentials.MTLSServerCredentials(serverSVID, serverSVID, tlsconfig.AuthorizeAny())))
gs := new(greeterServer)
helloworld.RegisterGreeterServer(server, gs)
lis, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
go func() {
server.Serve(lis)
}()
//nolint:staticcheck
conn, err := grpc.DialContext(t.Context(), lis.Addr().String(),
grpc.WithTransportCredentials(grpccredentials.MTLSClientCredentials(clientSVID, clientSVID, tlsconfig.AuthorizeAny())),
)
require.NoError(t, err)
_, err = helloworld.NewGreeterClient(conn).SayHello(t.Context(), new(helloworld.HelloRequest))
require.NoError(t, err)
lis.Close()
server.Stop()
return gs.ctx
}
func genLeafCert(rootPK *ecdsa.PrivateKey, rootCert *x509.Certificate, id spiffeid.ID, dns string) ([]byte, []byte, *x509.Certificate, crypto.Signer, error) {
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, nil, nil, nil, err
}
pkBytes, err := x509.MarshalPKCS8PrivateKey(pk)
if err != nil {
return nil, nil, nil, nil, err
}
cert := &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
x509.ExtKeyUsageClientAuth,
},
}
if len(dns) > 0 {
cert.DNSNames = []string{dns}
}
if !id.IsZero() {
cert.URIs = []*url.URL{id.URL()}
}
certBytes, err := x509.CreateCertificate(rand.Reader, cert, rootCert, &pk.PublicKey, rootPK)
if err != nil {
return nil, nil, nil, nil, err
}
cert, err = x509.ParseCertificate(certBytes)
if err != nil {
return nil, nil, nil, nil, err
}
pkPEM := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: pkBytes})
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes})
return certPEM, pkPEM, cert, pk, nil
}
type mockSVID struct {
svid *x509svid.SVID
bundle *x509bundle.Bundle
}
func (m *mockSVID) GetX509BundleForTrustDomain(_ spiffeid.TrustDomain) (*x509bundle.Bundle, error) {
return m.bundle, nil
}
func (m *mockSVID) GetX509SVID() (*x509svid.SVID, error) {
return m.svid, nil
}
type greeterServer struct {
helloworld.UnimplementedGreeterServer
ctx context.Context
}
func (s *greeterServer) SayHello(ctx context.Context, _ *helloworld.HelloRequest) (*helloworld.HelloReply, error) {
p, _ := peer.FromContext(ctx)
s.ctx = peer.NewContext(context.Background(), p)
return new(helloworld.HelloReply), nil
}

View File

@ -41,7 +41,7 @@ pHZ3vWGFAoGAc5Um3YYkhh2QScQBy5+kumH40LhFFy2ETznWEp0tS2NwmTfTm/Nl
Sg+Ct2nOw93cIhwDjWyoilkIapuuX2obY+sUc3kj2ugU+hONfuBStsF020IPP1sk
A9okIZVbz8ycqcjaBiNc4+TeiXED1K7bV9Kg+A9lxDxfGRybJ1/ECWA=
-----END RSA PRIVATE KEY-----
`
` // #nosec G101
privateKeyRSAPKCS8 = `-----BEGIN PRIVATE KEY-----
MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDcjaZ0griZFG77
LAytiNRnMHG3Q2UBUusyEaVomxvLs9ZMyIullWKnhIEP0bCcJTRMYUPuTb7u1+zT
@ -128,7 +128,7 @@ MHcCAQEEIOcFe4Q6ardS97ml2tV4+194nmlfQPh8o9ir/qsacEozoAoGCCqGSM49
AwEHoUQDQgAEUMn1c2ioMNi2DqvC8hdBVUERFZ97eVFsNVcQIgR0Hsq5PVrQ/dQ4
uI5u97b6k4wXHYFXMvPmsW1T6qZAE9bB3Q==
-----END EC PRIVATE KEY-----
`
` // #nosec G101
privateKeyP256PKCS8 = `-----BEGIN PRIVATE KEY-----
MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQg5wV7hDpqt1L3uaXa
1Xj7X3ieaV9A+Hyj2Kv+qxpwSjOhRANCAARQyfVzaKgw2LYOq8LyF0FVQREVn3t5

42
env/env.go vendored Normal file
View File

@ -0,0 +1,42 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package env
import (
"fmt"
"os"
"time"
)
// GetDurationWithRange returns the time.Duration value of the environment variable specified by `envVar`.
// If the environment variable is not set, it returns `defaultValue`.
// If the value is set but is not valid (not a valid time.Duration or falls outside the specified range
// [minValue, maxValue] inclusively), it returns `defaultValue` and an error.
func GetDurationWithRange(envVar string, defaultValue, min, max time.Duration) (time.Duration, error) {
v := os.Getenv(envVar)
if v == "" {
return defaultValue, nil
}
val, err := time.ParseDuration(v)
if err != nil {
return defaultValue, fmt.Errorf("invalid time.Duration value %s for the %s env variable: %w", val, envVar, err)
}
if val < min || val > max {
return defaultValue, fmt.Errorf("invalid value for the %s env variable: value should be between %s and %s, got %s", envVar, min, max, val)
}
return val, nil
}

96
env/env_test.go vendored Normal file
View File

@ -0,0 +1,96 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package env
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestGetIntWithRangeWrongValues(t *testing.T) {
testValues := []struct {
name string
envVarVal string
min time.Duration
max time.Duration
error string
}{
{
"should error if value is not a valid time.Duration",
"0.5",
time.Second,
2 * time.Second,
"invalid time.Duration value 0s for the MY_ENV env variable",
},
{
"should error if value is lower than 1s",
"0s",
time.Second,
10 * time.Second,
"value should be between 1s and 10s",
},
{
"should error if value is higher than 10s",
"2m",
time.Second,
10 * time.Second,
"value should be between 1s and 10s",
},
}
defaultValue := 3 * time.Second
for _, tt := range testValues {
t.Run(tt.name, func(t *testing.T) {
t.Setenv("MY_ENV", tt.envVarVal)
val, err := GetDurationWithRange("MY_ENV", defaultValue, tt.min, tt.max)
require.Error(t, err)
require.Contains(t, err.Error(), tt.error)
require.Equal(t, defaultValue, val)
})
}
}
func TestGetEnvDurationWithRangeValidValues(t *testing.T) {
testValues := []struct {
name string
envVarVal string
result time.Duration
}{
{
"should return default value if env variable is not set",
"",
3 * time.Second,
},
{
"should return result is env variable value is valid",
"4s",
4 * time.Second,
},
}
for _, tt := range testValues {
t.Run(tt.name, func(t *testing.T) {
if tt.envVarVal != "" {
t.Setenv("MY_ENV", tt.envVarVal)
}
val, err := GetDurationWithRange("MY_ENV", 3*time.Second, time.Second, 5*time.Second)
require.NoError(t, err)
require.Equal(t, tt.result, val)
})
}
}

34
errors/README.md Normal file
View File

@ -0,0 +1,34 @@
# Errors
The standardizing of errors to be used in Dapr based on the gRPC Richer Error Model and [accepted dapr/proposal](https://github.com/dapr/proposals/blob/main/20230511-BCIRS-error-handling-codes.md).
## Usage
Define the error
```go
import kitErrors "github.com/dapr/kit/errors"
// Define error in dapr pkg/api/errors/<building_block>.go
func PubSubNotFound(name string, pubsubType string, metadata map[string]string) error {
message := fmt.Sprintf("pubsub %s is not found", name)
return kitErrors.NewBuilder(
grpcCodes.NotFound,
http.StatusBadRequest,
message,
kitErrors.CodePrefixPubSub+kitErrors.CodeNotFound,
).
WithErrorInfo(kitErrors.CodePrefixPubSub+kitErrors.CodeNotFound, metadata).
WithResourceInfo(pubsubType, name, "", message).
Build()
}
```
Use the error
```go
import apiErrors "github.com/dapr/dapr/pkg/api/errors"
// Use error in dapr and pass in relevant information
err = apiErrors.PubSubNotFound(pubsubName, pubsubType, metadata)
```

39
errors/codes.go Normal file
View File

@ -0,0 +1,39 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package errors
const (
// Generic
CodeNotFound = "NOT_FOUND"
CodeNotConfigured = "NOT_CONFIGURED"
CodeNotSupported = "NOT_SUPPORTED"
CodeIllegalKey = "ILLEGAL_KEY"
// Components
CodePrefixStateStore = "DAPR_STATE_"
CodePrefixPubSub = "DAPR_PUBSUB_"
CodePrefixBindings = "DAPR_BINDING_"
CodePrefixSecretStore = "DAPR_SECRET_" // #nosec G101
CodePrefixConfigurationStore = "DAPR_CONFIGURATION_"
CodePrefixLock = "DAPR_LOCK_"
CodePrefixNameResolution = "DAPR_NAME_RESOLUTION_"
CodePrefixMiddleware = "DAPR_MIDDLEWARE_"
CodePrefixCryptography = "DAPR_CRYPTOGRAPHY_"
CodePrefixPlacement = "DAPR_PLACEMENT_"
// State
CodePostfixGetStateFailed = "GET_STATE_FAILED"
CodePostfixTooManyTransactions = "TOO_MANY_TRANSACTIONS"
CodePostfixQueryFailed = "QUERY_FAILED"
)

457
errors/errors.go Normal file
View File

@ -0,0 +1,457 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package errors
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"google.golang.org/genproto/googleapis/rpc/errdetails"
grpcCodes "google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/runtime/protoiface"
"github.com/dapr/kit/logger"
)
const (
Domain = "dapr.io"
errStringFormat = "api error: code = %s desc = %s"
typeGoogleAPI = "type.googleapis.com/"
)
var log = logger.NewLogger("dapr.kit")
// Error implements the Error interface and the interface that complies with "google.golang.org/grpc/status".FromError().
// It can be used to send errors to HTTP and gRPC servers, indicating the correct status code for each.
type Error struct {
// Added error details. To see available details see:
// https://github.com/googleapis/googleapis/blob/master/google/rpc/error_details.proto
details []proto.Message
// Status code for gRPC responses.
grpcCode grpcCodes.Code
// Status code for HTTP responses.
httpCode int
// Message is the human-readable error message.
message string
// Tag is a string identifying the error, used with HTTP responses only.
tag string
// Category is a string identifying the category of the error (i.e. "actor", "job", "pubsub), used for error code metrics only.
category string
}
// ErrorBuilder is used to build the error
type ErrorBuilder struct {
err *Error
}
// errorJSON is used to build the error for the HTTP Methods json output
type errorJSON struct {
ErrorCode string `json:"errorCode"`
Message string `json:"message"`
Details []any `json:"details,omitempty"`
}
/**************************************
Error
**************************************/
// HTTPStatusCode gets the error http code
func (e *Error) HTTPStatusCode() int {
return e.httpCode
}
// GrpcStatusCode gets the error grpc code
func (e *Error) GrpcStatusCode() grpcCodes.Code {
return e.grpcCode
}
// ErrorCode returns the error code from the error, prioritizing the legacy Error.Tag, otherwise the ErrorInfo.Reason
func (e *Error) ErrorCode() string {
errorCode := e.tag
for _, detail := range e.details {
if _, ok := detail.(*errdetails.ErrorInfo); ok {
if _, errInfoReason := convertErrorDetails(detail, *e); errInfoReason != "" {
return errInfoReason
}
}
}
return errorCode
}
// Category returns the error code's category
func (e *Error) Category() string {
return e.category
}
// Error implements the error interface.
func (e *Error) Error() string {
return e.String()
}
// String returns the string representation.
func (e *Error) String() string {
return fmt.Sprintf(errStringFormat, e.grpcCode.String(), e.message)
}
// Is implements the interface that checks if the error matches the given one.
func (e *Error) Is(targetI error) bool {
// Ignore the message in the comparison because the target could have been formatted
var target *Error
if !errors.As(targetI, &target) {
return false
}
return e.tag == target.tag &&
e.grpcCode == target.grpcCode &&
e.httpCode == target.httpCode
}
// Allow details to be mutable and added to the error in runtime
func (e *Error) AddDetails(details ...proto.Message) *Error {
e.details = append(e.details, details...)
return e
}
// FromError takes in an error and returns back the kitError if it's that type under the hood
func FromError(err error) (*Error, bool) {
if err == nil {
return nil, false
}
var kitErr *Error
if errors.As(err, &kitErr) {
return kitErr, true
}
return nil, false
}
/*** GRPC Methods ***/
// GRPCStatus returns the gRPC status.Status object.
func (e *Error) GRPCStatus() *status.Status {
stat := status.New(e.grpcCode, e.message)
// convert details from proto.Msg -> protoiface.MsgV1
var convertedDetails []protoiface.MessageV1
for _, detail := range e.details {
if v1, ok := detail.(protoiface.MessageV1); ok {
convertedDetails = append(convertedDetails, v1)
} else {
log.Debugf("Failed to convert error details: %s", detail)
}
}
if len(e.details) > 0 {
var err error
stat, err = stat.WithDetails(convertedDetails...)
if err != nil {
log.Debugf("Failed to add error details: %s to status: %s", err, stat)
}
}
return stat
}
/*** HTTP Methods ***/
// JSONErrorValue implements the errorResponseValue interface.
func (e *Error) JSONErrorValue() []byte {
grpcStatus := e.GRPCStatus().Proto()
// Make httpCode human readable
// If there is no http legacy code, use the http status text
// This will get overwritten later if there is an ErrorInfo code
httpStatus := e.tag
if httpStatus == "" {
httpStatus = http.StatusText(e.httpCode)
}
errJSON := errorJSON{
ErrorCode: httpStatus,
Message: grpcStatus.GetMessage(),
}
// Handle err details
details := e.details
if len(details) > 0 {
errJSON.Details = make([]any, len(details))
for i, detail := range details {
detailMap, errorCode := convertErrorDetails(detail, *e)
errJSON.Details[i] = detailMap
// If there is an errorCode, update the overall ErrorCode
if errorCode != "" {
errJSON.ErrorCode = errorCode
}
}
}
errBytes, err := json.Marshal(errJSON)
if err != nil {
errJSON, _ := json.Marshal(fmt.Sprintf("failed to encode proto to JSON: %v", err))
return errJSON
}
return errBytes
}
func convertErrorDetails(detail any, e Error) (map[string]interface{}, string) {
// cast to interface to be able to do type switch
// over all possible error_details defined
// https://github.com/googleapis/go-genproto/blob/main/googleapis/rpc/errdetails/error_details.pb.go
switch typedDetail := detail.(type) {
case *errdetails.ErrorInfo:
desc := typedDetail.ProtoReflect().Descriptor()
detailMap := map[string]interface{}{
"@type": typeGoogleAPI + desc.FullName(),
"reason": typedDetail.GetReason(),
"domain": typedDetail.GetDomain(),
"metadata": typedDetail.GetMetadata(),
}
var errorCode string
// If there is an ErrorInfo Reason, but no legacy Tag code, use the ErrorInfo Reason as the error code
if e.tag == "" && typedDetail.GetReason() != "" {
errorCode = typedDetail.GetReason()
}
return detailMap, errorCode
case *errdetails.RetryInfo:
desc := typedDetail.ProtoReflect().Descriptor()
detailMap := map[string]interface{}{
"@type": typeGoogleAPI + desc.FullName(),
"retry_delay": typedDetail.GetRetryDelay(),
}
return detailMap, ""
case *errdetails.DebugInfo:
desc := typedDetail.ProtoReflect().Descriptor()
detailMap := map[string]interface{}{
"@type": typeGoogleAPI + desc.FullName(),
"stack_entries": typedDetail.GetStackEntries(),
"detail": typedDetail.GetDetail(),
}
return detailMap, ""
case *errdetails.QuotaFailure:
desc := typedDetail.ProtoReflect().Descriptor()
detailMap := map[string]interface{}{
"@type": typeGoogleAPI + desc.FullName(),
"violations": typedDetail.GetViolations(),
}
return detailMap, ""
case *errdetails.PreconditionFailure:
desc := typedDetail.ProtoReflect().Descriptor()
detailMap := map[string]interface{}{
"@type": typeGoogleAPI + desc.FullName(),
"violations": typedDetail.GetViolations(),
}
return detailMap, ""
case *errdetails.BadRequest:
desc := typedDetail.ProtoReflect().Descriptor()
detailMap := map[string]interface{}{
"@type": typeGoogleAPI + desc.FullName(),
"field_violations": typedDetail.GetFieldViolations(),
}
return detailMap, ""
case *errdetails.RequestInfo:
desc := typedDetail.ProtoReflect().Descriptor()
detailMap := map[string]interface{}{
"@type": typeGoogleAPI + desc.FullName(),
"request_id": typedDetail.GetRequestId(),
"serving_data": typedDetail.GetServingData(),
}
return detailMap, ""
case *errdetails.ResourceInfo:
desc := typedDetail.ProtoReflect().Descriptor()
detailMap := map[string]interface{}{
"@type": typeGoogleAPI + desc.FullName(),
"resource_type": typedDetail.GetResourceType(),
"resource_name": typedDetail.GetResourceName(),
"owner": typedDetail.GetOwner(),
"description": typedDetail.GetDescription(),
}
return detailMap, ""
case *errdetails.Help:
desc := typedDetail.ProtoReflect().Descriptor()
detailMap := map[string]interface{}{
"@type": typeGoogleAPI + desc.FullName(),
"links": typedDetail.GetLinks(),
}
return detailMap, ""
case *errdetails.LocalizedMessage:
desc := typedDetail.ProtoReflect().Descriptor()
detailMap := map[string]interface{}{
"@type": typeGoogleAPI + desc.FullName(),
"locale": typedDetail.GetLocale(),
"message": typedDetail.GetMessage(),
}
return detailMap, ""
case *errdetails.QuotaFailure_Violation:
desc := typedDetail.ProtoReflect().Descriptor()
detailMap := map[string]interface{}{
"@type": typeGoogleAPI + desc.FullName(),
"subject": typedDetail.GetSubject(),
"description": typedDetail.GetDescription(),
}
return detailMap, ""
case *errdetails.PreconditionFailure_Violation:
desc := typedDetail.ProtoReflect().Descriptor()
detailMap := map[string]interface{}{
"@type": typeGoogleAPI + desc.FullName(),
"subject": typedDetail.GetSubject(),
"description": typedDetail.GetDescription(),
"type": typedDetail.GetType(),
}
return detailMap, ""
case *errdetails.BadRequest_FieldViolation:
desc := typedDetail.ProtoReflect().Descriptor()
detailMap := map[string]interface{}{
"@type": typeGoogleAPI + desc.FullName(),
"field": typedDetail.GetField(),
"description": typedDetail.GetDescription(),
}
return detailMap, ""
case *errdetails.Help_Link:
desc := typedDetail.ProtoReflect().Descriptor()
detailMap := map[string]interface{}{
"@type": typeGoogleAPI + desc.FullName(),
"description": typedDetail.GetDescription(),
"url": typedDetail.GetUrl(),
}
return detailMap, ""
default:
log.Debugf("Failed to convert error details due to incorrect type. \nSee types here: https://github.com/googleapis/googleapis/blob/master/google/rpc/error_details.proto. \nDetail: %s", detail)
// Handle unknown detail types
unknownDetail := map[string]interface{}{
"unknownDetailType": fmt.Sprintf("%T", typedDetail),
"unknownDetails": fmt.Sprintf("%#v", typedDetail),
}
return unknownDetail, ""
}
}
/**************************************
ErrorBuilder
**************************************/
// NewBuilder create a new ErrorBuilder using the supplied required error fields
func NewBuilder(grpcCode grpcCodes.Code, httpCode int, message string, tag string, category string) *ErrorBuilder {
return &ErrorBuilder{
err: &Error{
details: make([]proto.Message, 0),
grpcCode: grpcCode,
httpCode: httpCode,
message: message,
tag: tag,
category: category,
},
}
}
// WithResourceInfo is used to pass ResourceInfo error details to the Error struct.
func (b *ErrorBuilder) WithResourceInfo(resourceType string, resourceName string, owner string, description string) *ErrorBuilder {
resourceInfo := &errdetails.ResourceInfo{
ResourceType: resourceType,
ResourceName: resourceName,
Owner: owner,
Description: description,
}
b.err.details = append(b.err.details, resourceInfo)
return b
}
// WithHelpLink is used to pass HelpLink error details to the Error struct.
func (b *ErrorBuilder) WithHelpLink(url string, description string) *ErrorBuilder {
link := errdetails.Help_Link{
Description: description,
Url: url,
}
var links []*errdetails.Help_Link
links = append(links, &link)
help := &errdetails.Help{Links: links}
b.err.details = append(b.err.details, help)
return b
}
// WithHelp is used to pass Help error details to the Error struct.
func (b *ErrorBuilder) WithHelp(links []*errdetails.Help_Link) *ErrorBuilder {
b.err.details = append(b.err.details, &errdetails.Help{Links: links})
return b
}
// WithErrorInfo adds error information to the Error struct.
func (b *ErrorBuilder) WithErrorInfo(reason string, metadata map[string]string) *ErrorBuilder {
errorInfo := &errdetails.ErrorInfo{
Domain: Domain,
Reason: reason,
Metadata: metadata,
}
b.err.details = append(b.err.details, errorInfo)
return b
}
// WithFieldViolation is used to pass FieldViolation error details to the Error struct.
func (b *ErrorBuilder) WithFieldViolation(fieldName string, msg string) *ErrorBuilder {
br := &errdetails.BadRequest{
FieldViolations: []*errdetails.BadRequest_FieldViolation{{
Field: fieldName,
Description: msg,
}},
}
b.err.details = append(b.err.details, br)
return b
}
// WithDetails is used to pass any error details to the Error struct.
func (b *ErrorBuilder) WithDetails(details ...proto.Message) *ErrorBuilder {
b.err.details = append(b.err.details, details...)
return b
}
// Build builds our error
func (b *ErrorBuilder) Build() error {
// Check for ErrorInfo, since it's required per the proposal
containsErrorInfo := false
for _, detail := range b.err.details {
if _, ok := detail.(*errdetails.ErrorInfo); ok {
containsErrorInfo = true
break
}
}
if !containsErrorInfo {
log.Errorf("Must include ErrorInfo in error details. Error: %s", b.err.Error())
panic("Must include ErrorInfo in error details.")
}
return b.err
}

1017
errors/errors_test.go Normal file

File diff suppressed because it is too large Load Diff

View File

@ -14,28 +14,37 @@ limitations under the License.
package batcher
import (
"context"
"sync"
"sync/atomic"
"time"
"k8s.io/utils/clock"
"github.com/dapr/kit/events/queue"
)
// key is the type of the comparable key used to batch events.
type key interface {
comparable
type eventCh[T any] struct {
id int
ch chan<- T
}
type Options struct {
Interval time.Duration
Clock clock.Clock
}
// Batcher is a one to many event batcher. It batches events and sends them to
// the added event channel subscribers. Events are sent to the channels after
// the interval has elapsed. If events with the same key are received within
// the interval, the timer is reset.
type Batcher[T key] struct {
interval time.Duration
actives map[T]clock.Timer
eventChs []chan<- struct{}
type Batcher[K comparable, T any] struct {
interval time.Duration
eventChs []*eventCh[T]
queue *queue.Processor[K, *item[K, T]]
currentID int
clock clock.WithDelayedExecution
clock clock.Clock
lock sync.Mutex
wg sync.WaitGroup
closeCh chan struct{}
@ -43,88 +52,129 @@ type Batcher[T key] struct {
}
// New creates a new Batcher with the given interval and key type.
func New[T key](interval time.Duration) *Batcher[T] {
return &Batcher[T]{
interval: interval,
actives: make(map[T]clock.Timer),
clock: clock.RealClock{},
func New[K comparable, T any](opts Options) *Batcher[K, T] {
cl := opts.Clock
if cl == nil {
cl = clock.RealClock{}
}
b := &Batcher[K, T]{
interval: opts.Interval,
clock: cl,
closeCh: make(chan struct{}),
}
}
// WithClock sets the clock used by the batcher. Used for testing.
func (b *Batcher[T]) WithClock(clock clock.WithDelayedExecution) {
b.clock = clock
b.queue = queue.NewProcessor[K, *item[K, T]](queue.Options[K, *item[K, T]]{
ExecuteFn: b.execute,
Clock: opts.Clock,
})
return b
}
// Subscribe adds a new event channel subscriber. If the batcher is closed, the
// subscriber is silently dropped.
func (b *Batcher[T]) Subscribe(eventCh ...chan<- struct{}) {
func (b *Batcher[K, T]) Subscribe(ctx context.Context, ch ...chan<- T) {
b.lock.Lock()
defer b.lock.Unlock()
for _, c := range ch {
b.subscribe(ctx, c)
}
}
func (b *Batcher[K, T]) subscribe(ctx context.Context, ch chan<- T) {
if b.closed.Load() {
return
}
id := b.currentID
b.currentID++
bufferedCh := make(chan T, 50)
b.eventChs = append(b.eventChs, &eventCh[T]{
id: id,
ch: bufferedCh,
})
b.wg.Add(1)
go func() {
defer func() {
b.lock.Lock()
close(ch)
for i, eventCh := range b.eventChs {
if eventCh.id == id {
b.eventChs = append(b.eventChs[:i], b.eventChs[i+1:]...)
break
}
}
b.lock.Unlock()
b.wg.Done()
}()
for {
select {
case <-ctx.Done():
return
case <-b.closeCh:
return
case env := <-bufferedCh:
select {
case ch <- env:
case <-ctx.Done():
case <-b.closeCh:
}
}
}
}()
}
func (b *Batcher[K, T]) execute(i *item[K, T]) {
b.lock.Lock()
defer b.lock.Unlock()
if b.closed.Load() {
return
}
b.eventChs = append(b.eventChs, eventCh...)
for _, ev := range b.eventChs {
select {
case ev.ch <- i.value:
case <-b.closeCh:
}
}
}
// Batch adds the given key to the batcher. If an event for this key is already
// active, the timer is reset. If the batcher is closed, the key is silently
// dropped.
func (b *Batcher[T]) Batch(key T) {
b.lock.Lock()
defer b.lock.Unlock()
if b.closed.Load() {
return
}
if active, ok := b.actives[key]; ok {
if !active.Stop() {
<-active.C()
}
active.Reset(b.interval)
return
}
b.actives[key] = b.clock.AfterFunc(b.interval, func() {
b.lock.Lock()
defer b.lock.Unlock()
b.wg.Add(len(b.eventChs))
delete(b.actives, key)
for _, eventCh := range b.eventChs {
go func(eventCh chan<- struct{}) {
defer b.wg.Done()
select {
case eventCh <- struct{}{}:
case <-b.closeCh:
}
}(eventCh)
}
func (b *Batcher[K, T]) Batch(key K, value T) {
b.queue.Enqueue(&item[K, T]{
key: key,
value: value,
ttl: b.clock.Now().Add(b.interval),
})
}
// Close closes the batcher. It blocks until all events have been sent to the
// subscribers. The batcher will be a no-op after this call.
func (b *Batcher[T]) Close() {
func (b *Batcher[K, T]) Close() {
defer b.wg.Wait()
// Lock to ensure that no new timers are created.
b.queue.Close()
b.lock.Lock()
if b.closed.CompareAndSwap(false, true) {
close(b.closeCh)
}
actives := b.actives
b.lock.Unlock()
for _, active := range actives {
if !active.Stop() {
<-active.C()
}
}
b.lock.Lock()
b.actives = nil
b.lock.Unlock()
}
// item implements queue.queueable.
type item[K comparable, T any] struct {
key K
value T
ttl time.Time
}
func (b *item[K, T]) Key() K {
return b.key
}
func (b *item[K, T]) ScheduledTime() time.Time {
return b.ttl
}

View File

@ -25,49 +25,50 @@ func TestNew(t *testing.T) {
t.Parallel()
interval := time.Millisecond * 10
b := New[string](interval)
assert.Equal(t, interval, b.interval)
b := New[string, struct{}](Options{Interval: interval})
assert.False(t, b.closed.Load())
}
func TestWithClock(t *testing.T) {
b := New[string](time.Millisecond * 10)
fakeClock := testingclock.NewFakeClock(time.Now())
b.WithClock(fakeClock)
b := New[string, struct{}](Options{
Interval: time.Millisecond * 10,
Clock: fakeClock,
})
assert.Equal(t, fakeClock, b.clock)
}
func TestSubscribe(t *testing.T) {
t.Parallel()
b := New[string](time.Millisecond * 10)
b := New[string, struct{}](Options{Interval: time.Millisecond * 10})
ch := make(chan struct{})
b.Subscribe(ch)
assert.Equal(t, 1, len(b.eventChs))
b.Subscribe(t.Context(), ch)
assert.Len(t, b.eventChs, 1)
}
func TestBatch(t *testing.T) {
t.Parallel()
fakeClock := testingclock.NewFakeClock(time.Now())
b := New[string](time.Millisecond * 10)
b.clock = fakeClock
b := New[string, struct{}](Options{
Interval: time.Millisecond * 10,
Clock: fakeClock,
})
ch1 := make(chan struct{})
ch2 := make(chan struct{})
ch3 := make(chan struct{})
b.Subscribe(ch1, ch2)
b.Subscribe(ch3)
b.Subscribe(t.Context(), ch1, ch2)
b.Subscribe(t.Context(), ch3)
b.Batch("key1")
b.Batch("key1")
b.Batch("key1")
b.Batch("key1")
b.Batch("key2")
b.Batch("key2")
b.Batch("key3")
b.Batch("key3")
assert.Equal(t, 3, len(b.actives))
b.Batch("key1", struct{}{})
b.Batch("key1", struct{}{})
b.Batch("key1", struct{}{})
b.Batch("key1", struct{}{})
b.Batch("key2", struct{}{})
b.Batch("key2", struct{}{})
b.Batch("key3", struct{}{})
b.Batch("key3", struct{}{})
assert.Eventually(t, func() bool {
return fakeClock.HasWaiters()
@ -93,7 +94,7 @@ func TestBatch(t *testing.T) {
fakeClock.Step(time.Millisecond * 5)
for i := 0; i < 3; i++ {
for range 3 {
for _, ch := range []chan struct{}{ch1, ch2, ch3} {
select {
case <-ch:
@ -102,37 +103,58 @@ func TestBatch(t *testing.T) {
}
}
}
t.Run("ensure items are received in order with latest value", func(t *testing.T) {
fakeClock := testingclock.NewFakeClock(time.Now())
b := New[int, int](Options{
Interval: time.Millisecond * 10,
Clock: fakeClock,
})
t.Cleanup(b.Close)
ch1 := make(chan int, 10)
ch2 := make(chan int, 10)
ch3 := make(chan int, 10)
b.Subscribe(t.Context(), ch1, ch2)
b.Subscribe(t.Context(), ch3)
for i := range 10 {
b.Batch(i, i)
b.Batch(i, i+1)
b.Batch(i, i+2)
fakeClock.Step(time.Millisecond * 10)
}
for _, ch := range []chan int{ch1} {
for i := range 10 {
select {
case v := <-ch:
assert.Equal(t, i+2, v)
case <-time.After(time.Second):
assert.Fail(t, "should be triggered")
}
}
}
})
}
func TestClose(t *testing.T) {
t.Parallel()
b := New[string](time.Millisecond * 10)
b := New[string, struct{}](Options{Interval: time.Millisecond * 10})
ch := make(chan struct{})
b.Subscribe(ch)
b.Subscribe(t.Context(), ch)
assert.Len(t, b.eventChs, 1)
b.Batch("key1")
assert.Len(t, b.actives, 1)
b.Batch("key1", struct{}{})
b.Close()
assert.True(t, b.closed.Load())
assert.Equal(t, 0, len(b.actives))
}
func TestBatchAfterClose(t *testing.T) {
t.Parallel()
b := New[string](time.Millisecond * 10)
b.Close()
b.Batch("key1")
assert.Equal(t, 0, len(b.actives))
}
func TestSubscribeAfterClose(t *testing.T) {
t.Parallel()
b := New[string](time.Millisecond * 10)
b := New[string, struct{}](Options{Interval: time.Millisecond * 10})
b.Close()
ch := make(chan struct{})
b.Subscribe(ch)
assert.Equal(t, 0, len(b.eventChs))
b.Subscribe(t.Context(), ch)
assert.Empty(t, b.eventChs)
}

View File

@ -0,0 +1,132 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package broadcaster
import (
"context"
"sync"
"sync/atomic"
)
const bufferSize = 10
type eventCh[T any] struct {
id uint64
ch chan<- T
closeEventCh chan struct{}
}
type Broadcaster[T any] struct {
eventChs []*eventCh[T]
currentID uint64
lock sync.Mutex
wg sync.WaitGroup
closeCh chan struct{}
closed atomic.Bool
}
// New creates a new Broadcaster with the given interval and key type.
func New[T any]() *Broadcaster[T] {
return &Broadcaster[T]{
closeCh: make(chan struct{}),
}
}
// Subscribe adds a new event channel subscriber. If the batcher is closed, the
// subscriber is silently dropped.
func (b *Broadcaster[T]) Subscribe(ctx context.Context, ch ...chan<- T) {
b.lock.Lock()
defer b.lock.Unlock()
for _, c := range ch {
b.subscribe(ctx, c)
}
}
func (b *Broadcaster[T]) subscribe(ctx context.Context, ch chan<- T) {
if b.closed.Load() {
return
}
id := b.currentID
b.currentID++
bufferedCh := make(chan T, bufferSize)
closeEventCh := make(chan struct{})
b.eventChs = append(b.eventChs, &eventCh[T]{
id: id,
ch: bufferedCh,
closeEventCh: closeEventCh,
})
b.wg.Add(1)
go func() {
defer func() {
close(closeEventCh)
b.lock.Lock()
for i, eventCh := range b.eventChs {
if eventCh.id == id {
b.eventChs = append(b.eventChs[:i], b.eventChs[i+1:]...)
break
}
}
b.lock.Unlock()
b.wg.Done()
}()
for {
select {
case <-ctx.Done():
return
case <-b.closeCh:
return
case val := <-bufferedCh:
select {
case <-ctx.Done():
return
case <-b.closeCh:
return
case ch <- val:
}
}
}
}()
}
// Broadcast sends the given value to all subscribers.
func (b *Broadcaster[T]) Broadcast(value T) {
b.lock.Lock()
defer b.lock.Unlock()
if b.closed.Load() {
return
}
for _, ev := range b.eventChs {
select {
case <-ev.closeEventCh:
case ev.ch <- value:
case <-b.closeCh:
}
}
}
// Close closes the Broadcaster. It blocks until all events have been sent to
// the subscribers. The Broadcaster will be a no-op after this call.
func (b *Broadcaster[T]) Close() {
defer b.wg.Wait()
b.lock.Lock()
if b.closed.CompareAndSwap(false, true) {
close(b.closeCh)
}
b.lock.Unlock()
}

65
events/loop/fake/fake.go Normal file
View File

@ -0,0 +1,65 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package fake
import (
"context"
"github.com/dapr/kit/events/loop"
)
type Fake[T any] struct {
runFn func(context.Context) error
enqueueFn func(T)
closeFn func(T)
}
func New[T any]() *Fake[T] {
return &Fake[T]{
runFn: func(context.Context) error { return nil },
enqueueFn: func(T) {},
closeFn: func(T) {},
}
}
func (f *Fake[T]) WithRun(fn func(context.Context) error) *Fake[T] {
f.runFn = fn
return f
}
func (f *Fake[T]) WithEnqueue(fn func(T)) *Fake[T] {
f.enqueueFn = fn
return f
}
func (f *Fake[T]) WithClose(fn func(T)) *Fake[T] {
f.closeFn = fn
return f
}
func (f *Fake[T]) Run(ctx context.Context) error {
return f.runFn(ctx)
}
func (f *Fake[T]) Enqueue(t T) {
f.enqueueFn(t)
}
func (f *Fake[T]) Close(t T) {
f.closeFn(t)
}
func (f *Fake[T]) Reset(loop.Handler[T], uint64) loop.Interface[T] {
return f
}

View File

@ -0,0 +1,24 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package fake
import (
"testing"
"github.com/dapr/kit/events/loop"
)
func Test_Fake(*testing.T) {
var _ loop.Interface[int] = New[int]()
}

111
events/loop/loop.go Normal file
View File

@ -0,0 +1,111 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package loop
import (
"context"
"sync"
)
type Handler[T any] interface {
Handle(ctx context.Context, t T) error
}
type Interface[T any] interface {
Run(ctx context.Context) error
Enqueue(t T)
Close(t T)
Reset(h Handler[T], size uint64) Interface[T]
}
type loop[T any] struct {
queue chan T
handler Handler[T]
closed bool
closeCh chan struct{}
lock sync.RWMutex
}
func New[T any](h Handler[T], size uint64) Interface[T] {
return &loop[T]{
queue: make(chan T, size),
handler: h,
closeCh: make(chan struct{}),
}
}
func Empty[T any]() Interface[T] {
return new(loop[T])
}
func (l *loop[T]) Run(ctx context.Context) error {
defer close(l.closeCh)
for {
req, ok := <-l.queue
if !ok {
return nil
}
if err := l.handler.Handle(ctx, req); err != nil {
return err
}
}
}
func (l *loop[T]) Enqueue(req T) {
l.lock.RLock()
defer l.lock.RUnlock()
if l.closed {
return
}
select {
case l.queue <- req:
case <-l.closeCh:
}
}
func (l *loop[T]) Close(req T) {
l.lock.Lock()
l.closed = true
select {
case l.queue <- req:
case <-l.closeCh:
}
close(l.queue)
l.lock.Unlock()
<-l.closeCh
}
func (l *loop[T]) Reset(h Handler[T], size uint64) Interface[T] {
if l == nil {
return New[T](h, size)
}
l.lock.Lock()
defer l.lock.Unlock()
l.closed = false
l.closeCh = make(chan struct{})
l.handler = h
// TODO: @joshvanl: use a ring buffer so that we don't need to reallocate and
// improve performance.
l.queue = make(chan T, size)
return l
}

View File

@ -43,7 +43,9 @@ func ExampleProcessor() {
}
// Create the processor
processor := NewProcessor[*queueableItem](executeFn)
processor := NewProcessor[string, *queueableItem](Options[string, *queueableItem]{
ExecuteFn: executeFn,
})
// Add items to the processor, in any order, using Enqueue
processor.Enqueue(&queueableItem{Name: "item1", ExecutionTime: time.Now().Add(500 * time.Millisecond)})
@ -57,7 +59,7 @@ func ExampleProcessor() {
// Using Dequeue allows removing an item from the queue
processor.Dequeue("item4")
for i := 0; i < 3; i++ {
for range 3 {
fmt.Println(<-executed)
}
// Output:

View File

@ -14,7 +14,6 @@ limitations under the License.
package queue
import (
"errors"
"sync"
"sync/atomic"
"time"
@ -22,13 +21,15 @@ import (
kclock "k8s.io/utils/clock"
)
// ErrProcessorStopped is returned when the processor is not running.
var ErrProcessorStopped = errors.New("processor is stopped")
type Options[K comparable, T Queueable[K]] struct {
ExecuteFn func(r T)
Clock kclock.Clock
}
// Processor manages the queue of items and processes them at the correct time.
type Processor[T queueable] struct {
type Processor[K comparable, T Queueable[K]] struct {
executeFn func(r T)
queue queue[T]
queue queue[K, T]
clock kclock.Clock
lock sync.Mutex
wg sync.WaitGroup
@ -40,48 +41,51 @@ type Processor[T queueable] struct {
// NewProcessor returns a new Processor object.
// executeFn is the callback invoked when the item is to be executed; this will be invoked in a background goroutine.
func NewProcessor[T queueable](executeFn func(r T)) *Processor[T] {
return &Processor[T]{
executeFn: executeFn,
queue: newQueue[T](),
func NewProcessor[K comparable, T Queueable[K]](opts Options[K, T]) *Processor[K, T] {
cl := opts.Clock
if cl == nil {
cl = kclock.RealClock{}
}
return &Processor[K, T]{
executeFn: opts.ExecuteFn,
queue: newQueue[K, T](),
processorRunningCh: make(chan struct{}, 1),
stopCh: make(chan struct{}),
resetCh: make(chan struct{}, 1),
clock: kclock.RealClock{},
clock: cl,
}
}
// WithClock sets the clock used by the processor. Used for testing.
func (p *Processor[T]) WithClock(clock kclock.Clock) *Processor[T] {
p.clock = clock
return p
}
// Enqueue adds a new item to the queue.
// Enqueue adds a new items to the queue.
// If a item with the same ID already exists, it'll be replaced.
func (p *Processor[T]) Enqueue(r T) error {
func (p *Processor[K, T]) Enqueue(rs ...T) {
if p.stopped.Load() {
return ErrProcessorStopped
return
}
p.lock.Lock()
defer p.lock.Unlock()
for _, r := range rs {
p.enqueue(r)
}
}
func (p *Processor[K, T]) enqueue(r T) {
// Insert or replace the item in the queue
// If the item added or replaced is the first one in the queue, we need to know that
p.lock.Lock()
peek, ok := p.queue.Peek()
isFirst := (ok && peek.Key() == r.Key()) // This is going to be true if the item being replaced is the first one in the queue
p.queue.Insert(r, true)
peek, _ = p.queue.Peek() // No need to check for "ok" here because we know this will return an item
isFirst = isFirst || (peek == r) // This is also going to be true if the item just added landed at the front of the queue
p.process(isFirst)
p.lock.Unlock()
return nil
}
// Dequeue removes a item from the queue.
func (p *Processor[T]) Dequeue(key string) error {
func (p *Processor[K, T]) Dequeue(key K) {
if p.stopped.Load() {
return ErrProcessorStopped
return
}
// We need to check if this is the next item in the queue, as that requires stopping the processor
@ -93,13 +97,11 @@ func (p *Processor[T]) Dequeue(key string) error {
p.process(true)
}
p.lock.Unlock()
return nil
}
// Close stops the processor.
// This method blocks until the processor loop returns.
func (p *Processor[T]) Close() error {
func (p *Processor[K, T]) Close() error {
defer p.wg.Wait()
if p.stopped.CompareAndSwap(false, true) {
// Send a signal to stop
@ -114,7 +116,7 @@ func (p *Processor[T]) Close() error {
// Start the processing loop if it's not already running.
// This must be invoked while the caller has a lock.
func (p *Processor[T]) process(isNext bool) {
func (p *Processor[K, T]) process(isNext bool) {
// Do not start a loop if it's already running
select {
case p.processorRunningCh <- struct{}{}:
@ -140,7 +142,7 @@ func (p *Processor[T]) process(isNext bool) {
}
// Processing loop.
func (p *Processor[T]) processLoop() {
func (p *Processor[K, T]) processLoop() {
defer func() {
// Release the channel when exiting
<-p.processorRunningCh
@ -209,7 +211,7 @@ func (p *Processor[T]) processLoop() {
}
// Executes a item when it's time.
func (p *Processor[T]) execute(r T) {
func (p *Processor[K, T]) execute(r T) {
// Pop the item now that we're ready to process it
// There's a small chance this is a different item than the one we peeked before
p.lock.Lock()
@ -226,5 +228,5 @@ func (p *Processor[T]) execute(r T) {
return
}
go p.executeFn(r)
p.executeFn(r)
}

View File

@ -31,10 +31,12 @@ func TestProcessor(t *testing.T) {
// Create the processor
clock := clocktesting.NewFakeClock(time.Now())
executeCh := make(chan *queueableItem)
processor := NewProcessor(func(r *queueableItem) {
executeCh <- r
processor := NewProcessor[string, *queueableItem](Options[string, *queueableItem]{
ExecuteFn: func(r *queueableItem) {
executeCh <- r
},
Clock: clock,
})
processor.clock = clock
assertExecutedItem := func(t *testing.T) *queueableItem {
t.Helper()
@ -63,10 +65,9 @@ func TestProcessor(t *testing.T) {
t.Run("enqueue items", func(t *testing.T) {
for i := 1; i <= 5; i++ {
err := processor.Enqueue(
processor.Enqueue(
newTestItem(i, clock.Now().Add(time.Second*time.Duration(i))),
)
require.NoError(t, err)
}
// Advance tickers by 500ms to start
@ -83,8 +84,7 @@ func TestProcessor(t *testing.T) {
t.Run("enqueue item to be executed right away", func(t *testing.T) {
r := newTestItem(1, clock.Now())
err := processor.Enqueue(r)
require.NoError(t, err)
processor.Enqueue(r)
clock.Step(500 * time.Millisecond)
@ -95,10 +95,9 @@ func TestProcessor(t *testing.T) {
t.Run("enqueue item at the front of the queue", func(t *testing.T) {
// Enqueue 4 items
for i := 1; i <= 4; i++ {
err := processor.Enqueue(
processor.Enqueue(
newTestItem(i, clock.Now().Add(time.Second*time.Duration(i))),
)
require.NoError(t, err)
}
assert.Eventually(t, clock.HasWaiters, time.Second, 100*time.Millisecond)
@ -111,10 +110,9 @@ func TestProcessor(t *testing.T) {
assert.Equal(t, "1", received.Name)
// Add a new item at the front of the queue
err := processor.Enqueue(
processor.Enqueue(
newTestItem(99, clock.Now()),
)
require.NoError(t, err)
// Advance tickers and assert messages are coming in order
for i := 1; i <= 4; i++ {
@ -136,19 +134,16 @@ func TestProcessor(t *testing.T) {
// Enqueue 5 items
for i := 1; i <= 5; i++ {
err := processor.Enqueue(
processor.Enqueue(
newTestItem(i, clock.Now().Add(time.Second*time.Duration(i))),
)
require.NoError(t, err)
}
assert.Equal(t, 5, processor.queue.Len())
// Dequeue items 2 and 4
// Note that this is a string because it's the key
err := processor.Dequeue("2")
require.NoError(t, err)
err = processor.Dequeue("4")
require.NoError(t, err)
processor.Dequeue("2")
processor.Dequeue("4")
assert.Equal(t, 3, processor.queue.Len())
@ -173,10 +168,9 @@ func TestProcessor(t *testing.T) {
t.Run("dequeue item from the front of the queue", func(t *testing.T) {
// Enqueue 6 items
for i := 1; i <= 6; i++ {
err := processor.Enqueue(
processor.Enqueue(
newTestItem(i, clock.Now().Add(time.Second*time.Duration(i))),
)
require.NoError(t, err)
}
// Advance tickers and assert messages are coming in order
@ -187,8 +181,7 @@ func TestProcessor(t *testing.T) {
if i == 2 || i == 5 {
// Dequeue the item at the front of the queue
// Note that this is a string because it's the key
err := processor.Dequeue(strconv.Itoa(i))
require.NoError(t, err)
processor.Dequeue(strconv.Itoa(i))
// Skip items that have been removed
t.Logf("Should not receive signal %d", i)
@ -206,15 +199,13 @@ func TestProcessor(t *testing.T) {
t.Run("replace item", func(t *testing.T) {
// Enqueue 5 items
for i := 1; i <= 5; i++ {
err := processor.Enqueue(
processor.Enqueue(
newTestItem(i, clock.Now().Add(time.Second*time.Duration(i))),
)
require.NoError(t, err)
}
// Replace item 4, bumping its priority down
err := processor.Enqueue(newTestItem(4, clock.Now().Add(6*time.Second)))
require.NoError(t, err)
processor.Enqueue(newTestItem(4, clock.Now().Add(6*time.Second)))
// Advance tickers and assert messages are coming in order
for i := 1; i <= 6; i++ {
@ -241,10 +232,9 @@ func TestProcessor(t *testing.T) {
t.Run("replace item at the front of the queue", func(t *testing.T) {
// Enqueue 5 items
for i := 1; i <= 5; i++ {
err := processor.Enqueue(
processor.Enqueue(
newTestItem(i, clock.Now().Add(time.Second*time.Duration(i))),
)
require.NoError(t, err)
}
// Advance tickers and assert messages are coming in order
@ -253,8 +243,7 @@ func TestProcessor(t *testing.T) {
if i == 2 {
// Replace item 2, bumping its priority down, while it's at the front of the queue
err := processor.Enqueue(newTestItem(2, clock.Now().Add(5*time.Second)))
require.NoError(t, err)
processor.Enqueue(newTestItem(2, clock.Now().Add(5*time.Second)))
// This item has been pushed down
t.Logf("Should not receive signal %d now", i)
@ -282,13 +271,12 @@ func TestProcessor(t *testing.T) {
)
now := clock.Now()
wg := sync.WaitGroup{}
for i := 0; i < count; i++ {
for i := range count {
wg.Add(1)
go func(i int) {
defer wg.Done()
execTime := now.Add(time.Second * time.Duration(rand.Intn(maxDelay))) //nolint:gosec
err := processor.Enqueue(newTestItem(i, execTime))
require.NoError(t, err)
processor.Enqueue(newTestItem(i, execTime))
}(i)
}
wg.Wait()
@ -324,7 +312,7 @@ func TestProcessor(t *testing.T) {
close(doneCh)
// Ensure all items are true
for i := 0; i < count; i++ {
for i := range count {
assert.Truef(t, collected[i], "item %d not received", i)
}
})
@ -332,10 +320,9 @@ func TestProcessor(t *testing.T) {
t.Run("stop processor", func(t *testing.T) {
// Enqueue 5 items
for i := 1; i <= 5; i++ {
err := processor.Enqueue(
processor.Enqueue(
newTestItem(i, clock.Now().Add(time.Second*time.Duration(i))),
)
require.NoError(t, err)
}
assert.Eventually(t, clock.HasWaiters, time.Second, 100*time.Millisecond)
@ -348,10 +335,8 @@ func TestProcessor(t *testing.T) {
assertNoExecutedItem(t)
// Enqueuing and dequeueing should fail
err := processor.Enqueue(newTestItem(99, clock.Now()))
require.ErrorIs(t, err, ErrProcessorStopped)
err = processor.Dequeue("99")
require.ErrorIs(t, err, ErrProcessorStopped)
processor.Enqueue(newTestItem(99, clock.Now()))
processor.Dequeue("99")
// Stopping again is a nop (should not crash)
require.NoError(t, processor.Close())
@ -364,10 +349,12 @@ func TestClose(t *testing.T) {
// Create the processor
clock := clocktesting.NewFakeClock(time.Now())
executeCh := make(chan *queueableItem)
processor := NewProcessor(func(r *queueableItem) {
executeCh <- r
processor := NewProcessor[string, *queueableItem](Options[string, *queueableItem]{
ExecuteFn: func(r *queueableItem) {
executeCh <- r
},
Clock: clock,
})
processor.clock = clock
processor.Enqueue(newTestItem(1, clock.Now().Add(time.Second)))
processor.Enqueue(newTestItem(2, clock.Now().Add(time.Second*2)))
@ -415,14 +402,14 @@ func TestClose(t *testing.T) {
default:
}
for i := 0; i < 3; i++ {
for range 3 {
select {
case err := <-closeCh:
assert.NoError(t, err)
require.NoError(t, err)
case <-time.After(time.Second * 3):
t.Fatal("close should have returned")
}
}
assert.NoError(t, processor.Close())
require.NoError(t, processor.Close())
}

View File

@ -18,10 +18,10 @@ import (
"time"
)
// queueable is the interface for items that can be added to the queue.
type queueable interface {
// Queueable is the interface for items that can be added to the queue.
type Queueable[T comparable] interface {
comparable
Key() string
Key() T
ScheduledTime() time.Time
}
@ -29,27 +29,27 @@ type queueable interface {
// It acts as a "priority queue", in which items are added in order of when they're scheduled.
// Internally, it uses a heap (from container/heap) that allows Insert and Pop operations to be completed in O(log N) time (where N is the queue's length).
// Note: methods in this struct are not safe for concurrent use. Callers should use locks to ensure consistency.
type queue[T queueable] struct {
heap *queueHeap[T]
items map[string]*queueItem[T]
type queue[K comparable, T Queueable[K]] struct {
heap *queueHeap[K, T]
items map[K]*queueItem[K, T]
}
// newQueue creates a new queue.
func newQueue[T queueable]() queue[T] {
return queue[T]{
heap: new(queueHeap[T]),
items: make(map[string]*queueItem[T]),
func newQueue[K comparable, T Queueable[K]]() queue[K, T] {
return queue[K, T]{
heap: new(queueHeap[K, T]),
items: make(map[K]*queueItem[K, T]),
}
}
// Len returns the number of items in the queue.
func (p *queue[T]) Len() int {
func (p *queue[K, T]) Len() int {
return p.heap.Len()
}
// Insert inserts a new item into the queue.
// If replace is true, existing items are replaced
func (p *queue[T]) Insert(r T, replace bool) {
func (p *queue[K, T]) Insert(r T, replace bool) {
key := r.Key()
// Check if the item already exists
@ -62,7 +62,7 @@ func (p *queue[T]) Insert(r T, replace bool) {
return
}
item = &queueItem[T]{
item = &queueItem[K, T]{
value: r,
}
heap.Push(p.heap, item)
@ -71,13 +71,13 @@ func (p *queue[T]) Insert(r T, replace bool) {
// Pop removes the next item in the queue and returns it.
// The returned boolean value will be "true" if an item was found.
func (p *queue[T]) Pop() (T, bool) {
func (p *queue[K, T]) Pop() (T, bool) {
if p.Len() == 0 {
var zero T
return zero, false
}
item, ok := heap.Pop(p.heap).(*queueItem[T])
item, ok := heap.Pop(p.heap).(*queueItem[K, T])
if !ok || item == nil {
var zero T
return zero, false
@ -89,7 +89,7 @@ func (p *queue[T]) Pop() (T, bool) {
// Peek returns the next item in the queue, without removing it.
// The returned boolean value will be "true" if an item was found.
func (p *queue[T]) Peek() (T, bool) {
func (p *queue[K, T]) Peek() (T, bool) {
if p.Len() == 0 {
var zero T
return zero, false
@ -99,7 +99,7 @@ func (p *queue[T]) Peek() (T, bool) {
}
// Remove an item from the queue.
func (p *queue[T]) Remove(key string) {
func (p *queue[K, T]) Remove(key K) {
// If the item is not in the queue, this is a nop
item, ok := p.items[key]
if !ok {
@ -111,7 +111,7 @@ func (p *queue[T]) Remove(key string) {
}
// Update an item in the queue.
func (p *queue[T]) Update(r T) {
func (p *queue[K, T]) Update(r T) {
// If the item is not in the queue, this is a nop
item, ok := p.items[r.Key()]
if !ok {
@ -122,37 +122,37 @@ func (p *queue[T]) Update(r T) {
heap.Fix(p.heap, item.index)
}
type queueItem[T queueable] struct {
type queueItem[K comparable, T Queueable[K]] struct {
value T
// The index of the item in the heap. This is maintained by the heap.Interface methods.
index int
}
type queueHeap[T queueable] []*queueItem[T]
type queueHeap[K comparable, T Queueable[K]] []*queueItem[K, T]
func (pq queueHeap[T]) Len() int {
func (pq queueHeap[K, T]) Len() int {
return len(pq)
}
func (pq queueHeap[T]) Less(i, j int) bool {
func (pq queueHeap[K, T]) Less(i, j int) bool {
return pq[i].value.ScheduledTime().Before(pq[j].value.ScheduledTime())
}
func (pq queueHeap[T]) Swap(i, j int) {
func (pq queueHeap[K, T]) Swap(i, j int) {
pq[i], pq[j] = pq[j], pq[i]
pq[i].index = i
pq[j].index = j
}
func (pq *queueHeap[T]) Push(x any) {
func (pq *queueHeap[K, T]) Push(x any) {
n := len(*pq)
item := x.(*queueItem[T])
item := x.(*queueItem[K, T])
item.index = n
*pq = append(*pq, item)
}
func (pq *queueHeap[T]) Pop() any {
func (pq *queueHeap[K, T]) Pop() any {
old := *pq
n := len(old)
item := old[n-1]

View File

@ -23,7 +23,7 @@ import (
)
func TestQueue(t *testing.T) {
queue := newQueue[*queueableItem]()
queue := newQueue[string, *queueableItem]()
// Add 5 items, which are not in order
queue.Insert(newTestItem(2, "2022-02-02T02:02:02Z"), false)
@ -56,7 +56,7 @@ func TestQueue(t *testing.T) {
}
func TestQueueSkipDuplicates(t *testing.T) {
queue := newQueue[*queueableItem]()
queue := newQueue[string, *queueableItem]()
// Add 2 items
queue.Insert(newTestItem(2, "2022-02-02T02:02:02Z"), false)
@ -78,7 +78,7 @@ func TestQueueSkipDuplicates(t *testing.T) {
}
func TestQueueReplaceDuplicates(t *testing.T) {
queue := newQueue[*queueableItem]()
queue := newQueue[string, *queueableItem]()
// Add 2 items
queue.Insert(newTestItem(2, "2022-02-02T02:02:02Z"), false)
@ -100,7 +100,7 @@ func TestQueueReplaceDuplicates(t *testing.T) {
}
func TestAddToQueue(t *testing.T) {
queue := newQueue[*queueableItem]()
queue := newQueue[string, *queueableItem]()
// Add 5 items, which are not in order
queue.Insert(newTestItem(2, "2022-02-02T02:02:02Z"), false)
@ -151,7 +151,7 @@ func TestAddToQueue(t *testing.T) {
}
func TestRemoveFromQueue(t *testing.T) {
queue := newQueue[*queueableItem]()
queue := newQueue[string, *queueableItem]()
// Add 5 items, which are not in order
queue.Insert(newTestItem(2, "2022-02-02T02:02:02Z"), false)
@ -193,7 +193,7 @@ func TestRemoveFromQueue(t *testing.T) {
}
func TestUpdateInQueue(t *testing.T) {
queue := newQueue[*queueableItem]()
queue := newQueue[string, *queueableItem]()
// Add 5 items, which are not in order
queue.Insert(newTestItem(2, "2022-02-02T02:02:02Z"), false)
@ -238,7 +238,7 @@ func TestUpdateInQueue(t *testing.T) {
}
func TestQueuePeek(t *testing.T) {
queue := newQueue[*queueableItem]()
queue := newQueue[string, *queueableItem]()
// Peeking an empty queue returns false
_, ok := queue.Peek()
@ -299,7 +299,7 @@ func newTestItem(n int, dueTime any) *queueableItem {
return r
}
func popAndCompare(t *testing.T, q *queue[*queueableItem], expectN int, expectDueTime string) {
func popAndCompare(t *testing.T, q *queue[string, *queueableItem], expectN int, expectDueTime string) {
r, ok := q.Pop()
require.True(t, ok)
require.NotNil(t, r)
@ -307,7 +307,7 @@ func popAndCompare(t *testing.T, q *queue[*queueableItem], expectN int, expectDu
assert.Equal(t, expectDueTime, r.ScheduledTime().Format(time.RFC3339))
}
func peekAndCompare(t *testing.T, q *queue[*queueableItem], expectN int, expectDueTime string) {
func peekAndCompare(t *testing.T, q *queue[string, *queueableItem], expectN int, expectDueTime string) {
r, ok := q.Peek()
require.True(t, ok)
require.NotNil(t, r)

View File

@ -0,0 +1,251 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ratelimiting
import (
"context"
"errors"
"sync"
"sync/atomic"
"time"
"k8s.io/utils/clock"
)
// OptionsCoalescing configures a Coalescing RateLimiter.
type OptionsCoalescing struct {
// InitialDelay is the initial delay for the rate limiter. The rate limiter
// will not delay events less than the initial delay.
// Defaults to 500ms.
InitialDelay *time.Duration
// MaxDelay is the maximum delay for the rate limiter. The rate limiter will
// not delay events longer than the max delay.
// Defaults to 5s.
MaxDelay *time.Duration
// MaxPendingEvents is the maximum number of events that can pending on a
// rate limiter, before it fires an event anyway. Useful to prevent a rate
// limiter never firing events in a high throughput scenario.
// Defaults to unlimited.
MaxPendingEvents *int
}
// coalescing is a rate limiter that rate limits events. It coalesces events
// that occur within a rate limiting window.
type coalescing struct {
initialDelay time.Duration
maxDelay time.Duration
maxPendingEvents *int
pendingEvents int
timer clock.Timer
hasTimer atomic.Bool
inputCh chan struct{}
currentDur time.Duration
backoffFactor int
wg sync.WaitGroup
lock sync.RWMutex
clock clock.WithTicker
running atomic.Bool
closeCh chan struct{}
closed atomic.Bool
}
func NewCoalescing(opts OptionsCoalescing) (RateLimiter, error) {
initialDelay := time.Millisecond * 500
if opts.InitialDelay != nil {
initialDelay = *opts.InitialDelay
}
if initialDelay <= 0 {
return nil, errors.New("initial delay must be > 0")
}
maxDelay := time.Second * 5
if opts.MaxDelay != nil {
maxDelay = *opts.MaxDelay
}
if maxDelay <= 0 {
return nil, errors.New("max delay must be > 0")
}
if maxDelay < initialDelay {
return nil, errors.New("max delay must be >= base delay")
}
if opts.MaxPendingEvents != nil && *opts.MaxPendingEvents <= 0 {
return nil, errors.New("max pending events must be > 0")
}
return &coalescing{
initialDelay: initialDelay,
maxDelay: maxDelay,
maxPendingEvents: opts.MaxPendingEvents,
currentDur: initialDelay,
backoffFactor: 1,
inputCh: make(chan struct{}),
closeCh: make(chan struct{}),
clock: clock.RealClock{},
}, nil
}
// Run runs the rate limiter. It will begin rate limiting events after the
// first event is received.
func (c *coalescing) Run(ctx context.Context, ch chan<- struct{}) error {
if !c.running.CompareAndSwap(false, true) {
return errors.New("already running")
}
// Prevent wg race condition on Close and Run.
c.lock.Lock()
c.wg.Add(1)
c.lock.Unlock()
defer c.wg.Done()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
for {
// If the timer doesn't exist yet, we're waiting for the first event (which
// will fire immediately when received).
var timerCh <-chan time.Time
c.lock.RLock()
if c.hasTimer.Load() {
timerCh = c.timer.C()
}
c.lock.RUnlock()
select {
case <-ctx.Done():
return nil
case <-c.closeCh:
cancel()
return nil
case <-c.inputCh:
c.handleInputCh(ctx, ch)
case <-timerCh:
c.handleTimerFired(ctx, ch)
}
}
}
func (c *coalescing) handleInputCh(ctx context.Context, ch chan<- struct{}) {
c.lock.Lock()
defer c.lock.Unlock()
switch {
case !c.hasTimer.Load():
// We don't have a timer yet, so this is the first event that has fired. We
// fire the event immediately, and set the timer to fire again after the
// initial delay.
c.timer = c.clock.NewTimer(c.initialDelay)
c.hasTimer.Store(true)
c.fireEvent(ctx, ch)
default:
// If maxPendingEvents is set and we have reached it then fire the event
// immediately.
if c.maxPendingEvents != nil && c.pendingEvents >= *c.maxPendingEvents {
c.fireEvent(ctx, ch)
return
}
if !c.timer.Stop() {
<-c.timer.C()
}
// Setup backoff. Backoff is exponential. If initial is 500ms and max is
// 5s, the backoff will follow:
// 500ms, 1s, 2s, 4s, 5s, 5s, 5s, ...
if c.currentDur < c.maxDelay {
c.backoffFactor *= 2
c.currentDur = time.Duration(float64(c.initialDelay) * float64(c.backoffFactor))
if c.currentDur > c.maxDelay {
c.currentDur = c.maxDelay
}
}
c.timer.Reset(c.currentDur)
}
}
func (c *coalescing) handleTimerFired(ctx context.Context, ch chan<- struct{}) {
c.lock.Lock()
defer c.lock.Unlock()
c.fireEvent(ctx, ch)
c.reset()
}
func (c *coalescing) fireEvent(ctx context.Context, ch chan<- struct{}) {
// Important to only send on the channel if there are pending events,
// otherwise we will double send an event, for example if only a single event
// was sent and then the rate limiting window expired with no new events.
if c.pendingEvents > 0 {
c.pendingEvents = 0
c.wg.Add(1)
go func() {
defer c.wg.Done()
select {
case ch <- struct{}{}:
case <-ctx.Done():
}
}()
}
}
func (c *coalescing) reset() {
if !c.timer.Stop() {
select {
case <-c.timer.C():
default:
}
}
c.pendingEvents = 0
c.currentDur = c.initialDelay
c.backoffFactor = 1
c.hasTimer.Store(false)
c.timer = nil
}
func (c *coalescing) Add() {
c.lock.Lock()
defer c.lock.Unlock()
c.pendingEvents++
c.wg.Add(1)
go func() {
defer c.wg.Done()
select {
case c.inputCh <- struct{}{}:
case <-c.closeCh:
}
}()
}
func (c *coalescing) Close() {
defer func() {
// Prevent wg race condition on Close and Run.
c.lock.Lock()
c.wg.Wait()
c.lock.Unlock()
}()
if c.closed.CompareAndSwap(false, true) {
close(c.closeCh)
}
}
var _ RateLimiter = (*coalescing)(nil)

View File

@ -0,0 +1,358 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ratelimiting
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"k8s.io/utils/clock"
clocktesting "k8s.io/utils/clock/testing"
"github.com/dapr/kit/ptr"
)
func TestCoalescing(t *testing.T) {
runCoalescingTests := func(t *testing.T, clock clock.WithTicker, opts OptionsCoalescing) (*coalescing, chan struct{}) {
t.Helper()
c, err := NewCoalescing(opts)
require.NoError(t, err)
if clock != nil {
c.(RateLimiterWithTicker).WithTicker(clock)
}
ch := make(chan struct{})
errCh := make(chan error)
go func() {
errCh <- c.Run(t.Context(), ch)
}()
t.Cleanup(func() {
c.Close()
select {
case err := <-errCh:
require.NoError(t, err)
case <-time.After(time.Second):
require.Fail(t, "timeout")
}
})
return c.(*coalescing), ch
}
assertChannel := func(t *testing.T, ch chan struct{}) {
t.Helper()
select {
case <-ch:
case <-time.After(time.Second):
require.Fail(t, "timeout")
}
}
assertNoChannel := func(t *testing.T, ch chan struct{}) {
t.Helper()
select {
case <-ch:
require.Fail(t, "should not have received event")
case <-time.After(time.Millisecond * 10):
}
}
t.Run("closing context should return Run", func(t *testing.T) {
c, err := NewCoalescing(OptionsCoalescing{})
require.NoError(t, err)
ctx, cancel := context.WithCancel(t.Context())
errCh := make(chan error)
go func() {
errCh <- c.Run(ctx, make(chan struct{}))
}()
cancel()
select {
case err := <-errCh:
require.NoError(t, err)
case <-time.After(time.Second):
require.Fail(t, "timeout")
}
})
t.Run("calling Close should return Run", func(t *testing.T) {
c, err := NewCoalescing(OptionsCoalescing{})
require.NoError(t, err)
errCh := make(chan error)
go func() {
errCh <- c.Run(t.Context(), make(chan struct{}))
}()
c.Close()
select {
case err := <-errCh:
require.NoError(t, err)
case <-time.After(time.Second):
require.Fail(t, "timeout")
}
})
t.Run("calling Run twice should error", func(t *testing.T) {
c, err := NewCoalescing(OptionsCoalescing{})
require.NoError(t, err)
errCh := make(chan error)
go func() {
errCh <- c.Run(t.Context(), make(chan struct{}))
}()
c.Close()
select {
case err := <-errCh:
require.NoError(t, err)
case <-time.After(time.Second):
require.Fail(t, "timeout")
}
go func() {
errCh <- c.Run(t.Context(), make(chan struct{}))
}()
select {
case err := <-errCh:
require.Error(t, err)
case <-time.After(time.Second):
require.Fail(t, "timeout")
}
})
t.Run("options", func(t *testing.T) {
_, err := NewCoalescing(OptionsCoalescing{
InitialDelay: ptr.Of(-time.Second),
})
require.Error(t, err)
_, err = NewCoalescing(OptionsCoalescing{
MaxDelay: ptr.Of(-time.Second),
})
require.Error(t, err)
_, err = NewCoalescing(OptionsCoalescing{
MaxPendingEvents: ptr.Of(0),
})
require.Error(t, err)
_, err = NewCoalescing(OptionsCoalescing{
MaxPendingEvents: ptr.Of(-1),
})
require.Error(t, err)
_, err = NewCoalescing(OptionsCoalescing{
InitialDelay: ptr.Of(time.Second),
MaxDelay: ptr.Of(time.Second / 2),
})
require.Error(t, err)
_, err = NewCoalescing(OptionsCoalescing{
InitialDelay: ptr.Of(time.Second),
MaxDelay: ptr.Of(time.Second * 2),
MaxPendingEvents: ptr.Of(2),
})
require.NoError(t, err)
})
t.Run("sending a single event initially should immediately send it", func(t *testing.T) {
c, ch := runCoalescingTests(t, nil, OptionsCoalescing{})
c.Add()
select {
case <-ch:
case <-time.After(time.Second):
require.Fail(t, "timeout")
}
})
t.Run("second event after initial delay should not be rate limited", func(t *testing.T) {
clock := clocktesting.NewFakeClock(time.Now())
c, ch := runCoalescingTests(t, clock, OptionsCoalescing{
InitialDelay: ptr.Of(time.Second),
MaxDelay: ptr.Of(time.Second * 2),
})
c.Add()
assertChannel(t, ch)
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
clock.Step(time.Second)
c.Add()
assertChannel(t, ch)
assertNoChannel(t, ch)
})
t.Run("second event before initial delay should be rate limited", func(t *testing.T) {
clock := clocktesting.NewFakeClock(time.Now())
c, ch := runCoalescingTests(t, clock, OptionsCoalescing{
InitialDelay: ptr.Of(time.Second),
MaxDelay: ptr.Of(time.Second * 2),
})
c.Add()
assertChannel(t, ch)
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
clock.Step(time.Second / 2)
c.Add()
assertNoChannel(t, ch)
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
clock.Step(time.Second * 2)
assertChannel(t, ch)
assertNoChannel(t, ch)
})
t.Run("multiple events before initial delay should be rate limited to single event", func(t *testing.T) {
clock := clocktesting.NewFakeClock(time.Now())
c, ch := runCoalescingTests(t, clock, OptionsCoalescing{
InitialDelay: ptr.Of(time.Second),
MaxDelay: ptr.Of(time.Second * 2),
})
c.Add()
assertChannel(t, ch)
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
clock.Step(time.Second / 2)
c.Add()
c.Add()
c.Add()
assertNoChannel(t, ch)
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
clock.Step(time.Second * 2)
assertChannel(t, ch)
assertNoChannel(t, ch)
})
t.Run("rate limiting should increase if events keep being added", func(t *testing.T) {
clock := clocktesting.NewFakeClock(time.Now())
c, ch := runCoalescingTests(t, clock, OptionsCoalescing{
InitialDelay: ptr.Of(time.Second),
MaxDelay: ptr.Of(time.Second * 5),
})
c.Add()
assertChannel(t, ch)
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
clock.Step(time.Second / 2)
c.Add()
assertNoChannel(t, ch)
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
clock.Step(time.Second * 1)
c.Add()
assertNoChannel(t, ch)
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
clock.Step(time.Second * 2)
c.Add()
assertNoChannel(t, ch)
for range 4 {
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
clock.Step(time.Second * 4)
c.Add()
assertNoChannel(t, ch)
}
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
clock.Step(time.Second * 5)
assertChannel(t, ch)
assertNoChannel(t, ch)
})
t.Run("should fire event if reached maximum pending events", func(t *testing.T) {
clock := clocktesting.NewFakeClock(time.Now())
c, ch := runCoalescingTests(t, clock, OptionsCoalescing{
InitialDelay: ptr.Of(time.Second),
MaxDelay: ptr.Of(time.Second * 5),
MaxPendingEvents: ptr.Of(3),
})
c.Add()
assertChannel(t, ch)
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
clock.Step(time.Second / 2)
c.Add()
assertNoChannel(t, ch)
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
clock.Step(time.Second * 1)
c.Add()
assertNoChannel(t, ch)
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
clock.Step(time.Second * 1)
// We have reached 3 pending events so should fire event though we are rate
// limited.
c.Add()
assertChannel(t, ch)
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
clock.Step(time.Second * 2)
c.Add()
assertNoChannel(t, ch)
// Expire rate limit and fire event.
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
clock.Step(time.Second * 5)
assertChannel(t, ch)
assertNoChannel(t, ch)
// New event should fire immediately.
c.Add()
assertChannel(t, ch)
})
t.Run("lots of events fired in the first rate limiting window will trigger 2 event omitted", func(t *testing.T) {
clock := clocktesting.NewFakeClock(time.Now())
c, ch := runCoalescingTests(t, clock, OptionsCoalescing{
InitialDelay: ptr.Of(time.Second),
MaxDelay: ptr.Of(time.Second * 5),
})
c.Add()
assertChannel(t, ch)
assert.Eventually(t, c.hasTimer.Load, time.Second, time.Millisecond)
for range 10 {
c.Add()
}
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond)
clock.Step(time.Second * 5)
assertChannel(t, ch)
assert.False(t, clock.HasWaiters())
assertNoChannel(t, ch)
})
}

View File

@ -0,0 +1,29 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ratelimiting
import "context"
// RateLimiter is the interface for rate limiting events.
type RateLimiter interface {
// Run starts the rate limiter. The given channel will have events sent to
// it, according to the rate limited parameters.
Run(ctx context.Context, eventCh chan<- struct{}) error
// Add adds a new event to the rate limiter.
Add()
// Close closes the rate limiter and waits for all resources to be released.
Close()
}

View File

@ -0,0 +1,32 @@
//go:build unit
// +build unit
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ratelimiting
import "k8s.io/utils/clock"
// RateLimiterWithTicker is a RateLimiter that can be configured with a ticker.
// Used for testing.
type RateLimiterWithTicker interface {
RateLimiter
WithTicker(c clock.WithTicker)
}
func (c *coalescing) WithTicker(clock clock.WithTicker) {
c.clock = clock
}
var _ RateLimiterWithTicker = (*coalescing)(nil)

98
fswatcher/fswatcher.go Normal file
View File

@ -0,0 +1,98 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package fswatcher
import (
"context"
"errors"
"fmt"
"sync/atomic"
"time"
"github.com/fsnotify/fsnotify"
"github.com/dapr/kit/events/batcher"
)
// Options are the options for the FSWatcher.
type Options struct {
// Targets is a list of directories to watch for changes.
Targets []string
// Interval is the interval to wait before sending a notification after a file has changed.
// Default to 500ms.
Interval *time.Duration
}
// FSWatcher watches for changes to a directory on the filesystem and sends a notification to eventCh every time a file in the folder is changed.
// Although it's possible to watch for individual files, that's not recommended; watch for the file's parent folder instead.
// That is because, like in Kubernetes which uses system links on mounted volumes, the file may be deleted and recreated with a different inode.
// Note that changes are batched for 0.5 seconds before notifications are sent as events on a single file often come in batches.
type FSWatcher struct {
w *fsnotify.Watcher
running atomic.Bool
batcher *batcher.Batcher[string, struct{}]
}
func New(opts Options) (*FSWatcher, error) {
w, err := fsnotify.NewWatcher()
if err != nil {
return nil, fmt.Errorf("failed to create watcher: %w", err)
}
for _, target := range opts.Targets {
if err = w.Add(target); err != nil {
return nil, fmt.Errorf("failed to add target %s: %w", target, err)
}
}
interval := time.Millisecond * 500
if opts.Interval != nil {
interval = *opts.Interval
}
if interval < 0 {
return nil, errors.New("interval must be positive")
}
return &FSWatcher{
w: w,
// Often the case, writes to files are not atomic and involve multiple file system events.
// We want to hold off on sending events until we are sure that the file has been written to completion. We do this by waiting for a period of time after the last event has been received for a file name.
batcher: batcher.New[string, struct{}](batcher.Options{
Interval: interval,
}),
}, nil
}
func (f *FSWatcher) Run(ctx context.Context, eventCh chan<- struct{}) error {
if !f.running.CompareAndSwap(false, true) {
return errors.New("watcher already running")
}
defer f.batcher.Close()
f.batcher.Subscribe(ctx, eventCh)
for {
select {
case <-ctx.Done():
return f.w.Close()
case err := <-f.w.Errors:
return errors.Join(fmt.Errorf("watcher error: %w", err), f.w.Close())
case event := <-f.w.Events:
f.batcher.Batch(event.Name, struct{}{})
}
}
}

248
fswatcher/fswatcher_test.go Normal file
View File

@ -0,0 +1,248 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package fswatcher
import (
"context"
"os"
"path/filepath"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
clocktesting "k8s.io/utils/clock/testing"
"github.com/dapr/kit/events/batcher"
"github.com/dapr/kit/ptr"
)
func TestFSWatcher(t *testing.T) {
runWatcher := func(t *testing.T, opts Options, bacher *batcher.Batcher[string, struct{}]) <-chan struct{} {
t.Helper()
f, err := New(opts)
require.NoError(t, err)
if bacher != nil {
f.WithBatcher(bacher)
}
errCh := make(chan error)
ctx, cancel := context.WithCancel(t.Context())
eventsCh := make(chan struct{})
go func() {
errCh <- f.Run(ctx, eventsCh)
}()
t.Cleanup(func() {
cancel()
select {
case err := <-errCh:
require.NoError(t, err)
case <-time.After(time.Second):
assert.Fail(t, "timeout waiting for watcher to stop")
}
})
assert.Eventually(t, f.running.Load, time.Second, time.Millisecond*10)
return eventsCh
}
t.Run("creating fswatcher with no directory should not error", func(t *testing.T) {
runWatcher(t, Options{}, nil)
})
t.Run("creating fswatcher with 0 interval should not error", func(t *testing.T) {
_, err := New(Options{
Interval: ptr.Of(time.Duration(0)),
})
require.NoError(t, err)
})
t.Run("creating fswatcher with negative interval should error", func(t *testing.T) {
_, err := New(Options{
Interval: ptr.Of(time.Duration(-1)),
})
require.Error(t, err)
})
t.Run("running Run twice should error", func(t *testing.T) {
fs, err := New(Options{})
require.NoError(t, err)
ctx, cancel := context.WithCancel(t.Context())
cancel()
require.NoError(t, fs.Run(ctx, make(chan struct{})))
require.Error(t, fs.Run(ctx, make(chan struct{})))
})
t.Run("creating fswatcher with non-existent directory should error", func(t *testing.T) {
dir := t.TempDir()
require.NoError(t, os.RemoveAll(dir))
_, err := New(Options{
Targets: []string{dir},
})
require.Error(t, err)
})
t.Run("should fire event when event occurs on target file", func(t *testing.T) {
fp := filepath.Join(t.TempDir(), "test.txt")
require.NoError(t, os.WriteFile(fp, []byte{}, 0o600))
eventsCh := runWatcher(t, Options{
Targets: []string{fp},
Interval: ptr.Of(time.Duration(1)),
}, nil)
assert.Empty(t, eventsCh)
if runtime.GOOS == "windows" {
// If running in windows, wait for notify to be ready.
time.Sleep(time.Second)
}
require.NoError(t, os.WriteFile(fp, []byte{}, 0o600))
select {
case <-eventsCh:
case <-time.After(time.Second):
assert.Fail(t, "timeout waiting for event")
}
})
t.Run("should fire 2 events when event occurs on 2 file target", func(t *testing.T) {
fp1 := filepath.Join(t.TempDir(), "test.txt")
fp2 := filepath.Join(t.TempDir(), "test.txt")
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o600))
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o600))
eventsCh := runWatcher(t, Options{
Targets: []string{fp1, fp2},
Interval: ptr.Of(time.Duration(1)),
}, nil)
assert.Empty(t, eventsCh)
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o600))
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o600))
for range 2 {
select {
case <-eventsCh:
case <-time.After(time.Second):
assert.Fail(t, "timeout waiting for event")
}
}
})
t.Run("should fire 2 events when event occurs on 2 files inside target directory", func(t *testing.T) {
dir := t.TempDir()
fp1 := filepath.Join(dir, "test1.txt")
fp2 := filepath.Join(dir, "test2.txt")
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o600))
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o600))
eventsCh := runWatcher(t, Options{
Targets: []string{fp1, fp2},
Interval: ptr.Of(time.Duration(1)),
}, nil)
if runtime.GOOS == "windows" {
// If running in windows, wait for notify to be ready.
time.Sleep(time.Second)
}
assert.Empty(t, eventsCh)
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o600))
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o600))
for range 2 {
select {
case <-eventsCh:
case <-time.After(time.Second):
assert.Fail(t, "timeout waiting for event")
}
}
})
t.Run("should fire 2 events when event occurs on 2 target directories", func(t *testing.T) {
dir1 := t.TempDir()
dir2 := t.TempDir()
fp1 := filepath.Join(dir1, "test1.txt")
fp2 := filepath.Join(dir2, "test2.txt")
eventsCh := runWatcher(t, Options{
Targets: []string{dir1, dir2},
Interval: ptr.Of(time.Duration(1)),
}, nil)
assert.Empty(t, eventsCh)
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o600))
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o600))
for range 2 {
select {
case <-eventsCh:
case <-time.After(time.Second):
assert.Fail(t, "timeout waiting for event")
}
}
})
t.Run("should batch events of the same file for multiple events", func(t *testing.T) {
clock := clocktesting.NewFakeClock(time.Time{})
batcher := batcher.New[string, struct{}](batcher.Options{
Interval: time.Millisecond * 500,
Clock: clock,
})
dir1 := t.TempDir()
dir2 := t.TempDir()
fp1 := filepath.Join(dir1, "test1.txt")
fp2 := filepath.Join(dir2, "test2.txt")
eventsCh := runWatcher(t, Options{Targets: []string{dir1, dir2}}, batcher)
assert.Empty(t, eventsCh)
if runtime.GOOS == "windows" {
// If running in windows, wait for notify to be ready.
time.Sleep(time.Second)
}
for range 10 {
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o600))
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o600))
}
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond*10)
select {
case <-eventsCh:
assert.Fail(t, "unexpected event")
case <-time.After(time.Millisecond * 10):
}
clock.Step(time.Millisecond * 250)
for range 10 {
require.NoError(t, os.WriteFile(fp1, []byte{}, 0o600))
require.NoError(t, os.WriteFile(fp2, []byte{}, 0o600))
}
select {
case <-eventsCh:
assert.Fail(t, "unexpected event")
case <-time.After(time.Millisecond * 10):
}
assert.Eventually(t, clock.HasWaiters, time.Second, time.Millisecond*10)
clock.Step(time.Millisecond * 500)
for range 2 {
select {
case <-eventsCh:
case <-time.After(time.Second):
assert.Fail(t, "timeout waiting for event")
}
clock.Step(1)
}
})
}

28
fswatcher/unit.go Normal file
View File

@ -0,0 +1,28 @@
//go:build unit
// +build unit
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package fswatcher
import (
"github.com/dapr/kit/events/batcher"
)
func (f *FSWatcher) WithBatcher(b *batcher.Batcher[string, struct{}]) *FSWatcher {
f.batcher = b
return f
}

37
fswatcher/unit_test.go Normal file
View File

@ -0,0 +1,37 @@
//go:build unit
// +build unit
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package fswatcher
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dapr/kit/events/batcher"
)
func TestWithBatcher(t *testing.T) {
b := batcher.New[string, struct{}](batcher.Options{
Interval: time.Millisecond * 10,
})
f, err := New(Options{})
require.NoError(t, err)
f.WithBatcher(b)
assert.Equal(t, b, f.batcher)
}

View File

@ -1,75 +0,0 @@
/*
Copyright 2022 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package fswatcher
import (
"context"
"fmt"
"strings"
"time"
"github.com/fsnotify/fsnotify"
)
// Watch for changes to a directory on the filesystem and sends a notification to eventCh every time a file in the folder is changed.
// Although it's possible to watch for individual files, that's not recommended; watch for the file's parent folder instead.
// Note that changes are batched for 0.5 seconds before notifications are sent
func Watch(ctx context.Context, dir string, eventCh chan<- struct{}) error {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return fmt.Errorf("failed to create watcher: %w", err)
}
defer watcher.Close()
err = watcher.Add(dir)
if err != nil {
return fmt.Errorf("watcher error: %w", err)
}
batchCh := make(chan struct{}, 1)
defer close(batchCh)
for {
select {
// Watch for events
case event := <-watcher.Events:
if event.Op&fsnotify.Create == fsnotify.Create ||
event.Op&fsnotify.Write == fsnotify.Write {
if strings.Contains(event.Name, dir) {
// Batch the change
select {
case batchCh <- struct{}{}:
go func() {
time.Sleep(500 * time.Millisecond)
<-batchCh
eventCh <- struct{}{}
}()
default:
// There's already a change in the batch - nop
}
}
}
// Abort in case of errors
case err = <-watcher.Errors:
return fmt.Errorf("watcher listen error: %w", err)
// Stop on context canceled
case <-ctx.Done():
return ctx.Err()
}
}
}

View File

@ -1,166 +0,0 @@
/*
Copyright 2022 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package fswatcher
import (
"context"
"errors"
"os"
"path/filepath"
"testing"
"time"
)
func TestWatch(t *testing.T) {
baseDir := t.TempDir()
t.Run("watch for file changes", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Start watching
eventCh := make(chan struct{})
doneCh := make(chan struct{})
go func() {
err := Watch(ctx, baseDir, eventCh)
if errors.Is(err, context.Canceled) {
doneCh <- struct{}{}
} else {
panic(err)
}
}()
// Wait 1s for the watcher to start before touching the file
time.Sleep(time.Second)
statusCh := make(chan bool)
go func() {
select {
case <-eventCh:
statusCh <- true
case <-time.After(2 * time.Second):
statusCh <- false
}
}()
touchFile(baseDir, "file1")
// Expect a successful notification
if !(<-statusCh) {
t.Fatalf("did not get event within 2 seconds")
}
// Cancel and wait for the watcher to exit
cancel()
select {
case <-doneCh:
// All good - nop
case <-time.After(2 * time.Second):
t.Fatalf("did not stop within 2 seconds")
}
})
t.Run("changes are batched", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Start watching
eventCh := make(chan struct{})
doneCh := make(chan struct{})
go func() {
err := Watch(ctx, baseDir, eventCh)
if errors.Is(err, context.Canceled) {
doneCh <- struct{}{}
} else {
panic(err)
}
}()
// Wait 1s for the watcher to start before touching the file
time.Sleep(time.Second)
statusCh := make(chan bool, 1)
go func() {
for {
select {
case <-eventCh:
statusCh <- true
case <-time.After(2 * time.Second):
statusCh <- false
return
}
}
}()
// Touch the files
touchFile(baseDir, "file1")
touchFile(baseDir, "file2")
touchFile(baseDir, "file3")
// First message should be true
if !(<-statusCh) {
t.Fatalf("did not get event within 2 seconds")
}
// Second should be false
if <-statusCh {
t.Fatalf("got more than 1 change notification")
}
// Repeat
go func() {
for {
select {
case <-eventCh:
statusCh <- true
case <-time.After(2 * time.Second):
statusCh <- false
return
}
}
}()
touchFile(baseDir, "file1")
touchFile(baseDir, "file2")
touchFile(baseDir, "file3")
// First message should be true
if !(<-statusCh) {
t.Fatalf("did not get event within 2 seconds")
}
// Second should be false
if <-statusCh {
t.Fatalf("got more than 1 change notification")
}
// Cancel and wait for the watcher to exit
cancel()
select {
case <-doneCh:
// All good - nop
case <-time.After(2 * time.Second):
t.Fatalf("did not stop within 2 seconds")
}
})
}
func touchFile(base, name string) {
path := filepath.Join(base, name)
err := os.WriteFile(path, []byte("hola"), 0o666)
if err != nil {
panic(err)
}
}

36
go.mod
View File

@ -1,30 +1,48 @@
module github.com/dapr/kit
go 1.20
go 1.24.3
require (
github.com/alphadose/haxmap v1.3.1
github.com/cenkalti/backoff/v4 v4.2.1
github.com/fsnotify/fsnotify v1.6.0
github.com/lestrrat-go/httprc v1.0.4
github.com/lestrrat-go/jwx/v2 v2.0.12
github.com/fsnotify/fsnotify v1.7.0
github.com/lestrrat-go/httprc v1.0.5
github.com/lestrrat-go/jwx/v2 v2.0.21
github.com/mitchellh/mapstructure v1.5.1-0.20220423185008-bf980b35cac4
github.com/sirupsen/logrus v1.9.3
github.com/stretchr/testify v1.8.4
golang.org/x/crypto v0.12.0
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63
github.com/spf13/cast v1.5.1
github.com/spiffe/go-spiffe/v2 v2.5.0
github.com/stretchr/testify v1.10.0
github.com/tidwall/transform v0.0.0-20201103190739-32f242e2dbde
golang.org/x/crypto v0.39.0
golang.org/x/exp v0.0.0-20231006140011-7918f672742d
golang.org/x/tools v0.33.0
google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822
google.golang.org/grpc v1.73.0
google.golang.org/grpc/examples v0.0.0-20230224211313-3775f633ce20
google.golang.org/protobuf v1.36.6
k8s.io/apimachinery v0.26.9
k8s.io/utils v0.0.0-20230726121419-3b25d923346b
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect
github.com/go-jose/go-jose/v4 v4.0.5 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/lestrrat-go/blackmagic v1.0.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/lestrrat-go/blackmagic v1.0.2 // indirect
github.com/lestrrat-go/httpcc v1.0.1 // indirect
github.com/lestrrat-go/iter v1.0.2 // indirect
github.com/lestrrat-go/option v1.0.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/segmentio/asm v1.2.0 // indirect
golang.org/x/sys v0.11.0 // indirect
github.com/zeebo/errs v1.4.0 // indirect
golang.org/x/mod v0.25.0 // indirect
golang.org/x/net v0.41.0 // indirect
golang.org/x/sync v0.15.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.26.0 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

163
go.sum
View File

@ -1,94 +1,155 @@
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/alphadose/haxmap v1.3.1 h1:KmZh75duO1tC8pt3LmUwoTYiZ9sh4K52FX8p7/yrlqU=
github.com/alphadose/haxmap v1.3.1/go.mod h1:rjHw1IAqbxm0S3U5tD16GoKsiAd8FWx5BJ2IYqXwgmM=
github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM=
github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/decred/dcrd/crypto/blake256 v1.0.1/go.mod h1:2OfgNZ5wDpcsFmHmCK5gZTPcCXqlm2ArzUIkw9czNJo=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 h1:8UrgZ3GkP4i/CLijOJx79Yu+etlyjdBU4sfcs2WYQMs=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0=
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0XL9UY=
github.com/frankban/quicktest v1.14.4/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE=
github.com/go-jose/go-jose/v4 v4.0.5/go.mod h1:s3P1lRrkT8igV8D9OjyL4WRyHvjB6a4JSllnOrmmBOA=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/lestrrat-go/blackmagic v1.0.1 h1:lS5Zts+5HIC/8og6cGHb0uCcNCa3OUt1ygh3Qz2Fe80=
github.com/lestrrat-go/blackmagic v1.0.1/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.1.0 h1:Hsa8mG0dQ46ij8Sl2AYJDUv1oA9/d6Vk+3LG99Oe02g=
github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k=
github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU=
github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE=
github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E=
github.com/lestrrat-go/httprc v1.0.4 h1:bAZymwoZQb+Oq8MEbyipag7iSq6YIga8Wj6GOiJGdI8=
github.com/lestrrat-go/httprc v1.0.4/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo=
github.com/lestrrat-go/httprc v1.0.5 h1:bsTfiH8xaKOJPrg1R+E3iE/AWZr/x0Phj9PBTG/OLUk=
github.com/lestrrat-go/httprc v1.0.5/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo=
github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI=
github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4=
github.com/lestrrat-go/jwx/v2 v2.0.12 h1:3d589+5w/b9b7S3DneICPW16AqTyYXB7VRjgluSDWeA=
github.com/lestrrat-go/jwx/v2 v2.0.12/go.mod h1:Mq4KN1mM7bp+5z/W5HS8aCNs5RKZ911G/0y2qUjAQuQ=
github.com/lestrrat-go/option v1.0.0/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
github.com/lestrrat-go/jwx/v2 v2.0.21 h1:jAPKupy4uHgrHFEdjVjNkUgoBKtVDgrQPB/h55FHrR0=
github.com/lestrrat-go/jwx/v2 v2.0.21/go.mod h1:09mLW8zto6bWL9GbwnqAli+ArLf+5M33QLQPDggkUWM=
github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU=
github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
github.com/mitchellh/mapstructure v1.5.1-0.20220423185008-bf980b35cac4 h1:BpfhmLKZf+SjVanKKhCgf3bg+511DmU9eDQTen7LLbY=
github.com/mitchellh/mapstructure v1.5.1-0.20220423185008-bf980b35cac4/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys=
github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/spf13/cast v1.5.1 h1:R+kOtfhWQE6TVQzY+4D7wJLBgkdVasCEFxSUBYBYIlA=
github.com/spf13/cast v1.5.1/go.mod h1:b9PdjNptOpzXr7Rq1q9gJML/2cdGQAo69NKzQ10KN48=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spiffe/go-spiffe/v2 v2.5.0 h1:N2I01KCUkv1FAjZXJMwh95KK1ZIQLYbPfhaxw8WS0hE=
github.com/spiffe/go-spiffe/v2 v2.5.0/go.mod h1:P+NxobPc6wXhVtINNtFjNWGBTreew1GBUCwT2wPmb7g=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/transform v0.0.0-20201103190739-32f242e2dbde h1:AMNpJRc7P+GTwVbl8DkK2I9I8BBUzNiHuH/tlxrpan0=
github.com/tidwall/transform v0.0.0-20201103190739-32f242e2dbde/go.mod h1:MvrEmduDUz4ST5pGZ7CABCnOU5f3ZiOAZzT6b1A6nX8=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/zeebo/errs v1.4.0 h1:XNdoD/RRMKP7HD0UhJnIzUy74ISdGGxURlYG8HSWSfM=
github.com/zeebo/errs v1.4.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ=
go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y=
go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M=
go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE=
go.opentelemetry.io/otel/sdk v1.35.0 h1:iPctf8iprVySXSKJffSS79eOjl9pvxV9ZqOWT0QejKY=
go.opentelemetry.io/otel/sdk v1.35.0/go.mod h1:+ga1bZliga3DxJ3CQGg3updiaAJoNECOgJREo9KHGQg=
go.opentelemetry.io/otel/sdk/metric v1.35.0 h1:1RriWBmCKgkeHEhM7a2uMjMUfP7MsOF5JpUCaEqEI9o=
go.opentelemetry.io/otel/sdk/metric v1.35.0/go.mod h1:is6XYCUMpcKi+ZsOvfluY5YstFnhW0BidkR+gL+qN+w=
go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs=
go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ=
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 h1:fc6jSaCT0vBduLYZHYrBBNY4dsWuvgyff9noRNDdBeE=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok=
google.golang.org/grpc v1.73.0/go.mod h1:50sbHOUqWoCQGI8V2HQLJM0B+LMlIUjNSZmow7EVBQc=
google.golang.org/grpc/examples v0.0.0-20230224211313-3775f633ce20 h1:MLBCGN1O7GzIx+cBiwfYPwtmZ41U3Mn/cotLJciaArI=
google.golang.org/grpc/examples v0.0.0-20230224211313-3775f633ce20/go.mod h1:Nr5H8+MlGWr5+xX/STzdoEqJrO+YteqFbMyCsrb6mH0=
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
k8s.io/apimachinery v0.26.9 h1:5yAV9cFR7Z4gIorKcAjWnx4uxtxiFsERwq4Pvmx0CCg=
k8s.io/apimachinery v0.26.9/go.mod h1:qYzLkrQ9lhrZRh0jNKo2cfvf/R1/kQONnSiyB7NUJU0=
k8s.io/utils v0.0.0-20230726121419-3b25d923346b h1:sgn3ZU783SCgtaSJjpcVVlRqd6GSnlTLKgpAAttJvpI=
k8s.io/utils v0.0.0-20230726121419-3b25d923346b/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0=

100
grpccodes/grpccodes.go Normal file
View File

@ -0,0 +1,100 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package grpccodes
import (
"net/http"
"google.golang.org/grpc/codes"
)
// HTTPStatusFromCode converts a gRPC error code into the corresponding HTTP response status.
// https://github.com/grpc-ecosystem/grpc-gateway/blob/master/runtime/errors.go#L15
// See: https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto
func HTTPStatusFromCode(code codes.Code) int {
switch code {
case codes.OK:
return http.StatusOK
case codes.Canceled:
return http.StatusRequestTimeout
case codes.Unknown:
return http.StatusInternalServerError
case codes.InvalidArgument:
return http.StatusBadRequest
case codes.DeadlineExceeded:
return http.StatusGatewayTimeout
case codes.NotFound:
return http.StatusNotFound
case codes.AlreadyExists:
return http.StatusConflict
case codes.PermissionDenied:
return http.StatusForbidden
case codes.Unauthenticated:
return http.StatusUnauthorized
case codes.ResourceExhausted:
return http.StatusTooManyRequests
case codes.FailedPrecondition:
// Note, this deliberately doesn't translate to the similarly named '412 Precondition Failed' HTTP response status.
return http.StatusBadRequest
case codes.Aborted:
return http.StatusConflict
case codes.OutOfRange:
return http.StatusBadRequest
case codes.Unimplemented:
return http.StatusNotImplemented
case codes.Internal:
return http.StatusInternalServerError
case codes.Unavailable:
return http.StatusServiceUnavailable
case codes.DataLoss:
return http.StatusInternalServerError
default:
return http.StatusInternalServerError
}
}
// CodeFromHTTPStatus converts http status code to gRPC status code
// See: https://github.com/grpc/grpc/blob/master/doc/http-grpc-status-mapping.md
func CodeFromHTTPStatus(httpStatusCode int) codes.Code {
if httpStatusCode >= 200 && httpStatusCode < 300 {
return codes.OK
}
switch httpStatusCode {
case http.StatusRequestTimeout:
return codes.Canceled
case http.StatusInternalServerError:
return codes.Unknown
case http.StatusBadRequest:
return codes.Internal
case http.StatusGatewayTimeout:
return codes.DeadlineExceeded
case http.StatusNotFound:
return codes.NotFound
case http.StatusConflict:
return codes.AlreadyExists
case http.StatusForbidden:
return codes.PermissionDenied
case http.StatusUnauthorized:
return codes.Unauthenticated
case http.StatusTooManyRequests:
return codes.ResourceExhausted
case http.StatusNotImplemented:
return codes.Unimplemented
case http.StatusServiceUnavailable:
return codes.Unavailable
default:
return codes.Unknown
}
}

254
grpccodes/grpccodes_test.go Normal file
View File

@ -0,0 +1,254 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package grpccodes
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/codes"
)
func TestHTTPStatusFromCode(t *testing.T) {
tests := []struct {
name string
code codes.Code
result int
}{
{
name: "codes.OK-http.StatusOK",
code: codes.OK,
result: http.StatusOK,
},
{
name: "codes.Canceled-http.StatusRequestTimeout",
code: codes.Canceled,
result: http.StatusRequestTimeout,
},
{
name: "codes.Unknown-http.StatusInternalServerError",
code: codes.Unknown,
result: http.StatusInternalServerError,
},
{
name: "codes.InvalidArgument-http.StatusBadRequest",
code: codes.InvalidArgument,
result: http.StatusBadRequest,
},
{
name: "codes.DeadlineExceeded-http.StatusGatewayTimeout",
code: codes.DeadlineExceeded,
result: http.StatusGatewayTimeout,
},
{
name: "codes.NotFound-http.StatusNotFound",
code: codes.NotFound,
result: http.StatusNotFound,
},
{
name: "codes.AlreadyExists-http.StatusConflict",
code: codes.AlreadyExists,
result: http.StatusConflict,
},
{
name: "codes.PermissionDenied-http.StatusForbidden",
code: codes.PermissionDenied,
result: http.StatusForbidden,
},
{
name: "codes.Unauthenticated-http.StatusUnauthorized",
code: codes.Unauthenticated,
result: http.StatusUnauthorized,
},
{
name: "codes.ResourceExhausted-http.StatusTooManyRequests",
code: codes.ResourceExhausted,
result: http.StatusTooManyRequests,
},
{
name: "codes.FailedPrecondition-http.StatusBadRequest",
code: codes.FailedPrecondition,
result: http.StatusBadRequest,
},
{
name: "codes.Aborted-http.StatusConflict",
code: codes.Aborted,
result: http.StatusConflict,
},
{
name: "codes.OutOfRange-http.StatusBadRequest",
code: codes.OutOfRange,
result: http.StatusBadRequest,
},
{
name: "codes.Unimplemented-http.StatusNotImplemented",
code: codes.Unimplemented,
result: http.StatusNotImplemented,
},
{
name: "codes.Internal-http.StatusInternalServerError",
code: codes.Internal,
result: http.StatusInternalServerError,
},
{
name: "codes.Unavailable-http.StatusServiceUnavailable",
code: codes.Unavailable,
result: http.StatusServiceUnavailable,
},
{
name: "codes.DataLoss-http.StatusInternalServerError",
code: codes.DataLoss,
result: http.StatusInternalServerError,
},
{
name: "codes.InvalidCode-http.StatusInternalServerError",
code: 57, // codes.Code Does not exist
result: http.StatusInternalServerError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rt := HTTPStatusFromCode(tt.code)
assert.Equal(t, tt.result, rt)
})
}
}
func TestCodeFromHTTPStatus(t *testing.T) {
tests := []struct {
name string
httpcode int
result codes.Code
}{
{
name: "http.OK-codes.OK",
httpcode: http.StatusOK,
result: codes.OK,
},
{
name: "http.StatusCreated-codes.OK",
httpcode: http.StatusCreated,
result: codes.OK,
},
{
name: "http.StatusAccepted-codes.OK",
httpcode: http.StatusAccepted,
result: codes.OK,
},
{
name: "http.StatusNonAuthoritativeInfo-codes.OK",
httpcode: http.StatusNonAuthoritativeInfo,
result: codes.OK,
},
{
name: "http.StatusNoContent-codes.OK",
httpcode: http.StatusNoContent,
result: codes.OK,
},
{
name: "http.StatusResetContent-codes.OK",
httpcode: http.StatusResetContent,
result: codes.OK,
},
{
name: "http.StatusPartialContent-codes.OK",
httpcode: http.StatusPartialContent,
result: codes.OK,
},
{
name: "http.StatusMultiStatus-codes.OK",
httpcode: http.StatusMultiStatus,
result: codes.OK,
},
{
name: "http.StatusAlreadyReported-codes.OK",
httpcode: http.StatusAlreadyReported,
result: codes.OK,
},
{
name: "http.StatusIMUsed-codes.OK",
httpcode: http.StatusOK,
result: codes.OK,
},
{
name: "http.StatusRequestTimeout-codes.Canceled",
httpcode: http.StatusRequestTimeout,
result: codes.Canceled,
},
{
name: "http.StatusInternalServerError-codes.Unknown",
httpcode: http.StatusInternalServerError,
result: codes.Unknown,
},
{
name: "http.StatusBadRequest-codes.Internal",
httpcode: http.StatusBadRequest,
result: codes.Internal,
},
{
name: "http.StatusGatewayTimeout-codes.DeadlineExceeded",
httpcode: http.StatusGatewayTimeout,
result: codes.DeadlineExceeded,
},
{
name: "http.StatusNotFound-codes.NotFound",
httpcode: http.StatusNotFound,
result: codes.NotFound,
},
{
name: "http.StatusConflict-codes.AlreadyExists",
httpcode: http.StatusConflict,
result: codes.AlreadyExists,
},
{
name: "http.StatusForbidden-codes.PermissionDenied",
httpcode: http.StatusForbidden,
result: codes.PermissionDenied,
},
{
name: "http.StatusUnauthorized-codes.Unauthenticated",
httpcode: http.StatusUnauthorized,
result: codes.Unauthenticated,
},
{
name: "http.StatusTooManyRequests-codes.ResourceExhausted",
httpcode: http.StatusTooManyRequests,
result: codes.ResourceExhausted,
},
{
name: "http.StatusNotImplemented-codes.Unimplemented",
httpcode: http.StatusNotImplemented,
result: codes.Unimplemented,
},
{
name: "http.StatusServiceUnavailable-codes.Unavailable",
httpcode: http.StatusServiceUnavailable,
result: codes.Unavailable,
},
{
name: "HTTPStatusDoesNotExist-codes.Unavailable",
httpcode: 999,
result: codes.Unknown,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rt := CodeFromHTTPStatus(tt.httpcode)
assert.Equal(t, tt.result, rt)
})
}
}

View File

@ -22,6 +22,7 @@ package jwkscache
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"errors"
"fmt"
@ -36,6 +37,7 @@ import (
"github.com/lestrrat-go/httprc"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/dapr/kit/crypto/pem"
"github.com/dapr/kit/fswatcher"
"github.com/dapr/kit/logger"
)
@ -49,11 +51,11 @@ const (
// JWKSCache is a cache of JWKS objects.
// It fetches a JWKS object from a file on disk, a URL, or from a value passed as-is.
// TODO: Move this to dapr/kit and use it for the JWKS crypto component too
type JWKSCache struct {
location string
requestTimeout time.Duration
minRefreshInterval time.Duration
caCertificate string
jwks jwk.Set
logger logger.Logger
@ -113,6 +115,12 @@ func (c *JWKSCache) SetMinRefreshInterval(minRefreshInterval time.Duration) {
c.minRefreshInterval = minRefreshInterval
}
// SetCACertificate sets the CA certificate to trust.
// Can be a path to a local file or an actual, PEM-encoded certificate
func (c *JWKSCache) SetCACertificate(caCertificate string) {
c.caCertificate = caCertificate
}
// SetHTTPClient sets the HTTP client object to use.
func (c *JWKSCache) SetHTTPClient(client *http.Client) {
c.client = client
@ -184,12 +192,28 @@ func (c *JWKSCache) initJWKSFromURL(ctx context.Context, url string) error {
// We also need to create a custom HTTP client (if we don't have one already) because otherwise there's no timeout.
if c.client == nil {
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
}
// Load CA certificates if we have one
if c.caCertificate != "" {
caCert, err := pem.GetPEM(c.caCertificate)
if err != nil {
return fmt.Errorf("failed to load CA certificate: %w", err)
}
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(caCert) {
return errors.New("failed to add root certificate to certificate pool")
}
tlsConfig.RootCAs = caCertPool
}
c.client = &http.Client{
Timeout: c.requestTimeout,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
},
TLSClientConfig: tlsConfig,
},
}
}
@ -223,10 +247,16 @@ func (c *JWKSCache) initJWKSFromFile(ctx context.Context, file string) error {
eventCh := make(chan struct{})
loaded := make(chan error, 1) // Needs to be buffered to prevent a goroutine leak
go func() {
watchErr := fswatcher.Watch(ctx, path, eventCh)
if watchErr != nil && !errors.Is(watchErr, context.Canceled) {
// Log errors only
c.logger.Errorf("Error while watching for changes to the local JWKS file: %v", watchErr)
// Log errors only
fw, err := fswatcher.New(fswatcher.Options{
Targets: []string{path},
})
if err != nil {
c.logger.Errorf("Error while watching for changes to the local JWKS file: %v", err)
return
}
if err = fw.Run(ctx, eventCh); err != nil {
c.logger.Errorf("Error while watching for changes to the local JWKS file: %v", err)
}
}()
go func() {

View File

@ -40,7 +40,7 @@ func TestJWKSCache(t *testing.T) {
t.Run("init with value", func(t *testing.T) {
cache := NewJWKSCache(testJWKS1, log)
err := cache.initCache(context.Background())
err := cache.initCache(t.Context())
require.NoError(t, err)
set := cache.KeySet()
@ -53,7 +53,7 @@ func TestJWKSCache(t *testing.T) {
t.Run("init with base64-encoded value", func(t *testing.T) {
cache := NewJWKSCache(base64.StdEncoding.EncodeToString([]byte(testJWKS1)), log)
err := cache.initCache(context.Background())
err := cache.initCache(t.Context())
require.NoError(t, err)
set := cache.KeySet()
@ -68,12 +68,12 @@ func TestJWKSCache(t *testing.T) {
// Create a temporary directory and put the JWKS in there
dir := t.TempDir()
path := filepath.Join(dir, "jwks.json")
err := os.WriteFile(path, []byte(testJWKS1), 0o666)
err := os.WriteFile(path, []byte(testJWKS1), 0o600)
require.NoError(t, err)
// Should wait for first file to be loaded before initialization is reported as completed
cache := NewJWKSCache(path, log)
err = cache.initCache(context.Background())
err = cache.initCache(t.Context())
require.NoError(t, err)
set := cache.KeySet()
@ -87,7 +87,7 @@ func TestJWKSCache(t *testing.T) {
time.Sleep(time.Second)
// Update the file and verify it's picked up
err = os.WriteFile(path, []byte(testJWKS2), 0o666)
err = os.WriteFile(path, []byte(testJWKS2), 0o600)
require.NoError(t, err)
assert.Eventually(t, func() bool {
@ -127,7 +127,7 @@ func TestJWKSCache(t *testing.T) {
cache := NewJWKSCache("http://localhost/jwks.json", log)
cache.SetHTTPClient(client)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
defer cancel()
err := cache.initCache(ctx)
require.NoError(t, err)
@ -142,7 +142,7 @@ func TestJWKSCache(t *testing.T) {
t.Run("start and wait for init", func(t *testing.T) {
cache := NewJWKSCache(testJWKS1, log)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
defer cancel()
// Start in background
@ -157,7 +157,7 @@ func TestJWKSCache(t *testing.T) {
// Canceling the context should make Start() return
cancel()
require.Nil(t, <-errCh)
require.NoError(t, <-errCh)
})
t.Run("start and init fails", func(t *testing.T) {
@ -174,7 +174,7 @@ func TestJWKSCache(t *testing.T) {
cache := NewJWKSCache("https://localhost/jwks.json", log)
cache.SetHTTPClient(client)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
defer cancel()
// Start in background
@ -186,7 +186,7 @@ func TestJWKSCache(t *testing.T) {
// Wait for initialization
err := cache.WaitForCacheReady(ctx)
require.Error(t, err)
assert.ErrorContains(t, err, "failed to fetch JWKS")
require.ErrorContains(t, err, "failed to fetch JWKS")
// Canceling the context should make Start() return with the init error
cancel()
@ -194,7 +194,7 @@ func TestJWKSCache(t *testing.T) {
})
t.Run("start and init times out", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1500*time.Millisecond)
ctx, cancel := context.WithTimeout(t.Context(), 1500*time.Millisecond)
defer cancel()
// Create a custom HTTP client with a RoundTripper that doesn't require starting a TCP listener
@ -223,10 +223,10 @@ func TestJWKSCache(t *testing.T) {
}()
// Wait for initialization
err := cache.WaitForCacheReady(context.Background())
err := cache.WaitForCacheReady(t.Context())
require.Error(t, err)
assert.ErrorContains(t, err, "failed to fetch JWKS")
assert.ErrorIs(t, err, context.DeadlineExceeded)
require.ErrorContains(t, err, "failed to fetch JWKS")
require.ErrorIs(t, err, context.DeadlineExceeded)
// Canceling the context should make Start() return with the init error
cancel()

View File

@ -25,7 +25,6 @@ import (
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
)
const fakeLoggerName = "fakeLogger"
@ -150,7 +149,7 @@ func TestJSONLoggerFields(t *testing.T) {
b, _ := buf.ReadBytes('\n')
var o map[string]interface{}
assert.NoError(t, json.Unmarshal(b, &o))
require.NoError(t, json.Unmarshal(b, &o))
// assert
assert.Equal(t, tt.appID, o[logFieldAppID])
@ -160,7 +159,7 @@ func TestJSONLoggerFields(t *testing.T) {
assert.Equal(t, fakeLoggerName, o[logFieldScope])
assert.Equal(t, tt.message, o[logFieldMessage])
_, err := time.Parse(time.RFC3339, o[logFieldTimeStamp].(string))
assert.NoError(t, err)
require.NoError(t, err)
})
}
}
@ -286,7 +285,7 @@ func TestWithTypeFields(t *testing.T) {
testLogger.Info("testLogger with log LogType")
b, _ = buf.ReadBytes('\n')
maps.Clear(o)
clear(o)
require.NoError(t, json.Unmarshal(b, &o))
assert.Equalf(t, LogTypeLog, o[logFieldType], "testLogger must be %s type", LogTypeLog)
@ -309,12 +308,12 @@ func TestWithFields(t *testing.T) {
}).Info("🙃")
b, _ := buf.ReadBytes('\n')
maps.Clear(o)
clear(o)
require.NoError(t, json.Unmarshal(b, &o))
assert.Equal(t, "🙃", o["msg"])
assert.Equal(t, "world", o["hello"])
assert.Equal(t, float64(42), o["answer"])
assert.InDelta(t, float64(42), o["answer"], 000.1)
// Test with other fields
testLogger.WithFields(map[string]any{
@ -322,7 +321,7 @@ func TestWithFields(t *testing.T) {
}).Info("🐶")
b, _ = buf.ReadBytes('\n')
maps.Clear(o)
clear(o)
require.NoError(t, json.Unmarshal(b, &o))
assert.Equal(t, "🐶", o["msg"])
@ -336,7 +335,7 @@ func TestWithFields(t *testing.T) {
testLogger.Info("🤔")
b, _ = buf.ReadBytes('\n')
maps.Clear(o)
clear(o)
require.NoError(t, json.Unmarshal(b, &o))
assert.Equal(t, "🤔", o["msg"])

View File

@ -14,7 +14,6 @@ limitations under the License.
package logger
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
@ -78,7 +77,7 @@ func TestToLogLevel(t *testing.T) {
func TestNewContext(t *testing.T) {
t.Run("input nil logger", func(t *testing.T) {
ctx := NewContext(context.Background(), nil)
ctx := NewContext(t.Context(), nil)
assert.NotNil(t, ctx, "ctx is not nil")
logger := FromContextOrDefault(ctx)
@ -91,7 +90,7 @@ func TestNewContext(t *testing.T) {
logger := NewLogger(testLoggerName)
assert.NotNil(t, logger)
ctx := NewContext(context.Background(), logger)
ctx := NewContext(t.Context(), logger)
assert.NotNil(t, ctx, "ctx is not nil")
logger2 := FromContextOrDefault(ctx)
assert.NotNil(t, logger2)

Some files were not shown because too many files have changed in this diff Show More