Merge branch 'master' of https://github.com/dapr/components-contrib into newdeps

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
ItalyPaleAle 2023-02-16 23:35:28 +00:00
commit a484d7ebc7
297 changed files with 1793 additions and 1118 deletions

View File

@ -18,3 +18,5 @@ require (
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
)
replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181

View File

@ -33,3 +33,5 @@ require (
google.golang.org/protobuf v1.28.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181

View File

@ -80,7 +80,7 @@ func NewDingTalkWebhook(l logger.Logger) bindings.InputOutputBinding {
}
// Init performs metadata parsing.
func (t *DingTalkWebhook) Init(metadata bindings.Metadata) error {
func (t *DingTalkWebhook) Init(_ context.Context, metadata bindings.Metadata) error {
var err error
if err = t.settings.Decode(metadata.Properties); err != nil {
return fmt.Errorf("dingtalk configuration error: %w", err)
@ -107,6 +107,13 @@ func (t *DingTalkWebhook) Read(ctx context.Context, handler bindings.Handler) er
return nil
}
func (t *DingTalkWebhook) Close() error {
webhooks.Lock()
defer webhooks.Unlock()
delete(webhooks.m, t.settings.ID)
return nil
}
// Operations returns list of operations supported by dingtalk webhook binding.
func (t *DingTalkWebhook) Operations() []bindings.OperationKind {
return []bindings.OperationKind{bindings.CreateOperation, bindings.GetOperation}

View File

@ -57,7 +57,7 @@ func TestPublishMsg(t *testing.T) { //nolint:paralleltest
}}}
d := NewDingTalkWebhook(logger.NewLogger("test"))
err := d.Init(m)
err := d.Init(context.Background(), m)
require.NoError(t, err)
req := &bindings.InvokeRequest{Data: []byte(msg), Operation: bindings.CreateOperation, Metadata: map[string]string{}}
@ -78,7 +78,7 @@ func TestBindingReadAndInvoke(t *testing.T) { //nolint:paralleltest
}}
d := NewDingTalkWebhook(logger.NewLogger("test"))
err := d.Init(m)
err := d.Init(context.Background(), m)
assert.NoError(t, err)
var count int32
@ -106,3 +106,18 @@ func TestBindingReadAndInvoke(t *testing.T) { //nolint:paralleltest
require.FailNow(t, "read timeout")
}
}
func TestBindingClose(t *testing.T) {
d := NewDingTalkWebhook(logger.NewLogger("test"))
m := bindings.Metadata{Base: metadata.Base{
Name: "test",
Properties: map[string]string{
"url": "/test",
"secret": "",
"id": "x",
},
}}
assert.NoError(t, d.Init(context.Background(), m))
assert.NoError(t, d.Close())
assert.NoError(t, d.Close(), "second close should not error")
}

View File

@ -45,7 +45,7 @@ func NewAliCloudOSS(logger logger.Logger) bindings.OutputBinding {
}
// Init does metadata parsing and connection creation.
func (s *AliCloudOSS) Init(metadata bindings.Metadata) error {
func (s *AliCloudOSS) Init(_ context.Context, metadata bindings.Metadata) error {
m, err := s.parseMetadata(metadata)
if err != nil {
return err

View File

@ -31,7 +31,7 @@ type Callback struct {
}
// parse metadata field
func (s *AliCloudSlsLogstorage) Init(metadata bindings.Metadata) error {
func (s *AliCloudSlsLogstorage) Init(_ context.Context, metadata bindings.Metadata) error {
m, err := s.parseMeta(metadata)
if err != nil {
return err

View File

@ -58,7 +58,7 @@ func NewAliCloudTableStore(log logger.Logger) bindings.OutputBinding {
}
}
func (s *AliCloudTableStore) Init(metadata bindings.Metadata) error {
func (s *AliCloudTableStore) Init(_ context.Context, metadata bindings.Metadata) error {
m, err := s.parseMetadata(metadata)
if err != nil {
return err

View File

@ -51,7 +51,7 @@ func TestDataEncodeAndDecode(t *testing.T) {
metadata := bindings.Metadata{Base: metadata.Base{
Properties: getTestProperties(),
}}
aliCloudTableStore.Init(metadata)
aliCloudTableStore.Init(context.Background(), metadata)
// test create
putData := map[string]interface{}{

View File

@ -78,7 +78,7 @@ func NewAPNS(logger logger.Logger) bindings.OutputBinding {
// Init will configure the APNS output binding using the metadata specified
// in the binding's configuration.
func (a *APNS) Init(metadata bindings.Metadata) error {
func (a *APNS) Init(ctx context.Context, metadata bindings.Metadata) error {
if err := a.makeURLPrefix(metadata); err != nil {
return err
}

View File

@ -51,7 +51,7 @@ func TestInit(t *testing.T) {
},
}}
binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata)
err := binding.Init(context.Background(), metadata)
assert.Nil(t, err)
assert.Equal(t, developmentPrefix, binding.urlPrefix)
})
@ -66,7 +66,7 @@ func TestInit(t *testing.T) {
},
}}
binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata)
err := binding.Init(context.Background(), metadata)
assert.Nil(t, err)
assert.Equal(t, productionPrefix, binding.urlPrefix)
})
@ -80,7 +80,7 @@ func TestInit(t *testing.T) {
},
}}
binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata)
err := binding.Init(context.Background(), metadata)
assert.Nil(t, err)
assert.Equal(t, productionPrefix, binding.urlPrefix)
})
@ -95,7 +95,7 @@ func TestInit(t *testing.T) {
},
}}
binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata)
err := binding.Init(context.Background(), metadata)
assert.Error(t, err, "invalid value for development parameter: True")
})
@ -107,7 +107,7 @@ func TestInit(t *testing.T) {
},
}}
binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata)
err := binding.Init(context.Background(), metadata)
assert.Error(t, err, "the key-id parameter is required")
})
@ -120,7 +120,7 @@ func TestInit(t *testing.T) {
},
}}
binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata)
err := binding.Init(context.Background(), metadata)
assert.Nil(t, err)
assert.Equal(t, testKeyID, binding.authorizationBuilder.keyID)
})
@ -133,7 +133,7 @@ func TestInit(t *testing.T) {
},
}}
binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata)
err := binding.Init(context.Background(), metadata)
assert.Error(t, err, "the team-id parameter is required")
})
@ -146,7 +146,7 @@ func TestInit(t *testing.T) {
},
}}
binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata)
err := binding.Init(context.Background(), metadata)
assert.Nil(t, err)
assert.Equal(t, testTeamID, binding.authorizationBuilder.teamID)
})
@ -159,7 +159,7 @@ func TestInit(t *testing.T) {
},
}}
binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata)
err := binding.Init(context.Background(), metadata)
assert.Error(t, err, "the private-key parameter is required")
})
@ -172,7 +172,7 @@ func TestInit(t *testing.T) {
},
}}
binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata)
err := binding.Init(context.Background(), metadata)
assert.Nil(t, err)
assert.NotNil(t, binding.authorizationBuilder.privateKey)
})
@ -335,7 +335,7 @@ func makeTestBinding(t *testing.T, log logger.Logger) *APNS {
privateKeyKey: testPrivateKey,
},
}}
err := testBinding.Init(bindingMetadata)
err := testBinding.Init(context.Background(), bindingMetadata)
assert.Nil(t, err)
return testBinding

View File

@ -49,7 +49,7 @@ func NewDynamoDB(logger logger.Logger) bindings.OutputBinding {
}
// Init performs connection parsing for DynamoDB.
func (d *DynamoDB) Init(metadata bindings.Metadata) error {
func (d *DynamoDB) Init(_ context.Context, metadata bindings.Metadata) error {
meta, err := d.getDynamoDBMetadata(metadata)
if err != nil {
return err

View File

@ -15,7 +15,10 @@ package kinesis
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/aws/aws-sdk-go/aws"
@ -45,6 +48,10 @@ type AWSKinesis struct {
streamARN *string
consumerARN *string
logger logger.Logger
closed atomic.Bool
closeCh chan struct{}
wg sync.WaitGroup
}
type kinesisMetadata struct {
@ -83,11 +90,14 @@ type recordProcessor struct {
// NewAWSKinesis returns a new AWS Kinesis instance.
func NewAWSKinesis(logger logger.Logger) bindings.InputOutputBinding {
return &AWSKinesis{logger: logger}
return &AWSKinesis{
logger: logger,
closeCh: make(chan struct{}),
}
}
// Init does metadata parsing and connection creation.
func (a *AWSKinesis) Init(metadata bindings.Metadata) error {
func (a *AWSKinesis) Init(ctx context.Context, metadata bindings.Metadata) error {
m, err := a.parseMetadata(metadata)
if err != nil {
return err
@ -107,7 +117,7 @@ func (a *AWSKinesis) Init(metadata bindings.Metadata) error {
}
streamName := aws.String(m.StreamName)
stream, err := client.DescribeStream(&kinesis.DescribeStreamInput{
stream, err := client.DescribeStreamWithContext(ctx, &kinesis.DescribeStreamInput{
StreamName: streamName,
})
if err != nil {
@ -147,6 +157,10 @@ func (a *AWSKinesis) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*
}
func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err error) {
if a.closed.Load() {
return errors.New("binding is closed")
}
if a.metadata.KinesisConsumerMode == SharedThroughput {
a.worker = worker.NewWorker(a.recordProcessorFactory(ctx, handler), a.workerConfig)
err = a.worker.Start()
@ -166,8 +180,13 @@ func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err er
}
// Wait for context cancelation then stop
a.wg.Add(1)
go func() {
<-ctx.Done()
defer a.wg.Done()
select {
case <-ctx.Done():
case <-a.closeCh:
}
if a.metadata.KinesisConsumerMode == SharedThroughput {
a.worker.Shutdown()
} else if a.metadata.KinesisConsumerMode == ExtendedFanout {
@ -188,14 +207,25 @@ func (a *AWSKinesis) Subscribe(ctx context.Context, streamDesc kinesis.StreamDes
a.consumerARN = consumerARN
a.wg.Add(len(streamDesc.Shards))
for i, shard := range streamDesc.Shards {
go func(idx int, s *kinesis.Shard) error {
go func(idx int, s *kinesis.Shard) {
defer a.wg.Done()
// Reconnection backoff
bo := backoff.NewExponentialBackOff()
bo.InitialInterval = 2 * time.Second
// Repeat until context is canceled
for ctx.Err() == nil {
// Repeat until context is canceled or binding closed.
for {
select {
case <-ctx.Done():
return
case <-a.closeCh:
return
default:
}
sub, err := a.client.SubscribeToShardWithContext(ctx, &kinesis.SubscribeToShardInput{
ConsumerARN: consumerARN,
ShardId: s.ShardId,
@ -204,8 +234,12 @@ func (a *AWSKinesis) Subscribe(ctx context.Context, streamDesc kinesis.StreamDes
if err != nil {
wait := bo.NextBackOff()
a.logger.Errorf("Error while reading from shard %v: %v. Attempting to reconnect in %s...", s.ShardId, err, wait)
time.Sleep(wait)
continue
select {
case <-ctx.Done():
return
case <-time.After(wait):
continue
}
}
// Reset the backoff on connection success
@ -223,22 +257,30 @@ func (a *AWSKinesis) Subscribe(ctx context.Context, streamDesc kinesis.StreamDes
}
}
}
return nil
}(i, shard)
}
return nil
}
func (a *AWSKinesis) ensureConsumer(parentCtx context.Context, streamARN *string) (*string, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
consumer, err := a.client.DescribeStreamConsumerWithContext(ctx, &kinesis.DescribeStreamConsumerInput{
func (a *AWSKinesis) Close() error {
if a.closed.CompareAndSwap(false, true) {
close(a.closeCh)
}
a.wg.Wait()
return nil
}
func (a *AWSKinesis) ensureConsumer(ctx context.Context, streamARN *string) (*string, error) {
// Only set timeout on consumer call.
conCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
consumer, err := a.client.DescribeStreamConsumerWithContext(conCtx, &kinesis.DescribeStreamConsumerInput{
ConsumerName: &a.metadata.ConsumerName,
StreamARN: streamARN,
})
cancel()
if err != nil {
return a.registerConsumer(parentCtx, streamARN)
return a.registerConsumer(ctx, streamARN)
}
return consumer.ConsumerDescription.ConsumerARN, nil

View File

@ -99,7 +99,7 @@ func NewAWSS3(logger logger.Logger) bindings.OutputBinding {
}
// Init does metadata parsing and connection creation.
func (s *AWSS3) Init(metadata bindings.Metadata) error {
func (s *AWSS3) Init(_ context.Context, metadata bindings.Metadata) error {
m, err := s.parseMetadata(metadata)
if err != nil {
return err

View File

@ -61,7 +61,7 @@ func NewAWSSES(logger logger.Logger) bindings.OutputBinding {
}
// Init does metadata parsing.
func (a *AWSSES) Init(metadata bindings.Metadata) error {
func (a *AWSSES) Init(_ context.Context, metadata bindings.Metadata) error {
// Parse input metadata
meta, err := a.parseMetadata(metadata)
if err != nil {

View File

@ -53,7 +53,7 @@ func NewAWSSNS(logger logger.Logger) bindings.OutputBinding {
}
// Init does metadata parsing.
func (a *AWSSNS) Init(metadata bindings.Metadata) error {
func (a *AWSSNS) Init(_ context.Context, metadata bindings.Metadata) error {
m, err := a.parseMetadata(metadata)
if err != nil {
return err

View File

@ -16,6 +16,9 @@ package sqs
import (
"context"
"encoding/json"
"errors"
"sync"
"sync/atomic"
"time"
"github.com/aws/aws-sdk-go/aws"
@ -31,7 +34,10 @@ type AWSSQS struct {
Client *sqs.SQS
QueueURL *string
logger logger.Logger
logger logger.Logger
wg sync.WaitGroup
closeCh chan struct{}
closed atomic.Bool
}
type sqsMetadata struct {
@ -45,11 +51,14 @@ type sqsMetadata struct {
// NewAWSSQS returns a new AWS SQS instance.
func NewAWSSQS(logger logger.Logger) bindings.InputOutputBinding {
return &AWSSQS{logger: logger}
return &AWSSQS{
logger: logger,
closeCh: make(chan struct{}),
}
}
// Init does metadata parsing and connection creation.
func (a *AWSSQS) Init(metadata bindings.Metadata) error {
func (a *AWSSQS) Init(ctx context.Context, metadata bindings.Metadata) error {
m, err := a.parseSQSMetadata(metadata)
if err != nil {
return err
@ -61,7 +70,7 @@ func (a *AWSSQS) Init(metadata bindings.Metadata) error {
}
queueName := m.QueueName
resultURL, err := client.GetQueueUrl(&sqs.GetQueueUrlInput{
resultURL, err := client.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{
QueueName: aws.String(queueName),
})
if err != nil {
@ -89,9 +98,20 @@ func (a *AWSSQS) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bind
}
func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error {
if a.closed.Load() {
return errors.New("binding is closed")
}
a.wg.Add(1)
go func() {
// Repeat until the context is canceled
for ctx.Err() == nil {
defer a.wg.Done()
// Repeat until the context is canceled or component is closed
for {
if ctx.Err() != nil || a.closed.Load() {
return
}
result, err := a.Client.ReceiveMessageWithContext(ctx, &sqs.ReceiveMessageInput{
QueueUrl: a.QueueURL,
AttributeNames: aws.StringSlice([]string{
@ -126,13 +146,25 @@ func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error {
}
}
time.Sleep(time.Millisecond * 50)
select {
case <-ctx.Done():
case <-a.closeCh:
case <-time.After(time.Millisecond * 50):
}
}
}()
return nil
}
func (a *AWSSQS) Close() error {
if a.closed.CompareAndSwap(false, true) {
close(a.closeCh)
}
a.wg.Wait()
return nil
}
func (a *AWSSQS) parseSQSMetadata(metadata bindings.Metadata) (*sqsMetadata, error) {
b, err := json.Marshal(metadata.Properties)
if err != nil {

View File

@ -92,7 +92,7 @@ func NewAzureBlobStorage(logger logger.Logger) bindings.OutputBinding {
}
// Init performs metadata parsing.
func (a *AzureBlobStorage) Init(metadata bindings.Metadata) error {
func (a *AzureBlobStorage) Init(_ context.Context, metadata bindings.Metadata) error {
var err error
a.containerClient, a.metadata, err = storageinternal.CreateContainerStorageClient(context.TODO(), a.logger, metadata.Properties)
if err != nil {

View File

@ -53,7 +53,7 @@ func NewCosmosDB(logger logger.Logger) bindings.OutputBinding {
}
// Init performs CosmosDB connection parsing and connecting.
func (c *CosmosDB) Init(metadata bindings.Metadata) error {
func (c *CosmosDB) Init(ctx context.Context, metadata bindings.Metadata) error {
m, err := c.parseMetadata(metadata)
if err != nil {
return err
@ -103,9 +103,9 @@ func (c *CosmosDB) Init(metadata bindings.Metadata) error {
}
c.client = dbContainer
ctx, cancel := context.WithTimeout(context.Background(), timeoutValue*time.Second)
_, err = c.client.Read(ctx, nil)
cancel()
readCtx, readCancel := context.WithTimeout(ctx, timeoutValue*time.Second)
defer readCancel()
_, err = c.client.Read(readCtx, nil)
return err
}

View File

@ -59,7 +59,7 @@ func NewCosmosDBGremlinAPI(logger logger.Logger) bindings.OutputBinding {
}
// Init performs CosmosDBGremlinAPI connection parsing and connecting.
func (c *CosmosDBGremlinAPI) Init(metadata bindings.Metadata) error {
func (c *CosmosDBGremlinAPI) Init(_ context.Context, metadata bindings.Metadata) error {
c.logger.Debug("Initializing Cosmos Graph DB binding")
m, err := c.parseMetadata(metadata)

View File

@ -19,6 +19,8 @@ import (
"fmt"
"net/http"
"net/url"
"sync"
"sync/atomic"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
@ -42,6 +44,9 @@ const armOperationTimeout = 30 * time.Second
type AzureEventGrid struct {
metadata *azureEventGridMetadata
logger logger.Logger
closeCh chan struct{}
closed atomic.Bool
wg sync.WaitGroup
}
type azureEventGridMetadata struct {
@ -70,11 +75,14 @@ type azureEventGridMetadata struct {
// NewAzureEventGrid returns a new Azure Event Grid instance.
func NewAzureEventGrid(logger logger.Logger) bindings.InputOutputBinding {
return &AzureEventGrid{logger: logger}
return &AzureEventGrid{
logger: logger,
closeCh: make(chan struct{}),
}
}
// Init performs metadata init.
func (a *AzureEventGrid) Init(metadata bindings.Metadata) error {
func (a *AzureEventGrid) Init(_ context.Context, metadata bindings.Metadata) error {
m, err := a.parseMetadata(metadata)
if err != nil {
return err
@ -85,6 +93,10 @@ func (a *AzureEventGrid) Init(metadata bindings.Metadata) error {
}
func (a *AzureEventGrid) Read(ctx context.Context, handler bindings.Handler) error {
if a.closed.Load() {
return errors.New("binding is closed")
}
err := a.ensureInputBindingMetadata()
if err != nil {
return err
@ -120,17 +132,22 @@ func (a *AzureEventGrid) Read(ctx context.Context, handler bindings.Handler) err
}
// Run the server in background
a.wg.Add(2)
go func() {
defer a.wg.Done()
a.logger.Debugf("About to start listening for Event Grid events at http://localhost:%s/api/events", a.metadata.HandshakePort)
srvErr := srv.ListenAndServe(":" + a.metadata.HandshakePort)
if err != nil {
a.logger.Errorf("Error starting server: %v", srvErr)
}
}()
// Close the server when context is canceled
// Close the server when context is canceled or binding closed.
go func() {
<-ctx.Done()
defer a.wg.Done()
select {
case <-ctx.Done():
case <-a.closeCh:
}
srvErr := srv.Shutdown()
if err != nil {
a.logger.Errorf("Error shutting down server: %v", srvErr)
@ -149,6 +166,14 @@ func (a *AzureEventGrid) Operations() []bindings.OperationKind {
return []bindings.OperationKind{bindings.CreateOperation}
}
func (a *AzureEventGrid) Close() error {
if a.closed.CompareAndSwap(false, true) {
close(a.closeCh)
}
a.wg.Wait()
return nil
}
func (a *AzureEventGrid) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
err := a.ensureOutputBindingMetadata()
if err != nil {

View File

@ -37,7 +37,7 @@ func NewAzureEventHubs(logger logger.Logger) bindings.InputOutputBinding {
}
// Init performs metadata init.
func (a *AzureEventHubs) Init(metadata bindings.Metadata) error {
func (a *AzureEventHubs) Init(_ context.Context, metadata bindings.Metadata) error {
return a.AzureEventHubs.Init(metadata.Properties)
}

View File

@ -102,7 +102,7 @@ func testEventHubsBindingsAADAuthentication(t *testing.T) {
metadata := createEventHubsBindingsAADMetadata()
eventHubsBindings := NewAzureEventHubs(log)
err := eventHubsBindings.Init(metadata)
err := eventHubsBindings.Init(context.Background(), metadata)
require.NoError(t, err)
req := &bindings.InvokeRequest{
@ -146,7 +146,7 @@ func testReadIotHubEvents(t *testing.T) {
logger := logger.NewLogger("bindings.azure.eventhubs.integration.test")
eh := NewAzureEventHubs(logger)
err := eh.Init(createIotHubBindingsMetadata())
err := eh.Init(context.Background(), createIotHubBindingsMetadata())
require.NoError(t, err)
// Invoke az CLI via bash script to send test IoT device events

View File

@ -17,6 +17,8 @@ import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
servicebus "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus"
@ -39,17 +41,21 @@ type AzureServiceBusQueues struct {
client *impl.Client
timeout time.Duration
logger logger.Logger
closed atomic.Bool
wg sync.WaitGroup
closeCh chan struct{}
}
// NewAzureServiceBusQueues returns a new AzureServiceBusQueues instance.
func NewAzureServiceBusQueues(logger logger.Logger) bindings.InputOutputBinding {
return &AzureServiceBusQueues{
logger: logger,
logger: logger,
closeCh: make(chan struct{}),
}
}
// Init parses connection properties and creates a new Service Bus Queue client.
func (a *AzureServiceBusQueues) Init(metadata bindings.Metadata) (err error) {
func (a *AzureServiceBusQueues) Init(ctx context.Context, metadata bindings.Metadata) (err error) {
a.metadata, err = impl.ParseMetadata(metadata.Properties, a.logger, (impl.MetadataModeBinding | impl.MetadataModeQueues))
if err != nil {
return err
@ -62,7 +68,7 @@ func (a *AzureServiceBusQueues) Init(metadata bindings.Metadata) (err error) {
}
// Will do nothing if DisableEntityManagement is false
err = a.client.EnsureQueue(context.Background(), a.metadata.QueueName)
err = a.client.EnsureQueue(ctx, a.metadata.QueueName)
if err != nil {
return err
}
@ -100,14 +106,33 @@ func (a *AzureServiceBusQueues) Invoke(invokeCtx context.Context, req *bindings.
return nil, nil
}
func (a *AzureServiceBusQueues) Read(subscribeCtx context.Context, handler bindings.Handler) error {
func (a *AzureServiceBusQueues) Read(parentCtx context.Context, handler bindings.Handler) error {
if a.closed.Load() {
return errors.New("binding is closed")
}
// Reconnection backoff policy
bo := backoff.NewExponentialBackOff()
bo.MaxElapsedTime = 0
bo.InitialInterval = time.Duration(a.metadata.MinConnectionRecoveryInSec) * time.Second
bo.MaxInterval = time.Duration(a.metadata.MaxConnectionRecoveryInSec) * time.Second
subscribeCtx, subscribeCancel := context.WithCancel(parentCtx)
// Close the subscription context when the binding is closed.
a.wg.Add(2)
go func() {
defer a.wg.Done()
select {
case <-a.closeCh:
subscribeCancel()
case <-parentCtx.Done():
// nop
}
}()
go func() {
defer a.wg.Done()
// Reconnect loop.
for {
sub := impl.NewSubscription(subscribeCtx, impl.SubsriptionOptions{
@ -165,7 +190,12 @@ func (a *AzureServiceBusQueues) Read(subscribeCtx context.Context, handler bindi
wait := bo.NextBackOff()
a.logger.Warnf("Subscription to queue %s lost connection, attempting to reconnect in %s...", a.metadata.QueueName, wait)
time.Sleep(wait)
select {
case <-time.After(wait):
// nop
case <-a.closeCh:
return
}
}
}()
@ -204,7 +234,11 @@ func (a *AzureServiceBusQueues) getHandlerFn(handler bindings.Handler) impl.Hand
}
func (a *AzureServiceBusQueues) Close() (err error) {
if a.closed.CompareAndSwap(false, true) {
close(a.closeCh)
}
a.logger.Debug("Closing component")
a.client.CloseSender(a.metadata.QueueName)
a.wg.Wait()
return nil
}

View File

@ -78,7 +78,7 @@ type SignalR struct {
}
// Init is responsible for initializing the SignalR output based on the metadata.
func (s *SignalR) Init(metadata bindings.Metadata) (err error) {
func (s *SignalR) Init(_ context.Context, metadata bindings.Metadata) (err error) {
s.userAgent = "dapr-" + logger.DaprVersion
err = s.parseMetadata(metadata.Properties)

View File

@ -16,9 +16,12 @@ package storagequeues
import (
"context"
"encoding/base64"
"errors"
"fmt"
"net/url"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/Azure/azure-storage-queue-go/azqueue"
@ -40,9 +43,10 @@ type consumer struct {
// QueueHelper enables injection for testnig.
type QueueHelper interface {
Init(metadata bindings.Metadata) (*storageQueuesMetadata, error)
Init(ctx context.Context, metadata bindings.Metadata) (*storageQueuesMetadata, error)
Write(ctx context.Context, data []byte, ttl *time.Duration) error
Read(ctx context.Context, consumer *consumer) error
Close() error
}
// AzureQueueHelper concrete impl of queue helper.
@ -55,7 +59,7 @@ type AzureQueueHelper struct {
}
// Init sets up this helper.
func (d *AzureQueueHelper) Init(metadata bindings.Metadata) (*storageQueuesMetadata, error) {
func (d *AzureQueueHelper) Init(ctx context.Context, metadata bindings.Metadata) (*storageQueuesMetadata, error) {
m, err := parseMetadata(metadata)
if err != nil {
return nil, err
@ -89,9 +93,9 @@ func (d *AzureQueueHelper) Init(metadata bindings.Metadata) (*storageQueuesMetad
d.queueURL = azqueue.NewQueueURL(*URL, p)
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
_, err = d.queueURL.Create(ctx, azqueue.Metadata{})
cancel()
createCtx, createCancel := context.WithTimeout(ctx, 2*time.Minute)
_, err = d.queueURL.Create(createCtx, azqueue.Metadata{})
createCancel()
if err != nil {
return nil, err
}
@ -128,7 +132,10 @@ func (d *AzureQueueHelper) Read(ctx context.Context, consumer *consumer) error {
}
if res.NumMessages() == 0 {
// Queue was empty so back off by 10 seconds before trying again
time.Sleep(10 * time.Second)
select {
case <-time.After(10 * time.Second):
case <-ctx.Done():
}
return nil
}
mt := res.Message(0).Text
@ -162,6 +169,10 @@ func (d *AzureQueueHelper) Read(ctx context.Context, consumer *consumer) error {
return nil
}
func (d *AzureQueueHelper) Close() error {
return nil
}
// NewAzureQueueHelper creates new helper.
func NewAzureQueueHelper(logger logger.Logger) QueueHelper {
return &AzureQueueHelper{
@ -175,6 +186,10 @@ type AzureStorageQueues struct {
helper QueueHelper
logger logger.Logger
wg sync.WaitGroup
closeCh chan struct{}
closed atomic.Bool
}
type storageQueuesMetadata struct {
@ -189,12 +204,16 @@ type storageQueuesMetadata struct {
// NewAzureStorageQueues returns a new AzureStorageQueues instance.
func NewAzureStorageQueues(logger logger.Logger) bindings.InputOutputBinding {
return &AzureStorageQueues{helper: NewAzureQueueHelper(logger), logger: logger}
return &AzureStorageQueues{
helper: NewAzureQueueHelper(logger),
logger: logger,
closeCh: make(chan struct{}),
}
}
// Init parses connection properties and creates a new Storage Queue client.
func (a *AzureStorageQueues) Init(metadata bindings.Metadata) (err error) {
a.metadata, err = a.helper.Init(metadata)
func (a *AzureStorageQueues) Init(ctx context.Context, metadata bindings.Metadata) (err error) {
a.metadata, err = a.helper.Init(ctx, metadata)
if err != nil {
return err
}
@ -261,14 +280,32 @@ func (a *AzureStorageQueues) Invoke(ctx context.Context, req *bindings.InvokeReq
}
func (a *AzureStorageQueues) Read(ctx context.Context, handler bindings.Handler) error {
if a.closed.Load() {
return errors.New("input binding is closed")
}
c := consumer{
callback: handler,
}
// Close read context when binding is closed.
readCtx, cancel := context.WithCancel(ctx)
a.wg.Add(2)
go func() {
defer a.wg.Done()
defer cancel()
select {
case <-a.closeCh:
case <-ctx.Done():
}
}()
go func() {
defer a.wg.Done()
// Read until context is canceled
var err error
for ctx.Err() == nil {
err = a.helper.Read(ctx, &c)
for readCtx.Err() == nil {
err = a.helper.Read(readCtx, &c)
if err != nil {
a.logger.Errorf("error from c: %s", err)
}
@ -277,3 +314,11 @@ func (a *AzureStorageQueues) Read(ctx context.Context, handler bindings.Handler)
return nil
}
func (a *AzureStorageQueues) Close() error {
if a.closed.CompareAndSwap(false, true) {
close(a.closeCh)
}
a.wg.Wait()
return nil
}

View File

@ -16,6 +16,7 @@ package storagequeues
import (
"context"
"encoding/base64"
"sync"
"testing"
"time"
@ -32,9 +33,11 @@ type MockHelper struct {
mock.Mock
messages chan []byte
metadata *storageQueuesMetadata
closeCh chan struct{}
wg sync.WaitGroup
}
func (m *MockHelper) Init(metadata bindings.Metadata) (*storageQueuesMetadata, error) {
func (m *MockHelper) Init(ctx context.Context, metadata bindings.Metadata) (*storageQueuesMetadata, error) {
m.messages = make(chan []byte, 10)
var err error
m.metadata, err = parseMetadata(metadata)
@ -50,12 +53,23 @@ func (m *MockHelper) Write(ctx context.Context, data []byte, ttl *time.Duration)
func (m *MockHelper) Read(ctx context.Context, consumer *consumer) error {
retvals := m.Called(ctx, consumer)
readCtx, cancel := context.WithCancel(ctx)
m.wg.Add(2)
go func() {
defer m.wg.Done()
defer cancel()
select {
case <-readCtx.Done():
case <-m.closeCh:
}
}()
go func() {
defer m.wg.Done()
for msg := range m.messages {
if m.metadata.DecodeBase64 {
msg, _ = base64.StdEncoding.DecodeString(string(msg))
}
go consumer.callback(ctx, &bindings.ReadResponse{
go consumer.callback(readCtx, &bindings.ReadResponse{
Data: msg,
})
}
@ -64,18 +78,24 @@ func (m *MockHelper) Read(ctx context.Context, consumer *consumer) error {
return retvals.Error(0)
}
func (m *MockHelper) Close() error {
defer m.wg.Wait()
close(m.closeCh)
return nil
}
func TestWriteQueue(t *testing.T) {
mm := new(MockHelper)
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.MatchedBy(func(in *time.Duration) bool {
return in == nil
})).Return(nil)
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
m := bindings.Metadata{}
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
err := a.Init(m)
err := a.Init(context.Background(), m)
assert.Nil(t, err)
r := bindings.InvokeRequest{Data: []byte("This is my message")}
@ -83,6 +103,7 @@ func TestWriteQueue(t *testing.T) {
_, err = a.Invoke(context.Background(), &r)
assert.Nil(t, err)
assert.NoError(t, a.Close())
}
func TestWriteWithTTLInQueue(t *testing.T) {
@ -91,12 +112,12 @@ func TestWriteWithTTLInQueue(t *testing.T) {
return in != nil && *in == time.Second
})).Return(nil)
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
m := bindings.Metadata{}
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1", metadata.TTLMetadataKey: "1"}
err := a.Init(m)
err := a.Init(context.Background(), m)
assert.Nil(t, err)
r := bindings.InvokeRequest{Data: []byte("This is my message")}
@ -104,6 +125,7 @@ func TestWriteWithTTLInQueue(t *testing.T) {
_, err = a.Invoke(context.Background(), &r)
assert.Nil(t, err)
assert.NoError(t, a.Close())
}
func TestWriteWithTTLInWrite(t *testing.T) {
@ -112,12 +134,12 @@ func TestWriteWithTTLInWrite(t *testing.T) {
return in != nil && *in == time.Second
})).Return(nil)
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
m := bindings.Metadata{}
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1", metadata.TTLMetadataKey: "1"}
err := a.Init(m)
err := a.Init(context.Background(), m)
assert.Nil(t, err)
r := bindings.InvokeRequest{
@ -128,6 +150,7 @@ func TestWriteWithTTLInWrite(t *testing.T) {
_, err = a.Invoke(context.Background(), &r)
assert.Nil(t, err)
assert.NoError(t, a.Close())
}
// Uncomment this function to write a message to local storage queue
@ -138,7 +161,7 @@ func TestWriteWithTTLInWrite(t *testing.T) {
m := bindings.Metadata{}
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
err := a.Init(m)
err := a.Init(context.Background(), m)
assert.Nil(t, err)
r := bindings.InvokeRequest{Data: []byte("This is my message")}
@ -152,12 +175,12 @@ func TestReadQueue(t *testing.T) {
mm := new(MockHelper)
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)
mm.On("Read", mock.AnythingOfType("*context.cancelCtx"), mock.AnythingOfType("*storagequeues.consumer")).Return(nil)
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
m := bindings.Metadata{}
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
err := a.Init(m)
err := a.Init(context.Background(), m)
assert.Nil(t, err)
r := bindings.InvokeRequest{Data: []byte("This is my message")}
@ -186,6 +209,7 @@ func TestReadQueue(t *testing.T) {
t.Fatal("Timeout waiting for messages")
}
assert.Equal(t, 1, received)
assert.NoError(t, a.Close())
}
func TestReadQueueDecode(t *testing.T) {
@ -193,12 +217,12 @@ func TestReadQueueDecode(t *testing.T) {
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)
mm.On("Read", mock.AnythingOfType("*context.cancelCtx"), mock.AnythingOfType("*storagequeues.consumer")).Return(nil)
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
m := bindings.Metadata{}
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1", "decodeBase64": "true"}
err := a.Init(m)
err := a.Init(context.Background(), m)
assert.Nil(t, err)
r := bindings.InvokeRequest{Data: []byte("VGhpcyBpcyBteSBtZXNzYWdl")}
@ -227,6 +251,7 @@ func TestReadQueueDecode(t *testing.T) {
t.Fatal("Timeout waiting for messages")
}
assert.Equal(t, 1, received)
assert.NoError(t, a.Close())
}
// Uncomment this function to test reding from local queue
@ -237,7 +262,7 @@ func TestReadQueueDecode(t *testing.T) {
m := bindings.Metadata{}
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
err := a.Init(m)
err := a.Init(context.Background(), m)
assert.Nil(t, err)
r := bindings.InvokeRequest{Data: []byte("This is my message")}
@ -263,12 +288,12 @@ func TestReadQueueNoMessage(t *testing.T) {
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)
mm.On("Read", mock.AnythingOfType("*context.cancelCtx"), mock.AnythingOfType("*storagequeues.consumer")).Return(nil)
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
m := bindings.Metadata{}
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
err := a.Init(m)
err := a.Init(context.Background(), m)
assert.Nil(t, err)
ctx, cancel := context.WithCancel(context.Background())
@ -285,6 +310,7 @@ func TestReadQueueNoMessage(t *testing.T) {
time.Sleep(1 * time.Second)
cancel()
assert.Equal(t, 0, received)
assert.NoError(t, a.Close())
}
func TestParseMetadata(t *testing.T) {

View File

@ -48,7 +48,7 @@ func NewCFQueues(logger logger.Logger) bindings.OutputBinding {
}
// Init the component.
func (q *CFQueues) Init(metadata bindings.Metadata) error {
func (q *CFQueues) Init(_ context.Context, metadata bindings.Metadata) error {
// Decode the metadata
err := mapstructure.Decode(metadata.Properties, &q.metadata)
if err != nil {

View File

@ -51,7 +51,7 @@ func NewCommercetools(logger logger.Logger) bindings.OutputBinding {
}
// Init does metadata parsing and connection establishment.
func (ct *Binding) Init(metadata bindings.Metadata) error {
func (ct *Binding) Init(_ context.Context, metadata bindings.Metadata) error {
commercetoolsM, err := ct.getCommercetoolsMetadata(metadata)
if err != nil {
return err

View File

@ -16,6 +16,8 @@ package cron
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/benbjohnson/clock"
@ -32,6 +34,9 @@ type Binding struct {
schedule string
parser cron.Parser
clk clock.Clock
closed atomic.Bool
closeCh chan struct{}
wg sync.WaitGroup
}
// NewCron returns a new Cron event input binding.
@ -46,6 +51,7 @@ func NewCronWithClock(logger logger.Logger, clk clock.Clock) bindings.InputBindi
parser: cron.NewParser(
cron.SecondOptional | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor,
),
closeCh: make(chan struct{}),
}
}
@ -54,7 +60,7 @@ func NewCronWithClock(logger logger.Logger, clk clock.Clock) bindings.InputBindi
//
// "15 * * * * *" - Every 15 sec
// "0 30 * * * *" - Every 30 min
func (b *Binding) Init(metadata bindings.Metadata) error {
func (b *Binding) Init(ctx context.Context, metadata bindings.Metadata) error {
b.name = metadata.Name
s, f := metadata.Properties["schedule"]
if !f || s == "" {
@ -71,6 +77,10 @@ func (b *Binding) Init(metadata bindings.Metadata) error {
// Read triggers the Cron scheduler.
func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
if b.closed.Load() {
return errors.New("binding is closed")
}
c := cron.New(cron.WithParser(b.parser), cron.WithClock(b.clk))
id, err := c.AddFunc(b.schedule, func() {
b.logger.Debugf("name: %s, schedule fired: %v", b.name, time.Now())
@ -87,12 +97,25 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
c.Start()
b.logger.Debugf("name: %s, next run: %v", b.name, time.Until(c.Entry(id).Next))
b.wg.Add(1)
go func() {
// Wait for context to be canceled
<-ctx.Done()
defer b.wg.Done()
// Wait for context to be canceled or component to be closed.
select {
case <-ctx.Done():
case <-b.closeCh:
}
b.logger.Debugf("name: %s, stopping schedule: %s", b.name, b.schedule)
c.Stop()
}()
return nil
}
func (b *Binding) Close() error {
if b.closed.CompareAndSwap(false, true) {
close(b.closeCh)
}
b.wg.Wait()
return nil
}

View File

@ -16,6 +16,7 @@ package cron
import (
"context"
"os"
"sync/atomic"
"testing"
"time"
@ -84,7 +85,7 @@ func TestCronInitSuccess(t *testing.T) {
for _, test := range initTests {
c := getNewCron()
err := c.Init(getTestMetadata(test.schedule))
err := c.Init(context.Background(), getTestMetadata(test.schedule))
if test.errorExpected {
assert.Errorf(t, err, "Got no error while initializing an invalid schedule: %s", test.schedule)
} else {
@ -99,38 +100,41 @@ func TestCronRead(t *testing.T) {
clk := clock.NewMock()
c := getNewCronWithClock(clk)
schedule := "@every 1s"
assert.NoErrorf(t, c.Init(getTestMetadata(schedule)), "error initializing valid schedule")
expectedCount := 5
observedCount := 0
assert.NoErrorf(t, c.Init(context.Background(), getTestMetadata(schedule)), "error initializing valid schedule")
expectedCount := int32(5)
var observedCount atomic.Int32
err := c.Read(context.Background(), func(ctx context.Context, res *bindings.ReadResponse) ([]byte, error) {
assert.NotNil(t, res)
observedCount++
observedCount.Add(1)
return nil, nil
})
// Check if cron triggers 5 times in 5 seconds
for i := 0; i < expectedCount; i++ {
for i := int32(0); i < expectedCount; i++ {
// Add time to mock clock in 1 second intervals using loop to allow cron go routine to run
clk.Add(time.Second)
}
// Wait for 1 second after adding the last second to mock clock to allow cron to finish triggering
time.Sleep(1 * time.Second)
assert.Equal(t, expectedCount, observedCount, "Cron did not trigger expected number of times, expected %d, got %d", expectedCount, observedCount)
assert.Eventually(t, func() bool {
return observedCount.Load() == expectedCount
}, time.Second, time.Millisecond*10,
"Cron did not trigger expected number of times, expected %d, got %d", expectedCount, observedCount.Load())
assert.NoErrorf(t, err, "error on read")
assert.NoError(t, c.Close())
}
func TestCronReadWithContextCancellation(t *testing.T) {
clk := clock.NewMock()
c := getNewCronWithClock(clk)
schedule := "@every 1s"
assert.NoErrorf(t, c.Init(getTestMetadata(schedule)), "error initializing valid schedule")
expectedCount := 5
observedCount := 0
assert.NoErrorf(t, c.Init(context.Background(), getTestMetadata(schedule)), "error initializing valid schedule")
expectedCount := int32(5)
var observedCount atomic.Int32
ctx, cancel := context.WithCancel(context.Background())
err := c.Read(ctx, func(ctx context.Context, res *bindings.ReadResponse) ([]byte, error) {
assert.NotNil(t, res)
assert.LessOrEqualf(t, observedCount, expectedCount, "Invoke didn't stop the schedule")
observedCount++
if observedCount == expectedCount {
assert.LessOrEqualf(t, observedCount.Load(), expectedCount, "Invoke didn't stop the schedule")
observedCount.Add(1)
if observedCount.Load() == expectedCount {
// Cancel context after 5 triggers
cancel()
}
@ -141,7 +145,10 @@ func TestCronReadWithContextCancellation(t *testing.T) {
// Add time to mock clock in 1 second intervals using loop to allow cron go routine to run
clk.Add(time.Second)
}
time.Sleep(1 * time.Second)
assert.Equal(t, expectedCount, observedCount, "Cron did not trigger expected number of times, expected %d, got %d", expectedCount, observedCount)
assert.Eventually(t, func() bool {
return observedCount.Load() == expectedCount
}, time.Second, time.Millisecond*10,
"Cron did not trigger expected number of times, expected %d, got %d", expectedCount, observedCount.Load())
assert.NoErrorf(t, err, "error on read")
assert.NoError(t, c.Close())
}

View File

@ -47,7 +47,7 @@ func NewDubboOutput(logger logger.Logger) bindings.OutputBinding {
return dubboBinding
}
func (out *DubboOutputBinding) Init(_ bindings.Metadata) error {
func (out *DubboOutputBinding) Init(_ context.Context, _ bindings.Metadata) error {
dubboImpl.SetSerializer(constant.Hessian2Serialization, HessianSerializer{})
return nil
}

View File

@ -54,12 +54,13 @@ func TestInvoke(t *testing.T) {
// 0. init dapr provided and dubbo server
stopCh := make(chan struct{})
defer close(stopCh)
// Create output and set serializer before go routine to prevent data race.
output := NewDubboOutput(logger.NewLogger("test"))
dubboImpl.SetSerializer(constant.Hessian2Serialization, HessianSerializer{})
go func() {
assert.Nil(t, runDubboServer(stopCh))
}()
time.Sleep(time.Second * 3)
dubboImpl.SetSerializer(constant.Hessian2Serialization, HessianSerializer{})
output := NewDubboOutput(logger.NewLogger("test"))
// 1. create req/rsp value
reqUser := &User{Name: testName}

View File

@ -83,14 +83,13 @@ func NewGCPStorage(logger logger.Logger) bindings.OutputBinding {
}
// Init performs connection parsing.
func (g *GCPStorage) Init(metadata bindings.Metadata) error {
func (g *GCPStorage) Init(ctx context.Context, metadata bindings.Metadata) error {
m, b, err := g.parseMetadata(metadata)
if err != nil {
return err
}
clientOptions := option.WithCredentialsJSON(b)
ctx := context.Background()
client, err := storage.NewClient(ctx, clientOptions)
if err != nil {
return err

View File

@ -16,7 +16,10 @@ package pubsub
import (
"context"
"encoding/json"
"errors"
"fmt"
"sync"
"sync/atomic"
"cloud.google.com/go/pubsub"
"google.golang.org/api/option"
@ -36,6 +39,9 @@ type GCPPubSub struct {
client *pubsub.Client
metadata *pubSubMetadata
logger logger.Logger
closed atomic.Bool
closeCh chan struct{}
wg sync.WaitGroup
}
type pubSubMetadata struct {
@ -55,11 +61,14 @@ type pubSubMetadata struct {
// NewGCPPubSub returns a new GCPPubSub instance.
func NewGCPPubSub(logger logger.Logger) bindings.InputOutputBinding {
return &GCPPubSub{logger: logger}
return &GCPPubSub{
logger: logger,
closeCh: make(chan struct{}),
}
}
// Init parses metadata and creates a new Pub Sub client.
func (g *GCPPubSub) Init(metadata bindings.Metadata) error {
func (g *GCPPubSub) Init(ctx context.Context, metadata bindings.Metadata) error {
b, err := g.parseMetadata(metadata)
if err != nil {
return err
@ -71,7 +80,6 @@ func (g *GCPPubSub) Init(metadata bindings.Metadata) error {
return err
}
clientOptions := option.WithCredentialsJSON(b)
ctx := context.Background()
pubsubClient, err := pubsub.NewClient(ctx, pubsubMeta.ProjectID, clientOptions)
if err != nil {
return fmt.Errorf("error creating pubsub client: %s", err)
@ -88,7 +96,12 @@ func (g *GCPPubSub) parseMetadata(metadata bindings.Metadata) ([]byte, error) {
}
func (g *GCPPubSub) Read(ctx context.Context, handler bindings.Handler) error {
if g.closed.Load() {
return errors.New("binding is closed")
}
g.wg.Add(1)
go func() {
defer g.wg.Done()
sub := g.client.Subscription(g.metadata.Subscription)
err := sub.Receive(ctx, func(c context.Context, m *pubsub.Message) {
_, err := handler(c, &bindings.ReadResponse{
@ -128,5 +141,9 @@ func (g *GCPPubSub) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*b
}
func (g *GCPPubSub) Close() error {
if g.closed.CompareAndSwap(false, true) {
close(g.closeCh)
}
defer g.wg.Wait()
return g.client.Close()
}

View File

@ -58,7 +58,7 @@ func NewGraphQL(logger logger.Logger) bindings.OutputBinding {
}
// Init initializes the GraphQL binding.
func (gql *GraphQL) Init(metadata bindings.Metadata) error {
func (gql *GraphQL) Init(_ context.Context, metadata bindings.Metadata) error {
gql.logger.Debug("GraphQL Error: Initializing GraphQL binding")
p := metadata.Properties

View File

@ -74,7 +74,7 @@ func NewHTTP(logger logger.Logger) bindings.OutputBinding {
}
// Init performs metadata parsing.
func (h *HTTPSource) Init(metadata bindings.Metadata) error {
func (h *HTTPSource) Init(_ context.Context, metadata bindings.Metadata) error {
var err error
if err = mapstructure.Decode(metadata.Properties, &h.metadata); err != nil {
return err
@ -104,7 +104,7 @@ func (h *HTTPSource) Init(metadata bindings.Metadata) error {
Transport: netTransport,
}
if val, ok := metadata.Properties["errorIfNot2XX"]; ok {
if val := metadata.Properties["errorIfNot2XX"]; val != "" {
h.errorIfNot2XX = utils.IsTruthy(val)
} else {
// Default behavior

View File

@ -132,7 +132,7 @@ func InitBinding(s *httptest.Server, extraProps map[string]string) (bindings.Out
}
hs := NewHTTP(logger.NewLogger("test"))
err := hs.Init(m)
err := hs.Init(context.Background(), m)
return hs, err
}
@ -269,7 +269,7 @@ func InitBindingForHTTPS(s *httptest.Server, extraProps map[string]string) (bind
m.Properties[k] = v
}
hs := NewHTTP(logger.NewLogger("test"))
err := hs.Init(m)
err := hs.Init(context.Background(), m)
return hs, err
}

View File

@ -75,7 +75,7 @@ func NewHuaweiOBS(logger logger.Logger) bindings.OutputBinding {
}
// Init does metadata parsing and connection creation.
func (o *HuaweiOBS) Init(metadata bindings.Metadata) error {
func (o *HuaweiOBS) Init(_ context.Context, metadata bindings.Metadata) error {
o.logger.Debugf("initializing Huawei OBS binding and parsing metadata")
m, err := o.parseMetadata(metadata)

View File

@ -92,7 +92,7 @@ func TestInit(t *testing.T) {
"accessKey": "dummy-ak",
"secretKey": "dummy-sk",
}
err := obs.Init(m)
err := obs.Init(context.Background(), m)
assert.Nil(t, err)
})
t.Run("Init with missing bucket name", func(t *testing.T) {
@ -102,7 +102,7 @@ func TestInit(t *testing.T) {
"accessKey": "dummy-ak",
"secretKey": "dummy-sk",
}
err := obs.Init(m)
err := obs.Init(context.Background(), m)
assert.NotNil(t, err)
assert.Equal(t, err, fmt.Errorf("missing obs bucket name"))
})
@ -113,7 +113,7 @@ func TestInit(t *testing.T) {
"endpoint": "dummy-endpoint",
"secretKey": "dummy-sk",
}
err := obs.Init(m)
err := obs.Init(context.Background(), m)
assert.NotNil(t, err)
assert.Equal(t, err, fmt.Errorf("missing the huawei access key"))
})
@ -124,7 +124,7 @@ func TestInit(t *testing.T) {
"endpoint": "dummy-endpoint",
"accessKey": "dummy-ak",
}
err := obs.Init(m)
err := obs.Init(context.Background(), m)
assert.NotNil(t, err)
assert.Equal(t, err, fmt.Errorf("missing the huawei secret key"))
})
@ -135,7 +135,7 @@ func TestInit(t *testing.T) {
"accessKey": "dummy-ak",
"secretKey": "dummy-sk",
}
err := obs.Init(m)
err := obs.Init(context.Background(), m)
assert.NotNil(t, err)
assert.Equal(t, err, fmt.Errorf("missing obs endpoint"))
})

View File

@ -64,7 +64,7 @@ func NewInflux(logger logger.Logger) bindings.OutputBinding {
}
// Init does metadata parsing and connection establishment.
func (i *Influx) Init(metadata bindings.Metadata) error {
func (i *Influx) Init(_ context.Context, metadata bindings.Metadata) error {
influxMeta, err := i.getInfluxMetadata(metadata)
if err != nil {
return err

View File

@ -54,7 +54,7 @@ func TestInflux_Init(t *testing.T) {
assert.Nil(t, influx.client)
m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{"Url": "a", "Token": "a", "Org": "a", "Bucket": "a"}}}
err := influx.Init(m)
err := influx.Init(context.Background(), m)
assert.Nil(t, err)
assert.NotNil(t, influx.queryAPI)

View File

@ -16,6 +16,7 @@ package bindings
import (
"context"
"fmt"
"io"
"github.com/dapr/components-contrib/health"
)
@ -23,18 +24,21 @@ import (
// InputBinding is the interface to define a binding that triggers on incoming events.
type InputBinding interface {
// Init passes connection and properties metadata to the binding implementation.
Init(metadata Metadata) error
Init(ctx context.Context, metadata Metadata) error
// Read is a method that runs in background and triggers the callback function whenever an event arrives.
Read(ctx context.Context, handler Handler) error
// Close is a method that closes the connection to the binding. Must be
// called when the binding is no longer needed to free up resources.
io.Closer
}
// Handler is the handler used to invoke the app handler.
type Handler func(context.Context, *ReadResponse) ([]byte, error)
func PingInpBinding(inputBinding InputBinding) error {
func PingInpBinding(ctx context.Context, inputBinding InputBinding) error {
// checks if this input binding has the ping option then executes
if inputBindingWithPing, ok := inputBinding.(health.Pinger); ok {
return inputBindingWithPing.Ping()
return inputBindingWithPing.Ping(ctx)
} else {
return fmt.Errorf("ping is not implemented by this input binding")
}

View File

@ -15,7 +15,10 @@ package kafka
import (
"context"
"errors"
"strings"
"sync"
"sync/atomic"
"github.com/dapr/kit/logger"
@ -29,12 +32,13 @@ const (
)
type Binding struct {
kafka *kafka.Kafka
publishTopic string
topics []string
logger logger.Logger
subscribeCtx context.Context
subscribeCancel context.CancelFunc
kafka *kafka.Kafka
publishTopic string
topics []string
logger logger.Logger
closeCh chan struct{}
closed atomic.Bool
wg sync.WaitGroup
}
// NewKafka returns a new kafka binding instance.
@ -43,15 +47,14 @@ func NewKafka(logger logger.Logger) bindings.InputOutputBinding {
// in kafka binding component, disable consumer retry by default
k.DefaultConsumeRetryEnabled = false
return &Binding{
kafka: k,
logger: logger,
kafka: k,
logger: logger,
closeCh: make(chan struct{}),
}
}
func (b *Binding) Init(metadata bindings.Metadata) error {
b.subscribeCtx, b.subscribeCancel = context.WithCancel(context.Background())
err := b.kafka.Init(metadata.Properties)
func (b *Binding) Init(ctx context.Context, metadata bindings.Metadata) error {
err := b.kafka.Init(ctx, metadata.Properties)
if err != nil {
return err
}
@ -74,7 +77,10 @@ func (b *Binding) Operations() []bindings.OperationKind {
}
func (b *Binding) Close() (err error) {
b.subscribeCancel()
if b.closed.CompareAndSwap(false, true) {
close(b.closeCh)
}
defer b.wg.Wait()
return b.kafka.Close()
}
@ -84,6 +90,10 @@ func (b *Binding) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bin
}
func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
if b.closed.Load() {
return errors.New("error: binding is closed")
}
if len(b.topics) == 0 {
b.logger.Warnf("kafka binding: no topic defined, input bindings will not be started")
return nil
@ -96,31 +106,22 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
for _, t := range b.topics {
b.kafka.AddTopicHandler(t, handlerConfig)
}
b.wg.Add(1)
go func() {
// Wait for context cancelation
defer b.wg.Done()
// Wait for context cancelation or closure.
select {
case <-ctx.Done():
case <-b.subscribeCtx.Done():
case <-b.closeCh:
}
// Remove the topic handler before restarting the subscriber
// Remove the topic handlers.
for _, t := range b.topics {
b.kafka.RemoveTopicHandler(t)
}
// If the component's context has been canceled, do not re-subscribe
if b.subscribeCtx.Err() != nil {
return
}
err := b.kafka.Subscribe(b.subscribeCtx)
if err != nil {
b.logger.Errorf("kafka binding: error re-subscribing: %v", err)
}
}()
return b.kafka.Subscribe(b.subscribeCtx)
return b.kafka.Subscribe(ctx)
}
func adaptHandler(handler bindings.Handler) kafka.EventHandler {

View File

@ -2,8 +2,11 @@ package kubemq
import (
"context"
"errors"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
qs "github.com/kubemq-io/kubemq-go/queues_stream"
@ -19,31 +22,30 @@ type Kubemq interface {
}
type kubeMQ struct {
client *qs.QueuesStreamClient
opts *options
logger logger.Logger
ctx context.Context
ctxCancel context.CancelFunc
client *qs.QueuesStreamClient
opts *options
logger logger.Logger
closed atomic.Bool
closeCh chan struct{}
wg sync.WaitGroup
}
func NewKubeMQ(logger logger.Logger) Kubemq {
return &kubeMQ{
client: nil,
opts: nil,
logger: logger,
ctx: nil,
ctxCancel: nil,
client: nil,
opts: nil,
logger: logger,
closeCh: make(chan struct{}),
}
}
func (k *kubeMQ) Init(metadata bindings.Metadata) error {
func (k *kubeMQ) Init(ctx context.Context, metadata bindings.Metadata) error {
opts, err := createOptions(metadata)
if err != nil {
return err
}
k.opts = opts
k.ctx, k.ctxCancel = context.WithCancel(context.Background())
client, err := qs.NewQueuesStreamClient(k.ctx,
client, err := qs.NewQueuesStreamClient(ctx,
qs.WithAddress(opts.host, opts.port),
qs.WithCheckConnection(true),
qs.WithAuthToken(opts.authToken),
@ -53,22 +55,39 @@ func (k *kubeMQ) Init(metadata bindings.Metadata) error {
k.logger.Errorf("error init kubemq client error: %s", err.Error())
return err
}
k.ctx, k.ctxCancel = context.WithCancel(context.Background())
k.client = client
return nil
}
func (k *kubeMQ) Read(ctx context.Context, handler bindings.Handler) error {
if k.closed.Load() {
return errors.New("binding is closed")
}
k.wg.Add(2)
processCtx, cancel := context.WithCancel(ctx)
go func() {
defer k.wg.Done()
defer cancel()
select {
case <-k.closeCh:
case <-processCtx.Done():
}
}()
go func() {
defer k.wg.Done()
for {
err := k.processQueueMessage(k.ctx, handler)
err := k.processQueueMessage(processCtx, handler)
if err != nil {
k.logger.Error(err.Error())
time.Sleep(time.Second)
}
if k.ctx.Err() != nil {
return
// If context cancelled or kubeMQ closed, exit. Otherwise, continue
// after a second.
select {
case <-time.After(time.Second):
continue
case <-processCtx.Done():
}
return
}
}()
return nil
@ -82,7 +101,7 @@ func (k *kubeMQ) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bind
SetPolicyExpirationSeconds(parsePolicyExpirationSeconds(req.Metadata)).
SetPolicyMaxReceiveCount(parseSetPolicyMaxReceiveCount(req.Metadata)).
SetPolicyMaxReceiveQueue(parsePolicyMaxReceiveQueue(req.Metadata))
result, err := k.client.Send(k.ctx, queueMessage)
result, err := k.client.Send(ctx, queueMessage)
if err != nil {
return nil, err
}
@ -101,6 +120,14 @@ func (k *kubeMQ) Operations() []bindings.OperationKind {
return []bindings.OperationKind{bindings.CreateOperation}
}
func (k *kubeMQ) Close() error {
if k.closed.CompareAndSwap(false, true) {
close(k.closeCh)
}
defer k.wg.Wait()
return k.client.Close()
}
func (k *kubeMQ) processQueueMessage(ctx context.Context, handler bindings.Handler) error {
pr := qs.NewPollRequest().
SetChannel(k.opts.channel).

View File

@ -106,7 +106,7 @@ func Test_kubeMQ_Init(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
kubemq := NewKubeMQ(logger.NewLogger("test"))
err := kubemq.Init(tt.meta)
err := kubemq.Init(context.Background(), tt.meta)
if tt.wantErr {
require.Error(t, err)
} else {
@ -120,7 +120,7 @@ func Test_kubeMQ_Invoke_Read_Single_Message(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
kubemq := NewKubeMQ(logger.NewLogger("test"))
err := kubemq.Init(getDefaultMetadata("test.read.single"))
err := kubemq.Init(context.Background(), getDefaultMetadata("test.read.single"))
require.NoError(t, err)
dataReadCh := make(chan []byte)
invokeRequest := &bindings.InvokeRequest{
@ -147,7 +147,7 @@ func Test_kubeMQ_Invoke_Read_Single_MessageWithHandlerError(t *testing.T) {
kubemq := NewKubeMQ(logger.NewLogger("test"))
md := getDefaultMetadata("test.read.single.error")
md.Properties["autoAcknowledged"] = "false"
err := kubemq.Init(md)
err := kubemq.Init(context.Background(), md)
require.NoError(t, err)
invokeRequest := &bindings.InvokeRequest{
Data: []byte("test"),
@ -182,7 +182,7 @@ func Test_kubeMQ_Invoke_Error(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
kubemq := NewKubeMQ(logger.NewLogger("test"))
err := kubemq.Init(getDefaultMetadata("***test***"))
err := kubemq.Init(context.Background(), getDefaultMetadata("***test***"))
require.NoError(t, err)
invokeRequest := &bindings.InvokeRequest{

View File

@ -18,6 +18,8 @@ import (
"encoding/json"
"errors"
"strconv"
"sync"
"sync/atomic"
"time"
v1 "k8s.io/api/core/v1"
@ -35,6 +37,9 @@ type kubernetesInput struct {
namespace string
resyncPeriod time.Duration
logger logger.Logger
closed atomic.Bool
closeCh chan struct{}
wg sync.WaitGroup
}
type EventResponse struct {
@ -45,10 +50,13 @@ type EventResponse struct {
// NewKubernetes returns a new Kubernetes event input binding.
func NewKubernetes(logger logger.Logger) bindings.InputBinding {
return &kubernetesInput{logger: logger}
return &kubernetesInput{
logger: logger,
closeCh: make(chan struct{}),
}
}
func (k *kubernetesInput) Init(metadata bindings.Metadata) error {
func (k *kubernetesInput) Init(ctx context.Context, metadata bindings.Metadata) error {
client, err := kubeclient.GetKubeClient()
if err != nil {
return err
@ -78,6 +86,9 @@ func (k *kubernetesInput) parseMetadata(metadata bindings.Metadata) error {
}
func (k *kubernetesInput) Read(ctx context.Context, handler bindings.Handler) error {
if k.closed.Load() {
return errors.New("binding is closed")
}
watchlist := cache.NewListWatchFromClient(
k.kubeClient.CoreV1().RESTClient(),
"events",
@ -126,12 +137,28 @@ func (k *kubernetesInput) Read(ctx context.Context, handler bindings.Handler) er
},
)
k.wg.Add(3)
readCtx, cancel := context.WithCancel(ctx)
// catch when binding is closed.
go func() {
defer k.wg.Done()
defer cancel()
select {
case <-readCtx.Done():
case <-k.closeCh:
}
}()
// Start the controller in backgound
stopCh := make(chan struct{})
go controller.Run(stopCh)
go func() {
defer k.wg.Done()
controller.Run(readCtx.Done())
}()
// Watch for new messages and for context cancellation
go func() {
defer k.wg.Done()
var (
obj EventResponse
data []byte
@ -148,8 +175,7 @@ func (k *kubernetesInput) Read(ctx context.Context, handler bindings.Handler) er
Data: data,
})
}
case <-ctx.Done():
close(stopCh)
case <-readCtx.Done():
return
}
}
@ -157,3 +183,11 @@ func (k *kubernetesInput) Read(ctx context.Context, handler bindings.Handler) er
return nil
}
func (k *kubernetesInput) Close() error {
if k.closed.CompareAndSwap(false, true) {
close(k.closeCh)
}
k.wg.Wait()
return nil
}

View File

@ -64,7 +64,7 @@ func NewLocalStorage(logger logger.Logger) bindings.OutputBinding {
}
// Init performs metadata parsing.
func (ls *LocalStorage) Init(metadata bindings.Metadata) error {
func (ls *LocalStorage) Init(_ context.Context, metadata bindings.Metadata) error {
m, err := ls.parseMetadata(metadata)
if err != nil {
return fmt.Errorf("failed to parse metadata: %w", err)

View File

@ -40,30 +40,29 @@ type MQTT struct {
logger logger.Logger
isSubscribed atomic.Bool
readHandler bindings.Handler
ctx context.Context
cancel context.CancelFunc
backOff backoff.BackOff
closeCh chan struct{}
closed atomic.Bool
wg sync.WaitGroup
}
// NewMQTT returns a new MQTT instance.
func NewMQTT(logger logger.Logger) bindings.InputOutputBinding {
return &MQTT{
logger: logger,
logger: logger,
closeCh: make(chan struct{}),
}
}
// Init does MQTT connection parsing.
func (m *MQTT) Init(metadata bindings.Metadata) (err error) {
func (m *MQTT) Init(ctx context.Context, metadata bindings.Metadata) (err error) {
m.metadata, err = parseMQTTMetaData(metadata, m.logger)
if err != nil {
return err
}
m.ctx, m.cancel = context.WithCancel(context.Background())
// TODO: Make the backoff configurable for constant or exponential
b := backoff.NewConstantBackOff(5 * time.Second)
m.backOff = backoff.WithContext(b, m.ctx)
m.backOff = backoff.NewConstantBackOff(5 * time.Second)
return nil
}
@ -104,7 +103,7 @@ func (m *MQTT) getProducer() (mqtt.Client, error) {
return p, nil
}
func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
func (m *MQTT) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
producer, err := m.getProducer()
if err != nil {
return nil, fmt.Errorf("failed to create producer connection: %w", err)
@ -118,7 +117,7 @@ func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (*
bo := backoff.WithMaxRetries(
backoff.NewConstantBackOff(200*time.Millisecond), 3,
)
bo = backoff.WithContext(bo, parentCtx)
bo = backoff.WithContext(bo, ctx)
topic, ok := req.Metadata[mqttTopic]
if !ok || topic == "" {
@ -127,14 +126,13 @@ func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (*
}
return nil, retry.NotifyRecover(func() (err error) {
token := producer.Publish(topic, m.metadata.qos, m.metadata.retain, req.Data)
ctx, cancel := context.WithTimeout(parentCtx, defaultWait)
defer cancel()
select {
case <-token.Done():
err = token.Error()
case <-m.ctx.Done():
// Context canceled
err = m.ctx.Err()
case <-m.closeCh:
err = errors.New("mqtt client closed")
case <-time.After(defaultWait):
err = errors.New("mqtt client timeout")
case <-ctx.Done():
// Context canceled
err = ctx.Err()
@ -151,6 +149,10 @@ func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (*
}
func (m *MQTT) Read(ctx context.Context, handler bindings.Handler) error {
if m.closed.Load() {
return errors.New("error: binding is closed")
}
// If the subscription is already active, wait 2s before retrying (in case we're still disconnecting), otherwise return an error
if !m.isSubscribed.CompareAndSwap(false, true) {
m.logger.Debug("Subscription is already active; waiting 2s before retrying…")
@ -177,11 +179,14 @@ func (m *MQTT) Read(ctx context.Context, handler bindings.Handler) error {
// In background, watch for contexts cancelation and stop the connection
// However, do not call "unsubscribe" which would cause the broker to stop tracking the last message received by this consumer group
m.wg.Add(1)
go func() {
defer m.wg.Done()
select {
case <-ctx.Done():
// nop
case <-m.ctx.Done():
case <-m.closeCh:
// nop
}
@ -208,14 +213,12 @@ func (m *MQTT) connect(clientID string, isSubscriber bool) (mqtt.Client, error)
}
client := mqtt.NewClient(opts)
ctx, cancel := context.WithTimeout(m.ctx, defaultWait)
defer cancel()
token := client.Connect()
select {
case <-token.Done():
err = token.Error()
case <-ctx.Done():
err = ctx.Err()
case <-time.After(defaultWait):
err = errors.New("mqtt client timed out connecting")
}
if err != nil {
return nil, fmt.Errorf("failed to connect: %w", err)
@ -290,43 +293,46 @@ func (m *MQTT) createClientOptions(uri *url.URL, clientID string) *mqtt.ClientOp
return opts
}
func (m *MQTT) handleMessage(client mqtt.Client, mqttMsg mqtt.Message) {
// We're using m.ctx as context in this method because we don't have access to the Read context
// Canceling the Read context makes Read invoke "Disconnect" anyways
ctx := m.ctx
func (m *MQTT) handleMessage() func(client mqtt.Client, mqttMsg mqtt.Message) {
return func(client mqtt.Client, mqttMsg mqtt.Message) {
bo := m.backOff
if m.metadata.backOffMaxRetries >= 0 {
bo = backoff.WithMaxRetries(bo, uint64(m.metadata.backOffMaxRetries))
}
var bo backoff.BackOff = backoff.WithContext(m.backOff, ctx)
if m.metadata.backOffMaxRetries >= 0 {
bo = backoff.WithMaxRetries(bo, uint64(m.metadata.backOffMaxRetries))
}
err := retry.NotifyRecover(
func() error {
m.logger.Debugf("Processing MQTT message %s/%d", mqttMsg.Topic(), mqttMsg.MessageID())
// Use a background context here so that the context is not tied to the
// first Invoke first created the producer.
// TODO: add context to mqtt library, and add a OnConnectWithContext option
// to change this func signature to
// func(c mqtt.Client, ctx context.Context)
_, err := m.readHandler(context.Background(), &bindings.ReadResponse{
Data: mqttMsg.Payload(),
Metadata: map[string]string{
mqttTopic: mqttMsg.Topic(),
},
})
if err != nil {
return err
}
err := retry.NotifyRecover(
func() error {
m.logger.Debugf("Processing MQTT message %s/%d", mqttMsg.Topic(), mqttMsg.MessageID())
_, err := m.readHandler(ctx, &bindings.ReadResponse{
Data: mqttMsg.Payload(),
Metadata: map[string]string{
mqttTopic: mqttMsg.Topic(),
},
})
if err != nil {
return err
}
// Ack the message on success
mqttMsg.Ack()
return nil
},
bo,
func(err error, d time.Duration) {
m.logger.Errorf("Error processing MQTT message: %s/%d. Retrying…", mqttMsg.Topic(), mqttMsg.MessageID())
},
func() {
m.logger.Infof("Successfully processed MQTT message after it previously failed: %s/%d", mqttMsg.Topic(), mqttMsg.MessageID())
},
)
if err != nil {
m.logger.Errorf("Failed processing MQTT message: %s/%d: %v", mqttMsg.Topic(), mqttMsg.MessageID(), err)
// Ack the message on success
mqttMsg.Ack()
return nil
},
bo,
func(err error, d time.Duration) {
m.logger.Errorf("Error processing MQTT message: %s/%d. Retrying…", mqttMsg.Topic(), mqttMsg.MessageID())
},
func() {
m.logger.Infof("Successfully processed MQTT message after it previously failed: %s/%d", mqttMsg.Topic(), mqttMsg.MessageID())
},
)
if err != nil {
m.logger.Errorf("Failed processing MQTT message: %s/%d: %v", mqttMsg.Topic(), mqttMsg.MessageID(), err)
}
}
}
@ -336,17 +342,15 @@ func (m *MQTT) createSubscriberClientOptions(uri *url.URL, clientID string) *mqt
// On (re-)connection, add the topic subscription
opts.OnConnect = func(c mqtt.Client) {
token := c.Subscribe(m.metadata.topic, m.metadata.qos, m.handleMessage)
token := c.Subscribe(m.metadata.topic, m.metadata.qos, m.handleMessage())
var err error
subscribeCtx, subscribeCancel := context.WithTimeout(m.ctx, defaultWait)
defer subscribeCancel()
select {
case <-token.Done():
// Subscription went through (sucecessfully or not)
err = token.Error()
case <-subscribeCtx.Done():
err = fmt.Errorf("error while waiting for subscription token: %w", subscribeCtx.Err())
case <-time.After(defaultWait):
err = errors.New("timed out waiting for subscription to complete")
}
// Nothing we can do in case of errors besides logging them
@ -363,13 +367,16 @@ func (m *MQTT) Close() error {
m.producerLock.Lock()
defer m.producerLock.Unlock()
// Canceling the context also causes Read to stop receiving messages
m.cancel()
if m.closed.CompareAndSwap(false, true) {
close(m.closeCh)
}
if m.producer != nil {
m.producer.Disconnect(200)
m.producer = nil
}
m.wg.Wait()
return nil
}

View File

@ -49,6 +49,7 @@ func getConnectionString() string {
func TestInvokeWithTopic(t *testing.T) {
t.Parallel()
ctx := context.Background()
url := getConnectionString()
if url == "" {
@ -79,7 +80,7 @@ func TestInvokeWithTopic(t *testing.T) {
logger := logger.NewLogger("test")
r := NewMQTT(logger).(*MQTT)
err := r.Init(metadata)
err := r.Init(ctx, metadata)
assert.Nil(t, err)
conn, err := r.connect(uuid.NewString(), false)
@ -127,4 +128,5 @@ func TestInvokeWithTopic(t *testing.T) {
assert.True(t, ok)
assert.Equal(t, dataCustomized, mqttMessage.Payload())
assert.Equal(t, topicCustomized, mqttMessage.Topic())
assert.NoError(t, r.Close())
}

View File

@ -205,7 +205,6 @@ func TestParseMetadata(t *testing.T) {
logger := logger.NewLogger("test")
m := NewMQTT(logger).(*MQTT)
m.backOff = backoff.NewConstantBackOff(5 * time.Second)
m.ctx, m.cancel = context.WithCancel(context.Background())
m.readHandler = func(ctx context.Context, r *bindings.ReadResponse) ([]byte, error) {
assert.Equal(t, payload, r.Data)
metadata := r.Metadata
@ -215,7 +214,7 @@ func TestParseMetadata(t *testing.T) {
return r.Data, nil
}
m.handleMessage(nil, &mqttMockMessage{
m.handleMessage()(nil, &mqttMockMessage{
topic: topic,
payload: payload,
})

View File

@ -81,7 +81,7 @@ func NewMysql(logger logger.Logger) bindings.OutputBinding {
}
// Init initializes the MySQL binding.
func (m *Mysql) Init(metadata bindings.Metadata) error {
func (m *Mysql) Init(ctx context.Context, metadata bindings.Metadata) error {
m.logger.Debug("Initializing MySql binding")
p := metadata.Properties
@ -115,7 +115,7 @@ func (m *Mysql) Init(metadata bindings.Metadata) error {
return err
}
err = db.Ping()
err = db.PingContext(ctx)
if err != nil {
return fmt.Errorf("unable to ping the DB: %w", err)
}

View File

@ -75,7 +75,7 @@ func TestMysqlIntegration(t *testing.T) {
b := NewMysql(logger.NewLogger("test")).(*Mysql)
m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{connectionURLKey: url}}}
if err := b.Init(m); err != nil {
if err := b.Init(context.Background(), m); err != nil {
t.Fatal(err)
}

View File

@ -21,6 +21,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/nacos-group/nacos-sdk-go/v2/clients"
@ -56,6 +57,9 @@ type Nacos struct {
logger logger.Logger
configClient config_client.IConfigClient //nolint:nosnakecase
readHandler func(ctx context.Context, response *bindings.ReadResponse) ([]byte, error)
wg sync.WaitGroup
closed atomic.Bool
closeCh chan struct{}
}
// NewNacos returns a new Nacos instance.
@ -63,11 +67,12 @@ func NewNacos(logger logger.Logger) bindings.OutputBinding {
return &Nacos{
logger: logger,
watchesLock: sync.Mutex{},
closeCh: make(chan struct{}),
}
}
// Init implements InputBinding/OutputBinding's Init method.
func (n *Nacos) Init(metadata bindings.Metadata) error {
func (n *Nacos) Init(_ context.Context, metadata bindings.Metadata) error {
n.settings = Settings{
Timeout: defaultTimeout,
}
@ -146,6 +151,10 @@ func (n *Nacos) createConfigClient() error {
// Read implements InputBinding's Read method.
func (n *Nacos) Read(ctx context.Context, handler bindings.Handler) error {
if n.closed.Load() {
return errors.New("binding is closed")
}
n.readHandler = handler
n.watchesLock.Lock()
@ -154,9 +163,14 @@ func (n *Nacos) Read(ctx context.Context, handler bindings.Handler) error {
}
n.watchesLock.Unlock()
n.wg.Add(1)
go func() {
defer n.wg.Done()
// Cancel all listeners when the context is done
<-ctx.Done()
select {
case <-ctx.Done():
case <-n.closeCh:
}
n.cancelAllListeners()
}()
@ -165,8 +179,14 @@ func (n *Nacos) Read(ctx context.Context, handler bindings.Handler) error {
// Close implements cancel all listeners, see https://github.com/dapr/components-contrib/issues/779
func (n *Nacos) Close() error {
if n.closed.CompareAndSwap(false, true) {
close(n.closeCh)
}
n.cancelAllListeners()
n.wg.Wait()
return nil
}
@ -223,7 +243,11 @@ func (n *Nacos) addListener(ctx context.Context, config configParam) {
func (n *Nacos) addListenerFoInputBinding(ctx context.Context, config configParam) {
if n.addToWatches(config) {
go n.addListener(ctx, config)
n.wg.Add(1)
go func() {
defer n.wg.Done()
n.addListener(ctx, config)
}()
}
}

View File

@ -35,7 +35,7 @@ func TestInputBindingRead(t *testing.T) { //nolint:paralleltest
m.Properties, err = getNacosLocalCacheMetadata()
require.NoError(t, err)
n := NewNacos(logger.NewLogger("test")).(*Nacos)
err = n.Init(m)
err = n.Init(context.Background(), m)
require.NoError(t, err)
var count int32
ch := make(chan bool, 1)

View File

@ -22,15 +22,15 @@ import (
// OutputBinding is the interface for an output binding, allowing users to invoke remote systems with optional payloads.
type OutputBinding interface {
Init(metadata Metadata) error
Init(ctx context.Context, metadata Metadata) error
Invoke(ctx context.Context, req *InvokeRequest) (*InvokeResponse, error)
Operations() []OperationKind
}
func PingOutBinding(outputBinding OutputBinding) error {
func PingOutBinding(ctx context.Context, outputBinding OutputBinding) error {
// checks if this output binding has the ping option then executes
if outputBindingWithPing, ok := outputBinding.(health.Pinger); ok {
return outputBindingWithPing.Ping()
return outputBindingWithPing.Ping(ctx)
} else {
return fmt.Errorf("ping is not implemented by this output binding")
}

View File

@ -49,7 +49,7 @@ func NewPostgres(logger logger.Logger) bindings.OutputBinding {
}
// Init initializes the PostgreSql binding.
func (p *Postgres) Init(metadata bindings.Metadata) error {
func (p *Postgres) Init(ctx context.Context, metadata bindings.Metadata) error {
url, ok := metadata.Properties[connectionURLKey]
if !ok || url == "" {
return fmt.Errorf("required metadata not set: %s", connectionURLKey)
@ -60,7 +60,9 @@ func (p *Postgres) Init(metadata bindings.Metadata) error {
return fmt.Errorf("error opening DB connection: %w", err)
}
p.db, err = pgxpool.NewWithConfig(context.Background(), poolConfig)
// This context doesn't control the lifetime of the connection pool, and is
// only scoped to postgres creating resources at init.
p.db, err = pgxpool.NewWithConfig(ctx, poolConfig)
if err != nil {
return fmt.Errorf("unable to ping the DB: %w", err)
}

View File

@ -64,7 +64,7 @@ func TestPostgresIntegration(t *testing.T) {
// live DB test
b := NewPostgres(logger.NewLogger("test")).(*Postgres)
m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{connectionURLKey: url}}}
if err := b.Init(m); err != nil {
if err := b.Init(context.Background(), m); err != nil {
t.Fatal(err)
}

View File

@ -74,7 +74,7 @@ func (p *Postmark) parseMetadata(meta bindings.Metadata) (postmarkMetadata, erro
}
// Init does metadata parsing and not much else :).
func (p *Postmark) Init(metadata bindings.Metadata) error {
func (p *Postmark) Init(_ context.Context, metadata bindings.Metadata) error {
// Parse input metadata
meta, err := p.parseMetadata(metadata)
if err != nil {

View File

@ -19,6 +19,8 @@ import (
"fmt"
"math"
"strconv"
"sync"
"sync/atomic"
"time"
amqp "github.com/rabbitmq/amqp091-go"
@ -50,6 +52,9 @@ type RabbitMQ struct {
metadata rabbitMQMetadata
logger logger.Logger
queue amqp.Queue
closed atomic.Bool
closeCh chan struct{}
wg sync.WaitGroup
}
// Metadata is the rabbitmq config.
@ -66,11 +71,14 @@ type rabbitMQMetadata struct {
// NewRabbitMQ returns a new rabbitmq instance.
func NewRabbitMQ(logger logger.Logger) bindings.InputOutputBinding {
return &RabbitMQ{logger: logger}
return &RabbitMQ{
logger: logger,
closeCh: make(chan struct{}),
}
}
// Init does metadata parsing and connection creation.
func (r *RabbitMQ) Init(metadata bindings.Metadata) error {
func (r *RabbitMQ) Init(_ context.Context, metadata bindings.Metadata) error {
err := r.parseMetadata(metadata)
if err != nil {
return err
@ -226,6 +234,10 @@ func (r *RabbitMQ) declareQueue() (amqp.Queue, error) {
}
func (r *RabbitMQ) Read(ctx context.Context, handler bindings.Handler) error {
if r.closed.Load() {
return errors.New("binding already closed")
}
msgs, err := r.channel.Consume(
r.queue.Name,
"",
@ -239,14 +251,27 @@ func (r *RabbitMQ) Read(ctx context.Context, handler bindings.Handler) error {
return err
}
readCtx, cancel := context.WithCancel(ctx)
r.wg.Add(2)
go func() {
defer r.wg.Done()
defer cancel()
select {
case <-r.closeCh:
case <-readCtx.Done():
}
}()
go func() {
defer r.wg.Done()
var err error
for {
select {
case <-ctx.Done():
case <-readCtx.Done():
return
case d := <-msgs:
_, err = handler(ctx, &bindings.ReadResponse{
_, err = handler(readCtx, &bindings.ReadResponse{
Data: d.Body,
})
if err != nil {
@ -260,3 +285,11 @@ func (r *RabbitMQ) Read(ctx context.Context, handler bindings.Handler) error {
return nil
}
func (r *RabbitMQ) Close() error {
if r.closed.CompareAndSwap(false, true) {
close(r.closeCh)
}
defer r.wg.Wait()
return r.channel.Close()
}

View File

@ -85,7 +85,7 @@ func TestQueuesWithTTL(t *testing.T) {
logger := logger.NewLogger("test")
r := NewRabbitMQ(logger).(*RabbitMQ)
err := r.Init(metadata)
err := r.Init(context.Background(), metadata)
assert.Nil(t, err)
// Assert that if waited too long, we won't see any message
@ -117,6 +117,7 @@ func TestQueuesWithTTL(t *testing.T) {
assert.True(t, ok)
msgBody := string(msg.Body)
assert.Equal(t, testMsgContent, msgBody)
assert.NoError(t, r.Close())
}
func TestPublishingWithTTL(t *testing.T) {
@ -144,7 +145,7 @@ func TestPublishingWithTTL(t *testing.T) {
logger := logger.NewLogger("test")
rabbitMQBinding1 := NewRabbitMQ(logger).(*RabbitMQ)
err := rabbitMQBinding1.Init(metadata)
err := rabbitMQBinding1.Init(context.Background(), metadata)
assert.Nil(t, err)
// Assert that if waited too long, we won't see any message
@ -175,7 +176,7 @@ func TestPublishingWithTTL(t *testing.T) {
// Getting before it is expired, should return it
rabbitMQBinding2 := NewRabbitMQ(logger).(*RabbitMQ)
err = rabbitMQBinding2.Init(metadata)
err = rabbitMQBinding2.Init(context.Background(), metadata)
assert.Nil(t, err)
const testMsgContent = "test_msg"
@ -193,6 +194,9 @@ func TestPublishingWithTTL(t *testing.T) {
assert.True(t, ok)
msgBody := string(msg.Body)
assert.Equal(t, testMsgContent, msgBody)
assert.NoError(t, rabbitMQBinding1.Close())
assert.NoError(t, rabbitMQBinding1.Close())
}
func TestExclusiveQueue(t *testing.T) {
@ -222,7 +226,7 @@ func TestExclusiveQueue(t *testing.T) {
logger := logger.NewLogger("test")
r := NewRabbitMQ(logger).(*RabbitMQ)
err := r.Init(metadata)
err := r.Init(context.Background(), metadata)
assert.Nil(t, err)
// Assert that if waited too long, we won't see any message
@ -276,7 +280,7 @@ func TestPublishWithPriority(t *testing.T) {
logger := logger.NewLogger("test")
r := NewRabbitMQ(logger).(*RabbitMQ)
err := r.Init(metadata)
err := r.Init(context.Background(), metadata)
assert.Nil(t, err)
// Assert that if waited too long, we won't see any message

View File

@ -28,9 +28,6 @@ type Redis struct {
client rediscomponent.RedisClient
clientSettings *rediscomponent.Settings
logger logger.Logger
ctx context.Context
cancel context.CancelFunc
}
// NewRedis returns a new redis bindings instance.
@ -39,15 +36,13 @@ func NewRedis(logger logger.Logger) bindings.OutputBinding {
}
// Init performs metadata parsing and connection creation.
func (r *Redis) Init(meta bindings.Metadata) (err error) {
func (r *Redis) Init(ctx context.Context, meta bindings.Metadata) (err error) {
r.client, r.clientSettings, err = rediscomponent.ParseClientFromProperties(meta.Properties, nil)
if err != nil {
return err
}
r.ctx, r.cancel = context.WithCancel(context.Background())
_, err = r.client.PingResult(r.ctx)
_, err = r.client.PingResult(ctx)
if err != nil {
return fmt.Errorf("redis binding: error connecting to redis at %s: %s", r.clientSettings.Host, err)
}
@ -55,8 +50,8 @@ func (r *Redis) Init(meta bindings.Metadata) (err error) {
return err
}
func (r *Redis) Ping() error {
if _, err := r.client.PingResult(r.ctx); err != nil {
func (r *Redis) Ping(ctx context.Context) error {
if _, err := r.client.PingResult(ctx); err != nil {
return fmt.Errorf("redis binding: error connecting to redis at %s: %s", r.clientSettings.Host, err)
}
@ -101,7 +96,5 @@ func (r *Redis) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindi
}
func (r *Redis) Close() error {
r.cancel()
return r.client.Close()
}

View File

@ -40,7 +40,6 @@ func TestInvokeCreate(t *testing.T) {
client: c,
logger: logger.NewLogger("test"),
}
bind.ctx, bind.cancel = context.WithCancel(context.Background())
_, err := c.DoRead(context.Background(), "GET", testKey)
assert.Equal(t, redis.Nil, err)
@ -66,7 +65,6 @@ func TestInvokeGet(t *testing.T) {
client: c,
logger: logger.NewLogger("test"),
}
bind.ctx, bind.cancel = context.WithCancel(context.Background())
err := c.DoWrite(context.Background(), "SET", testKey, testData)
assert.Equal(t, nil, err)
@ -87,7 +85,6 @@ func TestInvokeDelete(t *testing.T) {
client: c,
logger: logger.NewLogger("test"),
}
bind.ctx, bind.cancel = context.WithCancel(context.Background())
err := c.DoWrite(context.Background(), "SET", testKey, testData)
assert.Equal(t, nil, err)

View File

@ -19,6 +19,8 @@ import (
"fmt"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
r "github.com/dancannon/gorethink"
@ -34,6 +36,9 @@ type Binding struct {
logger logger.Logger
session *r.Session
config StateConfig
closed atomic.Bool
closeCh chan struct{}
wg sync.WaitGroup
}
// StateConfig is the binding config.
@ -45,12 +50,13 @@ type StateConfig struct {
// NewRethinkDBStateChangeBinding returns a new RethinkDB actor event input binding.
func NewRethinkDBStateChangeBinding(logger logger.Logger) bindings.InputBinding {
return &Binding{
logger: logger,
logger: logger,
closeCh: make(chan struct{}),
}
}
// Init initializes the RethinkDB binding.
func (b *Binding) Init(metadata bindings.Metadata) error {
func (b *Binding) Init(ctx context.Context, metadata bindings.Metadata) error {
cfg, err := metadataToConfig(metadata.Properties, b.logger)
if err != nil {
return fmt.Errorf("unable to parse metadata properties: %w", err)
@ -68,6 +74,10 @@ func (b *Binding) Init(metadata bindings.Metadata) error {
// Read triggers the RethinkDB scheduler.
func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
if b.closed.Load() {
return errors.New("binding is closed")
}
b.logger.Infof("subscribing to state changes in %s.%s...", b.config.Database, b.config.Table)
cursor, err := r.DB(b.config.Database).
Table(b.config.Table).
@ -81,8 +91,21 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
return fmt.Errorf("error connecting to table '%s': %w", b.config.Table, err)
}
readCtx, cancel := context.WithCancel(ctx)
b.wg.Add(2)
go func() {
for ctx.Err() == nil {
defer b.wg.Done()
defer cancel()
select {
case <-b.closeCh:
case <-readCtx.Done():
}
}()
go func() {
defer b.wg.Done()
for readCtx.Err() == nil {
var change interface{}
ok := cursor.Next(&change)
if !ok {
@ -105,7 +128,7 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
},
}
if _, err := handler(ctx, resp); err != nil {
if _, err := handler(readCtx, resp); err != nil {
b.logger.Errorf("error invoking change handler: %v", err)
continue
}
@ -117,6 +140,14 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
return nil
}
func (b *Binding) Close() error {
if b.closed.CompareAndSwap(false, true) {
close(b.closeCh)
}
defer b.wg.Wait()
return b.session.Close()
}
func metadataToConfig(cfg map[string]string, logger logger.Logger) (StateConfig, error) {
c := StateConfig{}
for k, v := range cfg {

View File

@ -71,7 +71,7 @@ func TestBinding(t *testing.T) {
assert.NotNil(t, m.Properties)
b := getNewRethinkActorBinding()
err := b.Init(m)
err := b.Init(context.Background(), m)
assert.NoErrorf(t, err, "error initializing")
ctx, cancel := context.WithCancel(context.Background())

View File

@ -18,6 +18,8 @@ import (
"errors"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
mqc "github.com/apache/rocketmq-client-go/v2/consumer"
@ -35,27 +37,27 @@ type RocketMQ struct {
settings Settings
producer mqw.Producer
ctx context.Context
cancel context.CancelFunc
backOffConfig retry.Config
closeCh chan struct{}
closed atomic.Bool
wg sync.WaitGroup
}
func NewRocketMQ(l logger.Logger) *RocketMQ {
return &RocketMQ{ //nolint:exhaustivestruct
logger: l,
producer: nil,
closeCh: make(chan struct{}),
}
}
// Init performs metadata parsing.
func (a *RocketMQ) Init(metadata bindings.Metadata) error {
func (a *RocketMQ) Init(ctx context.Context, metadata bindings.Metadata) error {
var err error
if err = a.settings.Decode(metadata.Properties); err != nil {
return err
}
a.ctx, a.cancel = context.WithCancel(context.Background())
// Default retry configuration is used if no
// backOff properties are set.
if err = retry.DecodeConfigWithPrefix(
@ -75,6 +77,10 @@ func (a *RocketMQ) Init(metadata bindings.Metadata) error {
// Read triggers the rocketmq subscription.
func (a *RocketMQ) Read(ctx context.Context, handler bindings.Handler) error {
if a.closed.Load() {
return errors.New("error: binding is closed")
}
a.logger.Debugf("binding rocketmq: start read input binding")
consumer, err := a.setupConsumer()
@ -114,10 +120,12 @@ func (a *RocketMQ) Read(ctx context.Context, handler bindings.Handler) error {
a.logger.Debugf("binding-rocketmq: consumer started")
// Listen for context cancelation to stop the subscription
a.wg.Add(1)
go func() {
defer a.wg.Done()
select {
case <-ctx.Done():
case <-a.ctx.Done():
case <-a.closeCh:
}
innerErr := consumer.Shutdown()
@ -131,8 +139,10 @@ func (a *RocketMQ) Read(ctx context.Context, handler bindings.Handler) error {
// Close implements cancel all listeners, see https://github.com/dapr/components-contrib/issues/779
func (a *RocketMQ) Close() error {
a.cancel()
defer a.wg.Wait()
if a.closed.CompareAndSwap(false, true) {
close(a.closeCh)
}
return nil
}
@ -199,21 +209,21 @@ func (a *RocketMQ) Operations() []bindings.OperationKind {
return []bindings.OperationKind{bindings.CreateOperation}
}
func (a *RocketMQ) Invoke(req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
func (a *RocketMQ) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
rst := &bindings.InvokeResponse{Data: nil, Metadata: nil}
if req.Operation != bindings.CreateOperation {
return rst, fmt.Errorf("binding-rocketmq error: unsupported operation %s", req.Operation)
}
return rst, a.sendMessage(req)
return rst, a.sendMessage(ctx, req)
}
func (a *RocketMQ) sendMessage(req *bindings.InvokeRequest) error {
func (a *RocketMQ) sendMessage(ctx context.Context, req *bindings.InvokeRequest) error {
topic := req.Metadata[metadataRocketmqTopic]
if topic != "" {
_, err := a.send(topic, req.Metadata[metadataRocketmqTag], req.Metadata[metadataRocketmqKey], req.Data)
_, err := a.send(ctx, topic, req.Metadata[metadataRocketmqTag], req.Metadata[metadataRocketmqKey], req.Data)
if err != nil {
return err
}
@ -229,7 +239,7 @@ func (a *RocketMQ) sendMessage(req *bindings.InvokeRequest) error {
if err != nil {
return err
}
_, err = a.send(topic, mqExpression, req.Metadata[metadataRocketmqKey], req.Data)
_, err = a.send(ctx, topic, mqExpression, req.Metadata[metadataRocketmqKey], req.Data)
if err != nil {
return err
}
@ -239,9 +249,9 @@ func (a *RocketMQ) sendMessage(req *bindings.InvokeRequest) error {
return nil
}
func (a *RocketMQ) send(topic, mqExpr, key string, data []byte) (bool, error) {
func (a *RocketMQ) send(ctx context.Context, topic, mqExpr, key string, data []byte) (bool, error) {
msg := primitive.NewMessage(topic, data).WithTag(mqExpr).WithKeys([]string{key})
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
rst, err := a.producer.SendSync(ctx, msg)
if err != nil {

View File

@ -35,7 +35,7 @@ func TestInputBindingRead(t *testing.T) { //nolint:paralleltest
m := bindings.Metadata{} //nolint:exhaustivestruct
m.Properties = getTestMetadata()
r := NewRocketMQ(logger.NewLogger("test"))
err := r.Init(m)
err := r.Init(context.Background(), m)
require.NoError(t, err)
var count int32
@ -51,7 +51,7 @@ func TestInputBindingRead(t *testing.T) { //nolint:paralleltest
time.Sleep(5 * time.Second)
atomic.StoreInt32(&count, 0)
req := &bindings.InvokeRequest{Data: []byte("hello"), Operation: bindings.CreateOperation, Metadata: map[string]string{}}
_, err = r.Invoke(req)
_, err = r.Invoke(context.Background(), req)
require.NoError(t, err)
time.Sleep(10 * time.Second)

View File

@ -61,7 +61,7 @@ func NewSMTP(logger logger.Logger) bindings.OutputBinding {
}
// Init smtp component (parse metadata).
func (s *Mailer) Init(metadata bindings.Metadata) error {
func (s *Mailer) Init(_ context.Context, metadata bindings.Metadata) error {
// parse metadata
meta, err := s.parseMetadata(metadata)
if err != nil {

View File

@ -84,7 +84,7 @@ func (sg *SendGrid) parseMetadata(meta bindings.Metadata) (sendGridMetadata, err
}
// Init does metadata parsing and not much else :).
func (sg *SendGrid) Init(metadata bindings.Metadata) error {
func (sg *SendGrid) Init(_ context.Context, metadata bindings.Metadata) error {
// Parse input metadata
meta, err := sg.parseMetadata(metadata)
if err != nil {

View File

@ -60,19 +60,19 @@ func NewSMS(logger logger.Logger) bindings.OutputBinding {
}
}
func (t *SMS) Init(metadata bindings.Metadata) error {
func (t *SMS) Init(_ context.Context, metadata bindings.Metadata) error {
twilioM := twilioMetadata{
timeout: time.Minute * 5,
}
if metadata.Properties[fromNumber] == "" {
return errors.New("\"fromNumber\" is a required field")
return errors.New(`"fromNumber" is a required field`)
}
if metadata.Properties[accountSid] == "" {
return errors.New("\"accountSid\" is a required field")
return errors.New(`"accountSid" is a required field`)
}
if metadata.Properties[authToken] == "" {
return errors.New("\"authToken\" is a required field")
return errors.New(`"authToken" is a required field`)
}
twilioM.toNumber = metadata.Properties[toNumber]

View File

@ -53,7 +53,7 @@ func TestInit(t *testing.T) {
m := bindings.Metadata{}
m.Properties = map[string]string{"toNumber": "toNumber", "fromNumber": "fromNumber"}
tw := NewSMS(logger.NewLogger("test"))
err := tw.Init(m)
err := tw.Init(context.Background(), m)
assert.NotNil(t, err)
}
@ -66,7 +66,7 @@ func TestParseDuration(t *testing.T) {
"authToken": "authToken", "timeout": "badtimeout",
}
tw := NewSMS(logger.NewLogger("test"))
err := tw.Init(m)
err := tw.Init(context.Background(), m)
assert.NotNil(t, err)
}
@ -85,7 +85,7 @@ func TestWriteShouldSucceed(t *testing.T) {
tw.httpClient = &http.Client{
Transport: httpTransport,
}
err := tw.Init(m)
err := tw.Init(context.Background(), m)
assert.NoError(t, err)
t.Run("Should succeed with expected url and headers", func(t *testing.T) {
@ -123,7 +123,7 @@ func TestWriteShouldFail(t *testing.T) {
tw.httpClient = &http.Client{
Transport: httpTransport,
}
err := tw.Init(m)
err := tw.Init(context.Background(), m)
assert.NoError(t, err)
t.Run("Missing 'to' should fail", func(t *testing.T) {
@ -180,7 +180,7 @@ func TestMessageBody(t *testing.T) {
tw.httpClient = &http.Client{
Transport: httpTransport,
}
err := tw.Init(m)
err := tw.Init(context.Background(), m)
require.NoError(t, err)
tester := func(reqData []byte, expectBody string) func(t *testing.T) {

View File

@ -20,6 +20,8 @@ import (
"errors"
"fmt"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/dghubble/go-twitter/twitter"
@ -31,18 +33,21 @@ import (
// Binding represents Twitter input/output binding.
type Binding struct {
client *twitter.Client
query string
logger logger.Logger
client *twitter.Client
query string
logger logger.Logger
closed atomic.Bool
closeCh chan struct{}
wg sync.WaitGroup
}
// NewTwitter returns a new Twitter event input binding.
func NewTwitter(logger logger.Logger) bindings.InputOutputBinding {
return &Binding{logger: logger}
return &Binding{logger: logger, closeCh: make(chan struct{})}
}
// Init initializes the Twitter binding.
func (t *Binding) Init(metadata bindings.Metadata) error {
func (t *Binding) Init(ctx context.Context, metadata bindings.Metadata) error {
ck, f := metadata.Properties["consumerKey"]
if !f || ck == "" {
return fmt.Errorf("consumerKey not set")
@ -124,10 +129,17 @@ func (t *Binding) Read(ctx context.Context, handler bindings.Handler) error {
}
t.logger.Debug("starting handler...")
go demux.HandleChan(stream.Messages)
t.wg.Add(2)
go func() {
<-ctx.Done()
defer t.wg.Done()
demux.HandleChan(stream.Messages)
}()
go func() {
defer t.wg.Done()
select {
case <-t.closeCh:
case <-ctx.Done():
}
t.logger.Debug("stopping handler...")
stream.Stop()
}()
@ -135,6 +147,14 @@ func (t *Binding) Read(ctx context.Context, handler bindings.Handler) error {
return nil
}
func (t *Binding) Close() error {
if t.closed.CompareAndSwap(false, true) {
close(t.closeCh)
}
t.wg.Wait()
return nil
}
// Invoke handles all operations.
func (t *Binding) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
t.logger.Debugf("operation: %v", req.Operation)

View File

@ -60,7 +60,7 @@ func getRuntimeMetadata() map[string]string {
func TestInit(t *testing.T) {
m := getTestMetadata()
tw := NewTwitter(logger.NewLogger("test")).(*Binding)
err := tw.Init(m)
err := tw.Init(context.Background(), m)
assert.Nilf(t, err, "error initializing valid metadata properties")
}
@ -69,7 +69,7 @@ func TestInit(t *testing.T) {
func TestReadError(t *testing.T) {
tw := NewTwitter(logger.NewLogger("test")).(*Binding)
m := getTestMetadata()
err := tw.Init(m)
err := tw.Init(context.Background(), m)
assert.Nilf(t, err, "error initializing valid metadata properties")
err = tw.Read(context.Background(), func(ctx context.Context, res *bindings.ReadResponse) ([]byte, error) {
@ -79,6 +79,8 @@ func TestReadError(t *testing.T) {
return nil, nil
})
assert.Error(t, err)
assert.NoError(t, tw.Close())
}
// TestRead executes the Read method which calls Twiter API
@ -93,7 +95,7 @@ func TestRead(t *testing.T) {
m.Properties["query"] = "microsoft"
tw := NewTwitter(logger.NewLogger("test")).(*Binding)
tw.logger.SetOutputLevel(logger.DebugLevel)
err := tw.Init(m)
err := tw.Init(context.Background(), m)
assert.Nilf(t, err, "error initializing read")
ctx, cancel := context.WithCancel(context.Background())
@ -116,6 +118,8 @@ func TestRead(t *testing.T) {
cancel()
t.Fatal("Timeout waiting for messages")
}
assert.NoError(t, tw.Close())
}
// TestInvoke executes the Invoke method which calls Twiter API
@ -129,7 +133,7 @@ func TestInvoke(t *testing.T) {
m.Properties = getRuntimeMetadata()
tw := NewTwitter(logger.NewLogger("test")).(*Binding)
tw.logger.SetOutputLevel(logger.DebugLevel)
err := tw.Init(m)
err := tw.Init(context.Background(), m)
assert.Nilf(t, err, "error initializing Invoke")
req := &bindings.InvokeRequest{
@ -141,4 +145,5 @@ func TestInvoke(t *testing.T) {
resp, err := tw.Invoke(context.Background(), req)
assert.Nilf(t, err, "error on invoke")
assert.NotNil(t, resp)
assert.NoError(t, tw.Close())
}

View File

@ -61,7 +61,7 @@ func NewZeebeCommand(logger logger.Logger) bindings.OutputBinding {
}
// Init does metadata parsing and connection creation.
func (z *ZeebeCommand) Init(metadata bindings.Metadata) error {
func (z *ZeebeCommand) Init(ctx context.Context, metadata bindings.Metadata) error {
client, err := z.clientFactory.Get(metadata)
if err != nil {
return err
@ -114,7 +114,7 @@ func (z *ZeebeCommand) Invoke(ctx context.Context, req *bindings.InvokeRequest)
case UpdateJobRetriesOperation:
return z.updateJobRetries(ctx, req)
case ThrowErrorOperation:
return z.throwError(req)
return z.throwError(ctx, req)
case bindings.GetOperation:
fallthrough
case bindings.CreateOperation:

View File

@ -58,7 +58,7 @@ func TestInit(t *testing.T) {
}
cmd := ZeebeCommand{clientFactory: mcf, logger: testLogger}
err := cmd.Init(metadata)
err := cmd.Init(context.Background(), metadata)
assert.Error(t, err, errParsing)
})
@ -67,7 +67,7 @@ func TestInit(t *testing.T) {
mcf := mockClientFactory{}
cmd := ZeebeCommand{clientFactory: mcf, logger: testLogger}
err := cmd.Init(metadata)
err := cmd.Init(context.Background(), metadata)
assert.NoError(t, err)

View File

@ -30,7 +30,7 @@ type throwErrorPayload struct {
ErrorMessage string `json:"errorMessage"`
}
func (z *ZeebeCommand) throwError(req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
func (z *ZeebeCommand) throwError(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
var payload throwErrorPayload
err := json.Unmarshal(req.Data, &payload)
if err != nil {
@ -53,7 +53,7 @@ func (z *ZeebeCommand) throwError(req *bindings.InvokeRequest) (*bindings.Invoke
cmd = cmd.ErrorMessage(payload.ErrorMessage)
}
_, err = cmd.Send(context.Background())
_, err = cmd.Send(ctx)
if err != nil {
return nil, fmt.Errorf("cannot throw error for job key %d: %w", payload.JobKey, err)
}

View File

@ -19,6 +19,8 @@ import (
"errors"
"fmt"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/camunda/zeebe/clients/go/v8/pkg/entities"
@ -39,6 +41,9 @@ type ZeebeJobWorker struct {
client zbc.Client
metadata *jobWorkerMetadata
logger logger.Logger
closed atomic.Bool
closeCh chan struct{}
wg sync.WaitGroup
}
// https://docs.zeebe.io/basics/job-workers.html
@ -64,11 +69,15 @@ type jobHandler struct {
// NewZeebeJobWorker returns a new ZeebeJobWorker instance.
func NewZeebeJobWorker(logger logger.Logger) bindings.InputBinding {
return &ZeebeJobWorker{clientFactory: zeebe.NewClientFactoryImpl(logger), logger: logger}
return &ZeebeJobWorker{
clientFactory: zeebe.NewClientFactoryImpl(logger),
logger: logger,
closeCh: make(chan struct{}),
}
}
// Init does metadata parsing and connection creation.
func (z *ZeebeJobWorker) Init(metadata bindings.Metadata) error {
func (z *ZeebeJobWorker) Init(ctx context.Context, metadata bindings.Metadata) error {
meta, err := z.parseMetadata(metadata)
if err != nil {
return err
@ -90,6 +99,10 @@ func (z *ZeebeJobWorker) Init(metadata bindings.Metadata) error {
}
func (z *ZeebeJobWorker) Read(ctx context.Context, handler bindings.Handler) error {
if z.closed.Load() {
return fmt.Errorf("binding is closed")
}
h := jobHandler{
callback: handler,
logger: z.logger,
@ -99,8 +112,14 @@ func (z *ZeebeJobWorker) Read(ctx context.Context, handler bindings.Handler) err
jobWorker := z.getJobWorker(h)
z.wg.Add(1)
go func() {
<-ctx.Done()
defer z.wg.Done()
select {
case <-z.closeCh:
case <-ctx.Done():
}
jobWorker.Close()
jobWorker.AwaitClose()
@ -110,6 +129,14 @@ func (z *ZeebeJobWorker) Read(ctx context.Context, handler bindings.Handler) err
return nil
}
func (z *ZeebeJobWorker) Close() error {
if z.closed.CompareAndSwap(false, true) {
close(z.closeCh)
}
z.wg.Wait()
return nil
}
func (z *ZeebeJobWorker) parseMetadata(meta bindings.Metadata) (*jobWorkerMetadata, error) {
var m jobWorkerMetadata
err := metadata.DecodeMetadata(meta.Properties, &m)

View File

@ -14,6 +14,7 @@ limitations under the License.
package jobworker
import (
"context"
"errors"
"testing"
@ -53,10 +54,11 @@ func TestInit(t *testing.T) {
metadata := bindings.Metadata{}
var mcf mockClientFactory
jobWorker := ZeebeJobWorker{clientFactory: &mcf, logger: testLogger}
err := jobWorker.Init(metadata)
jobWorker := ZeebeJobWorker{clientFactory: &mcf, logger: testLogger, closeCh: make(chan struct{})}
err := jobWorker.Init(context.Background(), metadata)
assert.Error(t, err, ErrMissingJobType)
assert.NoError(t, jobWorker.Close())
})
t.Run("sets client from client factory", func(t *testing.T) {
@ -66,8 +68,8 @@ func TestInit(t *testing.T) {
mcf := mockClientFactory{
metadata: metadata,
}
jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger}
err := jobWorker.Init(metadata)
jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger, closeCh: make(chan struct{})}
err := jobWorker.Init(context.Background(), metadata)
assert.NoError(t, err)
@ -76,6 +78,7 @@ func TestInit(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, mc, jobWorker.client)
assert.Equal(t, metadata, mcf.metadata)
assert.NoError(t, jobWorker.Close())
})
t.Run("returns error if client could not be instantiated properly", func(t *testing.T) {
@ -85,9 +88,10 @@ func TestInit(t *testing.T) {
error: errParsing,
}
jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger}
err := jobWorker.Init(metadata)
jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger, closeCh: make(chan struct{})}
err := jobWorker.Init(context.Background(), metadata)
assert.Error(t, err, errParsing)
assert.NoError(t, jobWorker.Close())
})
t.Run("sets client from client factory", func(t *testing.T) {
@ -98,8 +102,8 @@ func TestInit(t *testing.T) {
metadata: metadata,
}
jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger}
err := jobWorker.Init(metadata)
jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger, closeCh: make(chan struct{})}
err := jobWorker.Init(context.Background(), metadata)
assert.NoError(t, err)
@ -108,5 +112,6 @@ func TestInit(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, mc, jobWorker.client)
assert.Equal(t, metadata, mcf.metadata)
assert.NoError(t, jobWorker.Close())
})
}

View File

@ -73,7 +73,7 @@ func NewAzureAppConfigurationStore(logger logger.Logger) configuration.Store {
}
// Init does metadata and connection parsing.
func (r *ConfigurationStore) Init(metadata configuration.Metadata) error {
func (r *ConfigurationStore) Init(_ context.Context, metadata configuration.Metadata) error {
m, err := parseMetadata(metadata)
if err != nil {
return err

View File

@ -204,7 +204,7 @@ func TestInit(t *testing.T) {
Properties: testProperties,
}}
err := s.Init(m)
err := s.Init(context.Background(), m)
assert.Nil(t, err)
cs, ok := s.(*ConfigurationStore)
assert.True(t, ok)
@ -229,7 +229,7 @@ func TestInit(t *testing.T) {
Properties: testProperties,
}}
err := s.Init(m)
err := s.Init(context.Background(), m)
assert.Nil(t, err)
cs, ok := s.(*ConfigurationStore)
assert.True(t, ok)

View File

@ -86,7 +86,7 @@ func NewPostgresConfigurationStore(logger logger.Logger) configuration.Store {
}
}
func (p *ConfigurationStore) Init(metadata configuration.Metadata) error {
func (p *ConfigurationStore) Init(parentCtx context.Context, metadata configuration.Metadata) error {
p.logger.Debug(InfoStartInit)
if p.client != nil {
return fmt.Errorf(ErrorAlreadyInitialized)
@ -98,7 +98,7 @@ func (p *ConfigurationStore) Init(metadata configuration.Metadata) error {
p.metadata = m
}
p.ActiveSubscriptions = make(map[string]*subscription)
ctx, cancel := context.WithTimeout(context.Background(), p.metadata.maxIdleTimeout)
ctx, cancel := context.WithTimeout(parentCtx, p.metadata.maxIdleTimeout)
defer cancel()
client, err := Connect(ctx, p.metadata.connectionString, p.metadata.maxIdleTimeout)
if err != nil {

View File

@ -143,7 +143,7 @@ func parseRedisMetadata(meta configuration.Metadata) (metadata, error) {
}
// Init does metadata and connection parsing.
func (r *ConfigurationStore) Init(metadata configuration.Metadata) error {
func (r *ConfigurationStore) Init(ctx context.Context, metadata configuration.Metadata) error {
m, err := parseRedisMetadata(metadata)
if err != nil {
return err
@ -156,11 +156,11 @@ func (r *ConfigurationStore) Init(metadata configuration.Metadata) error {
r.client = r.newClient(m)
}
if _, err = r.client.Ping(context.TODO()).Result(); err != nil {
if _, err = r.client.Ping(ctx).Result(); err != nil {
return fmt.Errorf("redis store: error connecting to redis at %s: %s", m.Host, err)
}
r.replicas, err = r.getConnectedSlaves()
r.replicas, err = r.getConnectedSlaves(ctx)
return err
}
@ -204,8 +204,8 @@ func (r *ConfigurationStore) newFailoverClient(m metadata) *redis.Client {
return redis.NewFailoverClient(opts)
}
func (r *ConfigurationStore) getConnectedSlaves() (int, error) {
res, err := r.client.Do(context.Background(), "INFO", "replication").Result()
func (r *ConfigurationStore) getConnectedSlaves(ctx context.Context) (int, error) {
res, err := r.client.Do(ctx, "INFO", "replication").Result()
if err != nil {
return 0, err
}

View File

@ -18,7 +18,7 @@ import "context"
// Store is an interface to perform operations on store.
type Store interface {
// Init configuration store.
Init(metadata Metadata) error
Init(ctx context.Context, metadata Metadata) error
// Get configuration.
Get(ctx context.Context, req *GetRequest) (*GetResponse, error)

2
go.mod
View File

@ -397,3 +397,5 @@ replace github.com/toolkits/concurrent => github.com/niean/gotools v0.0.0-201512
// this is a fork which addresses a performance issues due to go routines
replace dubbo.apache.org/dubbo-go/v3 => dubbo.apache.org/dubbo-go/v3 v3.0.3-0.20230118042253-4f159a2b38f3
replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181

View File

@ -1,5 +1,22 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package health
import (
"context"
)
type Pinger interface {
Ping() error
Ping(ctx context.Context) error
}

View File

@ -19,6 +19,7 @@ import (
"fmt"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/Shopify/sarama"
@ -31,6 +32,7 @@ type consumer struct {
k *Kafka
ready chan bool
running chan struct{}
stopped atomic.Bool
once sync.Once
mutex sync.Mutex
}
@ -275,9 +277,6 @@ func (k *Kafka) Subscribe(ctx context.Context) error {
k.cg = cg
ctx, cancel := context.WithCancel(ctx)
k.cancel = cancel
ready := make(chan bool)
k.consumer = consumer{
k: k,
@ -320,7 +319,10 @@ func (k *Kafka) Subscribe(ctx context.Context) error {
k.logger.Errorf("Error closing consumer group: %v", err)
}
close(k.consumer.running)
// Ensure running channel is only closed once.
if k.consumer.stopped.CompareAndSwap(false, true) {
close(k.consumer.running)
}
}()
<-ready
@ -331,7 +333,6 @@ func (k *Kafka) Subscribe(ctx context.Context) error {
// Close down consumer group resources, refresh once.
func (k *Kafka) closeSubscriptionResources() {
if k.cg != nil {
k.cancel()
err := k.cg.Close()
if err != nil {
k.logger.Errorf("Error closing consumer group: %v", err)

View File

@ -36,7 +36,6 @@ type Kafka struct {
saslPassword string
initialOffset int64
cg sarama.ConsumerGroup
cancel context.CancelFunc
consumer consumer
config *sarama.Config
subscribeTopics TopicHandlerConfig
@ -60,7 +59,7 @@ func NewKafka(logger logger.Logger) *Kafka {
}
// Init does metadata parsing and connection establishment.
func (k *Kafka) Init(metadata map[string]string) error {
func (k *Kafka) Init(_ context.Context, metadata map[string]string) error {
upgradedMetadata, err := k.upgradeMetadata(metadata)
if err != nil {
return err

View File

@ -107,7 +107,8 @@ func (ts *OAuthTokenSource) Token() (*sarama.AccessToken, error) {
oidcCfg := ccred.Config{ClientID: ts.ClientID, ClientSecret: ts.ClientSecret, Scopes: ts.Scopes, TokenURL: ts.TokenEndpoint.TokenURL, AuthStyle: ts.TokenEndpoint.AuthStyle}
timeoutCtx, _ := ctx.WithTimeout(ctx.TODO(), tokenRequestTimeout) //nolint:govet
timeoutCtx, cancel := ctx.WithTimeout(ctx.TODO(), tokenRequestTimeout)
defer cancel()
ts.configureClient()

View File

@ -38,9 +38,6 @@ type StandaloneRedisLock struct {
metadata rediscomponent.Metadata
logger logger.Logger
ctx context.Context
cancel context.CancelFunc
}
// NewStandaloneRedisLock returns a new standalone redis lock.
@ -54,7 +51,7 @@ func NewStandaloneRedisLock(logger logger.Logger) lock.Store {
}
// Init StandaloneRedisLock.
func (r *StandaloneRedisLock) InitLockStore(metadata lock.Metadata) error {
func (r *StandaloneRedisLock) InitLockStore(ctx context.Context, metadata lock.Metadata) error {
// 1. parse config
m, err := rediscomponent.ParseRedisMetadata(metadata.Properties)
if err != nil {
@ -75,13 +72,12 @@ func (r *StandaloneRedisLock) InitLockStore(metadata lock.Metadata) error {
if err != nil {
return err
}
r.ctx, r.cancel = context.WithCancel(context.Background())
// 3. connect to redis
if _, err = r.client.PingResult(r.ctx); err != nil {
if _, err = r.client.PingResult(ctx); err != nil {
return fmt.Errorf("[standaloneRedisLock]: error connecting to redis at %s: %s", r.clientSettings.Host, err)
}
// no replica
replicas, err := r.getConnectedSlaves()
replicas, err := r.getConnectedSlaves(ctx)
// pass the validation if error occurs,
// since some redis versions such as miniredis do not recognize the `INFO` command.
if err == nil && replicas > 0 {
@ -101,8 +97,8 @@ func needFailover(properties map[string]string) bool {
return false
}
func (r *StandaloneRedisLock) getConnectedSlaves() (int, error) {
res, err := r.client.DoRead(r.ctx, "INFO", "replication")
func (r *StandaloneRedisLock) getConnectedSlaves(ctx context.Context) (int, error) {
res, err := r.client.DoRead(ctx, "INFO", "replication")
if err != nil {
return 0, err
}
@ -183,9 +179,6 @@ func newInternalErrorUnlockResponse() *lock.UnlockResponse {
// Close shuts down the client's redis connections.
func (r *StandaloneRedisLock) Close() error {
if r.cancel != nil {
r.cancel()
}
if r.client != nil {
closeErr := r.client.Close()
r.client = nil

View File

@ -42,7 +42,7 @@ func TestStandaloneRedisLock_InitError(t *testing.T) {
cfg.Properties["redisPassword"] = ""
// init
err := comp.InitLockStore(cfg)
err := comp.InitLockStore(context.Background(), cfg)
assert.Error(t, err)
})
@ -58,7 +58,7 @@ func TestStandaloneRedisLock_InitError(t *testing.T) {
cfg.Properties["redisPassword"] = ""
// init
err := comp.InitLockStore(cfg)
err := comp.InitLockStore(context.Background(), cfg)
assert.Error(t, err)
})
@ -75,7 +75,7 @@ func TestStandaloneRedisLock_InitError(t *testing.T) {
cfg.Properties["maxRetries"] = "1 "
// init
err := comp.InitLockStore(cfg)
err := comp.InitLockStore(context.Background(), cfg)
assert.Error(t, err)
})
}
@ -96,7 +96,7 @@ func TestStandaloneRedisLock_TryLock(t *testing.T) {
cfg.Properties["redisHost"] = s.Addr()
cfg.Properties["redisPassword"] = ""
// init
err = comp.InitLockStore(cfg)
err = comp.InitLockStore(context.Background(), cfg)
assert.NoError(t, err)
// 1. client1 trylock
ownerID1 := uuid.New().String()

View File

@ -17,7 +17,7 @@ import "context"
type Store interface {
// Init this component.
InitLockStore(metadata Metadata) error
InitLockStore(ctx context.Context, metadata Metadata) error
// TryLock tries to acquire a lock.
TryLock(ctx context.Context, req *TryLockRequest) (*TryLockResponse, error)

View File

@ -57,14 +57,12 @@ type Middleware struct {
}
// GetHandler retruns the HTTP handler provided by the middleware.
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
func (m *Middleware) GetHandler(ctx context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
meta, err := m.getNativeMetadata(metadata)
if err != nil {
return nil, err
}
ctx := context.TODO()
// Create a JWKS cache that is refreshed automatically
cache := jwk.NewCache(ctx,
jwk.WithErrSink(httprc.ErrSinkFunc(func(err error) {

View File

@ -14,6 +14,7 @@ limitations under the License.
package oauth2
import (
"context"
"net/http"
"net/url"
"strings"
@ -59,7 +60,7 @@ const (
)
// GetHandler retruns the HTTP handler provided by the middleware.
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
func (m *Middleware) GetHandler(ctx context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
meta, err := m.getNativeMetadata(metadata)
if err != nil {
return nil, err

View File

@ -5,6 +5,7 @@
package mock_oauth2clientcredentials
import (
"context"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
@ -36,7 +37,7 @@ func (m *MockTokenProviderInterface) EXPECT() *MockTokenProviderInterfaceMockRec
}
// GetToken mocks base method
func (m *MockTokenProviderInterface) GetToken(arg0 *clientcredentials.Config) (*oauth2.Token, error) {
func (m *MockTokenProviderInterface) GetToken(ctx context.Context, arg0 *clientcredentials.Config) (*oauth2.Token, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetToken", arg0)
ret0, _ := ret[0].(*oauth2.Token)

View File

@ -45,7 +45,7 @@ type oAuth2ClientCredentialsMiddlewareMetadata struct {
// TokenProviderInterface provides a common interface to Mock the Token retrieval in unit tests.
type TokenProviderInterface interface {
GetToken(conf *clientcredentials.Config) (*oauth2.Token, error)
GetToken(ctx context.Context, conf *clientcredentials.Config) (*oauth2.Token, error)
}
// NewOAuth2ClientCredentialsMiddleware returns a new oAuth2 middleware.
@ -68,7 +68,7 @@ type Middleware struct {
}
// GetHandler retruns the HTTP handler provided by the middleware.
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
func (m *Middleware) GetHandler(_ context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
meta, err := m.getNativeMetadata(metadata)
if err != nil {
m.log.Errorf("getNativeMetadata error: %s", err)
@ -101,7 +101,7 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Ha
if !found {
m.log.Debugf("Cached token not found, try get one")
token, err := m.tokenProvider.GetToken(conf)
token, err := m.tokenProvider.GetToken(r.Context(), conf)
if err != nil {
m.log.Errorf("Error acquiring token: %s", err)
return
@ -171,8 +171,8 @@ func (m *Middleware) SetTokenProvider(tokenProvider TokenProviderInterface) {
}
// GetToken returns a token from the current OAuth2 ClientCredentials Configuration.
func (m *Middleware) GetToken(conf *clientcredentials.Config) (*oauth2.Token, error) {
tokenSource := conf.TokenSource(context.Background())
func (m *Middleware) GetToken(ctx context.Context, conf *clientcredentials.Config) (*oauth2.Token, error) {
tokenSource := conf.TokenSource(ctx)
return tokenSource.Token()
}

View File

@ -14,6 +14,7 @@ limitations under the License.
package oauth2clientcredentials
import (
"context"
"net/http"
"net/http/httptest"
"testing"
@ -45,7 +46,7 @@ func TestOAuth2ClientCredentialsMetadata(t *testing.T) {
metadata.Properties = map[string]string{}
log := logger.NewLogger("oauth2clientcredentials.test")
_, err := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
_, err := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(context.Background(), metadata)
assert.EqualError(t, err, "metadata errors: Parameter 'headerName' needs to be set. Parameter 'clientID' needs to be set. Parameter 'clientSecret' needs to be set. Parameter 'scopes' needs to be set. Parameter 'tokenURL' needs to be set. ")
// Invalid authStyle (non int)
@ -57,17 +58,17 @@ func TestOAuth2ClientCredentialsMetadata(t *testing.T) {
"headerName": "someHeader",
"authStyle": "asdf", // This is the value to test
}
_, err2 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
_, err2 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(context.Background(), metadata)
assert.EqualError(t, err2, "metadata errors: 1 error(s) decoding:\n\n* cannot parse 'AuthStyle' as int: strconv.ParseInt: parsing \"asdf\": invalid syntax")
// Invalid authStyle (int > 2)
metadata.Properties["authStyle"] = "3"
_, err3 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
_, err3 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(context.Background(), metadata)
assert.EqualError(t, err3, "metadata errors: Parameter 'authStyle' can only have the values 0,1,2. Received: '3'. ")
// Invalid authStyle (int < 0)
metadata.Properties["authStyle"] = "-1"
_, err4 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
_, err4 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(context.Background(), metadata)
assert.EqualError(t, err4, "metadata errors: Parameter 'authStyle' can only have the values 0,1,2. Received: '-1'. ")
}
@ -109,7 +110,7 @@ func TestOAuth2ClientCredentialsToken(t *testing.T) {
log := logger.NewLogger("oauth2clientcredentials.test")
oauth2clientcredentialsMiddleware, _ := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware)
oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider)
handler, err := oauth2clientcredentialsMiddleware.GetHandler(metadata)
handler, err := oauth2clientcredentialsMiddleware.GetHandler(context.Background(), metadata)
require.NoError(t, err)
// First handler call should return abc Token
@ -169,7 +170,7 @@ func TestOAuth2ClientCredentialsCache(t *testing.T) {
log := logger.NewLogger("oauth2clientcredentials.test")
oauth2clientcredentialsMiddleware, _ := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware)
oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider)
handler, err := oauth2clientcredentialsMiddleware.GetHandler(metadata)
handler, err := oauth2clientcredentialsMiddleware.GetHandler(context.Background(), metadata)
require.NoError(t, err)
// First handler call should return abc Token

View File

@ -106,13 +106,13 @@ func (s *Status) Valid() bool {
}
// GetHandler returns the HTTP handler provided by the middleware.
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
func (m *Middleware) GetHandler(parentCtx context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
meta, err := m.getNativeMetadata(metadata)
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(context.TODO(), time.Minute)
ctx, cancel := context.WithTimeout(parentCtx, time.Minute)
query, err := rego.New(
rego.Query("result = data.http.allow"),
rego.Module("inline.rego", meta.Rego),

View File

@ -14,6 +14,7 @@ limitations under the License.
package opa
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
@ -331,7 +332,7 @@ func TestOpaPolicy(t *testing.T) {
t.Run(name, func(t *testing.T) {
opaMiddleware := NewMiddleware(log)
handler, err := opaMiddleware.GetHandler(test.meta)
handler, err := opaMiddleware.GetHandler(context.Background(), test.meta)
if test.shouldHandlerError {
require.Error(t, err)
return

View File

@ -14,6 +14,7 @@ limitations under the License.
package ratelimit
import (
"context"
"fmt"
"net/http"
"strconv"
@ -46,7 +47,7 @@ func NewRateLimitMiddleware(_ logger.Logger) middleware.Middleware {
type Middleware struct{}
// GetHandler returns the HTTP handler provided by the middleware.
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
func (m *Middleware) GetHandler(_ context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
meta, err := m.getNativeMetadata(metadata)
if err != nil {
return nil, err

View File

@ -40,7 +40,7 @@ func NewMiddleware(logger logger.Logger) middleware.Middleware {
}
// GetHandler retruns the HTTP handler provided by the middleware.
func (m *Middleware) GetHandler(metadata middleware.Metadata) (
func (m *Middleware) GetHandler(_ context.Context, metadata middleware.Metadata) (
func(next http.Handler) http.Handler, error,
) {
if err := m.getNativeMetadata(metadata); err != nil {

View File

@ -14,6 +14,7 @@ limitations under the License.
package routeralias
import (
"context"
"io"
"net/http"
"net/http/httptest"
@ -46,7 +47,7 @@ func TestRequestHandlerWithIllegalRouterRule(t *testing.T) {
}
log := logger.NewLogger("routeralias.test")
ralias := NewMiddleware(log)
handler, err := ralias.GetHandler(meta)
handler, err := ralias.GetHandler(context.Background(), meta)
assert.Nil(t, err)
t.Run("hit: change router with common request", func(t *testing.T) {

Some files were not shown because too many files have changed in this diff Show More