From d098e38d6a4c12c4c1a2e64ed724e4bd3e528a80 Mon Sep 17 00:00:00 2001 From: Josh van Leeuwen Date: Thu, 16 Feb 2023 22:18:35 +0000 Subject: [PATCH] Propagate context from caller to appropriate places in the code (#2474) * Propagates contexts to callers where appropriate. Signed-off-by: joshvanl * Updates units tests with new func signature Signed-off-by: joshvanl * Fix linting errors Signed-off-by: joshvanl * Add atomic gate to alicloud rocketmq close channel. Signed-off-by: joshvanl * bindings/aws/kinesis use a separate ctx variable name Signed-off-by: joshvanl * binding/kafka: use atomic to prevent closing the channel twice Signed-off-by: joshvanl * bindings/mqtt3: use atomic bool to prevent close channel being closed multiple times Signed-off-by: joshvanl * bindings/mqtt3: use Background context for handle operations:w Signed-off-by: joshvanl * state/cocroachdb: add context to Ping() Signed-off-by: joshvanl * bindings/postgres: add comment explaining use of context. Signed-off-by: joshvanl * Adds comment header to health/pinger.go Signed-off-by: joshvanl * pubsub/aws/snssqs: add waitgroup to wait for all go routines to finish and block on Close(). Shuts down the subscription if there are no topic handlers. Signed-off-by: joshvanl * pubsub/mqtt3: add atomic bool to prevent multiple channel closes. Add wait group to block close on all goroutines to finish. Signed-off-by: joshvanl * pubsub/rabbitmq: fixes race conditions, uses atomic to prevent multiple closes, add wait group to block close on all goroutines Signed-off-by: joshvanl * pubsub/redis: revert ctx passed when it could be cancelled. Add wait group wait when closing. Signed-off-by: joshvanl * state/postges: pass context in init, and wait group on close. Signed-off-by: joshvanl * Update all `Ping()` to `PingContext()` where possible. Signed-off-by: joshvanl * state/in-memory: add atomic bool to prevent closing channel multiple times. Add wait group to block on close() Signed-off-by: joshvanl * state/mysql: don't use same ctx variable name Signed-off-by: joshvanl * Pass correct loop context to redis go routines Signed-off-by: joshvanl * Rename context when creating timeouts in state Signed-off-by: joshvanl * Remove state.Features() from requiring a context Signed-off-by: joshvanl * Revert wasm request handle Close func to be without context to implement io.Closer interface. Add 5 second timeout. Add io.Closer assertion in test. Signed-off-by: joshvanl * Remove superfluous go lint vet directive Signed-off-by: joshvanl * Change Configuration Init function to take context Signed-off-by: joshvanl * Updates input binding interface to include a `Close() error` function. `Close` blocks until all resources have been released and go routines have returned. Signed-off-by: joshvanl * Change `Close() error` in input binding struct to `io.Closer` interface. Signed-off-by: joshvanl * Update go.mod files to point to dapr/dapr PR https://github.com/dapr/dapr/pull/5831 Signed-off-by: joshvanl * pubsub/redis: watch closeCh to shutdown worker instead of init context. Signed-off-by: joshvanl * pubsub/aws/snssqs + bindings/kubemq: ensure closeCh is caught so Close correctly returns Signed-off-by: joshvanl * Close kubemq binding client on close. Ensure kafka consumer channel cannot be closed more than once. Signed-off-by: joshvanl * Tweaks Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * Fixed cert tests Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * binding/mqtt3: add inline Background context instead of passing to handleMessage Signed-off-by: joshvanl * pubsub/mqtt3: remove context from createSubscriberClientOptions Signed-off-by: joshvanl * pubsub/mqtt3: Remove `ResetConnection` func Signed-off-by: joshvanl * pubsub/kafka: Don't resubscribe if Subscribe is cancelled. Signed-off-by: joshvanl * binding/mqtt3: don't use context to control establishing connection Signed-off-by: joshvanl * bindings/mqtt3: Fix linting errors Signed-off-by: joshvanl --------- Signed-off-by: joshvanl Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> Co-authored-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> Co-authored-by: Yaron Schneider --- .build-tools/go.mod | 2 + .../conformance/temporal/worker/go.mod | 2 + bindings/alicloud/dingtalk/webhook/webhook.go | 9 +- .../alicloud/dingtalk/webhook/webhook_test.go | 19 ++- bindings/alicloud/oss/oss.go | 2 +- bindings/alicloud/sls/sls.go | 2 +- bindings/alicloud/tablestore/tablestore.go | 2 +- .../alicloud/tablestore/tablestore_test.go | 2 +- bindings/apns/apns.go | 2 +- bindings/apns/apns_test.go | 22 +-- bindings/aws/dynamodb/dynamodb.go | 2 +- bindings/aws/kinesis/kinesis.go | 72 ++++++++-- bindings/aws/s3/s3.go | 2 +- bindings/aws/ses/ses.go | 2 +- bindings/aws/sns/sns.go | 2 +- bindings/aws/sqs/sqs.go | 46 +++++- bindings/azure/blobstorage/blobstorage.go | 2 +- bindings/azure/cosmosdb/cosmosdb.go | 8 +- .../cosmosdbgremlinapi/cosmosdbgremlinapi.go | 2 +- bindings/azure/eventgrid/eventgrid.go | 35 ++++- bindings/azure/eventhubs/eventhubs.go | 2 +- .../eventhubs/eventhubs_integration_test.go | 4 +- .../servicebusqueues/servicebusqueues.go | 44 +++++- bindings/azure/signalr/signalr.go | 2 +- bindings/azure/storagequeues/storagequeues.go | 67 +++++++-- .../azure/storagequeues/storagequeues_test.go | 58 +++++--- bindings/cloudflare/queues/cfqueues.go | 2 +- bindings/commercetools/commercetools.go | 2 +- bindings/cron/cron.go | 29 +++- bindings/cron/cron_test.go | 39 +++--- bindings/dubbo/dubbo_output.go | 2 +- bindings/dubbo/dubbo_output_test.go | 5 +- bindings/gcp/bucket/bucket.go | 3 +- bindings/gcp/pubsub/pubsub.go | 23 ++- bindings/graphql/graphql.go | 2 +- bindings/http/http.go | 4 +- bindings/http/http_test.go | 4 +- bindings/huawei/obs/obs.go | 2 +- bindings/huawei/obs/obs_test.go | 10 +- bindings/influx/influx.go | 2 +- bindings/influx/influx_test.go | 2 +- bindings/input_binding.go | 10 +- bindings/kafka/kafka.go | 57 ++++---- bindings/kubemq/kubemq.go | 65 ++++++--- bindings/kubemq/kubemq_integration_test.go | 8 +- bindings/kubernetes/kubernetes.go | 46 +++++- bindings/localstorage/localstorage.go | 2 +- bindings/mqtt3/mqtt.go | 131 +++++++++--------- bindings/mqtt3/mqtt_integration_test.go | 4 +- bindings/mqtt3/mqtt_test.go | 3 +- bindings/mysql/mysql.go | 4 +- bindings/mysql/mysql_integration_test.go | 2 +- bindings/nacos/nacos.go | 30 +++- bindings/nacos/nacos_test.go | 2 +- bindings/output_binding.go | 6 +- bindings/postgres/postgres.go | 6 +- bindings/postgres/postgres_test.go | 2 +- bindings/postmark/postmark.go | 2 +- bindings/rabbitmq/rabbitmq.go | 41 +++++- .../rabbitmq/rabbitmq_integration_test.go | 14 +- bindings/redis/redis.go | 15 +- bindings/redis/redis_test.go | 3 - bindings/rethinkdb/statechange/statechange.go | 39 +++++- .../rethinkdb/statechange/statechange_test.go | 2 +- bindings/rocketmq/rocketmq.go | 40 ++++-- bindings/rocketmq/rocketmq_test.go | 4 +- bindings/smtp/smtp.go | 2 +- bindings/twilio/sendgrid/sendgrid.go | 2 +- bindings/twilio/sms/sms.go | 8 +- bindings/twilio/sms/sms_test.go | 10 +- bindings/twitter/twitter.go | 36 +++-- bindings/twitter/twitter_test.go | 13 +- bindings/zeebe/command/command.go | 4 +- bindings/zeebe/command/command_test.go | 4 +- bindings/zeebe/command/throw_error.go | 4 +- bindings/zeebe/jobworker/jobworker.go | 33 ++++- bindings/zeebe/jobworker/jobworker_test.go | 21 +-- configuration/azure/appconfig/appconfig.go | 2 +- .../azure/appconfig/appconfig_test.go | 4 +- configuration/postgres/postgres.go | 4 +- configuration/redis/redis.go | 10 +- configuration/store.go | 2 +- go.mod | 2 + health/pinger.go | 19 ++- internal/component/kafka/consumer.go | 11 +- internal/component/kafka/kafka.go | 3 +- internal/component/kafka/sasl_oauthbearer.go | 3 +- lock/redis/standalone.go | 17 +-- lock/redis/standalone_test.go | 8 +- lock/store.go | 2 +- middleware/http/bearer/bearer_middleware.go | 4 +- middleware/http/oauth2/oauth2_middleware.go | 3 +- .../mocks/mock_TokenProviderInterface.go | 3 +- .../oauth2clientcredentials_middleware.go | 10 +- ...oauth2clientcredentials_middleware_test.go | 13 +- middleware/http/opa/middleware.go | 4 +- middleware/http/opa/middleware_test.go | 3 +- .../http/ratelimit/ratelimit_middleware.go | 3 +- .../routeralias/routeralias_middleware.go | 2 +- .../routeralias_middleware_test.go | 3 +- .../routerchecker/routerchecker_middleware.go | 3 +- .../routerchecker_middleware_test.go | 5 +- middleware/http/sentinel/middleware.go | 3 +- middleware/http/sentinel/middleware_test.go | 3 +- middleware/http/wasm/benchmark_test.go | 3 +- middleware/http/wasm/example/go.mod | 2 + middleware/http/wasm/httpwasm.go | 12 +- middleware/http/wasm/httpwasm_test.go | 12 +- .../http/wasm/internal/e2e-guests/go.mod | 2 + middleware/http/wasm/internal/e2e_test.go | 3 +- middleware/middleware.go | 3 +- pubsub/aws/snssqs/snssqs.go | 88 ++++++++---- pubsub/azure/eventhubs/eventhubs.go | 2 +- .../eventhubs/eventhubs_integration_test.go | 2 +- pubsub/azure/servicebus/queues/servicebus.go | 2 +- pubsub/azure/servicebus/topics/servicebus.go | 2 +- pubsub/gcp/pubsub/pubsub.go | 4 +- pubsub/hazelcast/hazelcast.go | 2 +- pubsub/in-memory/in-memory.go | 2 +- pubsub/in-memory/in-memory_test.go | 8 +- pubsub/jetstream/jetstream.go | 2 +- pubsub/kafka/kafka.go | 4 +- pubsub/kubemq/kubemq.go | 5 +- pubsub/kubemq/kubemq_test.go | 10 +- pubsub/mqtt3/mqtt.go | 131 +++++++----------- pubsub/mqtt3/mqtt_test.go | 1 - pubsub/natsstreaming/natsstreaming.go | 2 +- pubsub/pubsub.go | 6 +- pubsub/pulsar/pulsar.go | 2 +- pubsub/rabbitmq/rabbitmq.go | 56 ++++++-- pubsub/rabbitmq/rabbitmq_test.go | 64 ++++----- pubsub/redis/redis.go | 72 +++++++--- pubsub/redis/redis_test.go | 1 - pubsub/rocketmq/rocketmq.go | 2 +- pubsub/rocketmq/rocketmq_test.go | 4 +- pubsub/solace/amqp/amqp.go | 18 ++- .../alicloud/parameterstore/parameterstore.go | 2 +- .../parameterstore/parameterstore_test.go | 6 +- .../aws/parameterstore/parameterstore.go | 2 +- .../aws/parameterstore/parameterstore_test.go | 2 +- .../aws/secretmanager/secretmanager.go | 2 +- .../aws/secretmanager/secretmanager_test.go | 2 +- secretstores/azure/keyvault/keyvault.go | 2 +- secretstores/azure/keyvault/keyvault_test.go | 9 +- .../gcp/secretmanager/secretmanager.go | 7 +- .../gcp/secretmanager/secretmanager_test.go | 13 +- secretstores/hashicorp/vault/vault.go | 2 +- secretstores/hashicorp/vault/vault_test.go | 29 ++-- secretstores/huaweicloud/csms/csms.go | 2 +- secretstores/kubernetes/kubernetes.go | 2 +- secretstores/local/env/envstore.go | 2 +- secretstores/local/env/envstore_test.go | 8 +- secretstores/local/file/filestore.go | 2 +- secretstores/local/file/filestore_test.go | 14 +- secretstores/secret_store.go | 6 +- secretstores/tencentcloud/ssm/ssm.go | 2 +- state/aerospike/aerospike.go | 2 +- state/alicloud/tablestore/tablestore.go | 2 +- state/alicloud/tablestore/tablestore_test.go | 2 +- state/aws/dynamodb/dynamodb.go | 2 +- state/aws/dynamodb/dynamodb_test.go | 8 +- state/azure/blobstorage/blobstorage.go | 6 +- state/azure/blobstorage/blobstorage_test.go | 9 +- state/azure/cosmosdb/cosmosdb.go | 40 +++--- state/azure/tablestorage/tablestorage.go | 13 +- state/cassandra/cassandra.go | 2 +- state/cloudflare/workerskv/workerskv.go | 2 +- state/cockroachdb/cockroachdb.go | 8 +- state/cockroachdb/cockroachdb_access.go | 10 +- .../cockroachdb_integration_test.go | 4 +- state/cockroachdb/cockroachdb_test.go | 6 +- state/cockroachdb/dbaccess.go | 4 +- state/couchbase/couchbase.go | 2 +- state/gcp/firestore/firestore.go | 3 +- state/hashicorp/consul/consul.go | 2 +- state/hazelcast/hazelcast.go | 2 +- state/in-memory/in_memory.go | 37 +++-- state/in-memory/in_memory_test.go | 2 +- state/jetstream/jetstream.go | 2 +- state/jetstream/jetstream_test.go | 2 +- state/memcached/memcached.go | 4 +- state/mongodb/mongodb.go | 14 +- state/mysql/mysql.go | 53 ++++--- state/mysql/mysql_integration_test.go | 10 +- state/mysql/mysql_test.go | 30 ++-- state/oci/objectstorage/objectstorage.go | 30 ++-- .../objectstorage_integration_test.go | 32 ++--- state/oci/objectstorage/objectstorage_test.go | 34 ++--- state/oracledatabase/dbaccess.go | 4 +- state/oracledatabase/oracledatabase.go | 8 +- .../oracledatabase_integration_test.go | 4 +- state/oracledatabase/oracledatabase_test.go | 10 +- state/oracledatabase/oracledatabaseaccess.go | 8 +- state/postgresql/dbaccess.go | 2 +- state/postgresql/postgresdbaccess.go | 38 +++-- state/postgresql/postgresql.go | 4 +- .../postgresql/postgresql_integration_test.go | 4 +- state/postgresql/postgresql_test.go | 4 +- state/redis/redis.go | 37 ++--- state/redis/redis_test.go | 23 ++- state/rethinkdb/rethinkdb.go | 26 ++-- state/rethinkdb/rethinkdb_test.go | 8 +- state/sqlite/sqlite.go | 8 +- state/sqlite/sqlite_dbaccess.go | 4 +- state/sqlite/sqlite_integration_test.go | 4 +- state/sqlite/sqlite_test.go | 8 +- state/sqlserver/sqlserver.go | 2 +- state/sqlserver/sqlserver_integration_test.go | 4 +- state/sqlserver/sqlserver_test.go | 7 +- state/store.go | 6 +- state/store_test.go | 4 +- state/transactional_store.go | 2 +- state/zookeeper/zk.go | 2 +- .../bindings/azure/blobstorage/go.mod | 2 + .../bindings/azure/blobstorage/go.sum | 4 +- .../bindings/azure/cosmosdb/go.mod | 2 + .../bindings/azure/cosmosdb/go.sum | 4 +- .../bindings/azure/eventhubs/go.mod | 2 + .../bindings/azure/eventhubs/go.sum | 4 +- .../bindings/azure/servicebusqueues/go.mod | 2 + .../bindings/azure/servicebusqueues/go.sum | 4 +- .../bindings/azure/storagequeues/go.mod | 2 + .../bindings/azure/storagequeues/go.sum | 4 +- tests/certification/bindings/cron/go.mod | 2 + tests/certification/bindings/cron/go.sum | 4 +- tests/certification/bindings/dubbo/go.mod | 2 + tests/certification/bindings/dubbo/go.sum | 4 +- tests/certification/bindings/kafka/go.mod | 2 + tests/certification/bindings/kafka/go.sum | 4 +- .../bindings/localstorage/go.mod | 2 + .../bindings/localstorage/go.sum | 4 +- tests/certification/bindings/nacos/go.mod | 2 + tests/certification/bindings/nacos/go.sum | 4 +- tests/certification/bindings/postgres/go.mod | 2 + tests/certification/bindings/postgres/go.sum | 4 +- tests/certification/bindings/rabbitmq/go.mod | 2 + tests/certification/bindings/rabbitmq/go.sum | 4 +- tests/certification/bindings/redis/go.mod | 2 + tests/certification/bindings/redis/go.sum | 4 +- tests/certification/go.mod | 2 + tests/certification/go.sum | 4 +- tests/certification/pubsub/aws/snssqs/go.mod | 2 + tests/certification/pubsub/aws/snssqs/go.sum | 4 +- .../pubsub/azure/eventhubs/go.mod | 2 + .../pubsub/azure/eventhubs/go.sum | 4 +- .../pubsub/azure/servicebus/topics/go.mod | 2 + .../pubsub/azure/servicebus/topics/go.sum | 4 +- tests/certification/pubsub/kafka/go.mod | 2 + tests/certification/pubsub/kafka/go.sum | 4 +- tests/certification/pubsub/mqtt3/go.mod | 2 + tests/certification/pubsub/mqtt3/go.sum | 4 +- tests/certification/pubsub/pulsar/go.mod | 2 + tests/certification/pubsub/pulsar/go.sum | 4 +- tests/certification/pubsub/rabbitmq/go.mod | 2 + tests/certification/pubsub/rabbitmq/go.sum | 4 +- .../secretstores/azure/keyvault/go.mod | 2 + .../secretstores/azure/keyvault/go.sum | 4 +- .../secretstores/hashicorp/vault/go.mod | 2 + .../secretstores/hashicorp/vault/go.sum | 4 +- .../secretstores/local/env/go.mod | 2 + .../secretstores/local/env/go.sum | 4 +- .../secretstores/local/file/go.mod | 2 + .../secretstores/local/file/go.sum | 4 +- tests/certification/state/aws/dynamodb/go.mod | 2 + tests/certification/state/aws/dynamodb/go.sum | 4 +- .../state/azure/blobstorage/go.mod | 2 + .../state/azure/blobstorage/go.sum | 4 +- .../certification/state/azure/cosmosdb/go.mod | 2 + .../certification/state/azure/cosmosdb/go.sum | 4 +- .../state/azure/tablestorage/go.mod | 2 + .../state/azure/tablestorage/go.sum | 4 +- tests/certification/state/cassandra/go.mod | 2 + tests/certification/state/cassandra/go.sum | 4 +- .../state/cockroachdb/cockroachdb_test.go | 4 +- tests/certification/state/cockroachdb/go.mod | 2 + tests/certification/state/cockroachdb/go.sum | 4 +- tests/certification/state/memcached/go.mod | 2 + tests/certification/state/memcached/go.sum | 4 +- tests/certification/state/mongodb/go.mod | 2 + tests/certification/state/mongodb/go.sum | 4 +- tests/certification/state/mysql/go.mod | 2 + tests/certification/state/mysql/go.sum | 4 +- tests/certification/state/mysql/mysql_test.go | 13 +- tests/certification/state/postgresql/go.mod | 2 + tests/certification/state/postgresql/go.sum | 4 +- .../state/postgresql/postgresql_test.go | 20 +-- tests/certification/state/redis/go.mod | 2 + tests/certification/state/redis/go.sum | 4 +- tests/certification/state/sqlserver/go.mod | 2 + tests/certification/state/sqlserver/go.sum | 4 +- tests/conformance/bindings/bindings.go | 14 +- .../configuration/configuration.go | 2 +- tests/conformance/pubsub/pubsub.go | 4 +- .../conformance/secretstores/secretstores.go | 4 +- tests/conformance/state/state.go | 4 +- tests/e2e/bindings/zeebe/helper.go | 4 +- tests/e2e/pubsub/jetstream/go.mod | 2 + 297 files changed, 1797 insertions(+), 1120 deletions(-) diff --git a/.build-tools/go.mod b/.build-tools/go.mod index 58f6ed770..b3860c1b0 100644 --- a/.build-tools/go.mod +++ b/.build-tools/go.mod @@ -18,3 +18,5 @@ require ( github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect ) + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/.github/infrastructure/conformance/temporal/worker/go.mod b/.github/infrastructure/conformance/temporal/worker/go.mod index b1b6d2bf6..c485e6d77 100644 --- a/.github/infrastructure/conformance/temporal/worker/go.mod +++ b/.github/infrastructure/conformance/temporal/worker/go.mod @@ -33,3 +33,5 @@ require ( google.golang.org/protobuf v1.28.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/bindings/alicloud/dingtalk/webhook/webhook.go b/bindings/alicloud/dingtalk/webhook/webhook.go index b75a837f1..1e7785749 100644 --- a/bindings/alicloud/dingtalk/webhook/webhook.go +++ b/bindings/alicloud/dingtalk/webhook/webhook.go @@ -80,7 +80,7 @@ func NewDingTalkWebhook(l logger.Logger) bindings.InputOutputBinding { } // Init performs metadata parsing. -func (t *DingTalkWebhook) Init(metadata bindings.Metadata) error { +func (t *DingTalkWebhook) Init(_ context.Context, metadata bindings.Metadata) error { var err error if err = t.settings.Decode(metadata.Properties); err != nil { return fmt.Errorf("dingtalk configuration error: %w", err) @@ -107,6 +107,13 @@ func (t *DingTalkWebhook) Read(ctx context.Context, handler bindings.Handler) er return nil } +func (t *DingTalkWebhook) Close() error { + webhooks.Lock() + defer webhooks.Unlock() + delete(webhooks.m, t.settings.ID) + return nil +} + // Operations returns list of operations supported by dingtalk webhook binding. func (t *DingTalkWebhook) Operations() []bindings.OperationKind { return []bindings.OperationKind{bindings.CreateOperation, bindings.GetOperation} diff --git a/bindings/alicloud/dingtalk/webhook/webhook_test.go b/bindings/alicloud/dingtalk/webhook/webhook_test.go index b527171af..f4b2cee2e 100644 --- a/bindings/alicloud/dingtalk/webhook/webhook_test.go +++ b/bindings/alicloud/dingtalk/webhook/webhook_test.go @@ -57,7 +57,7 @@ func TestPublishMsg(t *testing.T) { //nolint:paralleltest }}} d := NewDingTalkWebhook(logger.NewLogger("test")) - err := d.Init(m) + err := d.Init(context.Background(), m) require.NoError(t, err) req := &bindings.InvokeRequest{Data: []byte(msg), Operation: bindings.CreateOperation, Metadata: map[string]string{}} @@ -78,7 +78,7 @@ func TestBindingReadAndInvoke(t *testing.T) { //nolint:paralleltest }} d := NewDingTalkWebhook(logger.NewLogger("test")) - err := d.Init(m) + err := d.Init(context.Background(), m) assert.NoError(t, err) var count int32 @@ -106,3 +106,18 @@ func TestBindingReadAndInvoke(t *testing.T) { //nolint:paralleltest require.FailNow(t, "read timeout") } } + +func TestBindingClose(t *testing.T) { + d := NewDingTalkWebhook(logger.NewLogger("test")) + m := bindings.Metadata{Base: metadata.Base{ + Name: "test", + Properties: map[string]string{ + "url": "/test", + "secret": "", + "id": "x", + }, + }} + assert.NoError(t, d.Init(context.Background(), m)) + assert.NoError(t, d.Close()) + assert.NoError(t, d.Close(), "second close should not error") +} diff --git a/bindings/alicloud/oss/oss.go b/bindings/alicloud/oss/oss.go index ff850fa4c..9a424283d 100644 --- a/bindings/alicloud/oss/oss.go +++ b/bindings/alicloud/oss/oss.go @@ -45,7 +45,7 @@ func NewAliCloudOSS(logger logger.Logger) bindings.OutputBinding { } // Init does metadata parsing and connection creation. -func (s *AliCloudOSS) Init(metadata bindings.Metadata) error { +func (s *AliCloudOSS) Init(_ context.Context, metadata bindings.Metadata) error { m, err := s.parseMetadata(metadata) if err != nil { return err diff --git a/bindings/alicloud/sls/sls.go b/bindings/alicloud/sls/sls.go index 610a19b0f..032e6931a 100644 --- a/bindings/alicloud/sls/sls.go +++ b/bindings/alicloud/sls/sls.go @@ -31,7 +31,7 @@ type Callback struct { } // parse metadata field -func (s *AliCloudSlsLogstorage) Init(metadata bindings.Metadata) error { +func (s *AliCloudSlsLogstorage) Init(_ context.Context, metadata bindings.Metadata) error { m, err := s.parseMeta(metadata) if err != nil { return err diff --git a/bindings/alicloud/tablestore/tablestore.go b/bindings/alicloud/tablestore/tablestore.go index 63f9ba740..050464639 100644 --- a/bindings/alicloud/tablestore/tablestore.go +++ b/bindings/alicloud/tablestore/tablestore.go @@ -58,7 +58,7 @@ func NewAliCloudTableStore(log logger.Logger) bindings.OutputBinding { } } -func (s *AliCloudTableStore) Init(metadata bindings.Metadata) error { +func (s *AliCloudTableStore) Init(_ context.Context, metadata bindings.Metadata) error { m, err := s.parseMetadata(metadata) if err != nil { return err diff --git a/bindings/alicloud/tablestore/tablestore_test.go b/bindings/alicloud/tablestore/tablestore_test.go index 699adad0c..a9fd6c555 100644 --- a/bindings/alicloud/tablestore/tablestore_test.go +++ b/bindings/alicloud/tablestore/tablestore_test.go @@ -51,7 +51,7 @@ func TestDataEncodeAndDecode(t *testing.T) { metadata := bindings.Metadata{Base: metadata.Base{ Properties: getTestProperties(), }} - aliCloudTableStore.Init(metadata) + aliCloudTableStore.Init(context.Background(), metadata) // test create putData := map[string]interface{}{ diff --git a/bindings/apns/apns.go b/bindings/apns/apns.go index 5ad4836fe..83a8c9886 100644 --- a/bindings/apns/apns.go +++ b/bindings/apns/apns.go @@ -78,7 +78,7 @@ func NewAPNS(logger logger.Logger) bindings.OutputBinding { // Init will configure the APNS output binding using the metadata specified // in the binding's configuration. -func (a *APNS) Init(metadata bindings.Metadata) error { +func (a *APNS) Init(ctx context.Context, metadata bindings.Metadata) error { if err := a.makeURLPrefix(metadata); err != nil { return err } diff --git a/bindings/apns/apns_test.go b/bindings/apns/apns_test.go index 9dc7e7d9a..446188abe 100644 --- a/bindings/apns/apns_test.go +++ b/bindings/apns/apns_test.go @@ -51,7 +51,7 @@ func TestInit(t *testing.T) { }, }} binding := NewAPNS(testLogger).(*APNS) - err := binding.Init(metadata) + err := binding.Init(context.Background(), metadata) assert.Nil(t, err) assert.Equal(t, developmentPrefix, binding.urlPrefix) }) @@ -66,7 +66,7 @@ func TestInit(t *testing.T) { }, }} binding := NewAPNS(testLogger).(*APNS) - err := binding.Init(metadata) + err := binding.Init(context.Background(), metadata) assert.Nil(t, err) assert.Equal(t, productionPrefix, binding.urlPrefix) }) @@ -80,7 +80,7 @@ func TestInit(t *testing.T) { }, }} binding := NewAPNS(testLogger).(*APNS) - err := binding.Init(metadata) + err := binding.Init(context.Background(), metadata) assert.Nil(t, err) assert.Equal(t, productionPrefix, binding.urlPrefix) }) @@ -95,7 +95,7 @@ func TestInit(t *testing.T) { }, }} binding := NewAPNS(testLogger).(*APNS) - err := binding.Init(metadata) + err := binding.Init(context.Background(), metadata) assert.Error(t, err, "invalid value for development parameter: True") }) @@ -107,7 +107,7 @@ func TestInit(t *testing.T) { }, }} binding := NewAPNS(testLogger).(*APNS) - err := binding.Init(metadata) + err := binding.Init(context.Background(), metadata) assert.Error(t, err, "the key-id parameter is required") }) @@ -120,7 +120,7 @@ func TestInit(t *testing.T) { }, }} binding := NewAPNS(testLogger).(*APNS) - err := binding.Init(metadata) + err := binding.Init(context.Background(), metadata) assert.Nil(t, err) assert.Equal(t, testKeyID, binding.authorizationBuilder.keyID) }) @@ -133,7 +133,7 @@ func TestInit(t *testing.T) { }, }} binding := NewAPNS(testLogger).(*APNS) - err := binding.Init(metadata) + err := binding.Init(context.Background(), metadata) assert.Error(t, err, "the team-id parameter is required") }) @@ -146,7 +146,7 @@ func TestInit(t *testing.T) { }, }} binding := NewAPNS(testLogger).(*APNS) - err := binding.Init(metadata) + err := binding.Init(context.Background(), metadata) assert.Nil(t, err) assert.Equal(t, testTeamID, binding.authorizationBuilder.teamID) }) @@ -159,7 +159,7 @@ func TestInit(t *testing.T) { }, }} binding := NewAPNS(testLogger).(*APNS) - err := binding.Init(metadata) + err := binding.Init(context.Background(), metadata) assert.Error(t, err, "the private-key parameter is required") }) @@ -172,7 +172,7 @@ func TestInit(t *testing.T) { }, }} binding := NewAPNS(testLogger).(*APNS) - err := binding.Init(metadata) + err := binding.Init(context.Background(), metadata) assert.Nil(t, err) assert.NotNil(t, binding.authorizationBuilder.privateKey) }) @@ -335,7 +335,7 @@ func makeTestBinding(t *testing.T, log logger.Logger) *APNS { privateKeyKey: testPrivateKey, }, }} - err := testBinding.Init(bindingMetadata) + err := testBinding.Init(context.Background(), bindingMetadata) assert.Nil(t, err) return testBinding diff --git a/bindings/aws/dynamodb/dynamodb.go b/bindings/aws/dynamodb/dynamodb.go index 88ac33757..536404eed 100644 --- a/bindings/aws/dynamodb/dynamodb.go +++ b/bindings/aws/dynamodb/dynamodb.go @@ -49,7 +49,7 @@ func NewDynamoDB(logger logger.Logger) bindings.OutputBinding { } // Init performs connection parsing for DynamoDB. -func (d *DynamoDB) Init(metadata bindings.Metadata) error { +func (d *DynamoDB) Init(_ context.Context, metadata bindings.Metadata) error { meta, err := d.getDynamoDBMetadata(metadata) if err != nil { return err diff --git a/bindings/aws/kinesis/kinesis.go b/bindings/aws/kinesis/kinesis.go index ef42e485d..70427c26e 100644 --- a/bindings/aws/kinesis/kinesis.go +++ b/bindings/aws/kinesis/kinesis.go @@ -15,7 +15,10 @@ package kinesis import ( "context" + "errors" "fmt" + "sync" + "sync/atomic" "time" "github.com/aws/aws-sdk-go/aws" @@ -45,6 +48,10 @@ type AWSKinesis struct { streamARN *string consumerARN *string logger logger.Logger + + closed atomic.Bool + closeCh chan struct{} + wg sync.WaitGroup } type kinesisMetadata struct { @@ -83,11 +90,14 @@ type recordProcessor struct { // NewAWSKinesis returns a new AWS Kinesis instance. func NewAWSKinesis(logger logger.Logger) bindings.InputOutputBinding { - return &AWSKinesis{logger: logger} + return &AWSKinesis{ + logger: logger, + closeCh: make(chan struct{}), + } } // Init does metadata parsing and connection creation. -func (a *AWSKinesis) Init(metadata bindings.Metadata) error { +func (a *AWSKinesis) Init(ctx context.Context, metadata bindings.Metadata) error { m, err := a.parseMetadata(metadata) if err != nil { return err @@ -107,7 +117,7 @@ func (a *AWSKinesis) Init(metadata bindings.Metadata) error { } streamName := aws.String(m.StreamName) - stream, err := client.DescribeStream(&kinesis.DescribeStreamInput{ + stream, err := client.DescribeStreamWithContext(ctx, &kinesis.DescribeStreamInput{ StreamName: streamName, }) if err != nil { @@ -147,6 +157,10 @@ func (a *AWSKinesis) Invoke(ctx context.Context, req *bindings.InvokeRequest) (* } func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err error) { + if a.closed.Load() { + return errors.New("binding is closed") + } + if a.metadata.KinesisConsumerMode == SharedThroughput { a.worker = worker.NewWorker(a.recordProcessorFactory(ctx, handler), a.workerConfig) err = a.worker.Start() @@ -166,8 +180,13 @@ func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err er } // Wait for context cancelation then stop + a.wg.Add(1) go func() { - <-ctx.Done() + defer a.wg.Done() + select { + case <-ctx.Done(): + case <-a.closeCh: + } if a.metadata.KinesisConsumerMode == SharedThroughput { a.worker.Shutdown() } else if a.metadata.KinesisConsumerMode == ExtendedFanout { @@ -188,14 +207,25 @@ func (a *AWSKinesis) Subscribe(ctx context.Context, streamDesc kinesis.StreamDes a.consumerARN = consumerARN + a.wg.Add(len(streamDesc.Shards)) for i, shard := range streamDesc.Shards { - go func(idx int, s *kinesis.Shard) error { + go func(idx int, s *kinesis.Shard) { + defer a.wg.Done() + // Reconnection backoff bo := backoff.NewExponentialBackOff() bo.InitialInterval = 2 * time.Second - // Repeat until context is canceled - for ctx.Err() == nil { + // Repeat until context is canceled or binding closed. + for { + select { + case <-ctx.Done(): + return + case <-a.closeCh: + return + default: + } + sub, err := a.client.SubscribeToShardWithContext(ctx, &kinesis.SubscribeToShardInput{ ConsumerARN: consumerARN, ShardId: s.ShardId, @@ -204,8 +234,12 @@ func (a *AWSKinesis) Subscribe(ctx context.Context, streamDesc kinesis.StreamDes if err != nil { wait := bo.NextBackOff() a.logger.Errorf("Error while reading from shard %v: %v. Attempting to reconnect in %s...", s.ShardId, err, wait) - time.Sleep(wait) - continue + select { + case <-ctx.Done(): + return + case <-time.After(wait): + continue + } } // Reset the backoff on connection success @@ -223,22 +257,30 @@ func (a *AWSKinesis) Subscribe(ctx context.Context, streamDesc kinesis.StreamDes } } } - return nil }(i, shard) } return nil } -func (a *AWSKinesis) ensureConsumer(parentCtx context.Context, streamARN *string) (*string, error) { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - consumer, err := a.client.DescribeStreamConsumerWithContext(ctx, &kinesis.DescribeStreamConsumerInput{ +func (a *AWSKinesis) Close() error { + if a.closed.CompareAndSwap(false, true) { + close(a.closeCh) + } + a.wg.Wait() + return nil +} + +func (a *AWSKinesis) ensureConsumer(ctx context.Context, streamARN *string) (*string, error) { + // Only set timeout on consumer call. + conCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + consumer, err := a.client.DescribeStreamConsumerWithContext(conCtx, &kinesis.DescribeStreamConsumerInput{ ConsumerName: &a.metadata.ConsumerName, StreamARN: streamARN, }) - cancel() if err != nil { - return a.registerConsumer(parentCtx, streamARN) + return a.registerConsumer(ctx, streamARN) } return consumer.ConsumerDescription.ConsumerARN, nil diff --git a/bindings/aws/s3/s3.go b/bindings/aws/s3/s3.go index 9d7d0790c..cd9643106 100644 --- a/bindings/aws/s3/s3.go +++ b/bindings/aws/s3/s3.go @@ -99,7 +99,7 @@ func NewAWSS3(logger logger.Logger) bindings.OutputBinding { } // Init does metadata parsing and connection creation. -func (s *AWSS3) Init(metadata bindings.Metadata) error { +func (s *AWSS3) Init(_ context.Context, metadata bindings.Metadata) error { m, err := s.parseMetadata(metadata) if err != nil { return err diff --git a/bindings/aws/ses/ses.go b/bindings/aws/ses/ses.go index 3cc1e3804..61b36c7e6 100644 --- a/bindings/aws/ses/ses.go +++ b/bindings/aws/ses/ses.go @@ -61,7 +61,7 @@ func NewAWSSES(logger logger.Logger) bindings.OutputBinding { } // Init does metadata parsing. -func (a *AWSSES) Init(metadata bindings.Metadata) error { +func (a *AWSSES) Init(_ context.Context, metadata bindings.Metadata) error { // Parse input metadata meta, err := a.parseMetadata(metadata) if err != nil { diff --git a/bindings/aws/sns/sns.go b/bindings/aws/sns/sns.go index 8b0abd012..03a6f1285 100644 --- a/bindings/aws/sns/sns.go +++ b/bindings/aws/sns/sns.go @@ -53,7 +53,7 @@ func NewAWSSNS(logger logger.Logger) bindings.OutputBinding { } // Init does metadata parsing. -func (a *AWSSNS) Init(metadata bindings.Metadata) error { +func (a *AWSSNS) Init(_ context.Context, metadata bindings.Metadata) error { m, err := a.parseMetadata(metadata) if err != nil { return err diff --git a/bindings/aws/sqs/sqs.go b/bindings/aws/sqs/sqs.go index 444891fc6..72e05040f 100644 --- a/bindings/aws/sqs/sqs.go +++ b/bindings/aws/sqs/sqs.go @@ -16,6 +16,9 @@ package sqs import ( "context" "encoding/json" + "errors" + "sync" + "sync/atomic" "time" "github.com/aws/aws-sdk-go/aws" @@ -31,7 +34,10 @@ type AWSSQS struct { Client *sqs.SQS QueueURL *string - logger logger.Logger + logger logger.Logger + wg sync.WaitGroup + closeCh chan struct{} + closed atomic.Bool } type sqsMetadata struct { @@ -45,11 +51,14 @@ type sqsMetadata struct { // NewAWSSQS returns a new AWS SQS instance. func NewAWSSQS(logger logger.Logger) bindings.InputOutputBinding { - return &AWSSQS{logger: logger} + return &AWSSQS{ + logger: logger, + closeCh: make(chan struct{}), + } } // Init does metadata parsing and connection creation. -func (a *AWSSQS) Init(metadata bindings.Metadata) error { +func (a *AWSSQS) Init(ctx context.Context, metadata bindings.Metadata) error { m, err := a.parseSQSMetadata(metadata) if err != nil { return err @@ -61,7 +70,7 @@ func (a *AWSSQS) Init(metadata bindings.Metadata) error { } queueName := m.QueueName - resultURL, err := client.GetQueueUrl(&sqs.GetQueueUrlInput{ + resultURL, err := client.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{ QueueName: aws.String(queueName), }) if err != nil { @@ -89,9 +98,20 @@ func (a *AWSSQS) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bind } func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error { + if a.closed.Load() { + return errors.New("binding is closed") + } + + a.wg.Add(1) go func() { - // Repeat until the context is canceled - for ctx.Err() == nil { + defer a.wg.Done() + + // Repeat until the context is canceled or component is closed + for { + if ctx.Err() != nil || a.closed.Load() { + return + } + result, err := a.Client.ReceiveMessageWithContext(ctx, &sqs.ReceiveMessageInput{ QueueUrl: a.QueueURL, AttributeNames: aws.StringSlice([]string{ @@ -126,13 +146,25 @@ func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error { } } - time.Sleep(time.Millisecond * 50) + select { + case <-ctx.Done(): + case <-a.closeCh: + case <-time.After(time.Millisecond * 50): + } } }() return nil } +func (a *AWSSQS) Close() error { + if a.closed.CompareAndSwap(false, true) { + close(a.closeCh) + } + a.wg.Wait() + return nil +} + func (a *AWSSQS) parseSQSMetadata(metadata bindings.Metadata) (*sqsMetadata, error) { b, err := json.Marshal(metadata.Properties) if err != nil { diff --git a/bindings/azure/blobstorage/blobstorage.go b/bindings/azure/blobstorage/blobstorage.go index 58cc7e83f..9539527e6 100644 --- a/bindings/azure/blobstorage/blobstorage.go +++ b/bindings/azure/blobstorage/blobstorage.go @@ -92,7 +92,7 @@ func NewAzureBlobStorage(logger logger.Logger) bindings.OutputBinding { } // Init performs metadata parsing. -func (a *AzureBlobStorage) Init(metadata bindings.Metadata) error { +func (a *AzureBlobStorage) Init(_ context.Context, metadata bindings.Metadata) error { var err error a.containerClient, a.metadata, err = storageinternal.CreateContainerStorageClient(a.logger, metadata.Properties) if err != nil { diff --git a/bindings/azure/cosmosdb/cosmosdb.go b/bindings/azure/cosmosdb/cosmosdb.go index 4a3cbaf8b..b91caef61 100644 --- a/bindings/azure/cosmosdb/cosmosdb.go +++ b/bindings/azure/cosmosdb/cosmosdb.go @@ -53,7 +53,7 @@ func NewCosmosDB(logger logger.Logger) bindings.OutputBinding { } // Init performs CosmosDB connection parsing and connecting. -func (c *CosmosDB) Init(metadata bindings.Metadata) error { +func (c *CosmosDB) Init(ctx context.Context, metadata bindings.Metadata) error { m, err := c.parseMetadata(metadata) if err != nil { return err @@ -103,9 +103,9 @@ func (c *CosmosDB) Init(metadata bindings.Metadata) error { } c.client = dbContainer - ctx, cancel := context.WithTimeout(context.Background(), timeoutValue*time.Second) - _, err = c.client.Read(ctx, nil) - cancel() + readCtx, readCancel := context.WithTimeout(ctx, timeoutValue*time.Second) + defer readCancel() + _, err = c.client.Read(readCtx, nil) return err } diff --git a/bindings/azure/cosmosdbgremlinapi/cosmosdbgremlinapi.go b/bindings/azure/cosmosdbgremlinapi/cosmosdbgremlinapi.go index 80763609f..7836a1668 100644 --- a/bindings/azure/cosmosdbgremlinapi/cosmosdbgremlinapi.go +++ b/bindings/azure/cosmosdbgremlinapi/cosmosdbgremlinapi.go @@ -59,7 +59,7 @@ func NewCosmosDBGremlinAPI(logger logger.Logger) bindings.OutputBinding { } // Init performs CosmosDBGremlinAPI connection parsing and connecting. -func (c *CosmosDBGremlinAPI) Init(metadata bindings.Metadata) error { +func (c *CosmosDBGremlinAPI) Init(_ context.Context, metadata bindings.Metadata) error { c.logger.Debug("Initializing Cosmos Graph DB binding") m, err := c.parseMetadata(metadata) diff --git a/bindings/azure/eventgrid/eventgrid.go b/bindings/azure/eventgrid/eventgrid.go index d913a6576..b7549c57b 100644 --- a/bindings/azure/eventgrid/eventgrid.go +++ b/bindings/azure/eventgrid/eventgrid.go @@ -19,6 +19,8 @@ import ( "fmt" "net/http" "net/url" + "sync" + "sync/atomic" "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" @@ -42,6 +44,9 @@ const armOperationTimeout = 30 * time.Second type AzureEventGrid struct { metadata *azureEventGridMetadata logger logger.Logger + closeCh chan struct{} + closed atomic.Bool + wg sync.WaitGroup } type azureEventGridMetadata struct { @@ -70,11 +75,14 @@ type azureEventGridMetadata struct { // NewAzureEventGrid returns a new Azure Event Grid instance. func NewAzureEventGrid(logger logger.Logger) bindings.InputOutputBinding { - return &AzureEventGrid{logger: logger} + return &AzureEventGrid{ + logger: logger, + closeCh: make(chan struct{}), + } } // Init performs metadata init. -func (a *AzureEventGrid) Init(metadata bindings.Metadata) error { +func (a *AzureEventGrid) Init(_ context.Context, metadata bindings.Metadata) error { m, err := a.parseMetadata(metadata) if err != nil { return err @@ -85,6 +93,10 @@ func (a *AzureEventGrid) Init(metadata bindings.Metadata) error { } func (a *AzureEventGrid) Read(ctx context.Context, handler bindings.Handler) error { + if a.closed.Load() { + return errors.New("binding is closed") + } + err := a.ensureInputBindingMetadata() if err != nil { return err @@ -120,17 +132,22 @@ func (a *AzureEventGrid) Read(ctx context.Context, handler bindings.Handler) err } // Run the server in background + a.wg.Add(2) go func() { + defer a.wg.Done() a.logger.Debugf("About to start listening for Event Grid events at http://localhost:%s/api/events", a.metadata.HandshakePort) srvErr := srv.ListenAndServe(":" + a.metadata.HandshakePort) if err != nil { a.logger.Errorf("Error starting server: %v", srvErr) } }() - - // Close the server when context is canceled + // Close the server when context is canceled or binding closed. go func() { - <-ctx.Done() + defer a.wg.Done() + select { + case <-ctx.Done(): + case <-a.closeCh: + } srvErr := srv.Shutdown() if err != nil { a.logger.Errorf("Error shutting down server: %v", srvErr) @@ -149,6 +166,14 @@ func (a *AzureEventGrid) Operations() []bindings.OperationKind { return []bindings.OperationKind{bindings.CreateOperation} } +func (a *AzureEventGrid) Close() error { + if a.closed.CompareAndSwap(false, true) { + close(a.closeCh) + } + a.wg.Wait() + return nil +} + func (a *AzureEventGrid) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) { err := a.ensureOutputBindingMetadata() if err != nil { diff --git a/bindings/azure/eventhubs/eventhubs.go b/bindings/azure/eventhubs/eventhubs.go index 0be77490c..ec44fdd29 100644 --- a/bindings/azure/eventhubs/eventhubs.go +++ b/bindings/azure/eventhubs/eventhubs.go @@ -37,7 +37,7 @@ func NewAzureEventHubs(logger logger.Logger) bindings.InputOutputBinding { } // Init performs metadata init. -func (a *AzureEventHubs) Init(metadata bindings.Metadata) error { +func (a *AzureEventHubs) Init(_ context.Context, metadata bindings.Metadata) error { return a.AzureEventHubs.Init(metadata.Properties) } diff --git a/bindings/azure/eventhubs/eventhubs_integration_test.go b/bindings/azure/eventhubs/eventhubs_integration_test.go index 2eede303c..3276f2bee 100644 --- a/bindings/azure/eventhubs/eventhubs_integration_test.go +++ b/bindings/azure/eventhubs/eventhubs_integration_test.go @@ -102,7 +102,7 @@ func testEventHubsBindingsAADAuthentication(t *testing.T) { metadata := createEventHubsBindingsAADMetadata() eventHubsBindings := NewAzureEventHubs(log) - err := eventHubsBindings.Init(metadata) + err := eventHubsBindings.Init(context.Background(), metadata) require.NoError(t, err) req := &bindings.InvokeRequest{ @@ -146,7 +146,7 @@ func testReadIotHubEvents(t *testing.T) { logger := logger.NewLogger("bindings.azure.eventhubs.integration.test") eh := NewAzureEventHubs(logger) - err := eh.Init(createIotHubBindingsMetadata()) + err := eh.Init(context.Background(), createIotHubBindingsMetadata()) require.NoError(t, err) // Invoke az CLI via bash script to send test IoT device events diff --git a/bindings/azure/servicebusqueues/servicebusqueues.go b/bindings/azure/servicebusqueues/servicebusqueues.go index a84d5705c..ef4791bc4 100644 --- a/bindings/azure/servicebusqueues/servicebusqueues.go +++ b/bindings/azure/servicebusqueues/servicebusqueues.go @@ -17,6 +17,8 @@ import ( "context" "errors" "fmt" + "sync" + "sync/atomic" "time" servicebus "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus" @@ -39,17 +41,21 @@ type AzureServiceBusQueues struct { client *impl.Client timeout time.Duration logger logger.Logger + closed atomic.Bool + wg sync.WaitGroup + closeCh chan struct{} } // NewAzureServiceBusQueues returns a new AzureServiceBusQueues instance. func NewAzureServiceBusQueues(logger logger.Logger) bindings.InputOutputBinding { return &AzureServiceBusQueues{ - logger: logger, + logger: logger, + closeCh: make(chan struct{}), } } // Init parses connection properties and creates a new Service Bus Queue client. -func (a *AzureServiceBusQueues) Init(metadata bindings.Metadata) (err error) { +func (a *AzureServiceBusQueues) Init(ctx context.Context, metadata bindings.Metadata) (err error) { a.metadata, err = impl.ParseMetadata(metadata.Properties, a.logger, (impl.MetadataModeBinding | impl.MetadataModeQueues)) if err != nil { return err @@ -62,7 +68,7 @@ func (a *AzureServiceBusQueues) Init(metadata bindings.Metadata) (err error) { } // Will do nothing if DisableEntityManagement is false - err = a.client.EnsureQueue(context.Background(), a.metadata.QueueName) + err = a.client.EnsureQueue(ctx, a.metadata.QueueName) if err != nil { return err } @@ -100,14 +106,33 @@ func (a *AzureServiceBusQueues) Invoke(invokeCtx context.Context, req *bindings. return nil, nil } -func (a *AzureServiceBusQueues) Read(subscribeCtx context.Context, handler bindings.Handler) error { +func (a *AzureServiceBusQueues) Read(parentCtx context.Context, handler bindings.Handler) error { + if a.closed.Load() { + return errors.New("binding is closed") + } + // Reconnection backoff policy bo := backoff.NewExponentialBackOff() bo.MaxElapsedTime = 0 bo.InitialInterval = time.Duration(a.metadata.MinConnectionRecoveryInSec) * time.Second bo.MaxInterval = time.Duration(a.metadata.MaxConnectionRecoveryInSec) * time.Second + subscribeCtx, subscribeCancel := context.WithCancel(parentCtx) + + // Close the subscription context when the binding is closed. + a.wg.Add(2) go func() { + defer a.wg.Done() + select { + case <-a.closeCh: + subscribeCancel() + case <-parentCtx.Done(): + // nop + } + }() + + go func() { + defer a.wg.Done() // Reconnect loop. for { sub := impl.NewSubscription(subscribeCtx, impl.SubsriptionOptions{ @@ -165,7 +190,12 @@ func (a *AzureServiceBusQueues) Read(subscribeCtx context.Context, handler bindi wait := bo.NextBackOff() a.logger.Warnf("Subscription to queue %s lost connection, attempting to reconnect in %s...", a.metadata.QueueName, wait) - time.Sleep(wait) + select { + case <-time.After(wait): + // nop + case <-a.closeCh: + return + } } }() @@ -204,7 +234,11 @@ func (a *AzureServiceBusQueues) getHandlerFn(handler bindings.Handler) impl.Hand } func (a *AzureServiceBusQueues) Close() (err error) { + if a.closed.CompareAndSwap(false, true) { + close(a.closeCh) + } a.logger.Debug("Closing component") a.client.CloseSender(a.metadata.QueueName) + a.wg.Wait() return nil } diff --git a/bindings/azure/signalr/signalr.go b/bindings/azure/signalr/signalr.go index 1f0d49770..74aedafca 100644 --- a/bindings/azure/signalr/signalr.go +++ b/bindings/azure/signalr/signalr.go @@ -78,7 +78,7 @@ type SignalR struct { } // Init is responsible for initializing the SignalR output based on the metadata. -func (s *SignalR) Init(metadata bindings.Metadata) (err error) { +func (s *SignalR) Init(_ context.Context, metadata bindings.Metadata) (err error) { s.userAgent = "dapr-" + logger.DaprVersion err = s.parseMetadata(metadata.Properties) diff --git a/bindings/azure/storagequeues/storagequeues.go b/bindings/azure/storagequeues/storagequeues.go index 048e9a359..31c423920 100644 --- a/bindings/azure/storagequeues/storagequeues.go +++ b/bindings/azure/storagequeues/storagequeues.go @@ -16,9 +16,12 @@ package storagequeues import ( "context" "encoding/base64" + "errors" "fmt" "net/url" "strconv" + "sync" + "sync/atomic" "time" "github.com/Azure/azure-storage-queue-go/azqueue" @@ -40,9 +43,10 @@ type consumer struct { // QueueHelper enables injection for testnig. type QueueHelper interface { - Init(metadata bindings.Metadata) (*storageQueuesMetadata, error) + Init(ctx context.Context, metadata bindings.Metadata) (*storageQueuesMetadata, error) Write(ctx context.Context, data []byte, ttl *time.Duration) error Read(ctx context.Context, consumer *consumer) error + Close() error } // AzureQueueHelper concrete impl of queue helper. @@ -55,7 +59,7 @@ type AzureQueueHelper struct { } // Init sets up this helper. -func (d *AzureQueueHelper) Init(metadata bindings.Metadata) (*storageQueuesMetadata, error) { +func (d *AzureQueueHelper) Init(ctx context.Context, metadata bindings.Metadata) (*storageQueuesMetadata, error) { m, err := parseMetadata(metadata) if err != nil { return nil, err @@ -89,9 +93,9 @@ func (d *AzureQueueHelper) Init(metadata bindings.Metadata) (*storageQueuesMetad d.queueURL = azqueue.NewQueueURL(*URL, p) } - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) - _, err = d.queueURL.Create(ctx, azqueue.Metadata{}) - cancel() + createCtx, createCancel := context.WithTimeout(ctx, 2*time.Minute) + _, err = d.queueURL.Create(createCtx, azqueue.Metadata{}) + createCancel() if err != nil { return nil, err } @@ -128,7 +132,10 @@ func (d *AzureQueueHelper) Read(ctx context.Context, consumer *consumer) error { } if res.NumMessages() == 0 { // Queue was empty so back off by 10 seconds before trying again - time.Sleep(10 * time.Second) + select { + case <-time.After(10 * time.Second): + case <-ctx.Done(): + } return nil } mt := res.Message(0).Text @@ -162,6 +169,10 @@ func (d *AzureQueueHelper) Read(ctx context.Context, consumer *consumer) error { return nil } +func (d *AzureQueueHelper) Close() error { + return nil +} + // NewAzureQueueHelper creates new helper. func NewAzureQueueHelper(logger logger.Logger) QueueHelper { return &AzureQueueHelper{ @@ -175,6 +186,10 @@ type AzureStorageQueues struct { helper QueueHelper logger logger.Logger + + wg sync.WaitGroup + closeCh chan struct{} + closed atomic.Bool } type storageQueuesMetadata struct { @@ -189,12 +204,16 @@ type storageQueuesMetadata struct { // NewAzureStorageQueues returns a new AzureStorageQueues instance. func NewAzureStorageQueues(logger logger.Logger) bindings.InputOutputBinding { - return &AzureStorageQueues{helper: NewAzureQueueHelper(logger), logger: logger} + return &AzureStorageQueues{ + helper: NewAzureQueueHelper(logger), + logger: logger, + closeCh: make(chan struct{}), + } } // Init parses connection properties and creates a new Storage Queue client. -func (a *AzureStorageQueues) Init(metadata bindings.Metadata) (err error) { - a.metadata, err = a.helper.Init(metadata) +func (a *AzureStorageQueues) Init(ctx context.Context, metadata bindings.Metadata) (err error) { + a.metadata, err = a.helper.Init(ctx, metadata) if err != nil { return err } @@ -261,14 +280,32 @@ func (a *AzureStorageQueues) Invoke(ctx context.Context, req *bindings.InvokeReq } func (a *AzureStorageQueues) Read(ctx context.Context, handler bindings.Handler) error { + if a.closed.Load() { + return errors.New("input binding is closed") + } + c := consumer{ callback: handler, } + + // Close read context when binding is closed. + readCtx, cancel := context.WithCancel(ctx) + a.wg.Add(2) go func() { + defer a.wg.Done() + defer cancel() + + select { + case <-a.closeCh: + case <-ctx.Done(): + } + }() + go func() { + defer a.wg.Done() // Read until context is canceled var err error - for ctx.Err() == nil { - err = a.helper.Read(ctx, &c) + for readCtx.Err() == nil { + err = a.helper.Read(readCtx, &c) if err != nil { a.logger.Errorf("error from c: %s", err) } @@ -277,3 +314,11 @@ func (a *AzureStorageQueues) Read(ctx context.Context, handler bindings.Handler) return nil } + +func (a *AzureStorageQueues) Close() error { + if a.closed.CompareAndSwap(false, true) { + close(a.closeCh) + } + a.wg.Wait() + return nil +} diff --git a/bindings/azure/storagequeues/storagequeues_test.go b/bindings/azure/storagequeues/storagequeues_test.go index 678ba388d..2aba25b4e 100644 --- a/bindings/azure/storagequeues/storagequeues_test.go +++ b/bindings/azure/storagequeues/storagequeues_test.go @@ -16,6 +16,7 @@ package storagequeues import ( "context" "encoding/base64" + "sync" "testing" "time" @@ -32,9 +33,11 @@ type MockHelper struct { mock.Mock messages chan []byte metadata *storageQueuesMetadata + closeCh chan struct{} + wg sync.WaitGroup } -func (m *MockHelper) Init(metadata bindings.Metadata) (*storageQueuesMetadata, error) { +func (m *MockHelper) Init(ctx context.Context, metadata bindings.Metadata) (*storageQueuesMetadata, error) { m.messages = make(chan []byte, 10) var err error m.metadata, err = parseMetadata(metadata) @@ -50,12 +53,23 @@ func (m *MockHelper) Write(ctx context.Context, data []byte, ttl *time.Duration) func (m *MockHelper) Read(ctx context.Context, consumer *consumer) error { retvals := m.Called(ctx, consumer) + readCtx, cancel := context.WithCancel(ctx) + m.wg.Add(2) go func() { + defer m.wg.Done() + defer cancel() + select { + case <-readCtx.Done(): + case <-m.closeCh: + } + }() + go func() { + defer m.wg.Done() for msg := range m.messages { if m.metadata.DecodeBase64 { msg, _ = base64.StdEncoding.DecodeString(string(msg)) } - go consumer.callback(ctx, &bindings.ReadResponse{ + go consumer.callback(readCtx, &bindings.ReadResponse{ Data: msg, }) } @@ -64,18 +78,24 @@ func (m *MockHelper) Read(ctx context.Context, consumer *consumer) error { return retvals.Error(0) } +func (m *MockHelper) Close() error { + defer m.wg.Wait() + close(m.closeCh) + return nil +} + func TestWriteQueue(t *testing.T) { mm := new(MockHelper) mm.On("Write", mock.AnythingOfType("[]uint8"), mock.MatchedBy(func(in *time.Duration) bool { return in == nil })).Return(nil) - a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")} + a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})} m := bindings.Metadata{} m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"} - err := a.Init(m) + err := a.Init(context.Background(), m) assert.Nil(t, err) r := bindings.InvokeRequest{Data: []byte("This is my message")} @@ -83,6 +103,7 @@ func TestWriteQueue(t *testing.T) { _, err = a.Invoke(context.Background(), &r) assert.Nil(t, err) + assert.NoError(t, a.Close()) } func TestWriteWithTTLInQueue(t *testing.T) { @@ -91,12 +112,12 @@ func TestWriteWithTTLInQueue(t *testing.T) { return in != nil && *in == time.Second })).Return(nil) - a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")} + a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})} m := bindings.Metadata{} m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1", metadata.TTLMetadataKey: "1"} - err := a.Init(m) + err := a.Init(context.Background(), m) assert.Nil(t, err) r := bindings.InvokeRequest{Data: []byte("This is my message")} @@ -104,6 +125,7 @@ func TestWriteWithTTLInQueue(t *testing.T) { _, err = a.Invoke(context.Background(), &r) assert.Nil(t, err) + assert.NoError(t, a.Close()) } func TestWriteWithTTLInWrite(t *testing.T) { @@ -112,12 +134,12 @@ func TestWriteWithTTLInWrite(t *testing.T) { return in != nil && *in == time.Second })).Return(nil) - a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")} + a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})} m := bindings.Metadata{} m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1", metadata.TTLMetadataKey: "1"} - err := a.Init(m) + err := a.Init(context.Background(), m) assert.Nil(t, err) r := bindings.InvokeRequest{ @@ -128,6 +150,7 @@ func TestWriteWithTTLInWrite(t *testing.T) { _, err = a.Invoke(context.Background(), &r) assert.Nil(t, err) + assert.NoError(t, a.Close()) } // Uncomment this function to write a message to local storage queue @@ -138,7 +161,7 @@ func TestWriteWithTTLInWrite(t *testing.T) { m := bindings.Metadata{} m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"} - err := a.Init(m) + err := a.Init(context.Background(), m) assert.Nil(t, err) r := bindings.InvokeRequest{Data: []byte("This is my message")} @@ -152,12 +175,12 @@ func TestReadQueue(t *testing.T) { mm := new(MockHelper) mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil) mm.On("Read", mock.AnythingOfType("*context.cancelCtx"), mock.AnythingOfType("*storagequeues.consumer")).Return(nil) - a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")} + a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})} m := bindings.Metadata{} m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"} - err := a.Init(m) + err := a.Init(context.Background(), m) assert.Nil(t, err) r := bindings.InvokeRequest{Data: []byte("This is my message")} @@ -186,6 +209,7 @@ func TestReadQueue(t *testing.T) { t.Fatal("Timeout waiting for messages") } assert.Equal(t, 1, received) + assert.NoError(t, a.Close()) } func TestReadQueueDecode(t *testing.T) { @@ -193,12 +217,12 @@ func TestReadQueueDecode(t *testing.T) { mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil) mm.On("Read", mock.AnythingOfType("*context.cancelCtx"), mock.AnythingOfType("*storagequeues.consumer")).Return(nil) - a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")} + a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})} m := bindings.Metadata{} m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1", "decodeBase64": "true"} - err := a.Init(m) + err := a.Init(context.Background(), m) assert.Nil(t, err) r := bindings.InvokeRequest{Data: []byte("VGhpcyBpcyBteSBtZXNzYWdl")} @@ -227,6 +251,7 @@ func TestReadQueueDecode(t *testing.T) { t.Fatal("Timeout waiting for messages") } assert.Equal(t, 1, received) + assert.NoError(t, a.Close()) } // Uncomment this function to test reding from local queue @@ -237,7 +262,7 @@ func TestReadQueueDecode(t *testing.T) { m := bindings.Metadata{} m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"} - err := a.Init(m) + err := a.Init(context.Background(), m) assert.Nil(t, err) r := bindings.InvokeRequest{Data: []byte("This is my message")} @@ -263,12 +288,12 @@ func TestReadQueueNoMessage(t *testing.T) { mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil) mm.On("Read", mock.AnythingOfType("*context.cancelCtx"), mock.AnythingOfType("*storagequeues.consumer")).Return(nil) - a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")} + a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})} m := bindings.Metadata{} m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"} - err := a.Init(m) + err := a.Init(context.Background(), m) assert.Nil(t, err) ctx, cancel := context.WithCancel(context.Background()) @@ -285,6 +310,7 @@ func TestReadQueueNoMessage(t *testing.T) { time.Sleep(1 * time.Second) cancel() assert.Equal(t, 0, received) + assert.NoError(t, a.Close()) } func TestParseMetadata(t *testing.T) { diff --git a/bindings/cloudflare/queues/cfqueues.go b/bindings/cloudflare/queues/cfqueues.go index db8c89a62..984e7666b 100644 --- a/bindings/cloudflare/queues/cfqueues.go +++ b/bindings/cloudflare/queues/cfqueues.go @@ -48,7 +48,7 @@ func NewCFQueues(logger logger.Logger) bindings.OutputBinding { } // Init the component. -func (q *CFQueues) Init(metadata bindings.Metadata) error { +func (q *CFQueues) Init(_ context.Context, metadata bindings.Metadata) error { // Decode the metadata err := mapstructure.Decode(metadata.Properties, &q.metadata) if err != nil { diff --git a/bindings/commercetools/commercetools.go b/bindings/commercetools/commercetools.go index 9b2fd4cc9..a6c1d4a09 100644 --- a/bindings/commercetools/commercetools.go +++ b/bindings/commercetools/commercetools.go @@ -51,7 +51,7 @@ func NewCommercetools(logger logger.Logger) bindings.OutputBinding { } // Init does metadata parsing and connection establishment. -func (ct *Binding) Init(metadata bindings.Metadata) error { +func (ct *Binding) Init(_ context.Context, metadata bindings.Metadata) error { commercetoolsM, err := ct.getCommercetoolsMetadata(metadata) if err != nil { return err diff --git a/bindings/cron/cron.go b/bindings/cron/cron.go index 24549cdea..7113c761d 100644 --- a/bindings/cron/cron.go +++ b/bindings/cron/cron.go @@ -16,6 +16,8 @@ package cron import ( "context" "fmt" + "sync" + "sync/atomic" "time" "github.com/benbjohnson/clock" @@ -34,6 +36,9 @@ type Binding struct { schedule string parser cron.Parser clk clock.Clock + closed atomic.Bool + closeCh chan struct{} + wg sync.WaitGroup } // NewCron returns a new Cron event input binding. @@ -48,6 +53,7 @@ func NewCronWithClock(logger logger.Logger, clk clock.Clock) bindings.InputBindi parser: cron.NewParser( cron.SecondOptional | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor, ), + closeCh: make(chan struct{}), } } @@ -56,7 +62,7 @@ func NewCronWithClock(logger logger.Logger, clk clock.Clock) bindings.InputBindi // // "15 * * * * *" - Every 15 sec // "0 30 * * * *" - Every 30 min -func (b *Binding) Init(metadata bindings.Metadata) error { +func (b *Binding) Init(ctx context.Context, metadata bindings.Metadata) error { b.name = metadata.Name s, f := metadata.Properties["schedule"] if !f || s == "" { @@ -73,6 +79,10 @@ func (b *Binding) Init(metadata bindings.Metadata) error { // Read triggers the Cron scheduler. func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error { + if b.closed.Load() { + return errors.New("binding is closed") + } + c := cron.New(cron.WithParser(b.parser), cron.WithClock(b.clk)) id, err := c.AddFunc(b.schedule, func() { b.logger.Debugf("name: %s, schedule fired: %v", b.name, time.Now()) @@ -89,12 +99,25 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error { c.Start() b.logger.Debugf("name: %s, next run: %v", b.name, time.Until(c.Entry(id).Next)) + b.wg.Add(1) go func() { - // Wait for context to be canceled - <-ctx.Done() + defer b.wg.Done() + // Wait for context to be canceled or component to be closed. + select { + case <-ctx.Done(): + case <-b.closeCh: + } b.logger.Debugf("name: %s, stopping schedule: %s", b.name, b.schedule) c.Stop() }() return nil } + +func (b *Binding) Close() error { + if b.closed.CompareAndSwap(false, true) { + close(b.closeCh) + } + b.wg.Wait() + return nil +} diff --git a/bindings/cron/cron_test.go b/bindings/cron/cron_test.go index f48fb8415..b3d6897f7 100644 --- a/bindings/cron/cron_test.go +++ b/bindings/cron/cron_test.go @@ -16,6 +16,7 @@ package cron import ( "context" "os" + "sync/atomic" "testing" "time" @@ -84,7 +85,7 @@ func TestCronInitSuccess(t *testing.T) { for _, test := range initTests { c := getNewCron() - err := c.Init(getTestMetadata(test.schedule)) + err := c.Init(context.Background(), getTestMetadata(test.schedule)) if test.errorExpected { assert.Errorf(t, err, "Got no error while initializing an invalid schedule: %s", test.schedule) } else { @@ -99,38 +100,41 @@ func TestCronRead(t *testing.T) { clk := clock.NewMock() c := getNewCronWithClock(clk) schedule := "@every 1s" - assert.NoErrorf(t, c.Init(getTestMetadata(schedule)), "error initializing valid schedule") - expectedCount := 5 - observedCount := 0 + assert.NoErrorf(t, c.Init(context.Background(), getTestMetadata(schedule)), "error initializing valid schedule") + expectedCount := int32(5) + var observedCount atomic.Int32 err := c.Read(context.Background(), func(ctx context.Context, res *bindings.ReadResponse) ([]byte, error) { assert.NotNil(t, res) - observedCount++ + observedCount.Add(1) return nil, nil }) // Check if cron triggers 5 times in 5 seconds - for i := 0; i < expectedCount; i++ { + for i := int32(0); i < expectedCount; i++ { // Add time to mock clock in 1 second intervals using loop to allow cron go routine to run clk.Add(time.Second) } // Wait for 1 second after adding the last second to mock clock to allow cron to finish triggering - time.Sleep(1 * time.Second) - assert.Equal(t, expectedCount, observedCount, "Cron did not trigger expected number of times, expected %d, got %d", expectedCount, observedCount) + assert.Eventually(t, func() bool { + return observedCount.Load() == expectedCount + }, time.Second, time.Millisecond*10, + "Cron did not trigger expected number of times, expected %d, got %d", expectedCount, observedCount.Load()) assert.NoErrorf(t, err, "error on read") + assert.NoError(t, c.Close()) } func TestCronReadWithContextCancellation(t *testing.T) { clk := clock.NewMock() c := getNewCronWithClock(clk) schedule := "@every 1s" - assert.NoErrorf(t, c.Init(getTestMetadata(schedule)), "error initializing valid schedule") - expectedCount := 5 - observedCount := 0 + assert.NoErrorf(t, c.Init(context.Background(), getTestMetadata(schedule)), "error initializing valid schedule") + expectedCount := int32(5) + var observedCount atomic.Int32 ctx, cancel := context.WithCancel(context.Background()) err := c.Read(ctx, func(ctx context.Context, res *bindings.ReadResponse) ([]byte, error) { assert.NotNil(t, res) - assert.LessOrEqualf(t, observedCount, expectedCount, "Invoke didn't stop the schedule") - observedCount++ - if observedCount == expectedCount { + assert.LessOrEqualf(t, observedCount.Load(), expectedCount, "Invoke didn't stop the schedule") + observedCount.Add(1) + if observedCount.Load() == expectedCount { // Cancel context after 5 triggers cancel() } @@ -141,7 +145,10 @@ func TestCronReadWithContextCancellation(t *testing.T) { // Add time to mock clock in 1 second intervals using loop to allow cron go routine to run clk.Add(time.Second) } - time.Sleep(1 * time.Second) - assert.Equal(t, expectedCount, observedCount, "Cron did not trigger expected number of times, expected %d, got %d", expectedCount, observedCount) + assert.Eventually(t, func() bool { + return observedCount.Load() == expectedCount + }, time.Second, time.Millisecond*10, + "Cron did not trigger expected number of times, expected %d, got %d", expectedCount, observedCount.Load()) assert.NoErrorf(t, err, "error on read") + assert.NoError(t, c.Close()) } diff --git a/bindings/dubbo/dubbo_output.go b/bindings/dubbo/dubbo_output.go index ad55ea9f0..8310f4409 100644 --- a/bindings/dubbo/dubbo_output.go +++ b/bindings/dubbo/dubbo_output.go @@ -47,7 +47,7 @@ func NewDubboOutput(logger logger.Logger) bindings.OutputBinding { return dubboBinding } -func (out *DubboOutputBinding) Init(_ bindings.Metadata) error { +func (out *DubboOutputBinding) Init(_ context.Context, _ bindings.Metadata) error { dubboImpl.SetSerializer(constant.Hessian2Serialization, HessianSerializer{}) return nil } diff --git a/bindings/dubbo/dubbo_output_test.go b/bindings/dubbo/dubbo_output_test.go index c6ff53da9..46b57d979 100644 --- a/bindings/dubbo/dubbo_output_test.go +++ b/bindings/dubbo/dubbo_output_test.go @@ -54,12 +54,13 @@ func TestInvoke(t *testing.T) { // 0. init dapr provided and dubbo server stopCh := make(chan struct{}) defer close(stopCh) + // Create output and set serializer before go routine to prevent data race. + output := NewDubboOutput(logger.NewLogger("test")) + dubboImpl.SetSerializer(constant.Hessian2Serialization, HessianSerializer{}) go func() { assert.Nil(t, runDubboServer(stopCh)) }() time.Sleep(time.Second * 3) - dubboImpl.SetSerializer(constant.Hessian2Serialization, HessianSerializer{}) - output := NewDubboOutput(logger.NewLogger("test")) // 1. create req/rsp value reqUser := &User{Name: testName} diff --git a/bindings/gcp/bucket/bucket.go b/bindings/gcp/bucket/bucket.go index 9a83b1e96..02c00c4f6 100644 --- a/bindings/gcp/bucket/bucket.go +++ b/bindings/gcp/bucket/bucket.go @@ -83,14 +83,13 @@ func NewGCPStorage(logger logger.Logger) bindings.OutputBinding { } // Init performs connection parsing. -func (g *GCPStorage) Init(metadata bindings.Metadata) error { +func (g *GCPStorage) Init(ctx context.Context, metadata bindings.Metadata) error { m, b, err := g.parseMetadata(metadata) if err != nil { return err } clientOptions := option.WithCredentialsJSON(b) - ctx := context.Background() client, err := storage.NewClient(ctx, clientOptions) if err != nil { return err diff --git a/bindings/gcp/pubsub/pubsub.go b/bindings/gcp/pubsub/pubsub.go index f641f51a0..6df5abfd8 100644 --- a/bindings/gcp/pubsub/pubsub.go +++ b/bindings/gcp/pubsub/pubsub.go @@ -16,7 +16,10 @@ package pubsub import ( "context" "encoding/json" + "errors" "fmt" + "sync" + "sync/atomic" "cloud.google.com/go/pubsub" "google.golang.org/api/option" @@ -36,6 +39,9 @@ type GCPPubSub struct { client *pubsub.Client metadata *pubSubMetadata logger logger.Logger + closed atomic.Bool + closeCh chan struct{} + wg sync.WaitGroup } type pubSubMetadata struct { @@ -55,11 +61,14 @@ type pubSubMetadata struct { // NewGCPPubSub returns a new GCPPubSub instance. func NewGCPPubSub(logger logger.Logger) bindings.InputOutputBinding { - return &GCPPubSub{logger: logger} + return &GCPPubSub{ + logger: logger, + closeCh: make(chan struct{}), + } } // Init parses metadata and creates a new Pub Sub client. -func (g *GCPPubSub) Init(metadata bindings.Metadata) error { +func (g *GCPPubSub) Init(ctx context.Context, metadata bindings.Metadata) error { b, err := g.parseMetadata(metadata) if err != nil { return err @@ -71,7 +80,6 @@ func (g *GCPPubSub) Init(metadata bindings.Metadata) error { return err } clientOptions := option.WithCredentialsJSON(b) - ctx := context.Background() pubsubClient, err := pubsub.NewClient(ctx, pubsubMeta.ProjectID, clientOptions) if err != nil { return fmt.Errorf("error creating pubsub client: %s", err) @@ -88,7 +96,12 @@ func (g *GCPPubSub) parseMetadata(metadata bindings.Metadata) ([]byte, error) { } func (g *GCPPubSub) Read(ctx context.Context, handler bindings.Handler) error { + if g.closed.Load() { + return errors.New("binding is closed") + } + g.wg.Add(1) go func() { + defer g.wg.Done() sub := g.client.Subscription(g.metadata.Subscription) err := sub.Receive(ctx, func(c context.Context, m *pubsub.Message) { _, err := handler(c, &bindings.ReadResponse{ @@ -128,5 +141,9 @@ func (g *GCPPubSub) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*b } func (g *GCPPubSub) Close() error { + if g.closed.CompareAndSwap(false, true) { + close(g.closeCh) + } + defer g.wg.Wait() return g.client.Close() } diff --git a/bindings/graphql/graphql.go b/bindings/graphql/graphql.go index db8fb9cf3..da9a527a0 100644 --- a/bindings/graphql/graphql.go +++ b/bindings/graphql/graphql.go @@ -58,7 +58,7 @@ func NewGraphQL(logger logger.Logger) bindings.OutputBinding { } // Init initializes the GraphQL binding. -func (gql *GraphQL) Init(metadata bindings.Metadata) error { +func (gql *GraphQL) Init(_ context.Context, metadata bindings.Metadata) error { gql.logger.Debug("GraphQL Error: Initializing GraphQL binding") p := metadata.Properties diff --git a/bindings/http/http.go b/bindings/http/http.go index 00b55e8ce..3c3477e74 100644 --- a/bindings/http/http.go +++ b/bindings/http/http.go @@ -74,7 +74,7 @@ func NewHTTP(logger logger.Logger) bindings.OutputBinding { } // Init performs metadata parsing. -func (h *HTTPSource) Init(metadata bindings.Metadata) error { +func (h *HTTPSource) Init(_ context.Context, metadata bindings.Metadata) error { var err error if err = mapstructure.Decode(metadata.Properties, &h.metadata); err != nil { return err @@ -104,7 +104,7 @@ func (h *HTTPSource) Init(metadata bindings.Metadata) error { Transport: netTransport, } - if val, ok := metadata.Properties["errorIfNot2XX"]; ok { + if val := metadata.Properties["errorIfNot2XX"]; val != "" { h.errorIfNot2XX = utils.IsTruthy(val) } else { // Default behavior diff --git a/bindings/http/http_test.go b/bindings/http/http_test.go index 017f47c4c..75b1797e4 100644 --- a/bindings/http/http_test.go +++ b/bindings/http/http_test.go @@ -132,7 +132,7 @@ func InitBinding(s *httptest.Server, extraProps map[string]string) (bindings.Out } hs := NewHTTP(logger.NewLogger("test")) - err := hs.Init(m) + err := hs.Init(context.Background(), m) return hs, err } @@ -269,7 +269,7 @@ func InitBindingForHTTPS(s *httptest.Server, extraProps map[string]string) (bind m.Properties[k] = v } hs := NewHTTP(logger.NewLogger("test")) - err := hs.Init(m) + err := hs.Init(context.Background(), m) return hs, err } diff --git a/bindings/huawei/obs/obs.go b/bindings/huawei/obs/obs.go index 18615bb5c..8a60ea3a2 100644 --- a/bindings/huawei/obs/obs.go +++ b/bindings/huawei/obs/obs.go @@ -75,7 +75,7 @@ func NewHuaweiOBS(logger logger.Logger) bindings.OutputBinding { } // Init does metadata parsing and connection creation. -func (o *HuaweiOBS) Init(metadata bindings.Metadata) error { +func (o *HuaweiOBS) Init(_ context.Context, metadata bindings.Metadata) error { o.logger.Debugf("initializing Huawei OBS binding and parsing metadata") m, err := o.parseMetadata(metadata) diff --git a/bindings/huawei/obs/obs_test.go b/bindings/huawei/obs/obs_test.go index fa0f83702..605459451 100644 --- a/bindings/huawei/obs/obs_test.go +++ b/bindings/huawei/obs/obs_test.go @@ -92,7 +92,7 @@ func TestInit(t *testing.T) { "accessKey": "dummy-ak", "secretKey": "dummy-sk", } - err := obs.Init(m) + err := obs.Init(context.Background(), m) assert.Nil(t, err) }) t.Run("Init with missing bucket name", func(t *testing.T) { @@ -102,7 +102,7 @@ func TestInit(t *testing.T) { "accessKey": "dummy-ak", "secretKey": "dummy-sk", } - err := obs.Init(m) + err := obs.Init(context.Background(), m) assert.NotNil(t, err) assert.Equal(t, err, fmt.Errorf("missing obs bucket name")) }) @@ -113,7 +113,7 @@ func TestInit(t *testing.T) { "endpoint": "dummy-endpoint", "secretKey": "dummy-sk", } - err := obs.Init(m) + err := obs.Init(context.Background(), m) assert.NotNil(t, err) assert.Equal(t, err, fmt.Errorf("missing the huawei access key")) }) @@ -124,7 +124,7 @@ func TestInit(t *testing.T) { "endpoint": "dummy-endpoint", "accessKey": "dummy-ak", } - err := obs.Init(m) + err := obs.Init(context.Background(), m) assert.NotNil(t, err) assert.Equal(t, err, fmt.Errorf("missing the huawei secret key")) }) @@ -135,7 +135,7 @@ func TestInit(t *testing.T) { "accessKey": "dummy-ak", "secretKey": "dummy-sk", } - err := obs.Init(m) + err := obs.Init(context.Background(), m) assert.NotNil(t, err) assert.Equal(t, err, fmt.Errorf("missing obs endpoint")) }) diff --git a/bindings/influx/influx.go b/bindings/influx/influx.go index bc50e8c93..208050214 100644 --- a/bindings/influx/influx.go +++ b/bindings/influx/influx.go @@ -64,7 +64,7 @@ func NewInflux(logger logger.Logger) bindings.OutputBinding { } // Init does metadata parsing and connection establishment. -func (i *Influx) Init(metadata bindings.Metadata) error { +func (i *Influx) Init(_ context.Context, metadata bindings.Metadata) error { influxMeta, err := i.getInfluxMetadata(metadata) if err != nil { return err diff --git a/bindings/influx/influx_test.go b/bindings/influx/influx_test.go index aa049a5ba..33d506b59 100644 --- a/bindings/influx/influx_test.go +++ b/bindings/influx/influx_test.go @@ -54,7 +54,7 @@ func TestInflux_Init(t *testing.T) { assert.Nil(t, influx.client) m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{"Url": "a", "Token": "a", "Org": "a", "Bucket": "a"}}} - err := influx.Init(m) + err := influx.Init(context.Background(), m) assert.Nil(t, err) assert.NotNil(t, influx.queryAPI) diff --git a/bindings/input_binding.go b/bindings/input_binding.go index b1eebb7b9..4ce031911 100644 --- a/bindings/input_binding.go +++ b/bindings/input_binding.go @@ -16,6 +16,7 @@ package bindings import ( "context" "fmt" + "io" "github.com/dapr/components-contrib/health" ) @@ -23,18 +24,21 @@ import ( // InputBinding is the interface to define a binding that triggers on incoming events. type InputBinding interface { // Init passes connection and properties metadata to the binding implementation. - Init(metadata Metadata) error + Init(ctx context.Context, metadata Metadata) error // Read is a method that runs in background and triggers the callback function whenever an event arrives. Read(ctx context.Context, handler Handler) error + // Close is a method that closes the connection to the binding. Must be + // called when the binding is no longer needed to free up resources. + io.Closer } // Handler is the handler used to invoke the app handler. type Handler func(context.Context, *ReadResponse) ([]byte, error) -func PingInpBinding(inputBinding InputBinding) error { +func PingInpBinding(ctx context.Context, inputBinding InputBinding) error { // checks if this input binding has the ping option then executes if inputBindingWithPing, ok := inputBinding.(health.Pinger); ok { - return inputBindingWithPing.Ping() + return inputBindingWithPing.Ping(ctx) } else { return fmt.Errorf("ping is not implemented by this input binding") } diff --git a/bindings/kafka/kafka.go b/bindings/kafka/kafka.go index 11884a7ea..579a98f7b 100644 --- a/bindings/kafka/kafka.go +++ b/bindings/kafka/kafka.go @@ -15,7 +15,10 @@ package kafka import ( "context" + "errors" "strings" + "sync" + "sync/atomic" "github.com/dapr/kit/logger" @@ -29,12 +32,13 @@ const ( ) type Binding struct { - kafka *kafka.Kafka - publishTopic string - topics []string - logger logger.Logger - subscribeCtx context.Context - subscribeCancel context.CancelFunc + kafka *kafka.Kafka + publishTopic string + topics []string + logger logger.Logger + closeCh chan struct{} + closed atomic.Bool + wg sync.WaitGroup } // NewKafka returns a new kafka binding instance. @@ -43,15 +47,14 @@ func NewKafka(logger logger.Logger) bindings.InputOutputBinding { // in kafka binding component, disable consumer retry by default k.DefaultConsumeRetryEnabled = false return &Binding{ - kafka: k, - logger: logger, + kafka: k, + logger: logger, + closeCh: make(chan struct{}), } } -func (b *Binding) Init(metadata bindings.Metadata) error { - b.subscribeCtx, b.subscribeCancel = context.WithCancel(context.Background()) - - err := b.kafka.Init(metadata.Properties) +func (b *Binding) Init(ctx context.Context, metadata bindings.Metadata) error { + err := b.kafka.Init(ctx, metadata.Properties) if err != nil { return err } @@ -74,7 +77,10 @@ func (b *Binding) Operations() []bindings.OperationKind { } func (b *Binding) Close() (err error) { - b.subscribeCancel() + if b.closed.CompareAndSwap(false, true) { + close(b.closeCh) + } + defer b.wg.Wait() return b.kafka.Close() } @@ -84,6 +90,10 @@ func (b *Binding) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bin } func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error { + if b.closed.Load() { + return errors.New("error: binding is closed") + } + if len(b.topics) == 0 { b.logger.Warnf("kafka binding: no topic defined, input bindings will not be started") return nil @@ -96,31 +106,22 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error { for _, t := range b.topics { b.kafka.AddTopicHandler(t, handlerConfig) } - + b.wg.Add(1) go func() { - // Wait for context cancelation + defer b.wg.Done() + // Wait for context cancelation or closure. select { case <-ctx.Done(): - case <-b.subscribeCtx.Done(): + case <-b.closeCh: } - // Remove the topic handler before restarting the subscriber + // Remove the topic handlers. for _, t := range b.topics { b.kafka.RemoveTopicHandler(t) } - - // If the component's context has been canceled, do not re-subscribe - if b.subscribeCtx.Err() != nil { - return - } - - err := b.kafka.Subscribe(b.subscribeCtx) - if err != nil { - b.logger.Errorf("kafka binding: error re-subscribing: %v", err) - } }() - return b.kafka.Subscribe(b.subscribeCtx) + return b.kafka.Subscribe(ctx) } func adaptHandler(handler bindings.Handler) kafka.EventHandler { diff --git a/bindings/kubemq/kubemq.go b/bindings/kubemq/kubemq.go index ca6c589ce..331e37ca2 100644 --- a/bindings/kubemq/kubemq.go +++ b/bindings/kubemq/kubemq.go @@ -2,8 +2,11 @@ package kubemq import ( "context" + "errors" "fmt" "strings" + "sync" + "sync/atomic" "time" qs "github.com/kubemq-io/kubemq-go/queues_stream" @@ -19,31 +22,30 @@ type Kubemq interface { } type kubeMQ struct { - client *qs.QueuesStreamClient - opts *options - logger logger.Logger - ctx context.Context - ctxCancel context.CancelFunc + client *qs.QueuesStreamClient + opts *options + logger logger.Logger + closed atomic.Bool + closeCh chan struct{} + wg sync.WaitGroup } func NewKubeMQ(logger logger.Logger) Kubemq { return &kubeMQ{ - client: nil, - opts: nil, - logger: logger, - ctx: nil, - ctxCancel: nil, + client: nil, + opts: nil, + logger: logger, + closeCh: make(chan struct{}), } } -func (k *kubeMQ) Init(metadata bindings.Metadata) error { +func (k *kubeMQ) Init(ctx context.Context, metadata bindings.Metadata) error { opts, err := createOptions(metadata) if err != nil { return err } k.opts = opts - k.ctx, k.ctxCancel = context.WithCancel(context.Background()) - client, err := qs.NewQueuesStreamClient(k.ctx, + client, err := qs.NewQueuesStreamClient(ctx, qs.WithAddress(opts.host, opts.port), qs.WithCheckConnection(true), qs.WithAuthToken(opts.authToken), @@ -53,22 +55,39 @@ func (k *kubeMQ) Init(metadata bindings.Metadata) error { k.logger.Errorf("error init kubemq client error: %s", err.Error()) return err } - k.ctx, k.ctxCancel = context.WithCancel(context.Background()) k.client = client return nil } func (k *kubeMQ) Read(ctx context.Context, handler bindings.Handler) error { + if k.closed.Load() { + return errors.New("binding is closed") + } + k.wg.Add(2) + processCtx, cancel := context.WithCancel(ctx) go func() { + defer k.wg.Done() + defer cancel() + select { + case <-k.closeCh: + case <-processCtx.Done(): + } + }() + go func() { + defer k.wg.Done() for { - err := k.processQueueMessage(k.ctx, handler) + err := k.processQueueMessage(processCtx, handler) if err != nil { k.logger.Error(err.Error()) - time.Sleep(time.Second) } - if k.ctx.Err() != nil { - return + // If context cancelled or kubeMQ closed, exit. Otherwise, continue + // after a second. + select { + case <-time.After(time.Second): + continue + case <-processCtx.Done(): } + return } }() return nil @@ -82,7 +101,7 @@ func (k *kubeMQ) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bind SetPolicyExpirationSeconds(parsePolicyExpirationSeconds(req.Metadata)). SetPolicyMaxReceiveCount(parseSetPolicyMaxReceiveCount(req.Metadata)). SetPolicyMaxReceiveQueue(parsePolicyMaxReceiveQueue(req.Metadata)) - result, err := k.client.Send(k.ctx, queueMessage) + result, err := k.client.Send(ctx, queueMessage) if err != nil { return nil, err } @@ -101,6 +120,14 @@ func (k *kubeMQ) Operations() []bindings.OperationKind { return []bindings.OperationKind{bindings.CreateOperation} } +func (k *kubeMQ) Close() error { + if k.closed.CompareAndSwap(false, true) { + close(k.closeCh) + } + defer k.wg.Wait() + return k.client.Close() +} + func (k *kubeMQ) processQueueMessage(ctx context.Context, handler bindings.Handler) error { pr := qs.NewPollRequest(). SetChannel(k.opts.channel). diff --git a/bindings/kubemq/kubemq_integration_test.go b/bindings/kubemq/kubemq_integration_test.go index d9ead92ee..9f3b2ceee 100644 --- a/bindings/kubemq/kubemq_integration_test.go +++ b/bindings/kubemq/kubemq_integration_test.go @@ -106,7 +106,7 @@ func Test_kubeMQ_Init(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { kubemq := NewKubeMQ(logger.NewLogger("test")) - err := kubemq.Init(tt.meta) + err := kubemq.Init(context.Background(), tt.meta) if tt.wantErr { require.Error(t, err) } else { @@ -120,7 +120,7 @@ func Test_kubeMQ_Invoke_Read_Single_Message(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() kubemq := NewKubeMQ(logger.NewLogger("test")) - err := kubemq.Init(getDefaultMetadata("test.read.single")) + err := kubemq.Init(context.Background(), getDefaultMetadata("test.read.single")) require.NoError(t, err) dataReadCh := make(chan []byte) invokeRequest := &bindings.InvokeRequest{ @@ -147,7 +147,7 @@ func Test_kubeMQ_Invoke_Read_Single_MessageWithHandlerError(t *testing.T) { kubemq := NewKubeMQ(logger.NewLogger("test")) md := getDefaultMetadata("test.read.single.error") md.Properties["autoAcknowledged"] = "false" - err := kubemq.Init(md) + err := kubemq.Init(context.Background(), md) require.NoError(t, err) invokeRequest := &bindings.InvokeRequest{ Data: []byte("test"), @@ -182,7 +182,7 @@ func Test_kubeMQ_Invoke_Error(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() kubemq := NewKubeMQ(logger.NewLogger("test")) - err := kubemq.Init(getDefaultMetadata("***test***")) + err := kubemq.Init(context.Background(), getDefaultMetadata("***test***")) require.NoError(t, err) invokeRequest := &bindings.InvokeRequest{ diff --git a/bindings/kubernetes/kubernetes.go b/bindings/kubernetes/kubernetes.go index 3e344bedd..09b568ecf 100644 --- a/bindings/kubernetes/kubernetes.go +++ b/bindings/kubernetes/kubernetes.go @@ -18,6 +18,8 @@ import ( "encoding/json" "errors" "strconv" + "sync" + "sync/atomic" "time" v1 "k8s.io/api/core/v1" @@ -35,6 +37,9 @@ type kubernetesInput struct { namespace string resyncPeriod time.Duration logger logger.Logger + closed atomic.Bool + closeCh chan struct{} + wg sync.WaitGroup } type EventResponse struct { @@ -45,10 +50,13 @@ type EventResponse struct { // NewKubernetes returns a new Kubernetes event input binding. func NewKubernetes(logger logger.Logger) bindings.InputBinding { - return &kubernetesInput{logger: logger} + return &kubernetesInput{ + logger: logger, + closeCh: make(chan struct{}), + } } -func (k *kubernetesInput) Init(metadata bindings.Metadata) error { +func (k *kubernetesInput) Init(ctx context.Context, metadata bindings.Metadata) error { client, err := kubeclient.GetKubeClient() if err != nil { return err @@ -78,6 +86,9 @@ func (k *kubernetesInput) parseMetadata(metadata bindings.Metadata) error { } func (k *kubernetesInput) Read(ctx context.Context, handler bindings.Handler) error { + if k.closed.Load() { + return errors.New("binding is closed") + } watchlist := cache.NewListWatchFromClient( k.kubeClient.CoreV1().RESTClient(), "events", @@ -126,12 +137,28 @@ func (k *kubernetesInput) Read(ctx context.Context, handler bindings.Handler) er }, ) + k.wg.Add(3) + readCtx, cancel := context.WithCancel(ctx) + + // catch when binding is closed. + go func() { + defer k.wg.Done() + defer cancel() + select { + case <-readCtx.Done(): + case <-k.closeCh: + } + }() + // Start the controller in backgound - stopCh := make(chan struct{}) - go controller.Run(stopCh) + go func() { + defer k.wg.Done() + controller.Run(readCtx.Done()) + }() // Watch for new messages and for context cancellation go func() { + defer k.wg.Done() var ( obj EventResponse data []byte @@ -148,8 +175,7 @@ func (k *kubernetesInput) Read(ctx context.Context, handler bindings.Handler) er Data: data, }) } - case <-ctx.Done(): - close(stopCh) + case <-readCtx.Done(): return } } @@ -157,3 +183,11 @@ func (k *kubernetesInput) Read(ctx context.Context, handler bindings.Handler) er return nil } + +func (k *kubernetesInput) Close() error { + if k.closed.CompareAndSwap(false, true) { + close(k.closeCh) + } + k.wg.Wait() + return nil +} diff --git a/bindings/localstorage/localstorage.go b/bindings/localstorage/localstorage.go index 562c618d0..fc66b46d3 100644 --- a/bindings/localstorage/localstorage.go +++ b/bindings/localstorage/localstorage.go @@ -64,7 +64,7 @@ func NewLocalStorage(logger logger.Logger) bindings.OutputBinding { } // Init performs metadata parsing. -func (ls *LocalStorage) Init(metadata bindings.Metadata) error { +func (ls *LocalStorage) Init(_ context.Context, metadata bindings.Metadata) error { m, err := ls.parseMetadata(metadata) if err != nil { return fmt.Errorf("failed to parse metadata: %w", err) diff --git a/bindings/mqtt3/mqtt.go b/bindings/mqtt3/mqtt.go index d5c86dcdb..bff7fdabd 100644 --- a/bindings/mqtt3/mqtt.go +++ b/bindings/mqtt3/mqtt.go @@ -40,30 +40,29 @@ type MQTT struct { logger logger.Logger isSubscribed atomic.Bool readHandler bindings.Handler - ctx context.Context - cancel context.CancelFunc backOff backoff.BackOff + closeCh chan struct{} + closed atomic.Bool + wg sync.WaitGroup } // NewMQTT returns a new MQTT instance. func NewMQTT(logger logger.Logger) bindings.InputOutputBinding { return &MQTT{ - logger: logger, + logger: logger, + closeCh: make(chan struct{}), } } // Init does MQTT connection parsing. -func (m *MQTT) Init(metadata bindings.Metadata) (err error) { +func (m *MQTT) Init(ctx context.Context, metadata bindings.Metadata) (err error) { m.metadata, err = parseMQTTMetaData(metadata, m.logger) if err != nil { return err } - m.ctx, m.cancel = context.WithCancel(context.Background()) - // TODO: Make the backoff configurable for constant or exponential - b := backoff.NewConstantBackOff(5 * time.Second) - m.backOff = backoff.WithContext(b, m.ctx) + m.backOff = backoff.NewConstantBackOff(5 * time.Second) return nil } @@ -104,7 +103,7 @@ func (m *MQTT) getProducer() (mqtt.Client, error) { return p, nil } -func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) { +func (m *MQTT) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) { producer, err := m.getProducer() if err != nil { return nil, fmt.Errorf("failed to create producer connection: %w", err) @@ -118,7 +117,7 @@ func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (* bo := backoff.WithMaxRetries( backoff.NewConstantBackOff(200*time.Millisecond), 3, ) - bo = backoff.WithContext(bo, parentCtx) + bo = backoff.WithContext(bo, ctx) topic, ok := req.Metadata[mqttTopic] if !ok || topic == "" { @@ -127,14 +126,13 @@ func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (* } return nil, retry.NotifyRecover(func() (err error) { token := producer.Publish(topic, m.metadata.qos, m.metadata.retain, req.Data) - ctx, cancel := context.WithTimeout(parentCtx, defaultWait) - defer cancel() select { case <-token.Done(): err = token.Error() - case <-m.ctx.Done(): - // Context canceled - err = m.ctx.Err() + case <-m.closeCh: + err = errors.New("mqtt client closed") + case <-time.After(defaultWait): + err = errors.New("mqtt client timeout") case <-ctx.Done(): // Context canceled err = ctx.Err() @@ -151,6 +149,10 @@ func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (* } func (m *MQTT) Read(ctx context.Context, handler bindings.Handler) error { + if m.closed.Load() { + return errors.New("error: binding is closed") + } + // If the subscription is already active, wait 2s before retrying (in case we're still disconnecting), otherwise return an error if !m.isSubscribed.CompareAndSwap(false, true) { m.logger.Debug("Subscription is already active; waiting 2s before retrying…") @@ -177,11 +179,14 @@ func (m *MQTT) Read(ctx context.Context, handler bindings.Handler) error { // In background, watch for contexts cancelation and stop the connection // However, do not call "unsubscribe" which would cause the broker to stop tracking the last message received by this consumer group + m.wg.Add(1) go func() { + defer m.wg.Done() + select { case <-ctx.Done(): // nop - case <-m.ctx.Done(): + case <-m.closeCh: // nop } @@ -208,14 +213,12 @@ func (m *MQTT) connect(clientID string, isSubscriber bool) (mqtt.Client, error) } client := mqtt.NewClient(opts) - ctx, cancel := context.WithTimeout(m.ctx, defaultWait) - defer cancel() token := client.Connect() select { case <-token.Done(): err = token.Error() - case <-ctx.Done(): - err = ctx.Err() + case <-time.After(defaultWait): + err = errors.New("mqtt client timed out connecting") } if err != nil { return nil, fmt.Errorf("failed to connect: %w", err) @@ -290,43 +293,46 @@ func (m *MQTT) createClientOptions(uri *url.URL, clientID string) *mqtt.ClientOp return opts } -func (m *MQTT) handleMessage(client mqtt.Client, mqttMsg mqtt.Message) { - // We're using m.ctx as context in this method because we don't have access to the Read context - // Canceling the Read context makes Read invoke "Disconnect" anyways - ctx := m.ctx +func (m *MQTT) handleMessage() func(client mqtt.Client, mqttMsg mqtt.Message) { + return func(client mqtt.Client, mqttMsg mqtt.Message) { + bo := m.backOff + if m.metadata.backOffMaxRetries >= 0 { + bo = backoff.WithMaxRetries(bo, uint64(m.metadata.backOffMaxRetries)) + } - var bo backoff.BackOff = backoff.WithContext(m.backOff, ctx) - if m.metadata.backOffMaxRetries >= 0 { - bo = backoff.WithMaxRetries(bo, uint64(m.metadata.backOffMaxRetries)) - } + err := retry.NotifyRecover( + func() error { + m.logger.Debugf("Processing MQTT message %s/%d", mqttMsg.Topic(), mqttMsg.MessageID()) + // Use a background context here so that the context is not tied to the + // first Invoke first created the producer. + // TODO: add context to mqtt library, and add a OnConnectWithContext option + // to change this func signature to + // func(c mqtt.Client, ctx context.Context) + _, err := m.readHandler(context.Background(), &bindings.ReadResponse{ + Data: mqttMsg.Payload(), + Metadata: map[string]string{ + mqttTopic: mqttMsg.Topic(), + }, + }) + if err != nil { + return err + } - err := retry.NotifyRecover( - func() error { - m.logger.Debugf("Processing MQTT message %s/%d", mqttMsg.Topic(), mqttMsg.MessageID()) - _, err := m.readHandler(ctx, &bindings.ReadResponse{ - Data: mqttMsg.Payload(), - Metadata: map[string]string{ - mqttTopic: mqttMsg.Topic(), - }, - }) - if err != nil { - return err - } - - // Ack the message on success - mqttMsg.Ack() - return nil - }, - bo, - func(err error, d time.Duration) { - m.logger.Errorf("Error processing MQTT message: %s/%d. Retrying…", mqttMsg.Topic(), mqttMsg.MessageID()) - }, - func() { - m.logger.Infof("Successfully processed MQTT message after it previously failed: %s/%d", mqttMsg.Topic(), mqttMsg.MessageID()) - }, - ) - if err != nil { - m.logger.Errorf("Failed processing MQTT message: %s/%d: %v", mqttMsg.Topic(), mqttMsg.MessageID(), err) + // Ack the message on success + mqttMsg.Ack() + return nil + }, + bo, + func(err error, d time.Duration) { + m.logger.Errorf("Error processing MQTT message: %s/%d. Retrying…", mqttMsg.Topic(), mqttMsg.MessageID()) + }, + func() { + m.logger.Infof("Successfully processed MQTT message after it previously failed: %s/%d", mqttMsg.Topic(), mqttMsg.MessageID()) + }, + ) + if err != nil { + m.logger.Errorf("Failed processing MQTT message: %s/%d: %v", mqttMsg.Topic(), mqttMsg.MessageID(), err) + } } } @@ -336,17 +342,15 @@ func (m *MQTT) createSubscriberClientOptions(uri *url.URL, clientID string) *mqt // On (re-)connection, add the topic subscription opts.OnConnect = func(c mqtt.Client) { - token := c.Subscribe(m.metadata.topic, m.metadata.qos, m.handleMessage) + token := c.Subscribe(m.metadata.topic, m.metadata.qos, m.handleMessage()) var err error - subscribeCtx, subscribeCancel := context.WithTimeout(m.ctx, defaultWait) - defer subscribeCancel() select { case <-token.Done(): // Subscription went through (sucecessfully or not) err = token.Error() - case <-subscribeCtx.Done(): - err = fmt.Errorf("error while waiting for subscription token: %w", subscribeCtx.Err()) + case <-time.After(defaultWait): + err = errors.New("timed out waiting for subscription to complete") } // Nothing we can do in case of errors besides logging them @@ -363,13 +367,16 @@ func (m *MQTT) Close() error { m.producerLock.Lock() defer m.producerLock.Unlock() - // Canceling the context also causes Read to stop receiving messages - m.cancel() + if m.closed.CompareAndSwap(false, true) { + close(m.closeCh) + } if m.producer != nil { m.producer.Disconnect(200) m.producer = nil } + m.wg.Wait() + return nil } diff --git a/bindings/mqtt3/mqtt_integration_test.go b/bindings/mqtt3/mqtt_integration_test.go index b83534be6..7db2cafbb 100644 --- a/bindings/mqtt3/mqtt_integration_test.go +++ b/bindings/mqtt3/mqtt_integration_test.go @@ -49,6 +49,7 @@ func getConnectionString() string { func TestInvokeWithTopic(t *testing.T) { t.Parallel() + ctx := context.Background() url := getConnectionString() if url == "" { @@ -79,7 +80,7 @@ func TestInvokeWithTopic(t *testing.T) { logger := logger.NewLogger("test") r := NewMQTT(logger).(*MQTT) - err := r.Init(metadata) + err := r.Init(ctx, metadata) assert.Nil(t, err) conn, err := r.connect(uuid.NewString(), false) @@ -127,4 +128,5 @@ func TestInvokeWithTopic(t *testing.T) { assert.True(t, ok) assert.Equal(t, dataCustomized, mqttMessage.Payload()) assert.Equal(t, topicCustomized, mqttMessage.Topic()) + assert.NoError(t, r.Close()) } diff --git a/bindings/mqtt3/mqtt_test.go b/bindings/mqtt3/mqtt_test.go index 73b3f6ae6..1977c7e4a 100644 --- a/bindings/mqtt3/mqtt_test.go +++ b/bindings/mqtt3/mqtt_test.go @@ -205,7 +205,6 @@ func TestParseMetadata(t *testing.T) { logger := logger.NewLogger("test") m := NewMQTT(logger).(*MQTT) m.backOff = backoff.NewConstantBackOff(5 * time.Second) - m.ctx, m.cancel = context.WithCancel(context.Background()) m.readHandler = func(ctx context.Context, r *bindings.ReadResponse) ([]byte, error) { assert.Equal(t, payload, r.Data) metadata := r.Metadata @@ -215,7 +214,7 @@ func TestParseMetadata(t *testing.T) { return r.Data, nil } - m.handleMessage(nil, &mqttMockMessage{ + m.handleMessage()(nil, &mqttMockMessage{ topic: topic, payload: payload, }) diff --git a/bindings/mysql/mysql.go b/bindings/mysql/mysql.go index 9769ee7c6..1fba4aee7 100644 --- a/bindings/mysql/mysql.go +++ b/bindings/mysql/mysql.go @@ -81,7 +81,7 @@ func NewMysql(logger logger.Logger) bindings.OutputBinding { } // Init initializes the MySQL binding. -func (m *Mysql) Init(metadata bindings.Metadata) error { +func (m *Mysql) Init(ctx context.Context, metadata bindings.Metadata) error { m.logger.Debug("Initializing MySql binding") p := metadata.Properties @@ -115,7 +115,7 @@ func (m *Mysql) Init(metadata bindings.Metadata) error { return err } - err = db.Ping() + err = db.PingContext(ctx) if err != nil { return fmt.Errorf("unable to ping the DB: %w", err) } diff --git a/bindings/mysql/mysql_integration_test.go b/bindings/mysql/mysql_integration_test.go index 08d605c89..fcc173526 100644 --- a/bindings/mysql/mysql_integration_test.go +++ b/bindings/mysql/mysql_integration_test.go @@ -75,7 +75,7 @@ func TestMysqlIntegration(t *testing.T) { b := NewMysql(logger.NewLogger("test")).(*Mysql) m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{connectionURLKey: url}}} - if err := b.Init(m); err != nil { + if err := b.Init(context.Background(), m); err != nil { t.Fatal(err) } diff --git a/bindings/nacos/nacos.go b/bindings/nacos/nacos.go index 165c8ba8a..a5db9c320 100644 --- a/bindings/nacos/nacos.go +++ b/bindings/nacos/nacos.go @@ -21,6 +21,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/nacos-group/nacos-sdk-go/v2/clients" @@ -56,6 +57,9 @@ type Nacos struct { logger logger.Logger configClient config_client.IConfigClient //nolint:nosnakecase readHandler func(ctx context.Context, response *bindings.ReadResponse) ([]byte, error) + wg sync.WaitGroup + closed atomic.Bool + closeCh chan struct{} } // NewNacos returns a new Nacos instance. @@ -63,11 +67,12 @@ func NewNacos(logger logger.Logger) bindings.OutputBinding { return &Nacos{ logger: logger, watchesLock: sync.Mutex{}, + closeCh: make(chan struct{}), } } // Init implements InputBinding/OutputBinding's Init method. -func (n *Nacos) Init(metadata bindings.Metadata) error { +func (n *Nacos) Init(_ context.Context, metadata bindings.Metadata) error { n.settings = Settings{ Timeout: defaultTimeout, } @@ -146,6 +151,10 @@ func (n *Nacos) createConfigClient() error { // Read implements InputBinding's Read method. func (n *Nacos) Read(ctx context.Context, handler bindings.Handler) error { + if n.closed.Load() { + return errors.New("binding is closed") + } + n.readHandler = handler n.watchesLock.Lock() @@ -154,9 +163,14 @@ func (n *Nacos) Read(ctx context.Context, handler bindings.Handler) error { } n.watchesLock.Unlock() + n.wg.Add(1) go func() { + defer n.wg.Done() // Cancel all listeners when the context is done - <-ctx.Done() + select { + case <-ctx.Done(): + case <-n.closeCh: + } n.cancelAllListeners() }() @@ -165,8 +179,14 @@ func (n *Nacos) Read(ctx context.Context, handler bindings.Handler) error { // Close implements cancel all listeners, see https://github.com/dapr/components-contrib/issues/779 func (n *Nacos) Close() error { + if n.closed.CompareAndSwap(false, true) { + close(n.closeCh) + } + n.cancelAllListeners() + n.wg.Wait() + return nil } @@ -223,7 +243,11 @@ func (n *Nacos) addListener(ctx context.Context, config configParam) { func (n *Nacos) addListenerFoInputBinding(ctx context.Context, config configParam) { if n.addToWatches(config) { - go n.addListener(ctx, config) + n.wg.Add(1) + go func() { + defer n.wg.Done() + n.addListener(ctx, config) + }() } } diff --git a/bindings/nacos/nacos_test.go b/bindings/nacos/nacos_test.go index 4a4da1b3a..4b74b0ef4 100644 --- a/bindings/nacos/nacos_test.go +++ b/bindings/nacos/nacos_test.go @@ -35,7 +35,7 @@ func TestInputBindingRead(t *testing.T) { //nolint:paralleltest m.Properties, err = getNacosLocalCacheMetadata() require.NoError(t, err) n := NewNacos(logger.NewLogger("test")).(*Nacos) - err = n.Init(m) + err = n.Init(context.Background(), m) require.NoError(t, err) var count int32 ch := make(chan bool, 1) diff --git a/bindings/output_binding.go b/bindings/output_binding.go index e96e9f2ad..5b6d0f52b 100644 --- a/bindings/output_binding.go +++ b/bindings/output_binding.go @@ -22,15 +22,15 @@ import ( // OutputBinding is the interface for an output binding, allowing users to invoke remote systems with optional payloads. type OutputBinding interface { - Init(metadata Metadata) error + Init(ctx context.Context, metadata Metadata) error Invoke(ctx context.Context, req *InvokeRequest) (*InvokeResponse, error) Operations() []OperationKind } -func PingOutBinding(outputBinding OutputBinding) error { +func PingOutBinding(ctx context.Context, outputBinding OutputBinding) error { // checks if this output binding has the ping option then executes if outputBindingWithPing, ok := outputBinding.(health.Pinger); ok { - return outputBindingWithPing.Ping() + return outputBindingWithPing.Ping(ctx) } else { return fmt.Errorf("ping is not implemented by this output binding") } diff --git a/bindings/postgres/postgres.go b/bindings/postgres/postgres.go index c9946f7c9..1b6fe39ee 100644 --- a/bindings/postgres/postgres.go +++ b/bindings/postgres/postgres.go @@ -48,7 +48,7 @@ func NewPostgres(logger logger.Logger) bindings.OutputBinding { } // Init initializes the PostgreSql binding. -func (p *Postgres) Init(metadata bindings.Metadata) error { +func (p *Postgres) Init(ctx context.Context, metadata bindings.Metadata) error { url, ok := metadata.Properties[connectionURLKey] if !ok || url == "" { return errors.Errorf("required metadata not set: %s", connectionURLKey) @@ -59,7 +59,9 @@ func (p *Postgres) Init(metadata bindings.Metadata) error { return errors.Wrap(err, "error opening DB connection") } - p.db, err = pgxpool.NewWithConfig(context.Background(), poolConfig) + // This context doesn't control the lifetime of the connection pool, and is + // only scoped to postgres creating resources at init. + p.db, err = pgxpool.NewWithConfig(ctx, poolConfig) if err != nil { return errors.Wrap(err, "unable to ping the DB") } diff --git a/bindings/postgres/postgres_test.go b/bindings/postgres/postgres_test.go index 5359c4321..78adb43ab 100644 --- a/bindings/postgres/postgres_test.go +++ b/bindings/postgres/postgres_test.go @@ -64,7 +64,7 @@ func TestPostgresIntegration(t *testing.T) { // live DB test b := NewPostgres(logger.NewLogger("test")).(*Postgres) m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{connectionURLKey: url}}} - if err := b.Init(m); err != nil { + if err := b.Init(context.Background(), m); err != nil { t.Fatal(err) } diff --git a/bindings/postmark/postmark.go b/bindings/postmark/postmark.go index 5729643f7..760b0ca29 100644 --- a/bindings/postmark/postmark.go +++ b/bindings/postmark/postmark.go @@ -74,7 +74,7 @@ func (p *Postmark) parseMetadata(meta bindings.Metadata) (postmarkMetadata, erro } // Init does metadata parsing and not much else :). -func (p *Postmark) Init(metadata bindings.Metadata) error { +func (p *Postmark) Init(_ context.Context, metadata bindings.Metadata) error { // Parse input metadata meta, err := p.parseMetadata(metadata) if err != nil { diff --git a/bindings/rabbitmq/rabbitmq.go b/bindings/rabbitmq/rabbitmq.go index 70f80cdcc..f04165557 100644 --- a/bindings/rabbitmq/rabbitmq.go +++ b/bindings/rabbitmq/rabbitmq.go @@ -19,6 +19,8 @@ import ( "fmt" "math" "strconv" + "sync" + "sync/atomic" "time" amqp "github.com/rabbitmq/amqp091-go" @@ -50,6 +52,9 @@ type RabbitMQ struct { metadata rabbitMQMetadata logger logger.Logger queue amqp.Queue + closed atomic.Bool + closeCh chan struct{} + wg sync.WaitGroup } // Metadata is the rabbitmq config. @@ -66,11 +71,14 @@ type rabbitMQMetadata struct { // NewRabbitMQ returns a new rabbitmq instance. func NewRabbitMQ(logger logger.Logger) bindings.InputOutputBinding { - return &RabbitMQ{logger: logger} + return &RabbitMQ{ + logger: logger, + closeCh: make(chan struct{}), + } } // Init does metadata parsing and connection creation. -func (r *RabbitMQ) Init(metadata bindings.Metadata) error { +func (r *RabbitMQ) Init(_ context.Context, metadata bindings.Metadata) error { err := r.parseMetadata(metadata) if err != nil { return err @@ -226,6 +234,10 @@ func (r *RabbitMQ) declareQueue() (amqp.Queue, error) { } func (r *RabbitMQ) Read(ctx context.Context, handler bindings.Handler) error { + if r.closed.Load() { + return errors.New("binding already closed") + } + msgs, err := r.channel.Consume( r.queue.Name, "", @@ -239,14 +251,27 @@ func (r *RabbitMQ) Read(ctx context.Context, handler bindings.Handler) error { return err } + readCtx, cancel := context.WithCancel(ctx) + + r.wg.Add(2) go func() { + defer r.wg.Done() + defer cancel() + select { + case <-r.closeCh: + case <-readCtx.Done(): + } + }() + + go func() { + defer r.wg.Done() var err error for { select { - case <-ctx.Done(): + case <-readCtx.Done(): return case d := <-msgs: - _, err = handler(ctx, &bindings.ReadResponse{ + _, err = handler(readCtx, &bindings.ReadResponse{ Data: d.Body, }) if err != nil { @@ -260,3 +285,11 @@ func (r *RabbitMQ) Read(ctx context.Context, handler bindings.Handler) error { return nil } + +func (r *RabbitMQ) Close() error { + if r.closed.CompareAndSwap(false, true) { + close(r.closeCh) + } + defer r.wg.Wait() + return r.channel.Close() +} diff --git a/bindings/rabbitmq/rabbitmq_integration_test.go b/bindings/rabbitmq/rabbitmq_integration_test.go index e3939bbce..79bdc33a4 100644 --- a/bindings/rabbitmq/rabbitmq_integration_test.go +++ b/bindings/rabbitmq/rabbitmq_integration_test.go @@ -85,7 +85,7 @@ func TestQueuesWithTTL(t *testing.T) { logger := logger.NewLogger("test") r := NewRabbitMQ(logger).(*RabbitMQ) - err := r.Init(metadata) + err := r.Init(context.Background(), metadata) assert.Nil(t, err) // Assert that if waited too long, we won't see any message @@ -117,6 +117,7 @@ func TestQueuesWithTTL(t *testing.T) { assert.True(t, ok) msgBody := string(msg.Body) assert.Equal(t, testMsgContent, msgBody) + assert.NoError(t, r.Close()) } func TestPublishingWithTTL(t *testing.T) { @@ -144,7 +145,7 @@ func TestPublishingWithTTL(t *testing.T) { logger := logger.NewLogger("test") rabbitMQBinding1 := NewRabbitMQ(logger).(*RabbitMQ) - err := rabbitMQBinding1.Init(metadata) + err := rabbitMQBinding1.Init(context.Background(), metadata) assert.Nil(t, err) // Assert that if waited too long, we won't see any message @@ -175,7 +176,7 @@ func TestPublishingWithTTL(t *testing.T) { // Getting before it is expired, should return it rabbitMQBinding2 := NewRabbitMQ(logger).(*RabbitMQ) - err = rabbitMQBinding2.Init(metadata) + err = rabbitMQBinding2.Init(context.Background(), metadata) assert.Nil(t, err) const testMsgContent = "test_msg" @@ -193,6 +194,9 @@ func TestPublishingWithTTL(t *testing.T) { assert.True(t, ok) msgBody := string(msg.Body) assert.Equal(t, testMsgContent, msgBody) + + assert.NoError(t, rabbitMQBinding1.Close()) + assert.NoError(t, rabbitMQBinding1.Close()) } func TestExclusiveQueue(t *testing.T) { @@ -222,7 +226,7 @@ func TestExclusiveQueue(t *testing.T) { logger := logger.NewLogger("test") r := NewRabbitMQ(logger).(*RabbitMQ) - err := r.Init(metadata) + err := r.Init(context.Background(), metadata) assert.Nil(t, err) // Assert that if waited too long, we won't see any message @@ -276,7 +280,7 @@ func TestPublishWithPriority(t *testing.T) { logger := logger.NewLogger("test") r := NewRabbitMQ(logger).(*RabbitMQ) - err := r.Init(metadata) + err := r.Init(context.Background(), metadata) assert.Nil(t, err) // Assert that if waited too long, we won't see any message diff --git a/bindings/redis/redis.go b/bindings/redis/redis.go index 305baeb44..03551202c 100644 --- a/bindings/redis/redis.go +++ b/bindings/redis/redis.go @@ -28,9 +28,6 @@ type Redis struct { client rediscomponent.RedisClient clientSettings *rediscomponent.Settings logger logger.Logger - - ctx context.Context - cancel context.CancelFunc } // NewRedis returns a new redis bindings instance. @@ -39,15 +36,13 @@ func NewRedis(logger logger.Logger) bindings.OutputBinding { } // Init performs metadata parsing and connection creation. -func (r *Redis) Init(meta bindings.Metadata) (err error) { +func (r *Redis) Init(ctx context.Context, meta bindings.Metadata) (err error) { r.client, r.clientSettings, err = rediscomponent.ParseClientFromProperties(meta.Properties, nil) if err != nil { return err } - r.ctx, r.cancel = context.WithCancel(context.Background()) - - _, err = r.client.PingResult(r.ctx) + _, err = r.client.PingResult(ctx) if err != nil { return fmt.Errorf("redis binding: error connecting to redis at %s: %s", r.clientSettings.Host, err) } @@ -55,8 +50,8 @@ func (r *Redis) Init(meta bindings.Metadata) (err error) { return err } -func (r *Redis) Ping() error { - if _, err := r.client.PingResult(r.ctx); err != nil { +func (r *Redis) Ping(ctx context.Context) error { + if _, err := r.client.PingResult(ctx); err != nil { return fmt.Errorf("redis binding: error connecting to redis at %s: %s", r.clientSettings.Host, err) } @@ -101,7 +96,5 @@ func (r *Redis) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindi } func (r *Redis) Close() error { - r.cancel() - return r.client.Close() } diff --git a/bindings/redis/redis_test.go b/bindings/redis/redis_test.go index be5bb6519..59d71924b 100644 --- a/bindings/redis/redis_test.go +++ b/bindings/redis/redis_test.go @@ -40,7 +40,6 @@ func TestInvokeCreate(t *testing.T) { client: c, logger: logger.NewLogger("test"), } - bind.ctx, bind.cancel = context.WithCancel(context.Background()) _, err := c.DoRead(context.Background(), "GET", testKey) assert.Equal(t, redis.Nil, err) @@ -66,7 +65,6 @@ func TestInvokeGet(t *testing.T) { client: c, logger: logger.NewLogger("test"), } - bind.ctx, bind.cancel = context.WithCancel(context.Background()) err := c.DoWrite(context.Background(), "SET", testKey, testData) assert.Equal(t, nil, err) @@ -87,7 +85,6 @@ func TestInvokeDelete(t *testing.T) { client: c, logger: logger.NewLogger("test"), } - bind.ctx, bind.cancel = context.WithCancel(context.Background()) err := c.DoWrite(context.Background(), "SET", testKey, testData) assert.Equal(t, nil, err) diff --git a/bindings/rethinkdb/statechange/statechange.go b/bindings/rethinkdb/statechange/statechange.go index c6203bfcf..e820b1b97 100644 --- a/bindings/rethinkdb/statechange/statechange.go +++ b/bindings/rethinkdb/statechange/statechange.go @@ -18,6 +18,8 @@ import ( "encoding/json" "strconv" "strings" + "sync" + "sync/atomic" "time" r "github.com/dancannon/gorethink" @@ -34,6 +36,9 @@ type Binding struct { logger logger.Logger session *r.Session config StateConfig + closed atomic.Bool + closeCh chan struct{} + wg sync.WaitGroup } // StateConfig is the binding config. @@ -45,12 +50,13 @@ type StateConfig struct { // NewRethinkDBStateChangeBinding returns a new RethinkDB actor event input binding. func NewRethinkDBStateChangeBinding(logger logger.Logger) bindings.InputBinding { return &Binding{ - logger: logger, + logger: logger, + closeCh: make(chan struct{}), } } // Init initializes the RethinkDB binding. -func (b *Binding) Init(metadata bindings.Metadata) error { +func (b *Binding) Init(ctx context.Context, metadata bindings.Metadata) error { cfg, err := metadataToConfig(metadata.Properties, b.logger) if err != nil { return errors.Wrap(err, "unable to parse metadata properties") @@ -68,6 +74,10 @@ func (b *Binding) Init(metadata bindings.Metadata) error { // Read triggers the RethinkDB scheduler. func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error { + if b.closed.Load() { + return errors.New("binding is closed") + } + b.logger.Infof("subscribing to state changes in %s.%s...", b.config.Database, b.config.Table) cursor, err := r.DB(b.config.Database). Table(b.config.Table). @@ -81,8 +91,21 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error { errors.Wrapf(err, "error connecting to table %s", b.config.Table) } + readCtx, cancel := context.WithCancel(ctx) + b.wg.Add(2) + go func() { - for ctx.Err() == nil { + defer b.wg.Done() + defer cancel() + select { + case <-b.closeCh: + case <-readCtx.Done(): + } + }() + + go func() { + defer b.wg.Done() + for readCtx.Err() == nil { var change interface{} ok := cursor.Next(&change) if !ok { @@ -105,7 +128,7 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error { }, } - if _, err := handler(ctx, resp); err != nil { + if _, err := handler(readCtx, resp); err != nil { b.logger.Errorf("error invoking change handler: %v", err) continue } @@ -117,6 +140,14 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error { return nil } +func (b *Binding) Close() error { + if b.closed.CompareAndSwap(false, true) { + close(b.closeCh) + } + defer b.wg.Wait() + return b.session.Close() +} + func metadataToConfig(cfg map[string]string, logger logger.Logger) (StateConfig, error) { c := StateConfig{} for k, v := range cfg { diff --git a/bindings/rethinkdb/statechange/statechange_test.go b/bindings/rethinkdb/statechange/statechange_test.go index a4a405c1a..61995edf9 100644 --- a/bindings/rethinkdb/statechange/statechange_test.go +++ b/bindings/rethinkdb/statechange/statechange_test.go @@ -71,7 +71,7 @@ func TestBinding(t *testing.T) { assert.NotNil(t, m.Properties) b := getNewRethinkActorBinding() - err := b.Init(m) + err := b.Init(context.Background(), m) assert.NoErrorf(t, err, "error initializing") ctx, cancel := context.WithCancel(context.Background()) diff --git a/bindings/rocketmq/rocketmq.go b/bindings/rocketmq/rocketmq.go index 477c3938a..8b9ed0e50 100644 --- a/bindings/rocketmq/rocketmq.go +++ b/bindings/rocketmq/rocketmq.go @@ -18,6 +18,8 @@ import ( "errors" "fmt" "strings" + "sync" + "sync/atomic" "time" mqc "github.com/apache/rocketmq-client-go/v2/consumer" @@ -35,27 +37,27 @@ type RocketMQ struct { settings Settings producer mqw.Producer - ctx context.Context - cancel context.CancelFunc backOffConfig retry.Config + closeCh chan struct{} + closed atomic.Bool + wg sync.WaitGroup } func NewRocketMQ(l logger.Logger) *RocketMQ { return &RocketMQ{ //nolint:exhaustivestruct logger: l, producer: nil, + closeCh: make(chan struct{}), } } // Init performs metadata parsing. -func (a *RocketMQ) Init(metadata bindings.Metadata) error { +func (a *RocketMQ) Init(ctx context.Context, metadata bindings.Metadata) error { var err error if err = a.settings.Decode(metadata.Properties); err != nil { return err } - a.ctx, a.cancel = context.WithCancel(context.Background()) - // Default retry configuration is used if no // backOff properties are set. if err = retry.DecodeConfigWithPrefix( @@ -75,6 +77,10 @@ func (a *RocketMQ) Init(metadata bindings.Metadata) error { // Read triggers the rocketmq subscription. func (a *RocketMQ) Read(ctx context.Context, handler bindings.Handler) error { + if a.closed.Load() { + return errors.New("error: binding is closed") + } + a.logger.Debugf("binding rocketmq: start read input binding") consumer, err := a.setupConsumer() @@ -114,10 +120,12 @@ func (a *RocketMQ) Read(ctx context.Context, handler bindings.Handler) error { a.logger.Debugf("binding-rocketmq: consumer started") // Listen for context cancelation to stop the subscription + a.wg.Add(1) go func() { + defer a.wg.Done() select { case <-ctx.Done(): - case <-a.ctx.Done(): + case <-a.closeCh: } innerErr := consumer.Shutdown() @@ -131,8 +139,10 @@ func (a *RocketMQ) Read(ctx context.Context, handler bindings.Handler) error { // Close implements cancel all listeners, see https://github.com/dapr/components-contrib/issues/779 func (a *RocketMQ) Close() error { - a.cancel() - + defer a.wg.Wait() + if a.closed.CompareAndSwap(false, true) { + close(a.closeCh) + } return nil } @@ -199,21 +209,21 @@ func (a *RocketMQ) Operations() []bindings.OperationKind { return []bindings.OperationKind{bindings.CreateOperation} } -func (a *RocketMQ) Invoke(req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) { +func (a *RocketMQ) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) { rst := &bindings.InvokeResponse{Data: nil, Metadata: nil} if req.Operation != bindings.CreateOperation { return rst, fmt.Errorf("binding-rocketmq error: unsupported operation %s", req.Operation) } - return rst, a.sendMessage(req) + return rst, a.sendMessage(ctx, req) } -func (a *RocketMQ) sendMessage(req *bindings.InvokeRequest) error { +func (a *RocketMQ) sendMessage(ctx context.Context, req *bindings.InvokeRequest) error { topic := req.Metadata[metadataRocketmqTopic] if topic != "" { - _, err := a.send(topic, req.Metadata[metadataRocketmqTag], req.Metadata[metadataRocketmqKey], req.Data) + _, err := a.send(ctx, topic, req.Metadata[metadataRocketmqTag], req.Metadata[metadataRocketmqKey], req.Data) if err != nil { return err } @@ -229,7 +239,7 @@ func (a *RocketMQ) sendMessage(req *bindings.InvokeRequest) error { if err != nil { return err } - _, err = a.send(topic, mqExpression, req.Metadata[metadataRocketmqKey], req.Data) + _, err = a.send(ctx, topic, mqExpression, req.Metadata[metadataRocketmqKey], req.Data) if err != nil { return err } @@ -239,9 +249,9 @@ func (a *RocketMQ) sendMessage(req *bindings.InvokeRequest) error { return nil } -func (a *RocketMQ) send(topic, mqExpr, key string, data []byte) (bool, error) { +func (a *RocketMQ) send(ctx context.Context, topic, mqExpr, key string, data []byte) (bool, error) { msg := primitive.NewMessage(topic, data).WithTag(mqExpr).WithKeys([]string{key}) - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() rst, err := a.producer.SendSync(ctx, msg) if err != nil { diff --git a/bindings/rocketmq/rocketmq_test.go b/bindings/rocketmq/rocketmq_test.go index aefa57c34..4cb162257 100644 --- a/bindings/rocketmq/rocketmq_test.go +++ b/bindings/rocketmq/rocketmq_test.go @@ -35,7 +35,7 @@ func TestInputBindingRead(t *testing.T) { //nolint:paralleltest m := bindings.Metadata{} //nolint:exhaustivestruct m.Properties = getTestMetadata() r := NewRocketMQ(logger.NewLogger("test")) - err := r.Init(m) + err := r.Init(context.Background(), m) require.NoError(t, err) var count int32 @@ -51,7 +51,7 @@ func TestInputBindingRead(t *testing.T) { //nolint:paralleltest time.Sleep(5 * time.Second) atomic.StoreInt32(&count, 0) req := &bindings.InvokeRequest{Data: []byte("hello"), Operation: bindings.CreateOperation, Metadata: map[string]string{}} - _, err = r.Invoke(req) + _, err = r.Invoke(context.Background(), req) require.NoError(t, err) time.Sleep(10 * time.Second) diff --git a/bindings/smtp/smtp.go b/bindings/smtp/smtp.go index 2e6365cc0..8cbc1ad85 100644 --- a/bindings/smtp/smtp.go +++ b/bindings/smtp/smtp.go @@ -61,7 +61,7 @@ func NewSMTP(logger logger.Logger) bindings.OutputBinding { } // Init smtp component (parse metadata). -func (s *Mailer) Init(metadata bindings.Metadata) error { +func (s *Mailer) Init(_ context.Context, metadata bindings.Metadata) error { // parse metadata meta, err := s.parseMetadata(metadata) if err != nil { diff --git a/bindings/twilio/sendgrid/sendgrid.go b/bindings/twilio/sendgrid/sendgrid.go index a5d143b8d..f77f74918 100644 --- a/bindings/twilio/sendgrid/sendgrid.go +++ b/bindings/twilio/sendgrid/sendgrid.go @@ -84,7 +84,7 @@ func (sg *SendGrid) parseMetadata(meta bindings.Metadata) (sendGridMetadata, err } // Init does metadata parsing and not much else :). -func (sg *SendGrid) Init(metadata bindings.Metadata) error { +func (sg *SendGrid) Init(_ context.Context, metadata bindings.Metadata) error { // Parse input metadata meta, err := sg.parseMetadata(metadata) if err != nil { diff --git a/bindings/twilio/sms/sms.go b/bindings/twilio/sms/sms.go index 70a866de6..d01f8282d 100644 --- a/bindings/twilio/sms/sms.go +++ b/bindings/twilio/sms/sms.go @@ -60,19 +60,19 @@ func NewSMS(logger logger.Logger) bindings.OutputBinding { } } -func (t *SMS) Init(metadata bindings.Metadata) error { +func (t *SMS) Init(_ context.Context, metadata bindings.Metadata) error { twilioM := twilioMetadata{ timeout: time.Minute * 5, } if metadata.Properties[fromNumber] == "" { - return errors.New("\"fromNumber\" is a required field") + return errors.New(`"fromNumber" is a required field`) } if metadata.Properties[accountSid] == "" { - return errors.New("\"accountSid\" is a required field") + return errors.New(`"accountSid" is a required field`) } if metadata.Properties[authToken] == "" { - return errors.New("\"authToken\" is a required field") + return errors.New(`"authToken" is a required field`) } twilioM.toNumber = metadata.Properties[toNumber] diff --git a/bindings/twilio/sms/sms_test.go b/bindings/twilio/sms/sms_test.go index 2aba96320..7d1549f43 100644 --- a/bindings/twilio/sms/sms_test.go +++ b/bindings/twilio/sms/sms_test.go @@ -53,7 +53,7 @@ func TestInit(t *testing.T) { m := bindings.Metadata{} m.Properties = map[string]string{"toNumber": "toNumber", "fromNumber": "fromNumber"} tw := NewSMS(logger.NewLogger("test")) - err := tw.Init(m) + err := tw.Init(context.Background(), m) assert.NotNil(t, err) } @@ -66,7 +66,7 @@ func TestParseDuration(t *testing.T) { "authToken": "authToken", "timeout": "badtimeout", } tw := NewSMS(logger.NewLogger("test")) - err := tw.Init(m) + err := tw.Init(context.Background(), m) assert.NotNil(t, err) } @@ -85,7 +85,7 @@ func TestWriteShouldSucceed(t *testing.T) { tw.httpClient = &http.Client{ Transport: httpTransport, } - err := tw.Init(m) + err := tw.Init(context.Background(), m) assert.NoError(t, err) t.Run("Should succeed with expected url and headers", func(t *testing.T) { @@ -123,7 +123,7 @@ func TestWriteShouldFail(t *testing.T) { tw.httpClient = &http.Client{ Transport: httpTransport, } - err := tw.Init(m) + err := tw.Init(context.Background(), m) assert.NoError(t, err) t.Run("Missing 'to' should fail", func(t *testing.T) { @@ -180,7 +180,7 @@ func TestMessageBody(t *testing.T) { tw.httpClient = &http.Client{ Transport: httpTransport, } - err := tw.Init(m) + err := tw.Init(context.Background(), m) require.NoError(t, err) tester := func(reqData []byte, expectBody string) func(t *testing.T) { diff --git a/bindings/twitter/twitter.go b/bindings/twitter/twitter.go index e106246da..e3eedb468 100644 --- a/bindings/twitter/twitter.go +++ b/bindings/twitter/twitter.go @@ -19,6 +19,8 @@ import ( "encoding/json" "fmt" "strconv" + "sync" + "sync/atomic" "time" "github.com/dghubble/go-twitter/twitter" @@ -31,18 +33,21 @@ import ( // Binding represents Twitter input/output binding. type Binding struct { - client *twitter.Client - query string - logger logger.Logger + client *twitter.Client + query string + logger logger.Logger + closed atomic.Bool + closeCh chan struct{} + wg sync.WaitGroup } // NewTwitter returns a new Twitter event input binding. func NewTwitter(logger logger.Logger) bindings.InputOutputBinding { - return &Binding{logger: logger} + return &Binding{logger: logger, closeCh: make(chan struct{})} } // Init initializes the Twitter binding. -func (t *Binding) Init(metadata bindings.Metadata) error { +func (t *Binding) Init(ctx context.Context, metadata bindings.Metadata) error { ck, f := metadata.Properties["consumerKey"] if !f || ck == "" { return fmt.Errorf("consumerKey not set") @@ -124,10 +129,17 @@ func (t *Binding) Read(ctx context.Context, handler bindings.Handler) error { } t.logger.Debug("starting handler...") - go demux.HandleChan(stream.Messages) - + t.wg.Add(2) go func() { - <-ctx.Done() + defer t.wg.Done() + demux.HandleChan(stream.Messages) + }() + go func() { + defer t.wg.Done() + select { + case <-t.closeCh: + case <-ctx.Done(): + } t.logger.Debug("stopping handler...") stream.Stop() }() @@ -135,6 +147,14 @@ func (t *Binding) Read(ctx context.Context, handler bindings.Handler) error { return nil } +func (t *Binding) Close() error { + if t.closed.CompareAndSwap(false, true) { + close(t.closeCh) + } + t.wg.Wait() + return nil +} + // Invoke handles all operations. func (t *Binding) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) { t.logger.Debugf("operation: %v", req.Operation) diff --git a/bindings/twitter/twitter_test.go b/bindings/twitter/twitter_test.go index 096a64523..ba38dcca2 100644 --- a/bindings/twitter/twitter_test.go +++ b/bindings/twitter/twitter_test.go @@ -60,7 +60,7 @@ func getRuntimeMetadata() map[string]string { func TestInit(t *testing.T) { m := getTestMetadata() tw := NewTwitter(logger.NewLogger("test")).(*Binding) - err := tw.Init(m) + err := tw.Init(context.Background(), m) assert.Nilf(t, err, "error initializing valid metadata properties") } @@ -69,7 +69,7 @@ func TestInit(t *testing.T) { func TestReadError(t *testing.T) { tw := NewTwitter(logger.NewLogger("test")).(*Binding) m := getTestMetadata() - err := tw.Init(m) + err := tw.Init(context.Background(), m) assert.Nilf(t, err, "error initializing valid metadata properties") err = tw.Read(context.Background(), func(ctx context.Context, res *bindings.ReadResponse) ([]byte, error) { @@ -79,6 +79,8 @@ func TestReadError(t *testing.T) { return nil, nil }) assert.Error(t, err) + + assert.NoError(t, tw.Close()) } // TestRead executes the Read method which calls Twiter API @@ -93,7 +95,7 @@ func TestRead(t *testing.T) { m.Properties["query"] = "microsoft" tw := NewTwitter(logger.NewLogger("test")).(*Binding) tw.logger.SetOutputLevel(logger.DebugLevel) - err := tw.Init(m) + err := tw.Init(context.Background(), m) assert.Nilf(t, err, "error initializing read") ctx, cancel := context.WithCancel(context.Background()) @@ -116,6 +118,8 @@ func TestRead(t *testing.T) { cancel() t.Fatal("Timeout waiting for messages") } + + assert.NoError(t, tw.Close()) } // TestInvoke executes the Invoke method which calls Twiter API @@ -129,7 +133,7 @@ func TestInvoke(t *testing.T) { m.Properties = getRuntimeMetadata() tw := NewTwitter(logger.NewLogger("test")).(*Binding) tw.logger.SetOutputLevel(logger.DebugLevel) - err := tw.Init(m) + err := tw.Init(context.Background(), m) assert.Nilf(t, err, "error initializing Invoke") req := &bindings.InvokeRequest{ @@ -141,4 +145,5 @@ func TestInvoke(t *testing.T) { resp, err := tw.Invoke(context.Background(), req) assert.Nilf(t, err, "error on invoke") assert.NotNil(t, resp) + assert.NoError(t, tw.Close()) } diff --git a/bindings/zeebe/command/command.go b/bindings/zeebe/command/command.go index e43e25df7..55ebc9673 100644 --- a/bindings/zeebe/command/command.go +++ b/bindings/zeebe/command/command.go @@ -61,7 +61,7 @@ func NewZeebeCommand(logger logger.Logger) bindings.OutputBinding { } // Init does metadata parsing and connection creation. -func (z *ZeebeCommand) Init(metadata bindings.Metadata) error { +func (z *ZeebeCommand) Init(ctx context.Context, metadata bindings.Metadata) error { client, err := z.clientFactory.Get(metadata) if err != nil { return err @@ -114,7 +114,7 @@ func (z *ZeebeCommand) Invoke(ctx context.Context, req *bindings.InvokeRequest) case UpdateJobRetriesOperation: return z.updateJobRetries(ctx, req) case ThrowErrorOperation: - return z.throwError(req) + return z.throwError(ctx, req) case bindings.GetOperation: fallthrough case bindings.CreateOperation: diff --git a/bindings/zeebe/command/command_test.go b/bindings/zeebe/command/command_test.go index 53e93835f..a3b857b2a 100644 --- a/bindings/zeebe/command/command_test.go +++ b/bindings/zeebe/command/command_test.go @@ -58,7 +58,7 @@ func TestInit(t *testing.T) { } cmd := ZeebeCommand{clientFactory: mcf, logger: testLogger} - err := cmd.Init(metadata) + err := cmd.Init(context.Background(), metadata) assert.Error(t, err, errParsing) }) @@ -67,7 +67,7 @@ func TestInit(t *testing.T) { mcf := mockClientFactory{} cmd := ZeebeCommand{clientFactory: mcf, logger: testLogger} - err := cmd.Init(metadata) + err := cmd.Init(context.Background(), metadata) assert.NoError(t, err) diff --git a/bindings/zeebe/command/throw_error.go b/bindings/zeebe/command/throw_error.go index 984b810d4..3b0f43e8f 100644 --- a/bindings/zeebe/command/throw_error.go +++ b/bindings/zeebe/command/throw_error.go @@ -30,7 +30,7 @@ type throwErrorPayload struct { ErrorMessage string `json:"errorMessage"` } -func (z *ZeebeCommand) throwError(req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) { +func (z *ZeebeCommand) throwError(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) { var payload throwErrorPayload err := json.Unmarshal(req.Data, &payload) if err != nil { @@ -53,7 +53,7 @@ func (z *ZeebeCommand) throwError(req *bindings.InvokeRequest) (*bindings.Invoke cmd = cmd.ErrorMessage(payload.ErrorMessage) } - _, err = cmd.Send(context.Background()) + _, err = cmd.Send(ctx) if err != nil { return nil, fmt.Errorf("cannot throw error for job key %d: %w", payload.JobKey, err) } diff --git a/bindings/zeebe/jobworker/jobworker.go b/bindings/zeebe/jobworker/jobworker.go index f3175d5c8..df6904af1 100644 --- a/bindings/zeebe/jobworker/jobworker.go +++ b/bindings/zeebe/jobworker/jobworker.go @@ -19,6 +19,8 @@ import ( "errors" "fmt" "strconv" + "sync" + "sync/atomic" "time" "github.com/camunda/zeebe/clients/go/v8/pkg/entities" @@ -39,6 +41,9 @@ type ZeebeJobWorker struct { client zbc.Client metadata *jobWorkerMetadata logger logger.Logger + closed atomic.Bool + closeCh chan struct{} + wg sync.WaitGroup } // https://docs.zeebe.io/basics/job-workers.html @@ -64,11 +69,15 @@ type jobHandler struct { // NewZeebeJobWorker returns a new ZeebeJobWorker instance. func NewZeebeJobWorker(logger logger.Logger) bindings.InputBinding { - return &ZeebeJobWorker{clientFactory: zeebe.NewClientFactoryImpl(logger), logger: logger} + return &ZeebeJobWorker{ + clientFactory: zeebe.NewClientFactoryImpl(logger), + logger: logger, + closeCh: make(chan struct{}), + } } // Init does metadata parsing and connection creation. -func (z *ZeebeJobWorker) Init(metadata bindings.Metadata) error { +func (z *ZeebeJobWorker) Init(ctx context.Context, metadata bindings.Metadata) error { meta, err := z.parseMetadata(metadata) if err != nil { return err @@ -90,6 +99,10 @@ func (z *ZeebeJobWorker) Init(metadata bindings.Metadata) error { } func (z *ZeebeJobWorker) Read(ctx context.Context, handler bindings.Handler) error { + if z.closed.Load() { + return fmt.Errorf("binding is closed") + } + h := jobHandler{ callback: handler, logger: z.logger, @@ -99,8 +112,14 @@ func (z *ZeebeJobWorker) Read(ctx context.Context, handler bindings.Handler) err jobWorker := z.getJobWorker(h) + z.wg.Add(1) go func() { - <-ctx.Done() + defer z.wg.Done() + + select { + case <-z.closeCh: + case <-ctx.Done(): + } jobWorker.Close() jobWorker.AwaitClose() @@ -110,6 +129,14 @@ func (z *ZeebeJobWorker) Read(ctx context.Context, handler bindings.Handler) err return nil } +func (z *ZeebeJobWorker) Close() error { + if z.closed.CompareAndSwap(false, true) { + close(z.closeCh) + } + z.wg.Wait() + return nil +} + func (z *ZeebeJobWorker) parseMetadata(meta bindings.Metadata) (*jobWorkerMetadata, error) { var m jobWorkerMetadata err := metadata.DecodeMetadata(meta.Properties, &m) diff --git a/bindings/zeebe/jobworker/jobworker_test.go b/bindings/zeebe/jobworker/jobworker_test.go index 663fba1dd..46096bf46 100644 --- a/bindings/zeebe/jobworker/jobworker_test.go +++ b/bindings/zeebe/jobworker/jobworker_test.go @@ -14,6 +14,7 @@ limitations under the License. package jobworker import ( + "context" "errors" "testing" @@ -53,10 +54,11 @@ func TestInit(t *testing.T) { metadata := bindings.Metadata{} var mcf mockClientFactory - jobWorker := ZeebeJobWorker{clientFactory: &mcf, logger: testLogger} - err := jobWorker.Init(metadata) + jobWorker := ZeebeJobWorker{clientFactory: &mcf, logger: testLogger, closeCh: make(chan struct{})} + err := jobWorker.Init(context.Background(), metadata) assert.Error(t, err, ErrMissingJobType) + assert.NoError(t, jobWorker.Close()) }) t.Run("sets client from client factory", func(t *testing.T) { @@ -66,8 +68,8 @@ func TestInit(t *testing.T) { mcf := mockClientFactory{ metadata: metadata, } - jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger} - err := jobWorker.Init(metadata) + jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger, closeCh: make(chan struct{})} + err := jobWorker.Init(context.Background(), metadata) assert.NoError(t, err) @@ -76,6 +78,7 @@ func TestInit(t *testing.T) { assert.NoError(t, err) assert.Equal(t, mc, jobWorker.client) assert.Equal(t, metadata, mcf.metadata) + assert.NoError(t, jobWorker.Close()) }) t.Run("returns error if client could not be instantiated properly", func(t *testing.T) { @@ -85,9 +88,10 @@ func TestInit(t *testing.T) { error: errParsing, } - jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger} - err := jobWorker.Init(metadata) + jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger, closeCh: make(chan struct{})} + err := jobWorker.Init(context.Background(), metadata) assert.Error(t, err, errParsing) + assert.NoError(t, jobWorker.Close()) }) t.Run("sets client from client factory", func(t *testing.T) { @@ -98,8 +102,8 @@ func TestInit(t *testing.T) { metadata: metadata, } - jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger} - err := jobWorker.Init(metadata) + jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger, closeCh: make(chan struct{})} + err := jobWorker.Init(context.Background(), metadata) assert.NoError(t, err) @@ -108,5 +112,6 @@ func TestInit(t *testing.T) { assert.NoError(t, err) assert.Equal(t, mc, jobWorker.client) assert.Equal(t, metadata, mcf.metadata) + assert.NoError(t, jobWorker.Close()) }) } diff --git a/configuration/azure/appconfig/appconfig.go b/configuration/azure/appconfig/appconfig.go index 849763c5c..0253e1abf 100644 --- a/configuration/azure/appconfig/appconfig.go +++ b/configuration/azure/appconfig/appconfig.go @@ -73,7 +73,7 @@ func NewAzureAppConfigurationStore(logger logger.Logger) configuration.Store { } // Init does metadata and connection parsing. -func (r *ConfigurationStore) Init(metadata configuration.Metadata) error { +func (r *ConfigurationStore) Init(_ context.Context, metadata configuration.Metadata) error { m, err := parseMetadata(metadata) if err != nil { return err diff --git a/configuration/azure/appconfig/appconfig_test.go b/configuration/azure/appconfig/appconfig_test.go index 933ccea60..da1c3974d 100644 --- a/configuration/azure/appconfig/appconfig_test.go +++ b/configuration/azure/appconfig/appconfig_test.go @@ -204,7 +204,7 @@ func TestInit(t *testing.T) { Properties: testProperties, }} - err := s.Init(m) + err := s.Init(context.Background(), m) assert.Nil(t, err) cs, ok := s.(*ConfigurationStore) assert.True(t, ok) @@ -229,7 +229,7 @@ func TestInit(t *testing.T) { Properties: testProperties, }} - err := s.Init(m) + err := s.Init(context.Background(), m) assert.Nil(t, err) cs, ok := s.(*ConfigurationStore) assert.True(t, ok) diff --git a/configuration/postgres/postgres.go b/configuration/postgres/postgres.go index 5e81d74ab..d894e1536 100644 --- a/configuration/postgres/postgres.go +++ b/configuration/postgres/postgres.go @@ -86,7 +86,7 @@ func NewPostgresConfigurationStore(logger logger.Logger) configuration.Store { } } -func (p *ConfigurationStore) Init(metadata configuration.Metadata) error { +func (p *ConfigurationStore) Init(parentCtx context.Context, metadata configuration.Metadata) error { p.logger.Debug(InfoStartInit) if p.client != nil { return fmt.Errorf(ErrorAlreadyInitialized) @@ -98,7 +98,7 @@ func (p *ConfigurationStore) Init(metadata configuration.Metadata) error { p.metadata = m } p.ActiveSubscriptions = make(map[string]*subscription) - ctx, cancel := context.WithTimeout(context.Background(), p.metadata.maxIdleTimeout) + ctx, cancel := context.WithTimeout(parentCtx, p.metadata.maxIdleTimeout) defer cancel() client, err := Connect(ctx, p.metadata.connectionString, p.metadata.maxIdleTimeout) if err != nil { diff --git a/configuration/redis/redis.go b/configuration/redis/redis.go index 7d42583e0..7e1170017 100644 --- a/configuration/redis/redis.go +++ b/configuration/redis/redis.go @@ -143,7 +143,7 @@ func parseRedisMetadata(meta configuration.Metadata) (metadata, error) { } // Init does metadata and connection parsing. -func (r *ConfigurationStore) Init(metadata configuration.Metadata) error { +func (r *ConfigurationStore) Init(ctx context.Context, metadata configuration.Metadata) error { m, err := parseRedisMetadata(metadata) if err != nil { return err @@ -156,11 +156,11 @@ func (r *ConfigurationStore) Init(metadata configuration.Metadata) error { r.client = r.newClient(m) } - if _, err = r.client.Ping(context.TODO()).Result(); err != nil { + if _, err = r.client.Ping(ctx).Result(); err != nil { return fmt.Errorf("redis store: error connecting to redis at %s: %s", m.Host, err) } - r.replicas, err = r.getConnectedSlaves() + r.replicas, err = r.getConnectedSlaves(ctx) return err } @@ -204,8 +204,8 @@ func (r *ConfigurationStore) newFailoverClient(m metadata) *redis.Client { return redis.NewFailoverClient(opts) } -func (r *ConfigurationStore) getConnectedSlaves() (int, error) { - res, err := r.client.Do(context.Background(), "INFO", "replication").Result() +func (r *ConfigurationStore) getConnectedSlaves(ctx context.Context) (int, error) { + res, err := r.client.Do(ctx, "INFO", "replication").Result() if err != nil { return 0, err } diff --git a/configuration/store.go b/configuration/store.go index 4622b0d73..84af8b6a8 100644 --- a/configuration/store.go +++ b/configuration/store.go @@ -18,7 +18,7 @@ import "context" // Store is an interface to perform operations on store. type Store interface { // Init configuration store. - Init(metadata Metadata) error + Init(ctx context.Context, metadata Metadata) error // Get configuration. Get(ctx context.Context, req *GetRequest) (*GetResponse, error) diff --git a/go.mod b/go.mod index 33b8c55ab..8a740deb1 100644 --- a/go.mod +++ b/go.mod @@ -400,3 +400,5 @@ replace github.com/toolkits/concurrent => github.com/niean/gotools v0.0.0-201512 // this is a fork which addresses a performance issues due to go routines replace dubbo.apache.org/dubbo-go/v3 => dubbo.apache.org/dubbo-go/v3 v3.0.3-0.20230118042253-4f159a2b38f3 + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/health/pinger.go b/health/pinger.go index 1d697ee4c..02bb0721a 100644 --- a/health/pinger.go +++ b/health/pinger.go @@ -1,5 +1,22 @@ +/* +Copyright 2023 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package health +import ( + "context" +) + type Pinger interface { - Ping() error + Ping(ctx context.Context) error } diff --git a/internal/component/kafka/consumer.go b/internal/component/kafka/consumer.go index 2099f6b48..1601fd741 100644 --- a/internal/component/kafka/consumer.go +++ b/internal/component/kafka/consumer.go @@ -19,6 +19,7 @@ import ( "fmt" "strconv" "sync" + "sync/atomic" "time" "github.com/Shopify/sarama" @@ -31,6 +32,7 @@ type consumer struct { k *Kafka ready chan bool running chan struct{} + stopped atomic.Bool once sync.Once mutex sync.Mutex } @@ -275,9 +277,6 @@ func (k *Kafka) Subscribe(ctx context.Context) error { k.cg = cg - ctx, cancel := context.WithCancel(ctx) - k.cancel = cancel - ready := make(chan bool) k.consumer = consumer{ k: k, @@ -320,7 +319,10 @@ func (k *Kafka) Subscribe(ctx context.Context) error { k.logger.Errorf("Error closing consumer group: %v", err) } - close(k.consumer.running) + // Ensure running channel is only closed once. + if k.consumer.stopped.CompareAndSwap(false, true) { + close(k.consumer.running) + } }() <-ready @@ -331,7 +333,6 @@ func (k *Kafka) Subscribe(ctx context.Context) error { // Close down consumer group resources, refresh once. func (k *Kafka) closeSubscriptionResources() { if k.cg != nil { - k.cancel() err := k.cg.Close() if err != nil { k.logger.Errorf("Error closing consumer group: %v", err) diff --git a/internal/component/kafka/kafka.go b/internal/component/kafka/kafka.go index 9b1b8e301..c3b8b8707 100644 --- a/internal/component/kafka/kafka.go +++ b/internal/component/kafka/kafka.go @@ -36,7 +36,6 @@ type Kafka struct { saslPassword string initialOffset int64 cg sarama.ConsumerGroup - cancel context.CancelFunc consumer consumer config *sarama.Config subscribeTopics TopicHandlerConfig @@ -60,7 +59,7 @@ func NewKafka(logger logger.Logger) *Kafka { } // Init does metadata parsing and connection establishment. -func (k *Kafka) Init(metadata map[string]string) error { +func (k *Kafka) Init(_ context.Context, metadata map[string]string) error { upgradedMetadata, err := k.upgradeMetadata(metadata) if err != nil { return err diff --git a/internal/component/kafka/sasl_oauthbearer.go b/internal/component/kafka/sasl_oauthbearer.go index 99d548359..df70e5df5 100644 --- a/internal/component/kafka/sasl_oauthbearer.go +++ b/internal/component/kafka/sasl_oauthbearer.go @@ -107,7 +107,8 @@ func (ts *OAuthTokenSource) Token() (*sarama.AccessToken, error) { oidcCfg := ccred.Config{ClientID: ts.ClientID, ClientSecret: ts.ClientSecret, Scopes: ts.Scopes, TokenURL: ts.TokenEndpoint.TokenURL, AuthStyle: ts.TokenEndpoint.AuthStyle} - timeoutCtx, _ := ctx.WithTimeout(ctx.TODO(), tokenRequestTimeout) //nolint:govet + timeoutCtx, cancel := ctx.WithTimeout(ctx.TODO(), tokenRequestTimeout) + defer cancel() ts.configureClient() diff --git a/lock/redis/standalone.go b/lock/redis/standalone.go index da607f173..0f1753376 100644 --- a/lock/redis/standalone.go +++ b/lock/redis/standalone.go @@ -38,9 +38,6 @@ type StandaloneRedisLock struct { metadata rediscomponent.Metadata logger logger.Logger - - ctx context.Context - cancel context.CancelFunc } // NewStandaloneRedisLock returns a new standalone redis lock. @@ -54,7 +51,7 @@ func NewStandaloneRedisLock(logger logger.Logger) lock.Store { } // Init StandaloneRedisLock. -func (r *StandaloneRedisLock) InitLockStore(metadata lock.Metadata) error { +func (r *StandaloneRedisLock) InitLockStore(ctx context.Context, metadata lock.Metadata) error { // 1. parse config m, err := rediscomponent.ParseRedisMetadata(metadata.Properties) if err != nil { @@ -75,13 +72,12 @@ func (r *StandaloneRedisLock) InitLockStore(metadata lock.Metadata) error { if err != nil { return err } - r.ctx, r.cancel = context.WithCancel(context.Background()) // 3. connect to redis - if _, err = r.client.PingResult(r.ctx); err != nil { + if _, err = r.client.PingResult(ctx); err != nil { return fmt.Errorf("[standaloneRedisLock]: error connecting to redis at %s: %s", r.clientSettings.Host, err) } // no replica - replicas, err := r.getConnectedSlaves() + replicas, err := r.getConnectedSlaves(ctx) // pass the validation if error occurs, // since some redis versions such as miniredis do not recognize the `INFO` command. if err == nil && replicas > 0 { @@ -101,8 +97,8 @@ func needFailover(properties map[string]string) bool { return false } -func (r *StandaloneRedisLock) getConnectedSlaves() (int, error) { - res, err := r.client.DoRead(r.ctx, "INFO", "replication") +func (r *StandaloneRedisLock) getConnectedSlaves(ctx context.Context) (int, error) { + res, err := r.client.DoRead(ctx, "INFO", "replication") if err != nil { return 0, err } @@ -183,9 +179,6 @@ func newInternalErrorUnlockResponse() *lock.UnlockResponse { // Close shuts down the client's redis connections. func (r *StandaloneRedisLock) Close() error { - if r.cancel != nil { - r.cancel() - } if r.client != nil { closeErr := r.client.Close() r.client = nil diff --git a/lock/redis/standalone_test.go b/lock/redis/standalone_test.go index 796acd3ad..35a631521 100644 --- a/lock/redis/standalone_test.go +++ b/lock/redis/standalone_test.go @@ -42,7 +42,7 @@ func TestStandaloneRedisLock_InitError(t *testing.T) { cfg.Properties["redisPassword"] = "" // init - err := comp.InitLockStore(cfg) + err := comp.InitLockStore(context.Background(), cfg) assert.Error(t, err) }) @@ -58,7 +58,7 @@ func TestStandaloneRedisLock_InitError(t *testing.T) { cfg.Properties["redisPassword"] = "" // init - err := comp.InitLockStore(cfg) + err := comp.InitLockStore(context.Background(), cfg) assert.Error(t, err) }) @@ -75,7 +75,7 @@ func TestStandaloneRedisLock_InitError(t *testing.T) { cfg.Properties["maxRetries"] = "1 " // init - err := comp.InitLockStore(cfg) + err := comp.InitLockStore(context.Background(), cfg) assert.Error(t, err) }) } @@ -96,7 +96,7 @@ func TestStandaloneRedisLock_TryLock(t *testing.T) { cfg.Properties["redisHost"] = s.Addr() cfg.Properties["redisPassword"] = "" // init - err = comp.InitLockStore(cfg) + err = comp.InitLockStore(context.Background(), cfg) assert.NoError(t, err) // 1. client1 trylock ownerID1 := uuid.New().String() diff --git a/lock/store.go b/lock/store.go index 13b974a0a..5507f9f5e 100644 --- a/lock/store.go +++ b/lock/store.go @@ -17,7 +17,7 @@ import "context" type Store interface { // Init this component. - InitLockStore(metadata Metadata) error + InitLockStore(ctx context.Context, metadata Metadata) error // TryLock tries to acquire a lock. TryLock(ctx context.Context, req *TryLockRequest) (*TryLockResponse, error) diff --git a/middleware/http/bearer/bearer_middleware.go b/middleware/http/bearer/bearer_middleware.go index 5f270df8a..987ebbe0a 100644 --- a/middleware/http/bearer/bearer_middleware.go +++ b/middleware/http/bearer/bearer_middleware.go @@ -45,13 +45,13 @@ const ( ) // GetHandler retruns the HTTP handler provided by the middleware. -func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) { +func (m *Middleware) GetHandler(ctx context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) { meta, err := m.getNativeMetadata(metadata) if err != nil { return nil, err } - provider, err := oidc.NewProvider(context.Background(), meta.IssuerURL) + provider, err := oidc.NewProvider(ctx, meta.IssuerURL) if err != nil { return nil, err } diff --git a/middleware/http/oauth2/oauth2_middleware.go b/middleware/http/oauth2/oauth2_middleware.go index 2b89c4b6b..01fbd97ba 100644 --- a/middleware/http/oauth2/oauth2_middleware.go +++ b/middleware/http/oauth2/oauth2_middleware.go @@ -14,6 +14,7 @@ limitations under the License. package oauth2 import ( + "context" "net/http" "net/url" "strings" @@ -59,7 +60,7 @@ const ( ) // GetHandler retruns the HTTP handler provided by the middleware. -func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) { +func (m *Middleware) GetHandler(ctx context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) { meta, err := m.getNativeMetadata(metadata) if err != nil { return nil, err diff --git a/middleware/http/oauth2clientcredentials/mocks/mock_TokenProviderInterface.go b/middleware/http/oauth2clientcredentials/mocks/mock_TokenProviderInterface.go index 1397fb94f..6e107717d 100644 --- a/middleware/http/oauth2clientcredentials/mocks/mock_TokenProviderInterface.go +++ b/middleware/http/oauth2clientcredentials/mocks/mock_TokenProviderInterface.go @@ -5,6 +5,7 @@ package mock_oauth2clientcredentials import ( + "context" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -36,7 +37,7 @@ func (m *MockTokenProviderInterface) EXPECT() *MockTokenProviderInterfaceMockRec } // GetToken mocks base method -func (m *MockTokenProviderInterface) GetToken(arg0 *clientcredentials.Config) (*oauth2.Token, error) { +func (m *MockTokenProviderInterface) GetToken(ctx context.Context, arg0 *clientcredentials.Config) (*oauth2.Token, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetToken", arg0) ret0, _ := ret[0].(*oauth2.Token) diff --git a/middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware.go b/middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware.go index d84ef6003..0e7b22cf8 100644 --- a/middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware.go +++ b/middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware.go @@ -45,7 +45,7 @@ type oAuth2ClientCredentialsMiddlewareMetadata struct { // TokenProviderInterface provides a common interface to Mock the Token retrieval in unit tests. type TokenProviderInterface interface { - GetToken(conf *clientcredentials.Config) (*oauth2.Token, error) + GetToken(ctx context.Context, conf *clientcredentials.Config) (*oauth2.Token, error) } // NewOAuth2ClientCredentialsMiddleware returns a new oAuth2 middleware. @@ -68,7 +68,7 @@ type Middleware struct { } // GetHandler retruns the HTTP handler provided by the middleware. -func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) { +func (m *Middleware) GetHandler(_ context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) { meta, err := m.getNativeMetadata(metadata) if err != nil { m.log.Errorf("getNativeMetadata error: %s", err) @@ -101,7 +101,7 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Ha if !found { m.log.Debugf("Cached token not found, try get one") - token, err := m.tokenProvider.GetToken(conf) + token, err := m.tokenProvider.GetToken(r.Context(), conf) if err != nil { m.log.Errorf("Error acquiring token: %s", err) return @@ -171,8 +171,8 @@ func (m *Middleware) SetTokenProvider(tokenProvider TokenProviderInterface) { } // GetToken returns a token from the current OAuth2 ClientCredentials Configuration. -func (m *Middleware) GetToken(conf *clientcredentials.Config) (*oauth2.Token, error) { - tokenSource := conf.TokenSource(context.Background()) +func (m *Middleware) GetToken(ctx context.Context, conf *clientcredentials.Config) (*oauth2.Token, error) { + tokenSource := conf.TokenSource(ctx) return tokenSource.Token() } diff --git a/middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware_test.go b/middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware_test.go index 5b7b2365d..e1d6951b4 100644 --- a/middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware_test.go +++ b/middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware_test.go @@ -14,6 +14,7 @@ limitations under the License. package oauth2clientcredentials import ( + "context" "net/http" "net/http/httptest" "testing" @@ -45,7 +46,7 @@ func TestOAuth2ClientCredentialsMetadata(t *testing.T) { metadata.Properties = map[string]string{} log := logger.NewLogger("oauth2clientcredentials.test") - _, err := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata) + _, err := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(context.Background(), metadata) assert.EqualError(t, err, "metadata errors: Parameter 'headerName' needs to be set. Parameter 'clientID' needs to be set. Parameter 'clientSecret' needs to be set. Parameter 'scopes' needs to be set. Parameter 'tokenURL' needs to be set. ") // Invalid authStyle (non int) @@ -57,17 +58,17 @@ func TestOAuth2ClientCredentialsMetadata(t *testing.T) { "headerName": "someHeader", "authStyle": "asdf", // This is the value to test } - _, err2 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata) + _, err2 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(context.Background(), metadata) assert.EqualError(t, err2, "metadata errors: 1 error(s) decoding:\n\n* cannot parse 'AuthStyle' as int: strconv.ParseInt: parsing \"asdf\": invalid syntax") // Invalid authStyle (int > 2) metadata.Properties["authStyle"] = "3" - _, err3 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata) + _, err3 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(context.Background(), metadata) assert.EqualError(t, err3, "metadata errors: Parameter 'authStyle' can only have the values 0,1,2. Received: '3'. ") // Invalid authStyle (int < 0) metadata.Properties["authStyle"] = "-1" - _, err4 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata) + _, err4 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(context.Background(), metadata) assert.EqualError(t, err4, "metadata errors: Parameter 'authStyle' can only have the values 0,1,2. Received: '-1'. ") } @@ -109,7 +110,7 @@ func TestOAuth2ClientCredentialsToken(t *testing.T) { log := logger.NewLogger("oauth2clientcredentials.test") oauth2clientcredentialsMiddleware, _ := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware) oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider) - handler, err := oauth2clientcredentialsMiddleware.GetHandler(metadata) + handler, err := oauth2clientcredentialsMiddleware.GetHandler(context.Background(), metadata) require.NoError(t, err) // First handler call should return abc Token @@ -169,7 +170,7 @@ func TestOAuth2ClientCredentialsCache(t *testing.T) { log := logger.NewLogger("oauth2clientcredentials.test") oauth2clientcredentialsMiddleware, _ := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware) oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider) - handler, err := oauth2clientcredentialsMiddleware.GetHandler(metadata) + handler, err := oauth2clientcredentialsMiddleware.GetHandler(context.Background(), metadata) require.NoError(t, err) // First handler call should return abc Token diff --git a/middleware/http/opa/middleware.go b/middleware/http/opa/middleware.go index 6c9c2ba56..4c720859d 100644 --- a/middleware/http/opa/middleware.go +++ b/middleware/http/opa/middleware.go @@ -106,13 +106,13 @@ func (s *Status) Valid() bool { } // GetHandler returns the HTTP handler provided by the middleware. -func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) { +func (m *Middleware) GetHandler(parentCtx context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) { meta, err := m.getNativeMetadata(metadata) if err != nil { return nil, err } - ctx, cancel := context.WithTimeout(context.TODO(), time.Minute) + ctx, cancel := context.WithTimeout(parentCtx, time.Minute) query, err := rego.New( rego.Query("result = data.http.allow"), rego.Module("inline.rego", meta.Rego), diff --git a/middleware/http/opa/middleware_test.go b/middleware/http/opa/middleware_test.go index 4efdd6b08..64c16436a 100644 --- a/middleware/http/opa/middleware_test.go +++ b/middleware/http/opa/middleware_test.go @@ -14,6 +14,7 @@ limitations under the License. package opa import ( + "context" "encoding/json" "net/http" "net/http/httptest" @@ -331,7 +332,7 @@ func TestOpaPolicy(t *testing.T) { t.Run(name, func(t *testing.T) { opaMiddleware := NewMiddleware(log) - handler, err := opaMiddleware.GetHandler(test.meta) + handler, err := opaMiddleware.GetHandler(context.Background(), test.meta) if test.shouldHandlerError { require.Error(t, err) return diff --git a/middleware/http/ratelimit/ratelimit_middleware.go b/middleware/http/ratelimit/ratelimit_middleware.go index a28e39664..8b1b75575 100644 --- a/middleware/http/ratelimit/ratelimit_middleware.go +++ b/middleware/http/ratelimit/ratelimit_middleware.go @@ -14,6 +14,7 @@ limitations under the License. package ratelimit import ( + "context" "fmt" "net/http" "strconv" @@ -46,7 +47,7 @@ func NewRateLimitMiddleware(_ logger.Logger) middleware.Middleware { type Middleware struct{} // GetHandler returns the HTTP handler provided by the middleware. -func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) { +func (m *Middleware) GetHandler(_ context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) { meta, err := m.getNativeMetadata(metadata) if err != nil { return nil, err diff --git a/middleware/http/routeralias/routeralias_middleware.go b/middleware/http/routeralias/routeralias_middleware.go index b9c02dd4d..262c9ae43 100644 --- a/middleware/http/routeralias/routeralias_middleware.go +++ b/middleware/http/routeralias/routeralias_middleware.go @@ -40,7 +40,7 @@ func NewMiddleware(logger logger.Logger) middleware.Middleware { } // GetHandler retruns the HTTP handler provided by the middleware. -func (m *Middleware) GetHandler(metadata middleware.Metadata) ( +func (m *Middleware) GetHandler(_ context.Context, metadata middleware.Metadata) ( func(next http.Handler) http.Handler, error, ) { if err := m.getNativeMetadata(metadata); err != nil { diff --git a/middleware/http/routeralias/routeralias_middleware_test.go b/middleware/http/routeralias/routeralias_middleware_test.go index 4129d77ed..49ddd889b 100644 --- a/middleware/http/routeralias/routeralias_middleware_test.go +++ b/middleware/http/routeralias/routeralias_middleware_test.go @@ -14,6 +14,7 @@ limitations under the License. package routeralias import ( + "context" "io" "net/http" "net/http/httptest" @@ -46,7 +47,7 @@ func TestRequestHandlerWithIllegalRouterRule(t *testing.T) { } log := logger.NewLogger("routeralias.test") ralias := NewMiddleware(log) - handler, err := ralias.GetHandler(meta) + handler, err := ralias.GetHandler(context.Background(), meta) assert.Nil(t, err) t.Run("hit: change router with common request", func(t *testing.T) { diff --git a/middleware/http/routerchecker/routerchecker_middleware.go b/middleware/http/routerchecker/routerchecker_middleware.go index 44139a9db..244848a19 100644 --- a/middleware/http/routerchecker/routerchecker_middleware.go +++ b/middleware/http/routerchecker/routerchecker_middleware.go @@ -14,6 +14,7 @@ limitations under the License. package routerchecker import ( + "context" "fmt" "net/http" "regexp" @@ -40,7 +41,7 @@ type Middleware struct { } // GetHandler retruns the HTTP handler provided by the middleware. -func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) { +func (m *Middleware) GetHandler(_ context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) { meta, err := m.getNativeMetadata(metadata) if err != nil { return nil, err diff --git a/middleware/http/routerchecker/routerchecker_middleware_test.go b/middleware/http/routerchecker/routerchecker_middleware_test.go index 986154cf2..987a24c3c 100644 --- a/middleware/http/routerchecker/routerchecker_middleware_test.go +++ b/middleware/http/routerchecker/routerchecker_middleware_test.go @@ -14,6 +14,7 @@ limitations under the License. package routerchecker import ( + "context" "net/http" "net/http/httptest" "testing" @@ -37,7 +38,7 @@ func TestRequestHandlerWithIllegalRouterRule(t *testing.T) { }}} log := logger.NewLogger("routerchecker.test") rchecker := NewMiddleware(log) - handler, err := rchecker.GetHandler(meta) + handler, err := rchecker.GetHandler(context.Background(), meta) assert.Nil(t, err) r := httptest.NewRequest(http.MethodGet, "http://localhost:5001/v1.0/invoke/qcg.default/method/%20cat%20password", nil) @@ -54,7 +55,7 @@ func TestRequestHandlerWithLegalRouterRule(t *testing.T) { log := logger.NewLogger("routerchecker.test") rchecker := NewMiddleware(log) - handler, err := rchecker.GetHandler(meta) + handler, err := rchecker.GetHandler(context.Background(), meta) assert.Nil(t, err) r := httptest.NewRequest(http.MethodGet, "http://localhost:5001/v1.0/invoke/qcg.default/method", nil) diff --git a/middleware/http/sentinel/middleware.go b/middleware/http/sentinel/middleware.go index ad83ffe82..cc2b3372f 100644 --- a/middleware/http/sentinel/middleware.go +++ b/middleware/http/sentinel/middleware.go @@ -14,6 +14,7 @@ limitations under the License. package sentinel import ( + "context" "fmt" "net/http" @@ -51,7 +52,7 @@ type Middleware struct { } // GetHandler returns the HTTP handler provided by sentinel middleware. -func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) { +func (m *Middleware) GetHandler(_ context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) { var ( meta *middlewareMetadata err error diff --git a/middleware/http/sentinel/middleware_test.go b/middleware/http/sentinel/middleware_test.go index 891f9e8e2..c8594ad92 100644 --- a/middleware/http/sentinel/middleware_test.go +++ b/middleware/http/sentinel/middleware_test.go @@ -14,6 +14,7 @@ limitations under the License. package sentinel import ( + "context" "net/http" "net/http/httptest" "testing" @@ -48,7 +49,7 @@ func TestRequestHandlerWithFlowRules(t *testing.T) { log := logger.NewLogger("sentinel.test") sentinel := NewMiddleware(log) - handler, err := sentinel.GetHandler(meta) + handler, err := sentinel.GetHandler(context.Background(), meta) assert.Nil(t, err) r := httptest.NewRequest(http.MethodGet, "http://localhost:5001/v1.0/nodeapp/healthz", nil) diff --git a/middleware/http/wasm/benchmark_test.go b/middleware/http/wasm/benchmark_test.go index aefd01722..faff2e792 100644 --- a/middleware/http/wasm/benchmark_test.go +++ b/middleware/http/wasm/benchmark_test.go @@ -1,6 +1,7 @@ package wasm import ( + "context" "fmt" "io" "net/http" @@ -45,7 +46,7 @@ func benchmarkMiddleware(b *testing.B, path string) { l := logger.NewLogger(b.Name()) l.SetOutput(io.Discard) - handlerFn, err := NewMiddleware(l).GetHandler(dapr.Metadata{Base: md}) + handlerFn, err := NewMiddleware(l).GetHandler(context.Background(), dapr.Metadata{Base: md}) if err != nil { b.Fatal(err) } diff --git a/middleware/http/wasm/example/go.mod b/middleware/http/wasm/example/go.mod index 6b117f794..97ba0103e 100644 --- a/middleware/http/wasm/example/go.mod +++ b/middleware/http/wasm/example/go.mod @@ -3,3 +3,5 @@ module github.com/dapr/components-contrib/middleware/wasm/example go 1.19 require github.com/http-wasm/http-wasm-guest-tinygo v0.1.0 + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/middleware/http/wasm/httpwasm.go b/middleware/http/wasm/httpwasm.go index 8949bbe45..6b3a7fdca 100644 --- a/middleware/http/wasm/httpwasm.go +++ b/middleware/http/wasm/httpwasm.go @@ -9,6 +9,7 @@ import ( "fmt" "net/http" "os" + "time" "github.com/http-wasm/http-wasm-host-go/handler" @@ -21,9 +22,6 @@ import ( "github.com/dapr/kit/logger" ) -// ctx substitutes for context propagation until middleware APIs support it. -var ctx = context.Background() - // middlewareMetadata includes configuration used for the WebAssembly handler. // Detailed notes are in README.md for visibility. // @@ -47,8 +45,8 @@ func NewMiddleware(logger logger.Logger) dapr.Middleware { return &middleware{logger: logger} } -func (m *middleware) GetHandler(metadata dapr.Metadata) (func(next http.Handler) http.Handler, error) { - rh, err := m.getHandler(metadata) +func (m *middleware) GetHandler(ctx context.Context, metadata dapr.Metadata) (func(next http.Handler) http.Handler, error) { + rh, err := m.getHandler(ctx, metadata) if err != nil { return nil, err } @@ -56,7 +54,7 @@ func (m *middleware) GetHandler(metadata dapr.Metadata) (func(next http.Handler) } // getHandler is extracted for unit testing. -func (m *middleware) getHandler(metadata dapr.Metadata) (*requestHandler, error) { +func (m *middleware) getHandler(ctx context.Context, metadata dapr.Metadata) (*requestHandler, error) { meta, err := m.getMetadata(metadata) if err != nil { return nil, fmt.Errorf("wasm basic: failed to parse metadata: %w", err) @@ -165,5 +163,7 @@ func (rh *requestHandler) requestHandler(next http.Handler) http.Handler { // Close implements io.Closer func (rh *requestHandler) Close() error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() return rh.mw.Close(ctx) } diff --git a/middleware/http/wasm/httpwasm_test.go b/middleware/http/wasm/httpwasm_test.go index d60e000a1..d9aa970f7 100644 --- a/middleware/http/wasm/httpwasm_test.go +++ b/middleware/http/wasm/httpwasm_test.go @@ -2,7 +2,9 @@ package wasm import ( "bytes" + "context" _ "embed" + "io" "net/http" "net/http/httptest" "testing" @@ -31,7 +33,7 @@ func Test_middleware_log(t *testing.T) { m := &middleware{logger: l} message := "alert" - m.Log(ctx, api.LogLevelInfo, message) + m.Log(context.Background(), api.LogLevelInfo, message) require.Contains(t, buf.String(), `level=info msg=alert`) } @@ -114,7 +116,7 @@ func Test_middleware_getHandler(t *testing.T) { for _, tt := range tests { tc := tt t.Run(tc.name, func(t *testing.T) { - h, err := m.getHandler(dapr.Metadata{Base: tc.metadata}) + h, err := m.getHandler(context.Background(), dapr.Metadata{Base: tc.metadata}) if tc.expectedErr == "" { require.NoError(t, err) require.NotNil(t, h.mw) @@ -135,7 +137,7 @@ func Test_Example(t *testing.T) { // tinygo build -o router.wasm -scheduler=none --no-debug -target=wasi router.go` "path": "./example/router.wasm", }} - handlerFn, err := NewMiddleware(l).GetHandler(dapr.Metadata{Base: meta}) + handlerFn, err := NewMiddleware(l).GetHandler(context.Background(), dapr.Metadata{Base: meta}) require.NoError(t, err) handler := handlerFn(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) @@ -146,3 +148,7 @@ func Test_Example(t *testing.T) { require.Equal(t, "/hi?name=panda", httputils.RequestURI(r)) require.Empty(t, buf.String()) } + +func Test_ioCloser(t *testing.T) { + var _ io.Closer = &requestHandler{} +} diff --git a/middleware/http/wasm/internal/e2e-guests/go.mod b/middleware/http/wasm/internal/e2e-guests/go.mod index 9d8f47562..d9a5f83a1 100644 --- a/middleware/http/wasm/internal/e2e-guests/go.mod +++ b/middleware/http/wasm/internal/e2e-guests/go.mod @@ -3,3 +3,5 @@ module github.com/dapr/components-contrib/middleware/wasm/internal go 1.19 require github.com/http-wasm/http-wasm-guest-tinygo v0.1.0 + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/middleware/http/wasm/internal/e2e_test.go b/middleware/http/wasm/internal/e2e_test.go index 756cd9b7c..1ccb46421 100644 --- a/middleware/http/wasm/internal/e2e_test.go +++ b/middleware/http/wasm/internal/e2e_test.go @@ -2,6 +2,7 @@ package internal_test import ( "bytes" + "context" "log" "net/http" "net/http/httptest" @@ -139,7 +140,7 @@ func Test_EndToEnd(t *testing.T) { require.NoError(t, os.WriteFile(wasmPath, tc.guest, 0o600)) meta := metadata.Base{Properties: map[string]string{"path": wasmPath}} - handlerFn, err := wasm.NewMiddleware(l).GetHandler(middleware.Metadata{Base: meta}) + handlerFn, err := wasm.NewMiddleware(l).GetHandler(context.Background(), middleware.Metadata{Base: meta}) require.NoError(t, err) handler := handlerFn(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) tc.test(t, handler, &buf) diff --git a/middleware/middleware.go b/middleware/middleware.go index 0ff1babb7..88e13660e 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -14,10 +14,11 @@ limitations under the License. package middleware import ( + "context" "net/http" ) // Middleware is the interface for a middleware. type Middleware interface { - GetHandler(metadata Metadata) (func(next http.Handler) http.Handler, error) + GetHandler(ctx context.Context, metadata Metadata) (func(next http.Handler) http.Handler, error) } diff --git a/pubsub/aws/snssqs/snssqs.go b/pubsub/aws/snssqs/snssqs.go index e08efc432..f0958838c 100644 --- a/pubsub/aws/snssqs/snssqs.go +++ b/pubsub/aws/snssqs/snssqs.go @@ -16,10 +16,12 @@ package snssqs import ( "context" "encoding/json" + "errors" "fmt" "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/aws/aws-sdk-go/aws" @@ -61,12 +63,12 @@ type snsSqs struct { logger logger.Logger id string opsTimeout time.Duration - ctx context.Context - cancel context.CancelFunc - pollerCtx context.Context - pollerCancel context.CancelFunc backOffConfig retry.Config pollerRunning chan struct{} + + closeCh chan struct{} + closed atomic.Bool + wg sync.WaitGroup } type sqsQueueInfo struct { @@ -105,6 +107,7 @@ func NewSnsSqs(l logger.Logger) pubsub.PubSub { id: id, topicsLock: sync.RWMutex{}, pollerRunning: make(chan struct{}, 1), + closeCh: make(chan struct{}), } } @@ -147,7 +150,7 @@ func nameToAWSSanitizedName(name string, isFifo bool) string { return string(s[:j]) } -func (s *snsSqs) Init(metadata pubsub.Metadata) error { +func (s *snsSqs) Init(ctx context.Context, metadata pubsub.Metadata) error { md, err := s.getSnsSqsMetatdata(metadata) if err != nil { return err @@ -172,9 +175,8 @@ func (s *snsSqs) Init(metadata pubsub.Metadata) error { s.stsClient = sts.New(sess) s.opsTimeout = time.Duration(md.assetsManagementTimeoutSeconds * float64(time.Second)) - s.ctx, s.cancel = context.WithCancel(context.Background()) - err = s.setAwsAccountIDIfNotProvided(s.ctx) + err = s.setAwsAccountIDIfNotProvided(ctx) if err != nil { return err } @@ -637,7 +639,11 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters case pubsub.Single: f(message) case pubsub.Parallel: - go f(message) + wg.Add(1) + go func(message *sqs.Message) { + defer wg.Done() + f(message) + }(message) } } wg.Wait() @@ -747,10 +753,14 @@ func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(parentCtx context.Context, return nil } -func (s *snsSqs) Subscribe(subscribeCtx context.Context, req pubsub.SubscribeRequest, handler pubsub.Handler) error { +func (s *snsSqs) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, handler pubsub.Handler) error { + if s.closed.Load() { + return errors.New("error: pubsub has been closed") + } + // subscribers declare a topic ARN and declare a SQS queue to use // these should be idempotent - queues should not be created if they exist. - topicArn, sanitizedName, err := s.getOrCreateTopic(subscribeCtx, req.Topic) + topicArn, sanitizedName, err := s.getOrCreateTopic(ctx, req.Topic) if err != nil { wrappedErr := fmt.Errorf("error getting topic ARN for %s: %w", req.Topic, err) s.logger.Error(wrappedErr) @@ -760,7 +770,7 @@ func (s *snsSqs) Subscribe(subscribeCtx context.Context, req pubsub.SubscribeReq // this is the ID of the application, it is supplied via runtime as "consumerID". var queueInfo *sqsQueueInfo - queueInfo, err = s.getOrCreateQueue(subscribeCtx, s.metadata.sqsQueueName) + queueInfo, err = s.getOrCreateQueue(ctx, s.metadata.sqsQueueName) if err != nil { wrappedErr := fmt.Errorf("error retrieving SQS queue: %w", err) s.logger.Error(wrappedErr) @@ -770,7 +780,7 @@ func (s *snsSqs) Subscribe(subscribeCtx context.Context, req pubsub.SubscribeReq // only after a SQS queue and SNS topic had been setup, we restrict the SendMessage action to SNS as sole source // to prevent anyone but SNS to publish message to SQS. - err = s.restrictQueuePublishPolicyToOnlySNS(subscribeCtx, queueInfo, topicArn) + err = s.restrictQueuePublishPolicyToOnlySNS(ctx, queueInfo, topicArn) if err != nil { wrappedErr := fmt.Errorf("error setting sns-sqs subscription policy: %w", err) s.logger.Error(wrappedErr) @@ -783,7 +793,7 @@ func (s *snsSqs) Subscribe(subscribeCtx context.Context, req pubsub.SubscribeReq var derr error if len(s.metadata.sqsDeadLettersQueueName) > 0 { - deadLettersQueueInfo, derr = s.getOrCreateQueue(subscribeCtx, s.metadata.sqsDeadLettersQueueName) + deadLettersQueueInfo, derr = s.getOrCreateQueue(ctx, s.metadata.sqsDeadLettersQueueName) if derr != nil { wrappedErr := fmt.Errorf("error retrieving SQS dead-letter queue: %w", err) s.logger.Error(wrappedErr) @@ -791,7 +801,7 @@ func (s *snsSqs) Subscribe(subscribeCtx context.Context, req pubsub.SubscribeReq return wrappedErr } - err = s.setDeadLettersQueueAttributes(subscribeCtx, queueInfo, deadLettersQueueInfo) + err = s.setDeadLettersQueueAttributes(ctx, queueInfo, deadLettersQueueInfo) if err != nil { wrappedErr := fmt.Errorf("error creating dead-letter queue: %w", err) s.logger.Error(wrappedErr) @@ -801,7 +811,7 @@ func (s *snsSqs) Subscribe(subscribeCtx context.Context, req pubsub.SubscribeReq } // subscription creation is idempotent. Subscriptions are unique by topic/queue. - _, err = s.getOrCreateSnsSqsSubscription(subscribeCtx, queueInfo.arn, topicArn) + _, err = s.getOrCreateSnsSqsSubscription(ctx, queueInfo.arn, topicArn) if err != nil { wrappedErr := fmt.Errorf("error subscribing topic: %s, to queue: %s, with error: %w", topicArn, queueInfo.arn, err) s.logger.Error(wrappedErr) @@ -815,23 +825,45 @@ func (s *snsSqs) Subscribe(subscribeCtx context.Context, req pubsub.SubscribeReq s.topicHandlers[sanitizedName] = topicHandler{ topicName: req.Topic, handler: handler, - ctx: subscribeCtx, + ctx: ctx, } + // pollerCancel is used to cancel the polling goroutine. We use a noop cancel + // func in case the poller is already running and there is no cancel to use + // from the select below. + var pollerCancel context.CancelFunc = func() {} // Start the poller for the queue if it's not running already select { case s.pollerRunning <- struct{}{}: // If inserting in the channel succeeds, then it's not running already // Use a context that is tied to the background context - s.pollerCtx, s.pollerCancel = context.WithCancel(s.ctx) - go s.consumeSubscription(s.ctx, queueInfo, deadLettersQueueInfo) + var subctx context.Context + subctx, pollerCancel = context.WithCancel(context.Background()) + s.wg.Add(2) + go func() { + defer s.wg.Done() + defer pollerCancel() + select { + case <-s.closeCh: + case <-subctx.Done(): + } + }() + go func() { + defer s.wg.Done() + s.consumeSubscription(subctx, queueInfo, deadLettersQueueInfo) + }() default: // Do nothing, it means the poller is already running } - // Watch for subscription context cancelation to remove this subscription + // Watch for subscription context cancellation to remove this subscription + s.wg.Add(1) go func() { - <-subscribeCtx.Done() + defer s.wg.Done() + select { + case <-ctx.Done(): + case <-s.closeCh: + } s.topicsLock.Lock() defer s.topicsLock.Unlock() @@ -839,9 +871,9 @@ func (s *snsSqs) Subscribe(subscribeCtx context.Context, req pubsub.SubscribeReq // Remove the handler delete(s.topicHandlers, sanitizedName) - // If we don't have any topic left, close the poller + // If we don't have any topic left, close the poller. if len(s.topicHandlers) == 0 { - s.pollerCancel() + pollerCancel() } }() @@ -849,6 +881,10 @@ func (s *snsSqs) Subscribe(subscribeCtx context.Context, req pubsub.SubscribeReq } func (s *snsSqs) Publish(ctx context.Context, req *pubsub.PublishRequest) error { + if s.closed.Load() { + return errors.New("error: pubsub has been closed") + } + topicArn, _, err := s.getOrCreateTopic(ctx, req.Topic) if err != nil { s.logger.Errorf("error getting topic ARN for %s: %v", req.Topic, err) @@ -875,9 +911,13 @@ func (s *snsSqs) Publish(ctx context.Context, req *pubsub.PublishRequest) error return nil } +// Close should always be called to release the resources used by the SNS/SQS +// client. Blocks until all goroutines have returned. func (s *snsSqs) Close() error { - s.cancel() - + if s.closed.CompareAndSwap(false, true) { + close(s.closeCh) + } + s.wg.Wait() return nil } diff --git a/pubsub/azure/eventhubs/eventhubs.go b/pubsub/azure/eventhubs/eventhubs.go index 195de1715..8ab8e0649 100644 --- a/pubsub/azure/eventhubs/eventhubs.go +++ b/pubsub/azure/eventhubs/eventhubs.go @@ -41,7 +41,7 @@ func NewAzureEventHubs(logger logger.Logger) pubsub.PubSub { } // Init the object. -func (aeh *AzureEventHubs) Init(metadata pubsub.Metadata) error { +func (aeh *AzureEventHubs) Init(_ context.Context, metadata pubsub.Metadata) error { return aeh.AzureEventHubs.Init(metadata.Properties) } diff --git a/pubsub/azure/eventhubs/eventhubs_integration_test.go b/pubsub/azure/eventhubs/eventhubs_integration_test.go index a3d7ccca2..fc913ab36 100644 --- a/pubsub/azure/eventhubs/eventhubs_integration_test.go +++ b/pubsub/azure/eventhubs/eventhubs_integration_test.go @@ -69,7 +69,7 @@ func testReadIotHubEvents(t *testing.T) { logger := kitLogger.NewLogger("pubsub.azure.eventhubs.integration.test") logger.SetOutputLevel(kitLogger.DebugLevel) eh := NewAzureEventHubs(logger).(*AzureEventHubs) - err := eh.Init(createIotHubPubsubMetadata()) + err := eh.Init(context.Background(), createIotHubPubsubMetadata()) assert.NoError(t, err) // Invoke az CLI via bash script to send test IoT device events diff --git a/pubsub/azure/servicebus/queues/servicebus.go b/pubsub/azure/servicebus/queues/servicebus.go index db128c902..9c1726063 100644 --- a/pubsub/azure/servicebus/queues/servicebus.go +++ b/pubsub/azure/servicebus/queues/servicebus.go @@ -49,7 +49,7 @@ func NewAzureServiceBusQueues(logger logger.Logger) pubsub.PubSub { } } -func (a *azureServiceBus) Init(metadata pubsub.Metadata) (err error) { +func (a *azureServiceBus) Init(_ context.Context, metadata pubsub.Metadata) (err error) { a.metadata, err = impl.ParseMetadata(metadata.Properties, a.logger, impl.MetadataModeQueues) if err != nil { return err diff --git a/pubsub/azure/servicebus/topics/servicebus.go b/pubsub/azure/servicebus/topics/servicebus.go index d46fd57e4..718704312 100644 --- a/pubsub/azure/servicebus/topics/servicebus.go +++ b/pubsub/azure/servicebus/topics/servicebus.go @@ -49,7 +49,7 @@ func NewAzureServiceBusTopics(logger logger.Logger) pubsub.PubSub { } } -func (a *azureServiceBus) Init(metadata pubsub.Metadata) (err error) { +func (a *azureServiceBus) Init(_ context.Context, metadata pubsub.Metadata) (err error) { a.metadata, err = impl.ParseMetadata(metadata.Properties, a.logger, impl.MetadataModeTopics) if err != nil { return err diff --git a/pubsub/gcp/pubsub/pubsub.go b/pubsub/gcp/pubsub/pubsub.go index e16ac74ae..acb43aefe 100644 --- a/pubsub/gcp/pubsub/pubsub.go +++ b/pubsub/gcp/pubsub/pubsub.go @@ -175,13 +175,13 @@ func createMetadata(pubSubMetadata pubsub.Metadata) (*metadata, error) { } // Init parses metadata and creates a new Pub Sub client. -func (g *GCPPubSub) Init(meta pubsub.Metadata) error { +func (g *GCPPubSub) Init(ctx context.Context, meta pubsub.Metadata) error { metadata, err := createMetadata(meta) if err != nil { return err } - pubsubClient, err := g.getPubSubClient(context.Background(), metadata) + pubsubClient, err := g.getPubSubClient(ctx, metadata) if err != nil { return fmt.Errorf("%s error creating pubsub client: %w", errorMessagePrefix, err) } diff --git a/pubsub/hazelcast/hazelcast.go b/pubsub/hazelcast/hazelcast.go index 5e313a41a..c6d1b3fec 100644 --- a/pubsub/hazelcast/hazelcast.go +++ b/pubsub/hazelcast/hazelcast.go @@ -65,7 +65,7 @@ func parseHazelcastMetadata(meta pubsub.Metadata) (metadata, error) { return m, nil } -func (p *Hazelcast) Init(metadata pubsub.Metadata) error { +func (p *Hazelcast) Init(ctx context.Context, metadata pubsub.Metadata) error { p.logger.Warnf("DEPRECATION NOTICE: Component pubsub.hazelcast has been deprecated and will be removed in a future Dapr release.") m, err := parseHazelcastMetadata(metadata) diff --git a/pubsub/in-memory/in-memory.go b/pubsub/in-memory/in-memory.go index b54ef5c0e..ff9e6f77b 100644 --- a/pubsub/in-memory/in-memory.go +++ b/pubsub/in-memory/in-memory.go @@ -41,7 +41,7 @@ func (a *bus) Features() []pubsub.Feature { return []pubsub.Feature{pubsub.FeatureSubscribeWildcards} } -func (a *bus) Init(metadata pubsub.Metadata) error { +func (a *bus) Init(_ context.Context, metadata pubsub.Metadata) error { a.bus = eventbus.New(true) return nil diff --git a/pubsub/in-memory/in-memory_test.go b/pubsub/in-memory/in-memory_test.go index 033190374..ce1ec3d2e 100644 --- a/pubsub/in-memory/in-memory_test.go +++ b/pubsub/in-memory/in-memory_test.go @@ -26,7 +26,7 @@ import ( func TestNewInMemoryBus(t *testing.T) { bus := New(logger.NewLogger("test")) - bus.Init(pubsub.Metadata{}) + bus.Init(context.Background(), pubsub.Metadata{}) ch := make(chan []byte) bus.Subscribe(context.Background(), pubsub.SubscribeRequest{Topic: "demo"}, func(ctx context.Context, msg *pubsub.NewMessage) error { @@ -39,7 +39,7 @@ func TestNewInMemoryBus(t *testing.T) { func TestMultipleSubscribers(t *testing.T) { bus := New(logger.NewLogger("test")) - bus.Init(pubsub.Metadata{}) + bus.Init(context.Background(), pubsub.Metadata{}) ch1 := make(chan []byte) ch2 := make(chan []byte) @@ -59,7 +59,7 @@ func TestMultipleSubscribers(t *testing.T) { func TestWildcards(t *testing.T) { bus := New(logger.NewLogger("test")) - bus.Init(pubsub.Metadata{}) + bus.Init(context.Background(), pubsub.Metadata{}) ch1 := make(chan []byte) ch2 := make(chan []byte) @@ -83,7 +83,7 @@ func TestWildcards(t *testing.T) { func TestRetry(t *testing.T) { bus := New(logger.NewLogger("test")) - bus.Init(pubsub.Metadata{}) + bus.Init(context.Background(), pubsub.Metadata{}) ch := make(chan []byte) i := -1 diff --git a/pubsub/jetstream/jetstream.go b/pubsub/jetstream/jetstream.go index 8f28cbbd7..44f41f2a1 100644 --- a/pubsub/jetstream/jetstream.go +++ b/pubsub/jetstream/jetstream.go @@ -37,7 +37,7 @@ func NewJetStream(logger logger.Logger) pubsub.PubSub { return &jetstreamPubSub{l: logger} } -func (js *jetstreamPubSub) Init(metadata pubsub.Metadata) error { +func (js *jetstreamPubSub) Init(_ context.Context, metadata pubsub.Metadata) error { var err error js.meta, err = parseMetadata(metadata) if err != nil { diff --git a/pubsub/kafka/kafka.go b/pubsub/kafka/kafka.go index c790527d9..389353d2f 100644 --- a/pubsub/kafka/kafka.go +++ b/pubsub/kafka/kafka.go @@ -31,10 +31,10 @@ type PubSub struct { subscribeCancel context.CancelFunc } -func (p *PubSub) Init(metadata pubsub.Metadata) error { +func (p *PubSub) Init(ctx context.Context, metadata pubsub.Metadata) error { p.subscribeCtx, p.subscribeCancel = context.WithCancel(context.Background()) - return p.kafka.Init(metadata.Properties) + return p.kafka.Init(ctx, metadata.Properties) } func (p *PubSub) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, handler pubsub.Handler) error { diff --git a/pubsub/kubemq/kubemq.go b/pubsub/kubemq/kubemq.go index fdb22d4a4..a362e7a62 100644 --- a/pubsub/kubemq/kubemq.go +++ b/pubsub/kubemq/kubemq.go @@ -14,8 +14,6 @@ import ( type kubeMQ struct { metadata *metadata logger logger.Logger - ctx context.Context - ctxCancel context.CancelFunc eventsClient *kubeMQEvents eventStoreClient *kubeMQEventStore } @@ -26,14 +24,13 @@ func NewKubeMQ(logger logger.Logger) pubsub.PubSub { } } -func (k *kubeMQ) Init(metadata pubsub.Metadata) error { +func (k *kubeMQ) Init(_ context.Context, metadata pubsub.Metadata) error { meta, err := createMetadata(metadata) if err != nil { k.logger.Errorf("error init kubemq client error: %s", err.Error()) return err } k.metadata = meta - k.ctx, k.ctxCancel = context.WithCancel(context.Background()) if meta.isStore { k.eventStoreClient = newKubeMQEventsStore(k.logger) _ = k.eventStoreClient.Init(meta) diff --git a/pubsub/kubemq/kubemq_test.go b/pubsub/kubemq/kubemq_test.go index a46f6e4ca..3c9adec1b 100644 --- a/pubsub/kubemq/kubemq_test.go +++ b/pubsub/kubemq/kubemq_test.go @@ -20,8 +20,6 @@ func getMockEventsClient() *kubeMQEvents { publishFunc: nil, resultChan: nil, waitForResultTimeout: 0, - ctx: nil, - ctxCancel: nil, isInitialized: true, } } @@ -34,8 +32,6 @@ func getMockEventsStoreClient() *kubeMQEventStore { publishFunc: nil, resultChan: nil, waitForResultTimeout: 0, - ctx: nil, - ctxCancel: nil, isInitialized: true, } } @@ -103,7 +99,7 @@ func Test_kubeMQ_Init(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { k := NewKubeMQ(logger.NewLogger("test")) - err := k.Init(tt.meta) + err := k.Init(context.Background(), tt.meta) assert.Equal(t, tt.wantErr, err != nil) }) } @@ -113,8 +109,6 @@ func Test_kubeMQ_Close(t *testing.T) { type fields struct { metadata *metadata logger logger.Logger - ctx context.Context - ctxCancel context.CancelFunc eventsClient *kubeMQEvents eventStoreClient *kubeMQEventStore } @@ -151,8 +145,6 @@ func Test_kubeMQ_Close(t *testing.T) { k := &kubeMQ{ metadata: tt.fields.metadata, logger: tt.fields.logger, - ctx: tt.fields.ctx, - ctxCancel: tt.fields.ctxCancel, eventsClient: tt.fields.eventsClient, eventStoreClient: tt.fields.eventStoreClient, } diff --git a/pubsub/mqtt3/mqtt.go b/pubsub/mqtt3/mqtt.go index 03244c013..9f45c9e12 100644 --- a/pubsub/mqtt3/mqtt.go +++ b/pubsub/mqtt3/mqtt.go @@ -22,6 +22,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" mqtt "github.com/eclipse/paho.mqtt.golang" @@ -45,8 +46,9 @@ type mqttPubSub struct { topics map[string]mqttPubSubSubscription subscribingLock sync.RWMutex reconnectCh chan struct{} - ctx context.Context - cancel context.CancelFunc + closeCh chan struct{} + closed atomic.Bool + wg sync.WaitGroup } type mqttPubSubSubscription struct { @@ -59,21 +61,20 @@ type mqttPubSubSubscription struct { func NewMQTTPubSub(logger logger.Logger) pubsub.PubSub { return &mqttPubSub{ logger: logger, - reconnectCh: make(chan struct{}, 1), + reconnectCh: make(chan struct{}), + closeCh: make(chan struct{}), } } // Init parses metadata and creates a new Pub Sub client. -func (m *mqttPubSub) Init(metadata pubsub.Metadata) error { +func (m *mqttPubSub) Init(ctx context.Context, metadata pubsub.Metadata) error { mqttMeta, err := parseMQTTMetaData(metadata, m.logger) if err != nil { return err } m.metadata = mqttMeta - m.ctx, m.cancel = context.WithCancel(context.Background()) - - err = m.connect() + err = m.connect(ctx) if err != nil { return fmt.Errorf("failed to establish connection to broker: %w", err) } @@ -85,7 +86,11 @@ func (m *mqttPubSub) Init(metadata pubsub.Metadata) error { } // Publish the topic to mqtt pub sub. -func (m *mqttPubSub) Publish(parentCtx context.Context, req *pubsub.PublishRequest) (err error) { +func (m *mqttPubSub) Publish(ctx context.Context, req *pubsub.PublishRequest) (err error) { + if m.closed.Load() { + return errors.New("error: mqtt client closed") + } + if req.Topic == "" { return errors.New("topic name is empty") } @@ -103,14 +108,13 @@ func (m *mqttPubSub) Publish(parentCtx context.Context, req *pubsub.PublishReque } token := m.conn.Publish(req.Topic, m.metadata.qos, retain, req.Data) - ctx, cancel := context.WithTimeout(parentCtx, defaultWait) + ctx, cancel := context.WithTimeout(ctx, defaultWait) defer cancel() select { case <-token.Done(): err = token.Error() - case <-m.ctx.Done(): - // Context canceled - err = m.ctx.Err() + case <-m.closeCh: + err = errors.New("mqtt client closed") case <-ctx.Done(): // Context canceled err = ctx.Err() @@ -126,9 +130,8 @@ func (m *mqttPubSub) Publish(parentCtx context.Context, req *pubsub.PublishReque // Request metadata includes: // - "unsubscribeOnClose": if true, when the subscription is stopped (context canceled), then an Unsubscribe message is sent to the MQTT broker, which will stop delivering messages to this consumer ID until the subscription is explicitly re-started with a new Subscribe call. Otherwise, messages continue to be delivered but are not handled and are NACK'd automatically. "unsubscribeOnClose" should be used with dynamic subscriptions. func (m *mqttPubSub) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, handler pubsub.Handler) error { - if ctxErr := m.ctx.Err(); ctxErr != nil { - // If the global context has been canceled, we do not allow more subscriptions - return ctxErr + if m.closed.Load() { + return errors.New("error: mqtt client closed") } topic := req.Topic @@ -143,17 +146,16 @@ func (m *mqttPubSub) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, // Add the topic then start the subscription m.addTopic(topic, handler) - // Use the global context here to maintain the handler - token := m.conn.Subscribe(topic, m.metadata.qos, m.onMessage(m.ctx)) - subscribeCtx, subscribeCancel := context.WithTimeout(m.ctx, defaultWait) - defer subscribeCancel() + token := m.conn.Subscribe(topic, m.metadata.qos, m.onMessage(ctx)) var err error select { case <-token.Done(): // Subscription went through (sucecessfully or not) err = token.Error() - case <-subscribeCtx.Done(): - err = fmt.Errorf("error while waiting for subscription token: %w", subscribeCtx.Err()) + case <-ctx.Done(): + err = fmt.Errorf("error while waiting for subscription token: %w", ctx.Err()) + case <-time.After(defaultWait): + err = errors.New("timeout waiting for subscription") } if err != nil { // Return an error @@ -164,14 +166,15 @@ func (m *mqttPubSub) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, m.logger.Infof("MQTT is subscribed to topic %s (qos: %d)", topic, m.metadata.qos) // Listen for context cancelation to remove the subscription + m.wg.Add(1) go func() { + defer m.wg.Done() + select { case <-ctx.Done(): - case <-m.ctx.Done(): - } - - // If m.ctx has been canceled, nothing to do here as the entire connection will be closed - if m.ctx.Err() != nil { + case <-m.closeCh: + // If Close has been called, nothing to do here as the entire connection + // will be closed. return } @@ -187,18 +190,16 @@ func (m *mqttPubSub) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, } unsubscribeToken := m.conn.Unsubscribe(topic) - unsubscribeCtx, unsubscribeCancel := context.WithTimeout(m.ctx, defaultWait) - defer unsubscribeCancel() var unsubscribeErr error select { case <-unsubscribeToken.Done(): // Subscription went through (sucecessfully or not) unsubscribeErr = token.Error() - case <-unsubscribeCtx.Done(): - unsubscribeErr = fmt.Errorf("error while waiting for subscription token: %w", unsubscribeCtx.Err()) + case <-time.After(defaultWait): + unsubscribeErr = fmt.Errorf("timeout while unsubscribing from topic %s", topic) } if unsubscribeErr != nil { - m.logger.Warnf("Failed to ubsubscribe from topic %s: %v", topic, err) + m.logger.Warnf("Failed to ubsubscribe from topic %s: %v", topic, unsubscribeErr) } }() @@ -280,13 +281,13 @@ func (m *mqttPubSub) doConnect(ctx context.Context, clientID string) (mqtt.Clien } // Create a connection -func (m *mqttPubSub) connect() error { +func (m *mqttPubSub) connect(ctx context.Context) error { m.subscribingLock.Lock() defer m.subscribingLock.Unlock() - connCtx, connCancel := context.WithTimeout(m.ctx, defaultWait) - conn, err := m.doConnect(connCtx, m.metadata.consumerID) - connCancel() + ctx, cancel := context.WithTimeout(ctx, defaultWait) + defer cancel() + conn, err := m.doConnect(ctx, m.metadata.consumerID) if err != nil { return err } @@ -295,45 +296,6 @@ func (m *mqttPubSub) connect() error { return nil } -// Forcefully closes the connection and, after a delay, reconnects -func (m *mqttPubSub) ResetConection() { - const reconnectDelay = 30 * time.Second - - // Do not reconnect if there's already one attempt in progress - select { - case m.reconnectCh <- struct{}{}: - // nop - default: - // Already a reconnection attempt in progress, so abort - return - } - - // Disconnect - m.logger.Info("Closing connection with broker… will reconnect in " + reconnectDelay.String()) - m.conn.Disconnect(100) - - for m.ctx.Err() == nil { - time.Sleep(reconnectDelay) - - // Check for context cancelation before reconnecting, since we slept - if m.ctx.Err() != nil { - return - } - - m.logger.Debug("Reconnecting…") - err := m.connect() - if err != nil { - m.logger.Errorf("Failed to reconnect, will retry in " + reconnectDelay.String()) - } else { - m.logger.Info("Connection with broker re-established") - break - } - } - - // Release the reconnection token - <-m.reconnectCh -} - func (m *mqttPubSub) createClientOptions(uri *url.URL, clientID string) *mqtt.ClientOptions { opts := mqtt.NewClientOptions(). SetClientID(clientID). @@ -372,15 +334,22 @@ func (m *mqttPubSub) createClientOptions(uri *url.URL, clientID string) *mqtt.Cl subscribeTopics[k] = m.metadata.qos } - // Note that this is a bit unusual for a pubsub component as we're using m.ctx on the handler, which is tied to the component rather than the individual subscription + // Note that this is a bit unusual for a pubsub component as we're using a background context for the handler. // This is because we can't really use a different context for each handler in a single SubscribeMultiple call, and the alternative (multiple individual Subscribe calls) is not ideal + ctx, cancel := context.WithCancel(context.Background()) + m.wg.Add(1) + go func() { + defer m.wg.Done() + defer cancel() + <-m.closeCh + }() token := c.SubscribeMultiple( subscribeTopics, - m.onMessage(m.ctx), + m.onMessage(ctx), ) var err error - subscribeCtx, subscribeCancel := context.WithTimeout(m.ctx, defaultWait) + subscribeCtx, subscribeCancel := context.WithTimeout(ctx, defaultWait) defer subscribeCancel() select { case <-token.Done(): @@ -423,6 +392,7 @@ func (m *mqttPubSub) createClientOptions(uri *url.URL, clientID string) *mqtt.Cl return opts } +// Close the connection. Blocks until all subscriptions are closed. func (m *mqttPubSub) Close() error { m.subscribingLock.Lock() defer m.subscribingLock.Unlock() @@ -432,12 +402,15 @@ func (m *mqttPubSub) Close() error { // Clear all topics from the map as a first thing, before stopping all subscriptions (we have the lock anyways) maps.Clear(m.topics) - // Cancel the context - m.cancel() + if m.closed.CompareAndSwap(false, true) { + close(m.closeCh) + } // Disconnect m.conn.Disconnect(100) + m.wg.Wait() + return nil } diff --git a/pubsub/mqtt3/mqtt_test.go b/pubsub/mqtt3/mqtt_test.go index 4d0c9f968..a4523588a 100644 --- a/pubsub/mqtt3/mqtt_test.go +++ b/pubsub/mqtt3/mqtt_test.go @@ -697,7 +697,6 @@ func Test_mqttPubSub_Publish(t *testing.T) { m := &mqttPubSub{ conn: newMockedMQTTClient(msgCh), logger: tt.fields.logger, - ctx: tt.fields.ctx, metadata: tt.fields.metadata, } diff --git a/pubsub/natsstreaming/natsstreaming.go b/pubsub/natsstreaming/natsstreaming.go index db914dae5..5030ae9c2 100644 --- a/pubsub/natsstreaming/natsstreaming.go +++ b/pubsub/natsstreaming/natsstreaming.go @@ -187,7 +187,7 @@ func parseNATSStreamingMetadata(meta pubsub.Metadata) (metadata, error) { return m, nil } -func (n *natsStreamingPubSub) Init(metadata pubsub.Metadata) error { +func (n *natsStreamingPubSub) Init(_ context.Context, metadata pubsub.Metadata) error { m, err := parseNATSStreamingMetadata(metadata) if err != nil { return err diff --git a/pubsub/pubsub.go b/pubsub/pubsub.go index ca9cf0660..a8979044d 100644 --- a/pubsub/pubsub.go +++ b/pubsub/pubsub.go @@ -22,7 +22,7 @@ import ( // PubSub is the interface for message buses. type PubSub interface { - Init(metadata Metadata) error + Init(ctx context.Context, metadata Metadata) error Features() []Feature Publish(ctx context.Context, req *PublishRequest) error Subscribe(ctx context.Context, req SubscribeRequest, handler Handler) error @@ -64,10 +64,10 @@ type Handler func(ctx context.Context, msg *NewMessage) error // orderly fashion. type BulkHandler func(ctx context.Context, msg *BulkMessage) ([]BulkSubscribeResponseEntry, error) -func Ping(pubsub PubSub) error { +func Ping(ctx context.Context, pubsub PubSub) error { // checks if this pubsub has the ping option then executes if pubsubWithPing, ok := pubsub.(health.Pinger); ok { - return pubsubWithPing.Ping() + return pubsubWithPing.Ping(ctx) } else { return fmt.Errorf("ping is not implemented by this pubsub") } diff --git a/pubsub/pulsar/pulsar.go b/pubsub/pulsar/pulsar.go index ffdd10a9d..a123c3b13 100644 --- a/pubsub/pulsar/pulsar.go +++ b/pubsub/pulsar/pulsar.go @@ -175,7 +175,7 @@ func parsePulsarMetadata(meta pubsub.Metadata) (*pulsarMetadata, error) { return &m, nil } -func (p *Pulsar) Init(metadata pubsub.Metadata) error { +func (p *Pulsar) Init(_ context.Context, metadata pubsub.Metadata) error { m, err := parsePulsarMetadata(metadata) if err != nil { return err diff --git a/pubsub/rabbitmq/rabbitmq.go b/pubsub/rabbitmq/rabbitmq.go index 15e536920..03cf67bff 100644 --- a/pubsub/rabbitmq/rabbitmq.go +++ b/pubsub/rabbitmq/rabbitmq.go @@ -21,6 +21,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" amqp "github.com/rabbitmq/amqp091-go" @@ -58,10 +59,11 @@ type rabbitMQ struct { connectionCount int metadata *metadata declaredExchanges map[string]bool - ctx context.Context - cancel context.CancelFunc connectionDial func(protocol, uri string, tlsCfg *tls.Config) (rabbitMQConnectionBroker, rabbitMQChannelBroker, error) + closeCh chan struct{} + closed atomic.Bool + wg sync.WaitGroup logger logger.Logger } @@ -95,6 +97,7 @@ func NewRabbitMQ(logger logger.Logger) pubsub.PubSub { declaredExchanges: make(map[string]bool), logger: logger, connectionDial: dial, + closeCh: make(chan struct{}), } } @@ -124,14 +127,12 @@ func dial(protocol, uri string, tlsCfg *tls.Config) (rabbitMQConnectionBroker, r } // Init does metadata parsing and connection creation. -func (r *rabbitMQ) Init(metadata pubsub.Metadata) error { +func (r *rabbitMQ) Init(_ context.Context, metadata pubsub.Metadata) error { meta, err := createMetadata(metadata, r.logger) if err != nil { return err } - r.ctx, r.cancel = context.WithCancel(context.Background()) - r.metadata = meta r.reconnect(0) @@ -247,6 +248,10 @@ func (r *rabbitMQ) publishSync(ctx context.Context, req *pubsub.PublishRequest) } func (r *rabbitMQ) Publish(ctx context.Context, req *pubsub.PublishRequest) error { + if r.closed.Load() { + return errors.New("error: rabbitMQ is closed") + } + r.logger.Debugf("%s publishing message to %s", logMessagePrefix, req.Topic) attempt := 0 @@ -272,6 +277,10 @@ func (r *rabbitMQ) Publish(ctx context.Context, req *pubsub.PublishRequest) erro } func (r *rabbitMQ) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, handler pubsub.Handler) error { + if r.closed.Load() { + return errors.New("error: rabbitMQ is closed") + } + if r.metadata.consumerID == "" { return errors.New("consumerID is required for subscriptions") } @@ -282,7 +291,18 @@ func (r *rabbitMQ) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, h // Do not set a timeout on the context, as we're just waiting for the first ack; we're using a semaphore instead ackCh := make(chan struct{}, 1) defer close(ackCh) - go r.subscribeForever(ctx, req, queueName, handler, ackCh) + + subctx, cancel := context.WithCancel(ctx) + r.wg.Add(2) + go func() { + defer r.wg.Done() + r.subscribeForever(subctx, req, queueName, handler, ackCh) + }() + go func() { + defer r.wg.Done() + defer cancel() + <-r.closeCh + }() // Wait for the ack for 1 minute or return an error select { @@ -467,14 +487,18 @@ func (r *rabbitMQ) listenMessages(ctx context.Context, channel rabbitMQChannelBr switch r.metadata.concurrency { case pubsub.Single: err = r.handleMessage(ctx, d, topic, handler) + if err != nil && mustReconnect(channel, err) { + return err + } case pubsub.Parallel: + r.wg.Add(1) go func(d amqp.Delivery) { - err = r.handleMessage(ctx, d, topic, handler) + defer r.wg.Done() + if err := r.handleMessage(ctx, d, topic, handler); err != nil { + r.logger.Errorf("%s error handling message: %v", logMessagePrefix, err) + } }(d) } - if err != nil && mustReconnect(channel, err) { - return err - } } } } @@ -563,17 +587,21 @@ func (r *rabbitMQ) reset() (err error) { } func (r *rabbitMQ) isStopped() bool { - return r.ctx.Err() != nil + return r.closed.Load() } +// Close closes the rabbitMQ connection. Blocks until all go routines are done. func (r *rabbitMQ) Close() error { r.channelMutex.Lock() defer r.channelMutex.Unlock() - r.cancel() - err := r.reset() + if r.closed.CompareAndSwap(false, true) { + close(r.closeCh) + } - return err + defer r.wg.Wait() + + return r.reset() } func (r *rabbitMQ) Features() []pubsub.Feature { diff --git a/pubsub/rabbitmq/rabbitmq_test.go b/pubsub/rabbitmq/rabbitmq_test.go index 764371202..963fef76c 100644 --- a/pubsub/rabbitmq/rabbitmq_test.go +++ b/pubsub/rabbitmq/rabbitmq_test.go @@ -17,6 +17,7 @@ import ( "context" "crypto/tls" "errors" + "sync/atomic" "testing" "time" @@ -39,10 +40,11 @@ func newRabbitMQTest(broker *rabbitMQInMemoryBroker) pubsub.PubSub { declaredExchanges: make(map[string]bool), logger: logger.NewLogger("test"), connectionDial: func(protocol, uri string, tlsCfg *tls.Config) (rabbitMQConnectionBroker, rabbitMQChannelBroker, error) { - broker.connectCount++ + broker.connectCount.Add(1) return broker, broker, nil }, + closeCh: make(chan struct{}), } } @@ -54,7 +56,7 @@ func TestNoConsumer(t *testing.T) { metadataHostnameKey: "anyhost", }, }} - err := pubsubRabbitMQ.Init(metadata) + err := pubsubRabbitMQ.Init(context.Background(), metadata) assert.NoError(t, err) err = pubsubRabbitMQ.Subscribe(context.Background(), pubsub.SubscribeRequest{}, nil) assert.Contains(t, err.Error(), "consumerID is required for subscriptions") @@ -71,7 +73,7 @@ func TestConcurrencyMode(t *testing.T) { pubsub.ConcurrencyKey: string(pubsub.Parallel), }, }} - err := pubsubRabbitMQ.Init(metadata) + err := pubsubRabbitMQ.Init(context.Background(), metadata) assert.Nil(t, err) assert.Equal(t, pubsub.Parallel, pubsubRabbitMQ.(*rabbitMQ).metadata.concurrency) }) @@ -86,7 +88,7 @@ func TestConcurrencyMode(t *testing.T) { pubsub.ConcurrencyKey: string(pubsub.Single), }, }} - err := pubsubRabbitMQ.Init(metadata) + err := pubsubRabbitMQ.Init(context.Background(), metadata) assert.Nil(t, err) assert.Equal(t, pubsub.Single, pubsubRabbitMQ.(*rabbitMQ).metadata.concurrency) }) @@ -100,7 +102,7 @@ func TestConcurrencyMode(t *testing.T) { metadataConsumerIDKey: "consumer", }, }} - err := pubsubRabbitMQ.Init(metadata) + err := pubsubRabbitMQ.Init(context.Background(), metadata) assert.Nil(t, err) assert.Equal(t, pubsub.Parallel, pubsubRabbitMQ.(*rabbitMQ).metadata.concurrency) }) @@ -115,10 +117,10 @@ func TestPublishAndSubscribe(t *testing.T) { metadataConsumerIDKey: "consumer", }, }} - err := pubsubRabbitMQ.Init(metadata) + err := pubsubRabbitMQ.Init(context.Background(), metadata) assert.Nil(t, err) - assert.Equal(t, 1, broker.connectCount) - assert.Equal(t, 0, broker.closeCount) + assert.Equal(t, int32(1), broker.connectCount.Load()) + assert.Equal(t, int32(0), broker.closeCount.Load()) topic := "mytopic" @@ -158,10 +160,10 @@ func TestPublishReconnect(t *testing.T) { metadataConsumerIDKey: "consumer", }, }} - err := pubsubRabbitMQ.Init(metadata) + err := pubsubRabbitMQ.Init(context.Background(), metadata) assert.Nil(t, err) - assert.Equal(t, 1, broker.connectCount) - assert.Equal(t, 0, broker.closeCount) + assert.Equal(t, int32(1), broker.connectCount.Load()) + assert.Equal(t, int32(0), broker.closeCount.Load()) topic := "othertopic" @@ -190,8 +192,8 @@ func TestPublishReconnect(t *testing.T) { assert.Equal(t, 1, messageCount) assert.Equal(t, "hello world", lastMessage) // Check that reconnection happened - assert.Equal(t, 3, broker.connectCount) // three counts - one initial connection plus 2 reconnect attempts - assert.Equal(t, 4, broker.closeCount) // four counts - one for connection, one for channel , times 2 reconnect attempts + assert.Equal(t, int32(3), broker.connectCount.Load()) // three counts - one initial connection plus 2 reconnect attempts + assert.Equal(t, int32(4), broker.closeCount.Load()) // four counts - one for connection, one for channel , times 2 reconnect attempts err = pubsubRabbitMQ.Publish(context.Background(), &pubsub.PublishRequest{Topic: topic, Data: []byte("foo bar")}) assert.Nil(t, err) @@ -209,10 +211,10 @@ func TestPublishReconnectAfterClose(t *testing.T) { metadataConsumerIDKey: "consumer", }, }} - err := pubsubRabbitMQ.Init(metadata) + err := pubsubRabbitMQ.Init(context.Background(), metadata) assert.Nil(t, err) - assert.Equal(t, 1, broker.connectCount) - assert.Equal(t, 0, broker.closeCount) + assert.Equal(t, int32(1), broker.connectCount.Load()) + assert.Equal(t, int32(0), broker.closeCount.Load()) topic := "mytopic2" @@ -239,15 +241,15 @@ func TestPublishReconnectAfterClose(t *testing.T) { // Close PubSub err = pubsubRabbitMQ.Close() assert.Nil(t, err) - assert.Equal(t, 2, broker.closeCount) // two counts - one for connection, one for channel + assert.Equal(t, int32(2), broker.closeCount.Load()) // two counts - one for connection, one for channel err = pubsubRabbitMQ.Publish(context.Background(), &pubsub.PublishRequest{Topic: topic, Data: []byte(errorChannelConnection)}) assert.NotNil(t, err) assert.Equal(t, 1, messageCount) assert.Equal(t, "hello world", lastMessage) // Check that reconnection did not happened - assert.Equal(t, 1, broker.connectCount) - assert.Equal(t, 2, broker.closeCount) // two counts - one for connection, one for channel + assert.Equal(t, int32(1), broker.connectCount.Load()) + assert.Equal(t, int32(2), broker.closeCount.Load()) // two counts - one for connection, one for channel } func TestSubscribeBindRoutingKeys(t *testing.T) { @@ -259,10 +261,10 @@ func TestSubscribeBindRoutingKeys(t *testing.T) { metadataConsumerIDKey: "consumer", }, }} - err := pubsubRabbitMQ.Init(metadata) + err := pubsubRabbitMQ.Init(context.Background(), metadata) assert.Nil(t, err) - assert.Equal(t, 1, broker.connectCount) - assert.Equal(t, 0, broker.closeCount) + assert.Equal(t, int32(1), broker.connectCount.Load()) + assert.Equal(t, int32(0), broker.closeCount.Load()) topic := "mytopic_routingkeys" @@ -286,10 +288,10 @@ func TestSubscribeReconnect(t *testing.T) { pubsub.ConcurrencyKey: string(pubsub.Single), }, }} - err := pubsubRabbitMQ.Init(metadata) + err := pubsubRabbitMQ.Init(context.Background(), metadata) assert.Nil(t, err) - assert.Equal(t, 1, broker.connectCount) - assert.Equal(t, 0, broker.closeCount) + assert.Equal(t, int32(1), broker.connectCount.Load()) + assert.Equal(t, int32(0), broker.closeCount.Load()) topic := "thetopic" @@ -323,8 +325,8 @@ func TestSubscribeReconnect(t *testing.T) { time.Sleep(time.Second) // Check that reconnection happened - assert.Equal(t, 3, broker.connectCount) // initial connect + 2 reconnects - assert.Equal(t, 4, broker.closeCount) // two counts for each connection closure - one for connection, one for channel + assert.Equal(t, int32(3), broker.connectCount.Load()) // initial connect + 2 reconnects + assert.Equal(t, int32(4), broker.closeCount.Load()) // two counts for each connection closure - one for connection, one for channel } func createAMQPMessage(body []byte) amqp.Delivery { @@ -334,8 +336,8 @@ func createAMQPMessage(body []byte) amqp.Delivery { type rabbitMQInMemoryBroker struct { buffer chan amqp.Delivery - connectCount int - closeCount int + connectCount atomic.Int32 + closeCount atomic.Int32 } func (r *rabbitMQInMemoryBroker) Qos(prefetchCount, prefetchSize int, global bool) error { @@ -387,11 +389,11 @@ func (r *rabbitMQInMemoryBroker) Confirm(noWait bool) error { } func (r *rabbitMQInMemoryBroker) Close() error { - r.closeCount++ + r.closeCount.Add(1) return nil } func (r *rabbitMQInMemoryBroker) IsClosed() bool { - return r.connectCount <= r.closeCount + return r.connectCount.Load() <= r.closeCount.Load() } diff --git a/pubsub/redis/redis.go b/pubsub/redis/redis.go index a8202908a..51fee060c 100644 --- a/pubsub/redis/redis.go +++ b/pubsub/redis/redis.go @@ -18,6 +18,8 @@ import ( "errors" "fmt" "strconv" + "sync" + "sync/atomic" "time" rediscomponent "github.com/dapr/components-contrib/internal/component/redis" @@ -46,11 +48,11 @@ type redisStreams struct { client rediscomponent.RedisClient clientSettings *rediscomponent.Settings logger logger.Logger + wg sync.WaitGroup + closed atomic.Bool + closeCh chan struct{} queue chan redisMessageWrapper - - ctx context.Context - cancel context.CancelFunc } // redisMessageWrapper encapsulates the message identifier, @@ -64,7 +66,10 @@ type redisMessageWrapper struct { // NewRedisStreams returns a new redis streams pub-sub implementation. func NewRedisStreams(logger logger.Logger) pubsub.PubSub { - return &redisStreams{logger: logger} + return &redisStreams{ + logger: logger, + closeCh: make(chan struct{}), + } } func parseRedisMetadata(meta pubsub.Metadata) (metadata, error) { @@ -129,7 +134,7 @@ func parseRedisMetadata(meta pubsub.Metadata) (metadata, error) { return m, nil } -func (r *redisStreams) Init(metadata pubsub.Metadata) error { +func (r *redisStreams) Init(ctx context.Context, metadata pubsub.Metadata) error { m, err := parseRedisMetadata(metadata) if err != nil { return err @@ -140,21 +145,27 @@ func (r *redisStreams) Init(metadata pubsub.Metadata) error { return err } - r.ctx, r.cancel = context.WithCancel(context.Background()) - - if _, err = r.client.PingResult(r.ctx); err != nil { + if _, err = r.client.PingResult(ctx); err != nil { return fmt.Errorf("redis streams: error connecting to redis at %s: %s", r.clientSettings.Host, err) } r.queue = make(chan redisMessageWrapper, int(r.metadata.queueDepth)) for i := uint(0); i < r.metadata.concurrency; i++ { - go r.worker() + r.wg.Add(1) + go func() { + defer r.wg.Done() + r.worker() + }() } return nil } func (r *redisStreams) Publish(ctx context.Context, req *pubsub.PublishRequest) error { + if r.closed.Load() { + return errors.New("error: redis has been closed") + } + _, err := r.client.XAdd(ctx, req.Topic, r.metadata.maxLenApprox, map[string]interface{}{"data": req.Data}) if err != nil { return fmt.Errorf("redis streams: error from publish: %s", err) @@ -164,6 +175,10 @@ func (r *redisStreams) Publish(ctx context.Context, req *pubsub.PublishRequest) } func (r *redisStreams) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, handler pubsub.Handler) error { + if r.closed.Load() { + return errors.New("error: redis has been closed") + } + err := r.client.XGroupCreateMkStream(ctx, req.Topic, r.metadata.consumerID, "0") // Ignore BUSYGROUP errors if err != nil && err.Error() != "BUSYGROUP Consumer Group name already exists" { @@ -171,8 +186,26 @@ func (r *redisStreams) Subscribe(ctx context.Context, req pubsub.SubscribeReques return err } - go r.pollNewMessagesLoop(ctx, req.Topic, handler) - go r.reclaimPendingMessagesLoop(ctx, req.Topic, handler) + loopCtx, cancel := context.WithCancel(ctx) + r.wg.Add(3) + go func() { + // Add a context which catches the close signal to account for situations + // where Close is called, but the context is not cancelled. + defer r.wg.Done() + defer cancel() + select { + case <-loopCtx.Done(): + case <-r.closeCh: + } + }() + go func() { + defer r.wg.Done() + r.pollNewMessagesLoop(loopCtx, req.Topic, handler) + }() + go func() { + defer r.wg.Done() + r.reclaimPendingMessagesLoop(loopCtx, req.Topic, handler) + }() return nil } @@ -224,8 +257,8 @@ func createRedisMessageWrapper(ctx context.Context, stream string, handler pubsu func (r *redisStreams) worker() { for { select { - // Handle cancelation - case <-r.ctx.Done(): + // Handle closing + case <-r.closeCh: return case msg := <-r.queue: @@ -252,7 +285,7 @@ func (r *redisStreams) processMessage(msg redisMessageWrapper) error { return err } - // Use the background context in case subscriptionCtx is already closed + // Use the background context in case subscriptionCtx is already closed. if err := r.client.XAck(context.Background(), msg.message.Topic, r.metadata.consumerID, msg.messageID); err != nil { r.logger.Errorf("Error acknowledging Redis message %s: %v", msg.messageID, err) @@ -399,7 +432,7 @@ func (r *redisStreams) removeMessagesThatNoLongerExistFromPending(ctx context.Co // Ack the message to remove it from the pending list. if errors.Is(err, r.client.GetNilValueError()) { - // Use the background context in case subscriptionCtx is already closed + // Use the background context in case subscriptionCtx is already closed. if err = r.client.XAck(context.Background(), stream, r.metadata.consumerID, pendingID); err != nil { r.logger.Errorf("error acknowledging Redis message %s after failed claim for %s: %v", pendingID, stream, err) } @@ -411,8 +444,9 @@ func (r *redisStreams) removeMessagesThatNoLongerExistFromPending(ctx context.Co } func (r *redisStreams) Close() error { - if r.cancel != nil { - r.cancel() + defer r.wg.Wait() + if r.closed.CompareAndSwap(false, true) { + close(r.closeCh) } if r.client == nil { @@ -425,8 +459,8 @@ func (r *redisStreams) Features() []pubsub.Feature { return nil } -func (r *redisStreams) Ping() error { - if _, err := r.client.PingResult(context.Background()); err != nil { +func (r *redisStreams) Ping(ctx context.Context) error { + if _, err := r.client.PingResult(ctx); err != nil { return fmt.Errorf("redis pubsub: error connecting to redis at %s: %s", r.clientSettings.Host, err) } diff --git a/pubsub/redis/redis_test.go b/pubsub/redis/redis_test.go index 99a5f2533..75738fafe 100644 --- a/pubsub/redis/redis_test.go +++ b/pubsub/redis/redis_test.go @@ -96,7 +96,6 @@ func TestProcessStreams(t *testing.T) { // act testRedisStream := &redisStreams{logger: logger.NewLogger("test")} - testRedisStream.ctx, testRedisStream.cancel = context.WithCancel(context.Background()) testRedisStream.queue = make(chan redisMessageWrapper, 10) go testRedisStream.worker() testRedisStream.enqueueMessages(context.Background(), fakeConsumerID, fakeHandler, generateRedisStreamTestData(2, 3, expectedData)) diff --git a/pubsub/rocketmq/rocketmq.go b/pubsub/rocketmq/rocketmq.go index b941f9fb7..3fe8d2e2a 100644 --- a/pubsub/rocketmq/rocketmq.go +++ b/pubsub/rocketmq/rocketmq.go @@ -85,7 +85,7 @@ func NewRocketMQ(l logger.Logger) pubsub.PubSub { } } -func (r *rocketMQ) Init(metadata pubsub.Metadata) error { +func (r *rocketMQ) Init(_ context.Context, metadata pubsub.Metadata) error { var err error r.metadata, err = parseRocketMQMetaData(metadata) if err != nil { diff --git a/pubsub/rocketmq/rocketmq_test.go b/pubsub/rocketmq/rocketmq_test.go index 366f11602..582ad310c 100644 --- a/pubsub/rocketmq/rocketmq_test.go +++ b/pubsub/rocketmq/rocketmq_test.go @@ -47,7 +47,7 @@ func TestParseRocketMQMetadata(t *testing.T) { func TestRocketMQ_Init(t *testing.T) { meta := getTestMetadata() r := NewRocketMQ(logger.NewLogger("test")) - err := r.Init(pubsub.Metadata{Base: mdata.Base{Properties: meta}}) + err := r.Init(context.Background(), pubsub.Metadata{Base: mdata.Base{Properties: meta}}) assert.Nil(t, err) } @@ -217,6 +217,6 @@ func BuildRocketMQ() (logger.Logger, pubsub.PubSub, error) { meta := getTestMetadata() l := logger.NewLogger("test") r := NewRocketMQ(l) - err := r.Init(pubsub.Metadata{Base: mdata.Base{Properties: meta}}) + err := r.Init(context.Background(), pubsub.Metadata{Base: mdata.Base{Properties: meta}}) return l, r, err } diff --git a/pubsub/solace/amqp/amqp.go b/pubsub/solace/amqp/amqp.go index 137a77c18..c5ca42303 100644 --- a/pubsub/solace/amqp/amqp.go +++ b/pubsub/solace/amqp/amqp.go @@ -43,8 +43,6 @@ type amqpPubSub struct { logger logger.Logger publishLock sync.RWMutex publishRetryCount int - ctx context.Context - cancel context.CancelFunc } // NewAMQPPubsub returns a new AMQPPubSub instance @@ -56,7 +54,7 @@ func NewAMQPPubsub(logger logger.Logger) pubsub.PubSub { } // Init parses the metadata and creates a new Pub Sub Client. -func (a *amqpPubSub) Init(metadata pubsub.Metadata) error { +func (a *amqpPubSub) Init(ctx context.Context, metadata pubsub.Metadata) error { amqpMeta, err := parseAMQPMetaData(metadata, a.logger) if err != nil { return err @@ -64,9 +62,7 @@ func (a *amqpPubSub) Init(metadata pubsub.Metadata) error { a.metadata = amqpMeta - a.ctx, a.cancel = context.WithCancel(context.Background()) - - s, err := a.connect() + s, err := a.connect(ctx) if err != nil { return err } @@ -148,7 +144,7 @@ func (a *amqpPubSub) Publish(ctx context.Context, req *pubsub.PublishRequest) er func (a *amqpPubSub) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, handler pubsub.Handler) error { prefixedTopic := AddPrefixToAddress(req.Topic) - receiver, err := a.session.NewReceiver(a.ctx, + receiver, err := a.session.NewReceiver(ctx, prefixedTopic, nil, ) @@ -207,7 +203,7 @@ func (a *amqpPubSub) subscribeForever(ctx context.Context, receiver *amqp.Receiv } // Connect to the AMQP broker -func (a *amqpPubSub) connect() (*amqp.Session, error) { +func (a *amqpPubSub) connect(ctx context.Context) (*amqp.Session, error) { uri, err := url.Parse(a.metadata.url) if err != nil { return nil, err @@ -222,7 +218,7 @@ func (a *amqpPubSub) connect() (*amqp.Session, error) { } // Open a session - session, err := client.NewSession(a.ctx, nil) + session, err := client.NewSession(ctx, nil) if err != nil { a.logger.Fatal("Creating AMQP session:", err) } @@ -279,7 +275,9 @@ func (a *amqpPubSub) Close() error { defer a.publishLock.Unlock() - err := a.session.Close(a.ctx) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err := a.session.Close(ctx) if err != nil { a.logger.Warnf("failed to close the connection.", err) } diff --git a/secretstores/alicloud/parameterstore/parameterstore.go b/secretstores/alicloud/parameterstore/parameterstore.go index cfee8c2e3..c44d28db2 100644 --- a/secretstores/alicloud/parameterstore/parameterstore.go +++ b/secretstores/alicloud/parameterstore/parameterstore.go @@ -61,7 +61,7 @@ type oosSecretStore struct { } // Init creates a Alicloud parameter store client. -func (o *oosSecretStore) Init(metadata secretstores.Metadata) error { +func (o *oosSecretStore) Init(_ context.Context, metadata secretstores.Metadata) error { meta, err := o.getParameterStoreMetadata(metadata) if err != nil { return err diff --git a/secretstores/alicloud/parameterstore/parameterstore_test.go b/secretstores/alicloud/parameterstore/parameterstore_test.go index 9a7d44d50..d81d3d7bc 100644 --- a/secretstores/alicloud/parameterstore/parameterstore_test.go +++ b/secretstores/alicloud/parameterstore/parameterstore_test.go @@ -81,7 +81,7 @@ func TestInit(t *testing.T) { "accessKeyId": "a", "accessKeySecret": "a", } - err := s.Init(m) + err := s.Init(context.Background(), m) assert.Nil(t, err) }) @@ -90,7 +90,7 @@ func TestInit(t *testing.T) { "accessKeyId": "a", "accessKeySecret": "a", } - err := s.Init(m) + err := s.Init(context.Background(), m) assert.NotNil(t, err) }) } @@ -208,7 +208,7 @@ func TestBulkGetSecret(t *testing.T) { func TestGetFeatures(t *testing.T) { m := secretstores.Metadata{} s := NewParameterStore(logger.NewLogger("test")) - s.Init(m) + s.Init(context.Background(), m) t.Run("no features are advertised", func(t *testing.T) { f := s.Features() assert.Empty(t, f) diff --git a/secretstores/aws/parameterstore/parameterstore.go b/secretstores/aws/parameterstore/parameterstore.go index 6b81ae67b..7860e9a7a 100644 --- a/secretstores/aws/parameterstore/parameterstore.go +++ b/secretstores/aws/parameterstore/parameterstore.go @@ -55,7 +55,7 @@ type ssmSecretStore struct { } // Init creates a AWS secret manager client. -func (s *ssmSecretStore) Init(metadata secretstores.Metadata) error { +func (s *ssmSecretStore) Init(_ context.Context, metadata secretstores.Metadata) error { meta, err := s.getSecretManagerMetadata(metadata) if err != nil { return err diff --git a/secretstores/aws/parameterstore/parameterstore_test.go b/secretstores/aws/parameterstore/parameterstore_test.go index 6425b4b8b..571e46123 100644 --- a/secretstores/aws/parameterstore/parameterstore_test.go +++ b/secretstores/aws/parameterstore/parameterstore_test.go @@ -57,7 +57,7 @@ func TestInit(t *testing.T) { "SecretKey": "a", "SessionToken": "a", } - err := s.Init(m) + err := s.Init(context.Background(), m) assert.Nil(t, err) }) } diff --git a/secretstores/aws/secretmanager/secretmanager.go b/secretstores/aws/secretmanager/secretmanager.go index 1318940c2..0db219f17 100644 --- a/secretstores/aws/secretmanager/secretmanager.go +++ b/secretstores/aws/secretmanager/secretmanager.go @@ -53,7 +53,7 @@ type smSecretStore struct { } // Init creates a AWS secret manager client. -func (s *smSecretStore) Init(metadata secretstores.Metadata) error { +func (s *smSecretStore) Init(_ context.Context, metadata secretstores.Metadata) error { meta, err := s.getSecretManagerMetadata(metadata) if err != nil { return err diff --git a/secretstores/aws/secretmanager/secretmanager_test.go b/secretstores/aws/secretmanager/secretmanager_test.go index 9694ff632..8dc5b60d6 100644 --- a/secretstores/aws/secretmanager/secretmanager_test.go +++ b/secretstores/aws/secretmanager/secretmanager_test.go @@ -50,7 +50,7 @@ func TestInit(t *testing.T) { "SecretKey": "a", "SessionToken": "a", } - err := s.Init(m) + err := s.Init(context.Background(), m) assert.Nil(t, err) }) } diff --git a/secretstores/azure/keyvault/keyvault.go b/secretstores/azure/keyvault/keyvault.go index 4aad72af4..69e78cfac 100644 --- a/secretstores/azure/keyvault/keyvault.go +++ b/secretstores/azure/keyvault/keyvault.go @@ -61,7 +61,7 @@ func NewAzureKeyvaultSecretStore(logger logger.Logger) secretstores.SecretStore } // Init creates a Azure Key Vault client. -func (k *keyvaultSecretStore) Init(meta secretstores.Metadata) error { +func (k *keyvaultSecretStore) Init(_ context.Context, meta secretstores.Metadata) error { m := KeyvaultMetadata{} if err := metadata.DecodeMetadata(meta.Properties, &m); err != nil { return err diff --git a/secretstores/azure/keyvault/keyvault_test.go b/secretstores/azure/keyvault/keyvault_test.go index 356c21aa3..7022d95a6 100644 --- a/secretstores/azure/keyvault/keyvault_test.go +++ b/secretstores/azure/keyvault/keyvault_test.go @@ -15,6 +15,7 @@ limitations under the License. package keyvault import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -33,7 +34,7 @@ func TestInit(t *testing.T) { "azureClientId": "00000000-0000-0000-0000-000000000000", "azureClientSecret": "passw0rd", } - err := s.Init(m) + err := s.Init(context.Background(), m) assert.Nil(t, err) kv, ok := s.(*keyvaultSecretStore) assert.True(t, ok) @@ -49,7 +50,7 @@ func TestInit(t *testing.T) { "azureClientSecret": "passw0rd", "azureEnvironment": "AZURECHINACLOUD", } - err := s.Init(m) + err := s.Init(context.Background(), m) assert.Nil(t, err) kv, ok := s.(*keyvaultSecretStore) assert.True(t, ok) @@ -64,7 +65,7 @@ func TestInit(t *testing.T) { "azureClientId": "00000000-0000-0000-0000-000000000000", "azureClientSecret": "passw0rd", } - err := s.Init(m) + err := s.Init(context.Background(), m) assert.Nil(t, err) kv, ok := s.(*keyvaultSecretStore) assert.True(t, ok) @@ -79,7 +80,7 @@ func TestInit(t *testing.T) { "azureClientId": "00000000-0000-0000-0000-000000000000", "azureClientSecret": "passw0rd", } - err := s.Init(m) + err := s.Init(context.Background(), m) assert.Nil(t, err) kv, ok := s.(*keyvaultSecretStore) assert.True(t, ok) diff --git a/secretstores/gcp/secretmanager/secretmanager.go b/secretstores/gcp/secretmanager/secretmanager.go index c818461fa..d2d649869 100644 --- a/secretstores/gcp/secretmanager/secretmanager.go +++ b/secretstores/gcp/secretmanager/secretmanager.go @@ -68,13 +68,13 @@ func NewSecreteManager(logger logger.Logger) secretstores.SecretStore { } // Init creates a GCP secret manager client. -func (s *Store) Init(metadataRaw secretstores.Metadata) error { +func (s *Store) Init(ctx context.Context, metadataRaw secretstores.Metadata) error { metadata, err := s.parseSecretManagerMetadata(metadataRaw) if err != nil { return err } - client, err := s.getClient(metadata) + client, err := s.getClient(ctx, metadata) if err != nil { return fmt.Errorf("failed to setup secretmanager client: %s", err) } @@ -85,10 +85,9 @@ func (s *Store) Init(metadataRaw secretstores.Metadata) error { return nil } -func (s *Store) getClient(metadata *GcpSecretManagerMetadata) (*secretmanager.Client, error) { +func (s *Store) getClient(ctx context.Context, metadata *GcpSecretManagerMetadata) (*secretmanager.Client, error) { b, _ := json.Marshal(metadata) clientOptions := option.WithCredentialsJSON(b) - ctx := context.Background() client, err := secretmanager.NewClient(ctx, clientOptions) if err != nil { diff --git a/secretstores/gcp/secretmanager/secretmanager_test.go b/secretstores/gcp/secretmanager/secretmanager_test.go index a6e9fe767..e93dcda88 100644 --- a/secretstores/gcp/secretmanager/secretmanager_test.go +++ b/secretstores/gcp/secretmanager/secretmanager_test.go @@ -53,6 +53,7 @@ func (s *MockStore) Close() error { } func TestInit(t *testing.T) { + ctx := context.Background() m := secretstores.Metadata{} sm := NewSecreteManager(logger.NewLogger("test")) t.Run("Init with Wrong metadata", func(t *testing.T) { @@ -69,7 +70,7 @@ func TestInit(t *testing.T) { "client_x509_cert_url": "a", } - err := sm.Init(m) + err := sm.Init(ctx, m) assert.NotNil(t, err) assert.Equal(t, err, fmt.Errorf("failed to setup secretmanager client: google: could not parse key: private key should be a PEM or plain PKCS1 or PKCS8; parse error: asn1: syntax error: truncated tag or length")) }) @@ -78,7 +79,7 @@ func TestInit(t *testing.T) { m.Properties = map[string]string{ "dummy": "a", } - err := sm.Init(m) + err := sm.Init(ctx, m) assert.NotNil(t, err) assert.Equal(t, err, fmt.Errorf("missing property `type` in metadata")) }) @@ -87,13 +88,14 @@ func TestInit(t *testing.T) { m.Properties = map[string]string{ "type": "service_account", } - err := sm.Init(m) + err := sm.Init(ctx, m) assert.NotNil(t, err) assert.Equal(t, err, fmt.Errorf("missing property `project_id` in metadata")) }) } func TestGetSecret(t *testing.T) { + ctx := context.Background() sm := NewSecreteManager(logger.NewLogger("test")) t.Run("Get Secret - without Init", func(t *testing.T) { @@ -118,7 +120,7 @@ func TestGetSecret(t *testing.T) { "client_x509_cert_url": "a", }, }} - sm.Init(m) + sm.Init(ctx, m) v, err := sm.GetSecret(context.Background(), secretstores.GetSecretRequest{Name: "test"}) assert.NotNil(t, err) assert.Equal(t, secretstores.GetSecretResponse{Data: nil}, v) @@ -147,6 +149,7 @@ func TestGetSecret(t *testing.T) { } func TestBulkGetSecret(t *testing.T) { + ctx := context.Background() sm := NewSecreteManager(logger.NewLogger("test")) t.Run("Bulk Get Secret - without Init", func(t *testing.T) { @@ -173,7 +176,7 @@ func TestBulkGetSecret(t *testing.T) { }, }, } - sm.Init(m) + sm.Init(ctx, m) v, err := sm.BulkGetSecret(context.Background(), secretstores.BulkGetSecretRequest{}) assert.NotNil(t, err) assert.Equal(t, secretstores.BulkGetSecretResponse{Data: nil}, v) diff --git a/secretstores/hashicorp/vault/vault.go b/secretstores/hashicorp/vault/vault.go index a598037e4..fb10ac435 100644 --- a/secretstores/hashicorp/vault/vault.go +++ b/secretstores/hashicorp/vault/vault.go @@ -137,7 +137,7 @@ func NewHashiCorpVaultSecretStore(logger logger.Logger) secretstores.SecretStore } // Init creates a HashiCorp Vault client. -func (v *vaultSecretStore) Init(meta secretstores.Metadata) error { +func (v *vaultSecretStore) Init(_ context.Context, meta secretstores.Metadata) error { m := VaultMetadata{ VaultKVUsePrefix: true, } diff --git a/secretstores/hashicorp/vault/vault_test.go b/secretstores/hashicorp/vault/vault_test.go index 99dca5fd3..8f4b5660e 100644 --- a/secretstores/hashicorp/vault/vault_test.go +++ b/secretstores/hashicorp/vault/vault_test.go @@ -14,6 +14,7 @@ limitations under the License. package vault import ( + "context" "encoding/base64" "os" "strconv" @@ -119,7 +120,7 @@ func TestVaultEnginePath(t *testing.T) { t.Run("without engine path config", func(t *testing.T) { v := vaultSecretStore{} - err := v.Init(secretstores.Metadata{Base: metadata.Base{Properties: map[string]string{componentVaultToken: expectedTok, "skipVerify": "true"}}}) + err := v.Init(context.Background(), secretstores.Metadata{Base: metadata.Base{Properties: map[string]string{componentVaultToken: expectedTok, "skipVerify": "true"}}}) assert.Nil(t, err) assert.Equal(t, v.vaultEnginePath, defaultVaultEnginePath) }) @@ -127,7 +128,7 @@ func TestVaultEnginePath(t *testing.T) { t.Run("with engine path config", func(t *testing.T) { v := vaultSecretStore{} - err := v.Init(secretstores.Metadata{Base: metadata.Base{Properties: map[string]string{componentVaultToken: expectedTok, "skipVerify": "true", vaultEnginePath: "kv"}}}) + err := v.Init(context.Background(), secretstores.Metadata{Base: metadata.Base{Properties: map[string]string{componentVaultToken: expectedTok, "skipVerify": "true", vaultEnginePath: "kv"}}}) assert.Nil(t, err) assert.Equal(t, v.vaultEnginePath, "kv") }) @@ -151,7 +152,7 @@ func TestVaultTokenPrefix(t *testing.T) { logger: nil, } - if err := target.Init(m); err != nil { + if err := target.Init(context.Background(), m); err != nil { t.Fatal(err) } @@ -174,7 +175,7 @@ func TestVaultTokenPrefix(t *testing.T) { logger: nil, } - if err := target.Init(m); err != nil { + if err := target.Init(context.Background(), m); err != nil { t.Fatal(err) } @@ -215,7 +216,7 @@ func TestVaultTokenMountPathOrVaultTokenRequired(t *testing.T) { logger: nil, } - err := target.Init(m) + err := target.Init(context.Background(), m) assert.Equal(t, "", target.vaultToken) assert.Equal(t, "", target.vaultTokenMountPath) @@ -237,7 +238,7 @@ func TestVaultTokenMountPathOrVaultTokenRequired(t *testing.T) { logger: nil, } - if err := target.Init(m); err != nil { + if err := target.Init(context.Background(), m); err != nil { t.Fatal(err) } @@ -259,7 +260,7 @@ func TestVaultTokenMountPathOrVaultTokenRequired(t *testing.T) { logger: nil, } - if err := target.Init(m); err != nil { + if err := target.Init(context.Background(), m); err != nil { t.Fatal(err) } @@ -282,7 +283,7 @@ func TestVaultTokenMountPathOrVaultTokenRequired(t *testing.T) { logger: nil, } - err := target.Init(m) + err := target.Init(context.Background(), m) assert.Equal(t, expectedTok, target.vaultToken) assert.Equal(t, expectedTokMountPath, target.vaultTokenMountPath) @@ -309,7 +310,7 @@ func TestDefaultVaultAddress(t *testing.T) { logger: nil, } - if err := target.Init(m); err != nil { + if err := target.Init(context.Background(), m); err != nil { t.Fatal(err) } @@ -334,7 +335,7 @@ func TestVaultValueType(t *testing.T) { logger: nil, } - err := target.Init(m) + err := target.Init(context.Background(), m) assert.Nil(t, err) assert.True(t, target.vaultValueType.isMapType()) }) @@ -355,7 +356,7 @@ func TestVaultValueType(t *testing.T) { logger: nil, } - err := target.Init(m) + err := target.Init(context.Background(), m) assert.Nil(t, err) assert.False(t, target.vaultValueType.isMapType()) }) @@ -375,7 +376,7 @@ func TestVaultValueType(t *testing.T) { logger: nil, } - err := target.Init(m) + err := target.Init(context.Background(), m) assert.Nil(t, err) assert.True(t, target.vaultValueType.isMapType()) }) @@ -396,7 +397,7 @@ func TestVaultValueType(t *testing.T) { logger: nil, } - err := target.Init(m) + err := target.Init(context.Background(), m) assert.Error(t, err, "vault init error, invalid value type incorrect, accepted values are map or text") }) } @@ -428,7 +429,7 @@ func TestGetFeatures(t *testing.T) { // the call x509.SystemCertPool() because system root pool is not // available on Windows so ignore the error for when the tests are run // on the Windows platform during CI - _ = target.Init(m) + _ = target.Init(context.Background(), m) return target } diff --git a/secretstores/huaweicloud/csms/csms.go b/secretstores/huaweicloud/csms/csms.go index b218d3ccc..80d85e376 100644 --- a/secretstores/huaweicloud/csms/csms.go +++ b/secretstores/huaweicloud/csms/csms.go @@ -57,7 +57,7 @@ func NewHuaweiCsmsSecretStore(logger logger.Logger) secretstores.SecretStore { } // Init creates a Huawei csms client. -func (c *csmsSecretStore) Init(meta secretstores.Metadata) error { +func (c *csmsSecretStore) Init(ctx context.Context, meta secretstores.Metadata) error { m := CsmsSecretStoreMetadata{} metadata.DecodeMetadata(meta.Properties, &m) auth := basic.NewCredentialsBuilder(). diff --git a/secretstores/kubernetes/kubernetes.go b/secretstores/kubernetes/kubernetes.go index 34a111e87..a5f534114 100644 --- a/secretstores/kubernetes/kubernetes.go +++ b/secretstores/kubernetes/kubernetes.go @@ -42,7 +42,7 @@ func NewKubernetesSecretStore(logger logger.Logger) secretstores.SecretStore { } // Init creates a Kubernetes client. -func (k *kubernetesSecretStore) Init(metadata secretstores.Metadata) error { +func (k *kubernetesSecretStore) Init(_ context.Context, metadata secretstores.Metadata) error { client, err := kubeclient.GetKubeClient() if err != nil { return err diff --git a/secretstores/local/env/envstore.go b/secretstores/local/env/envstore.go index ac48d7829..b669abb06 100644 --- a/secretstores/local/env/envstore.go +++ b/secretstores/local/env/envstore.go @@ -38,7 +38,7 @@ func NewEnvSecretStore(logger logger.Logger) secretstores.SecretStore { } // Init creates a Local secret store. -func (s *envSecretStore) Init(metadata secretstores.Metadata) error { +func (s *envSecretStore) Init(_ context.Context, metadata secretstores.Metadata) error { return nil } diff --git a/secretstores/local/env/envstore_test.go b/secretstores/local/env/envstore_test.go index 20641f77a..79ee3d460 100644 --- a/secretstores/local/env/envstore_test.go +++ b/secretstores/local/env/envstore_test.go @@ -36,12 +36,12 @@ func TestEnvStore(t *testing.T) { require.Equal(t, secret, os.Getenv(key)) t.Run("Init", func(t *testing.T) { - err := s.Init(secretstores.Metadata{}) + err := s.Init(context.Background(), secretstores.Metadata{}) require.NoError(t, err) }) t.Run("Get", func(t *testing.T) { - err := s.Init(secretstores.Metadata{}) + err := s.Init(context.Background(), secretstores.Metadata{}) require.NoError(t, err) resp, err := s.GetSecret(context.Background(), secretstores.GetSecretRequest{Name: key}) require.NoError(t, err) @@ -50,7 +50,7 @@ func TestEnvStore(t *testing.T) { }) t.Run("Bulk get", func(t *testing.T) { - err := s.Init(secretstores.Metadata{}) + err := s.Init(context.Background(), secretstores.Metadata{}) require.NoError(t, err) resp, err := s.BulkGetSecret(context.Background(), secretstores.BulkGetSecretRequest{}) require.NoError(t, err) @@ -63,7 +63,7 @@ func TestEnvStore(t *testing.T) { t.Setenv("DAPR_API_TOKEN", "mondo") t.Setenv("FOO", "bar") - err := s.Init(secretstores.Metadata{}) + err := s.Init(context.Background(), secretstores.Metadata{}) require.NoError(t, err) t.Run("Get", func(t *testing.T) { diff --git a/secretstores/local/file/filestore.go b/secretstores/local/file/filestore.go index 4acbace13..995876f06 100644 --- a/secretstores/local/file/filestore.go +++ b/secretstores/local/file/filestore.go @@ -56,7 +56,7 @@ func NewLocalSecretStore(logger logger.Logger) secretstores.SecretStore { } // Init creates a Local secret store. -func (j *localSecretStore) Init(metadata secretstores.Metadata) error { +func (j *localSecretStore) Init(_ context.Context, metadata secretstores.Metadata) error { meta, err := j.getLocalSecretStoreMetadata(metadata) if err != nil { return err diff --git a/secretstores/local/file/filestore_test.go b/secretstores/local/file/filestore_test.go index 57002f21b..118e1c534 100644 --- a/secretstores/local/file/filestore_test.go +++ b/secretstores/local/file/filestore_test.go @@ -42,7 +42,7 @@ func TestInit(t *testing.T) { "secretsFile": "a", "nestedSeparator": "a", } - err := s.Init(m) + err := s.Init(context.Background(), m) assert.Nil(t, err) }) @@ -50,7 +50,7 @@ func TestInit(t *testing.T) { m.Properties = map[string]string{ "dummy": "a", } - err := s.Init(m) + err := s.Init(context.Background(), m) assert.NotNil(t, err) assert.Equal(t, err, fmt.Errorf("missing local secrets file in metadata")) }) @@ -74,7 +74,7 @@ func TestSeparator(t *testing.T) { "secretsFile": "a", "nestedSeparator": ".", } - err := s.Init(m) + err := s.Init(context.Background(), m) assert.Nil(t, err) req := secretstores.GetSecretRequest{ @@ -90,7 +90,7 @@ func TestSeparator(t *testing.T) { m.Properties = map[string]string{ "secretsFile": "a", } - err := s.Init(m) + err := s.Init(context.Background(), m) assert.Nil(t, err) req := secretstores.GetSecretRequest{ @@ -118,7 +118,7 @@ func TestGetSecret(t *testing.T) { return secrets, nil }, } - s.Init(m) + s.Init(context.Background(), m) t.Run("successfully retrieve secrets", func(t *testing.T) { req := secretstores.GetSecretRequest{ @@ -160,7 +160,7 @@ func TestBulkGetSecret(t *testing.T) { return secrets, nil }, } - s.Init(m) + s.Init(context.Background(), m) t.Run("successfully retrieve secrets", func(t *testing.T) { req := secretstores.BulkGetSecretRequest{} @@ -197,7 +197,7 @@ func TestMultiValuedSecrets(t *testing.T) { return secrets, err }, } - err := s.Init(m) + err := s.Init(context.Background(), m) require.NoError(t, err) t.Run("MultiValued stores support MULTIPLE_KEY_VALUES_PER_SECRET", func(t *testing.T) { diff --git a/secretstores/secret_store.go b/secretstores/secret_store.go index 3df22fa38..4257222b5 100644 --- a/secretstores/secret_store.go +++ b/secretstores/secret_store.go @@ -23,7 +23,7 @@ import ( // SecretStore is the interface for a component that handles secrets management. type SecretStore interface { // Init authenticates with the actual secret store and performs other init operation - Init(metadata Metadata) error + Init(ctx context.Context, metadata Metadata) error // GetSecret retrieves a secret using a key and returns a map of decrypted string/string values. GetSecret(ctx context.Context, req GetSecretRequest) (GetSecretResponse, error) // BulkGetSecret retrieves all secrets in the store and returns a map of decrypted string/string values. @@ -34,10 +34,10 @@ type SecretStore interface { GetComponentMetadata() map[string]string } -func Ping(secretStore SecretStore) error { +func Ping(ctx context.Context, secretStore SecretStore) error { // checks if this secretStore has the ping option then executes if secretStoreWithPing, ok := secretStore.(health.Pinger); ok { - return secretStoreWithPing.Ping() + return secretStoreWithPing.Ping(ctx) } else { return fmt.Errorf("ping is not implemented by this secret store") } diff --git a/secretstores/tencentcloud/ssm/ssm.go b/secretstores/tencentcloud/ssm/ssm.go index 865aed829..306be4cfd 100644 --- a/secretstores/tencentcloud/ssm/ssm.go +++ b/secretstores/tencentcloud/ssm/ssm.go @@ -69,7 +69,7 @@ func NewSSM(logger logger.Logger) secretstores.SecretStore { } // Init creates a TencentCloud ssm client. -func (s *ssmSecretStore) Init(meta secretstores.Metadata) error { +func (s *ssmSecretStore) Init(_ context.Context, meta secretstores.Metadata) error { m := SsmMetadata{} err := metadata.DecodeMetadata(meta.Properties, &m) if err != nil { diff --git a/state/aerospike/aerospike.go b/state/aerospike/aerospike.go index e64fe1332..ff592587b 100644 --- a/state/aerospike/aerospike.go +++ b/state/aerospike/aerospike.go @@ -91,7 +91,7 @@ func parseAndValidateMetadata(meta state.Metadata) (*aerospikeMetadata, error) { } // Init does metadata and connection parsing. -func (aspike *Aerospike) Init(metadata state.Metadata) error { +func (aspike *Aerospike) Init(_ context.Context, metadata state.Metadata) error { m, err := parseAndValidateMetadata(metadata) if err != nil { return err diff --git a/state/alicloud/tablestore/tablestore.go b/state/alicloud/tablestore/tablestore.go index 0b6bc92d7..2d975f262 100644 --- a/state/alicloud/tablestore/tablestore.go +++ b/state/alicloud/tablestore/tablestore.go @@ -54,7 +54,7 @@ func NewAliCloudTableStore(logger logger.Logger) state.Store { } } -func (s *AliCloudTableStore) Init(metadata state.Metadata) error { +func (s *AliCloudTableStore) Init(_ context.Context, metadata state.Metadata) error { m, err := s.parse(metadata) if err != nil { return err diff --git a/state/alicloud/tablestore/tablestore_test.go b/state/alicloud/tablestore/tablestore_test.go index 790b4fc40..9c511a561 100644 --- a/state/alicloud/tablestore/tablestore_test.go +++ b/state/alicloud/tablestore/tablestore_test.go @@ -52,7 +52,7 @@ func TestReadAndWrite(t *testing.T) { defer ctl.Finish() store := NewAliCloudTableStore(logger.NewLogger("test")).(*AliCloudTableStore) - store.Init(state.Metadata{}) + store.Init(context.Background(), state.Metadata{}) store.client = &mockClient{ data: make(map[string][]byte), diff --git a/state/aws/dynamodb/dynamodb.go b/state/aws/dynamodb/dynamodb.go index f1eb010cc..6143d7622 100644 --- a/state/aws/dynamodb/dynamodb.go +++ b/state/aws/dynamodb/dynamodb.go @@ -66,7 +66,7 @@ func NewDynamoDBStateStore(_ logger.Logger) state.Store { } // Init does metadata and connection parsing. -func (d *StateStore) Init(metadata state.Metadata) error { +func (d *StateStore) Init(_ context.Context, metadata state.Metadata) error { meta, err := d.getDynamoDBMetadata(metadata) if err != nil { return err diff --git a/state/aws/dynamodb/dynamodb_test.go b/state/aws/dynamodb/dynamodb_test.go index ffd71a57c..eff6c90af 100644 --- a/state/aws/dynamodb/dynamodb_test.go +++ b/state/aws/dynamodb/dynamodb_test.go @@ -85,7 +85,7 @@ func TestInit(t *testing.T) { "Table": "a", "TtlAttributeName": "a", } - err := s.Init(m) + err := s.Init(context.Background(), m) assert.Nil(t, err) }) @@ -93,7 +93,7 @@ func TestInit(t *testing.T) { m.Properties = map[string]string{ "Dummy": "a", } - err := s.Init(m) + err := s.Init(context.Background(), m) assert.NotNil(t, err) assert.Equal(t, err, fmt.Errorf("missing dynamodb table name")) }) @@ -103,7 +103,7 @@ func TestInit(t *testing.T) { "Table": "a", "Region": "eu-west-1", } - err := s.Init(m) + err := s.Init(context.Background(), m) assert.Nil(t, err) }) @@ -113,7 +113,7 @@ func TestInit(t *testing.T) { "Table": "a", "partitionKey": pkey, } - err := s.Init(m) + err := s.Init(context.Background(), m) assert.Nil(t, err) assert.Equal(t, s.partitionKey, pkey) }) diff --git a/state/azure/blobstorage/blobstorage.go b/state/azure/blobstorage/blobstorage.go index fb532dcf6..c495ce0ff 100644 --- a/state/azure/blobstorage/blobstorage.go +++ b/state/azure/blobstorage/blobstorage.go @@ -71,7 +71,7 @@ type StateStore struct { } // Init the connection to blob storage, optionally creates a blob container if it doesn't exist. -func (r *StateStore) Init(metadata state.Metadata) error { +func (r *StateStore) Init(_ context.Context, metadata state.Metadata) error { var err error r.containerClient, _, err = storageinternal.CreateContainerStorageClient(r.logger, metadata.Properties) if err != nil { @@ -100,8 +100,8 @@ func (r *StateStore) Set(ctx context.Context, req *state.SetRequest) error { return r.writeFile(ctx, req) } -func (r *StateStore) Ping() error { - if _, err := r.containerClient.GetProperties(context.Background(), nil); err != nil { +func (r *StateStore) Ping(ctx context.Context) error { + if _, err := r.containerClient.GetProperties(ctx, nil); err != nil { return fmt.Errorf("blob storage: error connecting to Blob storage at %s: %s", r.containerClient.URL(), err) } diff --git a/state/azure/blobstorage/blobstorage_test.go b/state/azure/blobstorage/blobstorage_test.go index e8ec8a97b..afcc768b8 100644 --- a/state/azure/blobstorage/blobstorage_test.go +++ b/state/azure/blobstorage/blobstorage_test.go @@ -14,6 +14,7 @@ limitations under the License. package blobstorage import ( + "context" "fmt" "testing" @@ -32,7 +33,7 @@ func TestInit(t *testing.T) { "accountKey": "e+Dnvl8EOxYxV94nurVaRQ==", "containerName": "dapr", } - err := s.Init(m) + err := s.Init(context.Background(), m) assert.Nil(t, err) assert.Equal(t, "https://acc.blob.core.windows.net/dapr", s.containerClient.URL()) }) @@ -41,7 +42,7 @@ func TestInit(t *testing.T) { m.Properties = map[string]string{ "invalidValue": "a", } - err := s.Init(m) + err := s.Init(context.Background(), m) assert.NotNil(t, err) assert.Equal(t, err, fmt.Errorf("missing or empty accountName field from metadata")) }) @@ -52,8 +53,8 @@ func TestInit(t *testing.T) { "accountKey": "e+Dnvl8EOxYxV94nurVaRQ==", "containerName": "dapr", } - s.Init(m) - err := s.Ping() + s.Init(context.Background(), m) + err := s.Ping(context.Background()) assert.NotNil(t, err) }) } diff --git a/state/azure/cosmosdb/cosmosdb.go b/state/azure/cosmosdb/cosmosdb.go index 3a383ebc7..e2af48b06 100644 --- a/state/azure/cosmosdb/cosmosdb.go +++ b/state/azure/cosmosdb/cosmosdb.go @@ -113,7 +113,7 @@ func (c *StateStore) GetComponentMetadata() map[string]string { } // Init does metadata and connection parsing. -func (c *StateStore) Init(meta state.Metadata) error { +func (c *StateStore) Init(ctx context.Context, meta state.Metadata) error { c.logger.Debugf("CosmosDB init start") m := metadata{ @@ -191,9 +191,9 @@ func (c *StateStore) Init(meta state.Metadata) error { c.metadata = m c.contentType = m.ContentType - ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) - _, err = c.client.Read(ctx, nil) - cancel() + readCtx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() + _, err = c.client.Read(readCtx, nil) return err } @@ -217,9 +217,9 @@ func (c *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.Get options.ConsistencyLevel = azcosmos.ConsistencyLevelEventual.ToPtr() } - ctx, cancel := context.WithTimeout(ctx, defaultTimeout) - readItem, err := c.client.ReadItem(ctx, azcosmos.NewPartitionKeyString(partitionKey), req.Key, &options) - cancel() + readCtx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() + readItem, err := c.client.ReadItem(readCtx, azcosmos.NewPartitionKeyString(partitionKey), req.Key, &options) if err != nil { var responseErr *azcore.ResponseError if errors.As(err, &responseErr) && responseErr.ErrorCode == "NotFound" { @@ -306,10 +306,10 @@ func (c *StateStore) Set(ctx context.Context, req *state.SetRequest) error { return err } - ctx, cancel := context.WithTimeout(ctx, defaultTimeout) + upsertCtx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() pk := azcosmos.NewPartitionKeyString(partitionKey) - _, err = c.client.UpsertItem(ctx, pk, marsh, &options) - cancel() + _, err = c.client.UpsertItem(upsertCtx, pk, marsh, &options) if err != nil { return err } @@ -335,10 +335,10 @@ func (c *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error options.ConsistencyLevel = azcosmos.ConsistencyLevelEventual.ToPtr() } - ctx, cancel := context.WithTimeout(ctx, defaultTimeout) + deleteCtx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() pk := azcosmos.NewPartitionKeyString(partitionKey) - _, err = c.client.DeleteItem(ctx, pk, req.Key, &options) - cancel() + _, err = c.client.DeleteItem(deleteCtx, pk, req.Key, &options) if err != nil && !isNotFoundError(err) { c.logger.Debugf("Error from cosmos.DeleteDocument e=%e, e.Error=%s", err, err.Error()) if req.ETag != nil && *req.ETag != "" { @@ -418,9 +418,9 @@ func (c *StateStore) Multi(ctx context.Context, request *state.TransactionalStat c.logger.Debugf("#operations=%d,partitionkey=%s", numOperations, partitionKey) - ctx, cancel := context.WithTimeout(ctx, defaultTimeout) - batchResponse, err := c.client.ExecuteTransactionalBatch(ctx, batch, nil) - cancel() + execCtx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() + batchResponse, err := c.client.ExecuteTransactionalBatch(execCtx, batch, nil) if err != nil { return err } @@ -464,10 +464,10 @@ func (c *StateStore) Query(ctx context.Context, req *state.QueryRequest) (*state }, nil } -func (c *StateStore) Ping() error { - ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) - _, err := c.client.Read(ctx, nil) - cancel() +func (c *StateStore) Ping(ctx context.Context) error { + pingCtx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() + _, err := c.client.Read(pingCtx, nil) if err != nil { return err } diff --git a/state/azure/tablestorage/tablestorage.go b/state/azure/tablestorage/tablestorage.go index e13ea2c68..c65880338 100644 --- a/state/azure/tablestorage/tablestorage.go +++ b/state/azure/tablestorage/tablestorage.go @@ -84,7 +84,7 @@ type tablesMetadata struct { } // Init Initialises connection to table storage, optionally creates a table if it doesn't exist. -func (r *StateStore) Init(metadata state.Metadata) error { +func (r *StateStore) Init(ctx context.Context, metadata state.Metadata) error { meta, err := getTablesMetadata(metadata.Properties) if err != nil { return err @@ -146,7 +146,7 @@ func (r *StateStore) Init(metadata state.Metadata) error { } if !meta.SkipCreateTable { - createContext, cancel := context.WithTimeout(context.Background(), timeout) + createContext, cancel := context.WithTimeout(ctx, timeout) defer cancel() _, innerErr := client.CreateTable(createContext, meta.TableName, nil) if innerErr != nil { @@ -171,8 +171,6 @@ func (r *StateStore) Features() []state.Feature { } func (r *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error { - r.logger.Debugf("delete %s", req.Key) - err := r.deleteRow(ctx, req) if err != nil { if req.ETag != nil { @@ -187,7 +185,6 @@ func (r *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error } func (r *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { - r.logger.Debugf("fetching %s", req.Key) pk, rk := getPartitionAndRowKey(req.Key, r.cosmosDBMode) ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() @@ -208,11 +205,7 @@ func (r *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.Get } func (r *StateStore) Set(ctx context.Context, req *state.SetRequest) error { - r.logger.Debugf("saving %s", req.Key) - - err := r.writeRow(ctx, req) - - return err + return r.writeRow(ctx, req) } func (r *StateStore) GetComponentMetadata() map[string]string { diff --git a/state/cassandra/cassandra.go b/state/cassandra/cassandra.go index 31e84af2a..9bcc1b95b 100644 --- a/state/cassandra/cassandra.go +++ b/state/cassandra/cassandra.go @@ -78,7 +78,7 @@ func NewCassandraStateStore(logger logger.Logger) state.Store { } // Init performs metadata and connection parsing. -func (c *Cassandra) Init(metadata state.Metadata) error { +func (c *Cassandra) Init(_ context.Context, metadata state.Metadata) error { meta, err := getCassandraMetadata(metadata) if err != nil { return err diff --git a/state/cloudflare/workerskv/workerskv.go b/state/cloudflare/workerskv/workerskv.go index 9df193cf8..e516dba97 100644 --- a/state/cloudflare/workerskv/workerskv.go +++ b/state/cloudflare/workerskv/workerskv.go @@ -56,7 +56,7 @@ func NewCFWorkersKV(logger logger.Logger) state.Store { } // Init the component. -func (q *CFWorkersKV) Init(metadata state.Metadata) error { +func (q *CFWorkersKV) Init(_ context.Context, metadata state.Metadata) error { // Decode the metadata err := mapstructure.Decode(metadata.Properties, &q.metadata) if err != nil { diff --git a/state/cockroachdb/cockroachdb.go b/state/cockroachdb/cockroachdb.go index f2972a4eb..3ce66da12 100644 --- a/state/cockroachdb/cockroachdb.go +++ b/state/cockroachdb/cockroachdb.go @@ -47,8 +47,8 @@ func internalNew(logger logger.Logger, dba dbAccess) *CockroachDB { } // Init initializes the CockroachDB state store. -func (c *CockroachDB) Init(metadata state.Metadata) error { - return c.dbaccess.Init(metadata) +func (c *CockroachDB) Init(ctx context.Context, metadata state.Metadata) error { + return c.dbaccess.Init(ctx, metadata) } // Features returns the features available in this state store. @@ -72,8 +72,8 @@ func (c *CockroachDB) Set(ctx context.Context, req *state.SetRequest) error { } // Ping checks if database is available. -func (c *CockroachDB) Ping() error { - return c.dbaccess.Ping() +func (c *CockroachDB) Ping(ctx context.Context) error { + return c.dbaccess.Ping(ctx) } // BulkDelete removes multiple entries from the store. diff --git a/state/cockroachdb/cockroachdb_access.go b/state/cockroachdb/cockroachdb_access.go index 638dd5093..e4dbc0731 100644 --- a/state/cockroachdb/cockroachdb_access.go +++ b/state/cockroachdb/cockroachdb_access.go @@ -81,7 +81,7 @@ func parseMetadata(meta state.Metadata) (*cockroachDBMetadata, error) { } // Init sets up CockroachDB connection and ensures that the state table exists. -func (p *cockroachDBAccess) Init(metadata state.Metadata) error { +func (p *cockroachDBAccess) Init(ctx context.Context, metadata state.Metadata) error { p.logger.Debug("Initializing CockroachDB state store") meta, err := parseMetadata(metadata) @@ -107,7 +107,7 @@ func (p *cockroachDBAccess) Init(metadata state.Metadata) error { p.db = databaseConn - if err = databaseConn.Ping(); err != nil { + if err = databaseConn.PingContext(ctx); err != nil { return err } @@ -116,7 +116,7 @@ func (p *cockroachDBAccess) Init(metadata state.Metadata) error { } // Ensure that a connection to the database is actually established - err = p.Ping() + err = p.Ping(ctx) if err != nil { return err } @@ -399,7 +399,7 @@ func (p *cockroachDBAccess) Query(ctx context.Context, req *state.QueryRequest) } // Ping implements database ping. -func (p *cockroachDBAccess) Ping() error { +func (p *cockroachDBAccess) Ping(ctx context.Context) error { retryCount := defaultMaxConnectionAttempts if p.metadata.MaxConnectionAttempts != nil && *p.metadata.MaxConnectionAttempts >= 0 { retryCount = *p.metadata.MaxConnectionAttempts @@ -411,7 +411,7 @@ func (p *cockroachDBAccess) Ping() error { backoff := config.NewBackOff() return retry.NotifyRecover(func() error { - err := p.db.Ping() + err := p.db.PingContext(ctx) if errors.Is(err, driver.ErrBadConn) { return fmt.Errorf("error when attempting to establish connection with cockroachDB: %v", err) } diff --git a/state/cockroachdb/cockroachdb_integration_test.go b/state/cockroachdb/cockroachdb_integration_test.go index 7c0e1b97b..2f0d947ed 100644 --- a/state/cockroachdb/cockroachdb_integration_test.go +++ b/state/cockroachdb/cockroachdb_integration_test.go @@ -59,7 +59,7 @@ func TestCockroachDBIntegration(t *testing.T) { defer pgs.Close() }) - if err := pgs.Init(metadata); err != nil { + if err := pgs.Init(context.Background(), metadata); err != nil { t.Fatal(err) } @@ -615,7 +615,7 @@ func testInitConfiguration(t *testing.T) { Base: metadata.Base{Properties: rowTest.props}, } - err := cockroackDB.Init(metadata) + err := cockroackDB.Init(context.Background(), metadata) if rowTest.expectedErr == "" { assert.Nil(t, err) } else { diff --git a/state/cockroachdb/cockroachdb_test.go b/state/cockroachdb/cockroachdb_test.go index 2b3664ecd..d0262f75d 100644 --- a/state/cockroachdb/cockroachdb_test.go +++ b/state/cockroachdb/cockroachdb_test.go @@ -37,7 +37,7 @@ type fakeDBaccess struct { deleteExecuted bool } -func (m *fakeDBaccess) Init(metadata state.Metadata) error { +func (m *fakeDBaccess) Init(ctx context.Context, metadata state.Metadata) error { m.initExecuted = true return nil @@ -81,7 +81,7 @@ func (m *fakeDBaccess) Close() error { return nil } -func (m *fakeDBaccess) Ping() error { +func (m *fakeDBaccess) Ping(ctx context.Context) error { return nil } @@ -121,7 +121,7 @@ func createCockroachDB(t *testing.T) *CockroachDB { Base: metadata.Base{Properties: map[string]string{connectionStringKey: fakeConnectionString}}, } - err := pgs.Init(*metadata) + err := pgs.Init(context.Background(), *metadata) assert.Nil(t, err) assert.NotNil(t, pgs.dbaccess) diff --git a/state/cockroachdb/dbaccess.go b/state/cockroachdb/dbaccess.go index 112b64bc0..fc47f2bc4 100644 --- a/state/cockroachdb/dbaccess.go +++ b/state/cockroachdb/dbaccess.go @@ -21,7 +21,7 @@ import ( // dbAccess is a private interface which enables unit testing of CockroachDB. type dbAccess interface { - Init(metadata state.Metadata) error + Init(ctx context.Context, metadata state.Metadata) error Set(ctx context.Context, req *state.SetRequest) error BulkSet(ctx context.Context, req []state.SetRequest) error Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) @@ -29,6 +29,6 @@ type dbAccess interface { BulkDelete(ctx context.Context, req []state.DeleteRequest) error ExecuteMulti(ctx context.Context, req *state.TransactionalStateRequest) error Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) - Ping() error + Ping(context.Context) error Close() error } diff --git a/state/couchbase/couchbase.go b/state/couchbase/couchbase.go index b334d6a7c..98de59d53 100644 --- a/state/couchbase/couchbase.go +++ b/state/couchbase/couchbase.go @@ -119,7 +119,7 @@ func parseAndValidateMetadata(meta state.Metadata) (*couchbaseMetadata, error) { } // Init does metadata and connection parsing. -func (cbs *Couchbase) Init(metadata state.Metadata) error { +func (cbs *Couchbase) Init(_ context.Context, metadata state.Metadata) error { meta, err := parseAndValidateMetadata(metadata) if err != nil { return err diff --git a/state/gcp/firestore/firestore.go b/state/gcp/firestore/firestore.go index a1f0a7876..5f4885a10 100644 --- a/state/gcp/firestore/firestore.go +++ b/state/gcp/firestore/firestore.go @@ -71,7 +71,7 @@ func NewFirestoreStateStore(logger logger.Logger) state.Store { } // Init does metadata and connection parsing. -func (f *Firestore) Init(metadata state.Metadata) error { +func (f *Firestore) Init(ctx context.Context, metadata state.Metadata) error { meta, err := getFirestoreMetadata(metadata) if err != nil { return err @@ -82,7 +82,6 @@ func (f *Firestore) Init(metadata state.Metadata) error { } opt := option.WithCredentialsJSON(b) - ctx := context.Background() client, err := datastore.NewClient(ctx, meta.ProjectID, opt) if err != nil { return err diff --git a/state/hashicorp/consul/consul.go b/state/hashicorp/consul/consul.go index 6f040ce8e..4c2611c68 100644 --- a/state/hashicorp/consul/consul.go +++ b/state/hashicorp/consul/consul.go @@ -54,7 +54,7 @@ func NewConsulStateStore(logger logger.Logger) state.Store { // Init does metadata and config parsing and initializes the // Consul client. -func (c *Consul) Init(metadata state.Metadata) error { +func (c *Consul) Init(_ context.Context, metadata state.Metadata) error { consulConfig, err := metadataToConfig(metadata.Properties) if err != nil { return fmt.Errorf("couldn't convert metadata properties: %s", err) diff --git a/state/hazelcast/hazelcast.go b/state/hazelcast/hazelcast.go index 07fb0fd4e..af62aa5c2 100644 --- a/state/hazelcast/hazelcast.go +++ b/state/hazelcast/hazelcast.go @@ -69,7 +69,7 @@ func validateAndParseMetadata(meta state.Metadata) (*hazelcastMetadata, error) { } // Init does metadata and connection parsing. -func (store *Hazelcast) Init(metadata state.Metadata) error { +func (store *Hazelcast) Init(_ context.Context, metadata state.Metadata) error { meta, err := validateAndParseMetadata(metadata) if err != nil { return err diff --git a/state/in-memory/in_memory.go b/state/in-memory/in_memory.go index 16ad28e88..78366dace 100644 --- a/state/in-memory/in_memory.go +++ b/state/in-memory/in_memory.go @@ -21,6 +21,7 @@ import ( "fmt" "strconv" "sync" + "sync/atomic" "time" "github.com/google/uuid" @@ -41,32 +42,36 @@ type inMemStateStoreItem struct { } type inMemoryStore struct { - items map[string]*inMemStateStoreItem - lock *sync.RWMutex - log logger.Logger - - ctx context.Context - cancel context.CancelFunc + items map[string]*inMemStateStoreItem + lock *sync.RWMutex + log logger.Logger + closeCh chan struct{} + closed atomic.Bool + wg sync.WaitGroup } func NewInMemoryStateStore(logger logger.Logger) state.Store { return &inMemoryStore{ - items: map[string]*inMemStateStoreItem{}, - lock: &sync.RWMutex{}, - log: logger, + items: map[string]*inMemStateStoreItem{}, + lock: &sync.RWMutex{}, + log: logger, + closeCh: make(chan struct{}), } } -func (store *inMemoryStore) Init(metadata state.Metadata) error { - store.ctx, store.cancel = context.WithCancel(context.Background()) +func (store *inMemoryStore) Init(ctx context.Context, metadata state.Metadata) error { // start a background go routine to clean expired item - go store.startCleanThread() + store.wg.Add(1) + go func() { + defer store.wg.Done() + store.startCleanThread() + }() return nil } func (store *inMemoryStore) Close() error { - if store.cancel != nil { - store.cancel() + if store.closed.CompareAndSwap(false, true) { + close(store.closeCh) } // release memory reference @@ -76,6 +81,8 @@ func (store *inMemoryStore) Close() error { delete(store.items, k) } + store.wg.Wait() + return nil } @@ -443,7 +450,7 @@ func (store *inMemoryStore) startCleanThread() { select { case <-time.After(time.Second): store.doCleanExpiredItems() - case <-store.ctx.Done(): + case <-store.closeCh: return } } diff --git a/state/in-memory/in_memory_test.go b/state/in-memory/in_memory_test.go index bdd4e857d..9f1fcc385 100644 --- a/state/in-memory/in_memory_test.go +++ b/state/in-memory/in_memory_test.go @@ -32,7 +32,7 @@ func TestReadAndWrite(t *testing.T) { defer ctl.Finish() store := NewInMemoryStateStore(logger.NewLogger("test")) - store.Init(state.Metadata{}) + store.Init(context.Background(), state.Metadata{}) keyA := "theFirstKey" valueA := "value of key" diff --git a/state/jetstream/jetstream.go b/state/jetstream/jetstream.go index bf8f77a6b..61aeb2e6c 100644 --- a/state/jetstream/jetstream.go +++ b/state/jetstream/jetstream.go @@ -59,7 +59,7 @@ func NewJetstreamStateStore(logger logger.Logger) state.Store { } // Init does parse metadata and establishes connection to nats broker. -func (js *StateStore) Init(metadata state.Metadata) error { +func (js *StateStore) Init(_ context.Context, metadata state.Metadata) error { meta, err := js.getMetadata(metadata) if err != nil { return err diff --git a/state/jetstream/jetstream_test.go b/state/jetstream/jetstream_test.go index d3e00f3be..750371a43 100644 --- a/state/jetstream/jetstream_test.go +++ b/state/jetstream/jetstream_test.go @@ -97,7 +97,7 @@ func TestSetGetAndDelete(t *testing.T) { store := NewJetstreamStateStore(nil) - err := store.Init(state.Metadata{ + err := store.Init(context.Background(), state.Metadata{ Base: metadata.Base{Properties: map[string]string{ "natsURL": nats.DefaultURL, "bucket": "test", diff --git a/state/memcached/memcached.go b/state/memcached/memcached.go index 8ef93a85b..4ad8d87f1 100644 --- a/state/memcached/memcached.go +++ b/state/memcached/memcached.go @@ -62,7 +62,7 @@ func NewMemCacheStateStore(logger logger.Logger) state.Store { return s } -func (m *Memcached) Init(metadata state.Metadata) error { +func (m *Memcached) Init(_ context.Context, metadata state.Metadata) error { meta, err := getMemcachedMetadata(metadata) if err != nil { return err @@ -78,6 +78,8 @@ func (m *Memcached) Init(metadata state.Metadata) error { m.client = client + // TODO: pass context when PR is merged. + // https://github.com/bradfitz/gomemcache/pull/126 err = client.Ping() if err != nil { return err diff --git a/state/mongodb/mongodb.go b/state/mongodb/mongodb.go index 0ee1aff47..43cb15355 100644 --- a/state/mongodb/mongodb.go +++ b/state/mongodb/mongodb.go @@ -113,7 +113,7 @@ func NewMongoDB(logger logger.Logger) state.Store { } // Init establishes connection to the store based on the metadata. -func (m *MongoDB) Init(metadata state.Metadata) error { +func (m *MongoDB) Init(ctx context.Context, metadata state.Metadata) error { meta, err := getMongoDBMetaData(metadata) if err != nil { return err @@ -121,12 +121,12 @@ func (m *MongoDB) Init(metadata state.Metadata) error { m.operationTimeout = meta.OperationTimeout - client, err := getMongoDBClient(meta) + client, err := getMongoDBClient(ctx, meta) if err != nil { return fmt.Errorf("error in creating mongodb client: %s", err) } - if err = client.Ping(context.Background(), nil); err != nil { + if err = client.Ping(ctx, nil); err != nil { return fmt.Errorf("error in connecting to mongodb, host: %s error: %s", meta.Host, err) } @@ -168,8 +168,8 @@ func (m *MongoDB) Set(ctx context.Context, req *state.SetRequest) error { return nil } -func (m *MongoDB) Ping() error { - if err := m.client.Ping(context.Background(), nil); err != nil { +func (m *MongoDB) Ping(ctx context.Context) error { + if err := m.client.Ping(ctx, nil); err != nil { return fmt.Errorf("mongoDB store: error connecting to mongoDB at %s: %s", m.metadata.Host, err) } @@ -360,14 +360,14 @@ func getMongoURI(metadata *mongoDBMetadata) string { return fmt.Sprintf(connectionURIFormat, metadata.Host, metadata.DatabaseName, metadata.Params) } -func getMongoDBClient(metadata *mongoDBMetadata) (*mongo.Client, error) { +func getMongoDBClient(ctx context.Context, metadata *mongoDBMetadata) (*mongo.Client, error) { uri := getMongoURI(metadata) // Set client options clientOptions := options.Client().ApplyURI(uri) // Connect to MongoDB - ctx, cancel := context.WithTimeout(context.Background(), metadata.OperationTimeout) + ctx, cancel := context.WithTimeout(ctx, metadata.OperationTimeout) defer cancel() daprUserAgent := "dapr-" + logger.DaprVersion diff --git a/state/mysql/mysql.go b/state/mysql/mysql.go index d0a89f911..e26e16120 100644 --- a/state/mysql/mysql.go +++ b/state/mysql/mysql.go @@ -121,7 +121,7 @@ func newMySQLStateStore(logger logger.Logger, factory iMySQLFactory) *MySQL { // TransactionalStore // Populate the rest of the MySQL object by reading the metadata and opening // a connection to the server. -func (m *MySQL) Init(metadata state.Metadata) error { +func (m *MySQL) Init(ctx context.Context, metadata state.Metadata) error { m.logger.Debug("Initializing MySql state store") err := m.parseMetadata(metadata.Properties) @@ -136,7 +136,7 @@ func (m *MySQL) Init(metadata state.Metadata) error { } // will be nil if everything is good or an err that needs to be returned - return m.finishInit(db) + return m.finishInit(ctx, db) } func (m *MySQL) parseMetadata(md map[string]string) error { @@ -199,12 +199,12 @@ func (m *MySQL) Features() []state.Feature { } // Ping the database. -func (m *MySQL) Ping() error { +func (m *MySQL) Ping(ctx context.Context) error { if m.db == nil { return sql.ErrConnDone } - return m.PingWithContext(context.Background()) + return m.PingWithContext(ctx) } // PingWithContext is like Ping but accepts a context. @@ -215,36 +215,36 @@ func (m *MySQL) PingWithContext(parentCtx context.Context) error { } // Separated out to make this portion of code testable. -func (m *MySQL) finishInit(db *sql.DB) error { +func (m *MySQL) finishInit(ctx context.Context, db *sql.DB) error { m.db = db - err := m.ensureStateSchema() + err := m.ensureStateSchema(ctx) if err != nil { m.logger.Error(err) return err } - err = m.Ping() + err = m.Ping(ctx) if err != nil { m.logger.Error(err) return err } // will be nil if everything is good or an err that needs to be returned - return m.ensureStateTable(m.tableName) + return m.ensureStateTable(ctx, m.tableName) } -func (m *MySQL) ensureStateSchema() error { - exists, err := schemaExists(m.db, m.schemaName, m.timeout) +func (m *MySQL) ensureStateSchema(ctx context.Context) error { + exists, err := schemaExists(ctx, m.db, m.schemaName, m.timeout) if err != nil { return err } if !exists { m.logger.Infof("Creating MySql schema '%s'", m.schemaName) - ctx, cancel := context.WithTimeout(context.Background(), m.timeout) + cctx, cancel := context.WithTimeout(ctx, m.timeout) defer cancel() - _, err = m.db.ExecContext(ctx, + _, err = m.db.ExecContext(cctx, fmt.Sprintf("CREATE DATABASE %s;", m.schemaName), ) if err != nil { @@ -273,8 +273,8 @@ func (m *MySQL) ensureStateSchema() error { return err } -func (m *MySQL) ensureStateTable(stateTableName string) error { - exists, err := tableExists(m.db, stateTableName, m.timeout) +func (m *MySQL) ensureStateTable(ctx context.Context, stateTableName string) error { + exists, err := tableExists(ctx, m.db, stateTableName, m.timeout) if err != nil { return err } @@ -297,10 +297,9 @@ func (m *MySQL) ensureStateTable(stateTableName string) error { eTag VARCHAR(36) NOT NULL );`, stateTableName) - ctx, cancel := context.WithTimeout(context.Background(), m.timeout) - defer cancel() - _, err = m.db.ExecContext(ctx, createTable) - + execCtx, execCancel := context.WithTimeout(ctx, m.timeout) + defer execCancel() + _, err = m.db.ExecContext(execCtx, createTable) if err != nil { return err } @@ -309,8 +308,8 @@ func (m *MySQL) ensureStateTable(stateTableName string) error { return nil } -func schemaExists(db *sql.DB, schemaName string, timeout time.Duration) (bool, error) { - ctx, cancel := context.WithTimeout(context.Background(), timeout) +func schemaExists(ctx context.Context, db *sql.DB, schemaName string, timeout time.Duration) (bool, error) { + schemeCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() // Returns 1 or 0 if the table exists or not @@ -318,12 +317,12 @@ func schemaExists(db *sql.DB, schemaName string, timeout time.Duration) (bool, e query := `SELECT EXISTS ( SELECT SCHEMA_NAME FROM information_schema.schemata WHERE SCHEMA_NAME = ? ) AS 'exists'` - err := db.QueryRowContext(ctx, query, schemaName).Scan(&exists) + err := db.QueryRowContext(schemeCtx, query, schemaName).Scan(&exists) return exists == 1, err } -func tableExists(db *sql.DB, tableName string, timeout time.Duration) (bool, error) { - ctx, cancel := context.WithTimeout(context.Background(), timeout) +func tableExists(ctx context.Context, db *sql.DB, tableName string, timeout time.Duration) (bool, error) { + tableCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() // Returns 1 or 0 if the table exists or not @@ -331,7 +330,7 @@ func tableExists(db *sql.DB, tableName string, timeout time.Duration) (bool, err query := `SELECT EXISTS ( SELECT TABLE_NAME FROM information_schema.tables WHERE TABLE_NAME = ? ) AS 'exists'` - err := db.QueryRowContext(ctx, query, tableName).Scan(&exists) + err := db.QueryRowContext(tableCtx, query, tableName).Scan(&exists) return exists == 1, err } @@ -355,15 +354,15 @@ func (m *MySQL) deleteValue(parentCtx context.Context, querier querier, req *sta result sql.Result ) - ctx, cancel := context.WithTimeout(parentCtx, m.timeout) + execCtx, cancel := context.WithTimeout(parentCtx, m.timeout) defer cancel() if req.ETag == nil || *req.ETag == "" { - result, err = querier.ExecContext(ctx, fmt.Sprintf( + result, err = querier.ExecContext(execCtx, fmt.Sprintf( `DELETE FROM %s WHERE id = ?`, m.tableName), req.Key) } else { - result, err = querier.ExecContext(ctx, fmt.Sprintf( + result, err = querier.ExecContext(execCtx, fmt.Sprintf( `DELETE FROM %s WHERE id = ? and eTag = ?`, m.tableName), req.Key, *req.ETag) } diff --git a/state/mysql/mysql_integration_test.go b/state/mysql/mysql_integration_test.go index ecc7b7180..6563d024a 100644 --- a/state/mysql/mysql_integration_test.go +++ b/state/mysql/mysql_integration_test.go @@ -115,7 +115,7 @@ func TestMySQLIntegration(t *testing.T) { Base: metadata.Base{Properties: tt.props}, } - err := p.Init(metadata) + err := p.Init(context.Background(), metadata) if tt.expectedErr == "" { assert.Nil(t, err) @@ -138,7 +138,7 @@ func TestMySQLIntegration(t *testing.T) { defer mys.Close() }) - error := mys.Init(metadata) + error := mys.Init(context.Background(), metadata) if error != nil { t.Fatal(error) } @@ -149,7 +149,7 @@ func TestMySQLIntegration(t *testing.T) { tableName := "test_state" // Drop the table if it already exists - exists, err := tableExists(mys.db, tableName, 10*time.Second) + exists, err := tableExists(context.Background(), mys.db, tableName, 10*time.Second) assert.Nil(t, err) if exists { dropTable(t, mys.db, tableName) @@ -157,11 +157,11 @@ func TestMySQLIntegration(t *testing.T) { // Create the state table and test for its existence // There should be no error - err = mys.ensureStateTable(tableName) + err = mys.ensureStateTable(context.Background(), tableName) assert.Nil(t, err) // Now create it and make sure there are no errors - exists, err = tableExists(mys.db, tableName, 10*time.Second) + exists, err = tableExists(context.Background(), mys.db, tableName, 10*time.Second) assert.Nil(t, err) assert.True(t, exists) diff --git a/state/mysql/mysql_test.go b/state/mysql/mysql_test.go index 8a7944a21..b9138da14 100644 --- a/state/mysql/mysql_test.go +++ b/state/mysql/mysql_test.go @@ -49,7 +49,7 @@ func TestEnsureStateSchemaHandlesShortConnectionString(t *testing.T) { m.mock1.ExpectQuery("SELECT EXISTS").WillReturnRows(rows) // Act - m.mySQL.ensureStateSchema() + m.mySQL.ensureStateSchema(context.Background()) // Assert assert.Equal(t, "theUser:thePassword@/theSchema", m.mySQL.connectionString) @@ -64,7 +64,7 @@ func TestFinishInitHandlesSchemaExistsError(t *testing.T) { m.mock1.ExpectQuery("SELECT EXISTS").WillReturnError(expectedErr) // Act - actualErr := m.mySQL.finishInit(m.mySQL.db) + actualErr := m.mySQL.finishInit(context.Background(), m.mySQL.db) // Assert assert.NotNil(t, actualErr, "now error returned") @@ -83,7 +83,7 @@ func TestFinishInitHandlesDatabaseCreateError(t *testing.T) { m.mock1.ExpectExec("CREATE DATABASE").WillReturnError(expectedErr) // Act - actualErr := m.mySQL.finishInit(m.mySQL.db) + actualErr := m.mySQL.finishInit(context.Background(), m.mySQL.db) // Assert assert.NotNil(t, actualErr, "now error returned") @@ -107,7 +107,7 @@ func TestFinishInitHandlesPingError(t *testing.T) { m.mock2.ExpectPing().WillReturnError(expectedErr) // Act - actualErr := m.mySQL.finishInit(m.mySQL.db) + actualErr := m.mySQL.finishInit(context.Background(), m.mySQL.db) // Assert assert.NotNil(t, actualErr, "now error returned") @@ -135,7 +135,7 @@ func TestFinishInitHandlesTableExistsError(t *testing.T) { m.mock2.ExpectQuery("SELECT EXISTS").WillReturnError(fmt.Errorf("tableExistsError")) // Act - err := m.mySQL.finishInit(m.mySQL.db) + err := m.mySQL.finishInit(context.Background(), m.mySQL.db) // Assert assert.NotNil(t, err, "no error returned") @@ -541,7 +541,7 @@ func TestTableExists(t *testing.T) { m.mock1.ExpectQuery("SELECT EXISTS").WillReturnRows(rows) // Act - actual, err := tableExists(m.mySQL.db, "store", 10*time.Second) + actual, err := tableExists(context.Background(), m.mySQL.db, "store", 10*time.Second) // Assert assert.Nil(t, err, `error was returned`) @@ -559,7 +559,7 @@ func TestEnsureStateTableHandlesCreateTableError(t *testing.T) { m.mock1.ExpectExec("CREATE TABLE").WillReturnError(fmt.Errorf("CreateTableError")) // Act - err := m.mySQL.ensureStateTable("state") + err := m.mySQL.ensureStateTable(context.Background(), "state") // Assert assert.NotNil(t, err, "no error returned") @@ -580,7 +580,7 @@ func TestEnsureStateTableCreatesTable(t *testing.T) { m.mock1.ExpectExec("CREATE TABLE").WillReturnResult(sqlmock.NewResult(1, 1)) // Act - err := m.mySQL.ensureStateTable("state") + err := m.mySQL.ensureStateTable(context.Background(), "state") // Assert assert.Nil(t, err) @@ -597,7 +597,7 @@ func TestInitReturnsErrorOnNoConnectionString(t *testing.T) { } // Act - err := m.mySQL.Init(*metadata) + err := m.mySQL.Init(context.Background(), *metadata) // Assert assert.NotNil(t, err) @@ -614,7 +614,7 @@ func TestInitReturnsErrorOnFailOpen(t *testing.T) { m.mock1.ExpectQuery("SELECT EXISTS").WillReturnError(sql.ErrConnDone) // Act - err := m.mySQL.Init(*metadata) + err := m.mySQL.Init(context.Background(), *metadata) // Assert assert.NotNil(t, err) @@ -637,7 +637,7 @@ func TestInitHandlesRegisterTLSConfigError(t *testing.T) { } // Act - err := m.mySQL.Init(*metadata) + err := m.mySQL.Init(context.Background(), *metadata) // Assert assert.NotNil(t, err) @@ -653,7 +653,7 @@ func TestInitSetsTableName(t *testing.T) { } // Act - err := m.mySQL.Init(*metadata) + err := m.mySQL.Init(context.Background(), *metadata) // Assert assert.NotNil(t, err) @@ -669,7 +669,7 @@ func TestInitInvalidTableName(t *testing.T) { } // Act - err := m.mySQL.Init(*metadata) + err := m.mySQL.Init(context.Background(), *metadata) // Assert assert.ErrorContains(t, err, "table name '🙃' is not valid") @@ -684,7 +684,7 @@ func TestInitSetsSchemaName(t *testing.T) { } // Act - err := m.mySQL.Init(*metadata) + err := m.mySQL.Init(context.Background(), *metadata) // Assert assert.NotNil(t, err) @@ -700,7 +700,7 @@ func TestInitInvalidSchemaName(t *testing.T) { } // Act - err := m.mySQL.Init(*metadata) + err := m.mySQL.Init(context.Background(), *metadata) // Assert assert.ErrorContains(t, err, "schema name '?' is not valid") diff --git a/state/oci/objectstorage/objectstorage.go b/state/oci/objectstorage/objectstorage.go index 90e0d18d7..47b479ed0 100644 --- a/state/oci/objectstorage/objectstorage.go +++ b/state/oci/objectstorage/objectstorage.go @@ -85,9 +85,9 @@ type objectStoreClient interface { getObject(ctx context.Context, objectname string) (content []byte, etag *string, metadata map[string]string, err error) deleteObject(ctx context.Context, objectname string, etag *string) (err error) putObject(ctx context.Context, objectname string, contentLen int64, content io.ReadCloser, metadata map[string]string, etag *string) error - initStorageBucket() error - initOCIObjectStorageClient() (*objectstorage.ObjectStorageClient, error) - pingBucket() error + initStorageBucket(ctx context.Context) error + initOCIObjectStorageClient(ctx context.Context) (*objectstorage.ObjectStorageClient, error) + pingBucket(ctx context.Context) error } type objectStorageClient struct { @@ -101,7 +101,7 @@ type ociObjectStorageClient struct { /********* Interface Implementations Init, Features, Get, Set, Delete and the instantiation function NewOCIObjectStorageStore. */ -func (r *StateStore) Init(metadata state.Metadata) error { +func (r *StateStore) Init(ctx context.Context, metadata state.Metadata) error { r.logger.Debugf("Init OCI Object Storage State Store") meta, err := getObjectStorageMetadata(metadata.Properties) if err != nil { @@ -114,13 +114,13 @@ func (r *StateStore) Init(metadata state.Metadata) error { logger: r.logger, } - objectStorageClient, cerr := r.client.initOCIObjectStorageClient() + objectStorageClient, cerr := r.client.initOCIObjectStorageClient(ctx) if cerr != nil { return fmt.Errorf("failed to initialize client or create bucket : %w", cerr) } meta.OCIObjectStorageClient = objectStorageClient - cerr = r.client.initStorageBucket() + cerr = r.client.initStorageBucket(ctx) if cerr != nil { return fmt.Errorf("failed to create bucket : %w", cerr) } @@ -162,8 +162,8 @@ func (r *StateStore) Set(ctx context.Context, req *state.SetRequest) error { return r.writeDocument(ctx, req) } -func (r *StateStore) Ping() error { - return r.pingBucket() +func (r *StateStore) Ping(ctx context.Context) error { + return r.pingBucket(ctx) } func NewOCIObjectStorageStore(logger logger.Logger) state.Store { @@ -312,8 +312,8 @@ func (r *StateStore) readDocument(ctx context.Context, req *state.GetRequest) ([ return content, etag, nil } -func (r *StateStore) pingBucket() error { - err := r.client.pingBucket() +func (r *StateStore) pingBucket(ctx context.Context) error { + err := r.client.pingBucket(ctx) if err != nil { r.logger.Debugf("ping bucket failed err %s", err) return fmt.Errorf("failed to ping bucket on OCI Object storage : %w", err) @@ -467,8 +467,7 @@ func (c *ociObjectStorageClient) putObject(ctx context.Context, objectname strin return nil } -func (c *ociObjectStorageClient) initStorageBucket() error { - ctx := context.Background() +func (c *ociObjectStorageClient) initStorageBucket(ctx context.Context) error { err := c.ensureBucketExists(ctx, *c.objectStorageMetadata.OCIObjectStorageClient, c.objectStorageMetadata.Namespace, c.objectStorageMetadata.BucketName, c.objectStorageMetadata.CompartmentOCID) if err != nil { return fmt.Errorf("failed to read or create bucket : %w", err) @@ -476,7 +475,7 @@ func (c *ociObjectStorageClient) initStorageBucket() error { return nil } -func (c *ociObjectStorageClient) initOCIObjectStorageClient() (*objectstorage.ObjectStorageClient, error) { +func (c *ociObjectStorageClient) initOCIObjectStorageClient(ctx context.Context) (*objectstorage.ObjectStorageClient, error) { var configurationProvider common.ConfigurationProvider if c.objectStorageMetadata.InstancePrincipalAuthentication { c.logger.Debugf("instance principal authentication is used. ") @@ -499,7 +498,6 @@ func (c *ociObjectStorageClient) initOCIObjectStorageClient() (*objectstorage.Ob if cerr != nil { return nil, fmt.Errorf("failed to create ObjectStorageClient : %w", cerr) } - ctx := context.Background() c.objectStorageMetadata.Namespace, cerr = getNamespace(ctx, objectStorageClient) if cerr != nil { return nil, fmt.Errorf("failed to get namespace : %w", cerr) @@ -507,12 +505,12 @@ func (c *ociObjectStorageClient) initOCIObjectStorageClient() (*objectstorage.Ob return &objectStorageClient, nil } -func (c *ociObjectStorageClient) pingBucket() error { +func (c *ociObjectStorageClient) pingBucket(ctx context.Context) error { req := objectstorage.GetBucketRequest{ NamespaceName: &c.objectStorageMetadata.Namespace, BucketName: &c.objectStorageMetadata.BucketName, } - _, err := c.objectStorageMetadata.OCIObjectStorageClient.GetBucket(context.Background(), req) + _, err := c.objectStorageMetadata.OCIObjectStorageClient.GetBucket(ctx, req) if err != nil { return fmt.Errorf("failed to retrieve bucket details : %w", err) } diff --git a/state/oci/objectstorage/objectstorage_integration_test.go b/state/oci/objectstorage/objectstorage_integration_test.go index ac6aa2dea..224f48466 100644 --- a/state/oci/objectstorage/objectstorage_integration_test.go +++ b/state/oci/objectstorage/objectstorage_integration_test.go @@ -87,14 +87,14 @@ func testGet(t *testing.T, ociProperties map[string]string) { meta.Properties = ociProperties t.Run("Get an non-existing key", func(t *testing.T) { - err := statestore.Init(meta) + err := statestore.Init(context.Background(), meta) assert.Nil(t, err) getResponse, err := statestore.Get(context.Background(), &state.GetRequest{Key: "xyzq"}) assert.Equal(t, &state.GetResponse{}, getResponse, "Response must be empty") assert.NoError(t, err, "Non-existing key must not be treated as error") }) t.Run("Get an existing key", func(t *testing.T) { - err := statestore.Init(meta) + err := statestore.Init(context.Background(), meta) assert.Nil(t, err) err = statestore.Set(context.Background(), &state.SetRequest{Key: "test-key", Value: []byte("test-value")}) assert.Nil(t, err) @@ -104,7 +104,7 @@ func testGet(t *testing.T, ociProperties map[string]string) { assert.NotNil(t, *getResponse.ETag, "ETag should be set") }) t.Run("Get an existing composed key", func(t *testing.T) { - err := statestore.Init(meta) + err := statestore.Init(context.Background(), meta) assert.Nil(t, err) err = statestore.Set(context.Background(), &state.SetRequest{Key: "test-app||test-key", Value: []byte("test-value")}) assert.Nil(t, err) @@ -114,7 +114,7 @@ func testGet(t *testing.T, ociProperties map[string]string) { }) t.Run("Get an unexpired state element with TTL set", func(t *testing.T) { testKey := "unexpired-ttl-test-key" - err := statestore.Init(meta) + err := statestore.Init(context.Background(), meta) assert.Nil(t, err) err = statestore.Set(context.Background(), &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ "ttlInSeconds": "100", @@ -126,7 +126,7 @@ func testGet(t *testing.T, ociProperties map[string]string) { }) t.Run("Get a state element with TTL set to -1 (not expire)", func(t *testing.T) { testKey := "never-expiring-ttl-test-key" - err := statestore.Init(meta) + err := statestore.Init(context.Background(), meta) assert.Nil(t, err) err = statestore.Set(context.Background(), &state.SetRequest{Key: testKey, Value: []byte("test-value"), Metadata: (map[string]string{ "ttlInSeconds": "-1", @@ -137,7 +137,7 @@ func testGet(t *testing.T, ociProperties map[string]string) { assert.Equal(t, "test-value", string(getResponse.Data), "Value retrieved should be equal (TTL setting of -1 means never expire)") }) t.Run("Get an expired (TTL in the past) state element", func(t *testing.T) { - err := statestore.Init(meta) + err := statestore.Init(context.Background(), meta) assert.Nil(t, err) err = statestore.Set(context.Background(), &state.SetRequest{Key: "ttl-test-key", Value: []byte("test-value"), Metadata: (map[string]string{ "ttlInSeconds": "1", @@ -155,14 +155,14 @@ func testSet(t *testing.T, ociProperties map[string]string) { meta.Properties = ociProperties statestore := NewOCIObjectStorageStore(logger.NewLogger("logger")) t.Run("Set without a key", func(t *testing.T) { - err := statestore.Init(meta) + err := statestore.Init(context.Background(), meta) assert.Nil(t, err) err = statestore.Set(context.Background(), &state.SetRequest{Value: []byte("test-value")}) assert.Equal(t, err, fmt.Errorf("key for value to set was missing from request"), "Lacking Key results in error") }) t.Run("Regular Set Operation", func(t *testing.T) { testKey := "local-test-key" - err := statestore.Init(meta) + err := statestore.Init(context.Background(), meta) assert.Nil(t, err) err = statestore.Set(context.Background(), &state.SetRequest{Key: testKey, Value: []byte("test-value")}) @@ -174,7 +174,7 @@ func testSet(t *testing.T, ociProperties map[string]string) { }) t.Run("Regular Set Operation with composite key", func(t *testing.T) { testKey := "test-app||other-test-key" - err := statestore.Init(meta) + err := statestore.Init(context.Background(), meta) assert.Nil(t, err) err = statestore.Set(context.Background(), &state.SetRequest{Key: testKey, Value: []byte("test-value")}) @@ -198,7 +198,7 @@ func testSet(t *testing.T, ociProperties map[string]string) { t.Run("Testing Set & Concurrency (ETags)", func(t *testing.T) { testKey := "etag-test-key" - err := statestore.Init(meta) + err := statestore.Init(context.Background(), meta) assert.Nil(t, err) err = statestore.Set(context.Background(), &state.SetRequest{Key: testKey, Value: []byte("test-value")}) @@ -231,14 +231,14 @@ func testDelete(t *testing.T, ociProperties map[string]string) { m.Properties = ociProperties s := NewOCIObjectStorageStore(logger.NewLogger("logger")) t.Run("Delete without a key", func(t *testing.T) { - err := s.Init(m) + err := s.Init(context.Background(), m) assert.Nil(t, err) err = s.Delete(context.Background(), &state.DeleteRequest{}) assert.Equal(t, err, fmt.Errorf("key for value to delete was missing from request"), "Lacking Key results in error") }) t.Run("Regular Delete Operation", func(t *testing.T) { testKey := "test-key" - err := s.Init(m) + err := s.Init(context.Background(), m) assert.Nil(t, err) err = s.Set(context.Background(), &state.SetRequest{Key: testKey, Value: []byte("test-value")}) @@ -248,7 +248,7 @@ func testDelete(t *testing.T, ociProperties map[string]string) { }) t.Run("Regular Delete Operation for composite key", func(t *testing.T) { testKey := "test-app||some-test-key" - err := s.Init(m) + err := s.Init(context.Background(), m) assert.Nil(t, err) err = s.Set(context.Background(), &state.SetRequest{Key: testKey, Value: []byte("test-value")}) @@ -263,7 +263,7 @@ func testDelete(t *testing.T, ociProperties map[string]string) { t.Run("Testing Delete & Concurrency (ETags)", func(t *testing.T) { testKey := "etag-test-delete-key" - err := s.Init(m) + err := s.Init(context.Background(), m) assert.Nil(t, err) // create document. err = s.Set(context.Background(), &state.SetRequest{Key: testKey, Value: []byte("test-value")}) @@ -289,9 +289,9 @@ func testPing(t *testing.T, ociProperties map[string]string) { m.Properties = ociProperties s := NewOCIObjectStorageStore(logger.NewLogger("logger")).(*StateStore) t.Run("Ping", func(t *testing.T) { - err := s.Init(m) + err := s.Init(context.Background(), m) assert.Nil(t, err) - err = s.Ping() + err = s.Ping(context.Background()) assert.Nil(t, err, "Ping should be successful") }) } diff --git a/state/oci/objectstorage/objectstorage_test.go b/state/oci/objectstorage/objectstorage_test.go index 892fb1d70..d389807b5 100644 --- a/state/oci/objectstorage/objectstorage_test.go +++ b/state/oci/objectstorage/objectstorage_test.go @@ -44,7 +44,7 @@ func TestInit(t *testing.T) { t.Parallel() t.Run("Init with beautifully complete yet incorrect metadata", func(t *testing.T) { meta.Properties = getDummyOCIObjectStorageConfiguration() - err := statestore.Init(meta) + err := statestore.Init(context.Background(), meta) assert.NotNil(t, err) assert.Error(t, err, "Incorrect configuration data should result in failure to create client") assert.Contains(t, err.Error(), "failed to initialize client", "Incorrect configuration data should result in failure to create client") @@ -52,56 +52,56 @@ func TestInit(t *testing.T) { t.Run("Init with missing region", func(t *testing.T) { meta.Properties = getDummyOCIObjectStorageConfiguration() meta.Properties[regionKey] = "" - err := statestore.Init(meta) + err := statestore.Init(context.Background(), meta) assert.NotNil(t, err) assert.Equal(t, fmt.Errorf("missing or empty region field from metadata"), err, "Lacking configuration property should be spotted") }) t.Run("Init with missing tenancyOCID", func(t *testing.T) { meta.Properties = getDummyOCIObjectStorageConfiguration() meta.Properties["tenancyOCID"] = "" - err := statestore.Init(meta) + err := statestore.Init(context.Background(), meta) assert.NotNil(t, err) assert.Equal(t, fmt.Errorf("missing or empty tenancyOCID field from metadata"), err, "Lacking configuration property should be spotted") }) t.Run("Init with missing userOCID", func(t *testing.T) { meta.Properties = getDummyOCIObjectStorageConfiguration() meta.Properties[userKey] = "" - err := statestore.Init(meta) + err := statestore.Init(context.Background(), meta) assert.NotNil(t, err) assert.Equal(t, fmt.Errorf("missing or empty userOCID field from metadata"), err, "Lacking configuration property should be spotted") }) t.Run("Init with missing compartmentOCID", func(t *testing.T) { meta.Properties = getDummyOCIObjectStorageConfiguration() meta.Properties[compartmentKey] = "" - err := statestore.Init(meta) + err := statestore.Init(context.Background(), meta) assert.NotNil(t, err) assert.Equal(t, fmt.Errorf("missing or empty compartmentOCID field from metadata"), err, "Lacking configuration property should be spotted") }) t.Run("Init with missing fingerprint", func(t *testing.T) { meta.Properties = getDummyOCIObjectStorageConfiguration() meta.Properties[fingerPrintKey] = "" - err := statestore.Init(meta) + err := statestore.Init(context.Background(), meta) assert.NotNil(t, err) assert.Equal(t, fmt.Errorf("missing or empty fingerPrint field from metadata"), err, "Lacking configuration property should be spotted") }) t.Run("Init with missing private key", func(t *testing.T) { meta.Properties = getDummyOCIObjectStorageConfiguration() meta.Properties[privateKeyKey] = "" - err := statestore.Init(meta) + err := statestore.Init(context.Background(), meta) assert.NotNil(t, err) assert.Equal(t, fmt.Errorf("missing or empty privateKey field from metadata"), err, "Lacking configuration property should be spotted") }) t.Run("Init with incorrect value for instancePrincipalAuthentication", func(t *testing.T) { meta.Properties = getDummyOCIObjectStorageConfiguration() meta.Properties[instancePrincipalAuthenticationKey] = "ZQWE" - err := statestore.Init(meta) + err := statestore.Init(context.Background(), meta) assert.NotNil(t, err, "if instancePrincipalAuthentication is defined, it should be true or false; if not: error should be raised ") }) t.Run("Init with missing fingerprint with instancePrincipalAuthentication", func(t *testing.T) { meta.Properties = getDummyOCIObjectStorageConfiguration() meta.Properties[fingerPrintKey] = "" meta.Properties[instancePrincipalAuthenticationKey] = "true" - err := statestore.Init(meta) + err := statestore.Init(context.Background(), meta) if err != nil { assert.Contains(t, err.Error(), "failed to initialize client", "unit tests not run on OCI will not be able to correctly create an OCI client based on instance principal authentication") } @@ -110,7 +110,7 @@ func TestInit(t *testing.T) { meta.Properties = getDummyOCIObjectStorageConfiguration() meta.Properties[configFileAuthenticationKey] = "true" meta.Properties[configFilePathKey] = "file_does_not_exist" - err := statestore.Init(meta) + err := statestore.Init(context.Background(), meta) assert.NotNil(t, err, "if configFileAuthentication is true and configFilePath does not indicate an existing file, then an error should be produced") if err != nil { assert.Contains(t, err.Error(), "does not exist", "if configFileAuthentication is true and configFilePath does not indicate an existing file, then an error should be produced that indicates this") @@ -120,7 +120,7 @@ func TestInit(t *testing.T) { meta.Properties = getDummyOCIObjectStorageConfiguration() meta.Properties[configFileAuthenticationKey] = "true" meta.Properties[configFilePathKey] = "~/some-file" - err := statestore.Init(meta) + err := statestore.Init(context.Background(), meta) assert.NotNil(t, err, "if configFileAuthentication is true and configFilePath contains a value that starts with ~/ , then an error should be produced") if err != nil { assert.Contains(t, err.Error(), "~", "if configFileAuthentication is true and configFilePath starts with ~/, then an error should be produced that indicates this") @@ -130,7 +130,7 @@ func TestInit(t *testing.T) { meta.Properties = getDummyOCIObjectStorageConfiguration() meta.Properties[fingerPrintKey] = "" meta.Properties[instancePrincipalAuthenticationKey] = "false" - err := statestore.Init(meta) + err := statestore.Init(context.Background(), meta) assert.NotNil(t, err, "if instancePrincipalAuthentication and configFileAuthentication are both false, then fingerprint is required and an error should be raised when it is missing") }) t.Run("Init with missing fingerprint with configFileAuthentication", func(t *testing.T) { @@ -138,7 +138,7 @@ func TestInit(t *testing.T) { meta.Properties[fingerPrintKey] = "" meta.Properties[instancePrincipalAuthenticationKey] = "false" meta.Properties[configFileAuthenticationKey] = "true" - err := statestore.Init(meta) + err := statestore.Init(context.Background(), meta) if err != nil { assert.Contains(t, err.Error(), "failed to initialize client", "if configFileAuthentication is true, then fingerprint is not required and error should be raised for failed to initialize client, not for missing fingerprint") } @@ -220,11 +220,11 @@ func (c *mockedObjectStoreClient) putObject(ctx context.Context, objectname stri return nil } -func (c *mockedObjectStoreClient) initStorageBucket() error { +func (c *mockedObjectStoreClient) initStorageBucket(ctx context.Context) error { return nil } -func (c *mockedObjectStoreClient) pingBucket() error { +func (c *mockedObjectStoreClient) pingBucket(ctx context.Context) error { c.pingBucketIsCalled = true return nil } @@ -264,7 +264,7 @@ func TestInitWithMockClient(t *testing.T) { s.client = &mockedObjectStoreClient{} meta := state.Metadata{} t.Run("Test Init with incomplete configuration", func(t *testing.T) { - err := s.Init(meta) + err := s.Init(context.Background(), meta) assert.NotNil(t, err, "Init should complain about lacking configuration settings") }) } @@ -276,7 +276,7 @@ func TestPingWithMockClient(t *testing.T) { s.client = mockClient t.Run("Test Ping", func(t *testing.T) { - err := s.Ping() + err := s.Ping(context.Background()) assert.Nil(t, err) assert.True(t, mockClient.pingBucketIsCalled, "function pingBucket should be invoked on the mockClient") }) diff --git a/state/oracledatabase/dbaccess.go b/state/oracledatabase/dbaccess.go index 188cf8366..0149765cb 100644 --- a/state/oracledatabase/dbaccess.go +++ b/state/oracledatabase/dbaccess.go @@ -21,8 +21,8 @@ import ( // dbAccess is a private interface which enables unit testing of Oracle Database. type dbAccess interface { - Init(metadata state.Metadata) error - Ping() error + Init(ctx context.Context, metadata state.Metadata) error + Ping(ctx context.Context) error Set(ctx context.Context, req *state.SetRequest) error Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) Delete(ctx context.Context, req *state.DeleteRequest) error diff --git a/state/oracledatabase/oracledatabase.go b/state/oracledatabase/oracledatabase.go index 3fe7f8595..6c0c3de5d 100644 --- a/state/oracledatabase/oracledatabase.go +++ b/state/oracledatabase/oracledatabase.go @@ -48,12 +48,12 @@ func newOracleDatabaseStateStore(logger logger.Logger, dba dbAccess) *OracleData } // Init initializes the SQL server state store. -func (o *OracleDatabase) Init(metadata state.Metadata) error { - return o.dbaccess.Init(metadata) +func (o *OracleDatabase) Init(ctx context.Context, metadata state.Metadata) error { + return o.dbaccess.Init(ctx, metadata) } -func (o *OracleDatabase) Ping() error { - return o.dbaccess.Ping() +func (o *OracleDatabase) Ping(ctx context.Context) error { + return o.dbaccess.Ping(ctx) } // Features returns the features available in this state store. diff --git a/state/oracledatabase/oracledatabase_integration_test.go b/state/oracledatabase/oracledatabase_integration_test.go index fe54ed2e0..27ea72b75 100644 --- a/state/oracledatabase/oracledatabase_integration_test.go +++ b/state/oracledatabase/oracledatabase_integration_test.go @@ -66,7 +66,7 @@ func TestOracleDatabaseIntegration(t *testing.T) { defer ods.Close() }) - if initerror := ods.Init(metadata); initerror != nil { + if initerror := ods.Init(context.Background(), metadata); initerror != nil { t.Fatal(initerror) } @@ -743,7 +743,7 @@ func testInitConfiguration(t *testing.T) { Base: metadata.Base{Properties: tt.props}, } - err := p.Init(metadata) + err := p.Init(context.Background(), metadata) if tt.expectedErr == "" { assert.Nil(t, err) } else { diff --git a/state/oracledatabase/oracledatabase_test.go b/state/oracledatabase/oracledatabase_test.go index 1a714c3fe..44e1f0553 100644 --- a/state/oracledatabase/oracledatabase_test.go +++ b/state/oracledatabase/oracledatabase_test.go @@ -38,12 +38,12 @@ type fakeDBaccess struct { getExecuted bool } -func (m *fakeDBaccess) Ping() error { +func (m *fakeDBaccess) Ping(ctx context.Context) error { m.pingExecuted = true return nil } -func (m *fakeDBaccess) Init(metadata state.Metadata) error { +func (m *fakeDBaccess) Init(ctx context.Context, metadata state.Metadata) error { m.initExecuted = true return nil @@ -77,7 +77,7 @@ func (m *fakeDBaccess) Close() error { func TestInitRunsDBAccessInit(t *testing.T) { t.Parallel() ods, fake := createOracleDatabaseWithFake(t) - ods.Ping() + ods.Ping(context.Background()) assert.True(t, fake.initExecuted) } @@ -194,7 +194,7 @@ func createOracleDatabaseWithFake(t *testing.T) (*OracleDatabase, *fakeDBaccess) func TestPingRunsDBAccessPing(t *testing.T) { t.Parallel() odb, fake := createOracleDatabaseWithFake(t) - odb.Ping() + odb.Ping(context.Background()) assert.True(t, fake.pingExecuted) } @@ -212,7 +212,7 @@ func createOracleDatabase(t *testing.T) *OracleDatabase { Base: metadata.Base{Properties: map[string]string{connectionStringKey: fakeConnectionString}}, } - err := odb.Init(*metadata) + err := odb.Init(context.Background(), *metadata) assert.Nil(t, err) assert.NotNil(t, odb.dbaccess) diff --git a/state/oracledatabase/oracledatabaseaccess.go b/state/oracledatabase/oracledatabaseaccess.go index 7c0b095c7..5988d47d6 100644 --- a/state/oracledatabase/oracledatabaseaccess.go +++ b/state/oracledatabase/oracledatabaseaccess.go @@ -63,8 +63,8 @@ func newOracleDatabaseAccess(logger logger.Logger) *oracleDatabaseAccess { } } -func (o *oracleDatabaseAccess) Ping() error { - return o.db.Ping() +func (o *oracleDatabaseAccess) Ping(ctx context.Context) error { + return o.db.PingContext(ctx) } func parseMetadata(meta map[string]string) (oracleDatabaseMetadata, error) { @@ -76,7 +76,7 @@ func parseMetadata(meta map[string]string) (oracleDatabaseMetadata, error) { } // Init sets up OracleDatabase connection and ensures that the state table exists. -func (o *oracleDatabaseAccess) Init(metadata state.Metadata) error { +func (o *oracleDatabaseAccess) Init(ctx context.Context, metadata state.Metadata) error { o.logger.Debug("Initializing OracleDatabase state store") meta, err := parseMetadata(metadata.Properties) o.metadata = meta @@ -102,7 +102,7 @@ func (o *oracleDatabaseAccess) Init(metadata state.Metadata) error { o.db = db - if pingErr := db.Ping(); pingErr != nil { + if pingErr := db.PingContext(ctx); pingErr != nil { return pingErr } err = o.ensureStateTable(tableName) diff --git a/state/postgresql/dbaccess.go b/state/postgresql/dbaccess.go index 62ab38927..2664cc8a8 100644 --- a/state/postgresql/dbaccess.go +++ b/state/postgresql/dbaccess.go @@ -24,7 +24,7 @@ import ( // dbAccess is a private interface which enables unit testing of PostgreSQL. type dbAccess interface { - Init(metadata state.Metadata) error + Init(ctx context.Context, metadata state.Metadata) error Set(ctx context.Context, req *state.SetRequest) error BulkSet(ctx context.Context, req []state.SetRequest) error Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) diff --git a/state/postgresql/postgresdbaccess.go b/state/postgresql/postgresdbaccess.go index 655c900b8..9560458cb 100644 --- a/state/postgresql/postgresdbaccess.go +++ b/state/postgresql/postgresdbaccess.go @@ -20,6 +20,8 @@ import ( "errors" "fmt" "strconv" + "sync" + "sync/atomic" "time" "github.com/jackc/pgx/v5" @@ -52,8 +54,10 @@ type PostgresDBAccess struct { logger logger.Logger metadata postgresMetadataStruct db pgxPoolConn - ctx context.Context - cancel context.CancelFunc + + closeCh chan struct{} + closed atomic.Bool + wg sync.WaitGroup } // newPostgresDBAccess creates a new instance of postgresAccess. @@ -61,16 +65,15 @@ func newPostgresDBAccess(logger logger.Logger) *PostgresDBAccess { logger.Debug("Instantiating new Postgres state store") return &PostgresDBAccess{ - logger: logger, + logger: logger, + closeCh: make(chan struct{}), } } // Init sets up Postgres connection and ensures that the state table exists. -func (p *PostgresDBAccess) Init(meta state.Metadata) error { +func (p *PostgresDBAccess) Init(ctx context.Context, meta state.Metadata) error { p.logger.Debug("Initializing Postgres state store") - p.ctx, p.cancel = context.WithCancel(context.Background()) - err := p.metadata.InitWithMetadata(meta) if err != nil { p.logger.Errorf("Failed to parse metadata: %v", err) @@ -87,7 +90,7 @@ func (p *PostgresDBAccess) Init(meta state.Metadata) error { config.MaxConnIdleTime = p.metadata.ConnectionMaxIdleTime } - connCtx, connCancel := context.WithTimeout(p.ctx, p.metadata.timeout) + connCtx, connCancel := context.WithTimeout(ctx, p.metadata.timeout) p.db, err = pgxpool.NewWithConfig(connCtx, config) connCancel() if err != nil { @@ -96,7 +99,7 @@ func (p *PostgresDBAccess) Init(meta state.Metadata) error { return err } - pingCtx, pingCancel := context.WithTimeout(p.ctx, p.metadata.timeout) + pingCtx, pingCancel := context.WithTimeout(ctx, p.metadata.timeout) err = p.db.Ping(pingCtx) pingCancel() if err != nil { @@ -111,12 +114,12 @@ func (p *PostgresDBAccess) Init(meta state.Metadata) error { MetadataTableName: p.metadata.MetadataTableName, StateTableName: p.metadata.TableName, } - err = migrate.Perform(p.ctx) + err = migrate.Perform(ctx) if err != nil { return err } - p.ScheduleCleanupExpiredData(p.ctx) + p.ScheduleCleanupExpiredData(ctx) return nil } @@ -444,13 +447,15 @@ func (p *PostgresDBAccess) Query(parentCtx context.Context, req *state.QueryRequ } func (p *PostgresDBAccess) ScheduleCleanupExpiredData(ctx context.Context) { - if p.metadata.cleanupInterval == nil || *p.metadata.cleanupInterval <= 0 { + if p.metadata.cleanupInterval == nil || *p.metadata.cleanupInterval <= 0 || p.closed.Load() { return } p.logger.Infof("Schedule expired data clean up every %d seconds", int(p.metadata.cleanupInterval.Seconds())) + p.wg.Add(1) go func() { + defer p.wg.Done() ticker := time.NewTicker(*p.metadata.cleanupInterval) defer ticker.Stop() @@ -465,6 +470,9 @@ func (p *PostgresDBAccess) ScheduleCleanupExpiredData(ctx context.Context) { case <-ctx.Done(): p.logger.Debug("Stopped background cleanup of expired data") return + case <-p.closeCh: + p.logger.Debug("Stopping background because PostgresDBAccess is closing") + return } } }() @@ -520,15 +528,17 @@ func (p *PostgresDBAccess) UpdateLastCleanup(ctx context.Context, db dbquerier, // Close implements io.Close. func (p *PostgresDBAccess) Close() error { - if p.cancel != nil { - p.cancel() - p.cancel = nil + if p.closed.CompareAndSwap(false, true) { + close(p.closeCh) } + if p.db != nil { p.db.Close() p.db = nil } + p.wg.Wait() + return nil } diff --git a/state/postgresql/postgresql.go b/state/postgresql/postgresql.go index 95f26fe2e..485f44704 100644 --- a/state/postgresql/postgresql.go +++ b/state/postgresql/postgresql.go @@ -45,8 +45,8 @@ func newPostgreSQLStateStore(logger logger.Logger, dba dbAccess) *PostgreSQL { } // Init initializes the SQL server state store. -func (p *PostgreSQL) Init(metadata state.Metadata) error { - return p.dbaccess.Init(metadata) +func (p *PostgreSQL) Init(ctx context.Context, metadata state.Metadata) error { + return p.dbaccess.Init(ctx, metadata) } // Features returns the features available in this state store. diff --git a/state/postgresql/postgresql_integration_test.go b/state/postgresql/postgresql_integration_test.go index ed1f2f326..e8c6b416a 100644 --- a/state/postgresql/postgresql_integration_test.go +++ b/state/postgresql/postgresql_integration_test.go @@ -60,7 +60,7 @@ func TestPostgreSQLIntegration(t *testing.T) { defer pgs.Close() }) - error := pgs.Init(metadata) + error := pgs.Init(context.Background(), metadata) if error != nil { t.Fatal(error) } @@ -472,7 +472,7 @@ func testInitConfiguration(t *testing.T) { Base: metadata.Base{Properties: tt.props}, } - err := p.Init(metadata) + err := p.Init(context.Background(), metadata) if tt.expectedErr == nil { assert.NoError(t, err) } else { diff --git a/state/postgresql/postgresql_test.go b/state/postgresql/postgresql_test.go index 1f46b92b7..0ea3b81b4 100644 --- a/state/postgresql/postgresql_test.go +++ b/state/postgresql/postgresql_test.go @@ -38,7 +38,7 @@ type fakeDBaccess struct { deleteExecuted bool } -func (m *fakeDBaccess) Init(metadata state.Metadata) error { +func (m *fakeDBaccess) Init(ctx context.Context, metadata state.Metadata) error { m.initExecuted = true return nil @@ -110,7 +110,7 @@ func createPostgreSQL(t *testing.T) *PostgreSQL { Base: metadata.Base{Properties: map[string]string{"connectionString": fakeConnectionString}}, } - err := pgs.Init(*metadata) + err := pgs.Init(context.Background(), *metadata) assert.Nil(t, err) assert.NotNil(t, pgs.dbaccess) diff --git a/state/redis/redis.go b/state/redis/redis.go index 3ff79d4c3..640d865df 100644 --- a/state/redis/redis.go +++ b/state/redis/redis.go @@ -100,9 +100,6 @@ type StateStore struct { features []state.Feature logger logger.Logger - - ctx context.Context - cancel context.CancelFunc } // NewRedisStateStore returns a new redis state store. @@ -117,8 +114,8 @@ func NewRedisStateStore(logger logger.Logger) state.Store { return s } -func (r *StateStore) Ping() error { - if _, err := r.client.PingResult(context.Background()); err != nil { +func (r *StateStore) Ping(ctx context.Context) error { + if _, err := r.client.PingResult(ctx); err != nil { return fmt.Errorf("redis store: error connecting to redis at %s: %s", r.clientSettings.Host, err) } @@ -126,7 +123,7 @@ func (r *StateStore) Ping() error { } // Init does metadata and connection parsing. -func (r *StateStore) Init(metadata state.Metadata) error { +func (r *StateStore) Init(ctx context.Context, metadata state.Metadata) error { m, err := rediscomponent.ParseRedisMetadata(metadata.Properties) if err != nil { return err @@ -144,17 +141,15 @@ func (r *StateStore) Init(metadata state.Metadata) error { return fmt.Errorf("redis store: error parsing query index schema: %v", err) } - r.ctx, r.cancel = context.WithCancel(context.Background()) - - if _, err = r.client.PingResult(r.ctx); err != nil { + if _, err = r.client.PingResult(ctx); err != nil { return fmt.Errorf("redis store: error connecting to redis at %s: %v", r.clientSettings.Host, err) } - if r.replicas, err = r.getConnectedSlaves(); err != nil { + if r.replicas, err = r.getConnectedSlaves(ctx); err != nil { return err } - if err = r.registerSchemas(); err != nil { + if err = r.registerSchemas(ctx); err != nil { return fmt.Errorf("redis store: error registering query schemas: %v", err) } @@ -166,8 +161,8 @@ func (r *StateStore) Features() []state.Feature { return r.features } -func (r *StateStore) getConnectedSlaves() (int, error) { - res, err := r.client.DoRead(r.ctx, "INFO", "replication") +func (r *StateStore) getConnectedSlaves(ctx context.Context) (int, error) { + res, err := r.client.DoRead(ctx, "INFO", "replication") if err != nil { return 0, err } @@ -271,8 +266,8 @@ func (r *StateStore) getDefault(ctx context.Context, req *state.GetRequest) (*st }, nil } -func (r *StateStore) getJSON(req *state.GetRequest) (*state.GetResponse, error) { - res, err := r.client.DoRead(r.ctx, "JSON.GET", req.Key) +func (r *StateStore) getJSON(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { + res, err := r.client.DoRead(ctx, "JSON.GET", req.Key) if err != nil { return nil, err } @@ -311,7 +306,7 @@ func (r *StateStore) getJSON(req *state.GetRequest) (*state.GetResponse, error) // Get retrieves state from redis with a key. func (r *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { if contentType, ok := req.Metadata[daprmetadata.ContentType]; ok && contentType == contenttype.JSONContentType && rediscomponent.ClientHasJSONSupport(r.client) { - return r.getJSON(req) + return r.getJSON(ctx, req) } return r.getDefault(ctx, req) @@ -446,18 +441,18 @@ func (r *StateStore) Multi(ctx context.Context, request *state.TransactionalStat return err } -func (r *StateStore) registerSchemas() error { +func (r *StateStore) registerSchemas(ctx context.Context) error { for name, elem := range r.querySchemas { r.logger.Infof("redis: create query index %s", name) - if err := r.client.DoWrite(r.ctx, elem.schema...); err != nil { + if err := r.client.DoWrite(ctx, elem.schema...); err != nil { if err.Error() != "Index already exists" { return err } r.logger.Infof("redis: drop stale query index %s", name) - if err = r.client.DoWrite(r.ctx, "FT.DROPINDEX", name); err != nil { + if err = r.client.DoWrite(ctx, "FT.DROPINDEX", name); err != nil { return err } - if err = r.client.DoWrite(r.ctx, elem.schema...); err != nil { + if err = r.client.DoWrite(ctx, elem.schema...); err != nil { return err } } @@ -545,8 +540,6 @@ func (r *StateStore) Query(ctx context.Context, req *state.QueryRequest) (*state } func (r *StateStore) Close() error { - r.cancel() - return r.client.Close() } diff --git a/state/redis/redis_test.go b/state/redis/redis_test.go index d116ba284..3cb8cc235 100644 --- a/state/redis/redis_test.go +++ b/state/redis/redis_test.go @@ -204,7 +204,6 @@ func TestTransactionalUpsert(t *testing.T) { json: jsoniter.ConfigFastest, logger: logger.NewLogger("test"), } - ss.ctx, ss.cancel = context.WithCancel(context.Background()) err := ss.Multi(context.Background(), &state.TransactionalStateRequest{ Operations: []state.TransactionalStateOperation{ @@ -270,7 +269,6 @@ func TestTransactionalDelete(t *testing.T) { json: jsoniter.ConfigFastest, logger: logger.NewLogger("test"), } - ss.ctx, ss.cancel = context.WithCancel(context.Background()) // Insert a record first. ss.Set(context.Background(), &state.SetRequest{ @@ -307,12 +305,12 @@ func TestPing(t *testing.T) { clientSettings: &rediscomponent.Settings{}, } - err := state.Ping(ss) + err := state.Ping(context.Background(), ss) assert.NoError(t, err) s.Close() - err = state.Ping(ss) + err = state.Ping(context.Background(), ss) assert.Error(t, err) } @@ -328,14 +326,13 @@ func TestRequestsWithGlobalTTL(t *testing.T) { logger: logger.NewLogger("test"), metadata: rediscomponent.Metadata{TTLInSeconds: &globalTTLInSeconds}, } - ss.ctx, ss.cancel = context.WithCancel(context.Background()) t.Run("TTL: Only global specified", func(t *testing.T) { ss.Set(context.Background(), &state.SetRequest{ Key: "weapon100", Value: "deathstar100", }) - ttl, _ := ss.client.TTLResult(ss.ctx, "weapon100") + ttl, _ := ss.client.TTLResult(context.Background(), "weapon100") assert.Equal(t, time.Duration(globalTTLInSeconds)*time.Second, ttl) }) @@ -349,7 +346,7 @@ func TestRequestsWithGlobalTTL(t *testing.T) { "ttlInSeconds": strconv.Itoa(requestTTL), }, }) - ttl, _ := ss.client.TTLResult(ss.ctx, "weapon100") + ttl, _ := ss.client.TTLResult(context.Background(), "weapon100") assert.Equal(t, time.Duration(requestTTL)*time.Second, ttl) }) @@ -420,7 +417,6 @@ func TestSetRequestWithTTL(t *testing.T) { json: jsoniter.ConfigFastest, logger: logger.NewLogger("test"), } - ss.ctx, ss.cancel = context.WithCancel(context.Background()) t.Run("TTL specified", func(t *testing.T) { ttlInSeconds := 100 @@ -432,7 +428,7 @@ func TestSetRequestWithTTL(t *testing.T) { }, }) - ttl, _ := ss.client.TTLResult(ss.ctx, "weapon100") + ttl, _ := ss.client.TTLResult(context.Background(), "weapon100") assert.Equal(t, time.Duration(ttlInSeconds)*time.Second, ttl) }) @@ -443,7 +439,7 @@ func TestSetRequestWithTTL(t *testing.T) { Value: "deathstar200", }) - ttl, _ := ss.client.TTLResult(ss.ctx, "weapon200") + ttl, _ := ss.client.TTLResult(context.Background(), "weapon200") assert.Equal(t, time.Duration(-1), ttl) }) @@ -453,7 +449,7 @@ func TestSetRequestWithTTL(t *testing.T) { Key: "weapon300", Value: "deathstar300", }) - ttl, _ := ss.client.TTLResult(ss.ctx, "weapon300") + ttl, _ := ss.client.TTLResult(context.Background(), "weapon300") assert.Equal(t, time.Duration(-1), ttl) // make the key no longer persistent @@ -465,7 +461,7 @@ func TestSetRequestWithTTL(t *testing.T) { "ttlInSeconds": strconv.Itoa(ttlInSeconds), }, }) - ttl, _ = ss.client.TTLResult(ss.ctx, "weapon300") + ttl, _ = ss.client.TTLResult(context.Background(), "weapon300") assert.Equal(t, time.Duration(ttlInSeconds)*time.Second, ttl) // make the key persistent again @@ -476,7 +472,7 @@ func TestSetRequestWithTTL(t *testing.T) { "ttlInSeconds": strconv.Itoa(-1), }, }) - ttl, _ = ss.client.TTLResult(ss.ctx, "weapon300") + ttl, _ = ss.client.TTLResult(context.Background(), "weapon300") assert.Equal(t, time.Duration(-1), ttl) }) } @@ -490,7 +486,6 @@ func TestTransactionalDeleteNoEtag(t *testing.T) { json: jsoniter.ConfigFastest, logger: logger.NewLogger("test"), } - ss.ctx, ss.cancel = context.WithCancel(context.Background()) // Insert a record first. ss.Set(context.Background(), &state.SetRequest{ diff --git a/state/rethinkdb/rethinkdb.go b/state/rethinkdb/rethinkdb.go index bc4f60cec..5037d0bff 100644 --- a/state/rethinkdb/rethinkdb.go +++ b/state/rethinkdb/rethinkdb.go @@ -66,7 +66,7 @@ func NewRethinkDBStateStore(logger logger.Logger) state.Store { } // Init parses metadata, initializes the RethinkDB client, and ensures the state table exists. -func (s *RethinkDB) Init(metadata state.Metadata) error { +func (s *RethinkDB) Init(ctx context.Context, metadata state.Metadata) error { r.Log.Out = io.Discard r.SetTags("rethinkdb", "json") cfg, err := metadataToConfig(metadata.Properties, s.logger) @@ -87,9 +87,9 @@ func (s *RethinkDB) Init(metadata state.Metadata) error { s.config = cfg // check if table already exists - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - c, err := r.DB(s.config.Database).TableList().Run(s.session, r.RunOpts{Context: ctx}) - cancel() + listContext, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + c, err := r.DB(s.config.Database).TableList().Run(s.session, r.RunOpts{Context: listContext}) if err != nil { return errors.Wrap(err, "error checking for state table existence in DB") } @@ -106,11 +106,11 @@ func (s *RethinkDB) Init(metadata state.Metadata) error { } if !tableExists(list, s.config.Table) { - ctx, cancel = context.WithTimeout(context.Background(), 30*time.Second) + cctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() _, err = r.DB(s.config.Database).TableCreate(s.config.Table, r.TableCreateOpts{ PrimaryKey: stateTablePKName, - }).RunWrite(s.session, r.RunOpts{Context: ctx}) - cancel() + }).RunWrite(s.session, r.RunOpts{Context: cctx}) if err != nil { return errors.Wrap(err, "error creating state table in DB") } @@ -118,21 +118,21 @@ func (s *RethinkDB) Init(metadata state.Metadata) error { if s.config.Archive && !tableExists(list, stateArchiveTableName) { // create archive table with autokey to preserve state id - ctx, cancel = context.WithTimeout(context.Background(), 30*time.Second) + ctblCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() _, err = r.DB(s.config.Database).TableCreate(stateArchiveTableName, - r.TableCreateOpts{PrimaryKey: stateArchiveTablePKName}).RunWrite(s.session, r.RunOpts{Context: ctx}) - cancel() + r.TableCreateOpts{PrimaryKey: stateArchiveTablePKName}).RunWrite(s.session, r.RunOpts{Context: ctblCtx}) if err != nil { return errors.Wrap(err, "error creating state archive table in DB") } // index archive table for id and timestamp - ctx, cancel = context.WithTimeout(context.Background(), 30*time.Second) + cindCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() _, err = r.DB(s.config.Database).Table(stateArchiveTableName). IndexCreateFunc("state_index", func(row r.Term) interface{} { return []interface{}{row.Field("id"), row.Field("timestamp")} - }).RunWrite(s.session, r.RunOpts{Context: ctx}) - cancel() + }).RunWrite(s.session, r.RunOpts{Context: cindCtx}) if err != nil { return errors.Wrap(err, "error creating state archive index in DB") } diff --git a/state/rethinkdb/rethinkdb_test.go b/state/rethinkdb/rethinkdb_test.go index 4d3ebb878..6079b4e3f 100644 --- a/state/rethinkdb/rethinkdb_test.go +++ b/state/rethinkdb/rethinkdb_test.go @@ -70,13 +70,13 @@ func TestRethinkDBStateStore(t *testing.T) { db := NewRethinkDBStateStore(logger.NewLogger("test")).(*RethinkDB) t.Run("With init", func(t *testing.T) { - if err := db.Init(m); err != nil { + if err := db.Init(context.Background(), m); err != nil { t.Fatalf("error initializing db: %v", err) } assert.Equal(t, stateTableNameDefault, db.config.Table) m.Properties["table"] = "test" - if err := db.Init(m); err != nil { + if err := db.Init(context.Background(), m); err != nil { t.Fatalf("error initializing db: %v", err) } assert.Equal(t, "test", db.config.Table) @@ -157,7 +157,7 @@ func TestRethinkDBStateStoreRongRun(t *testing.T) { m := state.Metadata{Base: metadata.Base{Properties: getTestMetadata()}} db := NewRethinkDBStateStore(logger.NewLogger("test")).(*RethinkDB) - if err := db.Init(m); err != nil { + if err := db.Init(context.Background(), m); err != nil { t.Fatalf("error initializing db: %v", err) } @@ -212,7 +212,7 @@ func TestRethinkDBStateStoreMulti(t *testing.T) { m := state.Metadata{Base: metadata.Base{Properties: getTestMetadata()}} db := NewRethinkDBStateStore(logger.NewLogger("test")).(*RethinkDB) - if err := db.Init(m); err != nil { + if err := db.Init(context.Background(), m); err != nil { t.Fatalf("error initializing db: %v", err) } diff --git a/state/sqlite/sqlite.go b/state/sqlite/sqlite.go index b0b96f1b0..f01a0aff0 100644 --- a/state/sqlite/sqlite.go +++ b/state/sqlite/sqlite.go @@ -45,8 +45,8 @@ func newSQLiteStateStore(logger logger.Logger, dba DBAccess) *SQLiteStore { } // Init initializes the Sql server state store. -func (s *SQLiteStore) Init(metadata state.Metadata) error { - return s.dbaccess.Init(metadata) +func (s *SQLiteStore) Init(ctx context.Context, metadata state.Metadata) error { + return s.dbaccess.Init(ctx, metadata) } func (s SQLiteStore) GetComponentMetadata() map[string]string { @@ -64,8 +64,8 @@ func (s *SQLiteStore) Features() []state.Feature { } } -func (s *SQLiteStore) Ping() error { - return s.dbaccess.Ping(context.TODO()) +func (s *SQLiteStore) Ping(ctx context.Context) error { + return s.dbaccess.Ping(ctx) } // Delete removes an entity from the store. diff --git a/state/sqlite/sqlite_dbaccess.go b/state/sqlite/sqlite_dbaccess.go index 6349a31a4..ba258d153 100644 --- a/state/sqlite/sqlite_dbaccess.go +++ b/state/sqlite/sqlite_dbaccess.go @@ -36,7 +36,7 @@ import ( // DBAccess is a private interface which enables unit testing of SQLite. type DBAccess interface { - Init(metadata state.Metadata) error + Init(ctx context.Context, metadata state.Metadata) error Ping(ctx context.Context) error Set(ctx context.Context, req *state.SetRequest) error Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) @@ -70,7 +70,7 @@ func newSqliteDBAccess(logger logger.Logger) *sqliteDBAccess { // Init sets up SQLite Database connection and ensures that the state table // exists. -func (a *sqliteDBAccess) Init(md state.Metadata) error { +func (a *sqliteDBAccess) Init(ctx context.Context, md state.Metadata) error { err := a.metadata.InitWithMetadata(md) if err != nil { return err diff --git a/state/sqlite/sqlite_integration_test.go b/state/sqlite/sqlite_integration_test.go index ea8ef85e1..6ceb0b13e 100644 --- a/state/sqlite/sqlite_integration_test.go +++ b/state/sqlite/sqlite_integration_test.go @@ -57,7 +57,7 @@ func TestSqliteIntegration(t *testing.T) { defer s.Close() }) - if initerror := s.Init(metadata); initerror != nil { + if initerror := s.Init(context.Background(), metadata); initerror != nil { t.Fatal(initerror) } @@ -647,7 +647,7 @@ func testInitConfiguration(t *testing.T) { }, } - err := p.Init(metadata) + err := p.Init(context.Background(), metadata) if tt.expectedErr == "" { assert.NoError(t, err) } else { diff --git a/state/sqlite/sqlite_test.go b/state/sqlite/sqlite_test.go index 991910525..7ed11ea11 100644 --- a/state/sqlite/sqlite_test.go +++ b/state/sqlite/sqlite_test.go @@ -259,7 +259,7 @@ func TestGetConnectionString(t *testing.T) { func TestInitRunsDBAccessInit(t *testing.T) { t.Parallel() ods, fake := createSqliteWithFake(t) - ods.Ping() + ods.Ping(context.Background()) assert.True(t, fake.initExecuted) } @@ -309,7 +309,7 @@ func TestValidMultiDeleteRequest(t *testing.T) { func TestPingRunsDBAccessPing(t *testing.T) { t.Parallel() odb, fake := createSqliteWithFake(t) - odb.Ping() + odb.Ping(context.Background()) assert.True(t, fake.pingExecuted) } @@ -327,7 +327,7 @@ func (m *fakeDBaccess) Ping(ctx context.Context) error { return nil } -func (m *fakeDBaccess) Init(metadata state.Metadata) error { +func (m *fakeDBaccess) Init(ctx context.Context, metadata state.Metadata) error { m.initExecuted = true return nil @@ -375,7 +375,7 @@ func createSqlite(t *testing.T) *SQLiteStore { }, } - err := odb.Init(*metadata) + err := odb.Init(context.Background(), *metadata) assert.NoError(t, err) assert.NotNil(t, odb.dbaccess) diff --git a/state/sqlserver/sqlserver.go b/state/sqlserver/sqlserver.go index 62b9599f4..6951d2bb2 100644 --- a/state/sqlserver/sqlserver.go +++ b/state/sqlserver/sqlserver.go @@ -165,7 +165,7 @@ func isValidIndexedPropertyType(s string) bool { } // Init initializes the SQL server state store. -func (s *SQLServer) Init(metadata state.Metadata) error { +func (s *SQLServer) Init(_ context.Context, metadata state.Metadata) error { err := s.parseMetadata(metadata.Properties) if err != nil { return err diff --git a/state/sqlserver/sqlserver_integration_test.go b/state/sqlserver/sqlserver_integration_test.go index 1a06d3310..3613156e8 100644 --- a/state/sqlserver/sqlserver_integration_test.go +++ b/state/sqlserver/sqlserver_integration_test.go @@ -124,7 +124,7 @@ func getTestStoreWithKeyType(t *testing.T, kt KeyType, indexedProperties string) schema := getUniqueDBSchema() metadata := createMetadata(schema, kt, indexedProperties) store := NewSQLServerStateStore(logger.NewLogger("test")).(*SQLServer) - err := store.Init(metadata) + err := store.Init(context.Background(), metadata) assert.Nil(t, err) return store @@ -800,7 +800,7 @@ func testMultipleInitializations(t *testing.T) { store := getTestStoreWithKeyType(t, test.kt, test.indexedProperties) store2 := NewSQLServerStateStore(logger.NewLogger("test")).(*SQLServer) - assert.Nil(t, store2.Init(createMetadata(store.schema, test.kt, test.indexedProperties))) + assert.Nil(t, store2.Init(context.Background(), createMetadata(store.schema, test.kt, test.indexedProperties))) }) } } diff --git a/state/sqlserver/sqlserver_test.go b/state/sqlserver/sqlserver_test.go index cacaf4eb0..d92c5a941 100644 --- a/state/sqlserver/sqlserver_test.go +++ b/state/sqlserver/sqlserver_test.go @@ -15,6 +15,7 @@ limitations under the License. package sqlserver import ( + "context" "errors" "testing" @@ -192,7 +193,7 @@ func TestValidConfiguration(t *testing.T) { Base: metadata.Base{Properties: tt.props}, } - err := sqlStore.Init(metadata) + err := sqlStore.Init(context.Background(), metadata) assert.Nil(t, err) assert.Equal(t, tt.expected.connectionString, sqlStore.connectionString) assert.Equal(t, tt.expected.tableName, sqlStore.tableName) @@ -329,7 +330,7 @@ func TestInvalidConfiguration(t *testing.T) { Base: metadata.Base{Properties: tt.props}, } - err := sqlStore.Init(metadata) + err := sqlStore.Init(context.Background(), metadata) assert.NotNil(t, err) if tt.expectedErr != "" { @@ -350,7 +351,7 @@ func TestExecuteMigrationFails(t *testing.T) { Base: metadata.Base{Properties: map[string]string{connectionStringKey: sampleConnectionString, tableNameKey: sampleUserTableName, databaseNameKey: "dapr_test_table"}}, } - err := sqlStore.Init(metadata) + err := sqlStore.Init(context.Background(), metadata) assert.NotNil(t, err) } diff --git a/state/store.go b/state/store.go index 803052bbe..38228b64a 100644 --- a/state/store.go +++ b/state/store.go @@ -23,7 +23,7 @@ import ( // Store is an interface to perform operations on store. type Store interface { BulkStore - Init(metadata Metadata) error + Init(ctx context.Context, metadata Metadata) error Features() []Feature Delete(ctx context.Context, req *DeleteRequest) error Get(ctx context.Context, req *GetRequest) (*GetResponse, error) @@ -31,10 +31,10 @@ type Store interface { GetComponentMetadata() map[string]string } -func Ping(store Store) error { +func Ping(ctx context.Context, store Store) error { // checks if this store has the ping option then executes if storeWithPing, ok := store.(health.Pinger); ok { - return storeWithPing.Ping() + return storeWithPing.Ping(ctx) } else { return fmt.Errorf("ping is not implemented by this state store") } diff --git a/state/store_test.go b/state/store_test.go index 0face2e2c..2efaa2cc3 100644 --- a/state/store_test.go +++ b/state/store_test.go @@ -107,7 +107,7 @@ type Store1 struct { bulkCount int } -func (s *Store1) Init(metadata Metadata) error { +func (s *Store1) Init(ctx context.Context, metadata Metadata) error { return nil } @@ -142,7 +142,7 @@ type Store2 struct { supportBulkGet bool } -func (s *Store2) Init(metadata Metadata) error { +func (s *Store2) Init(ctx context.Context, metadata Metadata) error { return nil } diff --git a/state/transactional_store.go b/state/transactional_store.go index 170f21d73..80019be34 100644 --- a/state/transactional_store.go +++ b/state/transactional_store.go @@ -19,6 +19,6 @@ import ( // TransactionalStore is an interface for initialization and support multiple transactional requests. type TransactionalStore interface { - Init(metadata Metadata) error + Init(ctx context.Context, metadata Metadata) error Multi(ctx context.Context, request *TransactionalStateRequest) error } diff --git a/state/zookeeper/zk.go b/state/zookeeper/zk.go index ee740ad75..bd2c13fe3 100644 --- a/state/zookeeper/zk.go +++ b/state/zookeeper/zk.go @@ -134,7 +134,7 @@ func NewZookeeperStateStore(logger logger.Logger) state.Store { } } -func (s *StateStore) Init(metadata state.Metadata) (err error) { +func (s *StateStore) Init(_ context.Context, metadata state.Metadata) (err error) { var c *config if c, err = newConfig(metadata.Properties); err != nil { diff --git a/tests/certification/bindings/azure/blobstorage/go.mod b/tests/certification/bindings/azure/blobstorage/go.mod index 0a86e4d7f..b4299a97d 100644 --- a/tests/certification/bindings/azure/blobstorage/go.mod +++ b/tests/certification/bindings/azure/blobstorage/go.mod @@ -162,3 +162,5 @@ require ( replace github.com/dapr/components-contrib/tests/certification => ../../../ replace github.com/dapr/components-contrib => ../../../../../ + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/bindings/azure/blobstorage/go.sum b/tests/certification/bindings/azure/blobstorage/go.sum index 16ad5aeae..319292cd4 100644 --- a/tests/certification/bindings/azure/blobstorage/go.sum +++ b/tests/certification/bindings/azure/blobstorage/go.sum @@ -76,6 +76,8 @@ github.com/AzureAD/microsoft-authentication-library-for-go v0.7.0/go.mod h1:BDJ5 github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -126,8 +128,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/bindings/azure/cosmosdb/go.mod b/tests/certification/bindings/azure/cosmosdb/go.mod index 4df09681c..3b699244b 100644 --- a/tests/certification/bindings/azure/cosmosdb/go.mod +++ b/tests/certification/bindings/azure/cosmosdb/go.mod @@ -164,3 +164,5 @@ require ( replace github.com/dapr/components-contrib/tests/certification => ../../../ replace github.com/dapr/components-contrib => ../../../../../ + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/bindings/azure/cosmosdb/go.sum b/tests/certification/bindings/azure/cosmosdb/go.sum index d417d771c..61d4c0cc6 100644 --- a/tests/certification/bindings/azure/cosmosdb/go.sum +++ b/tests/certification/bindings/azure/cosmosdb/go.sum @@ -78,6 +78,8 @@ github.com/AzureAD/microsoft-authentication-library-for-go v0.7.0/go.mod h1:BDJ5 github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -130,8 +132,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/bindings/azure/eventhubs/go.mod b/tests/certification/bindings/azure/eventhubs/go.mod index b29f8e1a4..c4f46be28 100644 --- a/tests/certification/bindings/azure/eventhubs/go.mod +++ b/tests/certification/bindings/azure/eventhubs/go.mod @@ -168,3 +168,5 @@ require ( replace github.com/dapr/components-contrib => ../../../../.. replace github.com/dapr/components-contrib/tests/certification => ../../.. + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/bindings/azure/eventhubs/go.sum b/tests/certification/bindings/azure/eventhubs/go.sum index 045f74f44..f886196e1 100644 --- a/tests/certification/bindings/azure/eventhubs/go.sum +++ b/tests/certification/bindings/azure/eventhubs/go.sum @@ -80,6 +80,8 @@ github.com/AzureAD/microsoft-authentication-library-for-go v0.7.0/go.mod h1:BDJ5 github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -130,8 +132,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/bindings/azure/servicebusqueues/go.mod b/tests/certification/bindings/azure/servicebusqueues/go.mod index 4fe538def..a3da0eebf 100644 --- a/tests/certification/bindings/azure/servicebusqueues/go.mod +++ b/tests/certification/bindings/azure/servicebusqueues/go.mod @@ -169,3 +169,5 @@ require ( replace github.com/dapr/components-contrib => ../../../../.. replace github.com/dapr/components-contrib/tests/certification => ../../.. + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/bindings/azure/servicebusqueues/go.sum b/tests/certification/bindings/azure/servicebusqueues/go.sum index 6873b6ce5..082af130e 100644 --- a/tests/certification/bindings/azure/servicebusqueues/go.sum +++ b/tests/certification/bindings/azure/servicebusqueues/go.sum @@ -78,6 +78,8 @@ github.com/AzureAD/microsoft-authentication-library-for-go v0.7.0/go.mod h1:BDJ5 github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -130,8 +132,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/bindings/azure/storagequeues/go.mod b/tests/certification/bindings/azure/storagequeues/go.mod index 8426eafff..8afa85b6f 100644 --- a/tests/certification/bindings/azure/storagequeues/go.mod +++ b/tests/certification/bindings/azure/storagequeues/go.mod @@ -165,3 +165,5 @@ require ( replace github.com/dapr/components-contrib => ../../../../.. replace github.com/dapr/components-contrib/tests/certification => ../../.. + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/bindings/azure/storagequeues/go.sum b/tests/certification/bindings/azure/storagequeues/go.sum index 1026a6850..83dae1407 100644 --- a/tests/certification/bindings/azure/storagequeues/go.sum +++ b/tests/certification/bindings/azure/storagequeues/go.sum @@ -74,6 +74,8 @@ github.com/AzureAD/microsoft-authentication-library-for-go v0.7.0/go.mod h1:BDJ5 github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -124,8 +126,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/bindings/cron/go.mod b/tests/certification/bindings/cron/go.mod index 289330a88..41287e9f1 100644 --- a/tests/certification/bindings/cron/go.mod +++ b/tests/certification/bindings/cron/go.mod @@ -149,3 +149,5 @@ replace github.com/dapr/components-contrib => ../../../../ // in the Dapr runtime. Don't commit with this uncommented! // // replace github.com/dapr/dapr => ../../../../../dapr + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/bindings/cron/go.sum b/tests/certification/bindings/cron/go.sum index f7859f893..02a85fe2d 100644 --- a/tests/certification/bindings/cron/go.sum +++ b/tests/certification/bindings/cron/go.sum @@ -38,6 +38,8 @@ github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go. github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -89,8 +91,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/bindings/dubbo/go.mod b/tests/certification/bindings/dubbo/go.mod index 3c294e61a..d598c2aee 100644 --- a/tests/certification/bindings/dubbo/go.mod +++ b/tests/certification/bindings/dubbo/go.mod @@ -178,3 +178,5 @@ require ( replace github.com/dapr/components-contrib => ../../../.. replace github.com/dapr/components-contrib/tests/certification => ../.. + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/bindings/dubbo/go.sum b/tests/certification/bindings/dubbo/go.sum index 30472e041..84291aa4e 100644 --- a/tests/certification/bindings/dubbo/go.sum +++ b/tests/certification/bindings/dubbo/go.sum @@ -44,6 +44,8 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= github.com/HdrHistogram/hdrhistogram-go v1.1.2/go.mod h1:yDgFjdqOqDEKOvasDdhWNXYg9BVp4O+o5f6V/ehm6Oo= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= @@ -167,8 +169,6 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creasty/defaults v1.5.2 h1:/VfB6uxpyp6h0fr7SPp7n8WJBoV8jfxQXPCnkVSjyls= github.com/creasty/defaults v1.5.2/go.mod h1:FPZ+Y0WNrbqOVw+c6av63eyHUAl6pMHZwqLPvXUZGfY= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/bindings/kafka/go.mod b/tests/certification/bindings/kafka/go.mod index f8d9cd914..64a58863b 100644 --- a/tests/certification/bindings/kafka/go.mod +++ b/tests/certification/bindings/kafka/go.mod @@ -161,3 +161,5 @@ require ( replace github.com/dapr/components-contrib/tests/certification => ../../ replace github.com/dapr/components-contrib => ../../../../ + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/bindings/kafka/go.sum b/tests/certification/bindings/kafka/go.sum index 71910f926..43ef02e3d 100644 --- a/tests/certification/bindings/kafka/go.sum +++ b/tests/certification/bindings/kafka/go.sum @@ -38,6 +38,8 @@ github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go. github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -91,8 +93,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/bindings/localstorage/go.mod b/tests/certification/bindings/localstorage/go.mod index bcd49dc93..0ec3e1a05 100644 --- a/tests/certification/bindings/localstorage/go.mod +++ b/tests/certification/bindings/localstorage/go.mod @@ -141,3 +141,5 @@ require ( replace github.com/dapr/components-contrib => ../../../.. replace github.com/dapr/components-contrib/tests/certification => ../.. + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/bindings/localstorage/go.sum b/tests/certification/bindings/localstorage/go.sum index d6770cee4..ae239fb84 100644 --- a/tests/certification/bindings/localstorage/go.sum +++ b/tests/certification/bindings/localstorage/go.sum @@ -38,6 +38,8 @@ github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go. github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -90,8 +92,6 @@ github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWH github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/cyphar/filepath-securejoin v0.2.3 h1:YX6ebbZCZP7VkM3scTTokDgBL2TY741X51MTk3ycuNI= github.com/cyphar/filepath-securejoin v0.2.3/go.mod h1:aPGpWjXOXUn2NCNjFvBE6aRxGGx79pTxQpKOJNYHHl4= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/bindings/nacos/go.mod b/tests/certification/bindings/nacos/go.mod index 21f0e2e88..c9515c029 100644 --- a/tests/certification/bindings/nacos/go.mod +++ b/tests/certification/bindings/nacos/go.mod @@ -149,3 +149,5 @@ require ( replace github.com/dapr/components-contrib => ../../../.. replace github.com/dapr/components-contrib/tests/certification => ../.. + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/bindings/nacos/go.sum b/tests/certification/bindings/nacos/go.sum index 6a1222d70..9e300da84 100644 --- a/tests/certification/bindings/nacos/go.sum +++ b/tests/certification/bindings/nacos/go.sum @@ -39,6 +39,8 @@ github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -95,8 +97,6 @@ github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20211001041855-01bcc9b48dfe/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/bindings/postgres/go.mod b/tests/certification/bindings/postgres/go.mod index 60df3a05d..d5a996772 100644 --- a/tests/certification/bindings/postgres/go.mod +++ b/tests/certification/bindings/postgres/go.mod @@ -148,3 +148,5 @@ require ( replace github.com/dapr/components-contrib/tests/certification => ../../ replace github.com/dapr/components-contrib => ../../../../ + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/bindings/postgres/go.sum b/tests/certification/bindings/postgres/go.sum index 161e6937f..5bbb2a473 100644 --- a/tests/certification/bindings/postgres/go.sum +++ b/tests/certification/bindings/postgres/go.sum @@ -38,6 +38,8 @@ github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go. github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -88,8 +90,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/bindings/rabbitmq/go.mod b/tests/certification/bindings/rabbitmq/go.mod index b13bbb6b4..20bccb0bb 100644 --- a/tests/certification/bindings/rabbitmq/go.mod +++ b/tests/certification/bindings/rabbitmq/go.mod @@ -145,3 +145,5 @@ require ( replace github.com/dapr/components-contrib/tests/certification => ../../ replace github.com/dapr/components-contrib => ../../../../ + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/bindings/rabbitmq/go.sum b/tests/certification/bindings/rabbitmq/go.sum index 7f8b069e2..ffd138a06 100644 --- a/tests/certification/bindings/rabbitmq/go.sum +++ b/tests/certification/bindings/rabbitmq/go.sum @@ -38,6 +38,8 @@ github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go. github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -88,8 +90,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/bindings/redis/go.mod b/tests/certification/bindings/redis/go.mod index 2a2b83c31..2057e9744 100644 --- a/tests/certification/bindings/redis/go.mod +++ b/tests/certification/bindings/redis/go.mod @@ -144,3 +144,5 @@ require ( replace github.com/dapr/components-contrib => ../../../.. replace github.com/dapr/components-contrib/tests/certification => ../.. + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/bindings/redis/go.sum b/tests/certification/bindings/redis/go.sum index 21e3bf83e..f58cbfef6 100644 --- a/tests/certification/bindings/redis/go.sum +++ b/tests/certification/bindings/redis/go.sum @@ -38,6 +38,8 @@ github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go. github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -90,8 +92,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/go.mod b/tests/certification/go.mod index 81b28011b..40968453f 100644 --- a/tests/certification/go.mod +++ b/tests/certification/go.mod @@ -139,3 +139,5 @@ require ( ) replace github.com/dapr/components-contrib => ../../ + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/go.sum b/tests/certification/go.sum index eac7d8579..f012781bb 100644 --- a/tests/certification/go.sum +++ b/tests/certification/go.sum @@ -38,6 +38,8 @@ github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go. github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -88,8 +90,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/pubsub/aws/snssqs/go.mod b/tests/certification/pubsub/aws/snssqs/go.mod index 73dbcc198..b3ce00121 100644 --- a/tests/certification/pubsub/aws/snssqs/go.mod +++ b/tests/certification/pubsub/aws/snssqs/go.mod @@ -146,3 +146,5 @@ require ( sigs.k8s.io/structured-merge-diff/v4 v4.2.3 // indirect sigs.k8s.io/yaml v1.3.0 // indirect ) + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/pubsub/aws/snssqs/go.sum b/tests/certification/pubsub/aws/snssqs/go.sum index 8e714ae90..ec70f2615 100644 --- a/tests/certification/pubsub/aws/snssqs/go.sum +++ b/tests/certification/pubsub/aws/snssqs/go.sum @@ -38,6 +38,8 @@ github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go. github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -90,8 +92,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/pubsub/azure/eventhubs/go.mod b/tests/certification/pubsub/azure/eventhubs/go.mod index 8050c10fd..27ea3a1cf 100644 --- a/tests/certification/pubsub/azure/eventhubs/go.mod +++ b/tests/certification/pubsub/azure/eventhubs/go.mod @@ -167,3 +167,5 @@ require ( replace github.com/dapr/components-contrib/tests/certification => ../../../ replace github.com/dapr/components-contrib => ../../../../../ + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/pubsub/azure/eventhubs/go.sum b/tests/certification/pubsub/azure/eventhubs/go.sum index e85c9ae3a..2b021963c 100644 --- a/tests/certification/pubsub/azure/eventhubs/go.sum +++ b/tests/certification/pubsub/azure/eventhubs/go.sum @@ -80,6 +80,8 @@ github.com/AzureAD/microsoft-authentication-library-for-go v0.7.0/go.mod h1:BDJ5 github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -130,8 +132,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/pubsub/azure/servicebus/topics/go.mod b/tests/certification/pubsub/azure/servicebus/topics/go.mod index 44053bbf7..2bdebb21d 100644 --- a/tests/certification/pubsub/azure/servicebus/topics/go.mod +++ b/tests/certification/pubsub/azure/servicebus/topics/go.mod @@ -169,3 +169,5 @@ require ( replace github.com/dapr/components-contrib => ../../../../../.. replace github.com/dapr/components-contrib/tests/certification => ../../../.. + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/pubsub/azure/servicebus/topics/go.sum b/tests/certification/pubsub/azure/servicebus/topics/go.sum index 6873b6ce5..082af130e 100644 --- a/tests/certification/pubsub/azure/servicebus/topics/go.sum +++ b/tests/certification/pubsub/azure/servicebus/topics/go.sum @@ -78,6 +78,8 @@ github.com/AzureAD/microsoft-authentication-library-for-go v0.7.0/go.mod h1:BDJ5 github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -130,8 +132,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/pubsub/kafka/go.mod b/tests/certification/pubsub/kafka/go.mod index 69b928cc0..05b51b40d 100644 --- a/tests/certification/pubsub/kafka/go.mod +++ b/tests/certification/pubsub/kafka/go.mod @@ -161,3 +161,5 @@ require ( replace github.com/dapr/components-contrib/tests/certification => ../../ replace github.com/dapr/components-contrib => ../../../../ + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/pubsub/kafka/go.sum b/tests/certification/pubsub/kafka/go.sum index 71910f926..43ef02e3d 100644 --- a/tests/certification/pubsub/kafka/go.sum +++ b/tests/certification/pubsub/kafka/go.sum @@ -38,6 +38,8 @@ github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go. github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -91,8 +93,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/pubsub/mqtt3/go.mod b/tests/certification/pubsub/mqtt3/go.mod index 2c89c4d4b..a99099e67 100644 --- a/tests/certification/pubsub/mqtt3/go.mod +++ b/tests/certification/pubsub/mqtt3/go.mod @@ -146,3 +146,5 @@ require ( replace github.com/dapr/components-contrib/tests/certification => ../../ replace github.com/dapr/components-contrib => ../../../../ + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/pubsub/mqtt3/go.sum b/tests/certification/pubsub/mqtt3/go.sum index 18729a69d..46cd91a21 100644 --- a/tests/certification/pubsub/mqtt3/go.sum +++ b/tests/certification/pubsub/mqtt3/go.sum @@ -38,6 +38,8 @@ github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go. github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -88,8 +90,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/pubsub/pulsar/go.mod b/tests/certification/pubsub/pulsar/go.mod index 592fa9b82..0979c9843 100644 --- a/tests/certification/pubsub/pulsar/go.mod +++ b/tests/certification/pubsub/pulsar/go.mod @@ -162,3 +162,5 @@ require ( sigs.k8s.io/structured-merge-diff/v4 v4.2.3 // indirect sigs.k8s.io/yaml v1.3.0 // indirect ) + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/pubsub/pulsar/go.sum b/tests/certification/pubsub/pulsar/go.sum index 04a418f68..e4de72893 100644 --- a/tests/certification/pubsub/pulsar/go.sum +++ b/tests/certification/pubsub/pulsar/go.sum @@ -52,6 +52,8 @@ github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= github.com/DataDog/zstd v1.5.0 h1:+K/VEwIAaPcHiMtQvpLD4lqW7f0Gk3xdYZmI1hD+CXo= github.com/DataDog/zstd v1.5.0/go.mod h1:g4AWEaM3yOg3HYfnJ3YIawPnVdXJh9QME85blwSAmyw= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -116,8 +118,6 @@ github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsr github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/danieljoos/wincred v1.1.2 h1:QLdCxFs1/Yl4zduvBdcHB8goaYk9RARS2SgLLRuAyr0= github.com/danieljoos/wincred v1.1.2/go.mod h1:GijpziifJoIBfYh+S7BbkdUTU4LfM+QnGqR5Vl2tAx0= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/pubsub/rabbitmq/go.mod b/tests/certification/pubsub/rabbitmq/go.mod index eea4ed511..e905a6812 100644 --- a/tests/certification/pubsub/rabbitmq/go.mod +++ b/tests/certification/pubsub/rabbitmq/go.mod @@ -145,3 +145,5 @@ require ( replace github.com/dapr/components-contrib/tests/certification => ../../ replace github.com/dapr/components-contrib => ../../../../ + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/pubsub/rabbitmq/go.sum b/tests/certification/pubsub/rabbitmq/go.sum index 7f8b069e2..ffd138a06 100644 --- a/tests/certification/pubsub/rabbitmq/go.sum +++ b/tests/certification/pubsub/rabbitmq/go.sum @@ -38,6 +38,8 @@ github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go. github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -88,8 +90,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/secretstores/azure/keyvault/go.mod b/tests/certification/secretstores/azure/keyvault/go.mod index f5c037d13..e20d35daf 100644 --- a/tests/certification/secretstores/azure/keyvault/go.mod +++ b/tests/certification/secretstores/azure/keyvault/go.mod @@ -163,3 +163,5 @@ require ( replace github.com/dapr/components-contrib/tests/certification => ../../../ replace github.com/dapr/components-contrib => ../../../../../ + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/secretstores/azure/keyvault/go.sum b/tests/certification/secretstores/azure/keyvault/go.sum index 50e21ce6a..94aa0850d 100644 --- a/tests/certification/secretstores/azure/keyvault/go.sum +++ b/tests/certification/secretstores/azure/keyvault/go.sum @@ -78,6 +78,8 @@ github.com/AzureAD/microsoft-authentication-library-for-go v0.7.0/go.mod h1:BDJ5 github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -128,8 +130,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/secretstores/hashicorp/vault/go.mod b/tests/certification/secretstores/hashicorp/vault/go.mod index 48f83f798..21be871ee 100644 --- a/tests/certification/secretstores/hashicorp/vault/go.mod +++ b/tests/certification/secretstores/hashicorp/vault/go.mod @@ -141,3 +141,5 @@ require ( replace github.com/dapr/components-contrib/tests/certification => ../../../ replace github.com/dapr/components-contrib => ../../../../../ + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/secretstores/hashicorp/vault/go.sum b/tests/certification/secretstores/hashicorp/vault/go.sum index 193e62cd9..f55a234d2 100644 --- a/tests/certification/secretstores/hashicorp/vault/go.sum +++ b/tests/certification/secretstores/hashicorp/vault/go.sum @@ -38,6 +38,8 @@ github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go. github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -88,8 +90,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/secretstores/local/env/go.mod b/tests/certification/secretstores/local/env/go.mod index 0ecc9ebd4..5d8569367 100644 --- a/tests/certification/secretstores/local/env/go.mod +++ b/tests/certification/secretstores/local/env/go.mod @@ -140,3 +140,5 @@ require ( replace github.com/dapr/components-contrib => ../../../../.. replace github.com/dapr/components-contrib/tests/certification => ../../.. + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/secretstores/local/env/go.sum b/tests/certification/secretstores/local/env/go.sum index a825e76f0..0e7a0ab69 100644 --- a/tests/certification/secretstores/local/env/go.sum +++ b/tests/certification/secretstores/local/env/go.sum @@ -38,6 +38,8 @@ github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go. github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -88,8 +90,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/secretstores/local/file/go.mod b/tests/certification/secretstores/local/file/go.mod index de198efad..8d2020c73 100644 --- a/tests/certification/secretstores/local/file/go.mod +++ b/tests/certification/secretstores/local/file/go.mod @@ -140,3 +140,5 @@ require ( replace github.com/dapr/components-contrib => ../../../../.. replace github.com/dapr/components-contrib/tests/certification => ../../.. + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/secretstores/local/file/go.sum b/tests/certification/secretstores/local/file/go.sum index a825e76f0..0e7a0ab69 100644 --- a/tests/certification/secretstores/local/file/go.sum +++ b/tests/certification/secretstores/local/file/go.sum @@ -38,6 +38,8 @@ github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go. github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -88,8 +90,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/state/aws/dynamodb/go.mod b/tests/certification/state/aws/dynamodb/go.mod index 1aa610e01..84d6c9126 100644 --- a/tests/certification/state/aws/dynamodb/go.mod +++ b/tests/certification/state/aws/dynamodb/go.mod @@ -142,3 +142,5 @@ require ( sigs.k8s.io/structured-merge-diff/v4 v4.2.3 // indirect sigs.k8s.io/yaml v1.3.0 // indirect ) + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/state/aws/dynamodb/go.sum b/tests/certification/state/aws/dynamodb/go.sum index 24795b40a..6c5c7161b 100644 --- a/tests/certification/state/aws/dynamodb/go.sum +++ b/tests/certification/state/aws/dynamodb/go.sum @@ -38,6 +38,8 @@ github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go. github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -90,8 +92,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/state/azure/blobstorage/go.mod b/tests/certification/state/azure/blobstorage/go.mod index 1907fd90d..048053868 100644 --- a/tests/certification/state/azure/blobstorage/go.mod +++ b/tests/certification/state/azure/blobstorage/go.mod @@ -162,3 +162,5 @@ require ( replace github.com/dapr/components-contrib => ../../../../.. replace github.com/dapr/components-contrib/tests/certification => ../../.. + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/state/azure/blobstorage/go.sum b/tests/certification/state/azure/blobstorage/go.sum index 16ad5aeae..319292cd4 100644 --- a/tests/certification/state/azure/blobstorage/go.sum +++ b/tests/certification/state/azure/blobstorage/go.sum @@ -76,6 +76,8 @@ github.com/AzureAD/microsoft-authentication-library-for-go v0.7.0/go.mod h1:BDJ5 github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -126,8 +128,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/state/azure/cosmosdb/go.mod b/tests/certification/state/azure/cosmosdb/go.mod index 054cf4860..477f297e5 100644 --- a/tests/certification/state/azure/cosmosdb/go.mod +++ b/tests/certification/state/azure/cosmosdb/go.mod @@ -163,3 +163,5 @@ require ( replace github.com/dapr/components-contrib => ../../../../.. replace github.com/dapr/components-contrib/tests/certification => ../../.. + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/state/azure/cosmosdb/go.sum b/tests/certification/state/azure/cosmosdb/go.sum index c1d5e7b84..0b1be7567 100644 --- a/tests/certification/state/azure/cosmosdb/go.sum +++ b/tests/certification/state/azure/cosmosdb/go.sum @@ -78,6 +78,8 @@ github.com/AzureAD/microsoft-authentication-library-for-go v0.7.0/go.mod h1:BDJ5 github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -128,8 +130,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/state/azure/tablestorage/go.mod b/tests/certification/state/azure/tablestorage/go.mod index 0772e8b61..da82571a8 100644 --- a/tests/certification/state/azure/tablestorage/go.mod +++ b/tests/certification/state/azure/tablestorage/go.mod @@ -162,3 +162,5 @@ require ( replace github.com/dapr/components-contrib => ../../../../.. replace github.com/dapr/components-contrib/tests/certification => ../../.. + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/state/azure/tablestorage/go.sum b/tests/certification/state/azure/tablestorage/go.sum index a653afe67..b04f14772 100644 --- a/tests/certification/state/azure/tablestorage/go.sum +++ b/tests/certification/state/azure/tablestorage/go.sum @@ -76,6 +76,8 @@ github.com/AzureAD/microsoft-authentication-library-for-go v0.7.0/go.mod h1:BDJ5 github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -126,8 +128,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/state/cassandra/go.mod b/tests/certification/state/cassandra/go.mod index c4faabf2b..8fbc46161 100644 --- a/tests/certification/state/cassandra/go.mod +++ b/tests/certification/state/cassandra/go.mod @@ -144,3 +144,5 @@ require ( replace github.com/dapr/components-contrib/tests/certification => ../../ replace github.com/dapr/components-contrib => ../../../../ + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/state/cassandra/go.sum b/tests/certification/state/cassandra/go.sum index d5b5acee8..a85af46dc 100644 --- a/tests/certification/state/cassandra/go.sum +++ b/tests/certification/state/cassandra/go.sum @@ -38,6 +38,8 @@ github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go. github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -92,8 +94,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/state/cockroachdb/cockroachdb_test.go b/tests/certification/state/cockroachdb/cockroachdb_test.go index f62448219..74b576332 100644 --- a/tests/certification/state/cockroachdb/cockroachdb_test.go +++ b/tests/certification/state/cockroachdb/cockroachdb_test.go @@ -70,7 +70,7 @@ func TestCockroach(t *testing.T) { } defer client.Close() - err = stateStore.Ping() + err = stateStore.Ping(context.Background()) assert.Equal(t, nil, err) err = client.SaveState(ctx, stateStoreName, certificationTestPrefix+"key1", []byte("certificationdata"), nil) @@ -106,7 +106,7 @@ func TestCockroach(t *testing.T) { } defer client.Close() - err = stateStore.Ping() + err = stateStore.Ping(context.Background()) assert.Equal(t, nil, err) resp, err := stateStore.Get(context.Background(), &state.GetRequest{ diff --git a/tests/certification/state/cockroachdb/go.mod b/tests/certification/state/cockroachdb/go.mod index 3de3040e1..f715349b0 100644 --- a/tests/certification/state/cockroachdb/go.mod +++ b/tests/certification/state/cockroachdb/go.mod @@ -144,3 +144,5 @@ require ( replace github.com/dapr/components-contrib/tests/certification => ../../ replace github.com/dapr/components-contrib => ../../../../ + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/state/cockroachdb/go.sum b/tests/certification/state/cockroachdb/go.sum index de43234cd..cd41961a4 100644 --- a/tests/certification/state/cockroachdb/go.sum +++ b/tests/certification/state/cockroachdb/go.sum @@ -39,6 +39,8 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -89,8 +91,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/state/memcached/go.mod b/tests/certification/state/memcached/go.mod index 460407ebe..620d44da0 100644 --- a/tests/certification/state/memcached/go.mod +++ b/tests/certification/state/memcached/go.mod @@ -142,3 +142,5 @@ require ( replace github.com/dapr/components-contrib/tests/certification => ../../ replace github.com/dapr/components-contrib => ../../../../ + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/state/memcached/go.sum b/tests/certification/state/memcached/go.sum index 8414c3075..ec61e810c 100644 --- a/tests/certification/state/memcached/go.sum +++ b/tests/certification/state/memcached/go.sum @@ -38,6 +38,8 @@ github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go. github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -90,8 +92,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/state/mongodb/go.mod b/tests/certification/state/mongodb/go.mod index 814eef5da..aa585fb03 100644 --- a/tests/certification/state/mongodb/go.mod +++ b/tests/certification/state/mongodb/go.mod @@ -149,3 +149,5 @@ require ( replace github.com/dapr/components-contrib/tests/certification => ../../ replace github.com/dapr/components-contrib => ../../../../ + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/state/mongodb/go.sum b/tests/certification/state/mongodb/go.sum index 88b0aa500..766a25397 100644 --- a/tests/certification/state/mongodb/go.sum +++ b/tests/certification/state/mongodb/go.sum @@ -38,6 +38,8 @@ github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go. github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -88,8 +90,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/state/mysql/go.mod b/tests/certification/state/mysql/go.mod index f0673563b..ab83eba57 100644 --- a/tests/certification/state/mysql/go.mod +++ b/tests/certification/state/mysql/go.mod @@ -141,3 +141,5 @@ require ( replace github.com/dapr/components-contrib/tests/certification => ../../ replace github.com/dapr/components-contrib => ../../../../ + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/state/mysql/go.sum b/tests/certification/state/mysql/go.sum index 674022c07..8ff402bdc 100644 --- a/tests/certification/state/mysql/go.sum +++ b/tests/certification/state/mysql/go.sum @@ -39,6 +39,8 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -89,8 +91,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/state/mysql/mysql_test.go b/tests/certification/state/mysql/mysql_test.go index 9259a06ce..d7a40249a 100644 --- a/tests/certification/state/mysql/mysql_test.go +++ b/tests/certification/state/mysql/mysql_test.go @@ -14,6 +14,7 @@ limitations under the License. package main import ( + "context" "database/sql" "errors" "strconv" @@ -258,7 +259,7 @@ func TestMySQL(t *testing.T) { component := registeredComponents[idx] // Should fail - err = component.Ping() + err = component.Ping(context.Background()) require.Error(t, err) assert.Equal(t, "driver: bad connection", err.Error()) @@ -282,7 +283,7 @@ func TestMySQL(t *testing.T) { start := time.Now() // Should fail - err = component.Ping() + err = component.Ping(context.Background()) assert.Error(t, err) assert.Truef(t, errors.Is(err, context.DeadlineExceeded), "expected context.DeadlineExceeded but got %v", err) assert.GreaterOrEqual(t, time.Since(start), timeout) @@ -297,7 +298,7 @@ func TestMySQL(t *testing.T) { component := registeredComponents[idx] // Check connection is active - err = component.Ping() + err = component.Ping(context.Background()) require.NoError(t, err) // Close the component @@ -305,7 +306,7 @@ func TestMySQL(t *testing.T) { require.NoError(t, err) // Ensure the connection is closed - err = component.Ping() + err = component.Ping(context.Background()) require.Error(t, err) assert.Truef(t, errors.Is(err, sql.ErrConnDone), "expected sql.ErrConnDone but got %v", err) @@ -334,14 +335,14 @@ func TestMySQL(t *testing.T) { // Init the component component := stateMysql.NewMySQLStateStore(log).(*stateMysql.MySQL) - component.Init(state.Metadata{ + component.Init(context.Background(), state.Metadata{ Base: metadata.Base{ Properties: properties, }, }) // Check connection is active - err = component.Ping() + err = component.Ping(context.Background()) require.NoError(t, err) var exists int diff --git a/tests/certification/state/postgresql/go.mod b/tests/certification/state/postgresql/go.mod index c346d557a..197d34f31 100644 --- a/tests/certification/state/postgresql/go.mod +++ b/tests/certification/state/postgresql/go.mod @@ -146,3 +146,5 @@ require ( replace github.com/dapr/components-contrib/tests/certification => ../../ replace github.com/dapr/components-contrib => ../../../../ + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/state/postgresql/go.sum b/tests/certification/state/postgresql/go.sum index 96af6e283..271e0ff76 100644 --- a/tests/certification/state/postgresql/go.sum +++ b/tests/certification/state/postgresql/go.sum @@ -38,6 +38,8 @@ github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go. github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -88,8 +90,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/state/postgresql/postgresql_test.go b/tests/certification/state/postgresql/postgresql_test.go index 6e0947c36..4c50682db 100644 --- a/tests/certification/state/postgresql/postgresql_test.go +++ b/tests/certification/state/postgresql/postgresql_test.go @@ -111,7 +111,7 @@ func TestPostgreSQL(t *testing.T) { md.Properties[keyMetadataTableName] = "clean_metadata" // Init and perform the migrations - err := storeObj.Init(md) + err := storeObj.Init(context.Background(), md) require.NoError(t, err, "failed to init") // We should have the tables correctly created @@ -135,7 +135,7 @@ func TestPostgreSQL(t *testing.T) { md.Properties[keyMetadataTableName] = "public.clean2_metadata" // Init and perform the migrations - err := storeObj.Init(md) + err := storeObj.Init(context.Background(), md) require.NoError(t, err, "failed to init") // We should have the tables correctly created @@ -165,7 +165,7 @@ func TestPostgreSQL(t *testing.T) { assert.Equal(t, migrationLevel, level, "migration level mismatch: found '%s' but expected '%s'", level, migrationLevel) // Init and perform the migrations - err = storeObj.Init(md) + err = storeObj.Init(context.Background(), md) require.NoError(t, err, "failed to init") // Ensure migration level is correct @@ -199,7 +199,7 @@ func TestPostgreSQL(t *testing.T) { md.Properties[keyMetadataTableName] = "pre_metadata" // Init and perform the migrations - err = storeObj.Init(md) + err = storeObj.Init(context.Background(), md) require.NoError(t, err, "failed to init") // We should have the metadata table created @@ -248,7 +248,7 @@ func TestPostgreSQL(t *testing.T) { // Init and perform the migrations storeObj := state_postgres.NewPostgreSQLStateStore(l).(*state_postgres.PostgreSQL) - err := storeObj.Init(md) + err := storeObj.Init(context.Background(), md) if err != nil { errs <- fmt.Errorf("%d failed to init: %w", i, err) return @@ -506,7 +506,7 @@ func TestPostgreSQL(t *testing.T) { md.Properties[keyCleanupInterval] = "" storeObj := state_postgres.NewPostgreSQLStateStore(log).(*state_postgres.PostgreSQL) - err := storeObj.Init(md) + err := storeObj.Init(context.Background(), md) require.NoError(t, err, "failed to init") defer storeObj.Close() @@ -523,7 +523,7 @@ func TestPostgreSQL(t *testing.T) { md.Properties[keyCleanupInterval] = "10" storeObj := state_postgres.NewPostgreSQLStateStore(log).(*state_postgres.PostgreSQL) - err := storeObj.Init(md) + err := storeObj.Init(context.Background(), md) require.NoError(t, err, "failed to init") defer storeObj.Close() @@ -540,7 +540,7 @@ func TestPostgreSQL(t *testing.T) { md.Properties[keyCleanupInterval] = "0" storeObj := state_postgres.NewPostgreSQLStateStore(log).(*state_postgres.PostgreSQL) - err := storeObj.Init(md) + err := storeObj.Init(context.Background(), md) require.NoError(t, err, "failed to init") defer storeObj.Close() @@ -570,7 +570,7 @@ func TestPostgreSQL(t *testing.T) { md.Properties[keyCleanupInterval] = "1" storeObj := state_postgres.NewPostgreSQLStateStore(log).(*state_postgres.PostgreSQL) - err := storeObj.Init(md) + err := storeObj.Init(context.Background(), md) require.NoError(t, err, "failed to init") defer storeObj.Close() @@ -607,7 +607,7 @@ func TestPostgreSQL(t *testing.T) { md.Properties[keyCleanupInterval] = "3600" storeObj := state_postgres.NewPostgreSQLStateStore(log).(*state_postgres.PostgreSQL) - err := storeObj.Init(md) + err := storeObj.Init(context.Background(), md) require.NoError(t, err, "failed to init") defer storeObj.Close() diff --git a/tests/certification/state/redis/go.mod b/tests/certification/state/redis/go.mod index 98b06006e..88e104b42 100644 --- a/tests/certification/state/redis/go.mod +++ b/tests/certification/state/redis/go.mod @@ -145,3 +145,5 @@ require ( replace github.com/dapr/components-contrib/tests/certification => ../../ replace github.com/dapr/components-contrib => ../../../../ + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/state/redis/go.sum b/tests/certification/state/redis/go.sum index 21e3bf83e..f58cbfef6 100644 --- a/tests/certification/state/redis/go.sum +++ b/tests/certification/state/redis/go.sum @@ -38,6 +38,8 @@ github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go. github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -90,8 +92,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/certification/state/sqlserver/go.mod b/tests/certification/state/sqlserver/go.mod index 00ef1aa43..cd0b1a724 100644 --- a/tests/certification/state/sqlserver/go.mod +++ b/tests/certification/state/sqlserver/go.mod @@ -145,3 +145,5 @@ require ( replace github.com/dapr/components-contrib/tests/certification => ../../ replace github.com/dapr/components-contrib => ../../../../ + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 diff --git a/tests/certification/state/sqlserver/go.sum b/tests/certification/state/sqlserver/go.sum index 5cf708fb1..ea56cf3ee 100644 --- a/tests/certification/state/sqlserver/go.sum +++ b/tests/certification/state/sqlserver/go.sum @@ -41,6 +41,8 @@ github.com/Azure/azure-sdk-for-go/sdk/internal v0.7.0/go.mod h1:yqy467j36fJxcRV2 github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181 h1:bNMar/yM9t63YDipWoMwD3Juv0dWSs4QvPAkb4yVY/0= +github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181/go.mod h1:u4qgGgAhx1bH0NJamDVnWEgffpHDiLj/CiaRjCheBA0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk= @@ -91,8 +93,6 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef h1:Af/gPpnu2Nti4gvEMheT6eCZkDp3ywSmw7O5bRljnUA= -github.com/dapr/dapr v1.9.4-0.20230126201028-81bc49c384ef/go.mod h1:8l7XfAWZlgIWDoQ5SJyjpjym9JohXmaBWBI6MOotdlU= github.com/dapr/go-sdk v1.6.0 h1:jg5A2khSCHF8bGZsig5RWN/gD0jjitszc2V6Uq2pPdY= github.com/dapr/go-sdk v1.6.0/go.mod h1:KLQBltoD9K0w5hKTihdcyg9Epob9gypwL5dYcQzPro4= github.com/dapr/kit v0.0.4 h1:i+7TIN4crC1Mo0JFyWpIkwAE8orlliA0O6/ibvs2AaE= diff --git a/tests/conformance/bindings/bindings.go b/tests/conformance/bindings/bindings.go index 80e131660..f0a08ab8a 100644 --- a/tests/conformance/bindings/bindings.go +++ b/tests/conformance/bindings/bindings.go @@ -127,7 +127,7 @@ func ConformanceTests(t *testing.T, props map[string]string, inputBinding bindin // Check for an output binding specific operation before init if config.HasOperation("operations") { testLogger.Info("Init output binding ...") - err := outputBinding.Init(bindings.Metadata{ + err := outputBinding.Init(context.Background(), bindings.Metadata{ Base: metadata.Base{Properties: props}, }) assert.NoError(t, err, "expected no error setting up output binding") @@ -135,7 +135,7 @@ func ConformanceTests(t *testing.T, props map[string]string, inputBinding bindin // Check for an input binding specific operation before init if config.HasOperation("read") { testLogger.Info("Init input binding ...") - err := inputBinding.Init(bindings.Metadata{ + err := inputBinding.Init(context.Background(), bindings.Metadata{ Base: metadata.Base{Properties: props}, }) assert.NoError(t, err, "expected no error setting up input binding") @@ -145,7 +145,7 @@ func ConformanceTests(t *testing.T, props map[string]string, inputBinding bindin t.Run("ping", func(t *testing.T) { if config.HasOperation("read") { - errInp := bindings.PingInpBinding(inputBinding) + errInp := bindings.PingInpBinding(context.Background(), inputBinding) // TODO: Ideally, all stable components should implenment ping function, // so will only assert assert.NoError(t, err) finally, i.e. when current implementation // implements ping in existing stable components @@ -156,7 +156,7 @@ func ConformanceTests(t *testing.T, props map[string]string, inputBinding bindin } } if config.HasOperation("operations") { - errOut := bindings.PingOutBinding(outputBinding) + errOut := bindings.PingOutBinding(context.Background(), outputBinding) // TODO: Ideally, all stable components should implenment ping function, // so will only assert assert.NoError(t, err) finally, i.e. when current implementation // implements ping in existing stable components @@ -297,10 +297,8 @@ func ConformanceTests(t *testing.T, props map[string]string, inputBinding bindin // Check for an input-binding specific operation before close if config.HasOperation("read") { testLogger.Info("Closing read connection ...") - if closer, ok := inputBinding.(io.Closer); ok { - err := closer.Close() - assert.NoError(t, err, "expected no error closing input binding") - } + err := inputBinding.Close() + assert.NoError(t, err, "expected no error closing input binding") } // Check for an output-binding specific operation before close if config.HasOperation("operations") { diff --git a/tests/conformance/configuration/configuration.go b/tests/conformance/configuration/configuration.go index 66cba353e..4ecb4c92d 100644 --- a/tests/conformance/configuration/configuration.go +++ b/tests/conformance/configuration/configuration.go @@ -142,7 +142,7 @@ func ConformanceTests(t *testing.T, props map[string]string, store configuration processedC3 := make(chan *configuration.UpdateEvent, keyCount*4) t.Run("init", func(t *testing.T) { - err := store.Init(configuration.Metadata{ + err := store.Init(context.Background(), configuration.Metadata{ Base: metadata.Base{Properties: props}, }) assert.Nil(t, err) diff --git a/tests/conformance/pubsub/pubsub.go b/tests/conformance/pubsub/pubsub.go index 16da5a691..523931fcd 100644 --- a/tests/conformance/pubsub/pubsub.go +++ b/tests/conformance/pubsub/pubsub.go @@ -103,14 +103,14 @@ func ConformanceTests(t *testing.T, props map[string]string, ps pubsub.PubSub, c // Init t.Run("init", func(t *testing.T) { - err := ps.Init(pubsub.Metadata{ + err := ps.Init(context.Background(), pubsub.Metadata{ Base: metadata.Base{Properties: props}, }) assert.NoError(t, err, "expected no error on setting up pubsub") }) t.Run("ping", func(t *testing.T) { - err := pubsub.Ping(ps) + err := pubsub.Ping(context.Background(), ps) // TODO: Ideally, all stable components should implenment ping function, // so will only assert assert.Nil(t, err) finally, i.e. when current implementation // implements ping in existing stable components diff --git a/tests/conformance/secretstores/secretstores.go b/tests/conformance/secretstores/secretstores.go index 163292b42..307d6f4cc 100644 --- a/tests/conformance/secretstores/secretstores.go +++ b/tests/conformance/secretstores/secretstores.go @@ -51,14 +51,14 @@ func ConformanceTests(t *testing.T, props map[string]string, store secretstores. // Init t.Run("init", func(t *testing.T) { - err := store.Init(secretstores.Metadata{ + err := store.Init(context.Background(), secretstores.Metadata{ Base: metadata.Base{Properties: props}, }) assert.NoError(t, err, "expected no error on initializing store") }) t.Run("ping", func(t *testing.T) { - err := secretstores.Ping(store) + err := secretstores.Ping(context.Background(), store) // TODO: Ideally, all stable components should implenment ping function, // so will only assert assert.Nil(t, err) finally, i.e. when current implementation // implements ping in existing stable components diff --git a/tests/conformance/state/state.go b/tests/conformance/state/state.go index 166ba087e..21f345c42 100644 --- a/tests/conformance/state/state.go +++ b/tests/conformance/state/state.go @@ -223,7 +223,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St } t.Run("init", func(t *testing.T) { - err := statestore.Init(state.Metadata{ + err := statestore.Init(context.Background(), state.Metadata{ Base: metadata.Base{Properties: props}, }) assert.Nil(t, err) @@ -235,7 +235,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St } t.Run("ping", func(t *testing.T) { - err := state.Ping(statestore) + err := state.Ping(context.Background(), statestore) // TODO: Ideally, all stable components should implenment ping function, // so will only assert assert.Nil(t, err) finally, i.e. when current implementation // implements ping in existing stable components diff --git a/tests/e2e/bindings/zeebe/helper.go b/tests/e2e/bindings/zeebe/helper.go index 58f61eb13..20617f17d 100644 --- a/tests/e2e/bindings/zeebe/helper.go +++ b/tests/e2e/bindings/zeebe/helper.go @@ -69,7 +69,7 @@ func Command() (*command.ZeebeCommand, error) { envVars := GetEnvVars() cmd := command.NewZeebeCommand(testLogger).(*command.ZeebeCommand) - err := cmd.Init(bindings.Metadata{Base: metadata.Base{ + err := cmd.Init(context.Background(), bindings.Metadata{Base: metadata.Base{ Name: "test", Properties: map[string]string{ "gatewayAddr": envVars.ZeebeBrokerHost + ":" + envVars.ZeebeBrokerGatewayPort, @@ -103,7 +103,7 @@ func JobWorker(jobType string, additionalMetadata ...MetadataPair) (*jobworker.Z } cmd := jobworker.NewZeebeJobWorker(testLogger).(*jobworker.ZeebeJobWorker) - if err := cmd.Init(metadata); err != nil { + if err := cmd.Init(context.Background(), metadata); err != nil { return nil, err } diff --git a/tests/e2e/pubsub/jetstream/go.mod b/tests/e2e/pubsub/jetstream/go.mod index 9c2998882..6282e92f9 100644 --- a/tests/e2e/pubsub/jetstream/go.mod +++ b/tests/e2e/pubsub/jetstream/go.mod @@ -22,3 +22,5 @@ require ( ) replace github.com/dapr/components-contrib => ../../../../ + +replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181