Propagate context from caller to appropriate places in the code (#2474)

* Propagates contexts to callers where appropriate.

Signed-off-by: joshvanl <me@joshvanl.dev>

* Updates units tests with new func signature

Signed-off-by: joshvanl <me@joshvanl.dev>

* Fix linting errors

Signed-off-by: joshvanl <me@joshvanl.dev>

* Add atomic gate to alicloud rocketmq close channel.

Signed-off-by: joshvanl <me@joshvanl.dev>

* bindings/aws/kinesis use a separate ctx variable name

Signed-off-by: joshvanl <me@joshvanl.dev>

* binding/kafka: use atomic to prevent closing the channel twice

Signed-off-by: joshvanl <me@joshvanl.dev>

* bindings/mqtt3: use atomic bool to prevent close channel being closed multiple times

Signed-off-by: joshvanl <me@joshvanl.dev>

* bindings/mqtt3: use Background context for handle operations:w

Signed-off-by: joshvanl <me@joshvanl.dev>

* state/cocroachdb: add context to Ping()

Signed-off-by: joshvanl <me@joshvanl.dev>

* bindings/postgres: add comment explaining use of context.

Signed-off-by: joshvanl <me@joshvanl.dev>

* Adds comment header to health/pinger.go

Signed-off-by: joshvanl <me@joshvanl.dev>

* pubsub/aws/snssqs: add waitgroup to wait for all go routines to finish
and block on Close(). Shuts down the subscription if there are no topic
handlers.

Signed-off-by: joshvanl <me@joshvanl.dev>

* pubsub/mqtt3: add atomic bool to prevent multiple channel closes. Add
wait group to block close on all goroutines to finish.

Signed-off-by: joshvanl <me@joshvanl.dev>

* pubsub/rabbitmq: fixes race conditions, uses atomic to prevent multiple
closes, add wait group to block close on all goroutines

Signed-off-by: joshvanl <me@joshvanl.dev>

* pubsub/redis: revert ctx passed when it could be cancelled. Add wait
group wait when closing.

Signed-off-by: joshvanl <me@joshvanl.dev>

* state/postges: pass context in init, and wait group on close.

Signed-off-by: joshvanl <me@joshvanl.dev>

* Update all `Ping()` to `PingContext()` where possible.

Signed-off-by: joshvanl <me@joshvanl.dev>

* state/in-memory: add atomic bool to prevent closing channel multiple
times. Add wait group to block on close()

Signed-off-by: joshvanl <me@joshvanl.dev>

* state/mysql: don't use same ctx variable name

Signed-off-by: joshvanl <me@joshvanl.dev>

* Pass correct loop context to redis go routines

Signed-off-by: joshvanl <me@joshvanl.dev>

* Rename context when creating timeouts in state

Signed-off-by: joshvanl <me@joshvanl.dev>

* Remove state.Features() from requiring a context

Signed-off-by: joshvanl <me@joshvanl.dev>

* Revert wasm request handle Close func to be without context to
implement io.Closer interface. Add 5 second timeout. Add io.Closer
assertion in test.

Signed-off-by: joshvanl <me@joshvanl.dev>

* Remove superfluous go lint vet directive

Signed-off-by: joshvanl <me@joshvanl.dev>

* Change Configuration Init function to take context

Signed-off-by: joshvanl <me@joshvanl.dev>

* Updates input binding interface to include a `Close() error` function. `Close`
blocks until all resources have been released and go routines have returned.

Signed-off-by: joshvanl <me@joshvanl.dev>

* Change `Close() error` in input binding struct to `io.Closer` interface.

Signed-off-by: joshvanl <me@joshvanl.dev>

* Update go.mod files to point to dapr/dapr PR https://github.com/dapr/dapr/pull/5831

Signed-off-by: joshvanl <me@joshvanl.dev>

* pubsub/redis: watch closeCh to shutdown worker instead of init context.

Signed-off-by: joshvanl <me@joshvanl.dev>

* pubsub/aws/snssqs + bindings/kubemq: ensure closeCh is caught so Close
correctly returns

Signed-off-by: joshvanl <me@joshvanl.dev>

* Close kubemq binding client on close. Ensure kafka consumer channel
cannot be closed more than once.

Signed-off-by: joshvanl <me@joshvanl.dev>

* Tweaks

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>

* Fixed cert tests

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>

* binding/mqtt3: add inline Background context instead of passing to
handleMessage

Signed-off-by: joshvanl <me@joshvanl.dev>

* pubsub/mqtt3: remove context from createSubscriberClientOptions

Signed-off-by: joshvanl <me@joshvanl.dev>

* pubsub/mqtt3: Remove `ResetConnection` func

Signed-off-by: joshvanl <me@joshvanl.dev>

* pubsub/kafka: Don't resubscribe if Subscribe is cancelled.

Signed-off-by: joshvanl <me@joshvanl.dev>

* binding/mqtt3: don't use context to control establishing connection

Signed-off-by: joshvanl <me@joshvanl.dev>

* bindings/mqtt3: Fix linting errors

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
Co-authored-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
Co-authored-by: Yaron Schneider <schneider.yaron@live.com>
This commit is contained in:
Josh van Leeuwen 2023-02-16 22:18:35 +00:00 committed by GitHub
parent 210c8c3c59
commit d098e38d6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
297 changed files with 1797 additions and 1120 deletions

View File

@ -18,3 +18,5 @@ require (
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect
) )
replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181

View File

@ -33,3 +33,5 @@ require (
google.golang.org/protobuf v1.28.1 // indirect google.golang.org/protobuf v1.28.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )
replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181

View File

@ -80,7 +80,7 @@ func NewDingTalkWebhook(l logger.Logger) bindings.InputOutputBinding {
} }
// Init performs metadata parsing. // Init performs metadata parsing.
func (t *DingTalkWebhook) Init(metadata bindings.Metadata) error { func (t *DingTalkWebhook) Init(_ context.Context, metadata bindings.Metadata) error {
var err error var err error
if err = t.settings.Decode(metadata.Properties); err != nil { if err = t.settings.Decode(metadata.Properties); err != nil {
return fmt.Errorf("dingtalk configuration error: %w", err) return fmt.Errorf("dingtalk configuration error: %w", err)
@ -107,6 +107,13 @@ func (t *DingTalkWebhook) Read(ctx context.Context, handler bindings.Handler) er
return nil return nil
} }
func (t *DingTalkWebhook) Close() error {
webhooks.Lock()
defer webhooks.Unlock()
delete(webhooks.m, t.settings.ID)
return nil
}
// Operations returns list of operations supported by dingtalk webhook binding. // Operations returns list of operations supported by dingtalk webhook binding.
func (t *DingTalkWebhook) Operations() []bindings.OperationKind { func (t *DingTalkWebhook) Operations() []bindings.OperationKind {
return []bindings.OperationKind{bindings.CreateOperation, bindings.GetOperation} return []bindings.OperationKind{bindings.CreateOperation, bindings.GetOperation}

View File

@ -57,7 +57,7 @@ func TestPublishMsg(t *testing.T) { //nolint:paralleltest
}}} }}}
d := NewDingTalkWebhook(logger.NewLogger("test")) d := NewDingTalkWebhook(logger.NewLogger("test"))
err := d.Init(m) err := d.Init(context.Background(), m)
require.NoError(t, err) require.NoError(t, err)
req := &bindings.InvokeRequest{Data: []byte(msg), Operation: bindings.CreateOperation, Metadata: map[string]string{}} req := &bindings.InvokeRequest{Data: []byte(msg), Operation: bindings.CreateOperation, Metadata: map[string]string{}}
@ -78,7 +78,7 @@ func TestBindingReadAndInvoke(t *testing.T) { //nolint:paralleltest
}} }}
d := NewDingTalkWebhook(logger.NewLogger("test")) d := NewDingTalkWebhook(logger.NewLogger("test"))
err := d.Init(m) err := d.Init(context.Background(), m)
assert.NoError(t, err) assert.NoError(t, err)
var count int32 var count int32
@ -106,3 +106,18 @@ func TestBindingReadAndInvoke(t *testing.T) { //nolint:paralleltest
require.FailNow(t, "read timeout") require.FailNow(t, "read timeout")
} }
} }
func TestBindingClose(t *testing.T) {
d := NewDingTalkWebhook(logger.NewLogger("test"))
m := bindings.Metadata{Base: metadata.Base{
Name: "test",
Properties: map[string]string{
"url": "/test",
"secret": "",
"id": "x",
},
}}
assert.NoError(t, d.Init(context.Background(), m))
assert.NoError(t, d.Close())
assert.NoError(t, d.Close(), "second close should not error")
}

View File

@ -45,7 +45,7 @@ func NewAliCloudOSS(logger logger.Logger) bindings.OutputBinding {
} }
// Init does metadata parsing and connection creation. // Init does metadata parsing and connection creation.
func (s *AliCloudOSS) Init(metadata bindings.Metadata) error { func (s *AliCloudOSS) Init(_ context.Context, metadata bindings.Metadata) error {
m, err := s.parseMetadata(metadata) m, err := s.parseMetadata(metadata)
if err != nil { if err != nil {
return err return err

View File

@ -31,7 +31,7 @@ type Callback struct {
} }
// parse metadata field // parse metadata field
func (s *AliCloudSlsLogstorage) Init(metadata bindings.Metadata) error { func (s *AliCloudSlsLogstorage) Init(_ context.Context, metadata bindings.Metadata) error {
m, err := s.parseMeta(metadata) m, err := s.parseMeta(metadata)
if err != nil { if err != nil {
return err return err

View File

@ -58,7 +58,7 @@ func NewAliCloudTableStore(log logger.Logger) bindings.OutputBinding {
} }
} }
func (s *AliCloudTableStore) Init(metadata bindings.Metadata) error { func (s *AliCloudTableStore) Init(_ context.Context, metadata bindings.Metadata) error {
m, err := s.parseMetadata(metadata) m, err := s.parseMetadata(metadata)
if err != nil { if err != nil {
return err return err

View File

@ -51,7 +51,7 @@ func TestDataEncodeAndDecode(t *testing.T) {
metadata := bindings.Metadata{Base: metadata.Base{ metadata := bindings.Metadata{Base: metadata.Base{
Properties: getTestProperties(), Properties: getTestProperties(),
}} }}
aliCloudTableStore.Init(metadata) aliCloudTableStore.Init(context.Background(), metadata)
// test create // test create
putData := map[string]interface{}{ putData := map[string]interface{}{

View File

@ -78,7 +78,7 @@ func NewAPNS(logger logger.Logger) bindings.OutputBinding {
// Init will configure the APNS output binding using the metadata specified // Init will configure the APNS output binding using the metadata specified
// in the binding's configuration. // in the binding's configuration.
func (a *APNS) Init(metadata bindings.Metadata) error { func (a *APNS) Init(ctx context.Context, metadata bindings.Metadata) error {
if err := a.makeURLPrefix(metadata); err != nil { if err := a.makeURLPrefix(metadata); err != nil {
return err return err
} }

View File

@ -51,7 +51,7 @@ func TestInit(t *testing.T) {
}, },
}} }}
binding := NewAPNS(testLogger).(*APNS) binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata) err := binding.Init(context.Background(), metadata)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, developmentPrefix, binding.urlPrefix) assert.Equal(t, developmentPrefix, binding.urlPrefix)
}) })
@ -66,7 +66,7 @@ func TestInit(t *testing.T) {
}, },
}} }}
binding := NewAPNS(testLogger).(*APNS) binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata) err := binding.Init(context.Background(), metadata)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, productionPrefix, binding.urlPrefix) assert.Equal(t, productionPrefix, binding.urlPrefix)
}) })
@ -80,7 +80,7 @@ func TestInit(t *testing.T) {
}, },
}} }}
binding := NewAPNS(testLogger).(*APNS) binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata) err := binding.Init(context.Background(), metadata)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, productionPrefix, binding.urlPrefix) assert.Equal(t, productionPrefix, binding.urlPrefix)
}) })
@ -95,7 +95,7 @@ func TestInit(t *testing.T) {
}, },
}} }}
binding := NewAPNS(testLogger).(*APNS) binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata) err := binding.Init(context.Background(), metadata)
assert.Error(t, err, "invalid value for development parameter: True") assert.Error(t, err, "invalid value for development parameter: True")
}) })
@ -107,7 +107,7 @@ func TestInit(t *testing.T) {
}, },
}} }}
binding := NewAPNS(testLogger).(*APNS) binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata) err := binding.Init(context.Background(), metadata)
assert.Error(t, err, "the key-id parameter is required") assert.Error(t, err, "the key-id parameter is required")
}) })
@ -120,7 +120,7 @@ func TestInit(t *testing.T) {
}, },
}} }}
binding := NewAPNS(testLogger).(*APNS) binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata) err := binding.Init(context.Background(), metadata)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, testKeyID, binding.authorizationBuilder.keyID) assert.Equal(t, testKeyID, binding.authorizationBuilder.keyID)
}) })
@ -133,7 +133,7 @@ func TestInit(t *testing.T) {
}, },
}} }}
binding := NewAPNS(testLogger).(*APNS) binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata) err := binding.Init(context.Background(), metadata)
assert.Error(t, err, "the team-id parameter is required") assert.Error(t, err, "the team-id parameter is required")
}) })
@ -146,7 +146,7 @@ func TestInit(t *testing.T) {
}, },
}} }}
binding := NewAPNS(testLogger).(*APNS) binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata) err := binding.Init(context.Background(), metadata)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, testTeamID, binding.authorizationBuilder.teamID) assert.Equal(t, testTeamID, binding.authorizationBuilder.teamID)
}) })
@ -159,7 +159,7 @@ func TestInit(t *testing.T) {
}, },
}} }}
binding := NewAPNS(testLogger).(*APNS) binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata) err := binding.Init(context.Background(), metadata)
assert.Error(t, err, "the private-key parameter is required") assert.Error(t, err, "the private-key parameter is required")
}) })
@ -172,7 +172,7 @@ func TestInit(t *testing.T) {
}, },
}} }}
binding := NewAPNS(testLogger).(*APNS) binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata) err := binding.Init(context.Background(), metadata)
assert.Nil(t, err) assert.Nil(t, err)
assert.NotNil(t, binding.authorizationBuilder.privateKey) assert.NotNil(t, binding.authorizationBuilder.privateKey)
}) })
@ -335,7 +335,7 @@ func makeTestBinding(t *testing.T, log logger.Logger) *APNS {
privateKeyKey: testPrivateKey, privateKeyKey: testPrivateKey,
}, },
}} }}
err := testBinding.Init(bindingMetadata) err := testBinding.Init(context.Background(), bindingMetadata)
assert.Nil(t, err) assert.Nil(t, err)
return testBinding return testBinding

View File

@ -49,7 +49,7 @@ func NewDynamoDB(logger logger.Logger) bindings.OutputBinding {
} }
// Init performs connection parsing for DynamoDB. // Init performs connection parsing for DynamoDB.
func (d *DynamoDB) Init(metadata bindings.Metadata) error { func (d *DynamoDB) Init(_ context.Context, metadata bindings.Metadata) error {
meta, err := d.getDynamoDBMetadata(metadata) meta, err := d.getDynamoDBMetadata(metadata)
if err != nil { if err != nil {
return err return err

View File

@ -15,7 +15,10 @@ package kinesis
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"sync"
"sync/atomic"
"time" "time"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
@ -45,6 +48,10 @@ type AWSKinesis struct {
streamARN *string streamARN *string
consumerARN *string consumerARN *string
logger logger.Logger logger logger.Logger
closed atomic.Bool
closeCh chan struct{}
wg sync.WaitGroup
} }
type kinesisMetadata struct { type kinesisMetadata struct {
@ -83,11 +90,14 @@ type recordProcessor struct {
// NewAWSKinesis returns a new AWS Kinesis instance. // NewAWSKinesis returns a new AWS Kinesis instance.
func NewAWSKinesis(logger logger.Logger) bindings.InputOutputBinding { func NewAWSKinesis(logger logger.Logger) bindings.InputOutputBinding {
return &AWSKinesis{logger: logger} return &AWSKinesis{
logger: logger,
closeCh: make(chan struct{}),
}
} }
// Init does metadata parsing and connection creation. // Init does metadata parsing and connection creation.
func (a *AWSKinesis) Init(metadata bindings.Metadata) error { func (a *AWSKinesis) Init(ctx context.Context, metadata bindings.Metadata) error {
m, err := a.parseMetadata(metadata) m, err := a.parseMetadata(metadata)
if err != nil { if err != nil {
return err return err
@ -107,7 +117,7 @@ func (a *AWSKinesis) Init(metadata bindings.Metadata) error {
} }
streamName := aws.String(m.StreamName) streamName := aws.String(m.StreamName)
stream, err := client.DescribeStream(&kinesis.DescribeStreamInput{ stream, err := client.DescribeStreamWithContext(ctx, &kinesis.DescribeStreamInput{
StreamName: streamName, StreamName: streamName,
}) })
if err != nil { if err != nil {
@ -147,6 +157,10 @@ func (a *AWSKinesis) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*
} }
func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err error) { func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err error) {
if a.closed.Load() {
return errors.New("binding is closed")
}
if a.metadata.KinesisConsumerMode == SharedThroughput { if a.metadata.KinesisConsumerMode == SharedThroughput {
a.worker = worker.NewWorker(a.recordProcessorFactory(ctx, handler), a.workerConfig) a.worker = worker.NewWorker(a.recordProcessorFactory(ctx, handler), a.workerConfig)
err = a.worker.Start() err = a.worker.Start()
@ -166,8 +180,13 @@ func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err er
} }
// Wait for context cancelation then stop // Wait for context cancelation then stop
a.wg.Add(1)
go func() { go func() {
<-ctx.Done() defer a.wg.Done()
select {
case <-ctx.Done():
case <-a.closeCh:
}
if a.metadata.KinesisConsumerMode == SharedThroughput { if a.metadata.KinesisConsumerMode == SharedThroughput {
a.worker.Shutdown() a.worker.Shutdown()
} else if a.metadata.KinesisConsumerMode == ExtendedFanout { } else if a.metadata.KinesisConsumerMode == ExtendedFanout {
@ -188,14 +207,25 @@ func (a *AWSKinesis) Subscribe(ctx context.Context, streamDesc kinesis.StreamDes
a.consumerARN = consumerARN a.consumerARN = consumerARN
a.wg.Add(len(streamDesc.Shards))
for i, shard := range streamDesc.Shards { for i, shard := range streamDesc.Shards {
go func(idx int, s *kinesis.Shard) error { go func(idx int, s *kinesis.Shard) {
defer a.wg.Done()
// Reconnection backoff // Reconnection backoff
bo := backoff.NewExponentialBackOff() bo := backoff.NewExponentialBackOff()
bo.InitialInterval = 2 * time.Second bo.InitialInterval = 2 * time.Second
// Repeat until context is canceled // Repeat until context is canceled or binding closed.
for ctx.Err() == nil { for {
select {
case <-ctx.Done():
return
case <-a.closeCh:
return
default:
}
sub, err := a.client.SubscribeToShardWithContext(ctx, &kinesis.SubscribeToShardInput{ sub, err := a.client.SubscribeToShardWithContext(ctx, &kinesis.SubscribeToShardInput{
ConsumerARN: consumerARN, ConsumerARN: consumerARN,
ShardId: s.ShardId, ShardId: s.ShardId,
@ -204,8 +234,12 @@ func (a *AWSKinesis) Subscribe(ctx context.Context, streamDesc kinesis.StreamDes
if err != nil { if err != nil {
wait := bo.NextBackOff() wait := bo.NextBackOff()
a.logger.Errorf("Error while reading from shard %v: %v. Attempting to reconnect in %s...", s.ShardId, err, wait) a.logger.Errorf("Error while reading from shard %v: %v. Attempting to reconnect in %s...", s.ShardId, err, wait)
time.Sleep(wait) select {
continue case <-ctx.Done():
return
case <-time.After(wait):
continue
}
} }
// Reset the backoff on connection success // Reset the backoff on connection success
@ -223,22 +257,30 @@ func (a *AWSKinesis) Subscribe(ctx context.Context, streamDesc kinesis.StreamDes
} }
} }
} }
return nil
}(i, shard) }(i, shard)
} }
return nil return nil
} }
func (a *AWSKinesis) ensureConsumer(parentCtx context.Context, streamARN *string) (*string, error) { func (a *AWSKinesis) Close() error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) if a.closed.CompareAndSwap(false, true) {
consumer, err := a.client.DescribeStreamConsumerWithContext(ctx, &kinesis.DescribeStreamConsumerInput{ close(a.closeCh)
}
a.wg.Wait()
return nil
}
func (a *AWSKinesis) ensureConsumer(ctx context.Context, streamARN *string) (*string, error) {
// Only set timeout on consumer call.
conCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
consumer, err := a.client.DescribeStreamConsumerWithContext(conCtx, &kinesis.DescribeStreamConsumerInput{
ConsumerName: &a.metadata.ConsumerName, ConsumerName: &a.metadata.ConsumerName,
StreamARN: streamARN, StreamARN: streamARN,
}) })
cancel()
if err != nil { if err != nil {
return a.registerConsumer(parentCtx, streamARN) return a.registerConsumer(ctx, streamARN)
} }
return consumer.ConsumerDescription.ConsumerARN, nil return consumer.ConsumerDescription.ConsumerARN, nil

View File

@ -99,7 +99,7 @@ func NewAWSS3(logger logger.Logger) bindings.OutputBinding {
} }
// Init does metadata parsing and connection creation. // Init does metadata parsing and connection creation.
func (s *AWSS3) Init(metadata bindings.Metadata) error { func (s *AWSS3) Init(_ context.Context, metadata bindings.Metadata) error {
m, err := s.parseMetadata(metadata) m, err := s.parseMetadata(metadata)
if err != nil { if err != nil {
return err return err

View File

@ -61,7 +61,7 @@ func NewAWSSES(logger logger.Logger) bindings.OutputBinding {
} }
// Init does metadata parsing. // Init does metadata parsing.
func (a *AWSSES) Init(metadata bindings.Metadata) error { func (a *AWSSES) Init(_ context.Context, metadata bindings.Metadata) error {
// Parse input metadata // Parse input metadata
meta, err := a.parseMetadata(metadata) meta, err := a.parseMetadata(metadata)
if err != nil { if err != nil {

View File

@ -53,7 +53,7 @@ func NewAWSSNS(logger logger.Logger) bindings.OutputBinding {
} }
// Init does metadata parsing. // Init does metadata parsing.
func (a *AWSSNS) Init(metadata bindings.Metadata) error { func (a *AWSSNS) Init(_ context.Context, metadata bindings.Metadata) error {
m, err := a.parseMetadata(metadata) m, err := a.parseMetadata(metadata)
if err != nil { if err != nil {
return err return err

View File

@ -16,6 +16,9 @@ package sqs
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"sync"
"sync/atomic"
"time" "time"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
@ -31,7 +34,10 @@ type AWSSQS struct {
Client *sqs.SQS Client *sqs.SQS
QueueURL *string QueueURL *string
logger logger.Logger logger logger.Logger
wg sync.WaitGroup
closeCh chan struct{}
closed atomic.Bool
} }
type sqsMetadata struct { type sqsMetadata struct {
@ -45,11 +51,14 @@ type sqsMetadata struct {
// NewAWSSQS returns a new AWS SQS instance. // NewAWSSQS returns a new AWS SQS instance.
func NewAWSSQS(logger logger.Logger) bindings.InputOutputBinding { func NewAWSSQS(logger logger.Logger) bindings.InputOutputBinding {
return &AWSSQS{logger: logger} return &AWSSQS{
logger: logger,
closeCh: make(chan struct{}),
}
} }
// Init does metadata parsing and connection creation. // Init does metadata parsing and connection creation.
func (a *AWSSQS) Init(metadata bindings.Metadata) error { func (a *AWSSQS) Init(ctx context.Context, metadata bindings.Metadata) error {
m, err := a.parseSQSMetadata(metadata) m, err := a.parseSQSMetadata(metadata)
if err != nil { if err != nil {
return err return err
@ -61,7 +70,7 @@ func (a *AWSSQS) Init(metadata bindings.Metadata) error {
} }
queueName := m.QueueName queueName := m.QueueName
resultURL, err := client.GetQueueUrl(&sqs.GetQueueUrlInput{ resultURL, err := client.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{
QueueName: aws.String(queueName), QueueName: aws.String(queueName),
}) })
if err != nil { if err != nil {
@ -89,9 +98,20 @@ func (a *AWSSQS) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bind
} }
func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error { func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error {
if a.closed.Load() {
return errors.New("binding is closed")
}
a.wg.Add(1)
go func() { go func() {
// Repeat until the context is canceled defer a.wg.Done()
for ctx.Err() == nil {
// Repeat until the context is canceled or component is closed
for {
if ctx.Err() != nil || a.closed.Load() {
return
}
result, err := a.Client.ReceiveMessageWithContext(ctx, &sqs.ReceiveMessageInput{ result, err := a.Client.ReceiveMessageWithContext(ctx, &sqs.ReceiveMessageInput{
QueueUrl: a.QueueURL, QueueUrl: a.QueueURL,
AttributeNames: aws.StringSlice([]string{ AttributeNames: aws.StringSlice([]string{
@ -126,13 +146,25 @@ func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error {
} }
} }
time.Sleep(time.Millisecond * 50) select {
case <-ctx.Done():
case <-a.closeCh:
case <-time.After(time.Millisecond * 50):
}
} }
}() }()
return nil return nil
} }
func (a *AWSSQS) Close() error {
if a.closed.CompareAndSwap(false, true) {
close(a.closeCh)
}
a.wg.Wait()
return nil
}
func (a *AWSSQS) parseSQSMetadata(metadata bindings.Metadata) (*sqsMetadata, error) { func (a *AWSSQS) parseSQSMetadata(metadata bindings.Metadata) (*sqsMetadata, error) {
b, err := json.Marshal(metadata.Properties) b, err := json.Marshal(metadata.Properties)
if err != nil { if err != nil {

View File

@ -92,7 +92,7 @@ func NewAzureBlobStorage(logger logger.Logger) bindings.OutputBinding {
} }
// Init performs metadata parsing. // Init performs metadata parsing.
func (a *AzureBlobStorage) Init(metadata bindings.Metadata) error { func (a *AzureBlobStorage) Init(_ context.Context, metadata bindings.Metadata) error {
var err error var err error
a.containerClient, a.metadata, err = storageinternal.CreateContainerStorageClient(a.logger, metadata.Properties) a.containerClient, a.metadata, err = storageinternal.CreateContainerStorageClient(a.logger, metadata.Properties)
if err != nil { if err != nil {

View File

@ -53,7 +53,7 @@ func NewCosmosDB(logger logger.Logger) bindings.OutputBinding {
} }
// Init performs CosmosDB connection parsing and connecting. // Init performs CosmosDB connection parsing and connecting.
func (c *CosmosDB) Init(metadata bindings.Metadata) error { func (c *CosmosDB) Init(ctx context.Context, metadata bindings.Metadata) error {
m, err := c.parseMetadata(metadata) m, err := c.parseMetadata(metadata)
if err != nil { if err != nil {
return err return err
@ -103,9 +103,9 @@ func (c *CosmosDB) Init(metadata bindings.Metadata) error {
} }
c.client = dbContainer c.client = dbContainer
ctx, cancel := context.WithTimeout(context.Background(), timeoutValue*time.Second) readCtx, readCancel := context.WithTimeout(ctx, timeoutValue*time.Second)
_, err = c.client.Read(ctx, nil) defer readCancel()
cancel() _, err = c.client.Read(readCtx, nil)
return err return err
} }

View File

@ -59,7 +59,7 @@ func NewCosmosDBGremlinAPI(logger logger.Logger) bindings.OutputBinding {
} }
// Init performs CosmosDBGremlinAPI connection parsing and connecting. // Init performs CosmosDBGremlinAPI connection parsing and connecting.
func (c *CosmosDBGremlinAPI) Init(metadata bindings.Metadata) error { func (c *CosmosDBGremlinAPI) Init(_ context.Context, metadata bindings.Metadata) error {
c.logger.Debug("Initializing Cosmos Graph DB binding") c.logger.Debug("Initializing Cosmos Graph DB binding")
m, err := c.parseMetadata(metadata) m, err := c.parseMetadata(metadata)

View File

@ -19,6 +19,8 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
"sync"
"sync/atomic"
"time" "time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore"
@ -42,6 +44,9 @@ const armOperationTimeout = 30 * time.Second
type AzureEventGrid struct { type AzureEventGrid struct {
metadata *azureEventGridMetadata metadata *azureEventGridMetadata
logger logger.Logger logger logger.Logger
closeCh chan struct{}
closed atomic.Bool
wg sync.WaitGroup
} }
type azureEventGridMetadata struct { type azureEventGridMetadata struct {
@ -70,11 +75,14 @@ type azureEventGridMetadata struct {
// NewAzureEventGrid returns a new Azure Event Grid instance. // NewAzureEventGrid returns a new Azure Event Grid instance.
func NewAzureEventGrid(logger logger.Logger) bindings.InputOutputBinding { func NewAzureEventGrid(logger logger.Logger) bindings.InputOutputBinding {
return &AzureEventGrid{logger: logger} return &AzureEventGrid{
logger: logger,
closeCh: make(chan struct{}),
}
} }
// Init performs metadata init. // Init performs metadata init.
func (a *AzureEventGrid) Init(metadata bindings.Metadata) error { func (a *AzureEventGrid) Init(_ context.Context, metadata bindings.Metadata) error {
m, err := a.parseMetadata(metadata) m, err := a.parseMetadata(metadata)
if err != nil { if err != nil {
return err return err
@ -85,6 +93,10 @@ func (a *AzureEventGrid) Init(metadata bindings.Metadata) error {
} }
func (a *AzureEventGrid) Read(ctx context.Context, handler bindings.Handler) error { func (a *AzureEventGrid) Read(ctx context.Context, handler bindings.Handler) error {
if a.closed.Load() {
return errors.New("binding is closed")
}
err := a.ensureInputBindingMetadata() err := a.ensureInputBindingMetadata()
if err != nil { if err != nil {
return err return err
@ -120,17 +132,22 @@ func (a *AzureEventGrid) Read(ctx context.Context, handler bindings.Handler) err
} }
// Run the server in background // Run the server in background
a.wg.Add(2)
go func() { go func() {
defer a.wg.Done()
a.logger.Debugf("About to start listening for Event Grid events at http://localhost:%s/api/events", a.metadata.HandshakePort) a.logger.Debugf("About to start listening for Event Grid events at http://localhost:%s/api/events", a.metadata.HandshakePort)
srvErr := srv.ListenAndServe(":" + a.metadata.HandshakePort) srvErr := srv.ListenAndServe(":" + a.metadata.HandshakePort)
if err != nil { if err != nil {
a.logger.Errorf("Error starting server: %v", srvErr) a.logger.Errorf("Error starting server: %v", srvErr)
} }
}() }()
// Close the server when context is canceled or binding closed.
// Close the server when context is canceled
go func() { go func() {
<-ctx.Done() defer a.wg.Done()
select {
case <-ctx.Done():
case <-a.closeCh:
}
srvErr := srv.Shutdown() srvErr := srv.Shutdown()
if err != nil { if err != nil {
a.logger.Errorf("Error shutting down server: %v", srvErr) a.logger.Errorf("Error shutting down server: %v", srvErr)
@ -149,6 +166,14 @@ func (a *AzureEventGrid) Operations() []bindings.OperationKind {
return []bindings.OperationKind{bindings.CreateOperation} return []bindings.OperationKind{bindings.CreateOperation}
} }
func (a *AzureEventGrid) Close() error {
if a.closed.CompareAndSwap(false, true) {
close(a.closeCh)
}
a.wg.Wait()
return nil
}
func (a *AzureEventGrid) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) { func (a *AzureEventGrid) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
err := a.ensureOutputBindingMetadata() err := a.ensureOutputBindingMetadata()
if err != nil { if err != nil {

View File

@ -37,7 +37,7 @@ func NewAzureEventHubs(logger logger.Logger) bindings.InputOutputBinding {
} }
// Init performs metadata init. // Init performs metadata init.
func (a *AzureEventHubs) Init(metadata bindings.Metadata) error { func (a *AzureEventHubs) Init(_ context.Context, metadata bindings.Metadata) error {
return a.AzureEventHubs.Init(metadata.Properties) return a.AzureEventHubs.Init(metadata.Properties)
} }

View File

@ -102,7 +102,7 @@ func testEventHubsBindingsAADAuthentication(t *testing.T) {
metadata := createEventHubsBindingsAADMetadata() metadata := createEventHubsBindingsAADMetadata()
eventHubsBindings := NewAzureEventHubs(log) eventHubsBindings := NewAzureEventHubs(log)
err := eventHubsBindings.Init(metadata) err := eventHubsBindings.Init(context.Background(), metadata)
require.NoError(t, err) require.NoError(t, err)
req := &bindings.InvokeRequest{ req := &bindings.InvokeRequest{
@ -146,7 +146,7 @@ func testReadIotHubEvents(t *testing.T) {
logger := logger.NewLogger("bindings.azure.eventhubs.integration.test") logger := logger.NewLogger("bindings.azure.eventhubs.integration.test")
eh := NewAzureEventHubs(logger) eh := NewAzureEventHubs(logger)
err := eh.Init(createIotHubBindingsMetadata()) err := eh.Init(context.Background(), createIotHubBindingsMetadata())
require.NoError(t, err) require.NoError(t, err)
// Invoke az CLI via bash script to send test IoT device events // Invoke az CLI via bash script to send test IoT device events

View File

@ -17,6 +17,8 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"sync"
"sync/atomic"
"time" "time"
servicebus "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus" servicebus "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus"
@ -39,17 +41,21 @@ type AzureServiceBusQueues struct {
client *impl.Client client *impl.Client
timeout time.Duration timeout time.Duration
logger logger.Logger logger logger.Logger
closed atomic.Bool
wg sync.WaitGroup
closeCh chan struct{}
} }
// NewAzureServiceBusQueues returns a new AzureServiceBusQueues instance. // NewAzureServiceBusQueues returns a new AzureServiceBusQueues instance.
func NewAzureServiceBusQueues(logger logger.Logger) bindings.InputOutputBinding { func NewAzureServiceBusQueues(logger logger.Logger) bindings.InputOutputBinding {
return &AzureServiceBusQueues{ return &AzureServiceBusQueues{
logger: logger, logger: logger,
closeCh: make(chan struct{}),
} }
} }
// Init parses connection properties and creates a new Service Bus Queue client. // Init parses connection properties and creates a new Service Bus Queue client.
func (a *AzureServiceBusQueues) Init(metadata bindings.Metadata) (err error) { func (a *AzureServiceBusQueues) Init(ctx context.Context, metadata bindings.Metadata) (err error) {
a.metadata, err = impl.ParseMetadata(metadata.Properties, a.logger, (impl.MetadataModeBinding | impl.MetadataModeQueues)) a.metadata, err = impl.ParseMetadata(metadata.Properties, a.logger, (impl.MetadataModeBinding | impl.MetadataModeQueues))
if err != nil { if err != nil {
return err return err
@ -62,7 +68,7 @@ func (a *AzureServiceBusQueues) Init(metadata bindings.Metadata) (err error) {
} }
// Will do nothing if DisableEntityManagement is false // Will do nothing if DisableEntityManagement is false
err = a.client.EnsureQueue(context.Background(), a.metadata.QueueName) err = a.client.EnsureQueue(ctx, a.metadata.QueueName)
if err != nil { if err != nil {
return err return err
} }
@ -100,14 +106,33 @@ func (a *AzureServiceBusQueues) Invoke(invokeCtx context.Context, req *bindings.
return nil, nil return nil, nil
} }
func (a *AzureServiceBusQueues) Read(subscribeCtx context.Context, handler bindings.Handler) error { func (a *AzureServiceBusQueues) Read(parentCtx context.Context, handler bindings.Handler) error {
if a.closed.Load() {
return errors.New("binding is closed")
}
// Reconnection backoff policy // Reconnection backoff policy
bo := backoff.NewExponentialBackOff() bo := backoff.NewExponentialBackOff()
bo.MaxElapsedTime = 0 bo.MaxElapsedTime = 0
bo.InitialInterval = time.Duration(a.metadata.MinConnectionRecoveryInSec) * time.Second bo.InitialInterval = time.Duration(a.metadata.MinConnectionRecoveryInSec) * time.Second
bo.MaxInterval = time.Duration(a.metadata.MaxConnectionRecoveryInSec) * time.Second bo.MaxInterval = time.Duration(a.metadata.MaxConnectionRecoveryInSec) * time.Second
subscribeCtx, subscribeCancel := context.WithCancel(parentCtx)
// Close the subscription context when the binding is closed.
a.wg.Add(2)
go func() { go func() {
defer a.wg.Done()
select {
case <-a.closeCh:
subscribeCancel()
case <-parentCtx.Done():
// nop
}
}()
go func() {
defer a.wg.Done()
// Reconnect loop. // Reconnect loop.
for { for {
sub := impl.NewSubscription(subscribeCtx, impl.SubsriptionOptions{ sub := impl.NewSubscription(subscribeCtx, impl.SubsriptionOptions{
@ -165,7 +190,12 @@ func (a *AzureServiceBusQueues) Read(subscribeCtx context.Context, handler bindi
wait := bo.NextBackOff() wait := bo.NextBackOff()
a.logger.Warnf("Subscription to queue %s lost connection, attempting to reconnect in %s...", a.metadata.QueueName, wait) a.logger.Warnf("Subscription to queue %s lost connection, attempting to reconnect in %s...", a.metadata.QueueName, wait)
time.Sleep(wait) select {
case <-time.After(wait):
// nop
case <-a.closeCh:
return
}
} }
}() }()
@ -204,7 +234,11 @@ func (a *AzureServiceBusQueues) getHandlerFn(handler bindings.Handler) impl.Hand
} }
func (a *AzureServiceBusQueues) Close() (err error) { func (a *AzureServiceBusQueues) Close() (err error) {
if a.closed.CompareAndSwap(false, true) {
close(a.closeCh)
}
a.logger.Debug("Closing component") a.logger.Debug("Closing component")
a.client.CloseSender(a.metadata.QueueName) a.client.CloseSender(a.metadata.QueueName)
a.wg.Wait()
return nil return nil
} }

View File

@ -78,7 +78,7 @@ type SignalR struct {
} }
// Init is responsible for initializing the SignalR output based on the metadata. // Init is responsible for initializing the SignalR output based on the metadata.
func (s *SignalR) Init(metadata bindings.Metadata) (err error) { func (s *SignalR) Init(_ context.Context, metadata bindings.Metadata) (err error) {
s.userAgent = "dapr-" + logger.DaprVersion s.userAgent = "dapr-" + logger.DaprVersion
err = s.parseMetadata(metadata.Properties) err = s.parseMetadata(metadata.Properties)

View File

@ -16,9 +16,12 @@ package storagequeues
import ( import (
"context" "context"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"net/url" "net/url"
"strconv" "strconv"
"sync"
"sync/atomic"
"time" "time"
"github.com/Azure/azure-storage-queue-go/azqueue" "github.com/Azure/azure-storage-queue-go/azqueue"
@ -40,9 +43,10 @@ type consumer struct {
// QueueHelper enables injection for testnig. // QueueHelper enables injection for testnig.
type QueueHelper interface { type QueueHelper interface {
Init(metadata bindings.Metadata) (*storageQueuesMetadata, error) Init(ctx context.Context, metadata bindings.Metadata) (*storageQueuesMetadata, error)
Write(ctx context.Context, data []byte, ttl *time.Duration) error Write(ctx context.Context, data []byte, ttl *time.Duration) error
Read(ctx context.Context, consumer *consumer) error Read(ctx context.Context, consumer *consumer) error
Close() error
} }
// AzureQueueHelper concrete impl of queue helper. // AzureQueueHelper concrete impl of queue helper.
@ -55,7 +59,7 @@ type AzureQueueHelper struct {
} }
// Init sets up this helper. // Init sets up this helper.
func (d *AzureQueueHelper) Init(metadata bindings.Metadata) (*storageQueuesMetadata, error) { func (d *AzureQueueHelper) Init(ctx context.Context, metadata bindings.Metadata) (*storageQueuesMetadata, error) {
m, err := parseMetadata(metadata) m, err := parseMetadata(metadata)
if err != nil { if err != nil {
return nil, err return nil, err
@ -89,9 +93,9 @@ func (d *AzureQueueHelper) Init(metadata bindings.Metadata) (*storageQueuesMetad
d.queueURL = azqueue.NewQueueURL(*URL, p) d.queueURL = azqueue.NewQueueURL(*URL, p)
} }
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) createCtx, createCancel := context.WithTimeout(ctx, 2*time.Minute)
_, err = d.queueURL.Create(ctx, azqueue.Metadata{}) _, err = d.queueURL.Create(createCtx, azqueue.Metadata{})
cancel() createCancel()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -128,7 +132,10 @@ func (d *AzureQueueHelper) Read(ctx context.Context, consumer *consumer) error {
} }
if res.NumMessages() == 0 { if res.NumMessages() == 0 {
// Queue was empty so back off by 10 seconds before trying again // Queue was empty so back off by 10 seconds before trying again
time.Sleep(10 * time.Second) select {
case <-time.After(10 * time.Second):
case <-ctx.Done():
}
return nil return nil
} }
mt := res.Message(0).Text mt := res.Message(0).Text
@ -162,6 +169,10 @@ func (d *AzureQueueHelper) Read(ctx context.Context, consumer *consumer) error {
return nil return nil
} }
func (d *AzureQueueHelper) Close() error {
return nil
}
// NewAzureQueueHelper creates new helper. // NewAzureQueueHelper creates new helper.
func NewAzureQueueHelper(logger logger.Logger) QueueHelper { func NewAzureQueueHelper(logger logger.Logger) QueueHelper {
return &AzureQueueHelper{ return &AzureQueueHelper{
@ -175,6 +186,10 @@ type AzureStorageQueues struct {
helper QueueHelper helper QueueHelper
logger logger.Logger logger logger.Logger
wg sync.WaitGroup
closeCh chan struct{}
closed atomic.Bool
} }
type storageQueuesMetadata struct { type storageQueuesMetadata struct {
@ -189,12 +204,16 @@ type storageQueuesMetadata struct {
// NewAzureStorageQueues returns a new AzureStorageQueues instance. // NewAzureStorageQueues returns a new AzureStorageQueues instance.
func NewAzureStorageQueues(logger logger.Logger) bindings.InputOutputBinding { func NewAzureStorageQueues(logger logger.Logger) bindings.InputOutputBinding {
return &AzureStorageQueues{helper: NewAzureQueueHelper(logger), logger: logger} return &AzureStorageQueues{
helper: NewAzureQueueHelper(logger),
logger: logger,
closeCh: make(chan struct{}),
}
} }
// Init parses connection properties and creates a new Storage Queue client. // Init parses connection properties and creates a new Storage Queue client.
func (a *AzureStorageQueues) Init(metadata bindings.Metadata) (err error) { func (a *AzureStorageQueues) Init(ctx context.Context, metadata bindings.Metadata) (err error) {
a.metadata, err = a.helper.Init(metadata) a.metadata, err = a.helper.Init(ctx, metadata)
if err != nil { if err != nil {
return err return err
} }
@ -261,14 +280,32 @@ func (a *AzureStorageQueues) Invoke(ctx context.Context, req *bindings.InvokeReq
} }
func (a *AzureStorageQueues) Read(ctx context.Context, handler bindings.Handler) error { func (a *AzureStorageQueues) Read(ctx context.Context, handler bindings.Handler) error {
if a.closed.Load() {
return errors.New("input binding is closed")
}
c := consumer{ c := consumer{
callback: handler, callback: handler,
} }
// Close read context when binding is closed.
readCtx, cancel := context.WithCancel(ctx)
a.wg.Add(2)
go func() { go func() {
defer a.wg.Done()
defer cancel()
select {
case <-a.closeCh:
case <-ctx.Done():
}
}()
go func() {
defer a.wg.Done()
// Read until context is canceled // Read until context is canceled
var err error var err error
for ctx.Err() == nil { for readCtx.Err() == nil {
err = a.helper.Read(ctx, &c) err = a.helper.Read(readCtx, &c)
if err != nil { if err != nil {
a.logger.Errorf("error from c: %s", err) a.logger.Errorf("error from c: %s", err)
} }
@ -277,3 +314,11 @@ func (a *AzureStorageQueues) Read(ctx context.Context, handler bindings.Handler)
return nil return nil
} }
func (a *AzureStorageQueues) Close() error {
if a.closed.CompareAndSwap(false, true) {
close(a.closeCh)
}
a.wg.Wait()
return nil
}

View File

@ -16,6 +16,7 @@ package storagequeues
import ( import (
"context" "context"
"encoding/base64" "encoding/base64"
"sync"
"testing" "testing"
"time" "time"
@ -32,9 +33,11 @@ type MockHelper struct {
mock.Mock mock.Mock
messages chan []byte messages chan []byte
metadata *storageQueuesMetadata metadata *storageQueuesMetadata
closeCh chan struct{}
wg sync.WaitGroup
} }
func (m *MockHelper) Init(metadata bindings.Metadata) (*storageQueuesMetadata, error) { func (m *MockHelper) Init(ctx context.Context, metadata bindings.Metadata) (*storageQueuesMetadata, error) {
m.messages = make(chan []byte, 10) m.messages = make(chan []byte, 10)
var err error var err error
m.metadata, err = parseMetadata(metadata) m.metadata, err = parseMetadata(metadata)
@ -50,12 +53,23 @@ func (m *MockHelper) Write(ctx context.Context, data []byte, ttl *time.Duration)
func (m *MockHelper) Read(ctx context.Context, consumer *consumer) error { func (m *MockHelper) Read(ctx context.Context, consumer *consumer) error {
retvals := m.Called(ctx, consumer) retvals := m.Called(ctx, consumer)
readCtx, cancel := context.WithCancel(ctx)
m.wg.Add(2)
go func() { go func() {
defer m.wg.Done()
defer cancel()
select {
case <-readCtx.Done():
case <-m.closeCh:
}
}()
go func() {
defer m.wg.Done()
for msg := range m.messages { for msg := range m.messages {
if m.metadata.DecodeBase64 { if m.metadata.DecodeBase64 {
msg, _ = base64.StdEncoding.DecodeString(string(msg)) msg, _ = base64.StdEncoding.DecodeString(string(msg))
} }
go consumer.callback(ctx, &bindings.ReadResponse{ go consumer.callback(readCtx, &bindings.ReadResponse{
Data: msg, Data: msg,
}) })
} }
@ -64,18 +78,24 @@ func (m *MockHelper) Read(ctx context.Context, consumer *consumer) error {
return retvals.Error(0) return retvals.Error(0)
} }
func (m *MockHelper) Close() error {
defer m.wg.Wait()
close(m.closeCh)
return nil
}
func TestWriteQueue(t *testing.T) { func TestWriteQueue(t *testing.T) {
mm := new(MockHelper) mm := new(MockHelper)
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.MatchedBy(func(in *time.Duration) bool { mm.On("Write", mock.AnythingOfType("[]uint8"), mock.MatchedBy(func(in *time.Duration) bool {
return in == nil return in == nil
})).Return(nil) })).Return(nil)
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")} a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
m := bindings.Metadata{} m := bindings.Metadata{}
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"} m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
err := a.Init(m) err := a.Init(context.Background(), m)
assert.Nil(t, err) assert.Nil(t, err)
r := bindings.InvokeRequest{Data: []byte("This is my message")} r := bindings.InvokeRequest{Data: []byte("This is my message")}
@ -83,6 +103,7 @@ func TestWriteQueue(t *testing.T) {
_, err = a.Invoke(context.Background(), &r) _, err = a.Invoke(context.Background(), &r)
assert.Nil(t, err) assert.Nil(t, err)
assert.NoError(t, a.Close())
} }
func TestWriteWithTTLInQueue(t *testing.T) { func TestWriteWithTTLInQueue(t *testing.T) {
@ -91,12 +112,12 @@ func TestWriteWithTTLInQueue(t *testing.T) {
return in != nil && *in == time.Second return in != nil && *in == time.Second
})).Return(nil) })).Return(nil)
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")} a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
m := bindings.Metadata{} m := bindings.Metadata{}
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1", metadata.TTLMetadataKey: "1"} m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1", metadata.TTLMetadataKey: "1"}
err := a.Init(m) err := a.Init(context.Background(), m)
assert.Nil(t, err) assert.Nil(t, err)
r := bindings.InvokeRequest{Data: []byte("This is my message")} r := bindings.InvokeRequest{Data: []byte("This is my message")}
@ -104,6 +125,7 @@ func TestWriteWithTTLInQueue(t *testing.T) {
_, err = a.Invoke(context.Background(), &r) _, err = a.Invoke(context.Background(), &r)
assert.Nil(t, err) assert.Nil(t, err)
assert.NoError(t, a.Close())
} }
func TestWriteWithTTLInWrite(t *testing.T) { func TestWriteWithTTLInWrite(t *testing.T) {
@ -112,12 +134,12 @@ func TestWriteWithTTLInWrite(t *testing.T) {
return in != nil && *in == time.Second return in != nil && *in == time.Second
})).Return(nil) })).Return(nil)
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")} a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
m := bindings.Metadata{} m := bindings.Metadata{}
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1", metadata.TTLMetadataKey: "1"} m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1", metadata.TTLMetadataKey: "1"}
err := a.Init(m) err := a.Init(context.Background(), m)
assert.Nil(t, err) assert.Nil(t, err)
r := bindings.InvokeRequest{ r := bindings.InvokeRequest{
@ -128,6 +150,7 @@ func TestWriteWithTTLInWrite(t *testing.T) {
_, err = a.Invoke(context.Background(), &r) _, err = a.Invoke(context.Background(), &r)
assert.Nil(t, err) assert.Nil(t, err)
assert.NoError(t, a.Close())
} }
// Uncomment this function to write a message to local storage queue // Uncomment this function to write a message to local storage queue
@ -138,7 +161,7 @@ func TestWriteWithTTLInWrite(t *testing.T) {
m := bindings.Metadata{} m := bindings.Metadata{}
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"} m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
err := a.Init(m) err := a.Init(context.Background(), m)
assert.Nil(t, err) assert.Nil(t, err)
r := bindings.InvokeRequest{Data: []byte("This is my message")} r := bindings.InvokeRequest{Data: []byte("This is my message")}
@ -152,12 +175,12 @@ func TestReadQueue(t *testing.T) {
mm := new(MockHelper) mm := new(MockHelper)
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil) mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)
mm.On("Read", mock.AnythingOfType("*context.cancelCtx"), mock.AnythingOfType("*storagequeues.consumer")).Return(nil) mm.On("Read", mock.AnythingOfType("*context.cancelCtx"), mock.AnythingOfType("*storagequeues.consumer")).Return(nil)
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")} a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
m := bindings.Metadata{} m := bindings.Metadata{}
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"} m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
err := a.Init(m) err := a.Init(context.Background(), m)
assert.Nil(t, err) assert.Nil(t, err)
r := bindings.InvokeRequest{Data: []byte("This is my message")} r := bindings.InvokeRequest{Data: []byte("This is my message")}
@ -186,6 +209,7 @@ func TestReadQueue(t *testing.T) {
t.Fatal("Timeout waiting for messages") t.Fatal("Timeout waiting for messages")
} }
assert.Equal(t, 1, received) assert.Equal(t, 1, received)
assert.NoError(t, a.Close())
} }
func TestReadQueueDecode(t *testing.T) { func TestReadQueueDecode(t *testing.T) {
@ -193,12 +217,12 @@ func TestReadQueueDecode(t *testing.T) {
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil) mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)
mm.On("Read", mock.AnythingOfType("*context.cancelCtx"), mock.AnythingOfType("*storagequeues.consumer")).Return(nil) mm.On("Read", mock.AnythingOfType("*context.cancelCtx"), mock.AnythingOfType("*storagequeues.consumer")).Return(nil)
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")} a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
m := bindings.Metadata{} m := bindings.Metadata{}
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1", "decodeBase64": "true"} m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1", "decodeBase64": "true"}
err := a.Init(m) err := a.Init(context.Background(), m)
assert.Nil(t, err) assert.Nil(t, err)
r := bindings.InvokeRequest{Data: []byte("VGhpcyBpcyBteSBtZXNzYWdl")} r := bindings.InvokeRequest{Data: []byte("VGhpcyBpcyBteSBtZXNzYWdl")}
@ -227,6 +251,7 @@ func TestReadQueueDecode(t *testing.T) {
t.Fatal("Timeout waiting for messages") t.Fatal("Timeout waiting for messages")
} }
assert.Equal(t, 1, received) assert.Equal(t, 1, received)
assert.NoError(t, a.Close())
} }
// Uncomment this function to test reding from local queue // Uncomment this function to test reding from local queue
@ -237,7 +262,7 @@ func TestReadQueueDecode(t *testing.T) {
m := bindings.Metadata{} m := bindings.Metadata{}
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"} m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
err := a.Init(m) err := a.Init(context.Background(), m)
assert.Nil(t, err) assert.Nil(t, err)
r := bindings.InvokeRequest{Data: []byte("This is my message")} r := bindings.InvokeRequest{Data: []byte("This is my message")}
@ -263,12 +288,12 @@ func TestReadQueueNoMessage(t *testing.T) {
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil) mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)
mm.On("Read", mock.AnythingOfType("*context.cancelCtx"), mock.AnythingOfType("*storagequeues.consumer")).Return(nil) mm.On("Read", mock.AnythingOfType("*context.cancelCtx"), mock.AnythingOfType("*storagequeues.consumer")).Return(nil)
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")} a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
m := bindings.Metadata{} m := bindings.Metadata{}
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"} m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
err := a.Init(m) err := a.Init(context.Background(), m)
assert.Nil(t, err) assert.Nil(t, err)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -285,6 +310,7 @@ func TestReadQueueNoMessage(t *testing.T) {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
cancel() cancel()
assert.Equal(t, 0, received) assert.Equal(t, 0, received)
assert.NoError(t, a.Close())
} }
func TestParseMetadata(t *testing.T) { func TestParseMetadata(t *testing.T) {

View File

@ -48,7 +48,7 @@ func NewCFQueues(logger logger.Logger) bindings.OutputBinding {
} }
// Init the component. // Init the component.
func (q *CFQueues) Init(metadata bindings.Metadata) error { func (q *CFQueues) Init(_ context.Context, metadata bindings.Metadata) error {
// Decode the metadata // Decode the metadata
err := mapstructure.Decode(metadata.Properties, &q.metadata) err := mapstructure.Decode(metadata.Properties, &q.metadata)
if err != nil { if err != nil {

View File

@ -51,7 +51,7 @@ func NewCommercetools(logger logger.Logger) bindings.OutputBinding {
} }
// Init does metadata parsing and connection establishment. // Init does metadata parsing and connection establishment.
func (ct *Binding) Init(metadata bindings.Metadata) error { func (ct *Binding) Init(_ context.Context, metadata bindings.Metadata) error {
commercetoolsM, err := ct.getCommercetoolsMetadata(metadata) commercetoolsM, err := ct.getCommercetoolsMetadata(metadata)
if err != nil { if err != nil {
return err return err

View File

@ -16,6 +16,8 @@ package cron
import ( import (
"context" "context"
"fmt" "fmt"
"sync"
"sync/atomic"
"time" "time"
"github.com/benbjohnson/clock" "github.com/benbjohnson/clock"
@ -34,6 +36,9 @@ type Binding struct {
schedule string schedule string
parser cron.Parser parser cron.Parser
clk clock.Clock clk clock.Clock
closed atomic.Bool
closeCh chan struct{}
wg sync.WaitGroup
} }
// NewCron returns a new Cron event input binding. // NewCron returns a new Cron event input binding.
@ -48,6 +53,7 @@ func NewCronWithClock(logger logger.Logger, clk clock.Clock) bindings.InputBindi
parser: cron.NewParser( parser: cron.NewParser(
cron.SecondOptional | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor, cron.SecondOptional | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor,
), ),
closeCh: make(chan struct{}),
} }
} }
@ -56,7 +62,7 @@ func NewCronWithClock(logger logger.Logger, clk clock.Clock) bindings.InputBindi
// //
// "15 * * * * *" - Every 15 sec // "15 * * * * *" - Every 15 sec
// "0 30 * * * *" - Every 30 min // "0 30 * * * *" - Every 30 min
func (b *Binding) Init(metadata bindings.Metadata) error { func (b *Binding) Init(ctx context.Context, metadata bindings.Metadata) error {
b.name = metadata.Name b.name = metadata.Name
s, f := metadata.Properties["schedule"] s, f := metadata.Properties["schedule"]
if !f || s == "" { if !f || s == "" {
@ -73,6 +79,10 @@ func (b *Binding) Init(metadata bindings.Metadata) error {
// Read triggers the Cron scheduler. // Read triggers the Cron scheduler.
func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error { func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
if b.closed.Load() {
return errors.New("binding is closed")
}
c := cron.New(cron.WithParser(b.parser), cron.WithClock(b.clk)) c := cron.New(cron.WithParser(b.parser), cron.WithClock(b.clk))
id, err := c.AddFunc(b.schedule, func() { id, err := c.AddFunc(b.schedule, func() {
b.logger.Debugf("name: %s, schedule fired: %v", b.name, time.Now()) b.logger.Debugf("name: %s, schedule fired: %v", b.name, time.Now())
@ -89,12 +99,25 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
c.Start() c.Start()
b.logger.Debugf("name: %s, next run: %v", b.name, time.Until(c.Entry(id).Next)) b.logger.Debugf("name: %s, next run: %v", b.name, time.Until(c.Entry(id).Next))
b.wg.Add(1)
go func() { go func() {
// Wait for context to be canceled defer b.wg.Done()
<-ctx.Done() // Wait for context to be canceled or component to be closed.
select {
case <-ctx.Done():
case <-b.closeCh:
}
b.logger.Debugf("name: %s, stopping schedule: %s", b.name, b.schedule) b.logger.Debugf("name: %s, stopping schedule: %s", b.name, b.schedule)
c.Stop() c.Stop()
}() }()
return nil return nil
} }
func (b *Binding) Close() error {
if b.closed.CompareAndSwap(false, true) {
close(b.closeCh)
}
b.wg.Wait()
return nil
}

View File

@ -16,6 +16,7 @@ package cron
import ( import (
"context" "context"
"os" "os"
"sync/atomic"
"testing" "testing"
"time" "time"
@ -84,7 +85,7 @@ func TestCronInitSuccess(t *testing.T) {
for _, test := range initTests { for _, test := range initTests {
c := getNewCron() c := getNewCron()
err := c.Init(getTestMetadata(test.schedule)) err := c.Init(context.Background(), getTestMetadata(test.schedule))
if test.errorExpected { if test.errorExpected {
assert.Errorf(t, err, "Got no error while initializing an invalid schedule: %s", test.schedule) assert.Errorf(t, err, "Got no error while initializing an invalid schedule: %s", test.schedule)
} else { } else {
@ -99,38 +100,41 @@ func TestCronRead(t *testing.T) {
clk := clock.NewMock() clk := clock.NewMock()
c := getNewCronWithClock(clk) c := getNewCronWithClock(clk)
schedule := "@every 1s" schedule := "@every 1s"
assert.NoErrorf(t, c.Init(getTestMetadata(schedule)), "error initializing valid schedule") assert.NoErrorf(t, c.Init(context.Background(), getTestMetadata(schedule)), "error initializing valid schedule")
expectedCount := 5 expectedCount := int32(5)
observedCount := 0 var observedCount atomic.Int32
err := c.Read(context.Background(), func(ctx context.Context, res *bindings.ReadResponse) ([]byte, error) { err := c.Read(context.Background(), func(ctx context.Context, res *bindings.ReadResponse) ([]byte, error) {
assert.NotNil(t, res) assert.NotNil(t, res)
observedCount++ observedCount.Add(1)
return nil, nil return nil, nil
}) })
// Check if cron triggers 5 times in 5 seconds // Check if cron triggers 5 times in 5 seconds
for i := 0; i < expectedCount; i++ { for i := int32(0); i < expectedCount; i++ {
// Add time to mock clock in 1 second intervals using loop to allow cron go routine to run // Add time to mock clock in 1 second intervals using loop to allow cron go routine to run
clk.Add(time.Second) clk.Add(time.Second)
} }
// Wait for 1 second after adding the last second to mock clock to allow cron to finish triggering // Wait for 1 second after adding the last second to mock clock to allow cron to finish triggering
time.Sleep(1 * time.Second) assert.Eventually(t, func() bool {
assert.Equal(t, expectedCount, observedCount, "Cron did not trigger expected number of times, expected %d, got %d", expectedCount, observedCount) return observedCount.Load() == expectedCount
}, time.Second, time.Millisecond*10,
"Cron did not trigger expected number of times, expected %d, got %d", expectedCount, observedCount.Load())
assert.NoErrorf(t, err, "error on read") assert.NoErrorf(t, err, "error on read")
assert.NoError(t, c.Close())
} }
func TestCronReadWithContextCancellation(t *testing.T) { func TestCronReadWithContextCancellation(t *testing.T) {
clk := clock.NewMock() clk := clock.NewMock()
c := getNewCronWithClock(clk) c := getNewCronWithClock(clk)
schedule := "@every 1s" schedule := "@every 1s"
assert.NoErrorf(t, c.Init(getTestMetadata(schedule)), "error initializing valid schedule") assert.NoErrorf(t, c.Init(context.Background(), getTestMetadata(schedule)), "error initializing valid schedule")
expectedCount := 5 expectedCount := int32(5)
observedCount := 0 var observedCount atomic.Int32
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
err := c.Read(ctx, func(ctx context.Context, res *bindings.ReadResponse) ([]byte, error) { err := c.Read(ctx, func(ctx context.Context, res *bindings.ReadResponse) ([]byte, error) {
assert.NotNil(t, res) assert.NotNil(t, res)
assert.LessOrEqualf(t, observedCount, expectedCount, "Invoke didn't stop the schedule") assert.LessOrEqualf(t, observedCount.Load(), expectedCount, "Invoke didn't stop the schedule")
observedCount++ observedCount.Add(1)
if observedCount == expectedCount { if observedCount.Load() == expectedCount {
// Cancel context after 5 triggers // Cancel context after 5 triggers
cancel() cancel()
} }
@ -141,7 +145,10 @@ func TestCronReadWithContextCancellation(t *testing.T) {
// Add time to mock clock in 1 second intervals using loop to allow cron go routine to run // Add time to mock clock in 1 second intervals using loop to allow cron go routine to run
clk.Add(time.Second) clk.Add(time.Second)
} }
time.Sleep(1 * time.Second) assert.Eventually(t, func() bool {
assert.Equal(t, expectedCount, observedCount, "Cron did not trigger expected number of times, expected %d, got %d", expectedCount, observedCount) return observedCount.Load() == expectedCount
}, time.Second, time.Millisecond*10,
"Cron did not trigger expected number of times, expected %d, got %d", expectedCount, observedCount.Load())
assert.NoErrorf(t, err, "error on read") assert.NoErrorf(t, err, "error on read")
assert.NoError(t, c.Close())
} }

View File

@ -47,7 +47,7 @@ func NewDubboOutput(logger logger.Logger) bindings.OutputBinding {
return dubboBinding return dubboBinding
} }
func (out *DubboOutputBinding) Init(_ bindings.Metadata) error { func (out *DubboOutputBinding) Init(_ context.Context, _ bindings.Metadata) error {
dubboImpl.SetSerializer(constant.Hessian2Serialization, HessianSerializer{}) dubboImpl.SetSerializer(constant.Hessian2Serialization, HessianSerializer{})
return nil return nil
} }

View File

@ -54,12 +54,13 @@ func TestInvoke(t *testing.T) {
// 0. init dapr provided and dubbo server // 0. init dapr provided and dubbo server
stopCh := make(chan struct{}) stopCh := make(chan struct{})
defer close(stopCh) defer close(stopCh)
// Create output and set serializer before go routine to prevent data race.
output := NewDubboOutput(logger.NewLogger("test"))
dubboImpl.SetSerializer(constant.Hessian2Serialization, HessianSerializer{})
go func() { go func() {
assert.Nil(t, runDubboServer(stopCh)) assert.Nil(t, runDubboServer(stopCh))
}() }()
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
dubboImpl.SetSerializer(constant.Hessian2Serialization, HessianSerializer{})
output := NewDubboOutput(logger.NewLogger("test"))
// 1. create req/rsp value // 1. create req/rsp value
reqUser := &User{Name: testName} reqUser := &User{Name: testName}

View File

@ -83,14 +83,13 @@ func NewGCPStorage(logger logger.Logger) bindings.OutputBinding {
} }
// Init performs connection parsing. // Init performs connection parsing.
func (g *GCPStorage) Init(metadata bindings.Metadata) error { func (g *GCPStorage) Init(ctx context.Context, metadata bindings.Metadata) error {
m, b, err := g.parseMetadata(metadata) m, b, err := g.parseMetadata(metadata)
if err != nil { if err != nil {
return err return err
} }
clientOptions := option.WithCredentialsJSON(b) clientOptions := option.WithCredentialsJSON(b)
ctx := context.Background()
client, err := storage.NewClient(ctx, clientOptions) client, err := storage.NewClient(ctx, clientOptions)
if err != nil { if err != nil {
return err return err

View File

@ -16,7 +16,10 @@ package pubsub
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"sync"
"sync/atomic"
"cloud.google.com/go/pubsub" "cloud.google.com/go/pubsub"
"google.golang.org/api/option" "google.golang.org/api/option"
@ -36,6 +39,9 @@ type GCPPubSub struct {
client *pubsub.Client client *pubsub.Client
metadata *pubSubMetadata metadata *pubSubMetadata
logger logger.Logger logger logger.Logger
closed atomic.Bool
closeCh chan struct{}
wg sync.WaitGroup
} }
type pubSubMetadata struct { type pubSubMetadata struct {
@ -55,11 +61,14 @@ type pubSubMetadata struct {
// NewGCPPubSub returns a new GCPPubSub instance. // NewGCPPubSub returns a new GCPPubSub instance.
func NewGCPPubSub(logger logger.Logger) bindings.InputOutputBinding { func NewGCPPubSub(logger logger.Logger) bindings.InputOutputBinding {
return &GCPPubSub{logger: logger} return &GCPPubSub{
logger: logger,
closeCh: make(chan struct{}),
}
} }
// Init parses metadata and creates a new Pub Sub client. // Init parses metadata and creates a new Pub Sub client.
func (g *GCPPubSub) Init(metadata bindings.Metadata) error { func (g *GCPPubSub) Init(ctx context.Context, metadata bindings.Metadata) error {
b, err := g.parseMetadata(metadata) b, err := g.parseMetadata(metadata)
if err != nil { if err != nil {
return err return err
@ -71,7 +80,6 @@ func (g *GCPPubSub) Init(metadata bindings.Metadata) error {
return err return err
} }
clientOptions := option.WithCredentialsJSON(b) clientOptions := option.WithCredentialsJSON(b)
ctx := context.Background()
pubsubClient, err := pubsub.NewClient(ctx, pubsubMeta.ProjectID, clientOptions) pubsubClient, err := pubsub.NewClient(ctx, pubsubMeta.ProjectID, clientOptions)
if err != nil { if err != nil {
return fmt.Errorf("error creating pubsub client: %s", err) return fmt.Errorf("error creating pubsub client: %s", err)
@ -88,7 +96,12 @@ func (g *GCPPubSub) parseMetadata(metadata bindings.Metadata) ([]byte, error) {
} }
func (g *GCPPubSub) Read(ctx context.Context, handler bindings.Handler) error { func (g *GCPPubSub) Read(ctx context.Context, handler bindings.Handler) error {
if g.closed.Load() {
return errors.New("binding is closed")
}
g.wg.Add(1)
go func() { go func() {
defer g.wg.Done()
sub := g.client.Subscription(g.metadata.Subscription) sub := g.client.Subscription(g.metadata.Subscription)
err := sub.Receive(ctx, func(c context.Context, m *pubsub.Message) { err := sub.Receive(ctx, func(c context.Context, m *pubsub.Message) {
_, err := handler(c, &bindings.ReadResponse{ _, err := handler(c, &bindings.ReadResponse{
@ -128,5 +141,9 @@ func (g *GCPPubSub) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*b
} }
func (g *GCPPubSub) Close() error { func (g *GCPPubSub) Close() error {
if g.closed.CompareAndSwap(false, true) {
close(g.closeCh)
}
defer g.wg.Wait()
return g.client.Close() return g.client.Close()
} }

View File

@ -58,7 +58,7 @@ func NewGraphQL(logger logger.Logger) bindings.OutputBinding {
} }
// Init initializes the GraphQL binding. // Init initializes the GraphQL binding.
func (gql *GraphQL) Init(metadata bindings.Metadata) error { func (gql *GraphQL) Init(_ context.Context, metadata bindings.Metadata) error {
gql.logger.Debug("GraphQL Error: Initializing GraphQL binding") gql.logger.Debug("GraphQL Error: Initializing GraphQL binding")
p := metadata.Properties p := metadata.Properties

View File

@ -74,7 +74,7 @@ func NewHTTP(logger logger.Logger) bindings.OutputBinding {
} }
// Init performs metadata parsing. // Init performs metadata parsing.
func (h *HTTPSource) Init(metadata bindings.Metadata) error { func (h *HTTPSource) Init(_ context.Context, metadata bindings.Metadata) error {
var err error var err error
if err = mapstructure.Decode(metadata.Properties, &h.metadata); err != nil { if err = mapstructure.Decode(metadata.Properties, &h.metadata); err != nil {
return err return err
@ -104,7 +104,7 @@ func (h *HTTPSource) Init(metadata bindings.Metadata) error {
Transport: netTransport, Transport: netTransport,
} }
if val, ok := metadata.Properties["errorIfNot2XX"]; ok { if val := metadata.Properties["errorIfNot2XX"]; val != "" {
h.errorIfNot2XX = utils.IsTruthy(val) h.errorIfNot2XX = utils.IsTruthy(val)
} else { } else {
// Default behavior // Default behavior

View File

@ -132,7 +132,7 @@ func InitBinding(s *httptest.Server, extraProps map[string]string) (bindings.Out
} }
hs := NewHTTP(logger.NewLogger("test")) hs := NewHTTP(logger.NewLogger("test"))
err := hs.Init(m) err := hs.Init(context.Background(), m)
return hs, err return hs, err
} }
@ -269,7 +269,7 @@ func InitBindingForHTTPS(s *httptest.Server, extraProps map[string]string) (bind
m.Properties[k] = v m.Properties[k] = v
} }
hs := NewHTTP(logger.NewLogger("test")) hs := NewHTTP(logger.NewLogger("test"))
err := hs.Init(m) err := hs.Init(context.Background(), m)
return hs, err return hs, err
} }

View File

@ -75,7 +75,7 @@ func NewHuaweiOBS(logger logger.Logger) bindings.OutputBinding {
} }
// Init does metadata parsing and connection creation. // Init does metadata parsing and connection creation.
func (o *HuaweiOBS) Init(metadata bindings.Metadata) error { func (o *HuaweiOBS) Init(_ context.Context, metadata bindings.Metadata) error {
o.logger.Debugf("initializing Huawei OBS binding and parsing metadata") o.logger.Debugf("initializing Huawei OBS binding and parsing metadata")
m, err := o.parseMetadata(metadata) m, err := o.parseMetadata(metadata)

View File

@ -92,7 +92,7 @@ func TestInit(t *testing.T) {
"accessKey": "dummy-ak", "accessKey": "dummy-ak",
"secretKey": "dummy-sk", "secretKey": "dummy-sk",
} }
err := obs.Init(m) err := obs.Init(context.Background(), m)
assert.Nil(t, err) assert.Nil(t, err)
}) })
t.Run("Init with missing bucket name", func(t *testing.T) { t.Run("Init with missing bucket name", func(t *testing.T) {
@ -102,7 +102,7 @@ func TestInit(t *testing.T) {
"accessKey": "dummy-ak", "accessKey": "dummy-ak",
"secretKey": "dummy-sk", "secretKey": "dummy-sk",
} }
err := obs.Init(m) err := obs.Init(context.Background(), m)
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, err, fmt.Errorf("missing obs bucket name")) assert.Equal(t, err, fmt.Errorf("missing obs bucket name"))
}) })
@ -113,7 +113,7 @@ func TestInit(t *testing.T) {
"endpoint": "dummy-endpoint", "endpoint": "dummy-endpoint",
"secretKey": "dummy-sk", "secretKey": "dummy-sk",
} }
err := obs.Init(m) err := obs.Init(context.Background(), m)
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, err, fmt.Errorf("missing the huawei access key")) assert.Equal(t, err, fmt.Errorf("missing the huawei access key"))
}) })
@ -124,7 +124,7 @@ func TestInit(t *testing.T) {
"endpoint": "dummy-endpoint", "endpoint": "dummy-endpoint",
"accessKey": "dummy-ak", "accessKey": "dummy-ak",
} }
err := obs.Init(m) err := obs.Init(context.Background(), m)
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, err, fmt.Errorf("missing the huawei secret key")) assert.Equal(t, err, fmt.Errorf("missing the huawei secret key"))
}) })
@ -135,7 +135,7 @@ func TestInit(t *testing.T) {
"accessKey": "dummy-ak", "accessKey": "dummy-ak",
"secretKey": "dummy-sk", "secretKey": "dummy-sk",
} }
err := obs.Init(m) err := obs.Init(context.Background(), m)
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, err, fmt.Errorf("missing obs endpoint")) assert.Equal(t, err, fmt.Errorf("missing obs endpoint"))
}) })

View File

@ -64,7 +64,7 @@ func NewInflux(logger logger.Logger) bindings.OutputBinding {
} }
// Init does metadata parsing and connection establishment. // Init does metadata parsing and connection establishment.
func (i *Influx) Init(metadata bindings.Metadata) error { func (i *Influx) Init(_ context.Context, metadata bindings.Metadata) error {
influxMeta, err := i.getInfluxMetadata(metadata) influxMeta, err := i.getInfluxMetadata(metadata)
if err != nil { if err != nil {
return err return err

View File

@ -54,7 +54,7 @@ func TestInflux_Init(t *testing.T) {
assert.Nil(t, influx.client) assert.Nil(t, influx.client)
m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{"Url": "a", "Token": "a", "Org": "a", "Bucket": "a"}}} m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{"Url": "a", "Token": "a", "Org": "a", "Bucket": "a"}}}
err := influx.Init(m) err := influx.Init(context.Background(), m)
assert.Nil(t, err) assert.Nil(t, err)
assert.NotNil(t, influx.queryAPI) assert.NotNil(t, influx.queryAPI)

View File

@ -16,6 +16,7 @@ package bindings
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"github.com/dapr/components-contrib/health" "github.com/dapr/components-contrib/health"
) )
@ -23,18 +24,21 @@ import (
// InputBinding is the interface to define a binding that triggers on incoming events. // InputBinding is the interface to define a binding that triggers on incoming events.
type InputBinding interface { type InputBinding interface {
// Init passes connection and properties metadata to the binding implementation. // Init passes connection and properties metadata to the binding implementation.
Init(metadata Metadata) error Init(ctx context.Context, metadata Metadata) error
// Read is a method that runs in background and triggers the callback function whenever an event arrives. // Read is a method that runs in background and triggers the callback function whenever an event arrives.
Read(ctx context.Context, handler Handler) error Read(ctx context.Context, handler Handler) error
// Close is a method that closes the connection to the binding. Must be
// called when the binding is no longer needed to free up resources.
io.Closer
} }
// Handler is the handler used to invoke the app handler. // Handler is the handler used to invoke the app handler.
type Handler func(context.Context, *ReadResponse) ([]byte, error) type Handler func(context.Context, *ReadResponse) ([]byte, error)
func PingInpBinding(inputBinding InputBinding) error { func PingInpBinding(ctx context.Context, inputBinding InputBinding) error {
// checks if this input binding has the ping option then executes // checks if this input binding has the ping option then executes
if inputBindingWithPing, ok := inputBinding.(health.Pinger); ok { if inputBindingWithPing, ok := inputBinding.(health.Pinger); ok {
return inputBindingWithPing.Ping() return inputBindingWithPing.Ping(ctx)
} else { } else {
return fmt.Errorf("ping is not implemented by this input binding") return fmt.Errorf("ping is not implemented by this input binding")
} }

View File

@ -15,7 +15,10 @@ package kafka
import ( import (
"context" "context"
"errors"
"strings" "strings"
"sync"
"sync/atomic"
"github.com/dapr/kit/logger" "github.com/dapr/kit/logger"
@ -29,12 +32,13 @@ const (
) )
type Binding struct { type Binding struct {
kafka *kafka.Kafka kafka *kafka.Kafka
publishTopic string publishTopic string
topics []string topics []string
logger logger.Logger logger logger.Logger
subscribeCtx context.Context closeCh chan struct{}
subscribeCancel context.CancelFunc closed atomic.Bool
wg sync.WaitGroup
} }
// NewKafka returns a new kafka binding instance. // NewKafka returns a new kafka binding instance.
@ -43,15 +47,14 @@ func NewKafka(logger logger.Logger) bindings.InputOutputBinding {
// in kafka binding component, disable consumer retry by default // in kafka binding component, disable consumer retry by default
k.DefaultConsumeRetryEnabled = false k.DefaultConsumeRetryEnabled = false
return &Binding{ return &Binding{
kafka: k, kafka: k,
logger: logger, logger: logger,
closeCh: make(chan struct{}),
} }
} }
func (b *Binding) Init(metadata bindings.Metadata) error { func (b *Binding) Init(ctx context.Context, metadata bindings.Metadata) error {
b.subscribeCtx, b.subscribeCancel = context.WithCancel(context.Background()) err := b.kafka.Init(ctx, metadata.Properties)
err := b.kafka.Init(metadata.Properties)
if err != nil { if err != nil {
return err return err
} }
@ -74,7 +77,10 @@ func (b *Binding) Operations() []bindings.OperationKind {
} }
func (b *Binding) Close() (err error) { func (b *Binding) Close() (err error) {
b.subscribeCancel() if b.closed.CompareAndSwap(false, true) {
close(b.closeCh)
}
defer b.wg.Wait()
return b.kafka.Close() return b.kafka.Close()
} }
@ -84,6 +90,10 @@ func (b *Binding) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bin
} }
func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error { func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
if b.closed.Load() {
return errors.New("error: binding is closed")
}
if len(b.topics) == 0 { if len(b.topics) == 0 {
b.logger.Warnf("kafka binding: no topic defined, input bindings will not be started") b.logger.Warnf("kafka binding: no topic defined, input bindings will not be started")
return nil return nil
@ -96,31 +106,22 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
for _, t := range b.topics { for _, t := range b.topics {
b.kafka.AddTopicHandler(t, handlerConfig) b.kafka.AddTopicHandler(t, handlerConfig)
} }
b.wg.Add(1)
go func() { go func() {
// Wait for context cancelation defer b.wg.Done()
// Wait for context cancelation or closure.
select { select {
case <-ctx.Done(): case <-ctx.Done():
case <-b.subscribeCtx.Done(): case <-b.closeCh:
} }
// Remove the topic handler before restarting the subscriber // Remove the topic handlers.
for _, t := range b.topics { for _, t := range b.topics {
b.kafka.RemoveTopicHandler(t) b.kafka.RemoveTopicHandler(t)
} }
// If the component's context has been canceled, do not re-subscribe
if b.subscribeCtx.Err() != nil {
return
}
err := b.kafka.Subscribe(b.subscribeCtx)
if err != nil {
b.logger.Errorf("kafka binding: error re-subscribing: %v", err)
}
}() }()
return b.kafka.Subscribe(b.subscribeCtx) return b.kafka.Subscribe(ctx)
} }
func adaptHandler(handler bindings.Handler) kafka.EventHandler { func adaptHandler(handler bindings.Handler) kafka.EventHandler {

View File

@ -2,8 +2,11 @@ package kubemq
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"strings" "strings"
"sync"
"sync/atomic"
"time" "time"
qs "github.com/kubemq-io/kubemq-go/queues_stream" qs "github.com/kubemq-io/kubemq-go/queues_stream"
@ -19,31 +22,30 @@ type Kubemq interface {
} }
type kubeMQ struct { type kubeMQ struct {
client *qs.QueuesStreamClient client *qs.QueuesStreamClient
opts *options opts *options
logger logger.Logger logger logger.Logger
ctx context.Context closed atomic.Bool
ctxCancel context.CancelFunc closeCh chan struct{}
wg sync.WaitGroup
} }
func NewKubeMQ(logger logger.Logger) Kubemq { func NewKubeMQ(logger logger.Logger) Kubemq {
return &kubeMQ{ return &kubeMQ{
client: nil, client: nil,
opts: nil, opts: nil,
logger: logger, logger: logger,
ctx: nil, closeCh: make(chan struct{}),
ctxCancel: nil,
} }
} }
func (k *kubeMQ) Init(metadata bindings.Metadata) error { func (k *kubeMQ) Init(ctx context.Context, metadata bindings.Metadata) error {
opts, err := createOptions(metadata) opts, err := createOptions(metadata)
if err != nil { if err != nil {
return err return err
} }
k.opts = opts k.opts = opts
k.ctx, k.ctxCancel = context.WithCancel(context.Background()) client, err := qs.NewQueuesStreamClient(ctx,
client, err := qs.NewQueuesStreamClient(k.ctx,
qs.WithAddress(opts.host, opts.port), qs.WithAddress(opts.host, opts.port),
qs.WithCheckConnection(true), qs.WithCheckConnection(true),
qs.WithAuthToken(opts.authToken), qs.WithAuthToken(opts.authToken),
@ -53,22 +55,39 @@ func (k *kubeMQ) Init(metadata bindings.Metadata) error {
k.logger.Errorf("error init kubemq client error: %s", err.Error()) k.logger.Errorf("error init kubemq client error: %s", err.Error())
return err return err
} }
k.ctx, k.ctxCancel = context.WithCancel(context.Background())
k.client = client k.client = client
return nil return nil
} }
func (k *kubeMQ) Read(ctx context.Context, handler bindings.Handler) error { func (k *kubeMQ) Read(ctx context.Context, handler bindings.Handler) error {
if k.closed.Load() {
return errors.New("binding is closed")
}
k.wg.Add(2)
processCtx, cancel := context.WithCancel(ctx)
go func() { go func() {
defer k.wg.Done()
defer cancel()
select {
case <-k.closeCh:
case <-processCtx.Done():
}
}()
go func() {
defer k.wg.Done()
for { for {
err := k.processQueueMessage(k.ctx, handler) err := k.processQueueMessage(processCtx, handler)
if err != nil { if err != nil {
k.logger.Error(err.Error()) k.logger.Error(err.Error())
time.Sleep(time.Second)
} }
if k.ctx.Err() != nil { // If context cancelled or kubeMQ closed, exit. Otherwise, continue
return // after a second.
select {
case <-time.After(time.Second):
continue
case <-processCtx.Done():
} }
return
} }
}() }()
return nil return nil
@ -82,7 +101,7 @@ func (k *kubeMQ) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bind
SetPolicyExpirationSeconds(parsePolicyExpirationSeconds(req.Metadata)). SetPolicyExpirationSeconds(parsePolicyExpirationSeconds(req.Metadata)).
SetPolicyMaxReceiveCount(parseSetPolicyMaxReceiveCount(req.Metadata)). SetPolicyMaxReceiveCount(parseSetPolicyMaxReceiveCount(req.Metadata)).
SetPolicyMaxReceiveQueue(parsePolicyMaxReceiveQueue(req.Metadata)) SetPolicyMaxReceiveQueue(parsePolicyMaxReceiveQueue(req.Metadata))
result, err := k.client.Send(k.ctx, queueMessage) result, err := k.client.Send(ctx, queueMessage)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -101,6 +120,14 @@ func (k *kubeMQ) Operations() []bindings.OperationKind {
return []bindings.OperationKind{bindings.CreateOperation} return []bindings.OperationKind{bindings.CreateOperation}
} }
func (k *kubeMQ) Close() error {
if k.closed.CompareAndSwap(false, true) {
close(k.closeCh)
}
defer k.wg.Wait()
return k.client.Close()
}
func (k *kubeMQ) processQueueMessage(ctx context.Context, handler bindings.Handler) error { func (k *kubeMQ) processQueueMessage(ctx context.Context, handler bindings.Handler) error {
pr := qs.NewPollRequest(). pr := qs.NewPollRequest().
SetChannel(k.opts.channel). SetChannel(k.opts.channel).

View File

@ -106,7 +106,7 @@ func Test_kubeMQ_Init(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
kubemq := NewKubeMQ(logger.NewLogger("test")) kubemq := NewKubeMQ(logger.NewLogger("test"))
err := kubemq.Init(tt.meta) err := kubemq.Init(context.Background(), tt.meta)
if tt.wantErr { if tt.wantErr {
require.Error(t, err) require.Error(t, err)
} else { } else {
@ -120,7 +120,7 @@ func Test_kubeMQ_Invoke_Read_Single_Message(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel() defer cancel()
kubemq := NewKubeMQ(logger.NewLogger("test")) kubemq := NewKubeMQ(logger.NewLogger("test"))
err := kubemq.Init(getDefaultMetadata("test.read.single")) err := kubemq.Init(context.Background(), getDefaultMetadata("test.read.single"))
require.NoError(t, err) require.NoError(t, err)
dataReadCh := make(chan []byte) dataReadCh := make(chan []byte)
invokeRequest := &bindings.InvokeRequest{ invokeRequest := &bindings.InvokeRequest{
@ -147,7 +147,7 @@ func Test_kubeMQ_Invoke_Read_Single_MessageWithHandlerError(t *testing.T) {
kubemq := NewKubeMQ(logger.NewLogger("test")) kubemq := NewKubeMQ(logger.NewLogger("test"))
md := getDefaultMetadata("test.read.single.error") md := getDefaultMetadata("test.read.single.error")
md.Properties["autoAcknowledged"] = "false" md.Properties["autoAcknowledged"] = "false"
err := kubemq.Init(md) err := kubemq.Init(context.Background(), md)
require.NoError(t, err) require.NoError(t, err)
invokeRequest := &bindings.InvokeRequest{ invokeRequest := &bindings.InvokeRequest{
Data: []byte("test"), Data: []byte("test"),
@ -182,7 +182,7 @@ func Test_kubeMQ_Invoke_Error(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel() defer cancel()
kubemq := NewKubeMQ(logger.NewLogger("test")) kubemq := NewKubeMQ(logger.NewLogger("test"))
err := kubemq.Init(getDefaultMetadata("***test***")) err := kubemq.Init(context.Background(), getDefaultMetadata("***test***"))
require.NoError(t, err) require.NoError(t, err)
invokeRequest := &bindings.InvokeRequest{ invokeRequest := &bindings.InvokeRequest{

View File

@ -18,6 +18,8 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"strconv" "strconv"
"sync"
"sync/atomic"
"time" "time"
v1 "k8s.io/api/core/v1" v1 "k8s.io/api/core/v1"
@ -35,6 +37,9 @@ type kubernetesInput struct {
namespace string namespace string
resyncPeriod time.Duration resyncPeriod time.Duration
logger logger.Logger logger logger.Logger
closed atomic.Bool
closeCh chan struct{}
wg sync.WaitGroup
} }
type EventResponse struct { type EventResponse struct {
@ -45,10 +50,13 @@ type EventResponse struct {
// NewKubernetes returns a new Kubernetes event input binding. // NewKubernetes returns a new Kubernetes event input binding.
func NewKubernetes(logger logger.Logger) bindings.InputBinding { func NewKubernetes(logger logger.Logger) bindings.InputBinding {
return &kubernetesInput{logger: logger} return &kubernetesInput{
logger: logger,
closeCh: make(chan struct{}),
}
} }
func (k *kubernetesInput) Init(metadata bindings.Metadata) error { func (k *kubernetesInput) Init(ctx context.Context, metadata bindings.Metadata) error {
client, err := kubeclient.GetKubeClient() client, err := kubeclient.GetKubeClient()
if err != nil { if err != nil {
return err return err
@ -78,6 +86,9 @@ func (k *kubernetesInput) parseMetadata(metadata bindings.Metadata) error {
} }
func (k *kubernetesInput) Read(ctx context.Context, handler bindings.Handler) error { func (k *kubernetesInput) Read(ctx context.Context, handler bindings.Handler) error {
if k.closed.Load() {
return errors.New("binding is closed")
}
watchlist := cache.NewListWatchFromClient( watchlist := cache.NewListWatchFromClient(
k.kubeClient.CoreV1().RESTClient(), k.kubeClient.CoreV1().RESTClient(),
"events", "events",
@ -126,12 +137,28 @@ func (k *kubernetesInput) Read(ctx context.Context, handler bindings.Handler) er
}, },
) )
k.wg.Add(3)
readCtx, cancel := context.WithCancel(ctx)
// catch when binding is closed.
go func() {
defer k.wg.Done()
defer cancel()
select {
case <-readCtx.Done():
case <-k.closeCh:
}
}()
// Start the controller in backgound // Start the controller in backgound
stopCh := make(chan struct{}) go func() {
go controller.Run(stopCh) defer k.wg.Done()
controller.Run(readCtx.Done())
}()
// Watch for new messages and for context cancellation // Watch for new messages and for context cancellation
go func() { go func() {
defer k.wg.Done()
var ( var (
obj EventResponse obj EventResponse
data []byte data []byte
@ -148,8 +175,7 @@ func (k *kubernetesInput) Read(ctx context.Context, handler bindings.Handler) er
Data: data, Data: data,
}) })
} }
case <-ctx.Done(): case <-readCtx.Done():
close(stopCh)
return return
} }
} }
@ -157,3 +183,11 @@ func (k *kubernetesInput) Read(ctx context.Context, handler bindings.Handler) er
return nil return nil
} }
func (k *kubernetesInput) Close() error {
if k.closed.CompareAndSwap(false, true) {
close(k.closeCh)
}
k.wg.Wait()
return nil
}

View File

@ -64,7 +64,7 @@ func NewLocalStorage(logger logger.Logger) bindings.OutputBinding {
} }
// Init performs metadata parsing. // Init performs metadata parsing.
func (ls *LocalStorage) Init(metadata bindings.Metadata) error { func (ls *LocalStorage) Init(_ context.Context, metadata bindings.Metadata) error {
m, err := ls.parseMetadata(metadata) m, err := ls.parseMetadata(metadata)
if err != nil { if err != nil {
return fmt.Errorf("failed to parse metadata: %w", err) return fmt.Errorf("failed to parse metadata: %w", err)

View File

@ -40,30 +40,29 @@ type MQTT struct {
logger logger.Logger logger logger.Logger
isSubscribed atomic.Bool isSubscribed atomic.Bool
readHandler bindings.Handler readHandler bindings.Handler
ctx context.Context
cancel context.CancelFunc
backOff backoff.BackOff backOff backoff.BackOff
closeCh chan struct{}
closed atomic.Bool
wg sync.WaitGroup
} }
// NewMQTT returns a new MQTT instance. // NewMQTT returns a new MQTT instance.
func NewMQTT(logger logger.Logger) bindings.InputOutputBinding { func NewMQTT(logger logger.Logger) bindings.InputOutputBinding {
return &MQTT{ return &MQTT{
logger: logger, logger: logger,
closeCh: make(chan struct{}),
} }
} }
// Init does MQTT connection parsing. // Init does MQTT connection parsing.
func (m *MQTT) Init(metadata bindings.Metadata) (err error) { func (m *MQTT) Init(ctx context.Context, metadata bindings.Metadata) (err error) {
m.metadata, err = parseMQTTMetaData(metadata, m.logger) m.metadata, err = parseMQTTMetaData(metadata, m.logger)
if err != nil { if err != nil {
return err return err
} }
m.ctx, m.cancel = context.WithCancel(context.Background())
// TODO: Make the backoff configurable for constant or exponential // TODO: Make the backoff configurable for constant or exponential
b := backoff.NewConstantBackOff(5 * time.Second) m.backOff = backoff.NewConstantBackOff(5 * time.Second)
m.backOff = backoff.WithContext(b, m.ctx)
return nil return nil
} }
@ -104,7 +103,7 @@ func (m *MQTT) getProducer() (mqtt.Client, error) {
return p, nil return p, nil
} }
func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) { func (m *MQTT) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
producer, err := m.getProducer() producer, err := m.getProducer()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create producer connection: %w", err) return nil, fmt.Errorf("failed to create producer connection: %w", err)
@ -118,7 +117,7 @@ func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (*
bo := backoff.WithMaxRetries( bo := backoff.WithMaxRetries(
backoff.NewConstantBackOff(200*time.Millisecond), 3, backoff.NewConstantBackOff(200*time.Millisecond), 3,
) )
bo = backoff.WithContext(bo, parentCtx) bo = backoff.WithContext(bo, ctx)
topic, ok := req.Metadata[mqttTopic] topic, ok := req.Metadata[mqttTopic]
if !ok || topic == "" { if !ok || topic == "" {
@ -127,14 +126,13 @@ func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (*
} }
return nil, retry.NotifyRecover(func() (err error) { return nil, retry.NotifyRecover(func() (err error) {
token := producer.Publish(topic, m.metadata.qos, m.metadata.retain, req.Data) token := producer.Publish(topic, m.metadata.qos, m.metadata.retain, req.Data)
ctx, cancel := context.WithTimeout(parentCtx, defaultWait)
defer cancel()
select { select {
case <-token.Done(): case <-token.Done():
err = token.Error() err = token.Error()
case <-m.ctx.Done(): case <-m.closeCh:
// Context canceled err = errors.New("mqtt client closed")
err = m.ctx.Err() case <-time.After(defaultWait):
err = errors.New("mqtt client timeout")
case <-ctx.Done(): case <-ctx.Done():
// Context canceled // Context canceled
err = ctx.Err() err = ctx.Err()
@ -151,6 +149,10 @@ func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (*
} }
func (m *MQTT) Read(ctx context.Context, handler bindings.Handler) error { func (m *MQTT) Read(ctx context.Context, handler bindings.Handler) error {
if m.closed.Load() {
return errors.New("error: binding is closed")
}
// If the subscription is already active, wait 2s before retrying (in case we're still disconnecting), otherwise return an error // If the subscription is already active, wait 2s before retrying (in case we're still disconnecting), otherwise return an error
if !m.isSubscribed.CompareAndSwap(false, true) { if !m.isSubscribed.CompareAndSwap(false, true) {
m.logger.Debug("Subscription is already active; waiting 2s before retrying…") m.logger.Debug("Subscription is already active; waiting 2s before retrying…")
@ -177,11 +179,14 @@ func (m *MQTT) Read(ctx context.Context, handler bindings.Handler) error {
// In background, watch for contexts cancelation and stop the connection // In background, watch for contexts cancelation and stop the connection
// However, do not call "unsubscribe" which would cause the broker to stop tracking the last message received by this consumer group // However, do not call "unsubscribe" which would cause the broker to stop tracking the last message received by this consumer group
m.wg.Add(1)
go func() { go func() {
defer m.wg.Done()
select { select {
case <-ctx.Done(): case <-ctx.Done():
// nop // nop
case <-m.ctx.Done(): case <-m.closeCh:
// nop // nop
} }
@ -208,14 +213,12 @@ func (m *MQTT) connect(clientID string, isSubscriber bool) (mqtt.Client, error)
} }
client := mqtt.NewClient(opts) client := mqtt.NewClient(opts)
ctx, cancel := context.WithTimeout(m.ctx, defaultWait)
defer cancel()
token := client.Connect() token := client.Connect()
select { select {
case <-token.Done(): case <-token.Done():
err = token.Error() err = token.Error()
case <-ctx.Done(): case <-time.After(defaultWait):
err = ctx.Err() err = errors.New("mqtt client timed out connecting")
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to connect: %w", err) return nil, fmt.Errorf("failed to connect: %w", err)
@ -290,43 +293,46 @@ func (m *MQTT) createClientOptions(uri *url.URL, clientID string) *mqtt.ClientOp
return opts return opts
} }
func (m *MQTT) handleMessage(client mqtt.Client, mqttMsg mqtt.Message) { func (m *MQTT) handleMessage() func(client mqtt.Client, mqttMsg mqtt.Message) {
// We're using m.ctx as context in this method because we don't have access to the Read context return func(client mqtt.Client, mqttMsg mqtt.Message) {
// Canceling the Read context makes Read invoke "Disconnect" anyways bo := m.backOff
ctx := m.ctx if m.metadata.backOffMaxRetries >= 0 {
bo = backoff.WithMaxRetries(bo, uint64(m.metadata.backOffMaxRetries))
}
var bo backoff.BackOff = backoff.WithContext(m.backOff, ctx) err := retry.NotifyRecover(
if m.metadata.backOffMaxRetries >= 0 { func() error {
bo = backoff.WithMaxRetries(bo, uint64(m.metadata.backOffMaxRetries)) m.logger.Debugf("Processing MQTT message %s/%d", mqttMsg.Topic(), mqttMsg.MessageID())
} // Use a background context here so that the context is not tied to the
// first Invoke first created the producer.
// TODO: add context to mqtt library, and add a OnConnectWithContext option
// to change this func signature to
// func(c mqtt.Client, ctx context.Context)
_, err := m.readHandler(context.Background(), &bindings.ReadResponse{
Data: mqttMsg.Payload(),
Metadata: map[string]string{
mqttTopic: mqttMsg.Topic(),
},
})
if err != nil {
return err
}
err := retry.NotifyRecover( // Ack the message on success
func() error { mqttMsg.Ack()
m.logger.Debugf("Processing MQTT message %s/%d", mqttMsg.Topic(), mqttMsg.MessageID()) return nil
_, err := m.readHandler(ctx, &bindings.ReadResponse{ },
Data: mqttMsg.Payload(), bo,
Metadata: map[string]string{ func(err error, d time.Duration) {
mqttTopic: mqttMsg.Topic(), m.logger.Errorf("Error processing MQTT message: %s/%d. Retrying…", mqttMsg.Topic(), mqttMsg.MessageID())
}, },
}) func() {
if err != nil { m.logger.Infof("Successfully processed MQTT message after it previously failed: %s/%d", mqttMsg.Topic(), mqttMsg.MessageID())
return err },
} )
if err != nil {
// Ack the message on success m.logger.Errorf("Failed processing MQTT message: %s/%d: %v", mqttMsg.Topic(), mqttMsg.MessageID(), err)
mqttMsg.Ack() }
return nil
},
bo,
func(err error, d time.Duration) {
m.logger.Errorf("Error processing MQTT message: %s/%d. Retrying…", mqttMsg.Topic(), mqttMsg.MessageID())
},
func() {
m.logger.Infof("Successfully processed MQTT message after it previously failed: %s/%d", mqttMsg.Topic(), mqttMsg.MessageID())
},
)
if err != nil {
m.logger.Errorf("Failed processing MQTT message: %s/%d: %v", mqttMsg.Topic(), mqttMsg.MessageID(), err)
} }
} }
@ -336,17 +342,15 @@ func (m *MQTT) createSubscriberClientOptions(uri *url.URL, clientID string) *mqt
// On (re-)connection, add the topic subscription // On (re-)connection, add the topic subscription
opts.OnConnect = func(c mqtt.Client) { opts.OnConnect = func(c mqtt.Client) {
token := c.Subscribe(m.metadata.topic, m.metadata.qos, m.handleMessage) token := c.Subscribe(m.metadata.topic, m.metadata.qos, m.handleMessage())
var err error var err error
subscribeCtx, subscribeCancel := context.WithTimeout(m.ctx, defaultWait)
defer subscribeCancel()
select { select {
case <-token.Done(): case <-token.Done():
// Subscription went through (sucecessfully or not) // Subscription went through (sucecessfully or not)
err = token.Error() err = token.Error()
case <-subscribeCtx.Done(): case <-time.After(defaultWait):
err = fmt.Errorf("error while waiting for subscription token: %w", subscribeCtx.Err()) err = errors.New("timed out waiting for subscription to complete")
} }
// Nothing we can do in case of errors besides logging them // Nothing we can do in case of errors besides logging them
@ -363,13 +367,16 @@ func (m *MQTT) Close() error {
m.producerLock.Lock() m.producerLock.Lock()
defer m.producerLock.Unlock() defer m.producerLock.Unlock()
// Canceling the context also causes Read to stop receiving messages if m.closed.CompareAndSwap(false, true) {
m.cancel() close(m.closeCh)
}
if m.producer != nil { if m.producer != nil {
m.producer.Disconnect(200) m.producer.Disconnect(200)
m.producer = nil m.producer = nil
} }
m.wg.Wait()
return nil return nil
} }

View File

@ -49,6 +49,7 @@ func getConnectionString() string {
func TestInvokeWithTopic(t *testing.T) { func TestInvokeWithTopic(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background()
url := getConnectionString() url := getConnectionString()
if url == "" { if url == "" {
@ -79,7 +80,7 @@ func TestInvokeWithTopic(t *testing.T) {
logger := logger.NewLogger("test") logger := logger.NewLogger("test")
r := NewMQTT(logger).(*MQTT) r := NewMQTT(logger).(*MQTT)
err := r.Init(metadata) err := r.Init(ctx, metadata)
assert.Nil(t, err) assert.Nil(t, err)
conn, err := r.connect(uuid.NewString(), false) conn, err := r.connect(uuid.NewString(), false)
@ -127,4 +128,5 @@ func TestInvokeWithTopic(t *testing.T) {
assert.True(t, ok) assert.True(t, ok)
assert.Equal(t, dataCustomized, mqttMessage.Payload()) assert.Equal(t, dataCustomized, mqttMessage.Payload())
assert.Equal(t, topicCustomized, mqttMessage.Topic()) assert.Equal(t, topicCustomized, mqttMessage.Topic())
assert.NoError(t, r.Close())
} }

View File

@ -205,7 +205,6 @@ func TestParseMetadata(t *testing.T) {
logger := logger.NewLogger("test") logger := logger.NewLogger("test")
m := NewMQTT(logger).(*MQTT) m := NewMQTT(logger).(*MQTT)
m.backOff = backoff.NewConstantBackOff(5 * time.Second) m.backOff = backoff.NewConstantBackOff(5 * time.Second)
m.ctx, m.cancel = context.WithCancel(context.Background())
m.readHandler = func(ctx context.Context, r *bindings.ReadResponse) ([]byte, error) { m.readHandler = func(ctx context.Context, r *bindings.ReadResponse) ([]byte, error) {
assert.Equal(t, payload, r.Data) assert.Equal(t, payload, r.Data)
metadata := r.Metadata metadata := r.Metadata
@ -215,7 +214,7 @@ func TestParseMetadata(t *testing.T) {
return r.Data, nil return r.Data, nil
} }
m.handleMessage(nil, &mqttMockMessage{ m.handleMessage()(nil, &mqttMockMessage{
topic: topic, topic: topic,
payload: payload, payload: payload,
}) })

View File

@ -81,7 +81,7 @@ func NewMysql(logger logger.Logger) bindings.OutputBinding {
} }
// Init initializes the MySQL binding. // Init initializes the MySQL binding.
func (m *Mysql) Init(metadata bindings.Metadata) error { func (m *Mysql) Init(ctx context.Context, metadata bindings.Metadata) error {
m.logger.Debug("Initializing MySql binding") m.logger.Debug("Initializing MySql binding")
p := metadata.Properties p := metadata.Properties
@ -115,7 +115,7 @@ func (m *Mysql) Init(metadata bindings.Metadata) error {
return err return err
} }
err = db.Ping() err = db.PingContext(ctx)
if err != nil { if err != nil {
return fmt.Errorf("unable to ping the DB: %w", err) return fmt.Errorf("unable to ping the DB: %w", err)
} }

View File

@ -75,7 +75,7 @@ func TestMysqlIntegration(t *testing.T) {
b := NewMysql(logger.NewLogger("test")).(*Mysql) b := NewMysql(logger.NewLogger("test")).(*Mysql)
m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{connectionURLKey: url}}} m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{connectionURLKey: url}}}
if err := b.Init(m); err != nil { if err := b.Init(context.Background(), m); err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -21,6 +21,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/nacos-group/nacos-sdk-go/v2/clients" "github.com/nacos-group/nacos-sdk-go/v2/clients"
@ -56,6 +57,9 @@ type Nacos struct {
logger logger.Logger logger logger.Logger
configClient config_client.IConfigClient //nolint:nosnakecase configClient config_client.IConfigClient //nolint:nosnakecase
readHandler func(ctx context.Context, response *bindings.ReadResponse) ([]byte, error) readHandler func(ctx context.Context, response *bindings.ReadResponse) ([]byte, error)
wg sync.WaitGroup
closed atomic.Bool
closeCh chan struct{}
} }
// NewNacos returns a new Nacos instance. // NewNacos returns a new Nacos instance.
@ -63,11 +67,12 @@ func NewNacos(logger logger.Logger) bindings.OutputBinding {
return &Nacos{ return &Nacos{
logger: logger, logger: logger,
watchesLock: sync.Mutex{}, watchesLock: sync.Mutex{},
closeCh: make(chan struct{}),
} }
} }
// Init implements InputBinding/OutputBinding's Init method. // Init implements InputBinding/OutputBinding's Init method.
func (n *Nacos) Init(metadata bindings.Metadata) error { func (n *Nacos) Init(_ context.Context, metadata bindings.Metadata) error {
n.settings = Settings{ n.settings = Settings{
Timeout: defaultTimeout, Timeout: defaultTimeout,
} }
@ -146,6 +151,10 @@ func (n *Nacos) createConfigClient() error {
// Read implements InputBinding's Read method. // Read implements InputBinding's Read method.
func (n *Nacos) Read(ctx context.Context, handler bindings.Handler) error { func (n *Nacos) Read(ctx context.Context, handler bindings.Handler) error {
if n.closed.Load() {
return errors.New("binding is closed")
}
n.readHandler = handler n.readHandler = handler
n.watchesLock.Lock() n.watchesLock.Lock()
@ -154,9 +163,14 @@ func (n *Nacos) Read(ctx context.Context, handler bindings.Handler) error {
} }
n.watchesLock.Unlock() n.watchesLock.Unlock()
n.wg.Add(1)
go func() { go func() {
defer n.wg.Done()
// Cancel all listeners when the context is done // Cancel all listeners when the context is done
<-ctx.Done() select {
case <-ctx.Done():
case <-n.closeCh:
}
n.cancelAllListeners() n.cancelAllListeners()
}() }()
@ -165,8 +179,14 @@ func (n *Nacos) Read(ctx context.Context, handler bindings.Handler) error {
// Close implements cancel all listeners, see https://github.com/dapr/components-contrib/issues/779 // Close implements cancel all listeners, see https://github.com/dapr/components-contrib/issues/779
func (n *Nacos) Close() error { func (n *Nacos) Close() error {
if n.closed.CompareAndSwap(false, true) {
close(n.closeCh)
}
n.cancelAllListeners() n.cancelAllListeners()
n.wg.Wait()
return nil return nil
} }
@ -223,7 +243,11 @@ func (n *Nacos) addListener(ctx context.Context, config configParam) {
func (n *Nacos) addListenerFoInputBinding(ctx context.Context, config configParam) { func (n *Nacos) addListenerFoInputBinding(ctx context.Context, config configParam) {
if n.addToWatches(config) { if n.addToWatches(config) {
go n.addListener(ctx, config) n.wg.Add(1)
go func() {
defer n.wg.Done()
n.addListener(ctx, config)
}()
} }
} }

View File

@ -35,7 +35,7 @@ func TestInputBindingRead(t *testing.T) { //nolint:paralleltest
m.Properties, err = getNacosLocalCacheMetadata() m.Properties, err = getNacosLocalCacheMetadata()
require.NoError(t, err) require.NoError(t, err)
n := NewNacos(logger.NewLogger("test")).(*Nacos) n := NewNacos(logger.NewLogger("test")).(*Nacos)
err = n.Init(m) err = n.Init(context.Background(), m)
require.NoError(t, err) require.NoError(t, err)
var count int32 var count int32
ch := make(chan bool, 1) ch := make(chan bool, 1)

View File

@ -22,15 +22,15 @@ import (
// OutputBinding is the interface for an output binding, allowing users to invoke remote systems with optional payloads. // OutputBinding is the interface for an output binding, allowing users to invoke remote systems with optional payloads.
type OutputBinding interface { type OutputBinding interface {
Init(metadata Metadata) error Init(ctx context.Context, metadata Metadata) error
Invoke(ctx context.Context, req *InvokeRequest) (*InvokeResponse, error) Invoke(ctx context.Context, req *InvokeRequest) (*InvokeResponse, error)
Operations() []OperationKind Operations() []OperationKind
} }
func PingOutBinding(outputBinding OutputBinding) error { func PingOutBinding(ctx context.Context, outputBinding OutputBinding) error {
// checks if this output binding has the ping option then executes // checks if this output binding has the ping option then executes
if outputBindingWithPing, ok := outputBinding.(health.Pinger); ok { if outputBindingWithPing, ok := outputBinding.(health.Pinger); ok {
return outputBindingWithPing.Ping() return outputBindingWithPing.Ping(ctx)
} else { } else {
return fmt.Errorf("ping is not implemented by this output binding") return fmt.Errorf("ping is not implemented by this output binding")
} }

View File

@ -48,7 +48,7 @@ func NewPostgres(logger logger.Logger) bindings.OutputBinding {
} }
// Init initializes the PostgreSql binding. // Init initializes the PostgreSql binding.
func (p *Postgres) Init(metadata bindings.Metadata) error { func (p *Postgres) Init(ctx context.Context, metadata bindings.Metadata) error {
url, ok := metadata.Properties[connectionURLKey] url, ok := metadata.Properties[connectionURLKey]
if !ok || url == "" { if !ok || url == "" {
return errors.Errorf("required metadata not set: %s", connectionURLKey) return errors.Errorf("required metadata not set: %s", connectionURLKey)
@ -59,7 +59,9 @@ func (p *Postgres) Init(metadata bindings.Metadata) error {
return errors.Wrap(err, "error opening DB connection") return errors.Wrap(err, "error opening DB connection")
} }
p.db, err = pgxpool.NewWithConfig(context.Background(), poolConfig) // This context doesn't control the lifetime of the connection pool, and is
// only scoped to postgres creating resources at init.
p.db, err = pgxpool.NewWithConfig(ctx, poolConfig)
if err != nil { if err != nil {
return errors.Wrap(err, "unable to ping the DB") return errors.Wrap(err, "unable to ping the DB")
} }

View File

@ -64,7 +64,7 @@ func TestPostgresIntegration(t *testing.T) {
// live DB test // live DB test
b := NewPostgres(logger.NewLogger("test")).(*Postgres) b := NewPostgres(logger.NewLogger("test")).(*Postgres)
m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{connectionURLKey: url}}} m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{connectionURLKey: url}}}
if err := b.Init(m); err != nil { if err := b.Init(context.Background(), m); err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -74,7 +74,7 @@ func (p *Postmark) parseMetadata(meta bindings.Metadata) (postmarkMetadata, erro
} }
// Init does metadata parsing and not much else :). // Init does metadata parsing and not much else :).
func (p *Postmark) Init(metadata bindings.Metadata) error { func (p *Postmark) Init(_ context.Context, metadata bindings.Metadata) error {
// Parse input metadata // Parse input metadata
meta, err := p.parseMetadata(metadata) meta, err := p.parseMetadata(metadata)
if err != nil { if err != nil {

View File

@ -19,6 +19,8 @@ import (
"fmt" "fmt"
"math" "math"
"strconv" "strconv"
"sync"
"sync/atomic"
"time" "time"
amqp "github.com/rabbitmq/amqp091-go" amqp "github.com/rabbitmq/amqp091-go"
@ -50,6 +52,9 @@ type RabbitMQ struct {
metadata rabbitMQMetadata metadata rabbitMQMetadata
logger logger.Logger logger logger.Logger
queue amqp.Queue queue amqp.Queue
closed atomic.Bool
closeCh chan struct{}
wg sync.WaitGroup
} }
// Metadata is the rabbitmq config. // Metadata is the rabbitmq config.
@ -66,11 +71,14 @@ type rabbitMQMetadata struct {
// NewRabbitMQ returns a new rabbitmq instance. // NewRabbitMQ returns a new rabbitmq instance.
func NewRabbitMQ(logger logger.Logger) bindings.InputOutputBinding { func NewRabbitMQ(logger logger.Logger) bindings.InputOutputBinding {
return &RabbitMQ{logger: logger} return &RabbitMQ{
logger: logger,
closeCh: make(chan struct{}),
}
} }
// Init does metadata parsing and connection creation. // Init does metadata parsing and connection creation.
func (r *RabbitMQ) Init(metadata bindings.Metadata) error { func (r *RabbitMQ) Init(_ context.Context, metadata bindings.Metadata) error {
err := r.parseMetadata(metadata) err := r.parseMetadata(metadata)
if err != nil { if err != nil {
return err return err
@ -226,6 +234,10 @@ func (r *RabbitMQ) declareQueue() (amqp.Queue, error) {
} }
func (r *RabbitMQ) Read(ctx context.Context, handler bindings.Handler) error { func (r *RabbitMQ) Read(ctx context.Context, handler bindings.Handler) error {
if r.closed.Load() {
return errors.New("binding already closed")
}
msgs, err := r.channel.Consume( msgs, err := r.channel.Consume(
r.queue.Name, r.queue.Name,
"", "",
@ -239,14 +251,27 @@ func (r *RabbitMQ) Read(ctx context.Context, handler bindings.Handler) error {
return err return err
} }
readCtx, cancel := context.WithCancel(ctx)
r.wg.Add(2)
go func() { go func() {
defer r.wg.Done()
defer cancel()
select {
case <-r.closeCh:
case <-readCtx.Done():
}
}()
go func() {
defer r.wg.Done()
var err error var err error
for { for {
select { select {
case <-ctx.Done(): case <-readCtx.Done():
return return
case d := <-msgs: case d := <-msgs:
_, err = handler(ctx, &bindings.ReadResponse{ _, err = handler(readCtx, &bindings.ReadResponse{
Data: d.Body, Data: d.Body,
}) })
if err != nil { if err != nil {
@ -260,3 +285,11 @@ func (r *RabbitMQ) Read(ctx context.Context, handler bindings.Handler) error {
return nil return nil
} }
func (r *RabbitMQ) Close() error {
if r.closed.CompareAndSwap(false, true) {
close(r.closeCh)
}
defer r.wg.Wait()
return r.channel.Close()
}

View File

@ -85,7 +85,7 @@ func TestQueuesWithTTL(t *testing.T) {
logger := logger.NewLogger("test") logger := logger.NewLogger("test")
r := NewRabbitMQ(logger).(*RabbitMQ) r := NewRabbitMQ(logger).(*RabbitMQ)
err := r.Init(metadata) err := r.Init(context.Background(), metadata)
assert.Nil(t, err) assert.Nil(t, err)
// Assert that if waited too long, we won't see any message // Assert that if waited too long, we won't see any message
@ -117,6 +117,7 @@ func TestQueuesWithTTL(t *testing.T) {
assert.True(t, ok) assert.True(t, ok)
msgBody := string(msg.Body) msgBody := string(msg.Body)
assert.Equal(t, testMsgContent, msgBody) assert.Equal(t, testMsgContent, msgBody)
assert.NoError(t, r.Close())
} }
func TestPublishingWithTTL(t *testing.T) { func TestPublishingWithTTL(t *testing.T) {
@ -144,7 +145,7 @@ func TestPublishingWithTTL(t *testing.T) {
logger := logger.NewLogger("test") logger := logger.NewLogger("test")
rabbitMQBinding1 := NewRabbitMQ(logger).(*RabbitMQ) rabbitMQBinding1 := NewRabbitMQ(logger).(*RabbitMQ)
err := rabbitMQBinding1.Init(metadata) err := rabbitMQBinding1.Init(context.Background(), metadata)
assert.Nil(t, err) assert.Nil(t, err)
// Assert that if waited too long, we won't see any message // Assert that if waited too long, we won't see any message
@ -175,7 +176,7 @@ func TestPublishingWithTTL(t *testing.T) {
// Getting before it is expired, should return it // Getting before it is expired, should return it
rabbitMQBinding2 := NewRabbitMQ(logger).(*RabbitMQ) rabbitMQBinding2 := NewRabbitMQ(logger).(*RabbitMQ)
err = rabbitMQBinding2.Init(metadata) err = rabbitMQBinding2.Init(context.Background(), metadata)
assert.Nil(t, err) assert.Nil(t, err)
const testMsgContent = "test_msg" const testMsgContent = "test_msg"
@ -193,6 +194,9 @@ func TestPublishingWithTTL(t *testing.T) {
assert.True(t, ok) assert.True(t, ok)
msgBody := string(msg.Body) msgBody := string(msg.Body)
assert.Equal(t, testMsgContent, msgBody) assert.Equal(t, testMsgContent, msgBody)
assert.NoError(t, rabbitMQBinding1.Close())
assert.NoError(t, rabbitMQBinding1.Close())
} }
func TestExclusiveQueue(t *testing.T) { func TestExclusiveQueue(t *testing.T) {
@ -222,7 +226,7 @@ func TestExclusiveQueue(t *testing.T) {
logger := logger.NewLogger("test") logger := logger.NewLogger("test")
r := NewRabbitMQ(logger).(*RabbitMQ) r := NewRabbitMQ(logger).(*RabbitMQ)
err := r.Init(metadata) err := r.Init(context.Background(), metadata)
assert.Nil(t, err) assert.Nil(t, err)
// Assert that if waited too long, we won't see any message // Assert that if waited too long, we won't see any message
@ -276,7 +280,7 @@ func TestPublishWithPriority(t *testing.T) {
logger := logger.NewLogger("test") logger := logger.NewLogger("test")
r := NewRabbitMQ(logger).(*RabbitMQ) r := NewRabbitMQ(logger).(*RabbitMQ)
err := r.Init(metadata) err := r.Init(context.Background(), metadata)
assert.Nil(t, err) assert.Nil(t, err)
// Assert that if waited too long, we won't see any message // Assert that if waited too long, we won't see any message

View File

@ -28,9 +28,6 @@ type Redis struct {
client rediscomponent.RedisClient client rediscomponent.RedisClient
clientSettings *rediscomponent.Settings clientSettings *rediscomponent.Settings
logger logger.Logger logger logger.Logger
ctx context.Context
cancel context.CancelFunc
} }
// NewRedis returns a new redis bindings instance. // NewRedis returns a new redis bindings instance.
@ -39,15 +36,13 @@ func NewRedis(logger logger.Logger) bindings.OutputBinding {
} }
// Init performs metadata parsing and connection creation. // Init performs metadata parsing and connection creation.
func (r *Redis) Init(meta bindings.Metadata) (err error) { func (r *Redis) Init(ctx context.Context, meta bindings.Metadata) (err error) {
r.client, r.clientSettings, err = rediscomponent.ParseClientFromProperties(meta.Properties, nil) r.client, r.clientSettings, err = rediscomponent.ParseClientFromProperties(meta.Properties, nil)
if err != nil { if err != nil {
return err return err
} }
r.ctx, r.cancel = context.WithCancel(context.Background()) _, err = r.client.PingResult(ctx)
_, err = r.client.PingResult(r.ctx)
if err != nil { if err != nil {
return fmt.Errorf("redis binding: error connecting to redis at %s: %s", r.clientSettings.Host, err) return fmt.Errorf("redis binding: error connecting to redis at %s: %s", r.clientSettings.Host, err)
} }
@ -55,8 +50,8 @@ func (r *Redis) Init(meta bindings.Metadata) (err error) {
return err return err
} }
func (r *Redis) Ping() error { func (r *Redis) Ping(ctx context.Context) error {
if _, err := r.client.PingResult(r.ctx); err != nil { if _, err := r.client.PingResult(ctx); err != nil {
return fmt.Errorf("redis binding: error connecting to redis at %s: %s", r.clientSettings.Host, err) return fmt.Errorf("redis binding: error connecting to redis at %s: %s", r.clientSettings.Host, err)
} }
@ -101,7 +96,5 @@ func (r *Redis) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindi
} }
func (r *Redis) Close() error { func (r *Redis) Close() error {
r.cancel()
return r.client.Close() return r.client.Close()
} }

View File

@ -40,7 +40,6 @@ func TestInvokeCreate(t *testing.T) {
client: c, client: c,
logger: logger.NewLogger("test"), logger: logger.NewLogger("test"),
} }
bind.ctx, bind.cancel = context.WithCancel(context.Background())
_, err := c.DoRead(context.Background(), "GET", testKey) _, err := c.DoRead(context.Background(), "GET", testKey)
assert.Equal(t, redis.Nil, err) assert.Equal(t, redis.Nil, err)
@ -66,7 +65,6 @@ func TestInvokeGet(t *testing.T) {
client: c, client: c,
logger: logger.NewLogger("test"), logger: logger.NewLogger("test"),
} }
bind.ctx, bind.cancel = context.WithCancel(context.Background())
err := c.DoWrite(context.Background(), "SET", testKey, testData) err := c.DoWrite(context.Background(), "SET", testKey, testData)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
@ -87,7 +85,6 @@ func TestInvokeDelete(t *testing.T) {
client: c, client: c,
logger: logger.NewLogger("test"), logger: logger.NewLogger("test"),
} }
bind.ctx, bind.cancel = context.WithCancel(context.Background())
err := c.DoWrite(context.Background(), "SET", testKey, testData) err := c.DoWrite(context.Background(), "SET", testKey, testData)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)

View File

@ -18,6 +18,8 @@ import (
"encoding/json" "encoding/json"
"strconv" "strconv"
"strings" "strings"
"sync"
"sync/atomic"
"time" "time"
r "github.com/dancannon/gorethink" r "github.com/dancannon/gorethink"
@ -34,6 +36,9 @@ type Binding struct {
logger logger.Logger logger logger.Logger
session *r.Session session *r.Session
config StateConfig config StateConfig
closed atomic.Bool
closeCh chan struct{}
wg sync.WaitGroup
} }
// StateConfig is the binding config. // StateConfig is the binding config.
@ -45,12 +50,13 @@ type StateConfig struct {
// NewRethinkDBStateChangeBinding returns a new RethinkDB actor event input binding. // NewRethinkDBStateChangeBinding returns a new RethinkDB actor event input binding.
func NewRethinkDBStateChangeBinding(logger logger.Logger) bindings.InputBinding { func NewRethinkDBStateChangeBinding(logger logger.Logger) bindings.InputBinding {
return &Binding{ return &Binding{
logger: logger, logger: logger,
closeCh: make(chan struct{}),
} }
} }
// Init initializes the RethinkDB binding. // Init initializes the RethinkDB binding.
func (b *Binding) Init(metadata bindings.Metadata) error { func (b *Binding) Init(ctx context.Context, metadata bindings.Metadata) error {
cfg, err := metadataToConfig(metadata.Properties, b.logger) cfg, err := metadataToConfig(metadata.Properties, b.logger)
if err != nil { if err != nil {
return errors.Wrap(err, "unable to parse metadata properties") return errors.Wrap(err, "unable to parse metadata properties")
@ -68,6 +74,10 @@ func (b *Binding) Init(metadata bindings.Metadata) error {
// Read triggers the RethinkDB scheduler. // Read triggers the RethinkDB scheduler.
func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error { func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
if b.closed.Load() {
return errors.New("binding is closed")
}
b.logger.Infof("subscribing to state changes in %s.%s...", b.config.Database, b.config.Table) b.logger.Infof("subscribing to state changes in %s.%s...", b.config.Database, b.config.Table)
cursor, err := r.DB(b.config.Database). cursor, err := r.DB(b.config.Database).
Table(b.config.Table). Table(b.config.Table).
@ -81,8 +91,21 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
errors.Wrapf(err, "error connecting to table %s", b.config.Table) errors.Wrapf(err, "error connecting to table %s", b.config.Table)
} }
readCtx, cancel := context.WithCancel(ctx)
b.wg.Add(2)
go func() { go func() {
for ctx.Err() == nil { defer b.wg.Done()
defer cancel()
select {
case <-b.closeCh:
case <-readCtx.Done():
}
}()
go func() {
defer b.wg.Done()
for readCtx.Err() == nil {
var change interface{} var change interface{}
ok := cursor.Next(&change) ok := cursor.Next(&change)
if !ok { if !ok {
@ -105,7 +128,7 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
}, },
} }
if _, err := handler(ctx, resp); err != nil { if _, err := handler(readCtx, resp); err != nil {
b.logger.Errorf("error invoking change handler: %v", err) b.logger.Errorf("error invoking change handler: %v", err)
continue continue
} }
@ -117,6 +140,14 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
return nil return nil
} }
func (b *Binding) Close() error {
if b.closed.CompareAndSwap(false, true) {
close(b.closeCh)
}
defer b.wg.Wait()
return b.session.Close()
}
func metadataToConfig(cfg map[string]string, logger logger.Logger) (StateConfig, error) { func metadataToConfig(cfg map[string]string, logger logger.Logger) (StateConfig, error) {
c := StateConfig{} c := StateConfig{}
for k, v := range cfg { for k, v := range cfg {

View File

@ -71,7 +71,7 @@ func TestBinding(t *testing.T) {
assert.NotNil(t, m.Properties) assert.NotNil(t, m.Properties)
b := getNewRethinkActorBinding() b := getNewRethinkActorBinding()
err := b.Init(m) err := b.Init(context.Background(), m)
assert.NoErrorf(t, err, "error initializing") assert.NoErrorf(t, err, "error initializing")
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())

View File

@ -18,6 +18,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
"sync"
"sync/atomic"
"time" "time"
mqc "github.com/apache/rocketmq-client-go/v2/consumer" mqc "github.com/apache/rocketmq-client-go/v2/consumer"
@ -35,27 +37,27 @@ type RocketMQ struct {
settings Settings settings Settings
producer mqw.Producer producer mqw.Producer
ctx context.Context
cancel context.CancelFunc
backOffConfig retry.Config backOffConfig retry.Config
closeCh chan struct{}
closed atomic.Bool
wg sync.WaitGroup
} }
func NewRocketMQ(l logger.Logger) *RocketMQ { func NewRocketMQ(l logger.Logger) *RocketMQ {
return &RocketMQ{ //nolint:exhaustivestruct return &RocketMQ{ //nolint:exhaustivestruct
logger: l, logger: l,
producer: nil, producer: nil,
closeCh: make(chan struct{}),
} }
} }
// Init performs metadata parsing. // Init performs metadata parsing.
func (a *RocketMQ) Init(metadata bindings.Metadata) error { func (a *RocketMQ) Init(ctx context.Context, metadata bindings.Metadata) error {
var err error var err error
if err = a.settings.Decode(metadata.Properties); err != nil { if err = a.settings.Decode(metadata.Properties); err != nil {
return err return err
} }
a.ctx, a.cancel = context.WithCancel(context.Background())
// Default retry configuration is used if no // Default retry configuration is used if no
// backOff properties are set. // backOff properties are set.
if err = retry.DecodeConfigWithPrefix( if err = retry.DecodeConfigWithPrefix(
@ -75,6 +77,10 @@ func (a *RocketMQ) Init(metadata bindings.Metadata) error {
// Read triggers the rocketmq subscription. // Read triggers the rocketmq subscription.
func (a *RocketMQ) Read(ctx context.Context, handler bindings.Handler) error { func (a *RocketMQ) Read(ctx context.Context, handler bindings.Handler) error {
if a.closed.Load() {
return errors.New("error: binding is closed")
}
a.logger.Debugf("binding rocketmq: start read input binding") a.logger.Debugf("binding rocketmq: start read input binding")
consumer, err := a.setupConsumer() consumer, err := a.setupConsumer()
@ -114,10 +120,12 @@ func (a *RocketMQ) Read(ctx context.Context, handler bindings.Handler) error {
a.logger.Debugf("binding-rocketmq: consumer started") a.logger.Debugf("binding-rocketmq: consumer started")
// Listen for context cancelation to stop the subscription // Listen for context cancelation to stop the subscription
a.wg.Add(1)
go func() { go func() {
defer a.wg.Done()
select { select {
case <-ctx.Done(): case <-ctx.Done():
case <-a.ctx.Done(): case <-a.closeCh:
} }
innerErr := consumer.Shutdown() innerErr := consumer.Shutdown()
@ -131,8 +139,10 @@ func (a *RocketMQ) Read(ctx context.Context, handler bindings.Handler) error {
// Close implements cancel all listeners, see https://github.com/dapr/components-contrib/issues/779 // Close implements cancel all listeners, see https://github.com/dapr/components-contrib/issues/779
func (a *RocketMQ) Close() error { func (a *RocketMQ) Close() error {
a.cancel() defer a.wg.Wait()
if a.closed.CompareAndSwap(false, true) {
close(a.closeCh)
}
return nil return nil
} }
@ -199,21 +209,21 @@ func (a *RocketMQ) Operations() []bindings.OperationKind {
return []bindings.OperationKind{bindings.CreateOperation} return []bindings.OperationKind{bindings.CreateOperation}
} }
func (a *RocketMQ) Invoke(req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) { func (a *RocketMQ) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
rst := &bindings.InvokeResponse{Data: nil, Metadata: nil} rst := &bindings.InvokeResponse{Data: nil, Metadata: nil}
if req.Operation != bindings.CreateOperation { if req.Operation != bindings.CreateOperation {
return rst, fmt.Errorf("binding-rocketmq error: unsupported operation %s", req.Operation) return rst, fmt.Errorf("binding-rocketmq error: unsupported operation %s", req.Operation)
} }
return rst, a.sendMessage(req) return rst, a.sendMessage(ctx, req)
} }
func (a *RocketMQ) sendMessage(req *bindings.InvokeRequest) error { func (a *RocketMQ) sendMessage(ctx context.Context, req *bindings.InvokeRequest) error {
topic := req.Metadata[metadataRocketmqTopic] topic := req.Metadata[metadataRocketmqTopic]
if topic != "" { if topic != "" {
_, err := a.send(topic, req.Metadata[metadataRocketmqTag], req.Metadata[metadataRocketmqKey], req.Data) _, err := a.send(ctx, topic, req.Metadata[metadataRocketmqTag], req.Metadata[metadataRocketmqKey], req.Data)
if err != nil { if err != nil {
return err return err
} }
@ -229,7 +239,7 @@ func (a *RocketMQ) sendMessage(req *bindings.InvokeRequest) error {
if err != nil { if err != nil {
return err return err
} }
_, err = a.send(topic, mqExpression, req.Metadata[metadataRocketmqKey], req.Data) _, err = a.send(ctx, topic, mqExpression, req.Metadata[metadataRocketmqKey], req.Data)
if err != nil { if err != nil {
return err return err
} }
@ -239,9 +249,9 @@ func (a *RocketMQ) sendMessage(req *bindings.InvokeRequest) error {
return nil return nil
} }
func (a *RocketMQ) send(topic, mqExpr, key string, data []byte) (bool, error) { func (a *RocketMQ) send(ctx context.Context, topic, mqExpr, key string, data []byte) (bool, error) {
msg := primitive.NewMessage(topic, data).WithTag(mqExpr).WithKeys([]string{key}) msg := primitive.NewMessage(topic, data).WithTag(mqExpr).WithKeys([]string{key})
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel() defer cancel()
rst, err := a.producer.SendSync(ctx, msg) rst, err := a.producer.SendSync(ctx, msg)
if err != nil { if err != nil {

View File

@ -35,7 +35,7 @@ func TestInputBindingRead(t *testing.T) { //nolint:paralleltest
m := bindings.Metadata{} //nolint:exhaustivestruct m := bindings.Metadata{} //nolint:exhaustivestruct
m.Properties = getTestMetadata() m.Properties = getTestMetadata()
r := NewRocketMQ(logger.NewLogger("test")) r := NewRocketMQ(logger.NewLogger("test"))
err := r.Init(m) err := r.Init(context.Background(), m)
require.NoError(t, err) require.NoError(t, err)
var count int32 var count int32
@ -51,7 +51,7 @@ func TestInputBindingRead(t *testing.T) { //nolint:paralleltest
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
atomic.StoreInt32(&count, 0) atomic.StoreInt32(&count, 0)
req := &bindings.InvokeRequest{Data: []byte("hello"), Operation: bindings.CreateOperation, Metadata: map[string]string{}} req := &bindings.InvokeRequest{Data: []byte("hello"), Operation: bindings.CreateOperation, Metadata: map[string]string{}}
_, err = r.Invoke(req) _, err = r.Invoke(context.Background(), req)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(10 * time.Second) time.Sleep(10 * time.Second)

View File

@ -61,7 +61,7 @@ func NewSMTP(logger logger.Logger) bindings.OutputBinding {
} }
// Init smtp component (parse metadata). // Init smtp component (parse metadata).
func (s *Mailer) Init(metadata bindings.Metadata) error { func (s *Mailer) Init(_ context.Context, metadata bindings.Metadata) error {
// parse metadata // parse metadata
meta, err := s.parseMetadata(metadata) meta, err := s.parseMetadata(metadata)
if err != nil { if err != nil {

View File

@ -84,7 +84,7 @@ func (sg *SendGrid) parseMetadata(meta bindings.Metadata) (sendGridMetadata, err
} }
// Init does metadata parsing and not much else :). // Init does metadata parsing and not much else :).
func (sg *SendGrid) Init(metadata bindings.Metadata) error { func (sg *SendGrid) Init(_ context.Context, metadata bindings.Metadata) error {
// Parse input metadata // Parse input metadata
meta, err := sg.parseMetadata(metadata) meta, err := sg.parseMetadata(metadata)
if err != nil { if err != nil {

View File

@ -60,19 +60,19 @@ func NewSMS(logger logger.Logger) bindings.OutputBinding {
} }
} }
func (t *SMS) Init(metadata bindings.Metadata) error { func (t *SMS) Init(_ context.Context, metadata bindings.Metadata) error {
twilioM := twilioMetadata{ twilioM := twilioMetadata{
timeout: time.Minute * 5, timeout: time.Minute * 5,
} }
if metadata.Properties[fromNumber] == "" { if metadata.Properties[fromNumber] == "" {
return errors.New("\"fromNumber\" is a required field") return errors.New(`"fromNumber" is a required field`)
} }
if metadata.Properties[accountSid] == "" { if metadata.Properties[accountSid] == "" {
return errors.New("\"accountSid\" is a required field") return errors.New(`"accountSid" is a required field`)
} }
if metadata.Properties[authToken] == "" { if metadata.Properties[authToken] == "" {
return errors.New("\"authToken\" is a required field") return errors.New(`"authToken" is a required field`)
} }
twilioM.toNumber = metadata.Properties[toNumber] twilioM.toNumber = metadata.Properties[toNumber]

View File

@ -53,7 +53,7 @@ func TestInit(t *testing.T) {
m := bindings.Metadata{} m := bindings.Metadata{}
m.Properties = map[string]string{"toNumber": "toNumber", "fromNumber": "fromNumber"} m.Properties = map[string]string{"toNumber": "toNumber", "fromNumber": "fromNumber"}
tw := NewSMS(logger.NewLogger("test")) tw := NewSMS(logger.NewLogger("test"))
err := tw.Init(m) err := tw.Init(context.Background(), m)
assert.NotNil(t, err) assert.NotNil(t, err)
} }
@ -66,7 +66,7 @@ func TestParseDuration(t *testing.T) {
"authToken": "authToken", "timeout": "badtimeout", "authToken": "authToken", "timeout": "badtimeout",
} }
tw := NewSMS(logger.NewLogger("test")) tw := NewSMS(logger.NewLogger("test"))
err := tw.Init(m) err := tw.Init(context.Background(), m)
assert.NotNil(t, err) assert.NotNil(t, err)
} }
@ -85,7 +85,7 @@ func TestWriteShouldSucceed(t *testing.T) {
tw.httpClient = &http.Client{ tw.httpClient = &http.Client{
Transport: httpTransport, Transport: httpTransport,
} }
err := tw.Init(m) err := tw.Init(context.Background(), m)
assert.NoError(t, err) assert.NoError(t, err)
t.Run("Should succeed with expected url and headers", func(t *testing.T) { t.Run("Should succeed with expected url and headers", func(t *testing.T) {
@ -123,7 +123,7 @@ func TestWriteShouldFail(t *testing.T) {
tw.httpClient = &http.Client{ tw.httpClient = &http.Client{
Transport: httpTransport, Transport: httpTransport,
} }
err := tw.Init(m) err := tw.Init(context.Background(), m)
assert.NoError(t, err) assert.NoError(t, err)
t.Run("Missing 'to' should fail", func(t *testing.T) { t.Run("Missing 'to' should fail", func(t *testing.T) {
@ -180,7 +180,7 @@ func TestMessageBody(t *testing.T) {
tw.httpClient = &http.Client{ tw.httpClient = &http.Client{
Transport: httpTransport, Transport: httpTransport,
} }
err := tw.Init(m) err := tw.Init(context.Background(), m)
require.NoError(t, err) require.NoError(t, err)
tester := func(reqData []byte, expectBody string) func(t *testing.T) { tester := func(reqData []byte, expectBody string) func(t *testing.T) {

View File

@ -19,6 +19,8 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"strconv" "strconv"
"sync"
"sync/atomic"
"time" "time"
"github.com/dghubble/go-twitter/twitter" "github.com/dghubble/go-twitter/twitter"
@ -31,18 +33,21 @@ import (
// Binding represents Twitter input/output binding. // Binding represents Twitter input/output binding.
type Binding struct { type Binding struct {
client *twitter.Client client *twitter.Client
query string query string
logger logger.Logger logger logger.Logger
closed atomic.Bool
closeCh chan struct{}
wg sync.WaitGroup
} }
// NewTwitter returns a new Twitter event input binding. // NewTwitter returns a new Twitter event input binding.
func NewTwitter(logger logger.Logger) bindings.InputOutputBinding { func NewTwitter(logger logger.Logger) bindings.InputOutputBinding {
return &Binding{logger: logger} return &Binding{logger: logger, closeCh: make(chan struct{})}
} }
// Init initializes the Twitter binding. // Init initializes the Twitter binding.
func (t *Binding) Init(metadata bindings.Metadata) error { func (t *Binding) Init(ctx context.Context, metadata bindings.Metadata) error {
ck, f := metadata.Properties["consumerKey"] ck, f := metadata.Properties["consumerKey"]
if !f || ck == "" { if !f || ck == "" {
return fmt.Errorf("consumerKey not set") return fmt.Errorf("consumerKey not set")
@ -124,10 +129,17 @@ func (t *Binding) Read(ctx context.Context, handler bindings.Handler) error {
} }
t.logger.Debug("starting handler...") t.logger.Debug("starting handler...")
go demux.HandleChan(stream.Messages) t.wg.Add(2)
go func() { go func() {
<-ctx.Done() defer t.wg.Done()
demux.HandleChan(stream.Messages)
}()
go func() {
defer t.wg.Done()
select {
case <-t.closeCh:
case <-ctx.Done():
}
t.logger.Debug("stopping handler...") t.logger.Debug("stopping handler...")
stream.Stop() stream.Stop()
}() }()
@ -135,6 +147,14 @@ func (t *Binding) Read(ctx context.Context, handler bindings.Handler) error {
return nil return nil
} }
func (t *Binding) Close() error {
if t.closed.CompareAndSwap(false, true) {
close(t.closeCh)
}
t.wg.Wait()
return nil
}
// Invoke handles all operations. // Invoke handles all operations.
func (t *Binding) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) { func (t *Binding) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
t.logger.Debugf("operation: %v", req.Operation) t.logger.Debugf("operation: %v", req.Operation)

View File

@ -60,7 +60,7 @@ func getRuntimeMetadata() map[string]string {
func TestInit(t *testing.T) { func TestInit(t *testing.T) {
m := getTestMetadata() m := getTestMetadata()
tw := NewTwitter(logger.NewLogger("test")).(*Binding) tw := NewTwitter(logger.NewLogger("test")).(*Binding)
err := tw.Init(m) err := tw.Init(context.Background(), m)
assert.Nilf(t, err, "error initializing valid metadata properties") assert.Nilf(t, err, "error initializing valid metadata properties")
} }
@ -69,7 +69,7 @@ func TestInit(t *testing.T) {
func TestReadError(t *testing.T) { func TestReadError(t *testing.T) {
tw := NewTwitter(logger.NewLogger("test")).(*Binding) tw := NewTwitter(logger.NewLogger("test")).(*Binding)
m := getTestMetadata() m := getTestMetadata()
err := tw.Init(m) err := tw.Init(context.Background(), m)
assert.Nilf(t, err, "error initializing valid metadata properties") assert.Nilf(t, err, "error initializing valid metadata properties")
err = tw.Read(context.Background(), func(ctx context.Context, res *bindings.ReadResponse) ([]byte, error) { err = tw.Read(context.Background(), func(ctx context.Context, res *bindings.ReadResponse) ([]byte, error) {
@ -79,6 +79,8 @@ func TestReadError(t *testing.T) {
return nil, nil return nil, nil
}) })
assert.Error(t, err) assert.Error(t, err)
assert.NoError(t, tw.Close())
} }
// TestRead executes the Read method which calls Twiter API // TestRead executes the Read method which calls Twiter API
@ -93,7 +95,7 @@ func TestRead(t *testing.T) {
m.Properties["query"] = "microsoft" m.Properties["query"] = "microsoft"
tw := NewTwitter(logger.NewLogger("test")).(*Binding) tw := NewTwitter(logger.NewLogger("test")).(*Binding)
tw.logger.SetOutputLevel(logger.DebugLevel) tw.logger.SetOutputLevel(logger.DebugLevel)
err := tw.Init(m) err := tw.Init(context.Background(), m)
assert.Nilf(t, err, "error initializing read") assert.Nilf(t, err, "error initializing read")
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -116,6 +118,8 @@ func TestRead(t *testing.T) {
cancel() cancel()
t.Fatal("Timeout waiting for messages") t.Fatal("Timeout waiting for messages")
} }
assert.NoError(t, tw.Close())
} }
// TestInvoke executes the Invoke method which calls Twiter API // TestInvoke executes the Invoke method which calls Twiter API
@ -129,7 +133,7 @@ func TestInvoke(t *testing.T) {
m.Properties = getRuntimeMetadata() m.Properties = getRuntimeMetadata()
tw := NewTwitter(logger.NewLogger("test")).(*Binding) tw := NewTwitter(logger.NewLogger("test")).(*Binding)
tw.logger.SetOutputLevel(logger.DebugLevel) tw.logger.SetOutputLevel(logger.DebugLevel)
err := tw.Init(m) err := tw.Init(context.Background(), m)
assert.Nilf(t, err, "error initializing Invoke") assert.Nilf(t, err, "error initializing Invoke")
req := &bindings.InvokeRequest{ req := &bindings.InvokeRequest{
@ -141,4 +145,5 @@ func TestInvoke(t *testing.T) {
resp, err := tw.Invoke(context.Background(), req) resp, err := tw.Invoke(context.Background(), req)
assert.Nilf(t, err, "error on invoke") assert.Nilf(t, err, "error on invoke")
assert.NotNil(t, resp) assert.NotNil(t, resp)
assert.NoError(t, tw.Close())
} }

View File

@ -61,7 +61,7 @@ func NewZeebeCommand(logger logger.Logger) bindings.OutputBinding {
} }
// Init does metadata parsing and connection creation. // Init does metadata parsing and connection creation.
func (z *ZeebeCommand) Init(metadata bindings.Metadata) error { func (z *ZeebeCommand) Init(ctx context.Context, metadata bindings.Metadata) error {
client, err := z.clientFactory.Get(metadata) client, err := z.clientFactory.Get(metadata)
if err != nil { if err != nil {
return err return err
@ -114,7 +114,7 @@ func (z *ZeebeCommand) Invoke(ctx context.Context, req *bindings.InvokeRequest)
case UpdateJobRetriesOperation: case UpdateJobRetriesOperation:
return z.updateJobRetries(ctx, req) return z.updateJobRetries(ctx, req)
case ThrowErrorOperation: case ThrowErrorOperation:
return z.throwError(req) return z.throwError(ctx, req)
case bindings.GetOperation: case bindings.GetOperation:
fallthrough fallthrough
case bindings.CreateOperation: case bindings.CreateOperation:

View File

@ -58,7 +58,7 @@ func TestInit(t *testing.T) {
} }
cmd := ZeebeCommand{clientFactory: mcf, logger: testLogger} cmd := ZeebeCommand{clientFactory: mcf, logger: testLogger}
err := cmd.Init(metadata) err := cmd.Init(context.Background(), metadata)
assert.Error(t, err, errParsing) assert.Error(t, err, errParsing)
}) })
@ -67,7 +67,7 @@ func TestInit(t *testing.T) {
mcf := mockClientFactory{} mcf := mockClientFactory{}
cmd := ZeebeCommand{clientFactory: mcf, logger: testLogger} cmd := ZeebeCommand{clientFactory: mcf, logger: testLogger}
err := cmd.Init(metadata) err := cmd.Init(context.Background(), metadata)
assert.NoError(t, err) assert.NoError(t, err)

View File

@ -30,7 +30,7 @@ type throwErrorPayload struct {
ErrorMessage string `json:"errorMessage"` ErrorMessage string `json:"errorMessage"`
} }
func (z *ZeebeCommand) throwError(req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) { func (z *ZeebeCommand) throwError(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
var payload throwErrorPayload var payload throwErrorPayload
err := json.Unmarshal(req.Data, &payload) err := json.Unmarshal(req.Data, &payload)
if err != nil { if err != nil {
@ -53,7 +53,7 @@ func (z *ZeebeCommand) throwError(req *bindings.InvokeRequest) (*bindings.Invoke
cmd = cmd.ErrorMessage(payload.ErrorMessage) cmd = cmd.ErrorMessage(payload.ErrorMessage)
} }
_, err = cmd.Send(context.Background()) _, err = cmd.Send(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot throw error for job key %d: %w", payload.JobKey, err) return nil, fmt.Errorf("cannot throw error for job key %d: %w", payload.JobKey, err)
} }

View File

@ -19,6 +19,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
"sync"
"sync/atomic"
"time" "time"
"github.com/camunda/zeebe/clients/go/v8/pkg/entities" "github.com/camunda/zeebe/clients/go/v8/pkg/entities"
@ -39,6 +41,9 @@ type ZeebeJobWorker struct {
client zbc.Client client zbc.Client
metadata *jobWorkerMetadata metadata *jobWorkerMetadata
logger logger.Logger logger logger.Logger
closed atomic.Bool
closeCh chan struct{}
wg sync.WaitGroup
} }
// https://docs.zeebe.io/basics/job-workers.html // https://docs.zeebe.io/basics/job-workers.html
@ -64,11 +69,15 @@ type jobHandler struct {
// NewZeebeJobWorker returns a new ZeebeJobWorker instance. // NewZeebeJobWorker returns a new ZeebeJobWorker instance.
func NewZeebeJobWorker(logger logger.Logger) bindings.InputBinding { func NewZeebeJobWorker(logger logger.Logger) bindings.InputBinding {
return &ZeebeJobWorker{clientFactory: zeebe.NewClientFactoryImpl(logger), logger: logger} return &ZeebeJobWorker{
clientFactory: zeebe.NewClientFactoryImpl(logger),
logger: logger,
closeCh: make(chan struct{}),
}
} }
// Init does metadata parsing and connection creation. // Init does metadata parsing and connection creation.
func (z *ZeebeJobWorker) Init(metadata bindings.Metadata) error { func (z *ZeebeJobWorker) Init(ctx context.Context, metadata bindings.Metadata) error {
meta, err := z.parseMetadata(metadata) meta, err := z.parseMetadata(metadata)
if err != nil { if err != nil {
return err return err
@ -90,6 +99,10 @@ func (z *ZeebeJobWorker) Init(metadata bindings.Metadata) error {
} }
func (z *ZeebeJobWorker) Read(ctx context.Context, handler bindings.Handler) error { func (z *ZeebeJobWorker) Read(ctx context.Context, handler bindings.Handler) error {
if z.closed.Load() {
return fmt.Errorf("binding is closed")
}
h := jobHandler{ h := jobHandler{
callback: handler, callback: handler,
logger: z.logger, logger: z.logger,
@ -99,8 +112,14 @@ func (z *ZeebeJobWorker) Read(ctx context.Context, handler bindings.Handler) err
jobWorker := z.getJobWorker(h) jobWorker := z.getJobWorker(h)
z.wg.Add(1)
go func() { go func() {
<-ctx.Done() defer z.wg.Done()
select {
case <-z.closeCh:
case <-ctx.Done():
}
jobWorker.Close() jobWorker.Close()
jobWorker.AwaitClose() jobWorker.AwaitClose()
@ -110,6 +129,14 @@ func (z *ZeebeJobWorker) Read(ctx context.Context, handler bindings.Handler) err
return nil return nil
} }
func (z *ZeebeJobWorker) Close() error {
if z.closed.CompareAndSwap(false, true) {
close(z.closeCh)
}
z.wg.Wait()
return nil
}
func (z *ZeebeJobWorker) parseMetadata(meta bindings.Metadata) (*jobWorkerMetadata, error) { func (z *ZeebeJobWorker) parseMetadata(meta bindings.Metadata) (*jobWorkerMetadata, error) {
var m jobWorkerMetadata var m jobWorkerMetadata
err := metadata.DecodeMetadata(meta.Properties, &m) err := metadata.DecodeMetadata(meta.Properties, &m)

View File

@ -14,6 +14,7 @@ limitations under the License.
package jobworker package jobworker
import ( import (
"context"
"errors" "errors"
"testing" "testing"
@ -53,10 +54,11 @@ func TestInit(t *testing.T) {
metadata := bindings.Metadata{} metadata := bindings.Metadata{}
var mcf mockClientFactory var mcf mockClientFactory
jobWorker := ZeebeJobWorker{clientFactory: &mcf, logger: testLogger} jobWorker := ZeebeJobWorker{clientFactory: &mcf, logger: testLogger, closeCh: make(chan struct{})}
err := jobWorker.Init(metadata) err := jobWorker.Init(context.Background(), metadata)
assert.Error(t, err, ErrMissingJobType) assert.Error(t, err, ErrMissingJobType)
assert.NoError(t, jobWorker.Close())
}) })
t.Run("sets client from client factory", func(t *testing.T) { t.Run("sets client from client factory", func(t *testing.T) {
@ -66,8 +68,8 @@ func TestInit(t *testing.T) {
mcf := mockClientFactory{ mcf := mockClientFactory{
metadata: metadata, metadata: metadata,
} }
jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger} jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger, closeCh: make(chan struct{})}
err := jobWorker.Init(metadata) err := jobWorker.Init(context.Background(), metadata)
assert.NoError(t, err) assert.NoError(t, err)
@ -76,6 +78,7 @@ func TestInit(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, mc, jobWorker.client) assert.Equal(t, mc, jobWorker.client)
assert.Equal(t, metadata, mcf.metadata) assert.Equal(t, metadata, mcf.metadata)
assert.NoError(t, jobWorker.Close())
}) })
t.Run("returns error if client could not be instantiated properly", func(t *testing.T) { t.Run("returns error if client could not be instantiated properly", func(t *testing.T) {
@ -85,9 +88,10 @@ func TestInit(t *testing.T) {
error: errParsing, error: errParsing,
} }
jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger} jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger, closeCh: make(chan struct{})}
err := jobWorker.Init(metadata) err := jobWorker.Init(context.Background(), metadata)
assert.Error(t, err, errParsing) assert.Error(t, err, errParsing)
assert.NoError(t, jobWorker.Close())
}) })
t.Run("sets client from client factory", func(t *testing.T) { t.Run("sets client from client factory", func(t *testing.T) {
@ -98,8 +102,8 @@ func TestInit(t *testing.T) {
metadata: metadata, metadata: metadata,
} }
jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger} jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger, closeCh: make(chan struct{})}
err := jobWorker.Init(metadata) err := jobWorker.Init(context.Background(), metadata)
assert.NoError(t, err) assert.NoError(t, err)
@ -108,5 +112,6 @@ func TestInit(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, mc, jobWorker.client) assert.Equal(t, mc, jobWorker.client)
assert.Equal(t, metadata, mcf.metadata) assert.Equal(t, metadata, mcf.metadata)
assert.NoError(t, jobWorker.Close())
}) })
} }

View File

@ -73,7 +73,7 @@ func NewAzureAppConfigurationStore(logger logger.Logger) configuration.Store {
} }
// Init does metadata and connection parsing. // Init does metadata and connection parsing.
func (r *ConfigurationStore) Init(metadata configuration.Metadata) error { func (r *ConfigurationStore) Init(_ context.Context, metadata configuration.Metadata) error {
m, err := parseMetadata(metadata) m, err := parseMetadata(metadata)
if err != nil { if err != nil {
return err return err

View File

@ -204,7 +204,7 @@ func TestInit(t *testing.T) {
Properties: testProperties, Properties: testProperties,
}} }}
err := s.Init(m) err := s.Init(context.Background(), m)
assert.Nil(t, err) assert.Nil(t, err)
cs, ok := s.(*ConfigurationStore) cs, ok := s.(*ConfigurationStore)
assert.True(t, ok) assert.True(t, ok)
@ -229,7 +229,7 @@ func TestInit(t *testing.T) {
Properties: testProperties, Properties: testProperties,
}} }}
err := s.Init(m) err := s.Init(context.Background(), m)
assert.Nil(t, err) assert.Nil(t, err)
cs, ok := s.(*ConfigurationStore) cs, ok := s.(*ConfigurationStore)
assert.True(t, ok) assert.True(t, ok)

View File

@ -86,7 +86,7 @@ func NewPostgresConfigurationStore(logger logger.Logger) configuration.Store {
} }
} }
func (p *ConfigurationStore) Init(metadata configuration.Metadata) error { func (p *ConfigurationStore) Init(parentCtx context.Context, metadata configuration.Metadata) error {
p.logger.Debug(InfoStartInit) p.logger.Debug(InfoStartInit)
if p.client != nil { if p.client != nil {
return fmt.Errorf(ErrorAlreadyInitialized) return fmt.Errorf(ErrorAlreadyInitialized)
@ -98,7 +98,7 @@ func (p *ConfigurationStore) Init(metadata configuration.Metadata) error {
p.metadata = m p.metadata = m
} }
p.ActiveSubscriptions = make(map[string]*subscription) p.ActiveSubscriptions = make(map[string]*subscription)
ctx, cancel := context.WithTimeout(context.Background(), p.metadata.maxIdleTimeout) ctx, cancel := context.WithTimeout(parentCtx, p.metadata.maxIdleTimeout)
defer cancel() defer cancel()
client, err := Connect(ctx, p.metadata.connectionString, p.metadata.maxIdleTimeout) client, err := Connect(ctx, p.metadata.connectionString, p.metadata.maxIdleTimeout)
if err != nil { if err != nil {

View File

@ -143,7 +143,7 @@ func parseRedisMetadata(meta configuration.Metadata) (metadata, error) {
} }
// Init does metadata and connection parsing. // Init does metadata and connection parsing.
func (r *ConfigurationStore) Init(metadata configuration.Metadata) error { func (r *ConfigurationStore) Init(ctx context.Context, metadata configuration.Metadata) error {
m, err := parseRedisMetadata(metadata) m, err := parseRedisMetadata(metadata)
if err != nil { if err != nil {
return err return err
@ -156,11 +156,11 @@ func (r *ConfigurationStore) Init(metadata configuration.Metadata) error {
r.client = r.newClient(m) r.client = r.newClient(m)
} }
if _, err = r.client.Ping(context.TODO()).Result(); err != nil { if _, err = r.client.Ping(ctx).Result(); err != nil {
return fmt.Errorf("redis store: error connecting to redis at %s: %s", m.Host, err) return fmt.Errorf("redis store: error connecting to redis at %s: %s", m.Host, err)
} }
r.replicas, err = r.getConnectedSlaves() r.replicas, err = r.getConnectedSlaves(ctx)
return err return err
} }
@ -204,8 +204,8 @@ func (r *ConfigurationStore) newFailoverClient(m metadata) *redis.Client {
return redis.NewFailoverClient(opts) return redis.NewFailoverClient(opts)
} }
func (r *ConfigurationStore) getConnectedSlaves() (int, error) { func (r *ConfigurationStore) getConnectedSlaves(ctx context.Context) (int, error) {
res, err := r.client.Do(context.Background(), "INFO", "replication").Result() res, err := r.client.Do(ctx, "INFO", "replication").Result()
if err != nil { if err != nil {
return 0, err return 0, err
} }

View File

@ -18,7 +18,7 @@ import "context"
// Store is an interface to perform operations on store. // Store is an interface to perform operations on store.
type Store interface { type Store interface {
// Init configuration store. // Init configuration store.
Init(metadata Metadata) error Init(ctx context.Context, metadata Metadata) error
// Get configuration. // Get configuration.
Get(ctx context.Context, req *GetRequest) (*GetResponse, error) Get(ctx context.Context, req *GetRequest) (*GetResponse, error)

2
go.mod
View File

@ -400,3 +400,5 @@ replace github.com/toolkits/concurrent => github.com/niean/gotools v0.0.0-201512
// this is a fork which addresses a performance issues due to go routines // this is a fork which addresses a performance issues due to go routines
replace dubbo.apache.org/dubbo-go/v3 => dubbo.apache.org/dubbo-go/v3 v3.0.3-0.20230118042253-4f159a2b38f3 replace dubbo.apache.org/dubbo-go/v3 => dubbo.apache.org/dubbo-go/v3 v3.0.3-0.20230118042253-4f159a2b38f3
replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181

View File

@ -1,5 +1,22 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package health package health
import (
"context"
)
type Pinger interface { type Pinger interface {
Ping() error Ping(ctx context.Context) error
} }

View File

@ -19,6 +19,7 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/Shopify/sarama" "github.com/Shopify/sarama"
@ -31,6 +32,7 @@ type consumer struct {
k *Kafka k *Kafka
ready chan bool ready chan bool
running chan struct{} running chan struct{}
stopped atomic.Bool
once sync.Once once sync.Once
mutex sync.Mutex mutex sync.Mutex
} }
@ -275,9 +277,6 @@ func (k *Kafka) Subscribe(ctx context.Context) error {
k.cg = cg k.cg = cg
ctx, cancel := context.WithCancel(ctx)
k.cancel = cancel
ready := make(chan bool) ready := make(chan bool)
k.consumer = consumer{ k.consumer = consumer{
k: k, k: k,
@ -320,7 +319,10 @@ func (k *Kafka) Subscribe(ctx context.Context) error {
k.logger.Errorf("Error closing consumer group: %v", err) k.logger.Errorf("Error closing consumer group: %v", err)
} }
close(k.consumer.running) // Ensure running channel is only closed once.
if k.consumer.stopped.CompareAndSwap(false, true) {
close(k.consumer.running)
}
}() }()
<-ready <-ready
@ -331,7 +333,6 @@ func (k *Kafka) Subscribe(ctx context.Context) error {
// Close down consumer group resources, refresh once. // Close down consumer group resources, refresh once.
func (k *Kafka) closeSubscriptionResources() { func (k *Kafka) closeSubscriptionResources() {
if k.cg != nil { if k.cg != nil {
k.cancel()
err := k.cg.Close() err := k.cg.Close()
if err != nil { if err != nil {
k.logger.Errorf("Error closing consumer group: %v", err) k.logger.Errorf("Error closing consumer group: %v", err)

View File

@ -36,7 +36,6 @@ type Kafka struct {
saslPassword string saslPassword string
initialOffset int64 initialOffset int64
cg sarama.ConsumerGroup cg sarama.ConsumerGroup
cancel context.CancelFunc
consumer consumer consumer consumer
config *sarama.Config config *sarama.Config
subscribeTopics TopicHandlerConfig subscribeTopics TopicHandlerConfig
@ -60,7 +59,7 @@ func NewKafka(logger logger.Logger) *Kafka {
} }
// Init does metadata parsing and connection establishment. // Init does metadata parsing and connection establishment.
func (k *Kafka) Init(metadata map[string]string) error { func (k *Kafka) Init(_ context.Context, metadata map[string]string) error {
upgradedMetadata, err := k.upgradeMetadata(metadata) upgradedMetadata, err := k.upgradeMetadata(metadata)
if err != nil { if err != nil {
return err return err

View File

@ -107,7 +107,8 @@ func (ts *OAuthTokenSource) Token() (*sarama.AccessToken, error) {
oidcCfg := ccred.Config{ClientID: ts.ClientID, ClientSecret: ts.ClientSecret, Scopes: ts.Scopes, TokenURL: ts.TokenEndpoint.TokenURL, AuthStyle: ts.TokenEndpoint.AuthStyle} oidcCfg := ccred.Config{ClientID: ts.ClientID, ClientSecret: ts.ClientSecret, Scopes: ts.Scopes, TokenURL: ts.TokenEndpoint.TokenURL, AuthStyle: ts.TokenEndpoint.AuthStyle}
timeoutCtx, _ := ctx.WithTimeout(ctx.TODO(), tokenRequestTimeout) //nolint:govet timeoutCtx, cancel := ctx.WithTimeout(ctx.TODO(), tokenRequestTimeout)
defer cancel()
ts.configureClient() ts.configureClient()

View File

@ -38,9 +38,6 @@ type StandaloneRedisLock struct {
metadata rediscomponent.Metadata metadata rediscomponent.Metadata
logger logger.Logger logger logger.Logger
ctx context.Context
cancel context.CancelFunc
} }
// NewStandaloneRedisLock returns a new standalone redis lock. // NewStandaloneRedisLock returns a new standalone redis lock.
@ -54,7 +51,7 @@ func NewStandaloneRedisLock(logger logger.Logger) lock.Store {
} }
// Init StandaloneRedisLock. // Init StandaloneRedisLock.
func (r *StandaloneRedisLock) InitLockStore(metadata lock.Metadata) error { func (r *StandaloneRedisLock) InitLockStore(ctx context.Context, metadata lock.Metadata) error {
// 1. parse config // 1. parse config
m, err := rediscomponent.ParseRedisMetadata(metadata.Properties) m, err := rediscomponent.ParseRedisMetadata(metadata.Properties)
if err != nil { if err != nil {
@ -75,13 +72,12 @@ func (r *StandaloneRedisLock) InitLockStore(metadata lock.Metadata) error {
if err != nil { if err != nil {
return err return err
} }
r.ctx, r.cancel = context.WithCancel(context.Background())
// 3. connect to redis // 3. connect to redis
if _, err = r.client.PingResult(r.ctx); err != nil { if _, err = r.client.PingResult(ctx); err != nil {
return fmt.Errorf("[standaloneRedisLock]: error connecting to redis at %s: %s", r.clientSettings.Host, err) return fmt.Errorf("[standaloneRedisLock]: error connecting to redis at %s: %s", r.clientSettings.Host, err)
} }
// no replica // no replica
replicas, err := r.getConnectedSlaves() replicas, err := r.getConnectedSlaves(ctx)
// pass the validation if error occurs, // pass the validation if error occurs,
// since some redis versions such as miniredis do not recognize the `INFO` command. // since some redis versions such as miniredis do not recognize the `INFO` command.
if err == nil && replicas > 0 { if err == nil && replicas > 0 {
@ -101,8 +97,8 @@ func needFailover(properties map[string]string) bool {
return false return false
} }
func (r *StandaloneRedisLock) getConnectedSlaves() (int, error) { func (r *StandaloneRedisLock) getConnectedSlaves(ctx context.Context) (int, error) {
res, err := r.client.DoRead(r.ctx, "INFO", "replication") res, err := r.client.DoRead(ctx, "INFO", "replication")
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -183,9 +179,6 @@ func newInternalErrorUnlockResponse() *lock.UnlockResponse {
// Close shuts down the client's redis connections. // Close shuts down the client's redis connections.
func (r *StandaloneRedisLock) Close() error { func (r *StandaloneRedisLock) Close() error {
if r.cancel != nil {
r.cancel()
}
if r.client != nil { if r.client != nil {
closeErr := r.client.Close() closeErr := r.client.Close()
r.client = nil r.client = nil

View File

@ -42,7 +42,7 @@ func TestStandaloneRedisLock_InitError(t *testing.T) {
cfg.Properties["redisPassword"] = "" cfg.Properties["redisPassword"] = ""
// init // init
err := comp.InitLockStore(cfg) err := comp.InitLockStore(context.Background(), cfg)
assert.Error(t, err) assert.Error(t, err)
}) })
@ -58,7 +58,7 @@ func TestStandaloneRedisLock_InitError(t *testing.T) {
cfg.Properties["redisPassword"] = "" cfg.Properties["redisPassword"] = ""
// init // init
err := comp.InitLockStore(cfg) err := comp.InitLockStore(context.Background(), cfg)
assert.Error(t, err) assert.Error(t, err)
}) })
@ -75,7 +75,7 @@ func TestStandaloneRedisLock_InitError(t *testing.T) {
cfg.Properties["maxRetries"] = "1 " cfg.Properties["maxRetries"] = "1 "
// init // init
err := comp.InitLockStore(cfg) err := comp.InitLockStore(context.Background(), cfg)
assert.Error(t, err) assert.Error(t, err)
}) })
} }
@ -96,7 +96,7 @@ func TestStandaloneRedisLock_TryLock(t *testing.T) {
cfg.Properties["redisHost"] = s.Addr() cfg.Properties["redisHost"] = s.Addr()
cfg.Properties["redisPassword"] = "" cfg.Properties["redisPassword"] = ""
// init // init
err = comp.InitLockStore(cfg) err = comp.InitLockStore(context.Background(), cfg)
assert.NoError(t, err) assert.NoError(t, err)
// 1. client1 trylock // 1. client1 trylock
ownerID1 := uuid.New().String() ownerID1 := uuid.New().String()

View File

@ -17,7 +17,7 @@ import "context"
type Store interface { type Store interface {
// Init this component. // Init this component.
InitLockStore(metadata Metadata) error InitLockStore(ctx context.Context, metadata Metadata) error
// TryLock tries to acquire a lock. // TryLock tries to acquire a lock.
TryLock(ctx context.Context, req *TryLockRequest) (*TryLockResponse, error) TryLock(ctx context.Context, req *TryLockRequest) (*TryLockResponse, error)

View File

@ -45,13 +45,13 @@ const (
) )
// GetHandler retruns the HTTP handler provided by the middleware. // GetHandler retruns the HTTP handler provided by the middleware.
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) { func (m *Middleware) GetHandler(ctx context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
meta, err := m.getNativeMetadata(metadata) meta, err := m.getNativeMetadata(metadata)
if err != nil { if err != nil {
return nil, err return nil, err
} }
provider, err := oidc.NewProvider(context.Background(), meta.IssuerURL) provider, err := oidc.NewProvider(ctx, meta.IssuerURL)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -14,6 +14,7 @@ limitations under the License.
package oauth2 package oauth2
import ( import (
"context"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -59,7 +60,7 @@ const (
) )
// GetHandler retruns the HTTP handler provided by the middleware. // GetHandler retruns the HTTP handler provided by the middleware.
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) { func (m *Middleware) GetHandler(ctx context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
meta, err := m.getNativeMetadata(metadata) meta, err := m.getNativeMetadata(metadata)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -5,6 +5,7 @@
package mock_oauth2clientcredentials package mock_oauth2clientcredentials
import ( import (
"context"
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
@ -36,7 +37,7 @@ func (m *MockTokenProviderInterface) EXPECT() *MockTokenProviderInterfaceMockRec
} }
// GetToken mocks base method // GetToken mocks base method
func (m *MockTokenProviderInterface) GetToken(arg0 *clientcredentials.Config) (*oauth2.Token, error) { func (m *MockTokenProviderInterface) GetToken(ctx context.Context, arg0 *clientcredentials.Config) (*oauth2.Token, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetToken", arg0) ret := m.ctrl.Call(m, "GetToken", arg0)
ret0, _ := ret[0].(*oauth2.Token) ret0, _ := ret[0].(*oauth2.Token)

View File

@ -45,7 +45,7 @@ type oAuth2ClientCredentialsMiddlewareMetadata struct {
// TokenProviderInterface provides a common interface to Mock the Token retrieval in unit tests. // TokenProviderInterface provides a common interface to Mock the Token retrieval in unit tests.
type TokenProviderInterface interface { type TokenProviderInterface interface {
GetToken(conf *clientcredentials.Config) (*oauth2.Token, error) GetToken(ctx context.Context, conf *clientcredentials.Config) (*oauth2.Token, error)
} }
// NewOAuth2ClientCredentialsMiddleware returns a new oAuth2 middleware. // NewOAuth2ClientCredentialsMiddleware returns a new oAuth2 middleware.
@ -68,7 +68,7 @@ type Middleware struct {
} }
// GetHandler retruns the HTTP handler provided by the middleware. // GetHandler retruns the HTTP handler provided by the middleware.
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) { func (m *Middleware) GetHandler(_ context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
meta, err := m.getNativeMetadata(metadata) meta, err := m.getNativeMetadata(metadata)
if err != nil { if err != nil {
m.log.Errorf("getNativeMetadata error: %s", err) m.log.Errorf("getNativeMetadata error: %s", err)
@ -101,7 +101,7 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Ha
if !found { if !found {
m.log.Debugf("Cached token not found, try get one") m.log.Debugf("Cached token not found, try get one")
token, err := m.tokenProvider.GetToken(conf) token, err := m.tokenProvider.GetToken(r.Context(), conf)
if err != nil { if err != nil {
m.log.Errorf("Error acquiring token: %s", err) m.log.Errorf("Error acquiring token: %s", err)
return return
@ -171,8 +171,8 @@ func (m *Middleware) SetTokenProvider(tokenProvider TokenProviderInterface) {
} }
// GetToken returns a token from the current OAuth2 ClientCredentials Configuration. // GetToken returns a token from the current OAuth2 ClientCredentials Configuration.
func (m *Middleware) GetToken(conf *clientcredentials.Config) (*oauth2.Token, error) { func (m *Middleware) GetToken(ctx context.Context, conf *clientcredentials.Config) (*oauth2.Token, error) {
tokenSource := conf.TokenSource(context.Background()) tokenSource := conf.TokenSource(ctx)
return tokenSource.Token() return tokenSource.Token()
} }

View File

@ -14,6 +14,7 @@ limitations under the License.
package oauth2clientcredentials package oauth2clientcredentials
import ( import (
"context"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -45,7 +46,7 @@ func TestOAuth2ClientCredentialsMetadata(t *testing.T) {
metadata.Properties = map[string]string{} metadata.Properties = map[string]string{}
log := logger.NewLogger("oauth2clientcredentials.test") log := logger.NewLogger("oauth2clientcredentials.test")
_, err := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata) _, err := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(context.Background(), metadata)
assert.EqualError(t, err, "metadata errors: Parameter 'headerName' needs to be set. Parameter 'clientID' needs to be set. Parameter 'clientSecret' needs to be set. Parameter 'scopes' needs to be set. Parameter 'tokenURL' needs to be set. ") assert.EqualError(t, err, "metadata errors: Parameter 'headerName' needs to be set. Parameter 'clientID' needs to be set. Parameter 'clientSecret' needs to be set. Parameter 'scopes' needs to be set. Parameter 'tokenURL' needs to be set. ")
// Invalid authStyle (non int) // Invalid authStyle (non int)
@ -57,17 +58,17 @@ func TestOAuth2ClientCredentialsMetadata(t *testing.T) {
"headerName": "someHeader", "headerName": "someHeader",
"authStyle": "asdf", // This is the value to test "authStyle": "asdf", // This is the value to test
} }
_, err2 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata) _, err2 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(context.Background(), metadata)
assert.EqualError(t, err2, "metadata errors: 1 error(s) decoding:\n\n* cannot parse 'AuthStyle' as int: strconv.ParseInt: parsing \"asdf\": invalid syntax") assert.EqualError(t, err2, "metadata errors: 1 error(s) decoding:\n\n* cannot parse 'AuthStyle' as int: strconv.ParseInt: parsing \"asdf\": invalid syntax")
// Invalid authStyle (int > 2) // Invalid authStyle (int > 2)
metadata.Properties["authStyle"] = "3" metadata.Properties["authStyle"] = "3"
_, err3 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata) _, err3 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(context.Background(), metadata)
assert.EqualError(t, err3, "metadata errors: Parameter 'authStyle' can only have the values 0,1,2. Received: '3'. ") assert.EqualError(t, err3, "metadata errors: Parameter 'authStyle' can only have the values 0,1,2. Received: '3'. ")
// Invalid authStyle (int < 0) // Invalid authStyle (int < 0)
metadata.Properties["authStyle"] = "-1" metadata.Properties["authStyle"] = "-1"
_, err4 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata) _, err4 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(context.Background(), metadata)
assert.EqualError(t, err4, "metadata errors: Parameter 'authStyle' can only have the values 0,1,2. Received: '-1'. ") assert.EqualError(t, err4, "metadata errors: Parameter 'authStyle' can only have the values 0,1,2. Received: '-1'. ")
} }
@ -109,7 +110,7 @@ func TestOAuth2ClientCredentialsToken(t *testing.T) {
log := logger.NewLogger("oauth2clientcredentials.test") log := logger.NewLogger("oauth2clientcredentials.test")
oauth2clientcredentialsMiddleware, _ := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware) oauth2clientcredentialsMiddleware, _ := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware)
oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider) oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider)
handler, err := oauth2clientcredentialsMiddleware.GetHandler(metadata) handler, err := oauth2clientcredentialsMiddleware.GetHandler(context.Background(), metadata)
require.NoError(t, err) require.NoError(t, err)
// First handler call should return abc Token // First handler call should return abc Token
@ -169,7 +170,7 @@ func TestOAuth2ClientCredentialsCache(t *testing.T) {
log := logger.NewLogger("oauth2clientcredentials.test") log := logger.NewLogger("oauth2clientcredentials.test")
oauth2clientcredentialsMiddleware, _ := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware) oauth2clientcredentialsMiddleware, _ := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware)
oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider) oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider)
handler, err := oauth2clientcredentialsMiddleware.GetHandler(metadata) handler, err := oauth2clientcredentialsMiddleware.GetHandler(context.Background(), metadata)
require.NoError(t, err) require.NoError(t, err)
// First handler call should return abc Token // First handler call should return abc Token

View File

@ -106,13 +106,13 @@ func (s *Status) Valid() bool {
} }
// GetHandler returns the HTTP handler provided by the middleware. // GetHandler returns the HTTP handler provided by the middleware.
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) { func (m *Middleware) GetHandler(parentCtx context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
meta, err := m.getNativeMetadata(metadata) meta, err := m.getNativeMetadata(metadata)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ctx, cancel := context.WithTimeout(context.TODO(), time.Minute) ctx, cancel := context.WithTimeout(parentCtx, time.Minute)
query, err := rego.New( query, err := rego.New(
rego.Query("result = data.http.allow"), rego.Query("result = data.http.allow"),
rego.Module("inline.rego", meta.Rego), rego.Module("inline.rego", meta.Rego),

View File

@ -14,6 +14,7 @@ limitations under the License.
package opa package opa
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -331,7 +332,7 @@ func TestOpaPolicy(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
opaMiddleware := NewMiddleware(log) opaMiddleware := NewMiddleware(log)
handler, err := opaMiddleware.GetHandler(test.meta) handler, err := opaMiddleware.GetHandler(context.Background(), test.meta)
if test.shouldHandlerError { if test.shouldHandlerError {
require.Error(t, err) require.Error(t, err)
return return

View File

@ -14,6 +14,7 @@ limitations under the License.
package ratelimit package ratelimit
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"strconv" "strconv"
@ -46,7 +47,7 @@ func NewRateLimitMiddleware(_ logger.Logger) middleware.Middleware {
type Middleware struct{} type Middleware struct{}
// GetHandler returns the HTTP handler provided by the middleware. // GetHandler returns the HTTP handler provided by the middleware.
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) { func (m *Middleware) GetHandler(_ context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
meta, err := m.getNativeMetadata(metadata) meta, err := m.getNativeMetadata(metadata)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -40,7 +40,7 @@ func NewMiddleware(logger logger.Logger) middleware.Middleware {
} }
// GetHandler retruns the HTTP handler provided by the middleware. // GetHandler retruns the HTTP handler provided by the middleware.
func (m *Middleware) GetHandler(metadata middleware.Metadata) ( func (m *Middleware) GetHandler(_ context.Context, metadata middleware.Metadata) (
func(next http.Handler) http.Handler, error, func(next http.Handler) http.Handler, error,
) { ) {
if err := m.getNativeMetadata(metadata); err != nil { if err := m.getNativeMetadata(metadata); err != nil {

View File

@ -14,6 +14,7 @@ limitations under the License.
package routeralias package routeralias
import ( import (
"context"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -46,7 +47,7 @@ func TestRequestHandlerWithIllegalRouterRule(t *testing.T) {
} }
log := logger.NewLogger("routeralias.test") log := logger.NewLogger("routeralias.test")
ralias := NewMiddleware(log) ralias := NewMiddleware(log)
handler, err := ralias.GetHandler(meta) handler, err := ralias.GetHandler(context.Background(), meta)
assert.Nil(t, err) assert.Nil(t, err)
t.Run("hit: change router with common request", func(t *testing.T) { t.Run("hit: change router with common request", func(t *testing.T) {

Some files were not shown because too many files have changed in this diff Show More