Merge branch 'master' of https://github.com/dapr/components-contrib into newdeps
Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
commit
a484d7ebc7
|
@ -18,3 +18,5 @@ require (
|
|||
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
)
|
||||
|
||||
replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181
|
||||
|
|
|
@ -33,3 +33,5 @@ require (
|
|||
google.golang.org/protobuf v1.28.1 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181
|
||||
|
|
|
@ -80,7 +80,7 @@ func NewDingTalkWebhook(l logger.Logger) bindings.InputOutputBinding {
|
|||
}
|
||||
|
||||
// Init performs metadata parsing.
|
||||
func (t *DingTalkWebhook) Init(metadata bindings.Metadata) error {
|
||||
func (t *DingTalkWebhook) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||
var err error
|
||||
if err = t.settings.Decode(metadata.Properties); err != nil {
|
||||
return fmt.Errorf("dingtalk configuration error: %w", err)
|
||||
|
@ -107,6 +107,13 @@ func (t *DingTalkWebhook) Read(ctx context.Context, handler bindings.Handler) er
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *DingTalkWebhook) Close() error {
|
||||
webhooks.Lock()
|
||||
defer webhooks.Unlock()
|
||||
delete(webhooks.m, t.settings.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Operations returns list of operations supported by dingtalk webhook binding.
|
||||
func (t *DingTalkWebhook) Operations() []bindings.OperationKind {
|
||||
return []bindings.OperationKind{bindings.CreateOperation, bindings.GetOperation}
|
||||
|
|
|
@ -57,7 +57,7 @@ func TestPublishMsg(t *testing.T) { //nolint:paralleltest
|
|||
}}}
|
||||
|
||||
d := NewDingTalkWebhook(logger.NewLogger("test"))
|
||||
err := d.Init(m)
|
||||
err := d.Init(context.Background(), m)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := &bindings.InvokeRequest{Data: []byte(msg), Operation: bindings.CreateOperation, Metadata: map[string]string{}}
|
||||
|
@ -78,7 +78,7 @@ func TestBindingReadAndInvoke(t *testing.T) { //nolint:paralleltest
|
|||
}}
|
||||
|
||||
d := NewDingTalkWebhook(logger.NewLogger("test"))
|
||||
err := d.Init(m)
|
||||
err := d.Init(context.Background(), m)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var count int32
|
||||
|
@ -106,3 +106,18 @@ func TestBindingReadAndInvoke(t *testing.T) { //nolint:paralleltest
|
|||
require.FailNow(t, "read timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBindingClose(t *testing.T) {
|
||||
d := NewDingTalkWebhook(logger.NewLogger("test"))
|
||||
m := bindings.Metadata{Base: metadata.Base{
|
||||
Name: "test",
|
||||
Properties: map[string]string{
|
||||
"url": "/test",
|
||||
"secret": "",
|
||||
"id": "x",
|
||||
},
|
||||
}}
|
||||
assert.NoError(t, d.Init(context.Background(), m))
|
||||
assert.NoError(t, d.Close())
|
||||
assert.NoError(t, d.Close(), "second close should not error")
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ func NewAliCloudOSS(logger logger.Logger) bindings.OutputBinding {
|
|||
}
|
||||
|
||||
// Init does metadata parsing and connection creation.
|
||||
func (s *AliCloudOSS) Init(metadata bindings.Metadata) error {
|
||||
func (s *AliCloudOSS) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||
m, err := s.parseMetadata(metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -31,7 +31,7 @@ type Callback struct {
|
|||
}
|
||||
|
||||
// parse metadata field
|
||||
func (s *AliCloudSlsLogstorage) Init(metadata bindings.Metadata) error {
|
||||
func (s *AliCloudSlsLogstorage) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||
m, err := s.parseMeta(metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -58,7 +58,7 @@ func NewAliCloudTableStore(log logger.Logger) bindings.OutputBinding {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *AliCloudTableStore) Init(metadata bindings.Metadata) error {
|
||||
func (s *AliCloudTableStore) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||
m, err := s.parseMetadata(metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -51,7 +51,7 @@ func TestDataEncodeAndDecode(t *testing.T) {
|
|||
metadata := bindings.Metadata{Base: metadata.Base{
|
||||
Properties: getTestProperties(),
|
||||
}}
|
||||
aliCloudTableStore.Init(metadata)
|
||||
aliCloudTableStore.Init(context.Background(), metadata)
|
||||
|
||||
// test create
|
||||
putData := map[string]interface{}{
|
||||
|
|
|
@ -78,7 +78,7 @@ func NewAPNS(logger logger.Logger) bindings.OutputBinding {
|
|||
|
||||
// Init will configure the APNS output binding using the metadata specified
|
||||
// in the binding's configuration.
|
||||
func (a *APNS) Init(metadata bindings.Metadata) error {
|
||||
func (a *APNS) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||
if err := a.makeURLPrefix(metadata); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -51,7 +51,7 @@ func TestInit(t *testing.T) {
|
|||
},
|
||||
}}
|
||||
binding := NewAPNS(testLogger).(*APNS)
|
||||
err := binding.Init(metadata)
|
||||
err := binding.Init(context.Background(), metadata)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, developmentPrefix, binding.urlPrefix)
|
||||
})
|
||||
|
@ -66,7 +66,7 @@ func TestInit(t *testing.T) {
|
|||
},
|
||||
}}
|
||||
binding := NewAPNS(testLogger).(*APNS)
|
||||
err := binding.Init(metadata)
|
||||
err := binding.Init(context.Background(), metadata)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, productionPrefix, binding.urlPrefix)
|
||||
})
|
||||
|
@ -80,7 +80,7 @@ func TestInit(t *testing.T) {
|
|||
},
|
||||
}}
|
||||
binding := NewAPNS(testLogger).(*APNS)
|
||||
err := binding.Init(metadata)
|
||||
err := binding.Init(context.Background(), metadata)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, productionPrefix, binding.urlPrefix)
|
||||
})
|
||||
|
@ -95,7 +95,7 @@ func TestInit(t *testing.T) {
|
|||
},
|
||||
}}
|
||||
binding := NewAPNS(testLogger).(*APNS)
|
||||
err := binding.Init(metadata)
|
||||
err := binding.Init(context.Background(), metadata)
|
||||
assert.Error(t, err, "invalid value for development parameter: True")
|
||||
})
|
||||
|
||||
|
@ -107,7 +107,7 @@ func TestInit(t *testing.T) {
|
|||
},
|
||||
}}
|
||||
binding := NewAPNS(testLogger).(*APNS)
|
||||
err := binding.Init(metadata)
|
||||
err := binding.Init(context.Background(), metadata)
|
||||
assert.Error(t, err, "the key-id parameter is required")
|
||||
})
|
||||
|
||||
|
@ -120,7 +120,7 @@ func TestInit(t *testing.T) {
|
|||
},
|
||||
}}
|
||||
binding := NewAPNS(testLogger).(*APNS)
|
||||
err := binding.Init(metadata)
|
||||
err := binding.Init(context.Background(), metadata)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, testKeyID, binding.authorizationBuilder.keyID)
|
||||
})
|
||||
|
@ -133,7 +133,7 @@ func TestInit(t *testing.T) {
|
|||
},
|
||||
}}
|
||||
binding := NewAPNS(testLogger).(*APNS)
|
||||
err := binding.Init(metadata)
|
||||
err := binding.Init(context.Background(), metadata)
|
||||
assert.Error(t, err, "the team-id parameter is required")
|
||||
})
|
||||
|
||||
|
@ -146,7 +146,7 @@ func TestInit(t *testing.T) {
|
|||
},
|
||||
}}
|
||||
binding := NewAPNS(testLogger).(*APNS)
|
||||
err := binding.Init(metadata)
|
||||
err := binding.Init(context.Background(), metadata)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, testTeamID, binding.authorizationBuilder.teamID)
|
||||
})
|
||||
|
@ -159,7 +159,7 @@ func TestInit(t *testing.T) {
|
|||
},
|
||||
}}
|
||||
binding := NewAPNS(testLogger).(*APNS)
|
||||
err := binding.Init(metadata)
|
||||
err := binding.Init(context.Background(), metadata)
|
||||
assert.Error(t, err, "the private-key parameter is required")
|
||||
})
|
||||
|
||||
|
@ -172,7 +172,7 @@ func TestInit(t *testing.T) {
|
|||
},
|
||||
}}
|
||||
binding := NewAPNS(testLogger).(*APNS)
|
||||
err := binding.Init(metadata)
|
||||
err := binding.Init(context.Background(), metadata)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, binding.authorizationBuilder.privateKey)
|
||||
})
|
||||
|
@ -335,7 +335,7 @@ func makeTestBinding(t *testing.T, log logger.Logger) *APNS {
|
|||
privateKeyKey: testPrivateKey,
|
||||
},
|
||||
}}
|
||||
err := testBinding.Init(bindingMetadata)
|
||||
err := testBinding.Init(context.Background(), bindingMetadata)
|
||||
assert.Nil(t, err)
|
||||
|
||||
return testBinding
|
||||
|
|
|
@ -49,7 +49,7 @@ func NewDynamoDB(logger logger.Logger) bindings.OutputBinding {
|
|||
}
|
||||
|
||||
// Init performs connection parsing for DynamoDB.
|
||||
func (d *DynamoDB) Init(metadata bindings.Metadata) error {
|
||||
func (d *DynamoDB) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||
meta, err := d.getDynamoDBMetadata(metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -15,7 +15,10 @@ package kinesis
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
|
@ -45,6 +48,10 @@ type AWSKinesis struct {
|
|||
streamARN *string
|
||||
consumerARN *string
|
||||
logger logger.Logger
|
||||
|
||||
closed atomic.Bool
|
||||
closeCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
type kinesisMetadata struct {
|
||||
|
@ -83,11 +90,14 @@ type recordProcessor struct {
|
|||
|
||||
// NewAWSKinesis returns a new AWS Kinesis instance.
|
||||
func NewAWSKinesis(logger logger.Logger) bindings.InputOutputBinding {
|
||||
return &AWSKinesis{logger: logger}
|
||||
return &AWSKinesis{
|
||||
logger: logger,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Init does metadata parsing and connection creation.
|
||||
func (a *AWSKinesis) Init(metadata bindings.Metadata) error {
|
||||
func (a *AWSKinesis) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||
m, err := a.parseMetadata(metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -107,7 +117,7 @@ func (a *AWSKinesis) Init(metadata bindings.Metadata) error {
|
|||
}
|
||||
|
||||
streamName := aws.String(m.StreamName)
|
||||
stream, err := client.DescribeStream(&kinesis.DescribeStreamInput{
|
||||
stream, err := client.DescribeStreamWithContext(ctx, &kinesis.DescribeStreamInput{
|
||||
StreamName: streamName,
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -147,6 +157,10 @@ func (a *AWSKinesis) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*
|
|||
}
|
||||
|
||||
func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err error) {
|
||||
if a.closed.Load() {
|
||||
return errors.New("binding is closed")
|
||||
}
|
||||
|
||||
if a.metadata.KinesisConsumerMode == SharedThroughput {
|
||||
a.worker = worker.NewWorker(a.recordProcessorFactory(ctx, handler), a.workerConfig)
|
||||
err = a.worker.Start()
|
||||
|
@ -166,8 +180,13 @@ func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err er
|
|||
}
|
||||
|
||||
// Wait for context cancelation then stop
|
||||
a.wg.Add(1)
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
defer a.wg.Done()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-a.closeCh:
|
||||
}
|
||||
if a.metadata.KinesisConsumerMode == SharedThroughput {
|
||||
a.worker.Shutdown()
|
||||
} else if a.metadata.KinesisConsumerMode == ExtendedFanout {
|
||||
|
@ -188,14 +207,25 @@ func (a *AWSKinesis) Subscribe(ctx context.Context, streamDesc kinesis.StreamDes
|
|||
|
||||
a.consumerARN = consumerARN
|
||||
|
||||
a.wg.Add(len(streamDesc.Shards))
|
||||
for i, shard := range streamDesc.Shards {
|
||||
go func(idx int, s *kinesis.Shard) error {
|
||||
go func(idx int, s *kinesis.Shard) {
|
||||
defer a.wg.Done()
|
||||
|
||||
// Reconnection backoff
|
||||
bo := backoff.NewExponentialBackOff()
|
||||
bo.InitialInterval = 2 * time.Second
|
||||
|
||||
// Repeat until context is canceled
|
||||
for ctx.Err() == nil {
|
||||
// Repeat until context is canceled or binding closed.
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-a.closeCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
sub, err := a.client.SubscribeToShardWithContext(ctx, &kinesis.SubscribeToShardInput{
|
||||
ConsumerARN: consumerARN,
|
||||
ShardId: s.ShardId,
|
||||
|
@ -204,8 +234,12 @@ func (a *AWSKinesis) Subscribe(ctx context.Context, streamDesc kinesis.StreamDes
|
|||
if err != nil {
|
||||
wait := bo.NextBackOff()
|
||||
a.logger.Errorf("Error while reading from shard %v: %v. Attempting to reconnect in %s...", s.ShardId, err, wait)
|
||||
time.Sleep(wait)
|
||||
continue
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(wait):
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Reset the backoff on connection success
|
||||
|
@ -223,22 +257,30 @@ func (a *AWSKinesis) Subscribe(ctx context.Context, streamDesc kinesis.StreamDes
|
|||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}(i, shard)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AWSKinesis) ensureConsumer(parentCtx context.Context, streamARN *string) (*string, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
consumer, err := a.client.DescribeStreamConsumerWithContext(ctx, &kinesis.DescribeStreamConsumerInput{
|
||||
func (a *AWSKinesis) Close() error {
|
||||
if a.closed.CompareAndSwap(false, true) {
|
||||
close(a.closeCh)
|
||||
}
|
||||
a.wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AWSKinesis) ensureConsumer(ctx context.Context, streamARN *string) (*string, error) {
|
||||
// Only set timeout on consumer call.
|
||||
conCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
consumer, err := a.client.DescribeStreamConsumerWithContext(conCtx, &kinesis.DescribeStreamConsumerInput{
|
||||
ConsumerName: &a.metadata.ConsumerName,
|
||||
StreamARN: streamARN,
|
||||
})
|
||||
cancel()
|
||||
if err != nil {
|
||||
return a.registerConsumer(parentCtx, streamARN)
|
||||
return a.registerConsumer(ctx, streamARN)
|
||||
}
|
||||
|
||||
return consumer.ConsumerDescription.ConsumerARN, nil
|
||||
|
|
|
@ -99,7 +99,7 @@ func NewAWSS3(logger logger.Logger) bindings.OutputBinding {
|
|||
}
|
||||
|
||||
// Init does metadata parsing and connection creation.
|
||||
func (s *AWSS3) Init(metadata bindings.Metadata) error {
|
||||
func (s *AWSS3) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||
m, err := s.parseMetadata(metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -61,7 +61,7 @@ func NewAWSSES(logger logger.Logger) bindings.OutputBinding {
|
|||
}
|
||||
|
||||
// Init does metadata parsing.
|
||||
func (a *AWSSES) Init(metadata bindings.Metadata) error {
|
||||
func (a *AWSSES) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||
// Parse input metadata
|
||||
meta, err := a.parseMetadata(metadata)
|
||||
if err != nil {
|
||||
|
|
|
@ -53,7 +53,7 @@ func NewAWSSNS(logger logger.Logger) bindings.OutputBinding {
|
|||
}
|
||||
|
||||
// Init does metadata parsing.
|
||||
func (a *AWSSNS) Init(metadata bindings.Metadata) error {
|
||||
func (a *AWSSNS) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||
m, err := a.parseMetadata(metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -16,6 +16,9 @@ package sqs
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
|
@ -31,7 +34,10 @@ type AWSSQS struct {
|
|||
Client *sqs.SQS
|
||||
QueueURL *string
|
||||
|
||||
logger logger.Logger
|
||||
logger logger.Logger
|
||||
wg sync.WaitGroup
|
||||
closeCh chan struct{}
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
type sqsMetadata struct {
|
||||
|
@ -45,11 +51,14 @@ type sqsMetadata struct {
|
|||
|
||||
// NewAWSSQS returns a new AWS SQS instance.
|
||||
func NewAWSSQS(logger logger.Logger) bindings.InputOutputBinding {
|
||||
return &AWSSQS{logger: logger}
|
||||
return &AWSSQS{
|
||||
logger: logger,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Init does metadata parsing and connection creation.
|
||||
func (a *AWSSQS) Init(metadata bindings.Metadata) error {
|
||||
func (a *AWSSQS) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||
m, err := a.parseSQSMetadata(metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -61,7 +70,7 @@ func (a *AWSSQS) Init(metadata bindings.Metadata) error {
|
|||
}
|
||||
|
||||
queueName := m.QueueName
|
||||
resultURL, err := client.GetQueueUrl(&sqs.GetQueueUrlInput{
|
||||
resultURL, err := client.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{
|
||||
QueueName: aws.String(queueName),
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -89,9 +98,20 @@ func (a *AWSSQS) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bind
|
|||
}
|
||||
|
||||
func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error {
|
||||
if a.closed.Load() {
|
||||
return errors.New("binding is closed")
|
||||
}
|
||||
|
||||
a.wg.Add(1)
|
||||
go func() {
|
||||
// Repeat until the context is canceled
|
||||
for ctx.Err() == nil {
|
||||
defer a.wg.Done()
|
||||
|
||||
// Repeat until the context is canceled or component is closed
|
||||
for {
|
||||
if ctx.Err() != nil || a.closed.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
result, err := a.Client.ReceiveMessageWithContext(ctx, &sqs.ReceiveMessageInput{
|
||||
QueueUrl: a.QueueURL,
|
||||
AttributeNames: aws.StringSlice([]string{
|
||||
|
@ -126,13 +146,25 @@ func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error {
|
|||
}
|
||||
}
|
||||
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-a.closeCh:
|
||||
case <-time.After(time.Millisecond * 50):
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AWSSQS) Close() error {
|
||||
if a.closed.CompareAndSwap(false, true) {
|
||||
close(a.closeCh)
|
||||
}
|
||||
a.wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AWSSQS) parseSQSMetadata(metadata bindings.Metadata) (*sqsMetadata, error) {
|
||||
b, err := json.Marshal(metadata.Properties)
|
||||
if err != nil {
|
||||
|
|
|
@ -92,7 +92,7 @@ func NewAzureBlobStorage(logger logger.Logger) bindings.OutputBinding {
|
|||
}
|
||||
|
||||
// Init performs metadata parsing.
|
||||
func (a *AzureBlobStorage) Init(metadata bindings.Metadata) error {
|
||||
func (a *AzureBlobStorage) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||
var err error
|
||||
a.containerClient, a.metadata, err = storageinternal.CreateContainerStorageClient(context.TODO(), a.logger, metadata.Properties)
|
||||
if err != nil {
|
||||
|
|
|
@ -53,7 +53,7 @@ func NewCosmosDB(logger logger.Logger) bindings.OutputBinding {
|
|||
}
|
||||
|
||||
// Init performs CosmosDB connection parsing and connecting.
|
||||
func (c *CosmosDB) Init(metadata bindings.Metadata) error {
|
||||
func (c *CosmosDB) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||
m, err := c.parseMetadata(metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -103,9 +103,9 @@ func (c *CosmosDB) Init(metadata bindings.Metadata) error {
|
|||
}
|
||||
|
||||
c.client = dbContainer
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeoutValue*time.Second)
|
||||
_, err = c.client.Read(ctx, nil)
|
||||
cancel()
|
||||
readCtx, readCancel := context.WithTimeout(ctx, timeoutValue*time.Second)
|
||||
defer readCancel()
|
||||
_, err = c.client.Read(readCtx, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -59,7 +59,7 @@ func NewCosmosDBGremlinAPI(logger logger.Logger) bindings.OutputBinding {
|
|||
}
|
||||
|
||||
// Init performs CosmosDBGremlinAPI connection parsing and connecting.
|
||||
func (c *CosmosDBGremlinAPI) Init(metadata bindings.Metadata) error {
|
||||
func (c *CosmosDBGremlinAPI) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||
c.logger.Debug("Initializing Cosmos Graph DB binding")
|
||||
|
||||
m, err := c.parseMetadata(metadata)
|
||||
|
|
|
@ -19,6 +19,8 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
|
@ -42,6 +44,9 @@ const armOperationTimeout = 30 * time.Second
|
|||
type AzureEventGrid struct {
|
||||
metadata *azureEventGridMetadata
|
||||
logger logger.Logger
|
||||
closeCh chan struct{}
|
||||
closed atomic.Bool
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
type azureEventGridMetadata struct {
|
||||
|
@ -70,11 +75,14 @@ type azureEventGridMetadata struct {
|
|||
|
||||
// NewAzureEventGrid returns a new Azure Event Grid instance.
|
||||
func NewAzureEventGrid(logger logger.Logger) bindings.InputOutputBinding {
|
||||
return &AzureEventGrid{logger: logger}
|
||||
return &AzureEventGrid{
|
||||
logger: logger,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Init performs metadata init.
|
||||
func (a *AzureEventGrid) Init(metadata bindings.Metadata) error {
|
||||
func (a *AzureEventGrid) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||
m, err := a.parseMetadata(metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -85,6 +93,10 @@ func (a *AzureEventGrid) Init(metadata bindings.Metadata) error {
|
|||
}
|
||||
|
||||
func (a *AzureEventGrid) Read(ctx context.Context, handler bindings.Handler) error {
|
||||
if a.closed.Load() {
|
||||
return errors.New("binding is closed")
|
||||
}
|
||||
|
||||
err := a.ensureInputBindingMetadata()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -120,17 +132,22 @@ func (a *AzureEventGrid) Read(ctx context.Context, handler bindings.Handler) err
|
|||
}
|
||||
|
||||
// Run the server in background
|
||||
a.wg.Add(2)
|
||||
go func() {
|
||||
defer a.wg.Done()
|
||||
a.logger.Debugf("About to start listening for Event Grid events at http://localhost:%s/api/events", a.metadata.HandshakePort)
|
||||
srvErr := srv.ListenAndServe(":" + a.metadata.HandshakePort)
|
||||
if err != nil {
|
||||
a.logger.Errorf("Error starting server: %v", srvErr)
|
||||
}
|
||||
}()
|
||||
|
||||
// Close the server when context is canceled
|
||||
// Close the server when context is canceled or binding closed.
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
defer a.wg.Done()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-a.closeCh:
|
||||
}
|
||||
srvErr := srv.Shutdown()
|
||||
if err != nil {
|
||||
a.logger.Errorf("Error shutting down server: %v", srvErr)
|
||||
|
@ -149,6 +166,14 @@ func (a *AzureEventGrid) Operations() []bindings.OperationKind {
|
|||
return []bindings.OperationKind{bindings.CreateOperation}
|
||||
}
|
||||
|
||||
func (a *AzureEventGrid) Close() error {
|
||||
if a.closed.CompareAndSwap(false, true) {
|
||||
close(a.closeCh)
|
||||
}
|
||||
a.wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AzureEventGrid) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
|
||||
err := a.ensureOutputBindingMetadata()
|
||||
if err != nil {
|
||||
|
|
|
@ -37,7 +37,7 @@ func NewAzureEventHubs(logger logger.Logger) bindings.InputOutputBinding {
|
|||
}
|
||||
|
||||
// Init performs metadata init.
|
||||
func (a *AzureEventHubs) Init(metadata bindings.Metadata) error {
|
||||
func (a *AzureEventHubs) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||
return a.AzureEventHubs.Init(metadata.Properties)
|
||||
}
|
||||
|
||||
|
|
|
@ -102,7 +102,7 @@ func testEventHubsBindingsAADAuthentication(t *testing.T) {
|
|||
metadata := createEventHubsBindingsAADMetadata()
|
||||
eventHubsBindings := NewAzureEventHubs(log)
|
||||
|
||||
err := eventHubsBindings.Init(metadata)
|
||||
err := eventHubsBindings.Init(context.Background(), metadata)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := &bindings.InvokeRequest{
|
||||
|
@ -146,7 +146,7 @@ func testReadIotHubEvents(t *testing.T) {
|
|||
|
||||
logger := logger.NewLogger("bindings.azure.eventhubs.integration.test")
|
||||
eh := NewAzureEventHubs(logger)
|
||||
err := eh.Init(createIotHubBindingsMetadata())
|
||||
err := eh.Init(context.Background(), createIotHubBindingsMetadata())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Invoke az CLI via bash script to send test IoT device events
|
||||
|
|
|
@ -17,6 +17,8 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
servicebus "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus"
|
||||
|
@ -39,17 +41,21 @@ type AzureServiceBusQueues struct {
|
|||
client *impl.Client
|
||||
timeout time.Duration
|
||||
logger logger.Logger
|
||||
closed atomic.Bool
|
||||
wg sync.WaitGroup
|
||||
closeCh chan struct{}
|
||||
}
|
||||
|
||||
// NewAzureServiceBusQueues returns a new AzureServiceBusQueues instance.
|
||||
func NewAzureServiceBusQueues(logger logger.Logger) bindings.InputOutputBinding {
|
||||
return &AzureServiceBusQueues{
|
||||
logger: logger,
|
||||
logger: logger,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Init parses connection properties and creates a new Service Bus Queue client.
|
||||
func (a *AzureServiceBusQueues) Init(metadata bindings.Metadata) (err error) {
|
||||
func (a *AzureServiceBusQueues) Init(ctx context.Context, metadata bindings.Metadata) (err error) {
|
||||
a.metadata, err = impl.ParseMetadata(metadata.Properties, a.logger, (impl.MetadataModeBinding | impl.MetadataModeQueues))
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -62,7 +68,7 @@ func (a *AzureServiceBusQueues) Init(metadata bindings.Metadata) (err error) {
|
|||
}
|
||||
|
||||
// Will do nothing if DisableEntityManagement is false
|
||||
err = a.client.EnsureQueue(context.Background(), a.metadata.QueueName)
|
||||
err = a.client.EnsureQueue(ctx, a.metadata.QueueName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -100,14 +106,33 @@ func (a *AzureServiceBusQueues) Invoke(invokeCtx context.Context, req *bindings.
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *AzureServiceBusQueues) Read(subscribeCtx context.Context, handler bindings.Handler) error {
|
||||
func (a *AzureServiceBusQueues) Read(parentCtx context.Context, handler bindings.Handler) error {
|
||||
if a.closed.Load() {
|
||||
return errors.New("binding is closed")
|
||||
}
|
||||
|
||||
// Reconnection backoff policy
|
||||
bo := backoff.NewExponentialBackOff()
|
||||
bo.MaxElapsedTime = 0
|
||||
bo.InitialInterval = time.Duration(a.metadata.MinConnectionRecoveryInSec) * time.Second
|
||||
bo.MaxInterval = time.Duration(a.metadata.MaxConnectionRecoveryInSec) * time.Second
|
||||
|
||||
subscribeCtx, subscribeCancel := context.WithCancel(parentCtx)
|
||||
|
||||
// Close the subscription context when the binding is closed.
|
||||
a.wg.Add(2)
|
||||
go func() {
|
||||
defer a.wg.Done()
|
||||
select {
|
||||
case <-a.closeCh:
|
||||
subscribeCancel()
|
||||
case <-parentCtx.Done():
|
||||
// nop
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer a.wg.Done()
|
||||
// Reconnect loop.
|
||||
for {
|
||||
sub := impl.NewSubscription(subscribeCtx, impl.SubsriptionOptions{
|
||||
|
@ -165,7 +190,12 @@ func (a *AzureServiceBusQueues) Read(subscribeCtx context.Context, handler bindi
|
|||
|
||||
wait := bo.NextBackOff()
|
||||
a.logger.Warnf("Subscription to queue %s lost connection, attempting to reconnect in %s...", a.metadata.QueueName, wait)
|
||||
time.Sleep(wait)
|
||||
select {
|
||||
case <-time.After(wait):
|
||||
// nop
|
||||
case <-a.closeCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -204,7 +234,11 @@ func (a *AzureServiceBusQueues) getHandlerFn(handler bindings.Handler) impl.Hand
|
|||
}
|
||||
|
||||
func (a *AzureServiceBusQueues) Close() (err error) {
|
||||
if a.closed.CompareAndSwap(false, true) {
|
||||
close(a.closeCh)
|
||||
}
|
||||
a.logger.Debug("Closing component")
|
||||
a.client.CloseSender(a.metadata.QueueName)
|
||||
a.wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -78,7 +78,7 @@ type SignalR struct {
|
|||
}
|
||||
|
||||
// Init is responsible for initializing the SignalR output based on the metadata.
|
||||
func (s *SignalR) Init(metadata bindings.Metadata) (err error) {
|
||||
func (s *SignalR) Init(_ context.Context, metadata bindings.Metadata) (err error) {
|
||||
s.userAgent = "dapr-" + logger.DaprVersion
|
||||
|
||||
err = s.parseMetadata(metadata.Properties)
|
||||
|
|
|
@ -16,9 +16,12 @@ package storagequeues
|
|||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-storage-queue-go/azqueue"
|
||||
|
@ -40,9 +43,10 @@ type consumer struct {
|
|||
|
||||
// QueueHelper enables injection for testnig.
|
||||
type QueueHelper interface {
|
||||
Init(metadata bindings.Metadata) (*storageQueuesMetadata, error)
|
||||
Init(ctx context.Context, metadata bindings.Metadata) (*storageQueuesMetadata, error)
|
||||
Write(ctx context.Context, data []byte, ttl *time.Duration) error
|
||||
Read(ctx context.Context, consumer *consumer) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
// AzureQueueHelper concrete impl of queue helper.
|
||||
|
@ -55,7 +59,7 @@ type AzureQueueHelper struct {
|
|||
}
|
||||
|
||||
// Init sets up this helper.
|
||||
func (d *AzureQueueHelper) Init(metadata bindings.Metadata) (*storageQueuesMetadata, error) {
|
||||
func (d *AzureQueueHelper) Init(ctx context.Context, metadata bindings.Metadata) (*storageQueuesMetadata, error) {
|
||||
m, err := parseMetadata(metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -89,9 +93,9 @@ func (d *AzureQueueHelper) Init(metadata bindings.Metadata) (*storageQueuesMetad
|
|||
d.queueURL = azqueue.NewQueueURL(*URL, p)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
_, err = d.queueURL.Create(ctx, azqueue.Metadata{})
|
||||
cancel()
|
||||
createCtx, createCancel := context.WithTimeout(ctx, 2*time.Minute)
|
||||
_, err = d.queueURL.Create(createCtx, azqueue.Metadata{})
|
||||
createCancel()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -128,7 +132,10 @@ func (d *AzureQueueHelper) Read(ctx context.Context, consumer *consumer) error {
|
|||
}
|
||||
if res.NumMessages() == 0 {
|
||||
// Queue was empty so back off by 10 seconds before trying again
|
||||
time.Sleep(10 * time.Second)
|
||||
select {
|
||||
case <-time.After(10 * time.Second):
|
||||
case <-ctx.Done():
|
||||
}
|
||||
return nil
|
||||
}
|
||||
mt := res.Message(0).Text
|
||||
|
@ -162,6 +169,10 @@ func (d *AzureQueueHelper) Read(ctx context.Context, consumer *consumer) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (d *AzureQueueHelper) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewAzureQueueHelper creates new helper.
|
||||
func NewAzureQueueHelper(logger logger.Logger) QueueHelper {
|
||||
return &AzureQueueHelper{
|
||||
|
@ -175,6 +186,10 @@ type AzureStorageQueues struct {
|
|||
helper QueueHelper
|
||||
|
||||
logger logger.Logger
|
||||
|
||||
wg sync.WaitGroup
|
||||
closeCh chan struct{}
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
type storageQueuesMetadata struct {
|
||||
|
@ -189,12 +204,16 @@ type storageQueuesMetadata struct {
|
|||
|
||||
// NewAzureStorageQueues returns a new AzureStorageQueues instance.
|
||||
func NewAzureStorageQueues(logger logger.Logger) bindings.InputOutputBinding {
|
||||
return &AzureStorageQueues{helper: NewAzureQueueHelper(logger), logger: logger}
|
||||
return &AzureStorageQueues{
|
||||
helper: NewAzureQueueHelper(logger),
|
||||
logger: logger,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Init parses connection properties and creates a new Storage Queue client.
|
||||
func (a *AzureStorageQueues) Init(metadata bindings.Metadata) (err error) {
|
||||
a.metadata, err = a.helper.Init(metadata)
|
||||
func (a *AzureStorageQueues) Init(ctx context.Context, metadata bindings.Metadata) (err error) {
|
||||
a.metadata, err = a.helper.Init(ctx, metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -261,14 +280,32 @@ func (a *AzureStorageQueues) Invoke(ctx context.Context, req *bindings.InvokeReq
|
|||
}
|
||||
|
||||
func (a *AzureStorageQueues) Read(ctx context.Context, handler bindings.Handler) error {
|
||||
if a.closed.Load() {
|
||||
return errors.New("input binding is closed")
|
||||
}
|
||||
|
||||
c := consumer{
|
||||
callback: handler,
|
||||
}
|
||||
|
||||
// Close read context when binding is closed.
|
||||
readCtx, cancel := context.WithCancel(ctx)
|
||||
a.wg.Add(2)
|
||||
go func() {
|
||||
defer a.wg.Done()
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-a.closeCh:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer a.wg.Done()
|
||||
// Read until context is canceled
|
||||
var err error
|
||||
for ctx.Err() == nil {
|
||||
err = a.helper.Read(ctx, &c)
|
||||
for readCtx.Err() == nil {
|
||||
err = a.helper.Read(readCtx, &c)
|
||||
if err != nil {
|
||||
a.logger.Errorf("error from c: %s", err)
|
||||
}
|
||||
|
@ -277,3 +314,11 @@ func (a *AzureStorageQueues) Read(ctx context.Context, handler bindings.Handler)
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AzureStorageQueues) Close() error {
|
||||
if a.closed.CompareAndSwap(false, true) {
|
||||
close(a.closeCh)
|
||||
}
|
||||
a.wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@ package storagequeues
|
|||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -32,9 +33,11 @@ type MockHelper struct {
|
|||
mock.Mock
|
||||
messages chan []byte
|
||||
metadata *storageQueuesMetadata
|
||||
closeCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func (m *MockHelper) Init(metadata bindings.Metadata) (*storageQueuesMetadata, error) {
|
||||
func (m *MockHelper) Init(ctx context.Context, metadata bindings.Metadata) (*storageQueuesMetadata, error) {
|
||||
m.messages = make(chan []byte, 10)
|
||||
var err error
|
||||
m.metadata, err = parseMetadata(metadata)
|
||||
|
@ -50,12 +53,23 @@ func (m *MockHelper) Write(ctx context.Context, data []byte, ttl *time.Duration)
|
|||
func (m *MockHelper) Read(ctx context.Context, consumer *consumer) error {
|
||||
retvals := m.Called(ctx, consumer)
|
||||
|
||||
readCtx, cancel := context.WithCancel(ctx)
|
||||
m.wg.Add(2)
|
||||
go func() {
|
||||
defer m.wg.Done()
|
||||
defer cancel()
|
||||
select {
|
||||
case <-readCtx.Done():
|
||||
case <-m.closeCh:
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer m.wg.Done()
|
||||
for msg := range m.messages {
|
||||
if m.metadata.DecodeBase64 {
|
||||
msg, _ = base64.StdEncoding.DecodeString(string(msg))
|
||||
}
|
||||
go consumer.callback(ctx, &bindings.ReadResponse{
|
||||
go consumer.callback(readCtx, &bindings.ReadResponse{
|
||||
Data: msg,
|
||||
})
|
||||
}
|
||||
|
@ -64,18 +78,24 @@ func (m *MockHelper) Read(ctx context.Context, consumer *consumer) error {
|
|||
return retvals.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockHelper) Close() error {
|
||||
defer m.wg.Wait()
|
||||
close(m.closeCh)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestWriteQueue(t *testing.T) {
|
||||
mm := new(MockHelper)
|
||||
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.MatchedBy(func(in *time.Duration) bool {
|
||||
return in == nil
|
||||
})).Return(nil)
|
||||
|
||||
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
|
||||
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
|
||||
|
||||
m := bindings.Metadata{}
|
||||
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
|
||||
|
||||
err := a.Init(m)
|
||||
err := a.Init(context.Background(), m)
|
||||
assert.Nil(t, err)
|
||||
|
||||
r := bindings.InvokeRequest{Data: []byte("This is my message")}
|
||||
|
@ -83,6 +103,7 @@ func TestWriteQueue(t *testing.T) {
|
|||
_, err = a.Invoke(context.Background(), &r)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, a.Close())
|
||||
}
|
||||
|
||||
func TestWriteWithTTLInQueue(t *testing.T) {
|
||||
|
@ -91,12 +112,12 @@ func TestWriteWithTTLInQueue(t *testing.T) {
|
|||
return in != nil && *in == time.Second
|
||||
})).Return(nil)
|
||||
|
||||
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
|
||||
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
|
||||
|
||||
m := bindings.Metadata{}
|
||||
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1", metadata.TTLMetadataKey: "1"}
|
||||
|
||||
err := a.Init(m)
|
||||
err := a.Init(context.Background(), m)
|
||||
assert.Nil(t, err)
|
||||
|
||||
r := bindings.InvokeRequest{Data: []byte("This is my message")}
|
||||
|
@ -104,6 +125,7 @@ func TestWriteWithTTLInQueue(t *testing.T) {
|
|||
_, err = a.Invoke(context.Background(), &r)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, a.Close())
|
||||
}
|
||||
|
||||
func TestWriteWithTTLInWrite(t *testing.T) {
|
||||
|
@ -112,12 +134,12 @@ func TestWriteWithTTLInWrite(t *testing.T) {
|
|||
return in != nil && *in == time.Second
|
||||
})).Return(nil)
|
||||
|
||||
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
|
||||
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
|
||||
|
||||
m := bindings.Metadata{}
|
||||
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1", metadata.TTLMetadataKey: "1"}
|
||||
|
||||
err := a.Init(m)
|
||||
err := a.Init(context.Background(), m)
|
||||
assert.Nil(t, err)
|
||||
|
||||
r := bindings.InvokeRequest{
|
||||
|
@ -128,6 +150,7 @@ func TestWriteWithTTLInWrite(t *testing.T) {
|
|||
_, err = a.Invoke(context.Background(), &r)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, a.Close())
|
||||
}
|
||||
|
||||
// Uncomment this function to write a message to local storage queue
|
||||
|
@ -138,7 +161,7 @@ func TestWriteWithTTLInWrite(t *testing.T) {
|
|||
m := bindings.Metadata{}
|
||||
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
|
||||
|
||||
err := a.Init(m)
|
||||
err := a.Init(context.Background(), m)
|
||||
assert.Nil(t, err)
|
||||
|
||||
r := bindings.InvokeRequest{Data: []byte("This is my message")}
|
||||
|
@ -152,12 +175,12 @@ func TestReadQueue(t *testing.T) {
|
|||
mm := new(MockHelper)
|
||||
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)
|
||||
mm.On("Read", mock.AnythingOfType("*context.cancelCtx"), mock.AnythingOfType("*storagequeues.consumer")).Return(nil)
|
||||
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
|
||||
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
|
||||
|
||||
m := bindings.Metadata{}
|
||||
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
|
||||
|
||||
err := a.Init(m)
|
||||
err := a.Init(context.Background(), m)
|
||||
assert.Nil(t, err)
|
||||
|
||||
r := bindings.InvokeRequest{Data: []byte("This is my message")}
|
||||
|
@ -186,6 +209,7 @@ func TestReadQueue(t *testing.T) {
|
|||
t.Fatal("Timeout waiting for messages")
|
||||
}
|
||||
assert.Equal(t, 1, received)
|
||||
assert.NoError(t, a.Close())
|
||||
}
|
||||
|
||||
func TestReadQueueDecode(t *testing.T) {
|
||||
|
@ -193,12 +217,12 @@ func TestReadQueueDecode(t *testing.T) {
|
|||
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)
|
||||
mm.On("Read", mock.AnythingOfType("*context.cancelCtx"), mock.AnythingOfType("*storagequeues.consumer")).Return(nil)
|
||||
|
||||
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
|
||||
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
|
||||
|
||||
m := bindings.Metadata{}
|
||||
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1", "decodeBase64": "true"}
|
||||
|
||||
err := a.Init(m)
|
||||
err := a.Init(context.Background(), m)
|
||||
assert.Nil(t, err)
|
||||
|
||||
r := bindings.InvokeRequest{Data: []byte("VGhpcyBpcyBteSBtZXNzYWdl")}
|
||||
|
@ -227,6 +251,7 @@ func TestReadQueueDecode(t *testing.T) {
|
|||
t.Fatal("Timeout waiting for messages")
|
||||
}
|
||||
assert.Equal(t, 1, received)
|
||||
assert.NoError(t, a.Close())
|
||||
}
|
||||
|
||||
// Uncomment this function to test reding from local queue
|
||||
|
@ -237,7 +262,7 @@ func TestReadQueueDecode(t *testing.T) {
|
|||
m := bindings.Metadata{}
|
||||
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
|
||||
|
||||
err := a.Init(m)
|
||||
err := a.Init(context.Background(), m)
|
||||
assert.Nil(t, err)
|
||||
|
||||
r := bindings.InvokeRequest{Data: []byte("This is my message")}
|
||||
|
@ -263,12 +288,12 @@ func TestReadQueueNoMessage(t *testing.T) {
|
|||
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)
|
||||
mm.On("Read", mock.AnythingOfType("*context.cancelCtx"), mock.AnythingOfType("*storagequeues.consumer")).Return(nil)
|
||||
|
||||
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
|
||||
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
|
||||
|
||||
m := bindings.Metadata{}
|
||||
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
|
||||
|
||||
err := a.Init(m)
|
||||
err := a.Init(context.Background(), m)
|
||||
assert.Nil(t, err)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
@ -285,6 +310,7 @@ func TestReadQueueNoMessage(t *testing.T) {
|
|||
time.Sleep(1 * time.Second)
|
||||
cancel()
|
||||
assert.Equal(t, 0, received)
|
||||
assert.NoError(t, a.Close())
|
||||
}
|
||||
|
||||
func TestParseMetadata(t *testing.T) {
|
||||
|
|
|
@ -48,7 +48,7 @@ func NewCFQueues(logger logger.Logger) bindings.OutputBinding {
|
|||
}
|
||||
|
||||
// Init the component.
|
||||
func (q *CFQueues) Init(metadata bindings.Metadata) error {
|
||||
func (q *CFQueues) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||
// Decode the metadata
|
||||
err := mapstructure.Decode(metadata.Properties, &q.metadata)
|
||||
if err != nil {
|
||||
|
|
|
@ -51,7 +51,7 @@ func NewCommercetools(logger logger.Logger) bindings.OutputBinding {
|
|||
}
|
||||
|
||||
// Init does metadata parsing and connection establishment.
|
||||
func (ct *Binding) Init(metadata bindings.Metadata) error {
|
||||
func (ct *Binding) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||
commercetoolsM, err := ct.getCommercetoolsMetadata(metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -16,6 +16,8 @@ package cron
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/benbjohnson/clock"
|
||||
|
@ -32,6 +34,9 @@ type Binding struct {
|
|||
schedule string
|
||||
parser cron.Parser
|
||||
clk clock.Clock
|
||||
closed atomic.Bool
|
||||
closeCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewCron returns a new Cron event input binding.
|
||||
|
@ -46,6 +51,7 @@ func NewCronWithClock(logger logger.Logger, clk clock.Clock) bindings.InputBindi
|
|||
parser: cron.NewParser(
|
||||
cron.SecondOptional | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor,
|
||||
),
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -54,7 +60,7 @@ func NewCronWithClock(logger logger.Logger, clk clock.Clock) bindings.InputBindi
|
|||
//
|
||||
// "15 * * * * *" - Every 15 sec
|
||||
// "0 30 * * * *" - Every 30 min
|
||||
func (b *Binding) Init(metadata bindings.Metadata) error {
|
||||
func (b *Binding) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||
b.name = metadata.Name
|
||||
s, f := metadata.Properties["schedule"]
|
||||
if !f || s == "" {
|
||||
|
@ -71,6 +77,10 @@ func (b *Binding) Init(metadata bindings.Metadata) error {
|
|||
|
||||
// Read triggers the Cron scheduler.
|
||||
func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
|
||||
if b.closed.Load() {
|
||||
return errors.New("binding is closed")
|
||||
}
|
||||
|
||||
c := cron.New(cron.WithParser(b.parser), cron.WithClock(b.clk))
|
||||
id, err := c.AddFunc(b.schedule, func() {
|
||||
b.logger.Debugf("name: %s, schedule fired: %v", b.name, time.Now())
|
||||
|
@ -87,12 +97,25 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
|
|||
c.Start()
|
||||
b.logger.Debugf("name: %s, next run: %v", b.name, time.Until(c.Entry(id).Next))
|
||||
|
||||
b.wg.Add(1)
|
||||
go func() {
|
||||
// Wait for context to be canceled
|
||||
<-ctx.Done()
|
||||
defer b.wg.Done()
|
||||
// Wait for context to be canceled or component to be closed.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-b.closeCh:
|
||||
}
|
||||
b.logger.Debugf("name: %s, stopping schedule: %s", b.name, b.schedule)
|
||||
c.Stop()
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Binding) Close() error {
|
||||
if b.closed.CompareAndSwap(false, true) {
|
||||
close(b.closeCh)
|
||||
}
|
||||
b.wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@ package cron
|
|||
import (
|
||||
"context"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -84,7 +85,7 @@ func TestCronInitSuccess(t *testing.T) {
|
|||
|
||||
for _, test := range initTests {
|
||||
c := getNewCron()
|
||||
err := c.Init(getTestMetadata(test.schedule))
|
||||
err := c.Init(context.Background(), getTestMetadata(test.schedule))
|
||||
if test.errorExpected {
|
||||
assert.Errorf(t, err, "Got no error while initializing an invalid schedule: %s", test.schedule)
|
||||
} else {
|
||||
|
@ -99,38 +100,41 @@ func TestCronRead(t *testing.T) {
|
|||
clk := clock.NewMock()
|
||||
c := getNewCronWithClock(clk)
|
||||
schedule := "@every 1s"
|
||||
assert.NoErrorf(t, c.Init(getTestMetadata(schedule)), "error initializing valid schedule")
|
||||
expectedCount := 5
|
||||
observedCount := 0
|
||||
assert.NoErrorf(t, c.Init(context.Background(), getTestMetadata(schedule)), "error initializing valid schedule")
|
||||
expectedCount := int32(5)
|
||||
var observedCount atomic.Int32
|
||||
err := c.Read(context.Background(), func(ctx context.Context, res *bindings.ReadResponse) ([]byte, error) {
|
||||
assert.NotNil(t, res)
|
||||
observedCount++
|
||||
observedCount.Add(1)
|
||||
return nil, nil
|
||||
})
|
||||
// Check if cron triggers 5 times in 5 seconds
|
||||
for i := 0; i < expectedCount; i++ {
|
||||
for i := int32(0); i < expectedCount; i++ {
|
||||
// Add time to mock clock in 1 second intervals using loop to allow cron go routine to run
|
||||
clk.Add(time.Second)
|
||||
}
|
||||
// Wait for 1 second after adding the last second to mock clock to allow cron to finish triggering
|
||||
time.Sleep(1 * time.Second)
|
||||
assert.Equal(t, expectedCount, observedCount, "Cron did not trigger expected number of times, expected %d, got %d", expectedCount, observedCount)
|
||||
assert.Eventually(t, func() bool {
|
||||
return observedCount.Load() == expectedCount
|
||||
}, time.Second, time.Millisecond*10,
|
||||
"Cron did not trigger expected number of times, expected %d, got %d", expectedCount, observedCount.Load())
|
||||
assert.NoErrorf(t, err, "error on read")
|
||||
assert.NoError(t, c.Close())
|
||||
}
|
||||
|
||||
func TestCronReadWithContextCancellation(t *testing.T) {
|
||||
clk := clock.NewMock()
|
||||
c := getNewCronWithClock(clk)
|
||||
schedule := "@every 1s"
|
||||
assert.NoErrorf(t, c.Init(getTestMetadata(schedule)), "error initializing valid schedule")
|
||||
expectedCount := 5
|
||||
observedCount := 0
|
||||
assert.NoErrorf(t, c.Init(context.Background(), getTestMetadata(schedule)), "error initializing valid schedule")
|
||||
expectedCount := int32(5)
|
||||
var observedCount atomic.Int32
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
err := c.Read(ctx, func(ctx context.Context, res *bindings.ReadResponse) ([]byte, error) {
|
||||
assert.NotNil(t, res)
|
||||
assert.LessOrEqualf(t, observedCount, expectedCount, "Invoke didn't stop the schedule")
|
||||
observedCount++
|
||||
if observedCount == expectedCount {
|
||||
assert.LessOrEqualf(t, observedCount.Load(), expectedCount, "Invoke didn't stop the schedule")
|
||||
observedCount.Add(1)
|
||||
if observedCount.Load() == expectedCount {
|
||||
// Cancel context after 5 triggers
|
||||
cancel()
|
||||
}
|
||||
|
@ -141,7 +145,10 @@ func TestCronReadWithContextCancellation(t *testing.T) {
|
|||
// Add time to mock clock in 1 second intervals using loop to allow cron go routine to run
|
||||
clk.Add(time.Second)
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
assert.Equal(t, expectedCount, observedCount, "Cron did not trigger expected number of times, expected %d, got %d", expectedCount, observedCount)
|
||||
assert.Eventually(t, func() bool {
|
||||
return observedCount.Load() == expectedCount
|
||||
}, time.Second, time.Millisecond*10,
|
||||
"Cron did not trigger expected number of times, expected %d, got %d", expectedCount, observedCount.Load())
|
||||
assert.NoErrorf(t, err, "error on read")
|
||||
assert.NoError(t, c.Close())
|
||||
}
|
||||
|
|
|
@ -47,7 +47,7 @@ func NewDubboOutput(logger logger.Logger) bindings.OutputBinding {
|
|||
return dubboBinding
|
||||
}
|
||||
|
||||
func (out *DubboOutputBinding) Init(_ bindings.Metadata) error {
|
||||
func (out *DubboOutputBinding) Init(_ context.Context, _ bindings.Metadata) error {
|
||||
dubboImpl.SetSerializer(constant.Hessian2Serialization, HessianSerializer{})
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -54,12 +54,13 @@ func TestInvoke(t *testing.T) {
|
|||
// 0. init dapr provided and dubbo server
|
||||
stopCh := make(chan struct{})
|
||||
defer close(stopCh)
|
||||
// Create output and set serializer before go routine to prevent data race.
|
||||
output := NewDubboOutput(logger.NewLogger("test"))
|
||||
dubboImpl.SetSerializer(constant.Hessian2Serialization, HessianSerializer{})
|
||||
go func() {
|
||||
assert.Nil(t, runDubboServer(stopCh))
|
||||
}()
|
||||
time.Sleep(time.Second * 3)
|
||||
dubboImpl.SetSerializer(constant.Hessian2Serialization, HessianSerializer{})
|
||||
output := NewDubboOutput(logger.NewLogger("test"))
|
||||
|
||||
// 1. create req/rsp value
|
||||
reqUser := &User{Name: testName}
|
||||
|
|
|
@ -83,14 +83,13 @@ func NewGCPStorage(logger logger.Logger) bindings.OutputBinding {
|
|||
}
|
||||
|
||||
// Init performs connection parsing.
|
||||
func (g *GCPStorage) Init(metadata bindings.Metadata) error {
|
||||
func (g *GCPStorage) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||
m, b, err := g.parseMetadata(metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clientOptions := option.WithCredentialsJSON(b)
|
||||
ctx := context.Background()
|
||||
client, err := storage.NewClient(ctx, clientOptions)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -16,7 +16,10 @@ package pubsub
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"cloud.google.com/go/pubsub"
|
||||
"google.golang.org/api/option"
|
||||
|
@ -36,6 +39,9 @@ type GCPPubSub struct {
|
|||
client *pubsub.Client
|
||||
metadata *pubSubMetadata
|
||||
logger logger.Logger
|
||||
closed atomic.Bool
|
||||
closeCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
type pubSubMetadata struct {
|
||||
|
@ -55,11 +61,14 @@ type pubSubMetadata struct {
|
|||
|
||||
// NewGCPPubSub returns a new GCPPubSub instance.
|
||||
func NewGCPPubSub(logger logger.Logger) bindings.InputOutputBinding {
|
||||
return &GCPPubSub{logger: logger}
|
||||
return &GCPPubSub{
|
||||
logger: logger,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Init parses metadata and creates a new Pub Sub client.
|
||||
func (g *GCPPubSub) Init(metadata bindings.Metadata) error {
|
||||
func (g *GCPPubSub) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||
b, err := g.parseMetadata(metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -71,7 +80,6 @@ func (g *GCPPubSub) Init(metadata bindings.Metadata) error {
|
|||
return err
|
||||
}
|
||||
clientOptions := option.WithCredentialsJSON(b)
|
||||
ctx := context.Background()
|
||||
pubsubClient, err := pubsub.NewClient(ctx, pubsubMeta.ProjectID, clientOptions)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating pubsub client: %s", err)
|
||||
|
@ -88,7 +96,12 @@ func (g *GCPPubSub) parseMetadata(metadata bindings.Metadata) ([]byte, error) {
|
|||
}
|
||||
|
||||
func (g *GCPPubSub) Read(ctx context.Context, handler bindings.Handler) error {
|
||||
if g.closed.Load() {
|
||||
return errors.New("binding is closed")
|
||||
}
|
||||
g.wg.Add(1)
|
||||
go func() {
|
||||
defer g.wg.Done()
|
||||
sub := g.client.Subscription(g.metadata.Subscription)
|
||||
err := sub.Receive(ctx, func(c context.Context, m *pubsub.Message) {
|
||||
_, err := handler(c, &bindings.ReadResponse{
|
||||
|
@ -128,5 +141,9 @@ func (g *GCPPubSub) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*b
|
|||
}
|
||||
|
||||
func (g *GCPPubSub) Close() error {
|
||||
if g.closed.CompareAndSwap(false, true) {
|
||||
close(g.closeCh)
|
||||
}
|
||||
defer g.wg.Wait()
|
||||
return g.client.Close()
|
||||
}
|
||||
|
|
|
@ -58,7 +58,7 @@ func NewGraphQL(logger logger.Logger) bindings.OutputBinding {
|
|||
}
|
||||
|
||||
// Init initializes the GraphQL binding.
|
||||
func (gql *GraphQL) Init(metadata bindings.Metadata) error {
|
||||
func (gql *GraphQL) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||
gql.logger.Debug("GraphQL Error: Initializing GraphQL binding")
|
||||
|
||||
p := metadata.Properties
|
||||
|
|
|
@ -74,7 +74,7 @@ func NewHTTP(logger logger.Logger) bindings.OutputBinding {
|
|||
}
|
||||
|
||||
// Init performs metadata parsing.
|
||||
func (h *HTTPSource) Init(metadata bindings.Metadata) error {
|
||||
func (h *HTTPSource) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||
var err error
|
||||
if err = mapstructure.Decode(metadata.Properties, &h.metadata); err != nil {
|
||||
return err
|
||||
|
@ -104,7 +104,7 @@ func (h *HTTPSource) Init(metadata bindings.Metadata) error {
|
|||
Transport: netTransport,
|
||||
}
|
||||
|
||||
if val, ok := metadata.Properties["errorIfNot2XX"]; ok {
|
||||
if val := metadata.Properties["errorIfNot2XX"]; val != "" {
|
||||
h.errorIfNot2XX = utils.IsTruthy(val)
|
||||
} else {
|
||||
// Default behavior
|
||||
|
|
|
@ -132,7 +132,7 @@ func InitBinding(s *httptest.Server, extraProps map[string]string) (bindings.Out
|
|||
}
|
||||
|
||||
hs := NewHTTP(logger.NewLogger("test"))
|
||||
err := hs.Init(m)
|
||||
err := hs.Init(context.Background(), m)
|
||||
return hs, err
|
||||
}
|
||||
|
||||
|
@ -269,7 +269,7 @@ func InitBindingForHTTPS(s *httptest.Server, extraProps map[string]string) (bind
|
|||
m.Properties[k] = v
|
||||
}
|
||||
hs := NewHTTP(logger.NewLogger("test"))
|
||||
err := hs.Init(m)
|
||||
err := hs.Init(context.Background(), m)
|
||||
return hs, err
|
||||
}
|
||||
|
||||
|
|
|
@ -75,7 +75,7 @@ func NewHuaweiOBS(logger logger.Logger) bindings.OutputBinding {
|
|||
}
|
||||
|
||||
// Init does metadata parsing and connection creation.
|
||||
func (o *HuaweiOBS) Init(metadata bindings.Metadata) error {
|
||||
func (o *HuaweiOBS) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||
o.logger.Debugf("initializing Huawei OBS binding and parsing metadata")
|
||||
|
||||
m, err := o.parseMetadata(metadata)
|
||||
|
|
|
@ -92,7 +92,7 @@ func TestInit(t *testing.T) {
|
|||
"accessKey": "dummy-ak",
|
||||
"secretKey": "dummy-sk",
|
||||
}
|
||||
err := obs.Init(m)
|
||||
err := obs.Init(context.Background(), m)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
t.Run("Init with missing bucket name", func(t *testing.T) {
|
||||
|
@ -102,7 +102,7 @@ func TestInit(t *testing.T) {
|
|||
"accessKey": "dummy-ak",
|
||||
"secretKey": "dummy-sk",
|
||||
}
|
||||
err := obs.Init(m)
|
||||
err := obs.Init(context.Background(), m)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, err, fmt.Errorf("missing obs bucket name"))
|
||||
})
|
||||
|
@ -113,7 +113,7 @@ func TestInit(t *testing.T) {
|
|||
"endpoint": "dummy-endpoint",
|
||||
"secretKey": "dummy-sk",
|
||||
}
|
||||
err := obs.Init(m)
|
||||
err := obs.Init(context.Background(), m)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, err, fmt.Errorf("missing the huawei access key"))
|
||||
})
|
||||
|
@ -124,7 +124,7 @@ func TestInit(t *testing.T) {
|
|||
"endpoint": "dummy-endpoint",
|
||||
"accessKey": "dummy-ak",
|
||||
}
|
||||
err := obs.Init(m)
|
||||
err := obs.Init(context.Background(), m)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, err, fmt.Errorf("missing the huawei secret key"))
|
||||
})
|
||||
|
@ -135,7 +135,7 @@ func TestInit(t *testing.T) {
|
|||
"accessKey": "dummy-ak",
|
||||
"secretKey": "dummy-sk",
|
||||
}
|
||||
err := obs.Init(m)
|
||||
err := obs.Init(context.Background(), m)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, err, fmt.Errorf("missing obs endpoint"))
|
||||
})
|
||||
|
|
|
@ -64,7 +64,7 @@ func NewInflux(logger logger.Logger) bindings.OutputBinding {
|
|||
}
|
||||
|
||||
// Init does metadata parsing and connection establishment.
|
||||
func (i *Influx) Init(metadata bindings.Metadata) error {
|
||||
func (i *Influx) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||
influxMeta, err := i.getInfluxMetadata(metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -54,7 +54,7 @@ func TestInflux_Init(t *testing.T) {
|
|||
assert.Nil(t, influx.client)
|
||||
|
||||
m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{"Url": "a", "Token": "a", "Org": "a", "Bucket": "a"}}}
|
||||
err := influx.Init(m)
|
||||
err := influx.Init(context.Background(), m)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.NotNil(t, influx.queryAPI)
|
||||
|
|
|
@ -16,6 +16,7 @@ package bindings
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/dapr/components-contrib/health"
|
||||
)
|
||||
|
@ -23,18 +24,21 @@ import (
|
|||
// InputBinding is the interface to define a binding that triggers on incoming events.
|
||||
type InputBinding interface {
|
||||
// Init passes connection and properties metadata to the binding implementation.
|
||||
Init(metadata Metadata) error
|
||||
Init(ctx context.Context, metadata Metadata) error
|
||||
// Read is a method that runs in background and triggers the callback function whenever an event arrives.
|
||||
Read(ctx context.Context, handler Handler) error
|
||||
// Close is a method that closes the connection to the binding. Must be
|
||||
// called when the binding is no longer needed to free up resources.
|
||||
io.Closer
|
||||
}
|
||||
|
||||
// Handler is the handler used to invoke the app handler.
|
||||
type Handler func(context.Context, *ReadResponse) ([]byte, error)
|
||||
|
||||
func PingInpBinding(inputBinding InputBinding) error {
|
||||
func PingInpBinding(ctx context.Context, inputBinding InputBinding) error {
|
||||
// checks if this input binding has the ping option then executes
|
||||
if inputBindingWithPing, ok := inputBinding.(health.Pinger); ok {
|
||||
return inputBindingWithPing.Ping()
|
||||
return inputBindingWithPing.Ping(ctx)
|
||||
} else {
|
||||
return fmt.Errorf("ping is not implemented by this input binding")
|
||||
}
|
||||
|
|
|
@ -15,7 +15,10 @@ package kafka
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/dapr/kit/logger"
|
||||
|
||||
|
@ -29,12 +32,13 @@ const (
|
|||
)
|
||||
|
||||
type Binding struct {
|
||||
kafka *kafka.Kafka
|
||||
publishTopic string
|
||||
topics []string
|
||||
logger logger.Logger
|
||||
subscribeCtx context.Context
|
||||
subscribeCancel context.CancelFunc
|
||||
kafka *kafka.Kafka
|
||||
publishTopic string
|
||||
topics []string
|
||||
logger logger.Logger
|
||||
closeCh chan struct{}
|
||||
closed atomic.Bool
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewKafka returns a new kafka binding instance.
|
||||
|
@ -43,15 +47,14 @@ func NewKafka(logger logger.Logger) bindings.InputOutputBinding {
|
|||
// in kafka binding component, disable consumer retry by default
|
||||
k.DefaultConsumeRetryEnabled = false
|
||||
return &Binding{
|
||||
kafka: k,
|
||||
logger: logger,
|
||||
kafka: k,
|
||||
logger: logger,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Binding) Init(metadata bindings.Metadata) error {
|
||||
b.subscribeCtx, b.subscribeCancel = context.WithCancel(context.Background())
|
||||
|
||||
err := b.kafka.Init(metadata.Properties)
|
||||
func (b *Binding) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||
err := b.kafka.Init(ctx, metadata.Properties)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -74,7 +77,10 @@ func (b *Binding) Operations() []bindings.OperationKind {
|
|||
}
|
||||
|
||||
func (b *Binding) Close() (err error) {
|
||||
b.subscribeCancel()
|
||||
if b.closed.CompareAndSwap(false, true) {
|
||||
close(b.closeCh)
|
||||
}
|
||||
defer b.wg.Wait()
|
||||
return b.kafka.Close()
|
||||
}
|
||||
|
||||
|
@ -84,6 +90,10 @@ func (b *Binding) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bin
|
|||
}
|
||||
|
||||
func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
|
||||
if b.closed.Load() {
|
||||
return errors.New("error: binding is closed")
|
||||
}
|
||||
|
||||
if len(b.topics) == 0 {
|
||||
b.logger.Warnf("kafka binding: no topic defined, input bindings will not be started")
|
||||
return nil
|
||||
|
@ -96,31 +106,22 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
|
|||
for _, t := range b.topics {
|
||||
b.kafka.AddTopicHandler(t, handlerConfig)
|
||||
}
|
||||
|
||||
b.wg.Add(1)
|
||||
go func() {
|
||||
// Wait for context cancelation
|
||||
defer b.wg.Done()
|
||||
// Wait for context cancelation or closure.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-b.subscribeCtx.Done():
|
||||
case <-b.closeCh:
|
||||
}
|
||||
|
||||
// Remove the topic handler before restarting the subscriber
|
||||
// Remove the topic handlers.
|
||||
for _, t := range b.topics {
|
||||
b.kafka.RemoveTopicHandler(t)
|
||||
}
|
||||
|
||||
// If the component's context has been canceled, do not re-subscribe
|
||||
if b.subscribeCtx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err := b.kafka.Subscribe(b.subscribeCtx)
|
||||
if err != nil {
|
||||
b.logger.Errorf("kafka binding: error re-subscribing: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return b.kafka.Subscribe(b.subscribeCtx)
|
||||
return b.kafka.Subscribe(ctx)
|
||||
}
|
||||
|
||||
func adaptHandler(handler bindings.Handler) kafka.EventHandler {
|
||||
|
|
|
@ -2,8 +2,11 @@ package kubemq
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
qs "github.com/kubemq-io/kubemq-go/queues_stream"
|
||||
|
@ -19,31 +22,30 @@ type Kubemq interface {
|
|||
}
|
||||
|
||||
type kubeMQ struct {
|
||||
client *qs.QueuesStreamClient
|
||||
opts *options
|
||||
logger logger.Logger
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
client *qs.QueuesStreamClient
|
||||
opts *options
|
||||
logger logger.Logger
|
||||
closed atomic.Bool
|
||||
closeCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewKubeMQ(logger logger.Logger) Kubemq {
|
||||
return &kubeMQ{
|
||||
client: nil,
|
||||
opts: nil,
|
||||
logger: logger,
|
||||
ctx: nil,
|
||||
ctxCancel: nil,
|
||||
client: nil,
|
||||
opts: nil,
|
||||
logger: logger,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (k *kubeMQ) Init(metadata bindings.Metadata) error {
|
||||
func (k *kubeMQ) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||
opts, err := createOptions(metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
k.opts = opts
|
||||
k.ctx, k.ctxCancel = context.WithCancel(context.Background())
|
||||
client, err := qs.NewQueuesStreamClient(k.ctx,
|
||||
client, err := qs.NewQueuesStreamClient(ctx,
|
||||
qs.WithAddress(opts.host, opts.port),
|
||||
qs.WithCheckConnection(true),
|
||||
qs.WithAuthToken(opts.authToken),
|
||||
|
@ -53,22 +55,39 @@ func (k *kubeMQ) Init(metadata bindings.Metadata) error {
|
|||
k.logger.Errorf("error init kubemq client error: %s", err.Error())
|
||||
return err
|
||||
}
|
||||
k.ctx, k.ctxCancel = context.WithCancel(context.Background())
|
||||
k.client = client
|
||||
return nil
|
||||
}
|
||||
|
||||
func (k *kubeMQ) Read(ctx context.Context, handler bindings.Handler) error {
|
||||
if k.closed.Load() {
|
||||
return errors.New("binding is closed")
|
||||
}
|
||||
k.wg.Add(2)
|
||||
processCtx, cancel := context.WithCancel(ctx)
|
||||
go func() {
|
||||
defer k.wg.Done()
|
||||
defer cancel()
|
||||
select {
|
||||
case <-k.closeCh:
|
||||
case <-processCtx.Done():
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer k.wg.Done()
|
||||
for {
|
||||
err := k.processQueueMessage(k.ctx, handler)
|
||||
err := k.processQueueMessage(processCtx, handler)
|
||||
if err != nil {
|
||||
k.logger.Error(err.Error())
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
if k.ctx.Err() != nil {
|
||||
return
|
||||
// If context cancelled or kubeMQ closed, exit. Otherwise, continue
|
||||
// after a second.
|
||||
select {
|
||||
case <-time.After(time.Second):
|
||||
continue
|
||||
case <-processCtx.Done():
|
||||
}
|
||||
return
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
|
@ -82,7 +101,7 @@ func (k *kubeMQ) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bind
|
|||
SetPolicyExpirationSeconds(parsePolicyExpirationSeconds(req.Metadata)).
|
||||
SetPolicyMaxReceiveCount(parseSetPolicyMaxReceiveCount(req.Metadata)).
|
||||
SetPolicyMaxReceiveQueue(parsePolicyMaxReceiveQueue(req.Metadata))
|
||||
result, err := k.client.Send(k.ctx, queueMessage)
|
||||
result, err := k.client.Send(ctx, queueMessage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -101,6 +120,14 @@ func (k *kubeMQ) Operations() []bindings.OperationKind {
|
|||
return []bindings.OperationKind{bindings.CreateOperation}
|
||||
}
|
||||
|
||||
func (k *kubeMQ) Close() error {
|
||||
if k.closed.CompareAndSwap(false, true) {
|
||||
close(k.closeCh)
|
||||
}
|
||||
defer k.wg.Wait()
|
||||
return k.client.Close()
|
||||
}
|
||||
|
||||
func (k *kubeMQ) processQueueMessage(ctx context.Context, handler bindings.Handler) error {
|
||||
pr := qs.NewPollRequest().
|
||||
SetChannel(k.opts.channel).
|
||||
|
|
|
@ -106,7 +106,7 @@ func Test_kubeMQ_Init(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
kubemq := NewKubeMQ(logger.NewLogger("test"))
|
||||
err := kubemq.Init(tt.meta)
|
||||
err := kubemq.Init(context.Background(), tt.meta)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
|
@ -120,7 +120,7 @@ func Test_kubeMQ_Invoke_Read_Single_Message(t *testing.T) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
kubemq := NewKubeMQ(logger.NewLogger("test"))
|
||||
err := kubemq.Init(getDefaultMetadata("test.read.single"))
|
||||
err := kubemq.Init(context.Background(), getDefaultMetadata("test.read.single"))
|
||||
require.NoError(t, err)
|
||||
dataReadCh := make(chan []byte)
|
||||
invokeRequest := &bindings.InvokeRequest{
|
||||
|
@ -147,7 +147,7 @@ func Test_kubeMQ_Invoke_Read_Single_MessageWithHandlerError(t *testing.T) {
|
|||
kubemq := NewKubeMQ(logger.NewLogger("test"))
|
||||
md := getDefaultMetadata("test.read.single.error")
|
||||
md.Properties["autoAcknowledged"] = "false"
|
||||
err := kubemq.Init(md)
|
||||
err := kubemq.Init(context.Background(), md)
|
||||
require.NoError(t, err)
|
||||
invokeRequest := &bindings.InvokeRequest{
|
||||
Data: []byte("test"),
|
||||
|
@ -182,7 +182,7 @@ func Test_kubeMQ_Invoke_Error(t *testing.T) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
kubemq := NewKubeMQ(logger.NewLogger("test"))
|
||||
err := kubemq.Init(getDefaultMetadata("***test***"))
|
||||
err := kubemq.Init(context.Background(), getDefaultMetadata("***test***"))
|
||||
require.NoError(t, err)
|
||||
|
||||
invokeRequest := &bindings.InvokeRequest{
|
||||
|
|
|
@ -18,6 +18,8 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
v1 "k8s.io/api/core/v1"
|
||||
|
@ -35,6 +37,9 @@ type kubernetesInput struct {
|
|||
namespace string
|
||||
resyncPeriod time.Duration
|
||||
logger logger.Logger
|
||||
closed atomic.Bool
|
||||
closeCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
type EventResponse struct {
|
||||
|
@ -45,10 +50,13 @@ type EventResponse struct {
|
|||
|
||||
// NewKubernetes returns a new Kubernetes event input binding.
|
||||
func NewKubernetes(logger logger.Logger) bindings.InputBinding {
|
||||
return &kubernetesInput{logger: logger}
|
||||
return &kubernetesInput{
|
||||
logger: logger,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (k *kubernetesInput) Init(metadata bindings.Metadata) error {
|
||||
func (k *kubernetesInput) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||
client, err := kubeclient.GetKubeClient()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -78,6 +86,9 @@ func (k *kubernetesInput) parseMetadata(metadata bindings.Metadata) error {
|
|||
}
|
||||
|
||||
func (k *kubernetesInput) Read(ctx context.Context, handler bindings.Handler) error {
|
||||
if k.closed.Load() {
|
||||
return errors.New("binding is closed")
|
||||
}
|
||||
watchlist := cache.NewListWatchFromClient(
|
||||
k.kubeClient.CoreV1().RESTClient(),
|
||||
"events",
|
||||
|
@ -126,12 +137,28 @@ func (k *kubernetesInput) Read(ctx context.Context, handler bindings.Handler) er
|
|||
},
|
||||
)
|
||||
|
||||
k.wg.Add(3)
|
||||
readCtx, cancel := context.WithCancel(ctx)
|
||||
|
||||
// catch when binding is closed.
|
||||
go func() {
|
||||
defer k.wg.Done()
|
||||
defer cancel()
|
||||
select {
|
||||
case <-readCtx.Done():
|
||||
case <-k.closeCh:
|
||||
}
|
||||
}()
|
||||
|
||||
// Start the controller in backgound
|
||||
stopCh := make(chan struct{})
|
||||
go controller.Run(stopCh)
|
||||
go func() {
|
||||
defer k.wg.Done()
|
||||
controller.Run(readCtx.Done())
|
||||
}()
|
||||
|
||||
// Watch for new messages and for context cancellation
|
||||
go func() {
|
||||
defer k.wg.Done()
|
||||
var (
|
||||
obj EventResponse
|
||||
data []byte
|
||||
|
@ -148,8 +175,7 @@ func (k *kubernetesInput) Read(ctx context.Context, handler bindings.Handler) er
|
|||
Data: data,
|
||||
})
|
||||
}
|
||||
case <-ctx.Done():
|
||||
close(stopCh)
|
||||
case <-readCtx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -157,3 +183,11 @@ func (k *kubernetesInput) Read(ctx context.Context, handler bindings.Handler) er
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (k *kubernetesInput) Close() error {
|
||||
if k.closed.CompareAndSwap(false, true) {
|
||||
close(k.closeCh)
|
||||
}
|
||||
k.wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -64,7 +64,7 @@ func NewLocalStorage(logger logger.Logger) bindings.OutputBinding {
|
|||
}
|
||||
|
||||
// Init performs metadata parsing.
|
||||
func (ls *LocalStorage) Init(metadata bindings.Metadata) error {
|
||||
func (ls *LocalStorage) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||
m, err := ls.parseMetadata(metadata)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse metadata: %w", err)
|
||||
|
|
|
@ -40,30 +40,29 @@ type MQTT struct {
|
|||
logger logger.Logger
|
||||
isSubscribed atomic.Bool
|
||||
readHandler bindings.Handler
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
backOff backoff.BackOff
|
||||
closeCh chan struct{}
|
||||
closed atomic.Bool
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewMQTT returns a new MQTT instance.
|
||||
func NewMQTT(logger logger.Logger) bindings.InputOutputBinding {
|
||||
return &MQTT{
|
||||
logger: logger,
|
||||
logger: logger,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Init does MQTT connection parsing.
|
||||
func (m *MQTT) Init(metadata bindings.Metadata) (err error) {
|
||||
func (m *MQTT) Init(ctx context.Context, metadata bindings.Metadata) (err error) {
|
||||
m.metadata, err = parseMQTTMetaData(metadata, m.logger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.ctx, m.cancel = context.WithCancel(context.Background())
|
||||
|
||||
// TODO: Make the backoff configurable for constant or exponential
|
||||
b := backoff.NewConstantBackOff(5 * time.Second)
|
||||
m.backOff = backoff.WithContext(b, m.ctx)
|
||||
m.backOff = backoff.NewConstantBackOff(5 * time.Second)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -104,7 +103,7 @@ func (m *MQTT) getProducer() (mqtt.Client, error) {
|
|||
return p, nil
|
||||
}
|
||||
|
||||
func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
|
||||
func (m *MQTT) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
|
||||
producer, err := m.getProducer()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create producer connection: %w", err)
|
||||
|
@ -118,7 +117,7 @@ func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (*
|
|||
bo := backoff.WithMaxRetries(
|
||||
backoff.NewConstantBackOff(200*time.Millisecond), 3,
|
||||
)
|
||||
bo = backoff.WithContext(bo, parentCtx)
|
||||
bo = backoff.WithContext(bo, ctx)
|
||||
|
||||
topic, ok := req.Metadata[mqttTopic]
|
||||
if !ok || topic == "" {
|
||||
|
@ -127,14 +126,13 @@ func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (*
|
|||
}
|
||||
return nil, retry.NotifyRecover(func() (err error) {
|
||||
token := producer.Publish(topic, m.metadata.qos, m.metadata.retain, req.Data)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, defaultWait)
|
||||
defer cancel()
|
||||
select {
|
||||
case <-token.Done():
|
||||
err = token.Error()
|
||||
case <-m.ctx.Done():
|
||||
// Context canceled
|
||||
err = m.ctx.Err()
|
||||
case <-m.closeCh:
|
||||
err = errors.New("mqtt client closed")
|
||||
case <-time.After(defaultWait):
|
||||
err = errors.New("mqtt client timeout")
|
||||
case <-ctx.Done():
|
||||
// Context canceled
|
||||
err = ctx.Err()
|
||||
|
@ -151,6 +149,10 @@ func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (*
|
|||
}
|
||||
|
||||
func (m *MQTT) Read(ctx context.Context, handler bindings.Handler) error {
|
||||
if m.closed.Load() {
|
||||
return errors.New("error: binding is closed")
|
||||
}
|
||||
|
||||
// If the subscription is already active, wait 2s before retrying (in case we're still disconnecting), otherwise return an error
|
||||
if !m.isSubscribed.CompareAndSwap(false, true) {
|
||||
m.logger.Debug("Subscription is already active; waiting 2s before retrying…")
|
||||
|
@ -177,11 +179,14 @@ func (m *MQTT) Read(ctx context.Context, handler bindings.Handler) error {
|
|||
|
||||
// In background, watch for contexts cancelation and stop the connection
|
||||
// However, do not call "unsubscribe" which would cause the broker to stop tracking the last message received by this consumer group
|
||||
m.wg.Add(1)
|
||||
go func() {
|
||||
defer m.wg.Done()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// nop
|
||||
case <-m.ctx.Done():
|
||||
case <-m.closeCh:
|
||||
// nop
|
||||
}
|
||||
|
||||
|
@ -208,14 +213,12 @@ func (m *MQTT) connect(clientID string, isSubscriber bool) (mqtt.Client, error)
|
|||
}
|
||||
client := mqtt.NewClient(opts)
|
||||
|
||||
ctx, cancel := context.WithTimeout(m.ctx, defaultWait)
|
||||
defer cancel()
|
||||
token := client.Connect()
|
||||
select {
|
||||
case <-token.Done():
|
||||
err = token.Error()
|
||||
case <-ctx.Done():
|
||||
err = ctx.Err()
|
||||
case <-time.After(defaultWait):
|
||||
err = errors.New("mqtt client timed out connecting")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect: %w", err)
|
||||
|
@ -290,43 +293,46 @@ func (m *MQTT) createClientOptions(uri *url.URL, clientID string) *mqtt.ClientOp
|
|||
return opts
|
||||
}
|
||||
|
||||
func (m *MQTT) handleMessage(client mqtt.Client, mqttMsg mqtt.Message) {
|
||||
// We're using m.ctx as context in this method because we don't have access to the Read context
|
||||
// Canceling the Read context makes Read invoke "Disconnect" anyways
|
||||
ctx := m.ctx
|
||||
func (m *MQTT) handleMessage() func(client mqtt.Client, mqttMsg mqtt.Message) {
|
||||
return func(client mqtt.Client, mqttMsg mqtt.Message) {
|
||||
bo := m.backOff
|
||||
if m.metadata.backOffMaxRetries >= 0 {
|
||||
bo = backoff.WithMaxRetries(bo, uint64(m.metadata.backOffMaxRetries))
|
||||
}
|
||||
|
||||
var bo backoff.BackOff = backoff.WithContext(m.backOff, ctx)
|
||||
if m.metadata.backOffMaxRetries >= 0 {
|
||||
bo = backoff.WithMaxRetries(bo, uint64(m.metadata.backOffMaxRetries))
|
||||
}
|
||||
err := retry.NotifyRecover(
|
||||
func() error {
|
||||
m.logger.Debugf("Processing MQTT message %s/%d", mqttMsg.Topic(), mqttMsg.MessageID())
|
||||
// Use a background context here so that the context is not tied to the
|
||||
// first Invoke first created the producer.
|
||||
// TODO: add context to mqtt library, and add a OnConnectWithContext option
|
||||
// to change this func signature to
|
||||
// func(c mqtt.Client, ctx context.Context)
|
||||
_, err := m.readHandler(context.Background(), &bindings.ReadResponse{
|
||||
Data: mqttMsg.Payload(),
|
||||
Metadata: map[string]string{
|
||||
mqttTopic: mqttMsg.Topic(),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := retry.NotifyRecover(
|
||||
func() error {
|
||||
m.logger.Debugf("Processing MQTT message %s/%d", mqttMsg.Topic(), mqttMsg.MessageID())
|
||||
_, err := m.readHandler(ctx, &bindings.ReadResponse{
|
||||
Data: mqttMsg.Payload(),
|
||||
Metadata: map[string]string{
|
||||
mqttTopic: mqttMsg.Topic(),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ack the message on success
|
||||
mqttMsg.Ack()
|
||||
return nil
|
||||
},
|
||||
bo,
|
||||
func(err error, d time.Duration) {
|
||||
m.logger.Errorf("Error processing MQTT message: %s/%d. Retrying…", mqttMsg.Topic(), mqttMsg.MessageID())
|
||||
},
|
||||
func() {
|
||||
m.logger.Infof("Successfully processed MQTT message after it previously failed: %s/%d", mqttMsg.Topic(), mqttMsg.MessageID())
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
m.logger.Errorf("Failed processing MQTT message: %s/%d: %v", mqttMsg.Topic(), mqttMsg.MessageID(), err)
|
||||
// Ack the message on success
|
||||
mqttMsg.Ack()
|
||||
return nil
|
||||
},
|
||||
bo,
|
||||
func(err error, d time.Duration) {
|
||||
m.logger.Errorf("Error processing MQTT message: %s/%d. Retrying…", mqttMsg.Topic(), mqttMsg.MessageID())
|
||||
},
|
||||
func() {
|
||||
m.logger.Infof("Successfully processed MQTT message after it previously failed: %s/%d", mqttMsg.Topic(), mqttMsg.MessageID())
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
m.logger.Errorf("Failed processing MQTT message: %s/%d: %v", mqttMsg.Topic(), mqttMsg.MessageID(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -336,17 +342,15 @@ func (m *MQTT) createSubscriberClientOptions(uri *url.URL, clientID string) *mqt
|
|||
|
||||
// On (re-)connection, add the topic subscription
|
||||
opts.OnConnect = func(c mqtt.Client) {
|
||||
token := c.Subscribe(m.metadata.topic, m.metadata.qos, m.handleMessage)
|
||||
token := c.Subscribe(m.metadata.topic, m.metadata.qos, m.handleMessage())
|
||||
|
||||
var err error
|
||||
subscribeCtx, subscribeCancel := context.WithTimeout(m.ctx, defaultWait)
|
||||
defer subscribeCancel()
|
||||
select {
|
||||
case <-token.Done():
|
||||
// Subscription went through (sucecessfully or not)
|
||||
err = token.Error()
|
||||
case <-subscribeCtx.Done():
|
||||
err = fmt.Errorf("error while waiting for subscription token: %w", subscribeCtx.Err())
|
||||
case <-time.After(defaultWait):
|
||||
err = errors.New("timed out waiting for subscription to complete")
|
||||
}
|
||||
|
||||
// Nothing we can do in case of errors besides logging them
|
||||
|
@ -363,13 +367,16 @@ func (m *MQTT) Close() error {
|
|||
m.producerLock.Lock()
|
||||
defer m.producerLock.Unlock()
|
||||
|
||||
// Canceling the context also causes Read to stop receiving messages
|
||||
m.cancel()
|
||||
if m.closed.CompareAndSwap(false, true) {
|
||||
close(m.closeCh)
|
||||
}
|
||||
|
||||
if m.producer != nil {
|
||||
m.producer.Disconnect(200)
|
||||
m.producer = nil
|
||||
}
|
||||
|
||||
m.wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -49,6 +49,7 @@ func getConnectionString() string {
|
|||
|
||||
func TestInvokeWithTopic(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
url := getConnectionString()
|
||||
if url == "" {
|
||||
|
@ -79,7 +80,7 @@ func TestInvokeWithTopic(t *testing.T) {
|
|||
logger := logger.NewLogger("test")
|
||||
|
||||
r := NewMQTT(logger).(*MQTT)
|
||||
err := r.Init(metadata)
|
||||
err := r.Init(ctx, metadata)
|
||||
assert.Nil(t, err)
|
||||
|
||||
conn, err := r.connect(uuid.NewString(), false)
|
||||
|
@ -127,4 +128,5 @@ func TestInvokeWithTopic(t *testing.T) {
|
|||
assert.True(t, ok)
|
||||
assert.Equal(t, dataCustomized, mqttMessage.Payload())
|
||||
assert.Equal(t, topicCustomized, mqttMessage.Topic())
|
||||
assert.NoError(t, r.Close())
|
||||
}
|
||||
|
|
|
@ -205,7 +205,6 @@ func TestParseMetadata(t *testing.T) {
|
|||
logger := logger.NewLogger("test")
|
||||
m := NewMQTT(logger).(*MQTT)
|
||||
m.backOff = backoff.NewConstantBackOff(5 * time.Second)
|
||||
m.ctx, m.cancel = context.WithCancel(context.Background())
|
||||
m.readHandler = func(ctx context.Context, r *bindings.ReadResponse) ([]byte, error) {
|
||||
assert.Equal(t, payload, r.Data)
|
||||
metadata := r.Metadata
|
||||
|
@ -215,7 +214,7 @@ func TestParseMetadata(t *testing.T) {
|
|||
return r.Data, nil
|
||||
}
|
||||
|
||||
m.handleMessage(nil, &mqttMockMessage{
|
||||
m.handleMessage()(nil, &mqttMockMessage{
|
||||
topic: topic,
|
||||
payload: payload,
|
||||
})
|
||||
|
|
|
@ -81,7 +81,7 @@ func NewMysql(logger logger.Logger) bindings.OutputBinding {
|
|||
}
|
||||
|
||||
// Init initializes the MySQL binding.
|
||||
func (m *Mysql) Init(metadata bindings.Metadata) error {
|
||||
func (m *Mysql) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||
m.logger.Debug("Initializing MySql binding")
|
||||
|
||||
p := metadata.Properties
|
||||
|
@ -115,7 +115,7 @@ func (m *Mysql) Init(metadata bindings.Metadata) error {
|
|||
return err
|
||||
}
|
||||
|
||||
err = db.Ping()
|
||||
err = db.PingContext(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to ping the DB: %w", err)
|
||||
}
|
||||
|
|
|
@ -75,7 +75,7 @@ func TestMysqlIntegration(t *testing.T) {
|
|||
|
||||
b := NewMysql(logger.NewLogger("test")).(*Mysql)
|
||||
m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{connectionURLKey: url}}}
|
||||
if err := b.Init(m); err != nil {
|
||||
if err := b.Init(context.Background(), m); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/nacos-group/nacos-sdk-go/v2/clients"
|
||||
|
@ -56,6 +57,9 @@ type Nacos struct {
|
|||
logger logger.Logger
|
||||
configClient config_client.IConfigClient //nolint:nosnakecase
|
||||
readHandler func(ctx context.Context, response *bindings.ReadResponse) ([]byte, error)
|
||||
wg sync.WaitGroup
|
||||
closed atomic.Bool
|
||||
closeCh chan struct{}
|
||||
}
|
||||
|
||||
// NewNacos returns a new Nacos instance.
|
||||
|
@ -63,11 +67,12 @@ func NewNacos(logger logger.Logger) bindings.OutputBinding {
|
|||
return &Nacos{
|
||||
logger: logger,
|
||||
watchesLock: sync.Mutex{},
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Init implements InputBinding/OutputBinding's Init method.
|
||||
func (n *Nacos) Init(metadata bindings.Metadata) error {
|
||||
func (n *Nacos) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||
n.settings = Settings{
|
||||
Timeout: defaultTimeout,
|
||||
}
|
||||
|
@ -146,6 +151,10 @@ func (n *Nacos) createConfigClient() error {
|
|||
|
||||
// Read implements InputBinding's Read method.
|
||||
func (n *Nacos) Read(ctx context.Context, handler bindings.Handler) error {
|
||||
if n.closed.Load() {
|
||||
return errors.New("binding is closed")
|
||||
}
|
||||
|
||||
n.readHandler = handler
|
||||
|
||||
n.watchesLock.Lock()
|
||||
|
@ -154,9 +163,14 @@ func (n *Nacos) Read(ctx context.Context, handler bindings.Handler) error {
|
|||
}
|
||||
n.watchesLock.Unlock()
|
||||
|
||||
n.wg.Add(1)
|
||||
go func() {
|
||||
defer n.wg.Done()
|
||||
// Cancel all listeners when the context is done
|
||||
<-ctx.Done()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-n.closeCh:
|
||||
}
|
||||
n.cancelAllListeners()
|
||||
}()
|
||||
|
||||
|
@ -165,8 +179,14 @@ func (n *Nacos) Read(ctx context.Context, handler bindings.Handler) error {
|
|||
|
||||
// Close implements cancel all listeners, see https://github.com/dapr/components-contrib/issues/779
|
||||
func (n *Nacos) Close() error {
|
||||
if n.closed.CompareAndSwap(false, true) {
|
||||
close(n.closeCh)
|
||||
}
|
||||
|
||||
n.cancelAllListeners()
|
||||
|
||||
n.wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -223,7 +243,11 @@ func (n *Nacos) addListener(ctx context.Context, config configParam) {
|
|||
|
||||
func (n *Nacos) addListenerFoInputBinding(ctx context.Context, config configParam) {
|
||||
if n.addToWatches(config) {
|
||||
go n.addListener(ctx, config)
|
||||
n.wg.Add(1)
|
||||
go func() {
|
||||
defer n.wg.Done()
|
||||
n.addListener(ctx, config)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ func TestInputBindingRead(t *testing.T) { //nolint:paralleltest
|
|||
m.Properties, err = getNacosLocalCacheMetadata()
|
||||
require.NoError(t, err)
|
||||
n := NewNacos(logger.NewLogger("test")).(*Nacos)
|
||||
err = n.Init(m)
|
||||
err = n.Init(context.Background(), m)
|
||||
require.NoError(t, err)
|
||||
var count int32
|
||||
ch := make(chan bool, 1)
|
||||
|
|
|
@ -22,15 +22,15 @@ import (
|
|||
|
||||
// OutputBinding is the interface for an output binding, allowing users to invoke remote systems with optional payloads.
|
||||
type OutputBinding interface {
|
||||
Init(metadata Metadata) error
|
||||
Init(ctx context.Context, metadata Metadata) error
|
||||
Invoke(ctx context.Context, req *InvokeRequest) (*InvokeResponse, error)
|
||||
Operations() []OperationKind
|
||||
}
|
||||
|
||||
func PingOutBinding(outputBinding OutputBinding) error {
|
||||
func PingOutBinding(ctx context.Context, outputBinding OutputBinding) error {
|
||||
// checks if this output binding has the ping option then executes
|
||||
if outputBindingWithPing, ok := outputBinding.(health.Pinger); ok {
|
||||
return outputBindingWithPing.Ping()
|
||||
return outputBindingWithPing.Ping(ctx)
|
||||
} else {
|
||||
return fmt.Errorf("ping is not implemented by this output binding")
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@ func NewPostgres(logger logger.Logger) bindings.OutputBinding {
|
|||
}
|
||||
|
||||
// Init initializes the PostgreSql binding.
|
||||
func (p *Postgres) Init(metadata bindings.Metadata) error {
|
||||
func (p *Postgres) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||
url, ok := metadata.Properties[connectionURLKey]
|
||||
if !ok || url == "" {
|
||||
return fmt.Errorf("required metadata not set: %s", connectionURLKey)
|
||||
|
@ -60,7 +60,9 @@ func (p *Postgres) Init(metadata bindings.Metadata) error {
|
|||
return fmt.Errorf("error opening DB connection: %w", err)
|
||||
}
|
||||
|
||||
p.db, err = pgxpool.NewWithConfig(context.Background(), poolConfig)
|
||||
// This context doesn't control the lifetime of the connection pool, and is
|
||||
// only scoped to postgres creating resources at init.
|
||||
p.db, err = pgxpool.NewWithConfig(ctx, poolConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to ping the DB: %w", err)
|
||||
}
|
||||
|
|
|
@ -64,7 +64,7 @@ func TestPostgresIntegration(t *testing.T) {
|
|||
// live DB test
|
||||
b := NewPostgres(logger.NewLogger("test")).(*Postgres)
|
||||
m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{connectionURLKey: url}}}
|
||||
if err := b.Init(m); err != nil {
|
||||
if err := b.Init(context.Background(), m); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
|
|
@ -74,7 +74,7 @@ func (p *Postmark) parseMetadata(meta bindings.Metadata) (postmarkMetadata, erro
|
|||
}
|
||||
|
||||
// Init does metadata parsing and not much else :).
|
||||
func (p *Postmark) Init(metadata bindings.Metadata) error {
|
||||
func (p *Postmark) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||
// Parse input metadata
|
||||
meta, err := p.parseMetadata(metadata)
|
||||
if err != nil {
|
||||
|
|
|
@ -19,6 +19,8 @@ import (
|
|||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
amqp "github.com/rabbitmq/amqp091-go"
|
||||
|
@ -50,6 +52,9 @@ type RabbitMQ struct {
|
|||
metadata rabbitMQMetadata
|
||||
logger logger.Logger
|
||||
queue amqp.Queue
|
||||
closed atomic.Bool
|
||||
closeCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// Metadata is the rabbitmq config.
|
||||
|
@ -66,11 +71,14 @@ type rabbitMQMetadata struct {
|
|||
|
||||
// NewRabbitMQ returns a new rabbitmq instance.
|
||||
func NewRabbitMQ(logger logger.Logger) bindings.InputOutputBinding {
|
||||
return &RabbitMQ{logger: logger}
|
||||
return &RabbitMQ{
|
||||
logger: logger,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Init does metadata parsing and connection creation.
|
||||
func (r *RabbitMQ) Init(metadata bindings.Metadata) error {
|
||||
func (r *RabbitMQ) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||
err := r.parseMetadata(metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -226,6 +234,10 @@ func (r *RabbitMQ) declareQueue() (amqp.Queue, error) {
|
|||
}
|
||||
|
||||
func (r *RabbitMQ) Read(ctx context.Context, handler bindings.Handler) error {
|
||||
if r.closed.Load() {
|
||||
return errors.New("binding already closed")
|
||||
}
|
||||
|
||||
msgs, err := r.channel.Consume(
|
||||
r.queue.Name,
|
||||
"",
|
||||
|
@ -239,14 +251,27 @@ func (r *RabbitMQ) Read(ctx context.Context, handler bindings.Handler) error {
|
|||
return err
|
||||
}
|
||||
|
||||
readCtx, cancel := context.WithCancel(ctx)
|
||||
|
||||
r.wg.Add(2)
|
||||
go func() {
|
||||
defer r.wg.Done()
|
||||
defer cancel()
|
||||
select {
|
||||
case <-r.closeCh:
|
||||
case <-readCtx.Done():
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer r.wg.Done()
|
||||
var err error
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-readCtx.Done():
|
||||
return
|
||||
case d := <-msgs:
|
||||
_, err = handler(ctx, &bindings.ReadResponse{
|
||||
_, err = handler(readCtx, &bindings.ReadResponse{
|
||||
Data: d.Body,
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -260,3 +285,11 @@ func (r *RabbitMQ) Read(ctx context.Context, handler bindings.Handler) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RabbitMQ) Close() error {
|
||||
if r.closed.CompareAndSwap(false, true) {
|
||||
close(r.closeCh)
|
||||
}
|
||||
defer r.wg.Wait()
|
||||
return r.channel.Close()
|
||||
}
|
||||
|
|
|
@ -85,7 +85,7 @@ func TestQueuesWithTTL(t *testing.T) {
|
|||
logger := logger.NewLogger("test")
|
||||
|
||||
r := NewRabbitMQ(logger).(*RabbitMQ)
|
||||
err := r.Init(metadata)
|
||||
err := r.Init(context.Background(), metadata)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Assert that if waited too long, we won't see any message
|
||||
|
@ -117,6 +117,7 @@ func TestQueuesWithTTL(t *testing.T) {
|
|||
assert.True(t, ok)
|
||||
msgBody := string(msg.Body)
|
||||
assert.Equal(t, testMsgContent, msgBody)
|
||||
assert.NoError(t, r.Close())
|
||||
}
|
||||
|
||||
func TestPublishingWithTTL(t *testing.T) {
|
||||
|
@ -144,7 +145,7 @@ func TestPublishingWithTTL(t *testing.T) {
|
|||
logger := logger.NewLogger("test")
|
||||
|
||||
rabbitMQBinding1 := NewRabbitMQ(logger).(*RabbitMQ)
|
||||
err := rabbitMQBinding1.Init(metadata)
|
||||
err := rabbitMQBinding1.Init(context.Background(), metadata)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Assert that if waited too long, we won't see any message
|
||||
|
@ -175,7 +176,7 @@ func TestPublishingWithTTL(t *testing.T) {
|
|||
|
||||
// Getting before it is expired, should return it
|
||||
rabbitMQBinding2 := NewRabbitMQ(logger).(*RabbitMQ)
|
||||
err = rabbitMQBinding2.Init(metadata)
|
||||
err = rabbitMQBinding2.Init(context.Background(), metadata)
|
||||
assert.Nil(t, err)
|
||||
|
||||
const testMsgContent = "test_msg"
|
||||
|
@ -193,6 +194,9 @@ func TestPublishingWithTTL(t *testing.T) {
|
|||
assert.True(t, ok)
|
||||
msgBody := string(msg.Body)
|
||||
assert.Equal(t, testMsgContent, msgBody)
|
||||
|
||||
assert.NoError(t, rabbitMQBinding1.Close())
|
||||
assert.NoError(t, rabbitMQBinding1.Close())
|
||||
}
|
||||
|
||||
func TestExclusiveQueue(t *testing.T) {
|
||||
|
@ -222,7 +226,7 @@ func TestExclusiveQueue(t *testing.T) {
|
|||
logger := logger.NewLogger("test")
|
||||
|
||||
r := NewRabbitMQ(logger).(*RabbitMQ)
|
||||
err := r.Init(metadata)
|
||||
err := r.Init(context.Background(), metadata)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Assert that if waited too long, we won't see any message
|
||||
|
@ -276,7 +280,7 @@ func TestPublishWithPriority(t *testing.T) {
|
|||
logger := logger.NewLogger("test")
|
||||
|
||||
r := NewRabbitMQ(logger).(*RabbitMQ)
|
||||
err := r.Init(metadata)
|
||||
err := r.Init(context.Background(), metadata)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Assert that if waited too long, we won't see any message
|
||||
|
|
|
@ -28,9 +28,6 @@ type Redis struct {
|
|||
client rediscomponent.RedisClient
|
||||
clientSettings *rediscomponent.Settings
|
||||
logger logger.Logger
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewRedis returns a new redis bindings instance.
|
||||
|
@ -39,15 +36,13 @@ func NewRedis(logger logger.Logger) bindings.OutputBinding {
|
|||
}
|
||||
|
||||
// Init performs metadata parsing and connection creation.
|
||||
func (r *Redis) Init(meta bindings.Metadata) (err error) {
|
||||
func (r *Redis) Init(ctx context.Context, meta bindings.Metadata) (err error) {
|
||||
r.client, r.clientSettings, err = rediscomponent.ParseClientFromProperties(meta.Properties, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.ctx, r.cancel = context.WithCancel(context.Background())
|
||||
|
||||
_, err = r.client.PingResult(r.ctx)
|
||||
_, err = r.client.PingResult(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("redis binding: error connecting to redis at %s: %s", r.clientSettings.Host, err)
|
||||
}
|
||||
|
@ -55,8 +50,8 @@ func (r *Redis) Init(meta bindings.Metadata) (err error) {
|
|||
return err
|
||||
}
|
||||
|
||||
func (r *Redis) Ping() error {
|
||||
if _, err := r.client.PingResult(r.ctx); err != nil {
|
||||
func (r *Redis) Ping(ctx context.Context) error {
|
||||
if _, err := r.client.PingResult(ctx); err != nil {
|
||||
return fmt.Errorf("redis binding: error connecting to redis at %s: %s", r.clientSettings.Host, err)
|
||||
}
|
||||
|
||||
|
@ -101,7 +96,5 @@ func (r *Redis) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindi
|
|||
}
|
||||
|
||||
func (r *Redis) Close() error {
|
||||
r.cancel()
|
||||
|
||||
return r.client.Close()
|
||||
}
|
||||
|
|
|
@ -40,7 +40,6 @@ func TestInvokeCreate(t *testing.T) {
|
|||
client: c,
|
||||
logger: logger.NewLogger("test"),
|
||||
}
|
||||
bind.ctx, bind.cancel = context.WithCancel(context.Background())
|
||||
|
||||
_, err := c.DoRead(context.Background(), "GET", testKey)
|
||||
assert.Equal(t, redis.Nil, err)
|
||||
|
@ -66,7 +65,6 @@ func TestInvokeGet(t *testing.T) {
|
|||
client: c,
|
||||
logger: logger.NewLogger("test"),
|
||||
}
|
||||
bind.ctx, bind.cancel = context.WithCancel(context.Background())
|
||||
|
||||
err := c.DoWrite(context.Background(), "SET", testKey, testData)
|
||||
assert.Equal(t, nil, err)
|
||||
|
@ -87,7 +85,6 @@ func TestInvokeDelete(t *testing.T) {
|
|||
client: c,
|
||||
logger: logger.NewLogger("test"),
|
||||
}
|
||||
bind.ctx, bind.cancel = context.WithCancel(context.Background())
|
||||
|
||||
err := c.DoWrite(context.Background(), "SET", testKey, testData)
|
||||
assert.Equal(t, nil, err)
|
||||
|
|
|
@ -19,6 +19,8 @@ import (
|
|||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
r "github.com/dancannon/gorethink"
|
||||
|
@ -34,6 +36,9 @@ type Binding struct {
|
|||
logger logger.Logger
|
||||
session *r.Session
|
||||
config StateConfig
|
||||
closed atomic.Bool
|
||||
closeCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// StateConfig is the binding config.
|
||||
|
@ -45,12 +50,13 @@ type StateConfig struct {
|
|||
// NewRethinkDBStateChangeBinding returns a new RethinkDB actor event input binding.
|
||||
func NewRethinkDBStateChangeBinding(logger logger.Logger) bindings.InputBinding {
|
||||
return &Binding{
|
||||
logger: logger,
|
||||
logger: logger,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Init initializes the RethinkDB binding.
|
||||
func (b *Binding) Init(metadata bindings.Metadata) error {
|
||||
func (b *Binding) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||
cfg, err := metadataToConfig(metadata.Properties, b.logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to parse metadata properties: %w", err)
|
||||
|
@ -68,6 +74,10 @@ func (b *Binding) Init(metadata bindings.Metadata) error {
|
|||
|
||||
// Read triggers the RethinkDB scheduler.
|
||||
func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
|
||||
if b.closed.Load() {
|
||||
return errors.New("binding is closed")
|
||||
}
|
||||
|
||||
b.logger.Infof("subscribing to state changes in %s.%s...", b.config.Database, b.config.Table)
|
||||
cursor, err := r.DB(b.config.Database).
|
||||
Table(b.config.Table).
|
||||
|
@ -81,8 +91,21 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
|
|||
return fmt.Errorf("error connecting to table '%s': %w", b.config.Table, err)
|
||||
}
|
||||
|
||||
readCtx, cancel := context.WithCancel(ctx)
|
||||
b.wg.Add(2)
|
||||
|
||||
go func() {
|
||||
for ctx.Err() == nil {
|
||||
defer b.wg.Done()
|
||||
defer cancel()
|
||||
select {
|
||||
case <-b.closeCh:
|
||||
case <-readCtx.Done():
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer b.wg.Done()
|
||||
for readCtx.Err() == nil {
|
||||
var change interface{}
|
||||
ok := cursor.Next(&change)
|
||||
if !ok {
|
||||
|
@ -105,7 +128,7 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
|
|||
},
|
||||
}
|
||||
|
||||
if _, err := handler(ctx, resp); err != nil {
|
||||
if _, err := handler(readCtx, resp); err != nil {
|
||||
b.logger.Errorf("error invoking change handler: %v", err)
|
||||
continue
|
||||
}
|
||||
|
@ -117,6 +140,14 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (b *Binding) Close() error {
|
||||
if b.closed.CompareAndSwap(false, true) {
|
||||
close(b.closeCh)
|
||||
}
|
||||
defer b.wg.Wait()
|
||||
return b.session.Close()
|
||||
}
|
||||
|
||||
func metadataToConfig(cfg map[string]string, logger logger.Logger) (StateConfig, error) {
|
||||
c := StateConfig{}
|
||||
for k, v := range cfg {
|
||||
|
|
|
@ -71,7 +71,7 @@ func TestBinding(t *testing.T) {
|
|||
assert.NotNil(t, m.Properties)
|
||||
|
||||
b := getNewRethinkActorBinding()
|
||||
err := b.Init(m)
|
||||
err := b.Init(context.Background(), m)
|
||||
assert.NoErrorf(t, err, "error initializing")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
|
|
@ -18,6 +18,8 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
mqc "github.com/apache/rocketmq-client-go/v2/consumer"
|
||||
|
@ -35,27 +37,27 @@ type RocketMQ struct {
|
|||
settings Settings
|
||||
producer mqw.Producer
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
backOffConfig retry.Config
|
||||
closeCh chan struct{}
|
||||
closed atomic.Bool
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewRocketMQ(l logger.Logger) *RocketMQ {
|
||||
return &RocketMQ{ //nolint:exhaustivestruct
|
||||
logger: l,
|
||||
producer: nil,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Init performs metadata parsing.
|
||||
func (a *RocketMQ) Init(metadata bindings.Metadata) error {
|
||||
func (a *RocketMQ) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||
var err error
|
||||
if err = a.settings.Decode(metadata.Properties); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
a.ctx, a.cancel = context.WithCancel(context.Background())
|
||||
|
||||
// Default retry configuration is used if no
|
||||
// backOff properties are set.
|
||||
if err = retry.DecodeConfigWithPrefix(
|
||||
|
@ -75,6 +77,10 @@ func (a *RocketMQ) Init(metadata bindings.Metadata) error {
|
|||
|
||||
// Read triggers the rocketmq subscription.
|
||||
func (a *RocketMQ) Read(ctx context.Context, handler bindings.Handler) error {
|
||||
if a.closed.Load() {
|
||||
return errors.New("error: binding is closed")
|
||||
}
|
||||
|
||||
a.logger.Debugf("binding rocketmq: start read input binding")
|
||||
|
||||
consumer, err := a.setupConsumer()
|
||||
|
@ -114,10 +120,12 @@ func (a *RocketMQ) Read(ctx context.Context, handler bindings.Handler) error {
|
|||
a.logger.Debugf("binding-rocketmq: consumer started")
|
||||
|
||||
// Listen for context cancelation to stop the subscription
|
||||
a.wg.Add(1)
|
||||
go func() {
|
||||
defer a.wg.Done()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-a.ctx.Done():
|
||||
case <-a.closeCh:
|
||||
}
|
||||
|
||||
innerErr := consumer.Shutdown()
|
||||
|
@ -131,8 +139,10 @@ func (a *RocketMQ) Read(ctx context.Context, handler bindings.Handler) error {
|
|||
|
||||
// Close implements cancel all listeners, see https://github.com/dapr/components-contrib/issues/779
|
||||
func (a *RocketMQ) Close() error {
|
||||
a.cancel()
|
||||
|
||||
defer a.wg.Wait()
|
||||
if a.closed.CompareAndSwap(false, true) {
|
||||
close(a.closeCh)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -199,21 +209,21 @@ func (a *RocketMQ) Operations() []bindings.OperationKind {
|
|||
return []bindings.OperationKind{bindings.CreateOperation}
|
||||
}
|
||||
|
||||
func (a *RocketMQ) Invoke(req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
|
||||
func (a *RocketMQ) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
|
||||
rst := &bindings.InvokeResponse{Data: nil, Metadata: nil}
|
||||
|
||||
if req.Operation != bindings.CreateOperation {
|
||||
return rst, fmt.Errorf("binding-rocketmq error: unsupported operation %s", req.Operation)
|
||||
}
|
||||
|
||||
return rst, a.sendMessage(req)
|
||||
return rst, a.sendMessage(ctx, req)
|
||||
}
|
||||
|
||||
func (a *RocketMQ) sendMessage(req *bindings.InvokeRequest) error {
|
||||
func (a *RocketMQ) sendMessage(ctx context.Context, req *bindings.InvokeRequest) error {
|
||||
topic := req.Metadata[metadataRocketmqTopic]
|
||||
|
||||
if topic != "" {
|
||||
_, err := a.send(topic, req.Metadata[metadataRocketmqTag], req.Metadata[metadataRocketmqKey], req.Data)
|
||||
_, err := a.send(ctx, topic, req.Metadata[metadataRocketmqTag], req.Metadata[metadataRocketmqKey], req.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -229,7 +239,7 @@ func (a *RocketMQ) sendMessage(req *bindings.InvokeRequest) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = a.send(topic, mqExpression, req.Metadata[metadataRocketmqKey], req.Data)
|
||||
_, err = a.send(ctx, topic, mqExpression, req.Metadata[metadataRocketmqKey], req.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -239,9 +249,9 @@ func (a *RocketMQ) sendMessage(req *bindings.InvokeRequest) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (a *RocketMQ) send(topic, mqExpr, key string, data []byte) (bool, error) {
|
||||
func (a *RocketMQ) send(ctx context.Context, topic, mqExpr, key string, data []byte) (bool, error) {
|
||||
msg := primitive.NewMessage(topic, data).WithTag(mqExpr).WithKeys([]string{key})
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
rst, err := a.producer.SendSync(ctx, msg)
|
||||
if err != nil {
|
||||
|
|
|
@ -35,7 +35,7 @@ func TestInputBindingRead(t *testing.T) { //nolint:paralleltest
|
|||
m := bindings.Metadata{} //nolint:exhaustivestruct
|
||||
m.Properties = getTestMetadata()
|
||||
r := NewRocketMQ(logger.NewLogger("test"))
|
||||
err := r.Init(m)
|
||||
err := r.Init(context.Background(), m)
|
||||
require.NoError(t, err)
|
||||
|
||||
var count int32
|
||||
|
@ -51,7 +51,7 @@ func TestInputBindingRead(t *testing.T) { //nolint:paralleltest
|
|||
time.Sleep(5 * time.Second)
|
||||
atomic.StoreInt32(&count, 0)
|
||||
req := &bindings.InvokeRequest{Data: []byte("hello"), Operation: bindings.CreateOperation, Metadata: map[string]string{}}
|
||||
_, err = r.Invoke(req)
|
||||
_, err = r.Invoke(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(10 * time.Second)
|
||||
|
|
|
@ -61,7 +61,7 @@ func NewSMTP(logger logger.Logger) bindings.OutputBinding {
|
|||
}
|
||||
|
||||
// Init smtp component (parse metadata).
|
||||
func (s *Mailer) Init(metadata bindings.Metadata) error {
|
||||
func (s *Mailer) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||
// parse metadata
|
||||
meta, err := s.parseMetadata(metadata)
|
||||
if err != nil {
|
||||
|
|
|
@ -84,7 +84,7 @@ func (sg *SendGrid) parseMetadata(meta bindings.Metadata) (sendGridMetadata, err
|
|||
}
|
||||
|
||||
// Init does metadata parsing and not much else :).
|
||||
func (sg *SendGrid) Init(metadata bindings.Metadata) error {
|
||||
func (sg *SendGrid) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||
// Parse input metadata
|
||||
meta, err := sg.parseMetadata(metadata)
|
||||
if err != nil {
|
||||
|
|
|
@ -60,19 +60,19 @@ func NewSMS(logger logger.Logger) bindings.OutputBinding {
|
|||
}
|
||||
}
|
||||
|
||||
func (t *SMS) Init(metadata bindings.Metadata) error {
|
||||
func (t *SMS) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||
twilioM := twilioMetadata{
|
||||
timeout: time.Minute * 5,
|
||||
}
|
||||
|
||||
if metadata.Properties[fromNumber] == "" {
|
||||
return errors.New("\"fromNumber\" is a required field")
|
||||
return errors.New(`"fromNumber" is a required field`)
|
||||
}
|
||||
if metadata.Properties[accountSid] == "" {
|
||||
return errors.New("\"accountSid\" is a required field")
|
||||
return errors.New(`"accountSid" is a required field`)
|
||||
}
|
||||
if metadata.Properties[authToken] == "" {
|
||||
return errors.New("\"authToken\" is a required field")
|
||||
return errors.New(`"authToken" is a required field`)
|
||||
}
|
||||
|
||||
twilioM.toNumber = metadata.Properties[toNumber]
|
||||
|
|
|
@ -53,7 +53,7 @@ func TestInit(t *testing.T) {
|
|||
m := bindings.Metadata{}
|
||||
m.Properties = map[string]string{"toNumber": "toNumber", "fromNumber": "fromNumber"}
|
||||
tw := NewSMS(logger.NewLogger("test"))
|
||||
err := tw.Init(m)
|
||||
err := tw.Init(context.Background(), m)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
|
@ -66,7 +66,7 @@ func TestParseDuration(t *testing.T) {
|
|||
"authToken": "authToken", "timeout": "badtimeout",
|
||||
}
|
||||
tw := NewSMS(logger.NewLogger("test"))
|
||||
err := tw.Init(m)
|
||||
err := tw.Init(context.Background(), m)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
|
@ -85,7 +85,7 @@ func TestWriteShouldSucceed(t *testing.T) {
|
|||
tw.httpClient = &http.Client{
|
||||
Transport: httpTransport,
|
||||
}
|
||||
err := tw.Init(m)
|
||||
err := tw.Init(context.Background(), m)
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Run("Should succeed with expected url and headers", func(t *testing.T) {
|
||||
|
@ -123,7 +123,7 @@ func TestWriteShouldFail(t *testing.T) {
|
|||
tw.httpClient = &http.Client{
|
||||
Transport: httpTransport,
|
||||
}
|
||||
err := tw.Init(m)
|
||||
err := tw.Init(context.Background(), m)
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Run("Missing 'to' should fail", func(t *testing.T) {
|
||||
|
@ -180,7 +180,7 @@ func TestMessageBody(t *testing.T) {
|
|||
tw.httpClient = &http.Client{
|
||||
Transport: httpTransport,
|
||||
}
|
||||
err := tw.Init(m)
|
||||
err := tw.Init(context.Background(), m)
|
||||
require.NoError(t, err)
|
||||
|
||||
tester := func(reqData []byte, expectBody string) func(t *testing.T) {
|
||||
|
|
|
@ -20,6 +20,8 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/dghubble/go-twitter/twitter"
|
||||
|
@ -31,18 +33,21 @@ import (
|
|||
|
||||
// Binding represents Twitter input/output binding.
|
||||
type Binding struct {
|
||||
client *twitter.Client
|
||||
query string
|
||||
logger logger.Logger
|
||||
client *twitter.Client
|
||||
query string
|
||||
logger logger.Logger
|
||||
closed atomic.Bool
|
||||
closeCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewTwitter returns a new Twitter event input binding.
|
||||
func NewTwitter(logger logger.Logger) bindings.InputOutputBinding {
|
||||
return &Binding{logger: logger}
|
||||
return &Binding{logger: logger, closeCh: make(chan struct{})}
|
||||
}
|
||||
|
||||
// Init initializes the Twitter binding.
|
||||
func (t *Binding) Init(metadata bindings.Metadata) error {
|
||||
func (t *Binding) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||
ck, f := metadata.Properties["consumerKey"]
|
||||
if !f || ck == "" {
|
||||
return fmt.Errorf("consumerKey not set")
|
||||
|
@ -124,10 +129,17 @@ func (t *Binding) Read(ctx context.Context, handler bindings.Handler) error {
|
|||
}
|
||||
|
||||
t.logger.Debug("starting handler...")
|
||||
go demux.HandleChan(stream.Messages)
|
||||
|
||||
t.wg.Add(2)
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
defer t.wg.Done()
|
||||
demux.HandleChan(stream.Messages)
|
||||
}()
|
||||
go func() {
|
||||
defer t.wg.Done()
|
||||
select {
|
||||
case <-t.closeCh:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
t.logger.Debug("stopping handler...")
|
||||
stream.Stop()
|
||||
}()
|
||||
|
@ -135,6 +147,14 @@ func (t *Binding) Read(ctx context.Context, handler bindings.Handler) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *Binding) Close() error {
|
||||
if t.closed.CompareAndSwap(false, true) {
|
||||
close(t.closeCh)
|
||||
}
|
||||
t.wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Invoke handles all operations.
|
||||
func (t *Binding) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
|
||||
t.logger.Debugf("operation: %v", req.Operation)
|
||||
|
|
|
@ -60,7 +60,7 @@ func getRuntimeMetadata() map[string]string {
|
|||
func TestInit(t *testing.T) {
|
||||
m := getTestMetadata()
|
||||
tw := NewTwitter(logger.NewLogger("test")).(*Binding)
|
||||
err := tw.Init(m)
|
||||
err := tw.Init(context.Background(), m)
|
||||
assert.Nilf(t, err, "error initializing valid metadata properties")
|
||||
}
|
||||
|
||||
|
@ -69,7 +69,7 @@ func TestInit(t *testing.T) {
|
|||
func TestReadError(t *testing.T) {
|
||||
tw := NewTwitter(logger.NewLogger("test")).(*Binding)
|
||||
m := getTestMetadata()
|
||||
err := tw.Init(m)
|
||||
err := tw.Init(context.Background(), m)
|
||||
assert.Nilf(t, err, "error initializing valid metadata properties")
|
||||
|
||||
err = tw.Read(context.Background(), func(ctx context.Context, res *bindings.ReadResponse) ([]byte, error) {
|
||||
|
@ -79,6 +79,8 @@ func TestReadError(t *testing.T) {
|
|||
return nil, nil
|
||||
})
|
||||
assert.Error(t, err)
|
||||
|
||||
assert.NoError(t, tw.Close())
|
||||
}
|
||||
|
||||
// TestRead executes the Read method which calls Twiter API
|
||||
|
@ -93,7 +95,7 @@ func TestRead(t *testing.T) {
|
|||
m.Properties["query"] = "microsoft"
|
||||
tw := NewTwitter(logger.NewLogger("test")).(*Binding)
|
||||
tw.logger.SetOutputLevel(logger.DebugLevel)
|
||||
err := tw.Init(m)
|
||||
err := tw.Init(context.Background(), m)
|
||||
assert.Nilf(t, err, "error initializing read")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
@ -116,6 +118,8 @@ func TestRead(t *testing.T) {
|
|||
cancel()
|
||||
t.Fatal("Timeout waiting for messages")
|
||||
}
|
||||
|
||||
assert.NoError(t, tw.Close())
|
||||
}
|
||||
|
||||
// TestInvoke executes the Invoke method which calls Twiter API
|
||||
|
@ -129,7 +133,7 @@ func TestInvoke(t *testing.T) {
|
|||
m.Properties = getRuntimeMetadata()
|
||||
tw := NewTwitter(logger.NewLogger("test")).(*Binding)
|
||||
tw.logger.SetOutputLevel(logger.DebugLevel)
|
||||
err := tw.Init(m)
|
||||
err := tw.Init(context.Background(), m)
|
||||
assert.Nilf(t, err, "error initializing Invoke")
|
||||
|
||||
req := &bindings.InvokeRequest{
|
||||
|
@ -141,4 +145,5 @@ func TestInvoke(t *testing.T) {
|
|||
resp, err := tw.Invoke(context.Background(), req)
|
||||
assert.Nilf(t, err, "error on invoke")
|
||||
assert.NotNil(t, resp)
|
||||
assert.NoError(t, tw.Close())
|
||||
}
|
||||
|
|
|
@ -61,7 +61,7 @@ func NewZeebeCommand(logger logger.Logger) bindings.OutputBinding {
|
|||
}
|
||||
|
||||
// Init does metadata parsing and connection creation.
|
||||
func (z *ZeebeCommand) Init(metadata bindings.Metadata) error {
|
||||
func (z *ZeebeCommand) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||
client, err := z.clientFactory.Get(metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -114,7 +114,7 @@ func (z *ZeebeCommand) Invoke(ctx context.Context, req *bindings.InvokeRequest)
|
|||
case UpdateJobRetriesOperation:
|
||||
return z.updateJobRetries(ctx, req)
|
||||
case ThrowErrorOperation:
|
||||
return z.throwError(req)
|
||||
return z.throwError(ctx, req)
|
||||
case bindings.GetOperation:
|
||||
fallthrough
|
||||
case bindings.CreateOperation:
|
||||
|
|
|
@ -58,7 +58,7 @@ func TestInit(t *testing.T) {
|
|||
}
|
||||
|
||||
cmd := ZeebeCommand{clientFactory: mcf, logger: testLogger}
|
||||
err := cmd.Init(metadata)
|
||||
err := cmd.Init(context.Background(), metadata)
|
||||
assert.Error(t, err, errParsing)
|
||||
})
|
||||
|
||||
|
@ -67,7 +67,7 @@ func TestInit(t *testing.T) {
|
|||
mcf := mockClientFactory{}
|
||||
|
||||
cmd := ZeebeCommand{clientFactory: mcf, logger: testLogger}
|
||||
err := cmd.Init(metadata)
|
||||
err := cmd.Init(context.Background(), metadata)
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ type throwErrorPayload struct {
|
|||
ErrorMessage string `json:"errorMessage"`
|
||||
}
|
||||
|
||||
func (z *ZeebeCommand) throwError(req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
|
||||
func (z *ZeebeCommand) throwError(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
|
||||
var payload throwErrorPayload
|
||||
err := json.Unmarshal(req.Data, &payload)
|
||||
if err != nil {
|
||||
|
@ -53,7 +53,7 @@ func (z *ZeebeCommand) throwError(req *bindings.InvokeRequest) (*bindings.Invoke
|
|||
cmd = cmd.ErrorMessage(payload.ErrorMessage)
|
||||
}
|
||||
|
||||
_, err = cmd.Send(context.Background())
|
||||
_, err = cmd.Send(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot throw error for job key %d: %w", payload.JobKey, err)
|
||||
}
|
||||
|
|
|
@ -19,6 +19,8 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/camunda/zeebe/clients/go/v8/pkg/entities"
|
||||
|
@ -39,6 +41,9 @@ type ZeebeJobWorker struct {
|
|||
client zbc.Client
|
||||
metadata *jobWorkerMetadata
|
||||
logger logger.Logger
|
||||
closed atomic.Bool
|
||||
closeCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// https://docs.zeebe.io/basics/job-workers.html
|
||||
|
@ -64,11 +69,15 @@ type jobHandler struct {
|
|||
|
||||
// NewZeebeJobWorker returns a new ZeebeJobWorker instance.
|
||||
func NewZeebeJobWorker(logger logger.Logger) bindings.InputBinding {
|
||||
return &ZeebeJobWorker{clientFactory: zeebe.NewClientFactoryImpl(logger), logger: logger}
|
||||
return &ZeebeJobWorker{
|
||||
clientFactory: zeebe.NewClientFactoryImpl(logger),
|
||||
logger: logger,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Init does metadata parsing and connection creation.
|
||||
func (z *ZeebeJobWorker) Init(metadata bindings.Metadata) error {
|
||||
func (z *ZeebeJobWorker) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||
meta, err := z.parseMetadata(metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -90,6 +99,10 @@ func (z *ZeebeJobWorker) Init(metadata bindings.Metadata) error {
|
|||
}
|
||||
|
||||
func (z *ZeebeJobWorker) Read(ctx context.Context, handler bindings.Handler) error {
|
||||
if z.closed.Load() {
|
||||
return fmt.Errorf("binding is closed")
|
||||
}
|
||||
|
||||
h := jobHandler{
|
||||
callback: handler,
|
||||
logger: z.logger,
|
||||
|
@ -99,8 +112,14 @@ func (z *ZeebeJobWorker) Read(ctx context.Context, handler bindings.Handler) err
|
|||
|
||||
jobWorker := z.getJobWorker(h)
|
||||
|
||||
z.wg.Add(1)
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
defer z.wg.Done()
|
||||
|
||||
select {
|
||||
case <-z.closeCh:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
|
||||
jobWorker.Close()
|
||||
jobWorker.AwaitClose()
|
||||
|
@ -110,6 +129,14 @@ func (z *ZeebeJobWorker) Read(ctx context.Context, handler bindings.Handler) err
|
|||
return nil
|
||||
}
|
||||
|
||||
func (z *ZeebeJobWorker) Close() error {
|
||||
if z.closed.CompareAndSwap(false, true) {
|
||||
close(z.closeCh)
|
||||
}
|
||||
z.wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (z *ZeebeJobWorker) parseMetadata(meta bindings.Metadata) (*jobWorkerMetadata, error) {
|
||||
var m jobWorkerMetadata
|
||||
err := metadata.DecodeMetadata(meta.Properties, &m)
|
||||
|
|
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||
package jobworker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
|
@ -53,10 +54,11 @@ func TestInit(t *testing.T) {
|
|||
metadata := bindings.Metadata{}
|
||||
var mcf mockClientFactory
|
||||
|
||||
jobWorker := ZeebeJobWorker{clientFactory: &mcf, logger: testLogger}
|
||||
err := jobWorker.Init(metadata)
|
||||
jobWorker := ZeebeJobWorker{clientFactory: &mcf, logger: testLogger, closeCh: make(chan struct{})}
|
||||
err := jobWorker.Init(context.Background(), metadata)
|
||||
|
||||
assert.Error(t, err, ErrMissingJobType)
|
||||
assert.NoError(t, jobWorker.Close())
|
||||
})
|
||||
|
||||
t.Run("sets client from client factory", func(t *testing.T) {
|
||||
|
@ -66,8 +68,8 @@ func TestInit(t *testing.T) {
|
|||
mcf := mockClientFactory{
|
||||
metadata: metadata,
|
||||
}
|
||||
jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger}
|
||||
err := jobWorker.Init(metadata)
|
||||
jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger, closeCh: make(chan struct{})}
|
||||
err := jobWorker.Init(context.Background(), metadata)
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
|
@ -76,6 +78,7 @@ func TestInit(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.Equal(t, mc, jobWorker.client)
|
||||
assert.Equal(t, metadata, mcf.metadata)
|
||||
assert.NoError(t, jobWorker.Close())
|
||||
})
|
||||
|
||||
t.Run("returns error if client could not be instantiated properly", func(t *testing.T) {
|
||||
|
@ -85,9 +88,10 @@ func TestInit(t *testing.T) {
|
|||
error: errParsing,
|
||||
}
|
||||
|
||||
jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger}
|
||||
err := jobWorker.Init(metadata)
|
||||
jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger, closeCh: make(chan struct{})}
|
||||
err := jobWorker.Init(context.Background(), metadata)
|
||||
assert.Error(t, err, errParsing)
|
||||
assert.NoError(t, jobWorker.Close())
|
||||
})
|
||||
|
||||
t.Run("sets client from client factory", func(t *testing.T) {
|
||||
|
@ -98,8 +102,8 @@ func TestInit(t *testing.T) {
|
|||
metadata: metadata,
|
||||
}
|
||||
|
||||
jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger}
|
||||
err := jobWorker.Init(metadata)
|
||||
jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger, closeCh: make(chan struct{})}
|
||||
err := jobWorker.Init(context.Background(), metadata)
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
|
@ -108,5 +112,6 @@ func TestInit(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.Equal(t, mc, jobWorker.client)
|
||||
assert.Equal(t, metadata, mcf.metadata)
|
||||
assert.NoError(t, jobWorker.Close())
|
||||
})
|
||||
}
|
||||
|
|
|
@ -73,7 +73,7 @@ func NewAzureAppConfigurationStore(logger logger.Logger) configuration.Store {
|
|||
}
|
||||
|
||||
// Init does metadata and connection parsing.
|
||||
func (r *ConfigurationStore) Init(metadata configuration.Metadata) error {
|
||||
func (r *ConfigurationStore) Init(_ context.Context, metadata configuration.Metadata) error {
|
||||
m, err := parseMetadata(metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -204,7 +204,7 @@ func TestInit(t *testing.T) {
|
|||
Properties: testProperties,
|
||||
}}
|
||||
|
||||
err := s.Init(m)
|
||||
err := s.Init(context.Background(), m)
|
||||
assert.Nil(t, err)
|
||||
cs, ok := s.(*ConfigurationStore)
|
||||
assert.True(t, ok)
|
||||
|
@ -229,7 +229,7 @@ func TestInit(t *testing.T) {
|
|||
Properties: testProperties,
|
||||
}}
|
||||
|
||||
err := s.Init(m)
|
||||
err := s.Init(context.Background(), m)
|
||||
assert.Nil(t, err)
|
||||
cs, ok := s.(*ConfigurationStore)
|
||||
assert.True(t, ok)
|
||||
|
|
|
@ -86,7 +86,7 @@ func NewPostgresConfigurationStore(logger logger.Logger) configuration.Store {
|
|||
}
|
||||
}
|
||||
|
||||
func (p *ConfigurationStore) Init(metadata configuration.Metadata) error {
|
||||
func (p *ConfigurationStore) Init(parentCtx context.Context, metadata configuration.Metadata) error {
|
||||
p.logger.Debug(InfoStartInit)
|
||||
if p.client != nil {
|
||||
return fmt.Errorf(ErrorAlreadyInitialized)
|
||||
|
@ -98,7 +98,7 @@ func (p *ConfigurationStore) Init(metadata configuration.Metadata) error {
|
|||
p.metadata = m
|
||||
}
|
||||
p.ActiveSubscriptions = make(map[string]*subscription)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), p.metadata.maxIdleTimeout)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, p.metadata.maxIdleTimeout)
|
||||
defer cancel()
|
||||
client, err := Connect(ctx, p.metadata.connectionString, p.metadata.maxIdleTimeout)
|
||||
if err != nil {
|
||||
|
|
|
@ -143,7 +143,7 @@ func parseRedisMetadata(meta configuration.Metadata) (metadata, error) {
|
|||
}
|
||||
|
||||
// Init does metadata and connection parsing.
|
||||
func (r *ConfigurationStore) Init(metadata configuration.Metadata) error {
|
||||
func (r *ConfigurationStore) Init(ctx context.Context, metadata configuration.Metadata) error {
|
||||
m, err := parseRedisMetadata(metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -156,11 +156,11 @@ func (r *ConfigurationStore) Init(metadata configuration.Metadata) error {
|
|||
r.client = r.newClient(m)
|
||||
}
|
||||
|
||||
if _, err = r.client.Ping(context.TODO()).Result(); err != nil {
|
||||
if _, err = r.client.Ping(ctx).Result(); err != nil {
|
||||
return fmt.Errorf("redis store: error connecting to redis at %s: %s", m.Host, err)
|
||||
}
|
||||
|
||||
r.replicas, err = r.getConnectedSlaves()
|
||||
r.replicas, err = r.getConnectedSlaves(ctx)
|
||||
|
||||
return err
|
||||
}
|
||||
|
@ -204,8 +204,8 @@ func (r *ConfigurationStore) newFailoverClient(m metadata) *redis.Client {
|
|||
return redis.NewFailoverClient(opts)
|
||||
}
|
||||
|
||||
func (r *ConfigurationStore) getConnectedSlaves() (int, error) {
|
||||
res, err := r.client.Do(context.Background(), "INFO", "replication").Result()
|
||||
func (r *ConfigurationStore) getConnectedSlaves(ctx context.Context) (int, error) {
|
||||
res, err := r.client.Do(ctx, "INFO", "replication").Result()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ import "context"
|
|||
// Store is an interface to perform operations on store.
|
||||
type Store interface {
|
||||
// Init configuration store.
|
||||
Init(metadata Metadata) error
|
||||
Init(ctx context.Context, metadata Metadata) error
|
||||
|
||||
// Get configuration.
|
||||
Get(ctx context.Context, req *GetRequest) (*GetResponse, error)
|
||||
|
|
2
go.mod
2
go.mod
|
@ -397,3 +397,5 @@ replace github.com/toolkits/concurrent => github.com/niean/gotools v0.0.0-201512
|
|||
|
||||
// this is a fork which addresses a performance issues due to go routines
|
||||
replace dubbo.apache.org/dubbo-go/v3 => dubbo.apache.org/dubbo-go/v3 v3.0.3-0.20230118042253-4f159a2b38f3
|
||||
|
||||
replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181
|
||||
|
|
|
@ -1,5 +1,22 @@
|
|||
/*
|
||||
Copyright 2023 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package health
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type Pinger interface {
|
||||
Ping() error
|
||||
Ping(ctx context.Context) error
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ import (
|
|||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Shopify/sarama"
|
||||
|
@ -31,6 +32,7 @@ type consumer struct {
|
|||
k *Kafka
|
||||
ready chan bool
|
||||
running chan struct{}
|
||||
stopped atomic.Bool
|
||||
once sync.Once
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
@ -275,9 +277,6 @@ func (k *Kafka) Subscribe(ctx context.Context) error {
|
|||
|
||||
k.cg = cg
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
k.cancel = cancel
|
||||
|
||||
ready := make(chan bool)
|
||||
k.consumer = consumer{
|
||||
k: k,
|
||||
|
@ -320,7 +319,10 @@ func (k *Kafka) Subscribe(ctx context.Context) error {
|
|||
k.logger.Errorf("Error closing consumer group: %v", err)
|
||||
}
|
||||
|
||||
close(k.consumer.running)
|
||||
// Ensure running channel is only closed once.
|
||||
if k.consumer.stopped.CompareAndSwap(false, true) {
|
||||
close(k.consumer.running)
|
||||
}
|
||||
}()
|
||||
|
||||
<-ready
|
||||
|
@ -331,7 +333,6 @@ func (k *Kafka) Subscribe(ctx context.Context) error {
|
|||
// Close down consumer group resources, refresh once.
|
||||
func (k *Kafka) closeSubscriptionResources() {
|
||||
if k.cg != nil {
|
||||
k.cancel()
|
||||
err := k.cg.Close()
|
||||
if err != nil {
|
||||
k.logger.Errorf("Error closing consumer group: %v", err)
|
||||
|
|
|
@ -36,7 +36,6 @@ type Kafka struct {
|
|||
saslPassword string
|
||||
initialOffset int64
|
||||
cg sarama.ConsumerGroup
|
||||
cancel context.CancelFunc
|
||||
consumer consumer
|
||||
config *sarama.Config
|
||||
subscribeTopics TopicHandlerConfig
|
||||
|
@ -60,7 +59,7 @@ func NewKafka(logger logger.Logger) *Kafka {
|
|||
}
|
||||
|
||||
// Init does metadata parsing and connection establishment.
|
||||
func (k *Kafka) Init(metadata map[string]string) error {
|
||||
func (k *Kafka) Init(_ context.Context, metadata map[string]string) error {
|
||||
upgradedMetadata, err := k.upgradeMetadata(metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -107,7 +107,8 @@ func (ts *OAuthTokenSource) Token() (*sarama.AccessToken, error) {
|
|||
|
||||
oidcCfg := ccred.Config{ClientID: ts.ClientID, ClientSecret: ts.ClientSecret, Scopes: ts.Scopes, TokenURL: ts.TokenEndpoint.TokenURL, AuthStyle: ts.TokenEndpoint.AuthStyle}
|
||||
|
||||
timeoutCtx, _ := ctx.WithTimeout(ctx.TODO(), tokenRequestTimeout) //nolint:govet
|
||||
timeoutCtx, cancel := ctx.WithTimeout(ctx.TODO(), tokenRequestTimeout)
|
||||
defer cancel()
|
||||
|
||||
ts.configureClient()
|
||||
|
||||
|
|
|
@ -38,9 +38,6 @@ type StandaloneRedisLock struct {
|
|||
metadata rediscomponent.Metadata
|
||||
|
||||
logger logger.Logger
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewStandaloneRedisLock returns a new standalone redis lock.
|
||||
|
@ -54,7 +51,7 @@ func NewStandaloneRedisLock(logger logger.Logger) lock.Store {
|
|||
}
|
||||
|
||||
// Init StandaloneRedisLock.
|
||||
func (r *StandaloneRedisLock) InitLockStore(metadata lock.Metadata) error {
|
||||
func (r *StandaloneRedisLock) InitLockStore(ctx context.Context, metadata lock.Metadata) error {
|
||||
// 1. parse config
|
||||
m, err := rediscomponent.ParseRedisMetadata(metadata.Properties)
|
||||
if err != nil {
|
||||
|
@ -75,13 +72,12 @@ func (r *StandaloneRedisLock) InitLockStore(metadata lock.Metadata) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.ctx, r.cancel = context.WithCancel(context.Background())
|
||||
// 3. connect to redis
|
||||
if _, err = r.client.PingResult(r.ctx); err != nil {
|
||||
if _, err = r.client.PingResult(ctx); err != nil {
|
||||
return fmt.Errorf("[standaloneRedisLock]: error connecting to redis at %s: %s", r.clientSettings.Host, err)
|
||||
}
|
||||
// no replica
|
||||
replicas, err := r.getConnectedSlaves()
|
||||
replicas, err := r.getConnectedSlaves(ctx)
|
||||
// pass the validation if error occurs,
|
||||
// since some redis versions such as miniredis do not recognize the `INFO` command.
|
||||
if err == nil && replicas > 0 {
|
||||
|
@ -101,8 +97,8 @@ func needFailover(properties map[string]string) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (r *StandaloneRedisLock) getConnectedSlaves() (int, error) {
|
||||
res, err := r.client.DoRead(r.ctx, "INFO", "replication")
|
||||
func (r *StandaloneRedisLock) getConnectedSlaves(ctx context.Context) (int, error) {
|
||||
res, err := r.client.DoRead(ctx, "INFO", "replication")
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
@ -183,9 +179,6 @@ func newInternalErrorUnlockResponse() *lock.UnlockResponse {
|
|||
|
||||
// Close shuts down the client's redis connections.
|
||||
func (r *StandaloneRedisLock) Close() error {
|
||||
if r.cancel != nil {
|
||||
r.cancel()
|
||||
}
|
||||
if r.client != nil {
|
||||
closeErr := r.client.Close()
|
||||
r.client = nil
|
||||
|
|
|
@ -42,7 +42,7 @@ func TestStandaloneRedisLock_InitError(t *testing.T) {
|
|||
cfg.Properties["redisPassword"] = ""
|
||||
|
||||
// init
|
||||
err := comp.InitLockStore(cfg)
|
||||
err := comp.InitLockStore(context.Background(), cfg)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
|
@ -58,7 +58,7 @@ func TestStandaloneRedisLock_InitError(t *testing.T) {
|
|||
cfg.Properties["redisPassword"] = ""
|
||||
|
||||
// init
|
||||
err := comp.InitLockStore(cfg)
|
||||
err := comp.InitLockStore(context.Background(), cfg)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
|
@ -75,7 +75,7 @@ func TestStandaloneRedisLock_InitError(t *testing.T) {
|
|||
cfg.Properties["maxRetries"] = "1 "
|
||||
|
||||
// init
|
||||
err := comp.InitLockStore(cfg)
|
||||
err := comp.InitLockStore(context.Background(), cfg)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
@ -96,7 +96,7 @@ func TestStandaloneRedisLock_TryLock(t *testing.T) {
|
|||
cfg.Properties["redisHost"] = s.Addr()
|
||||
cfg.Properties["redisPassword"] = ""
|
||||
// init
|
||||
err = comp.InitLockStore(cfg)
|
||||
err = comp.InitLockStore(context.Background(), cfg)
|
||||
assert.NoError(t, err)
|
||||
// 1. client1 trylock
|
||||
ownerID1 := uuid.New().String()
|
||||
|
|
|
@ -17,7 +17,7 @@ import "context"
|
|||
|
||||
type Store interface {
|
||||
// Init this component.
|
||||
InitLockStore(metadata Metadata) error
|
||||
InitLockStore(ctx context.Context, metadata Metadata) error
|
||||
|
||||
// TryLock tries to acquire a lock.
|
||||
TryLock(ctx context.Context, req *TryLockRequest) (*TryLockResponse, error)
|
||||
|
|
|
@ -57,14 +57,12 @@ type Middleware struct {
|
|||
}
|
||||
|
||||
// GetHandler retruns the HTTP handler provided by the middleware.
|
||||
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
|
||||
func (m *Middleware) GetHandler(ctx context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
|
||||
meta, err := m.getNativeMetadata(metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx := context.TODO()
|
||||
|
||||
// Create a JWKS cache that is refreshed automatically
|
||||
cache := jwk.NewCache(ctx,
|
||||
jwk.WithErrSink(httprc.ErrSinkFunc(func(err error) {
|
||||
|
|
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
@ -59,7 +60,7 @@ const (
|
|||
)
|
||||
|
||||
// GetHandler retruns the HTTP handler provided by the middleware.
|
||||
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
|
||||
func (m *Middleware) GetHandler(ctx context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
|
||||
meta, err := m.getNativeMetadata(metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
package mock_oauth2clientcredentials
|
||||
|
||||
import (
|
||||
"context"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
|
@ -36,7 +37,7 @@ func (m *MockTokenProviderInterface) EXPECT() *MockTokenProviderInterfaceMockRec
|
|||
}
|
||||
|
||||
// GetToken mocks base method
|
||||
func (m *MockTokenProviderInterface) GetToken(arg0 *clientcredentials.Config) (*oauth2.Token, error) {
|
||||
func (m *MockTokenProviderInterface) GetToken(ctx context.Context, arg0 *clientcredentials.Config) (*oauth2.Token, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetToken", arg0)
|
||||
ret0, _ := ret[0].(*oauth2.Token)
|
||||
|
|
|
@ -45,7 +45,7 @@ type oAuth2ClientCredentialsMiddlewareMetadata struct {
|
|||
|
||||
// TokenProviderInterface provides a common interface to Mock the Token retrieval in unit tests.
|
||||
type TokenProviderInterface interface {
|
||||
GetToken(conf *clientcredentials.Config) (*oauth2.Token, error)
|
||||
GetToken(ctx context.Context, conf *clientcredentials.Config) (*oauth2.Token, error)
|
||||
}
|
||||
|
||||
// NewOAuth2ClientCredentialsMiddleware returns a new oAuth2 middleware.
|
||||
|
@ -68,7 +68,7 @@ type Middleware struct {
|
|||
}
|
||||
|
||||
// GetHandler retruns the HTTP handler provided by the middleware.
|
||||
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
|
||||
func (m *Middleware) GetHandler(_ context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
|
||||
meta, err := m.getNativeMetadata(metadata)
|
||||
if err != nil {
|
||||
m.log.Errorf("getNativeMetadata error: %s", err)
|
||||
|
@ -101,7 +101,7 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Ha
|
|||
if !found {
|
||||
m.log.Debugf("Cached token not found, try get one")
|
||||
|
||||
token, err := m.tokenProvider.GetToken(conf)
|
||||
token, err := m.tokenProvider.GetToken(r.Context(), conf)
|
||||
if err != nil {
|
||||
m.log.Errorf("Error acquiring token: %s", err)
|
||||
return
|
||||
|
@ -171,8 +171,8 @@ func (m *Middleware) SetTokenProvider(tokenProvider TokenProviderInterface) {
|
|||
}
|
||||
|
||||
// GetToken returns a token from the current OAuth2 ClientCredentials Configuration.
|
||||
func (m *Middleware) GetToken(conf *clientcredentials.Config) (*oauth2.Token, error) {
|
||||
tokenSource := conf.TokenSource(context.Background())
|
||||
func (m *Middleware) GetToken(ctx context.Context, conf *clientcredentials.Config) (*oauth2.Token, error) {
|
||||
tokenSource := conf.TokenSource(ctx)
|
||||
|
||||
return tokenSource.Token()
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||
package oauth2clientcredentials
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
@ -45,7 +46,7 @@ func TestOAuth2ClientCredentialsMetadata(t *testing.T) {
|
|||
metadata.Properties = map[string]string{}
|
||||
|
||||
log := logger.NewLogger("oauth2clientcredentials.test")
|
||||
_, err := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
|
||||
_, err := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(context.Background(), metadata)
|
||||
assert.EqualError(t, err, "metadata errors: Parameter 'headerName' needs to be set. Parameter 'clientID' needs to be set. Parameter 'clientSecret' needs to be set. Parameter 'scopes' needs to be set. Parameter 'tokenURL' needs to be set. ")
|
||||
|
||||
// Invalid authStyle (non int)
|
||||
|
@ -57,17 +58,17 @@ func TestOAuth2ClientCredentialsMetadata(t *testing.T) {
|
|||
"headerName": "someHeader",
|
||||
"authStyle": "asdf", // This is the value to test
|
||||
}
|
||||
_, err2 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
|
||||
_, err2 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(context.Background(), metadata)
|
||||
assert.EqualError(t, err2, "metadata errors: 1 error(s) decoding:\n\n* cannot parse 'AuthStyle' as int: strconv.ParseInt: parsing \"asdf\": invalid syntax")
|
||||
|
||||
// Invalid authStyle (int > 2)
|
||||
metadata.Properties["authStyle"] = "3"
|
||||
_, err3 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
|
||||
_, err3 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(context.Background(), metadata)
|
||||
assert.EqualError(t, err3, "metadata errors: Parameter 'authStyle' can only have the values 0,1,2. Received: '3'. ")
|
||||
|
||||
// Invalid authStyle (int < 0)
|
||||
metadata.Properties["authStyle"] = "-1"
|
||||
_, err4 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
|
||||
_, err4 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(context.Background(), metadata)
|
||||
assert.EqualError(t, err4, "metadata errors: Parameter 'authStyle' can only have the values 0,1,2. Received: '-1'. ")
|
||||
}
|
||||
|
||||
|
@ -109,7 +110,7 @@ func TestOAuth2ClientCredentialsToken(t *testing.T) {
|
|||
log := logger.NewLogger("oauth2clientcredentials.test")
|
||||
oauth2clientcredentialsMiddleware, _ := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware)
|
||||
oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider)
|
||||
handler, err := oauth2clientcredentialsMiddleware.GetHandler(metadata)
|
||||
handler, err := oauth2clientcredentialsMiddleware.GetHandler(context.Background(), metadata)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First handler call should return abc Token
|
||||
|
@ -169,7 +170,7 @@ func TestOAuth2ClientCredentialsCache(t *testing.T) {
|
|||
log := logger.NewLogger("oauth2clientcredentials.test")
|
||||
oauth2clientcredentialsMiddleware, _ := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware)
|
||||
oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider)
|
||||
handler, err := oauth2clientcredentialsMiddleware.GetHandler(metadata)
|
||||
handler, err := oauth2clientcredentialsMiddleware.GetHandler(context.Background(), metadata)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First handler call should return abc Token
|
||||
|
|
|
@ -106,13 +106,13 @@ func (s *Status) Valid() bool {
|
|||
}
|
||||
|
||||
// GetHandler returns the HTTP handler provided by the middleware.
|
||||
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
|
||||
func (m *Middleware) GetHandler(parentCtx context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
|
||||
meta, err := m.getNativeMetadata(metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), time.Minute)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, time.Minute)
|
||||
query, err := rego.New(
|
||||
rego.Query("result = data.http.allow"),
|
||||
rego.Module("inline.rego", meta.Rego),
|
||||
|
|
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||
package opa
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
@ -331,7 +332,7 @@ func TestOpaPolicy(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
opaMiddleware := NewMiddleware(log)
|
||||
|
||||
handler, err := opaMiddleware.GetHandler(test.meta)
|
||||
handler, err := opaMiddleware.GetHandler(context.Background(), test.meta)
|
||||
if test.shouldHandlerError {
|
||||
require.Error(t, err)
|
||||
return
|
||||
|
|
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
@ -46,7 +47,7 @@ func NewRateLimitMiddleware(_ logger.Logger) middleware.Middleware {
|
|||
type Middleware struct{}
|
||||
|
||||
// GetHandler returns the HTTP handler provided by the middleware.
|
||||
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
|
||||
func (m *Middleware) GetHandler(_ context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
|
||||
meta, err := m.getNativeMetadata(metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -40,7 +40,7 @@ func NewMiddleware(logger logger.Logger) middleware.Middleware {
|
|||
}
|
||||
|
||||
// GetHandler retruns the HTTP handler provided by the middleware.
|
||||
func (m *Middleware) GetHandler(metadata middleware.Metadata) (
|
||||
func (m *Middleware) GetHandler(_ context.Context, metadata middleware.Metadata) (
|
||||
func(next http.Handler) http.Handler, error,
|
||||
) {
|
||||
if err := m.getNativeMetadata(metadata); err != nil {
|
||||
|
|
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||
package routeralias
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
@ -46,7 +47,7 @@ func TestRequestHandlerWithIllegalRouterRule(t *testing.T) {
|
|||
}
|
||||
log := logger.NewLogger("routeralias.test")
|
||||
ralias := NewMiddleware(log)
|
||||
handler, err := ralias.GetHandler(meta)
|
||||
handler, err := ralias.GetHandler(context.Background(), meta)
|
||||
assert.Nil(t, err)
|
||||
|
||||
t.Run("hit: change router with common request", func(t *testing.T) {
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue