Propagate context from caller to appropriate places in the code (#2474)
* Propagates contexts to callers where appropriate. Signed-off-by: joshvanl <me@joshvanl.dev> * Updates units tests with new func signature Signed-off-by: joshvanl <me@joshvanl.dev> * Fix linting errors Signed-off-by: joshvanl <me@joshvanl.dev> * Add atomic gate to alicloud rocketmq close channel. Signed-off-by: joshvanl <me@joshvanl.dev> * bindings/aws/kinesis use a separate ctx variable name Signed-off-by: joshvanl <me@joshvanl.dev> * binding/kafka: use atomic to prevent closing the channel twice Signed-off-by: joshvanl <me@joshvanl.dev> * bindings/mqtt3: use atomic bool to prevent close channel being closed multiple times Signed-off-by: joshvanl <me@joshvanl.dev> * bindings/mqtt3: use Background context for handle operations:w Signed-off-by: joshvanl <me@joshvanl.dev> * state/cocroachdb: add context to Ping() Signed-off-by: joshvanl <me@joshvanl.dev> * bindings/postgres: add comment explaining use of context. Signed-off-by: joshvanl <me@joshvanl.dev> * Adds comment header to health/pinger.go Signed-off-by: joshvanl <me@joshvanl.dev> * pubsub/aws/snssqs: add waitgroup to wait for all go routines to finish and block on Close(). Shuts down the subscription if there are no topic handlers. Signed-off-by: joshvanl <me@joshvanl.dev> * pubsub/mqtt3: add atomic bool to prevent multiple channel closes. Add wait group to block close on all goroutines to finish. Signed-off-by: joshvanl <me@joshvanl.dev> * pubsub/rabbitmq: fixes race conditions, uses atomic to prevent multiple closes, add wait group to block close on all goroutines Signed-off-by: joshvanl <me@joshvanl.dev> * pubsub/redis: revert ctx passed when it could be cancelled. Add wait group wait when closing. Signed-off-by: joshvanl <me@joshvanl.dev> * state/postges: pass context in init, and wait group on close. Signed-off-by: joshvanl <me@joshvanl.dev> * Update all `Ping()` to `PingContext()` where possible. Signed-off-by: joshvanl <me@joshvanl.dev> * state/in-memory: add atomic bool to prevent closing channel multiple times. Add wait group to block on close() Signed-off-by: joshvanl <me@joshvanl.dev> * state/mysql: don't use same ctx variable name Signed-off-by: joshvanl <me@joshvanl.dev> * Pass correct loop context to redis go routines Signed-off-by: joshvanl <me@joshvanl.dev> * Rename context when creating timeouts in state Signed-off-by: joshvanl <me@joshvanl.dev> * Remove state.Features() from requiring a context Signed-off-by: joshvanl <me@joshvanl.dev> * Revert wasm request handle Close func to be without context to implement io.Closer interface. Add 5 second timeout. Add io.Closer assertion in test. Signed-off-by: joshvanl <me@joshvanl.dev> * Remove superfluous go lint vet directive Signed-off-by: joshvanl <me@joshvanl.dev> * Change Configuration Init function to take context Signed-off-by: joshvanl <me@joshvanl.dev> * Updates input binding interface to include a `Close() error` function. `Close` blocks until all resources have been released and go routines have returned. Signed-off-by: joshvanl <me@joshvanl.dev> * Change `Close() error` in input binding struct to `io.Closer` interface. Signed-off-by: joshvanl <me@joshvanl.dev> * Update go.mod files to point to dapr/dapr PR https://github.com/dapr/dapr/pull/5831 Signed-off-by: joshvanl <me@joshvanl.dev> * pubsub/redis: watch closeCh to shutdown worker instead of init context. Signed-off-by: joshvanl <me@joshvanl.dev> * pubsub/aws/snssqs + bindings/kubemq: ensure closeCh is caught so Close correctly returns Signed-off-by: joshvanl <me@joshvanl.dev> * Close kubemq binding client on close. Ensure kafka consumer channel cannot be closed more than once. Signed-off-by: joshvanl <me@joshvanl.dev> * Tweaks Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * Fixed cert tests Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> * binding/mqtt3: add inline Background context instead of passing to handleMessage Signed-off-by: joshvanl <me@joshvanl.dev> * pubsub/mqtt3: remove context from createSubscriberClientOptions Signed-off-by: joshvanl <me@joshvanl.dev> * pubsub/mqtt3: Remove `ResetConnection` func Signed-off-by: joshvanl <me@joshvanl.dev> * pubsub/kafka: Don't resubscribe if Subscribe is cancelled. Signed-off-by: joshvanl <me@joshvanl.dev> * binding/mqtt3: don't use context to control establishing connection Signed-off-by: joshvanl <me@joshvanl.dev> * bindings/mqtt3: Fix linting errors Signed-off-by: joshvanl <me@joshvanl.dev> --------- Signed-off-by: joshvanl <me@joshvanl.dev> Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> Co-authored-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> Co-authored-by: Yaron Schneider <schneider.yaron@live.com>
This commit is contained in:
parent
210c8c3c59
commit
d098e38d6a
|
|
@ -18,3 +18,5 @@ require (
|
||||||
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
|
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
|
||||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
|
replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181
|
||||||
|
|
|
||||||
|
|
@ -33,3 +33,5 @@ require (
|
||||||
google.golang.org/protobuf v1.28.1 // indirect
|
google.golang.org/protobuf v1.28.1 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
|
replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181
|
||||||
|
|
|
||||||
|
|
@ -80,7 +80,7 @@ func NewDingTalkWebhook(l logger.Logger) bindings.InputOutputBinding {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init performs metadata parsing.
|
// Init performs metadata parsing.
|
||||||
func (t *DingTalkWebhook) Init(metadata bindings.Metadata) error {
|
func (t *DingTalkWebhook) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||||
var err error
|
var err error
|
||||||
if err = t.settings.Decode(metadata.Properties); err != nil {
|
if err = t.settings.Decode(metadata.Properties); err != nil {
|
||||||
return fmt.Errorf("dingtalk configuration error: %w", err)
|
return fmt.Errorf("dingtalk configuration error: %w", err)
|
||||||
|
|
@ -107,6 +107,13 @@ func (t *DingTalkWebhook) Read(ctx context.Context, handler bindings.Handler) er
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *DingTalkWebhook) Close() error {
|
||||||
|
webhooks.Lock()
|
||||||
|
defer webhooks.Unlock()
|
||||||
|
delete(webhooks.m, t.settings.ID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Operations returns list of operations supported by dingtalk webhook binding.
|
// Operations returns list of operations supported by dingtalk webhook binding.
|
||||||
func (t *DingTalkWebhook) Operations() []bindings.OperationKind {
|
func (t *DingTalkWebhook) Operations() []bindings.OperationKind {
|
||||||
return []bindings.OperationKind{bindings.CreateOperation, bindings.GetOperation}
|
return []bindings.OperationKind{bindings.CreateOperation, bindings.GetOperation}
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ func TestPublishMsg(t *testing.T) { //nolint:paralleltest
|
||||||
}}}
|
}}}
|
||||||
|
|
||||||
d := NewDingTalkWebhook(logger.NewLogger("test"))
|
d := NewDingTalkWebhook(logger.NewLogger("test"))
|
||||||
err := d.Init(m)
|
err := d.Init(context.Background(), m)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req := &bindings.InvokeRequest{Data: []byte(msg), Operation: bindings.CreateOperation, Metadata: map[string]string{}}
|
req := &bindings.InvokeRequest{Data: []byte(msg), Operation: bindings.CreateOperation, Metadata: map[string]string{}}
|
||||||
|
|
@ -78,7 +78,7 @@ func TestBindingReadAndInvoke(t *testing.T) { //nolint:paralleltest
|
||||||
}}
|
}}
|
||||||
|
|
||||||
d := NewDingTalkWebhook(logger.NewLogger("test"))
|
d := NewDingTalkWebhook(logger.NewLogger("test"))
|
||||||
err := d.Init(m)
|
err := d.Init(context.Background(), m)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
var count int32
|
var count int32
|
||||||
|
|
@ -106,3 +106,18 @@ func TestBindingReadAndInvoke(t *testing.T) { //nolint:paralleltest
|
||||||
require.FailNow(t, "read timeout")
|
require.FailNow(t, "read timeout")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBindingClose(t *testing.T) {
|
||||||
|
d := NewDingTalkWebhook(logger.NewLogger("test"))
|
||||||
|
m := bindings.Metadata{Base: metadata.Base{
|
||||||
|
Name: "test",
|
||||||
|
Properties: map[string]string{
|
||||||
|
"url": "/test",
|
||||||
|
"secret": "",
|
||||||
|
"id": "x",
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
assert.NoError(t, d.Init(context.Background(), m))
|
||||||
|
assert.NoError(t, d.Close())
|
||||||
|
assert.NoError(t, d.Close(), "second close should not error")
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ func NewAliCloudOSS(logger logger.Logger) bindings.OutputBinding {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init does metadata parsing and connection creation.
|
// Init does metadata parsing and connection creation.
|
||||||
func (s *AliCloudOSS) Init(metadata bindings.Metadata) error {
|
func (s *AliCloudOSS) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||||
m, err := s.parseMetadata(metadata)
|
m, err := s.parseMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ type Callback struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// parse metadata field
|
// parse metadata field
|
||||||
func (s *AliCloudSlsLogstorage) Init(metadata bindings.Metadata) error {
|
func (s *AliCloudSlsLogstorage) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||||
m, err := s.parseMeta(metadata)
|
m, err := s.parseMeta(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,7 @@ func NewAliCloudTableStore(log logger.Logger) bindings.OutputBinding {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AliCloudTableStore) Init(metadata bindings.Metadata) error {
|
func (s *AliCloudTableStore) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||||
m, err := s.parseMetadata(metadata)
|
m, err := s.parseMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ func TestDataEncodeAndDecode(t *testing.T) {
|
||||||
metadata := bindings.Metadata{Base: metadata.Base{
|
metadata := bindings.Metadata{Base: metadata.Base{
|
||||||
Properties: getTestProperties(),
|
Properties: getTestProperties(),
|
||||||
}}
|
}}
|
||||||
aliCloudTableStore.Init(metadata)
|
aliCloudTableStore.Init(context.Background(), metadata)
|
||||||
|
|
||||||
// test create
|
// test create
|
||||||
putData := map[string]interface{}{
|
putData := map[string]interface{}{
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,7 @@ func NewAPNS(logger logger.Logger) bindings.OutputBinding {
|
||||||
|
|
||||||
// Init will configure the APNS output binding using the metadata specified
|
// Init will configure the APNS output binding using the metadata specified
|
||||||
// in the binding's configuration.
|
// in the binding's configuration.
|
||||||
func (a *APNS) Init(metadata bindings.Metadata) error {
|
func (a *APNS) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||||
if err := a.makeURLPrefix(metadata); err != nil {
|
if err := a.makeURLPrefix(metadata); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ func TestInit(t *testing.T) {
|
||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
binding := NewAPNS(testLogger).(*APNS)
|
binding := NewAPNS(testLogger).(*APNS)
|
||||||
err := binding.Init(metadata)
|
err := binding.Init(context.Background(), metadata)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, developmentPrefix, binding.urlPrefix)
|
assert.Equal(t, developmentPrefix, binding.urlPrefix)
|
||||||
})
|
})
|
||||||
|
|
@ -66,7 +66,7 @@ func TestInit(t *testing.T) {
|
||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
binding := NewAPNS(testLogger).(*APNS)
|
binding := NewAPNS(testLogger).(*APNS)
|
||||||
err := binding.Init(metadata)
|
err := binding.Init(context.Background(), metadata)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, productionPrefix, binding.urlPrefix)
|
assert.Equal(t, productionPrefix, binding.urlPrefix)
|
||||||
})
|
})
|
||||||
|
|
@ -80,7 +80,7 @@ func TestInit(t *testing.T) {
|
||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
binding := NewAPNS(testLogger).(*APNS)
|
binding := NewAPNS(testLogger).(*APNS)
|
||||||
err := binding.Init(metadata)
|
err := binding.Init(context.Background(), metadata)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, productionPrefix, binding.urlPrefix)
|
assert.Equal(t, productionPrefix, binding.urlPrefix)
|
||||||
})
|
})
|
||||||
|
|
@ -95,7 +95,7 @@ func TestInit(t *testing.T) {
|
||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
binding := NewAPNS(testLogger).(*APNS)
|
binding := NewAPNS(testLogger).(*APNS)
|
||||||
err := binding.Init(metadata)
|
err := binding.Init(context.Background(), metadata)
|
||||||
assert.Error(t, err, "invalid value for development parameter: True")
|
assert.Error(t, err, "invalid value for development parameter: True")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
@ -107,7 +107,7 @@ func TestInit(t *testing.T) {
|
||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
binding := NewAPNS(testLogger).(*APNS)
|
binding := NewAPNS(testLogger).(*APNS)
|
||||||
err := binding.Init(metadata)
|
err := binding.Init(context.Background(), metadata)
|
||||||
assert.Error(t, err, "the key-id parameter is required")
|
assert.Error(t, err, "the key-id parameter is required")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
@ -120,7 +120,7 @@ func TestInit(t *testing.T) {
|
||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
binding := NewAPNS(testLogger).(*APNS)
|
binding := NewAPNS(testLogger).(*APNS)
|
||||||
err := binding.Init(metadata)
|
err := binding.Init(context.Background(), metadata)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, testKeyID, binding.authorizationBuilder.keyID)
|
assert.Equal(t, testKeyID, binding.authorizationBuilder.keyID)
|
||||||
})
|
})
|
||||||
|
|
@ -133,7 +133,7 @@ func TestInit(t *testing.T) {
|
||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
binding := NewAPNS(testLogger).(*APNS)
|
binding := NewAPNS(testLogger).(*APNS)
|
||||||
err := binding.Init(metadata)
|
err := binding.Init(context.Background(), metadata)
|
||||||
assert.Error(t, err, "the team-id parameter is required")
|
assert.Error(t, err, "the team-id parameter is required")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
@ -146,7 +146,7 @@ func TestInit(t *testing.T) {
|
||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
binding := NewAPNS(testLogger).(*APNS)
|
binding := NewAPNS(testLogger).(*APNS)
|
||||||
err := binding.Init(metadata)
|
err := binding.Init(context.Background(), metadata)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, testTeamID, binding.authorizationBuilder.teamID)
|
assert.Equal(t, testTeamID, binding.authorizationBuilder.teamID)
|
||||||
})
|
})
|
||||||
|
|
@ -159,7 +159,7 @@ func TestInit(t *testing.T) {
|
||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
binding := NewAPNS(testLogger).(*APNS)
|
binding := NewAPNS(testLogger).(*APNS)
|
||||||
err := binding.Init(metadata)
|
err := binding.Init(context.Background(), metadata)
|
||||||
assert.Error(t, err, "the private-key parameter is required")
|
assert.Error(t, err, "the private-key parameter is required")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
@ -172,7 +172,7 @@ func TestInit(t *testing.T) {
|
||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
binding := NewAPNS(testLogger).(*APNS)
|
binding := NewAPNS(testLogger).(*APNS)
|
||||||
err := binding.Init(metadata)
|
err := binding.Init(context.Background(), metadata)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.NotNil(t, binding.authorizationBuilder.privateKey)
|
assert.NotNil(t, binding.authorizationBuilder.privateKey)
|
||||||
})
|
})
|
||||||
|
|
@ -335,7 +335,7 @@ func makeTestBinding(t *testing.T, log logger.Logger) *APNS {
|
||||||
privateKeyKey: testPrivateKey,
|
privateKeyKey: testPrivateKey,
|
||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
err := testBinding.Init(bindingMetadata)
|
err := testBinding.Init(context.Background(), bindingMetadata)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
return testBinding
|
return testBinding
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,7 @@ func NewDynamoDB(logger logger.Logger) bindings.OutputBinding {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init performs connection parsing for DynamoDB.
|
// Init performs connection parsing for DynamoDB.
|
||||||
func (d *DynamoDB) Init(metadata bindings.Metadata) error {
|
func (d *DynamoDB) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||||
meta, err := d.getDynamoDBMetadata(metadata)
|
meta, err := d.getDynamoDBMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,10 @@ package kinesis
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go/aws"
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
|
|
@ -45,6 +48,10 @@ type AWSKinesis struct {
|
||||||
streamARN *string
|
streamARN *string
|
||||||
consumerARN *string
|
consumerARN *string
|
||||||
logger logger.Logger
|
logger logger.Logger
|
||||||
|
|
||||||
|
closed atomic.Bool
|
||||||
|
closeCh chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
type kinesisMetadata struct {
|
type kinesisMetadata struct {
|
||||||
|
|
@ -83,11 +90,14 @@ type recordProcessor struct {
|
||||||
|
|
||||||
// NewAWSKinesis returns a new AWS Kinesis instance.
|
// NewAWSKinesis returns a new AWS Kinesis instance.
|
||||||
func NewAWSKinesis(logger logger.Logger) bindings.InputOutputBinding {
|
func NewAWSKinesis(logger logger.Logger) bindings.InputOutputBinding {
|
||||||
return &AWSKinesis{logger: logger}
|
return &AWSKinesis{
|
||||||
|
logger: logger,
|
||||||
|
closeCh: make(chan struct{}),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init does metadata parsing and connection creation.
|
// Init does metadata parsing and connection creation.
|
||||||
func (a *AWSKinesis) Init(metadata bindings.Metadata) error {
|
func (a *AWSKinesis) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||||
m, err := a.parseMetadata(metadata)
|
m, err := a.parseMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -107,7 +117,7 @@ func (a *AWSKinesis) Init(metadata bindings.Metadata) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
streamName := aws.String(m.StreamName)
|
streamName := aws.String(m.StreamName)
|
||||||
stream, err := client.DescribeStream(&kinesis.DescribeStreamInput{
|
stream, err := client.DescribeStreamWithContext(ctx, &kinesis.DescribeStreamInput{
|
||||||
StreamName: streamName,
|
StreamName: streamName,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -147,6 +157,10 @@ func (a *AWSKinesis) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err error) {
|
func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err error) {
|
||||||
|
if a.closed.Load() {
|
||||||
|
return errors.New("binding is closed")
|
||||||
|
}
|
||||||
|
|
||||||
if a.metadata.KinesisConsumerMode == SharedThroughput {
|
if a.metadata.KinesisConsumerMode == SharedThroughput {
|
||||||
a.worker = worker.NewWorker(a.recordProcessorFactory(ctx, handler), a.workerConfig)
|
a.worker = worker.NewWorker(a.recordProcessorFactory(ctx, handler), a.workerConfig)
|
||||||
err = a.worker.Start()
|
err = a.worker.Start()
|
||||||
|
|
@ -166,8 +180,13 @@ func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err er
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for context cancelation then stop
|
// Wait for context cancelation then stop
|
||||||
|
a.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
<-ctx.Done()
|
defer a.wg.Done()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-a.closeCh:
|
||||||
|
}
|
||||||
if a.metadata.KinesisConsumerMode == SharedThroughput {
|
if a.metadata.KinesisConsumerMode == SharedThroughput {
|
||||||
a.worker.Shutdown()
|
a.worker.Shutdown()
|
||||||
} else if a.metadata.KinesisConsumerMode == ExtendedFanout {
|
} else if a.metadata.KinesisConsumerMode == ExtendedFanout {
|
||||||
|
|
@ -188,14 +207,25 @@ func (a *AWSKinesis) Subscribe(ctx context.Context, streamDesc kinesis.StreamDes
|
||||||
|
|
||||||
a.consumerARN = consumerARN
|
a.consumerARN = consumerARN
|
||||||
|
|
||||||
|
a.wg.Add(len(streamDesc.Shards))
|
||||||
for i, shard := range streamDesc.Shards {
|
for i, shard := range streamDesc.Shards {
|
||||||
go func(idx int, s *kinesis.Shard) error {
|
go func(idx int, s *kinesis.Shard) {
|
||||||
|
defer a.wg.Done()
|
||||||
|
|
||||||
// Reconnection backoff
|
// Reconnection backoff
|
||||||
bo := backoff.NewExponentialBackOff()
|
bo := backoff.NewExponentialBackOff()
|
||||||
bo.InitialInterval = 2 * time.Second
|
bo.InitialInterval = 2 * time.Second
|
||||||
|
|
||||||
// Repeat until context is canceled
|
// Repeat until context is canceled or binding closed.
|
||||||
for ctx.Err() == nil {
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-a.closeCh:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
sub, err := a.client.SubscribeToShardWithContext(ctx, &kinesis.SubscribeToShardInput{
|
sub, err := a.client.SubscribeToShardWithContext(ctx, &kinesis.SubscribeToShardInput{
|
||||||
ConsumerARN: consumerARN,
|
ConsumerARN: consumerARN,
|
||||||
ShardId: s.ShardId,
|
ShardId: s.ShardId,
|
||||||
|
|
@ -204,8 +234,12 @@ func (a *AWSKinesis) Subscribe(ctx context.Context, streamDesc kinesis.StreamDes
|
||||||
if err != nil {
|
if err != nil {
|
||||||
wait := bo.NextBackOff()
|
wait := bo.NextBackOff()
|
||||||
a.logger.Errorf("Error while reading from shard %v: %v. Attempting to reconnect in %s...", s.ShardId, err, wait)
|
a.logger.Errorf("Error while reading from shard %v: %v. Attempting to reconnect in %s...", s.ShardId, err, wait)
|
||||||
time.Sleep(wait)
|
select {
|
||||||
continue
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-time.After(wait):
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset the backoff on connection success
|
// Reset the backoff on connection success
|
||||||
|
|
@ -223,22 +257,30 @@ func (a *AWSKinesis) Subscribe(ctx context.Context, streamDesc kinesis.StreamDes
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}(i, shard)
|
}(i, shard)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AWSKinesis) ensureConsumer(parentCtx context.Context, streamARN *string) (*string, error) {
|
func (a *AWSKinesis) Close() error {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
if a.closed.CompareAndSwap(false, true) {
|
||||||
consumer, err := a.client.DescribeStreamConsumerWithContext(ctx, &kinesis.DescribeStreamConsumerInput{
|
close(a.closeCh)
|
||||||
|
}
|
||||||
|
a.wg.Wait()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AWSKinesis) ensureConsumer(ctx context.Context, streamARN *string) (*string, error) {
|
||||||
|
// Only set timeout on consumer call.
|
||||||
|
conCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
consumer, err := a.client.DescribeStreamConsumerWithContext(conCtx, &kinesis.DescribeStreamConsumerInput{
|
||||||
ConsumerName: &a.metadata.ConsumerName,
|
ConsumerName: &a.metadata.ConsumerName,
|
||||||
StreamARN: streamARN,
|
StreamARN: streamARN,
|
||||||
})
|
})
|
||||||
cancel()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return a.registerConsumer(parentCtx, streamARN)
|
return a.registerConsumer(ctx, streamARN)
|
||||||
}
|
}
|
||||||
|
|
||||||
return consumer.ConsumerDescription.ConsumerARN, nil
|
return consumer.ConsumerDescription.ConsumerARN, nil
|
||||||
|
|
|
||||||
|
|
@ -99,7 +99,7 @@ func NewAWSS3(logger logger.Logger) bindings.OutputBinding {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init does metadata parsing and connection creation.
|
// Init does metadata parsing and connection creation.
|
||||||
func (s *AWSS3) Init(metadata bindings.Metadata) error {
|
func (s *AWSS3) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||||
m, err := s.parseMetadata(metadata)
|
m, err := s.parseMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,7 @@ func NewAWSSES(logger logger.Logger) bindings.OutputBinding {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init does metadata parsing.
|
// Init does metadata parsing.
|
||||||
func (a *AWSSES) Init(metadata bindings.Metadata) error {
|
func (a *AWSSES) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||||
// Parse input metadata
|
// Parse input metadata
|
||||||
meta, err := a.parseMetadata(metadata)
|
meta, err := a.parseMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ func NewAWSSNS(logger logger.Logger) bindings.OutputBinding {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init does metadata parsing.
|
// Init does metadata parsing.
|
||||||
func (a *AWSSNS) Init(metadata bindings.Metadata) error {
|
func (a *AWSSNS) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||||
m, err := a.parseMetadata(metadata)
|
m, err := a.parseMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,9 @@ package sqs
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go/aws"
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
|
|
@ -31,7 +34,10 @@ type AWSSQS struct {
|
||||||
Client *sqs.SQS
|
Client *sqs.SQS
|
||||||
QueueURL *string
|
QueueURL *string
|
||||||
|
|
||||||
logger logger.Logger
|
logger logger.Logger
|
||||||
|
wg sync.WaitGroup
|
||||||
|
closeCh chan struct{}
|
||||||
|
closed atomic.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type sqsMetadata struct {
|
type sqsMetadata struct {
|
||||||
|
|
@ -45,11 +51,14 @@ type sqsMetadata struct {
|
||||||
|
|
||||||
// NewAWSSQS returns a new AWS SQS instance.
|
// NewAWSSQS returns a new AWS SQS instance.
|
||||||
func NewAWSSQS(logger logger.Logger) bindings.InputOutputBinding {
|
func NewAWSSQS(logger logger.Logger) bindings.InputOutputBinding {
|
||||||
return &AWSSQS{logger: logger}
|
return &AWSSQS{
|
||||||
|
logger: logger,
|
||||||
|
closeCh: make(chan struct{}),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init does metadata parsing and connection creation.
|
// Init does metadata parsing and connection creation.
|
||||||
func (a *AWSSQS) Init(metadata bindings.Metadata) error {
|
func (a *AWSSQS) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||||
m, err := a.parseSQSMetadata(metadata)
|
m, err := a.parseSQSMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -61,7 +70,7 @@ func (a *AWSSQS) Init(metadata bindings.Metadata) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
queueName := m.QueueName
|
queueName := m.QueueName
|
||||||
resultURL, err := client.GetQueueUrl(&sqs.GetQueueUrlInput{
|
resultURL, err := client.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{
|
||||||
QueueName: aws.String(queueName),
|
QueueName: aws.String(queueName),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -89,9 +98,20 @@ func (a *AWSSQS) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bind
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error {
|
func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
|
if a.closed.Load() {
|
||||||
|
return errors.New("binding is closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
a.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
// Repeat until the context is canceled
|
defer a.wg.Done()
|
||||||
for ctx.Err() == nil {
|
|
||||||
|
// Repeat until the context is canceled or component is closed
|
||||||
|
for {
|
||||||
|
if ctx.Err() != nil || a.closed.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
result, err := a.Client.ReceiveMessageWithContext(ctx, &sqs.ReceiveMessageInput{
|
result, err := a.Client.ReceiveMessageWithContext(ctx, &sqs.ReceiveMessageInput{
|
||||||
QueueUrl: a.QueueURL,
|
QueueUrl: a.QueueURL,
|
||||||
AttributeNames: aws.StringSlice([]string{
|
AttributeNames: aws.StringSlice([]string{
|
||||||
|
|
@ -126,13 +146,25 @@ func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
time.Sleep(time.Millisecond * 50)
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-a.closeCh:
|
||||||
|
case <-time.After(time.Millisecond * 50):
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *AWSSQS) Close() error {
|
||||||
|
if a.closed.CompareAndSwap(false, true) {
|
||||||
|
close(a.closeCh)
|
||||||
|
}
|
||||||
|
a.wg.Wait()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *AWSSQS) parseSQSMetadata(metadata bindings.Metadata) (*sqsMetadata, error) {
|
func (a *AWSSQS) parseSQSMetadata(metadata bindings.Metadata) (*sqsMetadata, error) {
|
||||||
b, err := json.Marshal(metadata.Properties)
|
b, err := json.Marshal(metadata.Properties)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -92,7 +92,7 @@ func NewAzureBlobStorage(logger logger.Logger) bindings.OutputBinding {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init performs metadata parsing.
|
// Init performs metadata parsing.
|
||||||
func (a *AzureBlobStorage) Init(metadata bindings.Metadata) error {
|
func (a *AzureBlobStorage) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||||
var err error
|
var err error
|
||||||
a.containerClient, a.metadata, err = storageinternal.CreateContainerStorageClient(a.logger, metadata.Properties)
|
a.containerClient, a.metadata, err = storageinternal.CreateContainerStorageClient(a.logger, metadata.Properties)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ func NewCosmosDB(logger logger.Logger) bindings.OutputBinding {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init performs CosmosDB connection parsing and connecting.
|
// Init performs CosmosDB connection parsing and connecting.
|
||||||
func (c *CosmosDB) Init(metadata bindings.Metadata) error {
|
func (c *CosmosDB) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||||
m, err := c.parseMetadata(metadata)
|
m, err := c.parseMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -103,9 +103,9 @@ func (c *CosmosDB) Init(metadata bindings.Metadata) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
c.client = dbContainer
|
c.client = dbContainer
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), timeoutValue*time.Second)
|
readCtx, readCancel := context.WithTimeout(ctx, timeoutValue*time.Second)
|
||||||
_, err = c.client.Read(ctx, nil)
|
defer readCancel()
|
||||||
cancel()
|
_, err = c.client.Read(readCtx, nil)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,7 @@ func NewCosmosDBGremlinAPI(logger logger.Logger) bindings.OutputBinding {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init performs CosmosDBGremlinAPI connection parsing and connecting.
|
// Init performs CosmosDBGremlinAPI connection parsing and connecting.
|
||||||
func (c *CosmosDBGremlinAPI) Init(metadata bindings.Metadata) error {
|
func (c *CosmosDBGremlinAPI) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||||
c.logger.Debug("Initializing Cosmos Graph DB binding")
|
c.logger.Debug("Initializing Cosmos Graph DB binding")
|
||||||
|
|
||||||
m, err := c.parseMetadata(metadata)
|
m, err := c.parseMetadata(metadata)
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,8 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||||
|
|
@ -42,6 +44,9 @@ const armOperationTimeout = 30 * time.Second
|
||||||
type AzureEventGrid struct {
|
type AzureEventGrid struct {
|
||||||
metadata *azureEventGridMetadata
|
metadata *azureEventGridMetadata
|
||||||
logger logger.Logger
|
logger logger.Logger
|
||||||
|
closeCh chan struct{}
|
||||||
|
closed atomic.Bool
|
||||||
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
type azureEventGridMetadata struct {
|
type azureEventGridMetadata struct {
|
||||||
|
|
@ -70,11 +75,14 @@ type azureEventGridMetadata struct {
|
||||||
|
|
||||||
// NewAzureEventGrid returns a new Azure Event Grid instance.
|
// NewAzureEventGrid returns a new Azure Event Grid instance.
|
||||||
func NewAzureEventGrid(logger logger.Logger) bindings.InputOutputBinding {
|
func NewAzureEventGrid(logger logger.Logger) bindings.InputOutputBinding {
|
||||||
return &AzureEventGrid{logger: logger}
|
return &AzureEventGrid{
|
||||||
|
logger: logger,
|
||||||
|
closeCh: make(chan struct{}),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init performs metadata init.
|
// Init performs metadata init.
|
||||||
func (a *AzureEventGrid) Init(metadata bindings.Metadata) error {
|
func (a *AzureEventGrid) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||||
m, err := a.parseMetadata(metadata)
|
m, err := a.parseMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -85,6 +93,10 @@ func (a *AzureEventGrid) Init(metadata bindings.Metadata) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AzureEventGrid) Read(ctx context.Context, handler bindings.Handler) error {
|
func (a *AzureEventGrid) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
|
if a.closed.Load() {
|
||||||
|
return errors.New("binding is closed")
|
||||||
|
}
|
||||||
|
|
||||||
err := a.ensureInputBindingMetadata()
|
err := a.ensureInputBindingMetadata()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -120,17 +132,22 @@ func (a *AzureEventGrid) Read(ctx context.Context, handler bindings.Handler) err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run the server in background
|
// Run the server in background
|
||||||
|
a.wg.Add(2)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer a.wg.Done()
|
||||||
a.logger.Debugf("About to start listening for Event Grid events at http://localhost:%s/api/events", a.metadata.HandshakePort)
|
a.logger.Debugf("About to start listening for Event Grid events at http://localhost:%s/api/events", a.metadata.HandshakePort)
|
||||||
srvErr := srv.ListenAndServe(":" + a.metadata.HandshakePort)
|
srvErr := srv.ListenAndServe(":" + a.metadata.HandshakePort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.logger.Errorf("Error starting server: %v", srvErr)
|
a.logger.Errorf("Error starting server: %v", srvErr)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
// Close the server when context is canceled or binding closed.
|
||||||
// Close the server when context is canceled
|
|
||||||
go func() {
|
go func() {
|
||||||
<-ctx.Done()
|
defer a.wg.Done()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-a.closeCh:
|
||||||
|
}
|
||||||
srvErr := srv.Shutdown()
|
srvErr := srv.Shutdown()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.logger.Errorf("Error shutting down server: %v", srvErr)
|
a.logger.Errorf("Error shutting down server: %v", srvErr)
|
||||||
|
|
@ -149,6 +166,14 @@ func (a *AzureEventGrid) Operations() []bindings.OperationKind {
|
||||||
return []bindings.OperationKind{bindings.CreateOperation}
|
return []bindings.OperationKind{bindings.CreateOperation}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *AzureEventGrid) Close() error {
|
||||||
|
if a.closed.CompareAndSwap(false, true) {
|
||||||
|
close(a.closeCh)
|
||||||
|
}
|
||||||
|
a.wg.Wait()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *AzureEventGrid) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
|
func (a *AzureEventGrid) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
|
||||||
err := a.ensureOutputBindingMetadata()
|
err := a.ensureOutputBindingMetadata()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@ func NewAzureEventHubs(logger logger.Logger) bindings.InputOutputBinding {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init performs metadata init.
|
// Init performs metadata init.
|
||||||
func (a *AzureEventHubs) Init(metadata bindings.Metadata) error {
|
func (a *AzureEventHubs) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||||
return a.AzureEventHubs.Init(metadata.Properties)
|
return a.AzureEventHubs.Init(metadata.Properties)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -102,7 +102,7 @@ func testEventHubsBindingsAADAuthentication(t *testing.T) {
|
||||||
metadata := createEventHubsBindingsAADMetadata()
|
metadata := createEventHubsBindingsAADMetadata()
|
||||||
eventHubsBindings := NewAzureEventHubs(log)
|
eventHubsBindings := NewAzureEventHubs(log)
|
||||||
|
|
||||||
err := eventHubsBindings.Init(metadata)
|
err := eventHubsBindings.Init(context.Background(), metadata)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req := &bindings.InvokeRequest{
|
req := &bindings.InvokeRequest{
|
||||||
|
|
@ -146,7 +146,7 @@ func testReadIotHubEvents(t *testing.T) {
|
||||||
|
|
||||||
logger := logger.NewLogger("bindings.azure.eventhubs.integration.test")
|
logger := logger.NewLogger("bindings.azure.eventhubs.integration.test")
|
||||||
eh := NewAzureEventHubs(logger)
|
eh := NewAzureEventHubs(logger)
|
||||||
err := eh.Init(createIotHubBindingsMetadata())
|
err := eh.Init(context.Background(), createIotHubBindingsMetadata())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Invoke az CLI via bash script to send test IoT device events
|
// Invoke az CLI via bash script to send test IoT device events
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,8 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
servicebus "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus"
|
servicebus "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus"
|
||||||
|
|
@ -39,17 +41,21 @@ type AzureServiceBusQueues struct {
|
||||||
client *impl.Client
|
client *impl.Client
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
logger logger.Logger
|
logger logger.Logger
|
||||||
|
closed atomic.Bool
|
||||||
|
wg sync.WaitGroup
|
||||||
|
closeCh chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAzureServiceBusQueues returns a new AzureServiceBusQueues instance.
|
// NewAzureServiceBusQueues returns a new AzureServiceBusQueues instance.
|
||||||
func NewAzureServiceBusQueues(logger logger.Logger) bindings.InputOutputBinding {
|
func NewAzureServiceBusQueues(logger logger.Logger) bindings.InputOutputBinding {
|
||||||
return &AzureServiceBusQueues{
|
return &AzureServiceBusQueues{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
closeCh: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init parses connection properties and creates a new Service Bus Queue client.
|
// Init parses connection properties and creates a new Service Bus Queue client.
|
||||||
func (a *AzureServiceBusQueues) Init(metadata bindings.Metadata) (err error) {
|
func (a *AzureServiceBusQueues) Init(ctx context.Context, metadata bindings.Metadata) (err error) {
|
||||||
a.metadata, err = impl.ParseMetadata(metadata.Properties, a.logger, (impl.MetadataModeBinding | impl.MetadataModeQueues))
|
a.metadata, err = impl.ParseMetadata(metadata.Properties, a.logger, (impl.MetadataModeBinding | impl.MetadataModeQueues))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -62,7 +68,7 @@ func (a *AzureServiceBusQueues) Init(metadata bindings.Metadata) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Will do nothing if DisableEntityManagement is false
|
// Will do nothing if DisableEntityManagement is false
|
||||||
err = a.client.EnsureQueue(context.Background(), a.metadata.QueueName)
|
err = a.client.EnsureQueue(ctx, a.metadata.QueueName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -100,14 +106,33 @@ func (a *AzureServiceBusQueues) Invoke(invokeCtx context.Context, req *bindings.
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AzureServiceBusQueues) Read(subscribeCtx context.Context, handler bindings.Handler) error {
|
func (a *AzureServiceBusQueues) Read(parentCtx context.Context, handler bindings.Handler) error {
|
||||||
|
if a.closed.Load() {
|
||||||
|
return errors.New("binding is closed")
|
||||||
|
}
|
||||||
|
|
||||||
// Reconnection backoff policy
|
// Reconnection backoff policy
|
||||||
bo := backoff.NewExponentialBackOff()
|
bo := backoff.NewExponentialBackOff()
|
||||||
bo.MaxElapsedTime = 0
|
bo.MaxElapsedTime = 0
|
||||||
bo.InitialInterval = time.Duration(a.metadata.MinConnectionRecoveryInSec) * time.Second
|
bo.InitialInterval = time.Duration(a.metadata.MinConnectionRecoveryInSec) * time.Second
|
||||||
bo.MaxInterval = time.Duration(a.metadata.MaxConnectionRecoveryInSec) * time.Second
|
bo.MaxInterval = time.Duration(a.metadata.MaxConnectionRecoveryInSec) * time.Second
|
||||||
|
|
||||||
|
subscribeCtx, subscribeCancel := context.WithCancel(parentCtx)
|
||||||
|
|
||||||
|
// Close the subscription context when the binding is closed.
|
||||||
|
a.wg.Add(2)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer a.wg.Done()
|
||||||
|
select {
|
||||||
|
case <-a.closeCh:
|
||||||
|
subscribeCancel()
|
||||||
|
case <-parentCtx.Done():
|
||||||
|
// nop
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer a.wg.Done()
|
||||||
// Reconnect loop.
|
// Reconnect loop.
|
||||||
for {
|
for {
|
||||||
sub := impl.NewSubscription(subscribeCtx, impl.SubsriptionOptions{
|
sub := impl.NewSubscription(subscribeCtx, impl.SubsriptionOptions{
|
||||||
|
|
@ -165,7 +190,12 @@ func (a *AzureServiceBusQueues) Read(subscribeCtx context.Context, handler bindi
|
||||||
|
|
||||||
wait := bo.NextBackOff()
|
wait := bo.NextBackOff()
|
||||||
a.logger.Warnf("Subscription to queue %s lost connection, attempting to reconnect in %s...", a.metadata.QueueName, wait)
|
a.logger.Warnf("Subscription to queue %s lost connection, attempting to reconnect in %s...", a.metadata.QueueName, wait)
|
||||||
time.Sleep(wait)
|
select {
|
||||||
|
case <-time.After(wait):
|
||||||
|
// nop
|
||||||
|
case <-a.closeCh:
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|
@ -204,7 +234,11 @@ func (a *AzureServiceBusQueues) getHandlerFn(handler bindings.Handler) impl.Hand
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AzureServiceBusQueues) Close() (err error) {
|
func (a *AzureServiceBusQueues) Close() (err error) {
|
||||||
|
if a.closed.CompareAndSwap(false, true) {
|
||||||
|
close(a.closeCh)
|
||||||
|
}
|
||||||
a.logger.Debug("Closing component")
|
a.logger.Debug("Closing component")
|
||||||
a.client.CloseSender(a.metadata.QueueName)
|
a.client.CloseSender(a.metadata.QueueName)
|
||||||
|
a.wg.Wait()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,7 @@ type SignalR struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init is responsible for initializing the SignalR output based on the metadata.
|
// Init is responsible for initializing the SignalR output based on the metadata.
|
||||||
func (s *SignalR) Init(metadata bindings.Metadata) (err error) {
|
func (s *SignalR) Init(_ context.Context, metadata bindings.Metadata) (err error) {
|
||||||
s.userAgent = "dapr-" + logger.DaprVersion
|
s.userAgent = "dapr-" + logger.DaprVersion
|
||||||
|
|
||||||
err = s.parseMetadata(metadata.Properties)
|
err = s.parseMetadata(metadata.Properties)
|
||||||
|
|
|
||||||
|
|
@ -16,9 +16,12 @@ package storagequeues
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Azure/azure-storage-queue-go/azqueue"
|
"github.com/Azure/azure-storage-queue-go/azqueue"
|
||||||
|
|
@ -40,9 +43,10 @@ type consumer struct {
|
||||||
|
|
||||||
// QueueHelper enables injection for testnig.
|
// QueueHelper enables injection for testnig.
|
||||||
type QueueHelper interface {
|
type QueueHelper interface {
|
||||||
Init(metadata bindings.Metadata) (*storageQueuesMetadata, error)
|
Init(ctx context.Context, metadata bindings.Metadata) (*storageQueuesMetadata, error)
|
||||||
Write(ctx context.Context, data []byte, ttl *time.Duration) error
|
Write(ctx context.Context, data []byte, ttl *time.Duration) error
|
||||||
Read(ctx context.Context, consumer *consumer) error
|
Read(ctx context.Context, consumer *consumer) error
|
||||||
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
// AzureQueueHelper concrete impl of queue helper.
|
// AzureQueueHelper concrete impl of queue helper.
|
||||||
|
|
@ -55,7 +59,7 @@ type AzureQueueHelper struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init sets up this helper.
|
// Init sets up this helper.
|
||||||
func (d *AzureQueueHelper) Init(metadata bindings.Metadata) (*storageQueuesMetadata, error) {
|
func (d *AzureQueueHelper) Init(ctx context.Context, metadata bindings.Metadata) (*storageQueuesMetadata, error) {
|
||||||
m, err := parseMetadata(metadata)
|
m, err := parseMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -89,9 +93,9 @@ func (d *AzureQueueHelper) Init(metadata bindings.Metadata) (*storageQueuesMetad
|
||||||
d.queueURL = azqueue.NewQueueURL(*URL, p)
|
d.queueURL = azqueue.NewQueueURL(*URL, p)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
createCtx, createCancel := context.WithTimeout(ctx, 2*time.Minute)
|
||||||
_, err = d.queueURL.Create(ctx, azqueue.Metadata{})
|
_, err = d.queueURL.Create(createCtx, azqueue.Metadata{})
|
||||||
cancel()
|
createCancel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -128,7 +132,10 @@ func (d *AzureQueueHelper) Read(ctx context.Context, consumer *consumer) error {
|
||||||
}
|
}
|
||||||
if res.NumMessages() == 0 {
|
if res.NumMessages() == 0 {
|
||||||
// Queue was empty so back off by 10 seconds before trying again
|
// Queue was empty so back off by 10 seconds before trying again
|
||||||
time.Sleep(10 * time.Second)
|
select {
|
||||||
|
case <-time.After(10 * time.Second):
|
||||||
|
case <-ctx.Done():
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
mt := res.Message(0).Text
|
mt := res.Message(0).Text
|
||||||
|
|
@ -162,6 +169,10 @@ func (d *AzureQueueHelper) Read(ctx context.Context, consumer *consumer) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *AzureQueueHelper) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// NewAzureQueueHelper creates new helper.
|
// NewAzureQueueHelper creates new helper.
|
||||||
func NewAzureQueueHelper(logger logger.Logger) QueueHelper {
|
func NewAzureQueueHelper(logger logger.Logger) QueueHelper {
|
||||||
return &AzureQueueHelper{
|
return &AzureQueueHelper{
|
||||||
|
|
@ -175,6 +186,10 @@ type AzureStorageQueues struct {
|
||||||
helper QueueHelper
|
helper QueueHelper
|
||||||
|
|
||||||
logger logger.Logger
|
logger logger.Logger
|
||||||
|
|
||||||
|
wg sync.WaitGroup
|
||||||
|
closeCh chan struct{}
|
||||||
|
closed atomic.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type storageQueuesMetadata struct {
|
type storageQueuesMetadata struct {
|
||||||
|
|
@ -189,12 +204,16 @@ type storageQueuesMetadata struct {
|
||||||
|
|
||||||
// NewAzureStorageQueues returns a new AzureStorageQueues instance.
|
// NewAzureStorageQueues returns a new AzureStorageQueues instance.
|
||||||
func NewAzureStorageQueues(logger logger.Logger) bindings.InputOutputBinding {
|
func NewAzureStorageQueues(logger logger.Logger) bindings.InputOutputBinding {
|
||||||
return &AzureStorageQueues{helper: NewAzureQueueHelper(logger), logger: logger}
|
return &AzureStorageQueues{
|
||||||
|
helper: NewAzureQueueHelper(logger),
|
||||||
|
logger: logger,
|
||||||
|
closeCh: make(chan struct{}),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init parses connection properties and creates a new Storage Queue client.
|
// Init parses connection properties and creates a new Storage Queue client.
|
||||||
func (a *AzureStorageQueues) Init(metadata bindings.Metadata) (err error) {
|
func (a *AzureStorageQueues) Init(ctx context.Context, metadata bindings.Metadata) (err error) {
|
||||||
a.metadata, err = a.helper.Init(metadata)
|
a.metadata, err = a.helper.Init(ctx, metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -261,14 +280,32 @@ func (a *AzureStorageQueues) Invoke(ctx context.Context, req *bindings.InvokeReq
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AzureStorageQueues) Read(ctx context.Context, handler bindings.Handler) error {
|
func (a *AzureStorageQueues) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
|
if a.closed.Load() {
|
||||||
|
return errors.New("input binding is closed")
|
||||||
|
}
|
||||||
|
|
||||||
c := consumer{
|
c := consumer{
|
||||||
callback: handler,
|
callback: handler,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close read context when binding is closed.
|
||||||
|
readCtx, cancel := context.WithCancel(ctx)
|
||||||
|
a.wg.Add(2)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer a.wg.Done()
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-a.closeCh:
|
||||||
|
case <-ctx.Done():
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer a.wg.Done()
|
||||||
// Read until context is canceled
|
// Read until context is canceled
|
||||||
var err error
|
var err error
|
||||||
for ctx.Err() == nil {
|
for readCtx.Err() == nil {
|
||||||
err = a.helper.Read(ctx, &c)
|
err = a.helper.Read(readCtx, &c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.logger.Errorf("error from c: %s", err)
|
a.logger.Errorf("error from c: %s", err)
|
||||||
}
|
}
|
||||||
|
|
@ -277,3 +314,11 @@ func (a *AzureStorageQueues) Read(ctx context.Context, handler bindings.Handler)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *AzureStorageQueues) Close() error {
|
||||||
|
if a.closed.CompareAndSwap(false, true) {
|
||||||
|
close(a.closeCh)
|
||||||
|
}
|
||||||
|
a.wg.Wait()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ package storagequeues
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
|
@ -32,9 +33,11 @@ type MockHelper struct {
|
||||||
mock.Mock
|
mock.Mock
|
||||||
messages chan []byte
|
messages chan []byte
|
||||||
metadata *storageQueuesMetadata
|
metadata *storageQueuesMetadata
|
||||||
|
closeCh chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockHelper) Init(metadata bindings.Metadata) (*storageQueuesMetadata, error) {
|
func (m *MockHelper) Init(ctx context.Context, metadata bindings.Metadata) (*storageQueuesMetadata, error) {
|
||||||
m.messages = make(chan []byte, 10)
|
m.messages = make(chan []byte, 10)
|
||||||
var err error
|
var err error
|
||||||
m.metadata, err = parseMetadata(metadata)
|
m.metadata, err = parseMetadata(metadata)
|
||||||
|
|
@ -50,12 +53,23 @@ func (m *MockHelper) Write(ctx context.Context, data []byte, ttl *time.Duration)
|
||||||
func (m *MockHelper) Read(ctx context.Context, consumer *consumer) error {
|
func (m *MockHelper) Read(ctx context.Context, consumer *consumer) error {
|
||||||
retvals := m.Called(ctx, consumer)
|
retvals := m.Called(ctx, consumer)
|
||||||
|
|
||||||
|
readCtx, cancel := context.WithCancel(ctx)
|
||||||
|
m.wg.Add(2)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer m.wg.Done()
|
||||||
|
defer cancel()
|
||||||
|
select {
|
||||||
|
case <-readCtx.Done():
|
||||||
|
case <-m.closeCh:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer m.wg.Done()
|
||||||
for msg := range m.messages {
|
for msg := range m.messages {
|
||||||
if m.metadata.DecodeBase64 {
|
if m.metadata.DecodeBase64 {
|
||||||
msg, _ = base64.StdEncoding.DecodeString(string(msg))
|
msg, _ = base64.StdEncoding.DecodeString(string(msg))
|
||||||
}
|
}
|
||||||
go consumer.callback(ctx, &bindings.ReadResponse{
|
go consumer.callback(readCtx, &bindings.ReadResponse{
|
||||||
Data: msg,
|
Data: msg,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -64,18 +78,24 @@ func (m *MockHelper) Read(ctx context.Context, consumer *consumer) error {
|
||||||
return retvals.Error(0)
|
return retvals.Error(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockHelper) Close() error {
|
||||||
|
defer m.wg.Wait()
|
||||||
|
close(m.closeCh)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestWriteQueue(t *testing.T) {
|
func TestWriteQueue(t *testing.T) {
|
||||||
mm := new(MockHelper)
|
mm := new(MockHelper)
|
||||||
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.MatchedBy(func(in *time.Duration) bool {
|
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.MatchedBy(func(in *time.Duration) bool {
|
||||||
return in == nil
|
return in == nil
|
||||||
})).Return(nil)
|
})).Return(nil)
|
||||||
|
|
||||||
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
|
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
|
||||||
|
|
||||||
m := bindings.Metadata{}
|
m := bindings.Metadata{}
|
||||||
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
|
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
|
||||||
|
|
||||||
err := a.Init(m)
|
err := a.Init(context.Background(), m)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
r := bindings.InvokeRequest{Data: []byte("This is my message")}
|
r := bindings.InvokeRequest{Data: []byte("This is my message")}
|
||||||
|
|
@ -83,6 +103,7 @@ func TestWriteQueue(t *testing.T) {
|
||||||
_, err = a.Invoke(context.Background(), &r)
|
_, err = a.Invoke(context.Background(), &r)
|
||||||
|
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
assert.NoError(t, a.Close())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWriteWithTTLInQueue(t *testing.T) {
|
func TestWriteWithTTLInQueue(t *testing.T) {
|
||||||
|
|
@ -91,12 +112,12 @@ func TestWriteWithTTLInQueue(t *testing.T) {
|
||||||
return in != nil && *in == time.Second
|
return in != nil && *in == time.Second
|
||||||
})).Return(nil)
|
})).Return(nil)
|
||||||
|
|
||||||
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
|
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
|
||||||
|
|
||||||
m := bindings.Metadata{}
|
m := bindings.Metadata{}
|
||||||
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1", metadata.TTLMetadataKey: "1"}
|
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1", metadata.TTLMetadataKey: "1"}
|
||||||
|
|
||||||
err := a.Init(m)
|
err := a.Init(context.Background(), m)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
r := bindings.InvokeRequest{Data: []byte("This is my message")}
|
r := bindings.InvokeRequest{Data: []byte("This is my message")}
|
||||||
|
|
@ -104,6 +125,7 @@ func TestWriteWithTTLInQueue(t *testing.T) {
|
||||||
_, err = a.Invoke(context.Background(), &r)
|
_, err = a.Invoke(context.Background(), &r)
|
||||||
|
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
assert.NoError(t, a.Close())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWriteWithTTLInWrite(t *testing.T) {
|
func TestWriteWithTTLInWrite(t *testing.T) {
|
||||||
|
|
@ -112,12 +134,12 @@ func TestWriteWithTTLInWrite(t *testing.T) {
|
||||||
return in != nil && *in == time.Second
|
return in != nil && *in == time.Second
|
||||||
})).Return(nil)
|
})).Return(nil)
|
||||||
|
|
||||||
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
|
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
|
||||||
|
|
||||||
m := bindings.Metadata{}
|
m := bindings.Metadata{}
|
||||||
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1", metadata.TTLMetadataKey: "1"}
|
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1", metadata.TTLMetadataKey: "1"}
|
||||||
|
|
||||||
err := a.Init(m)
|
err := a.Init(context.Background(), m)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
r := bindings.InvokeRequest{
|
r := bindings.InvokeRequest{
|
||||||
|
|
@ -128,6 +150,7 @@ func TestWriteWithTTLInWrite(t *testing.T) {
|
||||||
_, err = a.Invoke(context.Background(), &r)
|
_, err = a.Invoke(context.Background(), &r)
|
||||||
|
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
assert.NoError(t, a.Close())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Uncomment this function to write a message to local storage queue
|
// Uncomment this function to write a message to local storage queue
|
||||||
|
|
@ -138,7 +161,7 @@ func TestWriteWithTTLInWrite(t *testing.T) {
|
||||||
m := bindings.Metadata{}
|
m := bindings.Metadata{}
|
||||||
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
|
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
|
||||||
|
|
||||||
err := a.Init(m)
|
err := a.Init(context.Background(), m)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
r := bindings.InvokeRequest{Data: []byte("This is my message")}
|
r := bindings.InvokeRequest{Data: []byte("This is my message")}
|
||||||
|
|
@ -152,12 +175,12 @@ func TestReadQueue(t *testing.T) {
|
||||||
mm := new(MockHelper)
|
mm := new(MockHelper)
|
||||||
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)
|
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)
|
||||||
mm.On("Read", mock.AnythingOfType("*context.cancelCtx"), mock.AnythingOfType("*storagequeues.consumer")).Return(nil)
|
mm.On("Read", mock.AnythingOfType("*context.cancelCtx"), mock.AnythingOfType("*storagequeues.consumer")).Return(nil)
|
||||||
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
|
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
|
||||||
|
|
||||||
m := bindings.Metadata{}
|
m := bindings.Metadata{}
|
||||||
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
|
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
|
||||||
|
|
||||||
err := a.Init(m)
|
err := a.Init(context.Background(), m)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
r := bindings.InvokeRequest{Data: []byte("This is my message")}
|
r := bindings.InvokeRequest{Data: []byte("This is my message")}
|
||||||
|
|
@ -186,6 +209,7 @@ func TestReadQueue(t *testing.T) {
|
||||||
t.Fatal("Timeout waiting for messages")
|
t.Fatal("Timeout waiting for messages")
|
||||||
}
|
}
|
||||||
assert.Equal(t, 1, received)
|
assert.Equal(t, 1, received)
|
||||||
|
assert.NoError(t, a.Close())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReadQueueDecode(t *testing.T) {
|
func TestReadQueueDecode(t *testing.T) {
|
||||||
|
|
@ -193,12 +217,12 @@ func TestReadQueueDecode(t *testing.T) {
|
||||||
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)
|
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)
|
||||||
mm.On("Read", mock.AnythingOfType("*context.cancelCtx"), mock.AnythingOfType("*storagequeues.consumer")).Return(nil)
|
mm.On("Read", mock.AnythingOfType("*context.cancelCtx"), mock.AnythingOfType("*storagequeues.consumer")).Return(nil)
|
||||||
|
|
||||||
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
|
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
|
||||||
|
|
||||||
m := bindings.Metadata{}
|
m := bindings.Metadata{}
|
||||||
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1", "decodeBase64": "true"}
|
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1", "decodeBase64": "true"}
|
||||||
|
|
||||||
err := a.Init(m)
|
err := a.Init(context.Background(), m)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
r := bindings.InvokeRequest{Data: []byte("VGhpcyBpcyBteSBtZXNzYWdl")}
|
r := bindings.InvokeRequest{Data: []byte("VGhpcyBpcyBteSBtZXNzYWdl")}
|
||||||
|
|
@ -227,6 +251,7 @@ func TestReadQueueDecode(t *testing.T) {
|
||||||
t.Fatal("Timeout waiting for messages")
|
t.Fatal("Timeout waiting for messages")
|
||||||
}
|
}
|
||||||
assert.Equal(t, 1, received)
|
assert.Equal(t, 1, received)
|
||||||
|
assert.NoError(t, a.Close())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Uncomment this function to test reding from local queue
|
// Uncomment this function to test reding from local queue
|
||||||
|
|
@ -237,7 +262,7 @@ func TestReadQueueDecode(t *testing.T) {
|
||||||
m := bindings.Metadata{}
|
m := bindings.Metadata{}
|
||||||
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
|
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
|
||||||
|
|
||||||
err := a.Init(m)
|
err := a.Init(context.Background(), m)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
r := bindings.InvokeRequest{Data: []byte("This is my message")}
|
r := bindings.InvokeRequest{Data: []byte("This is my message")}
|
||||||
|
|
@ -263,12 +288,12 @@ func TestReadQueueNoMessage(t *testing.T) {
|
||||||
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)
|
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)
|
||||||
mm.On("Read", mock.AnythingOfType("*context.cancelCtx"), mock.AnythingOfType("*storagequeues.consumer")).Return(nil)
|
mm.On("Read", mock.AnythingOfType("*context.cancelCtx"), mock.AnythingOfType("*storagequeues.consumer")).Return(nil)
|
||||||
|
|
||||||
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
|
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
|
||||||
|
|
||||||
m := bindings.Metadata{}
|
m := bindings.Metadata{}
|
||||||
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
|
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
|
||||||
|
|
||||||
err := a.Init(m)
|
err := a.Init(context.Background(), m)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
@ -285,6 +310,7 @@ func TestReadQueueNoMessage(t *testing.T) {
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
cancel()
|
cancel()
|
||||||
assert.Equal(t, 0, received)
|
assert.Equal(t, 0, received)
|
||||||
|
assert.NoError(t, a.Close())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseMetadata(t *testing.T) {
|
func TestParseMetadata(t *testing.T) {
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ func NewCFQueues(logger logger.Logger) bindings.OutputBinding {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init the component.
|
// Init the component.
|
||||||
func (q *CFQueues) Init(metadata bindings.Metadata) error {
|
func (q *CFQueues) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||||
// Decode the metadata
|
// Decode the metadata
|
||||||
err := mapstructure.Decode(metadata.Properties, &q.metadata)
|
err := mapstructure.Decode(metadata.Properties, &q.metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ func NewCommercetools(logger logger.Logger) bindings.OutputBinding {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init does metadata parsing and connection establishment.
|
// Init does metadata parsing and connection establishment.
|
||||||
func (ct *Binding) Init(metadata bindings.Metadata) error {
|
func (ct *Binding) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||||
commercetoolsM, err := ct.getCommercetoolsMetadata(metadata)
|
commercetoolsM, err := ct.getCommercetoolsMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,8 @@ package cron
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/benbjohnson/clock"
|
"github.com/benbjohnson/clock"
|
||||||
|
|
@ -34,6 +36,9 @@ type Binding struct {
|
||||||
schedule string
|
schedule string
|
||||||
parser cron.Parser
|
parser cron.Parser
|
||||||
clk clock.Clock
|
clk clock.Clock
|
||||||
|
closed atomic.Bool
|
||||||
|
closeCh chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCron returns a new Cron event input binding.
|
// NewCron returns a new Cron event input binding.
|
||||||
|
|
@ -48,6 +53,7 @@ func NewCronWithClock(logger logger.Logger, clk clock.Clock) bindings.InputBindi
|
||||||
parser: cron.NewParser(
|
parser: cron.NewParser(
|
||||||
cron.SecondOptional | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor,
|
cron.SecondOptional | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor,
|
||||||
),
|
),
|
||||||
|
closeCh: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -56,7 +62,7 @@ func NewCronWithClock(logger logger.Logger, clk clock.Clock) bindings.InputBindi
|
||||||
//
|
//
|
||||||
// "15 * * * * *" - Every 15 sec
|
// "15 * * * * *" - Every 15 sec
|
||||||
// "0 30 * * * *" - Every 30 min
|
// "0 30 * * * *" - Every 30 min
|
||||||
func (b *Binding) Init(metadata bindings.Metadata) error {
|
func (b *Binding) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||||
b.name = metadata.Name
|
b.name = metadata.Name
|
||||||
s, f := metadata.Properties["schedule"]
|
s, f := metadata.Properties["schedule"]
|
||||||
if !f || s == "" {
|
if !f || s == "" {
|
||||||
|
|
@ -73,6 +79,10 @@ func (b *Binding) Init(metadata bindings.Metadata) error {
|
||||||
|
|
||||||
// Read triggers the Cron scheduler.
|
// Read triggers the Cron scheduler.
|
||||||
func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
|
func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
|
if b.closed.Load() {
|
||||||
|
return errors.New("binding is closed")
|
||||||
|
}
|
||||||
|
|
||||||
c := cron.New(cron.WithParser(b.parser), cron.WithClock(b.clk))
|
c := cron.New(cron.WithParser(b.parser), cron.WithClock(b.clk))
|
||||||
id, err := c.AddFunc(b.schedule, func() {
|
id, err := c.AddFunc(b.schedule, func() {
|
||||||
b.logger.Debugf("name: %s, schedule fired: %v", b.name, time.Now())
|
b.logger.Debugf("name: %s, schedule fired: %v", b.name, time.Now())
|
||||||
|
|
@ -89,12 +99,25 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
c.Start()
|
c.Start()
|
||||||
b.logger.Debugf("name: %s, next run: %v", b.name, time.Until(c.Entry(id).Next))
|
b.logger.Debugf("name: %s, next run: %v", b.name, time.Until(c.Entry(id).Next))
|
||||||
|
|
||||||
|
b.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
// Wait for context to be canceled
|
defer b.wg.Done()
|
||||||
<-ctx.Done()
|
// Wait for context to be canceled or component to be closed.
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-b.closeCh:
|
||||||
|
}
|
||||||
b.logger.Debugf("name: %s, stopping schedule: %s", b.name, b.schedule)
|
b.logger.Debugf("name: %s, stopping schedule: %s", b.name, b.schedule)
|
||||||
c.Stop()
|
c.Stop()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *Binding) Close() error {
|
||||||
|
if b.closed.CompareAndSwap(false, true) {
|
||||||
|
close(b.closeCh)
|
||||||
|
}
|
||||||
|
b.wg.Wait()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ package cron
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"os"
|
"os"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
|
@ -84,7 +85,7 @@ func TestCronInitSuccess(t *testing.T) {
|
||||||
|
|
||||||
for _, test := range initTests {
|
for _, test := range initTests {
|
||||||
c := getNewCron()
|
c := getNewCron()
|
||||||
err := c.Init(getTestMetadata(test.schedule))
|
err := c.Init(context.Background(), getTestMetadata(test.schedule))
|
||||||
if test.errorExpected {
|
if test.errorExpected {
|
||||||
assert.Errorf(t, err, "Got no error while initializing an invalid schedule: %s", test.schedule)
|
assert.Errorf(t, err, "Got no error while initializing an invalid schedule: %s", test.schedule)
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -99,38 +100,41 @@ func TestCronRead(t *testing.T) {
|
||||||
clk := clock.NewMock()
|
clk := clock.NewMock()
|
||||||
c := getNewCronWithClock(clk)
|
c := getNewCronWithClock(clk)
|
||||||
schedule := "@every 1s"
|
schedule := "@every 1s"
|
||||||
assert.NoErrorf(t, c.Init(getTestMetadata(schedule)), "error initializing valid schedule")
|
assert.NoErrorf(t, c.Init(context.Background(), getTestMetadata(schedule)), "error initializing valid schedule")
|
||||||
expectedCount := 5
|
expectedCount := int32(5)
|
||||||
observedCount := 0
|
var observedCount atomic.Int32
|
||||||
err := c.Read(context.Background(), func(ctx context.Context, res *bindings.ReadResponse) ([]byte, error) {
|
err := c.Read(context.Background(), func(ctx context.Context, res *bindings.ReadResponse) ([]byte, error) {
|
||||||
assert.NotNil(t, res)
|
assert.NotNil(t, res)
|
||||||
observedCount++
|
observedCount.Add(1)
|
||||||
return nil, nil
|
return nil, nil
|
||||||
})
|
})
|
||||||
// Check if cron triggers 5 times in 5 seconds
|
// Check if cron triggers 5 times in 5 seconds
|
||||||
for i := 0; i < expectedCount; i++ {
|
for i := int32(0); i < expectedCount; i++ {
|
||||||
// Add time to mock clock in 1 second intervals using loop to allow cron go routine to run
|
// Add time to mock clock in 1 second intervals using loop to allow cron go routine to run
|
||||||
clk.Add(time.Second)
|
clk.Add(time.Second)
|
||||||
}
|
}
|
||||||
// Wait for 1 second after adding the last second to mock clock to allow cron to finish triggering
|
// Wait for 1 second after adding the last second to mock clock to allow cron to finish triggering
|
||||||
time.Sleep(1 * time.Second)
|
assert.Eventually(t, func() bool {
|
||||||
assert.Equal(t, expectedCount, observedCount, "Cron did not trigger expected number of times, expected %d, got %d", expectedCount, observedCount)
|
return observedCount.Load() == expectedCount
|
||||||
|
}, time.Second, time.Millisecond*10,
|
||||||
|
"Cron did not trigger expected number of times, expected %d, got %d", expectedCount, observedCount.Load())
|
||||||
assert.NoErrorf(t, err, "error on read")
|
assert.NoErrorf(t, err, "error on read")
|
||||||
|
assert.NoError(t, c.Close())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCronReadWithContextCancellation(t *testing.T) {
|
func TestCronReadWithContextCancellation(t *testing.T) {
|
||||||
clk := clock.NewMock()
|
clk := clock.NewMock()
|
||||||
c := getNewCronWithClock(clk)
|
c := getNewCronWithClock(clk)
|
||||||
schedule := "@every 1s"
|
schedule := "@every 1s"
|
||||||
assert.NoErrorf(t, c.Init(getTestMetadata(schedule)), "error initializing valid schedule")
|
assert.NoErrorf(t, c.Init(context.Background(), getTestMetadata(schedule)), "error initializing valid schedule")
|
||||||
expectedCount := 5
|
expectedCount := int32(5)
|
||||||
observedCount := 0
|
var observedCount atomic.Int32
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
err := c.Read(ctx, func(ctx context.Context, res *bindings.ReadResponse) ([]byte, error) {
|
err := c.Read(ctx, func(ctx context.Context, res *bindings.ReadResponse) ([]byte, error) {
|
||||||
assert.NotNil(t, res)
|
assert.NotNil(t, res)
|
||||||
assert.LessOrEqualf(t, observedCount, expectedCount, "Invoke didn't stop the schedule")
|
assert.LessOrEqualf(t, observedCount.Load(), expectedCount, "Invoke didn't stop the schedule")
|
||||||
observedCount++
|
observedCount.Add(1)
|
||||||
if observedCount == expectedCount {
|
if observedCount.Load() == expectedCount {
|
||||||
// Cancel context after 5 triggers
|
// Cancel context after 5 triggers
|
||||||
cancel()
|
cancel()
|
||||||
}
|
}
|
||||||
|
|
@ -141,7 +145,10 @@ func TestCronReadWithContextCancellation(t *testing.T) {
|
||||||
// Add time to mock clock in 1 second intervals using loop to allow cron go routine to run
|
// Add time to mock clock in 1 second intervals using loop to allow cron go routine to run
|
||||||
clk.Add(time.Second)
|
clk.Add(time.Second)
|
||||||
}
|
}
|
||||||
time.Sleep(1 * time.Second)
|
assert.Eventually(t, func() bool {
|
||||||
assert.Equal(t, expectedCount, observedCount, "Cron did not trigger expected number of times, expected %d, got %d", expectedCount, observedCount)
|
return observedCount.Load() == expectedCount
|
||||||
|
}, time.Second, time.Millisecond*10,
|
||||||
|
"Cron did not trigger expected number of times, expected %d, got %d", expectedCount, observedCount.Load())
|
||||||
assert.NoErrorf(t, err, "error on read")
|
assert.NoErrorf(t, err, "error on read")
|
||||||
|
assert.NoError(t, c.Close())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ func NewDubboOutput(logger logger.Logger) bindings.OutputBinding {
|
||||||
return dubboBinding
|
return dubboBinding
|
||||||
}
|
}
|
||||||
|
|
||||||
func (out *DubboOutputBinding) Init(_ bindings.Metadata) error {
|
func (out *DubboOutputBinding) Init(_ context.Context, _ bindings.Metadata) error {
|
||||||
dubboImpl.SetSerializer(constant.Hessian2Serialization, HessianSerializer{})
|
dubboImpl.SetSerializer(constant.Hessian2Serialization, HessianSerializer{})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -54,12 +54,13 @@ func TestInvoke(t *testing.T) {
|
||||||
// 0. init dapr provided and dubbo server
|
// 0. init dapr provided and dubbo server
|
||||||
stopCh := make(chan struct{})
|
stopCh := make(chan struct{})
|
||||||
defer close(stopCh)
|
defer close(stopCh)
|
||||||
|
// Create output and set serializer before go routine to prevent data race.
|
||||||
|
output := NewDubboOutput(logger.NewLogger("test"))
|
||||||
|
dubboImpl.SetSerializer(constant.Hessian2Serialization, HessianSerializer{})
|
||||||
go func() {
|
go func() {
|
||||||
assert.Nil(t, runDubboServer(stopCh))
|
assert.Nil(t, runDubboServer(stopCh))
|
||||||
}()
|
}()
|
||||||
time.Sleep(time.Second * 3)
|
time.Sleep(time.Second * 3)
|
||||||
dubboImpl.SetSerializer(constant.Hessian2Serialization, HessianSerializer{})
|
|
||||||
output := NewDubboOutput(logger.NewLogger("test"))
|
|
||||||
|
|
||||||
// 1. create req/rsp value
|
// 1. create req/rsp value
|
||||||
reqUser := &User{Name: testName}
|
reqUser := &User{Name: testName}
|
||||||
|
|
|
||||||
|
|
@ -83,14 +83,13 @@ func NewGCPStorage(logger logger.Logger) bindings.OutputBinding {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init performs connection parsing.
|
// Init performs connection parsing.
|
||||||
func (g *GCPStorage) Init(metadata bindings.Metadata) error {
|
func (g *GCPStorage) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||||
m, b, err := g.parseMetadata(metadata)
|
m, b, err := g.parseMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
clientOptions := option.WithCredentialsJSON(b)
|
clientOptions := option.WithCredentialsJSON(b)
|
||||||
ctx := context.Background()
|
|
||||||
client, err := storage.NewClient(ctx, clientOptions)
|
client, err := storage.NewClient(ctx, clientOptions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,10 @@ package pubsub
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"cloud.google.com/go/pubsub"
|
"cloud.google.com/go/pubsub"
|
||||||
"google.golang.org/api/option"
|
"google.golang.org/api/option"
|
||||||
|
|
@ -36,6 +39,9 @@ type GCPPubSub struct {
|
||||||
client *pubsub.Client
|
client *pubsub.Client
|
||||||
metadata *pubSubMetadata
|
metadata *pubSubMetadata
|
||||||
logger logger.Logger
|
logger logger.Logger
|
||||||
|
closed atomic.Bool
|
||||||
|
closeCh chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
type pubSubMetadata struct {
|
type pubSubMetadata struct {
|
||||||
|
|
@ -55,11 +61,14 @@ type pubSubMetadata struct {
|
||||||
|
|
||||||
// NewGCPPubSub returns a new GCPPubSub instance.
|
// NewGCPPubSub returns a new GCPPubSub instance.
|
||||||
func NewGCPPubSub(logger logger.Logger) bindings.InputOutputBinding {
|
func NewGCPPubSub(logger logger.Logger) bindings.InputOutputBinding {
|
||||||
return &GCPPubSub{logger: logger}
|
return &GCPPubSub{
|
||||||
|
logger: logger,
|
||||||
|
closeCh: make(chan struct{}),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init parses metadata and creates a new Pub Sub client.
|
// Init parses metadata and creates a new Pub Sub client.
|
||||||
func (g *GCPPubSub) Init(metadata bindings.Metadata) error {
|
func (g *GCPPubSub) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||||
b, err := g.parseMetadata(metadata)
|
b, err := g.parseMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -71,7 +80,6 @@ func (g *GCPPubSub) Init(metadata bindings.Metadata) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
clientOptions := option.WithCredentialsJSON(b)
|
clientOptions := option.WithCredentialsJSON(b)
|
||||||
ctx := context.Background()
|
|
||||||
pubsubClient, err := pubsub.NewClient(ctx, pubsubMeta.ProjectID, clientOptions)
|
pubsubClient, err := pubsub.NewClient(ctx, pubsubMeta.ProjectID, clientOptions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error creating pubsub client: %s", err)
|
return fmt.Errorf("error creating pubsub client: %s", err)
|
||||||
|
|
@ -88,7 +96,12 @@ func (g *GCPPubSub) parseMetadata(metadata bindings.Metadata) ([]byte, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GCPPubSub) Read(ctx context.Context, handler bindings.Handler) error {
|
func (g *GCPPubSub) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
|
if g.closed.Load() {
|
||||||
|
return errors.New("binding is closed")
|
||||||
|
}
|
||||||
|
g.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer g.wg.Done()
|
||||||
sub := g.client.Subscription(g.metadata.Subscription)
|
sub := g.client.Subscription(g.metadata.Subscription)
|
||||||
err := sub.Receive(ctx, func(c context.Context, m *pubsub.Message) {
|
err := sub.Receive(ctx, func(c context.Context, m *pubsub.Message) {
|
||||||
_, err := handler(c, &bindings.ReadResponse{
|
_, err := handler(c, &bindings.ReadResponse{
|
||||||
|
|
@ -128,5 +141,9 @@ func (g *GCPPubSub) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GCPPubSub) Close() error {
|
func (g *GCPPubSub) Close() error {
|
||||||
|
if g.closed.CompareAndSwap(false, true) {
|
||||||
|
close(g.closeCh)
|
||||||
|
}
|
||||||
|
defer g.wg.Wait()
|
||||||
return g.client.Close()
|
return g.client.Close()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,7 @@ func NewGraphQL(logger logger.Logger) bindings.OutputBinding {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init initializes the GraphQL binding.
|
// Init initializes the GraphQL binding.
|
||||||
func (gql *GraphQL) Init(metadata bindings.Metadata) error {
|
func (gql *GraphQL) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||||
gql.logger.Debug("GraphQL Error: Initializing GraphQL binding")
|
gql.logger.Debug("GraphQL Error: Initializing GraphQL binding")
|
||||||
|
|
||||||
p := metadata.Properties
|
p := metadata.Properties
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,7 @@ func NewHTTP(logger logger.Logger) bindings.OutputBinding {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init performs metadata parsing.
|
// Init performs metadata parsing.
|
||||||
func (h *HTTPSource) Init(metadata bindings.Metadata) error {
|
func (h *HTTPSource) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||||
var err error
|
var err error
|
||||||
if err = mapstructure.Decode(metadata.Properties, &h.metadata); err != nil {
|
if err = mapstructure.Decode(metadata.Properties, &h.metadata); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -104,7 +104,7 @@ func (h *HTTPSource) Init(metadata bindings.Metadata) error {
|
||||||
Transport: netTransport,
|
Transport: netTransport,
|
||||||
}
|
}
|
||||||
|
|
||||||
if val, ok := metadata.Properties["errorIfNot2XX"]; ok {
|
if val := metadata.Properties["errorIfNot2XX"]; val != "" {
|
||||||
h.errorIfNot2XX = utils.IsTruthy(val)
|
h.errorIfNot2XX = utils.IsTruthy(val)
|
||||||
} else {
|
} else {
|
||||||
// Default behavior
|
// Default behavior
|
||||||
|
|
|
||||||
|
|
@ -132,7 +132,7 @@ func InitBinding(s *httptest.Server, extraProps map[string]string) (bindings.Out
|
||||||
}
|
}
|
||||||
|
|
||||||
hs := NewHTTP(logger.NewLogger("test"))
|
hs := NewHTTP(logger.NewLogger("test"))
|
||||||
err := hs.Init(m)
|
err := hs.Init(context.Background(), m)
|
||||||
return hs, err
|
return hs, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -269,7 +269,7 @@ func InitBindingForHTTPS(s *httptest.Server, extraProps map[string]string) (bind
|
||||||
m.Properties[k] = v
|
m.Properties[k] = v
|
||||||
}
|
}
|
||||||
hs := NewHTTP(logger.NewLogger("test"))
|
hs := NewHTTP(logger.NewLogger("test"))
|
||||||
err := hs.Init(m)
|
err := hs.Init(context.Background(), m)
|
||||||
return hs, err
|
return hs, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -75,7 +75,7 @@ func NewHuaweiOBS(logger logger.Logger) bindings.OutputBinding {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init does metadata parsing and connection creation.
|
// Init does metadata parsing and connection creation.
|
||||||
func (o *HuaweiOBS) Init(metadata bindings.Metadata) error {
|
func (o *HuaweiOBS) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||||
o.logger.Debugf("initializing Huawei OBS binding and parsing metadata")
|
o.logger.Debugf("initializing Huawei OBS binding and parsing metadata")
|
||||||
|
|
||||||
m, err := o.parseMetadata(metadata)
|
m, err := o.parseMetadata(metadata)
|
||||||
|
|
|
||||||
|
|
@ -92,7 +92,7 @@ func TestInit(t *testing.T) {
|
||||||
"accessKey": "dummy-ak",
|
"accessKey": "dummy-ak",
|
||||||
"secretKey": "dummy-sk",
|
"secretKey": "dummy-sk",
|
||||||
}
|
}
|
||||||
err := obs.Init(m)
|
err := obs.Init(context.Background(), m)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
})
|
})
|
||||||
t.Run("Init with missing bucket name", func(t *testing.T) {
|
t.Run("Init with missing bucket name", func(t *testing.T) {
|
||||||
|
|
@ -102,7 +102,7 @@ func TestInit(t *testing.T) {
|
||||||
"accessKey": "dummy-ak",
|
"accessKey": "dummy-ak",
|
||||||
"secretKey": "dummy-sk",
|
"secretKey": "dummy-sk",
|
||||||
}
|
}
|
||||||
err := obs.Init(m)
|
err := obs.Init(context.Background(), m)
|
||||||
assert.NotNil(t, err)
|
assert.NotNil(t, err)
|
||||||
assert.Equal(t, err, fmt.Errorf("missing obs bucket name"))
|
assert.Equal(t, err, fmt.Errorf("missing obs bucket name"))
|
||||||
})
|
})
|
||||||
|
|
@ -113,7 +113,7 @@ func TestInit(t *testing.T) {
|
||||||
"endpoint": "dummy-endpoint",
|
"endpoint": "dummy-endpoint",
|
||||||
"secretKey": "dummy-sk",
|
"secretKey": "dummy-sk",
|
||||||
}
|
}
|
||||||
err := obs.Init(m)
|
err := obs.Init(context.Background(), m)
|
||||||
assert.NotNil(t, err)
|
assert.NotNil(t, err)
|
||||||
assert.Equal(t, err, fmt.Errorf("missing the huawei access key"))
|
assert.Equal(t, err, fmt.Errorf("missing the huawei access key"))
|
||||||
})
|
})
|
||||||
|
|
@ -124,7 +124,7 @@ func TestInit(t *testing.T) {
|
||||||
"endpoint": "dummy-endpoint",
|
"endpoint": "dummy-endpoint",
|
||||||
"accessKey": "dummy-ak",
|
"accessKey": "dummy-ak",
|
||||||
}
|
}
|
||||||
err := obs.Init(m)
|
err := obs.Init(context.Background(), m)
|
||||||
assert.NotNil(t, err)
|
assert.NotNil(t, err)
|
||||||
assert.Equal(t, err, fmt.Errorf("missing the huawei secret key"))
|
assert.Equal(t, err, fmt.Errorf("missing the huawei secret key"))
|
||||||
})
|
})
|
||||||
|
|
@ -135,7 +135,7 @@ func TestInit(t *testing.T) {
|
||||||
"accessKey": "dummy-ak",
|
"accessKey": "dummy-ak",
|
||||||
"secretKey": "dummy-sk",
|
"secretKey": "dummy-sk",
|
||||||
}
|
}
|
||||||
err := obs.Init(m)
|
err := obs.Init(context.Background(), m)
|
||||||
assert.NotNil(t, err)
|
assert.NotNil(t, err)
|
||||||
assert.Equal(t, err, fmt.Errorf("missing obs endpoint"))
|
assert.Equal(t, err, fmt.Errorf("missing obs endpoint"))
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@ func NewInflux(logger logger.Logger) bindings.OutputBinding {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init does metadata parsing and connection establishment.
|
// Init does metadata parsing and connection establishment.
|
||||||
func (i *Influx) Init(metadata bindings.Metadata) error {
|
func (i *Influx) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||||
influxMeta, err := i.getInfluxMetadata(metadata)
|
influxMeta, err := i.getInfluxMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ func TestInflux_Init(t *testing.T) {
|
||||||
assert.Nil(t, influx.client)
|
assert.Nil(t, influx.client)
|
||||||
|
|
||||||
m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{"Url": "a", "Token": "a", "Org": "a", "Bucket": "a"}}}
|
m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{"Url": "a", "Token": "a", "Org": "a", "Bucket": "a"}}}
|
||||||
err := influx.Init(m)
|
err := influx.Init(context.Background(), m)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
assert.NotNil(t, influx.queryAPI)
|
assert.NotNil(t, influx.queryAPI)
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ package bindings
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
"github.com/dapr/components-contrib/health"
|
"github.com/dapr/components-contrib/health"
|
||||||
)
|
)
|
||||||
|
|
@ -23,18 +24,21 @@ import (
|
||||||
// InputBinding is the interface to define a binding that triggers on incoming events.
|
// InputBinding is the interface to define a binding that triggers on incoming events.
|
||||||
type InputBinding interface {
|
type InputBinding interface {
|
||||||
// Init passes connection and properties metadata to the binding implementation.
|
// Init passes connection and properties metadata to the binding implementation.
|
||||||
Init(metadata Metadata) error
|
Init(ctx context.Context, metadata Metadata) error
|
||||||
// Read is a method that runs in background and triggers the callback function whenever an event arrives.
|
// Read is a method that runs in background and triggers the callback function whenever an event arrives.
|
||||||
Read(ctx context.Context, handler Handler) error
|
Read(ctx context.Context, handler Handler) error
|
||||||
|
// Close is a method that closes the connection to the binding. Must be
|
||||||
|
// called when the binding is no longer needed to free up resources.
|
||||||
|
io.Closer
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handler is the handler used to invoke the app handler.
|
// Handler is the handler used to invoke the app handler.
|
||||||
type Handler func(context.Context, *ReadResponse) ([]byte, error)
|
type Handler func(context.Context, *ReadResponse) ([]byte, error)
|
||||||
|
|
||||||
func PingInpBinding(inputBinding InputBinding) error {
|
func PingInpBinding(ctx context.Context, inputBinding InputBinding) error {
|
||||||
// checks if this input binding has the ping option then executes
|
// checks if this input binding has the ping option then executes
|
||||||
if inputBindingWithPing, ok := inputBinding.(health.Pinger); ok {
|
if inputBindingWithPing, ok := inputBinding.(health.Pinger); ok {
|
||||||
return inputBindingWithPing.Ping()
|
return inputBindingWithPing.Ping(ctx)
|
||||||
} else {
|
} else {
|
||||||
return fmt.Errorf("ping is not implemented by this input binding")
|
return fmt.Errorf("ping is not implemented by this input binding")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,10 @@ package kafka
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/dapr/kit/logger"
|
"github.com/dapr/kit/logger"
|
||||||
|
|
||||||
|
|
@ -29,12 +32,13 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Binding struct {
|
type Binding struct {
|
||||||
kafka *kafka.Kafka
|
kafka *kafka.Kafka
|
||||||
publishTopic string
|
publishTopic string
|
||||||
topics []string
|
topics []string
|
||||||
logger logger.Logger
|
logger logger.Logger
|
||||||
subscribeCtx context.Context
|
closeCh chan struct{}
|
||||||
subscribeCancel context.CancelFunc
|
closed atomic.Bool
|
||||||
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewKafka returns a new kafka binding instance.
|
// NewKafka returns a new kafka binding instance.
|
||||||
|
|
@ -43,15 +47,14 @@ func NewKafka(logger logger.Logger) bindings.InputOutputBinding {
|
||||||
// in kafka binding component, disable consumer retry by default
|
// in kafka binding component, disable consumer retry by default
|
||||||
k.DefaultConsumeRetryEnabled = false
|
k.DefaultConsumeRetryEnabled = false
|
||||||
return &Binding{
|
return &Binding{
|
||||||
kafka: k,
|
kafka: k,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
closeCh: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Binding) Init(metadata bindings.Metadata) error {
|
func (b *Binding) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||||
b.subscribeCtx, b.subscribeCancel = context.WithCancel(context.Background())
|
err := b.kafka.Init(ctx, metadata.Properties)
|
||||||
|
|
||||||
err := b.kafka.Init(metadata.Properties)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -74,7 +77,10 @@ func (b *Binding) Operations() []bindings.OperationKind {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Binding) Close() (err error) {
|
func (b *Binding) Close() (err error) {
|
||||||
b.subscribeCancel()
|
if b.closed.CompareAndSwap(false, true) {
|
||||||
|
close(b.closeCh)
|
||||||
|
}
|
||||||
|
defer b.wg.Wait()
|
||||||
return b.kafka.Close()
|
return b.kafka.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -84,6 +90,10 @@ func (b *Binding) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bin
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
|
func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
|
if b.closed.Load() {
|
||||||
|
return errors.New("error: binding is closed")
|
||||||
|
}
|
||||||
|
|
||||||
if len(b.topics) == 0 {
|
if len(b.topics) == 0 {
|
||||||
b.logger.Warnf("kafka binding: no topic defined, input bindings will not be started")
|
b.logger.Warnf("kafka binding: no topic defined, input bindings will not be started")
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -96,31 +106,22 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
for _, t := range b.topics {
|
for _, t := range b.topics {
|
||||||
b.kafka.AddTopicHandler(t, handlerConfig)
|
b.kafka.AddTopicHandler(t, handlerConfig)
|
||||||
}
|
}
|
||||||
|
b.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
// Wait for context cancelation
|
defer b.wg.Done()
|
||||||
|
// Wait for context cancelation or closure.
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
case <-b.subscribeCtx.Done():
|
case <-b.closeCh:
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove the topic handler before restarting the subscriber
|
// Remove the topic handlers.
|
||||||
for _, t := range b.topics {
|
for _, t := range b.topics {
|
||||||
b.kafka.RemoveTopicHandler(t)
|
b.kafka.RemoveTopicHandler(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the component's context has been canceled, do not re-subscribe
|
|
||||||
if b.subscribeCtx.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err := b.kafka.Subscribe(b.subscribeCtx)
|
|
||||||
if err != nil {
|
|
||||||
b.logger.Errorf("kafka binding: error re-subscribing: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return b.kafka.Subscribe(b.subscribeCtx)
|
return b.kafka.Subscribe(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func adaptHandler(handler bindings.Handler) kafka.EventHandler {
|
func adaptHandler(handler bindings.Handler) kafka.EventHandler {
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,11 @@ package kubemq
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
qs "github.com/kubemq-io/kubemq-go/queues_stream"
|
qs "github.com/kubemq-io/kubemq-go/queues_stream"
|
||||||
|
|
@ -19,31 +22,30 @@ type Kubemq interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type kubeMQ struct {
|
type kubeMQ struct {
|
||||||
client *qs.QueuesStreamClient
|
client *qs.QueuesStreamClient
|
||||||
opts *options
|
opts *options
|
||||||
logger logger.Logger
|
logger logger.Logger
|
||||||
ctx context.Context
|
closed atomic.Bool
|
||||||
ctxCancel context.CancelFunc
|
closeCh chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewKubeMQ(logger logger.Logger) Kubemq {
|
func NewKubeMQ(logger logger.Logger) Kubemq {
|
||||||
return &kubeMQ{
|
return &kubeMQ{
|
||||||
client: nil,
|
client: nil,
|
||||||
opts: nil,
|
opts: nil,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
ctx: nil,
|
closeCh: make(chan struct{}),
|
||||||
ctxCancel: nil,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *kubeMQ) Init(metadata bindings.Metadata) error {
|
func (k *kubeMQ) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||||
opts, err := createOptions(metadata)
|
opts, err := createOptions(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
k.opts = opts
|
k.opts = opts
|
||||||
k.ctx, k.ctxCancel = context.WithCancel(context.Background())
|
client, err := qs.NewQueuesStreamClient(ctx,
|
||||||
client, err := qs.NewQueuesStreamClient(k.ctx,
|
|
||||||
qs.WithAddress(opts.host, opts.port),
|
qs.WithAddress(opts.host, opts.port),
|
||||||
qs.WithCheckConnection(true),
|
qs.WithCheckConnection(true),
|
||||||
qs.WithAuthToken(opts.authToken),
|
qs.WithAuthToken(opts.authToken),
|
||||||
|
|
@ -53,22 +55,39 @@ func (k *kubeMQ) Init(metadata bindings.Metadata) error {
|
||||||
k.logger.Errorf("error init kubemq client error: %s", err.Error())
|
k.logger.Errorf("error init kubemq client error: %s", err.Error())
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
k.ctx, k.ctxCancel = context.WithCancel(context.Background())
|
|
||||||
k.client = client
|
k.client = client
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *kubeMQ) Read(ctx context.Context, handler bindings.Handler) error {
|
func (k *kubeMQ) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
|
if k.closed.Load() {
|
||||||
|
return errors.New("binding is closed")
|
||||||
|
}
|
||||||
|
k.wg.Add(2)
|
||||||
|
processCtx, cancel := context.WithCancel(ctx)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer k.wg.Done()
|
||||||
|
defer cancel()
|
||||||
|
select {
|
||||||
|
case <-k.closeCh:
|
||||||
|
case <-processCtx.Done():
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer k.wg.Done()
|
||||||
for {
|
for {
|
||||||
err := k.processQueueMessage(k.ctx, handler)
|
err := k.processQueueMessage(processCtx, handler)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
k.logger.Error(err.Error())
|
k.logger.Error(err.Error())
|
||||||
time.Sleep(time.Second)
|
|
||||||
}
|
}
|
||||||
if k.ctx.Err() != nil {
|
// If context cancelled or kubeMQ closed, exit. Otherwise, continue
|
||||||
return
|
// after a second.
|
||||||
|
select {
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
continue
|
||||||
|
case <-processCtx.Done():
|
||||||
}
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -82,7 +101,7 @@ func (k *kubeMQ) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bind
|
||||||
SetPolicyExpirationSeconds(parsePolicyExpirationSeconds(req.Metadata)).
|
SetPolicyExpirationSeconds(parsePolicyExpirationSeconds(req.Metadata)).
|
||||||
SetPolicyMaxReceiveCount(parseSetPolicyMaxReceiveCount(req.Metadata)).
|
SetPolicyMaxReceiveCount(parseSetPolicyMaxReceiveCount(req.Metadata)).
|
||||||
SetPolicyMaxReceiveQueue(parsePolicyMaxReceiveQueue(req.Metadata))
|
SetPolicyMaxReceiveQueue(parsePolicyMaxReceiveQueue(req.Metadata))
|
||||||
result, err := k.client.Send(k.ctx, queueMessage)
|
result, err := k.client.Send(ctx, queueMessage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -101,6 +120,14 @@ func (k *kubeMQ) Operations() []bindings.OperationKind {
|
||||||
return []bindings.OperationKind{bindings.CreateOperation}
|
return []bindings.OperationKind{bindings.CreateOperation}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (k *kubeMQ) Close() error {
|
||||||
|
if k.closed.CompareAndSwap(false, true) {
|
||||||
|
close(k.closeCh)
|
||||||
|
}
|
||||||
|
defer k.wg.Wait()
|
||||||
|
return k.client.Close()
|
||||||
|
}
|
||||||
|
|
||||||
func (k *kubeMQ) processQueueMessage(ctx context.Context, handler bindings.Handler) error {
|
func (k *kubeMQ) processQueueMessage(ctx context.Context, handler bindings.Handler) error {
|
||||||
pr := qs.NewPollRequest().
|
pr := qs.NewPollRequest().
|
||||||
SetChannel(k.opts.channel).
|
SetChannel(k.opts.channel).
|
||||||
|
|
|
||||||
|
|
@ -106,7 +106,7 @@ func Test_kubeMQ_Init(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
kubemq := NewKubeMQ(logger.NewLogger("test"))
|
kubemq := NewKubeMQ(logger.NewLogger("test"))
|
||||||
err := kubemq.Init(tt.meta)
|
err := kubemq.Init(context.Background(), tt.meta)
|
||||||
if tt.wantErr {
|
if tt.wantErr {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -120,7 +120,7 @@ func Test_kubeMQ_Invoke_Read_Single_Message(t *testing.T) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
kubemq := NewKubeMQ(logger.NewLogger("test"))
|
kubemq := NewKubeMQ(logger.NewLogger("test"))
|
||||||
err := kubemq.Init(getDefaultMetadata("test.read.single"))
|
err := kubemq.Init(context.Background(), getDefaultMetadata("test.read.single"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
dataReadCh := make(chan []byte)
|
dataReadCh := make(chan []byte)
|
||||||
invokeRequest := &bindings.InvokeRequest{
|
invokeRequest := &bindings.InvokeRequest{
|
||||||
|
|
@ -147,7 +147,7 @@ func Test_kubeMQ_Invoke_Read_Single_MessageWithHandlerError(t *testing.T) {
|
||||||
kubemq := NewKubeMQ(logger.NewLogger("test"))
|
kubemq := NewKubeMQ(logger.NewLogger("test"))
|
||||||
md := getDefaultMetadata("test.read.single.error")
|
md := getDefaultMetadata("test.read.single.error")
|
||||||
md.Properties["autoAcknowledged"] = "false"
|
md.Properties["autoAcknowledged"] = "false"
|
||||||
err := kubemq.Init(md)
|
err := kubemq.Init(context.Background(), md)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
invokeRequest := &bindings.InvokeRequest{
|
invokeRequest := &bindings.InvokeRequest{
|
||||||
Data: []byte("test"),
|
Data: []byte("test"),
|
||||||
|
|
@ -182,7 +182,7 @@ func Test_kubeMQ_Invoke_Error(t *testing.T) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
kubemq := NewKubeMQ(logger.NewLogger("test"))
|
kubemq := NewKubeMQ(logger.NewLogger("test"))
|
||||||
err := kubemq.Init(getDefaultMetadata("***test***"))
|
err := kubemq.Init(context.Background(), getDefaultMetadata("***test***"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
invokeRequest := &bindings.InvokeRequest{
|
invokeRequest := &bindings.InvokeRequest{
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,8 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
v1 "k8s.io/api/core/v1"
|
v1 "k8s.io/api/core/v1"
|
||||||
|
|
@ -35,6 +37,9 @@ type kubernetesInput struct {
|
||||||
namespace string
|
namespace string
|
||||||
resyncPeriod time.Duration
|
resyncPeriod time.Duration
|
||||||
logger logger.Logger
|
logger logger.Logger
|
||||||
|
closed atomic.Bool
|
||||||
|
closeCh chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
type EventResponse struct {
|
type EventResponse struct {
|
||||||
|
|
@ -45,10 +50,13 @@ type EventResponse struct {
|
||||||
|
|
||||||
// NewKubernetes returns a new Kubernetes event input binding.
|
// NewKubernetes returns a new Kubernetes event input binding.
|
||||||
func NewKubernetes(logger logger.Logger) bindings.InputBinding {
|
func NewKubernetes(logger logger.Logger) bindings.InputBinding {
|
||||||
return &kubernetesInput{logger: logger}
|
return &kubernetesInput{
|
||||||
|
logger: logger,
|
||||||
|
closeCh: make(chan struct{}),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *kubernetesInput) Init(metadata bindings.Metadata) error {
|
func (k *kubernetesInput) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||||
client, err := kubeclient.GetKubeClient()
|
client, err := kubeclient.GetKubeClient()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -78,6 +86,9 @@ func (k *kubernetesInput) parseMetadata(metadata bindings.Metadata) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *kubernetesInput) Read(ctx context.Context, handler bindings.Handler) error {
|
func (k *kubernetesInput) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
|
if k.closed.Load() {
|
||||||
|
return errors.New("binding is closed")
|
||||||
|
}
|
||||||
watchlist := cache.NewListWatchFromClient(
|
watchlist := cache.NewListWatchFromClient(
|
||||||
k.kubeClient.CoreV1().RESTClient(),
|
k.kubeClient.CoreV1().RESTClient(),
|
||||||
"events",
|
"events",
|
||||||
|
|
@ -126,12 +137,28 @@ func (k *kubernetesInput) Read(ctx context.Context, handler bindings.Handler) er
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
k.wg.Add(3)
|
||||||
|
readCtx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
|
// catch when binding is closed.
|
||||||
|
go func() {
|
||||||
|
defer k.wg.Done()
|
||||||
|
defer cancel()
|
||||||
|
select {
|
||||||
|
case <-readCtx.Done():
|
||||||
|
case <-k.closeCh:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// Start the controller in backgound
|
// Start the controller in backgound
|
||||||
stopCh := make(chan struct{})
|
go func() {
|
||||||
go controller.Run(stopCh)
|
defer k.wg.Done()
|
||||||
|
controller.Run(readCtx.Done())
|
||||||
|
}()
|
||||||
|
|
||||||
// Watch for new messages and for context cancellation
|
// Watch for new messages and for context cancellation
|
||||||
go func() {
|
go func() {
|
||||||
|
defer k.wg.Done()
|
||||||
var (
|
var (
|
||||||
obj EventResponse
|
obj EventResponse
|
||||||
data []byte
|
data []byte
|
||||||
|
|
@ -148,8 +175,7 @@ func (k *kubernetesInput) Read(ctx context.Context, handler bindings.Handler) er
|
||||||
Data: data,
|
Data: data,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
case <-ctx.Done():
|
case <-readCtx.Done():
|
||||||
close(stopCh)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -157,3 +183,11 @@ func (k *kubernetesInput) Read(ctx context.Context, handler bindings.Handler) er
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (k *kubernetesInput) Close() error {
|
||||||
|
if k.closed.CompareAndSwap(false, true) {
|
||||||
|
close(k.closeCh)
|
||||||
|
}
|
||||||
|
k.wg.Wait()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@ func NewLocalStorage(logger logger.Logger) bindings.OutputBinding {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init performs metadata parsing.
|
// Init performs metadata parsing.
|
||||||
func (ls *LocalStorage) Init(metadata bindings.Metadata) error {
|
func (ls *LocalStorage) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||||
m, err := ls.parseMetadata(metadata)
|
m, err := ls.parseMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to parse metadata: %w", err)
|
return fmt.Errorf("failed to parse metadata: %w", err)
|
||||||
|
|
|
||||||
|
|
@ -40,30 +40,29 @@ type MQTT struct {
|
||||||
logger logger.Logger
|
logger logger.Logger
|
||||||
isSubscribed atomic.Bool
|
isSubscribed atomic.Bool
|
||||||
readHandler bindings.Handler
|
readHandler bindings.Handler
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
backOff backoff.BackOff
|
backOff backoff.BackOff
|
||||||
|
closeCh chan struct{}
|
||||||
|
closed atomic.Bool
|
||||||
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMQTT returns a new MQTT instance.
|
// NewMQTT returns a new MQTT instance.
|
||||||
func NewMQTT(logger logger.Logger) bindings.InputOutputBinding {
|
func NewMQTT(logger logger.Logger) bindings.InputOutputBinding {
|
||||||
return &MQTT{
|
return &MQTT{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
closeCh: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init does MQTT connection parsing.
|
// Init does MQTT connection parsing.
|
||||||
func (m *MQTT) Init(metadata bindings.Metadata) (err error) {
|
func (m *MQTT) Init(ctx context.Context, metadata bindings.Metadata) (err error) {
|
||||||
m.metadata, err = parseMQTTMetaData(metadata, m.logger)
|
m.metadata, err = parseMQTTMetaData(metadata, m.logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
m.ctx, m.cancel = context.WithCancel(context.Background())
|
|
||||||
|
|
||||||
// TODO: Make the backoff configurable for constant or exponential
|
// TODO: Make the backoff configurable for constant or exponential
|
||||||
b := backoff.NewConstantBackOff(5 * time.Second)
|
m.backOff = backoff.NewConstantBackOff(5 * time.Second)
|
||||||
m.backOff = backoff.WithContext(b, m.ctx)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -104,7 +103,7 @@ func (m *MQTT) getProducer() (mqtt.Client, error) {
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
|
func (m *MQTT) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
|
||||||
producer, err := m.getProducer()
|
producer, err := m.getProducer()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create producer connection: %w", err)
|
return nil, fmt.Errorf("failed to create producer connection: %w", err)
|
||||||
|
|
@ -118,7 +117,7 @@ func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (*
|
||||||
bo := backoff.WithMaxRetries(
|
bo := backoff.WithMaxRetries(
|
||||||
backoff.NewConstantBackOff(200*time.Millisecond), 3,
|
backoff.NewConstantBackOff(200*time.Millisecond), 3,
|
||||||
)
|
)
|
||||||
bo = backoff.WithContext(bo, parentCtx)
|
bo = backoff.WithContext(bo, ctx)
|
||||||
|
|
||||||
topic, ok := req.Metadata[mqttTopic]
|
topic, ok := req.Metadata[mqttTopic]
|
||||||
if !ok || topic == "" {
|
if !ok || topic == "" {
|
||||||
|
|
@ -127,14 +126,13 @@ func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (*
|
||||||
}
|
}
|
||||||
return nil, retry.NotifyRecover(func() (err error) {
|
return nil, retry.NotifyRecover(func() (err error) {
|
||||||
token := producer.Publish(topic, m.metadata.qos, m.metadata.retain, req.Data)
|
token := producer.Publish(topic, m.metadata.qos, m.metadata.retain, req.Data)
|
||||||
ctx, cancel := context.WithTimeout(parentCtx, defaultWait)
|
|
||||||
defer cancel()
|
|
||||||
select {
|
select {
|
||||||
case <-token.Done():
|
case <-token.Done():
|
||||||
err = token.Error()
|
err = token.Error()
|
||||||
case <-m.ctx.Done():
|
case <-m.closeCh:
|
||||||
// Context canceled
|
err = errors.New("mqtt client closed")
|
||||||
err = m.ctx.Err()
|
case <-time.After(defaultWait):
|
||||||
|
err = errors.New("mqtt client timeout")
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
// Context canceled
|
// Context canceled
|
||||||
err = ctx.Err()
|
err = ctx.Err()
|
||||||
|
|
@ -151,6 +149,10 @@ func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (*
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MQTT) Read(ctx context.Context, handler bindings.Handler) error {
|
func (m *MQTT) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
|
if m.closed.Load() {
|
||||||
|
return errors.New("error: binding is closed")
|
||||||
|
}
|
||||||
|
|
||||||
// If the subscription is already active, wait 2s before retrying (in case we're still disconnecting), otherwise return an error
|
// If the subscription is already active, wait 2s before retrying (in case we're still disconnecting), otherwise return an error
|
||||||
if !m.isSubscribed.CompareAndSwap(false, true) {
|
if !m.isSubscribed.CompareAndSwap(false, true) {
|
||||||
m.logger.Debug("Subscription is already active; waiting 2s before retrying…")
|
m.logger.Debug("Subscription is already active; waiting 2s before retrying…")
|
||||||
|
|
@ -177,11 +179,14 @@ func (m *MQTT) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
|
|
||||||
// In background, watch for contexts cancelation and stop the connection
|
// In background, watch for contexts cancelation and stop the connection
|
||||||
// However, do not call "unsubscribe" which would cause the broker to stop tracking the last message received by this consumer group
|
// However, do not call "unsubscribe" which would cause the broker to stop tracking the last message received by this consumer group
|
||||||
|
m.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer m.wg.Done()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
// nop
|
// nop
|
||||||
case <-m.ctx.Done():
|
case <-m.closeCh:
|
||||||
// nop
|
// nop
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -208,14 +213,12 @@ func (m *MQTT) connect(clientID string, isSubscriber bool) (mqtt.Client, error)
|
||||||
}
|
}
|
||||||
client := mqtt.NewClient(opts)
|
client := mqtt.NewClient(opts)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(m.ctx, defaultWait)
|
|
||||||
defer cancel()
|
|
||||||
token := client.Connect()
|
token := client.Connect()
|
||||||
select {
|
select {
|
||||||
case <-token.Done():
|
case <-token.Done():
|
||||||
err = token.Error()
|
err = token.Error()
|
||||||
case <-ctx.Done():
|
case <-time.After(defaultWait):
|
||||||
err = ctx.Err()
|
err = errors.New("mqtt client timed out connecting")
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to connect: %w", err)
|
return nil, fmt.Errorf("failed to connect: %w", err)
|
||||||
|
|
@ -290,43 +293,46 @@ func (m *MQTT) createClientOptions(uri *url.URL, clientID string) *mqtt.ClientOp
|
||||||
return opts
|
return opts
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MQTT) handleMessage(client mqtt.Client, mqttMsg mqtt.Message) {
|
func (m *MQTT) handleMessage() func(client mqtt.Client, mqttMsg mqtt.Message) {
|
||||||
// We're using m.ctx as context in this method because we don't have access to the Read context
|
return func(client mqtt.Client, mqttMsg mqtt.Message) {
|
||||||
// Canceling the Read context makes Read invoke "Disconnect" anyways
|
bo := m.backOff
|
||||||
ctx := m.ctx
|
if m.metadata.backOffMaxRetries >= 0 {
|
||||||
|
bo = backoff.WithMaxRetries(bo, uint64(m.metadata.backOffMaxRetries))
|
||||||
|
}
|
||||||
|
|
||||||
var bo backoff.BackOff = backoff.WithContext(m.backOff, ctx)
|
err := retry.NotifyRecover(
|
||||||
if m.metadata.backOffMaxRetries >= 0 {
|
func() error {
|
||||||
bo = backoff.WithMaxRetries(bo, uint64(m.metadata.backOffMaxRetries))
|
m.logger.Debugf("Processing MQTT message %s/%d", mqttMsg.Topic(), mqttMsg.MessageID())
|
||||||
}
|
// Use a background context here so that the context is not tied to the
|
||||||
|
// first Invoke first created the producer.
|
||||||
|
// TODO: add context to mqtt library, and add a OnConnectWithContext option
|
||||||
|
// to change this func signature to
|
||||||
|
// func(c mqtt.Client, ctx context.Context)
|
||||||
|
_, err := m.readHandler(context.Background(), &bindings.ReadResponse{
|
||||||
|
Data: mqttMsg.Payload(),
|
||||||
|
Metadata: map[string]string{
|
||||||
|
mqttTopic: mqttMsg.Topic(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
err := retry.NotifyRecover(
|
// Ack the message on success
|
||||||
func() error {
|
mqttMsg.Ack()
|
||||||
m.logger.Debugf("Processing MQTT message %s/%d", mqttMsg.Topic(), mqttMsg.MessageID())
|
return nil
|
||||||
_, err := m.readHandler(ctx, &bindings.ReadResponse{
|
},
|
||||||
Data: mqttMsg.Payload(),
|
bo,
|
||||||
Metadata: map[string]string{
|
func(err error, d time.Duration) {
|
||||||
mqttTopic: mqttMsg.Topic(),
|
m.logger.Errorf("Error processing MQTT message: %s/%d. Retrying…", mqttMsg.Topic(), mqttMsg.MessageID())
|
||||||
},
|
},
|
||||||
})
|
func() {
|
||||||
if err != nil {
|
m.logger.Infof("Successfully processed MQTT message after it previously failed: %s/%d", mqttMsg.Topic(), mqttMsg.MessageID())
|
||||||
return err
|
},
|
||||||
}
|
)
|
||||||
|
if err != nil {
|
||||||
// Ack the message on success
|
m.logger.Errorf("Failed processing MQTT message: %s/%d: %v", mqttMsg.Topic(), mqttMsg.MessageID(), err)
|
||||||
mqttMsg.Ack()
|
}
|
||||||
return nil
|
|
||||||
},
|
|
||||||
bo,
|
|
||||||
func(err error, d time.Duration) {
|
|
||||||
m.logger.Errorf("Error processing MQTT message: %s/%d. Retrying…", mqttMsg.Topic(), mqttMsg.MessageID())
|
|
||||||
},
|
|
||||||
func() {
|
|
||||||
m.logger.Infof("Successfully processed MQTT message after it previously failed: %s/%d", mqttMsg.Topic(), mqttMsg.MessageID())
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
m.logger.Errorf("Failed processing MQTT message: %s/%d: %v", mqttMsg.Topic(), mqttMsg.MessageID(), err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -336,17 +342,15 @@ func (m *MQTT) createSubscriberClientOptions(uri *url.URL, clientID string) *mqt
|
||||||
|
|
||||||
// On (re-)connection, add the topic subscription
|
// On (re-)connection, add the topic subscription
|
||||||
opts.OnConnect = func(c mqtt.Client) {
|
opts.OnConnect = func(c mqtt.Client) {
|
||||||
token := c.Subscribe(m.metadata.topic, m.metadata.qos, m.handleMessage)
|
token := c.Subscribe(m.metadata.topic, m.metadata.qos, m.handleMessage())
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
subscribeCtx, subscribeCancel := context.WithTimeout(m.ctx, defaultWait)
|
|
||||||
defer subscribeCancel()
|
|
||||||
select {
|
select {
|
||||||
case <-token.Done():
|
case <-token.Done():
|
||||||
// Subscription went through (sucecessfully or not)
|
// Subscription went through (sucecessfully or not)
|
||||||
err = token.Error()
|
err = token.Error()
|
||||||
case <-subscribeCtx.Done():
|
case <-time.After(defaultWait):
|
||||||
err = fmt.Errorf("error while waiting for subscription token: %w", subscribeCtx.Err())
|
err = errors.New("timed out waiting for subscription to complete")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Nothing we can do in case of errors besides logging them
|
// Nothing we can do in case of errors besides logging them
|
||||||
|
|
@ -363,13 +367,16 @@ func (m *MQTT) Close() error {
|
||||||
m.producerLock.Lock()
|
m.producerLock.Lock()
|
||||||
defer m.producerLock.Unlock()
|
defer m.producerLock.Unlock()
|
||||||
|
|
||||||
// Canceling the context also causes Read to stop receiving messages
|
if m.closed.CompareAndSwap(false, true) {
|
||||||
m.cancel()
|
close(m.closeCh)
|
||||||
|
}
|
||||||
|
|
||||||
if m.producer != nil {
|
if m.producer != nil {
|
||||||
m.producer.Disconnect(200)
|
m.producer.Disconnect(200)
|
||||||
m.producer = nil
|
m.producer = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.wg.Wait()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,7 @@ func getConnectionString() string {
|
||||||
|
|
||||||
func TestInvokeWithTopic(t *testing.T) {
|
func TestInvokeWithTopic(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
url := getConnectionString()
|
url := getConnectionString()
|
||||||
if url == "" {
|
if url == "" {
|
||||||
|
|
@ -79,7 +80,7 @@ func TestInvokeWithTopic(t *testing.T) {
|
||||||
logger := logger.NewLogger("test")
|
logger := logger.NewLogger("test")
|
||||||
|
|
||||||
r := NewMQTT(logger).(*MQTT)
|
r := NewMQTT(logger).(*MQTT)
|
||||||
err := r.Init(metadata)
|
err := r.Init(ctx, metadata)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
conn, err := r.connect(uuid.NewString(), false)
|
conn, err := r.connect(uuid.NewString(), false)
|
||||||
|
|
@ -127,4 +128,5 @@ func TestInvokeWithTopic(t *testing.T) {
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
assert.Equal(t, dataCustomized, mqttMessage.Payload())
|
assert.Equal(t, dataCustomized, mqttMessage.Payload())
|
||||||
assert.Equal(t, topicCustomized, mqttMessage.Topic())
|
assert.Equal(t, topicCustomized, mqttMessage.Topic())
|
||||||
|
assert.NoError(t, r.Close())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -205,7 +205,6 @@ func TestParseMetadata(t *testing.T) {
|
||||||
logger := logger.NewLogger("test")
|
logger := logger.NewLogger("test")
|
||||||
m := NewMQTT(logger).(*MQTT)
|
m := NewMQTT(logger).(*MQTT)
|
||||||
m.backOff = backoff.NewConstantBackOff(5 * time.Second)
|
m.backOff = backoff.NewConstantBackOff(5 * time.Second)
|
||||||
m.ctx, m.cancel = context.WithCancel(context.Background())
|
|
||||||
m.readHandler = func(ctx context.Context, r *bindings.ReadResponse) ([]byte, error) {
|
m.readHandler = func(ctx context.Context, r *bindings.ReadResponse) ([]byte, error) {
|
||||||
assert.Equal(t, payload, r.Data)
|
assert.Equal(t, payload, r.Data)
|
||||||
metadata := r.Metadata
|
metadata := r.Metadata
|
||||||
|
|
@ -215,7 +214,7 @@ func TestParseMetadata(t *testing.T) {
|
||||||
return r.Data, nil
|
return r.Data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
m.handleMessage(nil, &mqttMockMessage{
|
m.handleMessage()(nil, &mqttMockMessage{
|
||||||
topic: topic,
|
topic: topic,
|
||||||
payload: payload,
|
payload: payload,
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -81,7 +81,7 @@ func NewMysql(logger logger.Logger) bindings.OutputBinding {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init initializes the MySQL binding.
|
// Init initializes the MySQL binding.
|
||||||
func (m *Mysql) Init(metadata bindings.Metadata) error {
|
func (m *Mysql) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||||
m.logger.Debug("Initializing MySql binding")
|
m.logger.Debug("Initializing MySql binding")
|
||||||
|
|
||||||
p := metadata.Properties
|
p := metadata.Properties
|
||||||
|
|
@ -115,7 +115,7 @@ func (m *Mysql) Init(metadata bindings.Metadata) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.Ping()
|
err = db.PingContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to ping the DB: %w", err)
|
return fmt.Errorf("unable to ping the DB: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -75,7 +75,7 @@ func TestMysqlIntegration(t *testing.T) {
|
||||||
|
|
||||||
b := NewMysql(logger.NewLogger("test")).(*Mysql)
|
b := NewMysql(logger.NewLogger("test")).(*Mysql)
|
||||||
m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{connectionURLKey: url}}}
|
m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{connectionURLKey: url}}}
|
||||||
if err := b.Init(m); err != nil {
|
if err := b.Init(context.Background(), m); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/nacos-group/nacos-sdk-go/v2/clients"
|
"github.com/nacos-group/nacos-sdk-go/v2/clients"
|
||||||
|
|
@ -56,6 +57,9 @@ type Nacos struct {
|
||||||
logger logger.Logger
|
logger logger.Logger
|
||||||
configClient config_client.IConfigClient //nolint:nosnakecase
|
configClient config_client.IConfigClient //nolint:nosnakecase
|
||||||
readHandler func(ctx context.Context, response *bindings.ReadResponse) ([]byte, error)
|
readHandler func(ctx context.Context, response *bindings.ReadResponse) ([]byte, error)
|
||||||
|
wg sync.WaitGroup
|
||||||
|
closed atomic.Bool
|
||||||
|
closeCh chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewNacos returns a new Nacos instance.
|
// NewNacos returns a new Nacos instance.
|
||||||
|
|
@ -63,11 +67,12 @@ func NewNacos(logger logger.Logger) bindings.OutputBinding {
|
||||||
return &Nacos{
|
return &Nacos{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
watchesLock: sync.Mutex{},
|
watchesLock: sync.Mutex{},
|
||||||
|
closeCh: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init implements InputBinding/OutputBinding's Init method.
|
// Init implements InputBinding/OutputBinding's Init method.
|
||||||
func (n *Nacos) Init(metadata bindings.Metadata) error {
|
func (n *Nacos) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||||
n.settings = Settings{
|
n.settings = Settings{
|
||||||
Timeout: defaultTimeout,
|
Timeout: defaultTimeout,
|
||||||
}
|
}
|
||||||
|
|
@ -146,6 +151,10 @@ func (n *Nacos) createConfigClient() error {
|
||||||
|
|
||||||
// Read implements InputBinding's Read method.
|
// Read implements InputBinding's Read method.
|
||||||
func (n *Nacos) Read(ctx context.Context, handler bindings.Handler) error {
|
func (n *Nacos) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
|
if n.closed.Load() {
|
||||||
|
return errors.New("binding is closed")
|
||||||
|
}
|
||||||
|
|
||||||
n.readHandler = handler
|
n.readHandler = handler
|
||||||
|
|
||||||
n.watchesLock.Lock()
|
n.watchesLock.Lock()
|
||||||
|
|
@ -154,9 +163,14 @@ func (n *Nacos) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
}
|
}
|
||||||
n.watchesLock.Unlock()
|
n.watchesLock.Unlock()
|
||||||
|
|
||||||
|
n.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer n.wg.Done()
|
||||||
// Cancel all listeners when the context is done
|
// Cancel all listeners when the context is done
|
||||||
<-ctx.Done()
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-n.closeCh:
|
||||||
|
}
|
||||||
n.cancelAllListeners()
|
n.cancelAllListeners()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|
@ -165,8 +179,14 @@ func (n *Nacos) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
|
|
||||||
// Close implements cancel all listeners, see https://github.com/dapr/components-contrib/issues/779
|
// Close implements cancel all listeners, see https://github.com/dapr/components-contrib/issues/779
|
||||||
func (n *Nacos) Close() error {
|
func (n *Nacos) Close() error {
|
||||||
|
if n.closed.CompareAndSwap(false, true) {
|
||||||
|
close(n.closeCh)
|
||||||
|
}
|
||||||
|
|
||||||
n.cancelAllListeners()
|
n.cancelAllListeners()
|
||||||
|
|
||||||
|
n.wg.Wait()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -223,7 +243,11 @@ func (n *Nacos) addListener(ctx context.Context, config configParam) {
|
||||||
|
|
||||||
func (n *Nacos) addListenerFoInputBinding(ctx context.Context, config configParam) {
|
func (n *Nacos) addListenerFoInputBinding(ctx context.Context, config configParam) {
|
||||||
if n.addToWatches(config) {
|
if n.addToWatches(config) {
|
||||||
go n.addListener(ctx, config)
|
n.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer n.wg.Done()
|
||||||
|
n.addListener(ctx, config)
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ func TestInputBindingRead(t *testing.T) { //nolint:paralleltest
|
||||||
m.Properties, err = getNacosLocalCacheMetadata()
|
m.Properties, err = getNacosLocalCacheMetadata()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
n := NewNacos(logger.NewLogger("test")).(*Nacos)
|
n := NewNacos(logger.NewLogger("test")).(*Nacos)
|
||||||
err = n.Init(m)
|
err = n.Init(context.Background(), m)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
var count int32
|
var count int32
|
||||||
ch := make(chan bool, 1)
|
ch := make(chan bool, 1)
|
||||||
|
|
|
||||||
|
|
@ -22,15 +22,15 @@ import (
|
||||||
|
|
||||||
// OutputBinding is the interface for an output binding, allowing users to invoke remote systems with optional payloads.
|
// OutputBinding is the interface for an output binding, allowing users to invoke remote systems with optional payloads.
|
||||||
type OutputBinding interface {
|
type OutputBinding interface {
|
||||||
Init(metadata Metadata) error
|
Init(ctx context.Context, metadata Metadata) error
|
||||||
Invoke(ctx context.Context, req *InvokeRequest) (*InvokeResponse, error)
|
Invoke(ctx context.Context, req *InvokeRequest) (*InvokeResponse, error)
|
||||||
Operations() []OperationKind
|
Operations() []OperationKind
|
||||||
}
|
}
|
||||||
|
|
||||||
func PingOutBinding(outputBinding OutputBinding) error {
|
func PingOutBinding(ctx context.Context, outputBinding OutputBinding) error {
|
||||||
// checks if this output binding has the ping option then executes
|
// checks if this output binding has the ping option then executes
|
||||||
if outputBindingWithPing, ok := outputBinding.(health.Pinger); ok {
|
if outputBindingWithPing, ok := outputBinding.(health.Pinger); ok {
|
||||||
return outputBindingWithPing.Ping()
|
return outputBindingWithPing.Ping(ctx)
|
||||||
} else {
|
} else {
|
||||||
return fmt.Errorf("ping is not implemented by this output binding")
|
return fmt.Errorf("ping is not implemented by this output binding")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ func NewPostgres(logger logger.Logger) bindings.OutputBinding {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init initializes the PostgreSql binding.
|
// Init initializes the PostgreSql binding.
|
||||||
func (p *Postgres) Init(metadata bindings.Metadata) error {
|
func (p *Postgres) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||||
url, ok := metadata.Properties[connectionURLKey]
|
url, ok := metadata.Properties[connectionURLKey]
|
||||||
if !ok || url == "" {
|
if !ok || url == "" {
|
||||||
return errors.Errorf("required metadata not set: %s", connectionURLKey)
|
return errors.Errorf("required metadata not set: %s", connectionURLKey)
|
||||||
|
|
@ -59,7 +59,9 @@ func (p *Postgres) Init(metadata bindings.Metadata) error {
|
||||||
return errors.Wrap(err, "error opening DB connection")
|
return errors.Wrap(err, "error opening DB connection")
|
||||||
}
|
}
|
||||||
|
|
||||||
p.db, err = pgxpool.NewWithConfig(context.Background(), poolConfig)
|
// This context doesn't control the lifetime of the connection pool, and is
|
||||||
|
// only scoped to postgres creating resources at init.
|
||||||
|
p.db, err = pgxpool.NewWithConfig(ctx, poolConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "unable to ping the DB")
|
return errors.Wrap(err, "unable to ping the DB")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@ func TestPostgresIntegration(t *testing.T) {
|
||||||
// live DB test
|
// live DB test
|
||||||
b := NewPostgres(logger.NewLogger("test")).(*Postgres)
|
b := NewPostgres(logger.NewLogger("test")).(*Postgres)
|
||||||
m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{connectionURLKey: url}}}
|
m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{connectionURLKey: url}}}
|
||||||
if err := b.Init(m); err != nil {
|
if err := b.Init(context.Background(), m); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,7 @@ func (p *Postmark) parseMetadata(meta bindings.Metadata) (postmarkMetadata, erro
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init does metadata parsing and not much else :).
|
// Init does metadata parsing and not much else :).
|
||||||
func (p *Postmark) Init(metadata bindings.Metadata) error {
|
func (p *Postmark) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||||
// Parse input metadata
|
// Parse input metadata
|
||||||
meta, err := p.parseMetadata(metadata)
|
meta, err := p.parseMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,8 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
amqp "github.com/rabbitmq/amqp091-go"
|
amqp "github.com/rabbitmq/amqp091-go"
|
||||||
|
|
@ -50,6 +52,9 @@ type RabbitMQ struct {
|
||||||
metadata rabbitMQMetadata
|
metadata rabbitMQMetadata
|
||||||
logger logger.Logger
|
logger logger.Logger
|
||||||
queue amqp.Queue
|
queue amqp.Queue
|
||||||
|
closed atomic.Bool
|
||||||
|
closeCh chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
// Metadata is the rabbitmq config.
|
// Metadata is the rabbitmq config.
|
||||||
|
|
@ -66,11 +71,14 @@ type rabbitMQMetadata struct {
|
||||||
|
|
||||||
// NewRabbitMQ returns a new rabbitmq instance.
|
// NewRabbitMQ returns a new rabbitmq instance.
|
||||||
func NewRabbitMQ(logger logger.Logger) bindings.InputOutputBinding {
|
func NewRabbitMQ(logger logger.Logger) bindings.InputOutputBinding {
|
||||||
return &RabbitMQ{logger: logger}
|
return &RabbitMQ{
|
||||||
|
logger: logger,
|
||||||
|
closeCh: make(chan struct{}),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init does metadata parsing and connection creation.
|
// Init does metadata parsing and connection creation.
|
||||||
func (r *RabbitMQ) Init(metadata bindings.Metadata) error {
|
func (r *RabbitMQ) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||||
err := r.parseMetadata(metadata)
|
err := r.parseMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -226,6 +234,10 @@ func (r *RabbitMQ) declareQueue() (amqp.Queue, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RabbitMQ) Read(ctx context.Context, handler bindings.Handler) error {
|
func (r *RabbitMQ) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
|
if r.closed.Load() {
|
||||||
|
return errors.New("binding already closed")
|
||||||
|
}
|
||||||
|
|
||||||
msgs, err := r.channel.Consume(
|
msgs, err := r.channel.Consume(
|
||||||
r.queue.Name,
|
r.queue.Name,
|
||||||
"",
|
"",
|
||||||
|
|
@ -239,14 +251,27 @@ func (r *RabbitMQ) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
readCtx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
|
r.wg.Add(2)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer r.wg.Done()
|
||||||
|
defer cancel()
|
||||||
|
select {
|
||||||
|
case <-r.closeCh:
|
||||||
|
case <-readCtx.Done():
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer r.wg.Done()
|
||||||
var err error
|
var err error
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-readCtx.Done():
|
||||||
return
|
return
|
||||||
case d := <-msgs:
|
case d := <-msgs:
|
||||||
_, err = handler(ctx, &bindings.ReadResponse{
|
_, err = handler(readCtx, &bindings.ReadResponse{
|
||||||
Data: d.Body,
|
Data: d.Body,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -260,3 +285,11 @@ func (r *RabbitMQ) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *RabbitMQ) Close() error {
|
||||||
|
if r.closed.CompareAndSwap(false, true) {
|
||||||
|
close(r.closeCh)
|
||||||
|
}
|
||||||
|
defer r.wg.Wait()
|
||||||
|
return r.channel.Close()
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -85,7 +85,7 @@ func TestQueuesWithTTL(t *testing.T) {
|
||||||
logger := logger.NewLogger("test")
|
logger := logger.NewLogger("test")
|
||||||
|
|
||||||
r := NewRabbitMQ(logger).(*RabbitMQ)
|
r := NewRabbitMQ(logger).(*RabbitMQ)
|
||||||
err := r.Init(metadata)
|
err := r.Init(context.Background(), metadata)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// Assert that if waited too long, we won't see any message
|
// Assert that if waited too long, we won't see any message
|
||||||
|
|
@ -117,6 +117,7 @@ func TestQueuesWithTTL(t *testing.T) {
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
msgBody := string(msg.Body)
|
msgBody := string(msg.Body)
|
||||||
assert.Equal(t, testMsgContent, msgBody)
|
assert.Equal(t, testMsgContent, msgBody)
|
||||||
|
assert.NoError(t, r.Close())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPublishingWithTTL(t *testing.T) {
|
func TestPublishingWithTTL(t *testing.T) {
|
||||||
|
|
@ -144,7 +145,7 @@ func TestPublishingWithTTL(t *testing.T) {
|
||||||
logger := logger.NewLogger("test")
|
logger := logger.NewLogger("test")
|
||||||
|
|
||||||
rabbitMQBinding1 := NewRabbitMQ(logger).(*RabbitMQ)
|
rabbitMQBinding1 := NewRabbitMQ(logger).(*RabbitMQ)
|
||||||
err := rabbitMQBinding1.Init(metadata)
|
err := rabbitMQBinding1.Init(context.Background(), metadata)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// Assert that if waited too long, we won't see any message
|
// Assert that if waited too long, we won't see any message
|
||||||
|
|
@ -175,7 +176,7 @@ func TestPublishingWithTTL(t *testing.T) {
|
||||||
|
|
||||||
// Getting before it is expired, should return it
|
// Getting before it is expired, should return it
|
||||||
rabbitMQBinding2 := NewRabbitMQ(logger).(*RabbitMQ)
|
rabbitMQBinding2 := NewRabbitMQ(logger).(*RabbitMQ)
|
||||||
err = rabbitMQBinding2.Init(metadata)
|
err = rabbitMQBinding2.Init(context.Background(), metadata)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
const testMsgContent = "test_msg"
|
const testMsgContent = "test_msg"
|
||||||
|
|
@ -193,6 +194,9 @@ func TestPublishingWithTTL(t *testing.T) {
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
msgBody := string(msg.Body)
|
msgBody := string(msg.Body)
|
||||||
assert.Equal(t, testMsgContent, msgBody)
|
assert.Equal(t, testMsgContent, msgBody)
|
||||||
|
|
||||||
|
assert.NoError(t, rabbitMQBinding1.Close())
|
||||||
|
assert.NoError(t, rabbitMQBinding1.Close())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExclusiveQueue(t *testing.T) {
|
func TestExclusiveQueue(t *testing.T) {
|
||||||
|
|
@ -222,7 +226,7 @@ func TestExclusiveQueue(t *testing.T) {
|
||||||
logger := logger.NewLogger("test")
|
logger := logger.NewLogger("test")
|
||||||
|
|
||||||
r := NewRabbitMQ(logger).(*RabbitMQ)
|
r := NewRabbitMQ(logger).(*RabbitMQ)
|
||||||
err := r.Init(metadata)
|
err := r.Init(context.Background(), metadata)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// Assert that if waited too long, we won't see any message
|
// Assert that if waited too long, we won't see any message
|
||||||
|
|
@ -276,7 +280,7 @@ func TestPublishWithPriority(t *testing.T) {
|
||||||
logger := logger.NewLogger("test")
|
logger := logger.NewLogger("test")
|
||||||
|
|
||||||
r := NewRabbitMQ(logger).(*RabbitMQ)
|
r := NewRabbitMQ(logger).(*RabbitMQ)
|
||||||
err := r.Init(metadata)
|
err := r.Init(context.Background(), metadata)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// Assert that if waited too long, we won't see any message
|
// Assert that if waited too long, we won't see any message
|
||||||
|
|
|
||||||
|
|
@ -28,9 +28,6 @@ type Redis struct {
|
||||||
client rediscomponent.RedisClient
|
client rediscomponent.RedisClient
|
||||||
clientSettings *rediscomponent.Settings
|
clientSettings *rediscomponent.Settings
|
||||||
logger logger.Logger
|
logger logger.Logger
|
||||||
|
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRedis returns a new redis bindings instance.
|
// NewRedis returns a new redis bindings instance.
|
||||||
|
|
@ -39,15 +36,13 @@ func NewRedis(logger logger.Logger) bindings.OutputBinding {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init performs metadata parsing and connection creation.
|
// Init performs metadata parsing and connection creation.
|
||||||
func (r *Redis) Init(meta bindings.Metadata) (err error) {
|
func (r *Redis) Init(ctx context.Context, meta bindings.Metadata) (err error) {
|
||||||
r.client, r.clientSettings, err = rediscomponent.ParseClientFromProperties(meta.Properties, nil)
|
r.client, r.clientSettings, err = rediscomponent.ParseClientFromProperties(meta.Properties, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
r.ctx, r.cancel = context.WithCancel(context.Background())
|
_, err = r.client.PingResult(ctx)
|
||||||
|
|
||||||
_, err = r.client.PingResult(r.ctx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("redis binding: error connecting to redis at %s: %s", r.clientSettings.Host, err)
|
return fmt.Errorf("redis binding: error connecting to redis at %s: %s", r.clientSettings.Host, err)
|
||||||
}
|
}
|
||||||
|
|
@ -55,8 +50,8 @@ func (r *Redis) Init(meta bindings.Metadata) (err error) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Redis) Ping() error {
|
func (r *Redis) Ping(ctx context.Context) error {
|
||||||
if _, err := r.client.PingResult(r.ctx); err != nil {
|
if _, err := r.client.PingResult(ctx); err != nil {
|
||||||
return fmt.Errorf("redis binding: error connecting to redis at %s: %s", r.clientSettings.Host, err)
|
return fmt.Errorf("redis binding: error connecting to redis at %s: %s", r.clientSettings.Host, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -101,7 +96,5 @@ func (r *Redis) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindi
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Redis) Close() error {
|
func (r *Redis) Close() error {
|
||||||
r.cancel()
|
|
||||||
|
|
||||||
return r.client.Close()
|
return r.client.Close()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,6 @@ func TestInvokeCreate(t *testing.T) {
|
||||||
client: c,
|
client: c,
|
||||||
logger: logger.NewLogger("test"),
|
logger: logger.NewLogger("test"),
|
||||||
}
|
}
|
||||||
bind.ctx, bind.cancel = context.WithCancel(context.Background())
|
|
||||||
|
|
||||||
_, err := c.DoRead(context.Background(), "GET", testKey)
|
_, err := c.DoRead(context.Background(), "GET", testKey)
|
||||||
assert.Equal(t, redis.Nil, err)
|
assert.Equal(t, redis.Nil, err)
|
||||||
|
|
@ -66,7 +65,6 @@ func TestInvokeGet(t *testing.T) {
|
||||||
client: c,
|
client: c,
|
||||||
logger: logger.NewLogger("test"),
|
logger: logger.NewLogger("test"),
|
||||||
}
|
}
|
||||||
bind.ctx, bind.cancel = context.WithCancel(context.Background())
|
|
||||||
|
|
||||||
err := c.DoWrite(context.Background(), "SET", testKey, testData)
|
err := c.DoWrite(context.Background(), "SET", testKey, testData)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
|
|
@ -87,7 +85,6 @@ func TestInvokeDelete(t *testing.T) {
|
||||||
client: c,
|
client: c,
|
||||||
logger: logger.NewLogger("test"),
|
logger: logger.NewLogger("test"),
|
||||||
}
|
}
|
||||||
bind.ctx, bind.cancel = context.WithCancel(context.Background())
|
|
||||||
|
|
||||||
err := c.DoWrite(context.Background(), "SET", testKey, testData)
|
err := c.DoWrite(context.Background(), "SET", testKey, testData)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,8 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
r "github.com/dancannon/gorethink"
|
r "github.com/dancannon/gorethink"
|
||||||
|
|
@ -34,6 +36,9 @@ type Binding struct {
|
||||||
logger logger.Logger
|
logger logger.Logger
|
||||||
session *r.Session
|
session *r.Session
|
||||||
config StateConfig
|
config StateConfig
|
||||||
|
closed atomic.Bool
|
||||||
|
closeCh chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
// StateConfig is the binding config.
|
// StateConfig is the binding config.
|
||||||
|
|
@ -45,12 +50,13 @@ type StateConfig struct {
|
||||||
// NewRethinkDBStateChangeBinding returns a new RethinkDB actor event input binding.
|
// NewRethinkDBStateChangeBinding returns a new RethinkDB actor event input binding.
|
||||||
func NewRethinkDBStateChangeBinding(logger logger.Logger) bindings.InputBinding {
|
func NewRethinkDBStateChangeBinding(logger logger.Logger) bindings.InputBinding {
|
||||||
return &Binding{
|
return &Binding{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
closeCh: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init initializes the RethinkDB binding.
|
// Init initializes the RethinkDB binding.
|
||||||
func (b *Binding) Init(metadata bindings.Metadata) error {
|
func (b *Binding) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||||
cfg, err := metadataToConfig(metadata.Properties, b.logger)
|
cfg, err := metadataToConfig(metadata.Properties, b.logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "unable to parse metadata properties")
|
return errors.Wrap(err, "unable to parse metadata properties")
|
||||||
|
|
@ -68,6 +74,10 @@ func (b *Binding) Init(metadata bindings.Metadata) error {
|
||||||
|
|
||||||
// Read triggers the RethinkDB scheduler.
|
// Read triggers the RethinkDB scheduler.
|
||||||
func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
|
func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
|
if b.closed.Load() {
|
||||||
|
return errors.New("binding is closed")
|
||||||
|
}
|
||||||
|
|
||||||
b.logger.Infof("subscribing to state changes in %s.%s...", b.config.Database, b.config.Table)
|
b.logger.Infof("subscribing to state changes in %s.%s...", b.config.Database, b.config.Table)
|
||||||
cursor, err := r.DB(b.config.Database).
|
cursor, err := r.DB(b.config.Database).
|
||||||
Table(b.config.Table).
|
Table(b.config.Table).
|
||||||
|
|
@ -81,8 +91,21 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
errors.Wrapf(err, "error connecting to table %s", b.config.Table)
|
errors.Wrapf(err, "error connecting to table %s", b.config.Table)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
readCtx, cancel := context.WithCancel(ctx)
|
||||||
|
b.wg.Add(2)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for ctx.Err() == nil {
|
defer b.wg.Done()
|
||||||
|
defer cancel()
|
||||||
|
select {
|
||||||
|
case <-b.closeCh:
|
||||||
|
case <-readCtx.Done():
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer b.wg.Done()
|
||||||
|
for readCtx.Err() == nil {
|
||||||
var change interface{}
|
var change interface{}
|
||||||
ok := cursor.Next(&change)
|
ok := cursor.Next(&change)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|
@ -105,7 +128,7 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := handler(ctx, resp); err != nil {
|
if _, err := handler(readCtx, resp); err != nil {
|
||||||
b.logger.Errorf("error invoking change handler: %v", err)
|
b.logger.Errorf("error invoking change handler: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
@ -117,6 +140,14 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *Binding) Close() error {
|
||||||
|
if b.closed.CompareAndSwap(false, true) {
|
||||||
|
close(b.closeCh)
|
||||||
|
}
|
||||||
|
defer b.wg.Wait()
|
||||||
|
return b.session.Close()
|
||||||
|
}
|
||||||
|
|
||||||
func metadataToConfig(cfg map[string]string, logger logger.Logger) (StateConfig, error) {
|
func metadataToConfig(cfg map[string]string, logger logger.Logger) (StateConfig, error) {
|
||||||
c := StateConfig{}
|
c := StateConfig{}
|
||||||
for k, v := range cfg {
|
for k, v := range cfg {
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,7 @@ func TestBinding(t *testing.T) {
|
||||||
assert.NotNil(t, m.Properties)
|
assert.NotNil(t, m.Properties)
|
||||||
|
|
||||||
b := getNewRethinkActorBinding()
|
b := getNewRethinkActorBinding()
|
||||||
err := b.Init(m)
|
err := b.Init(context.Background(), m)
|
||||||
assert.NoErrorf(t, err, "error initializing")
|
assert.NoErrorf(t, err, "error initializing")
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,8 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
mqc "github.com/apache/rocketmq-client-go/v2/consumer"
|
mqc "github.com/apache/rocketmq-client-go/v2/consumer"
|
||||||
|
|
@ -35,27 +37,27 @@ type RocketMQ struct {
|
||||||
settings Settings
|
settings Settings
|
||||||
producer mqw.Producer
|
producer mqw.Producer
|
||||||
|
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
backOffConfig retry.Config
|
backOffConfig retry.Config
|
||||||
|
closeCh chan struct{}
|
||||||
|
closed atomic.Bool
|
||||||
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRocketMQ(l logger.Logger) *RocketMQ {
|
func NewRocketMQ(l logger.Logger) *RocketMQ {
|
||||||
return &RocketMQ{ //nolint:exhaustivestruct
|
return &RocketMQ{ //nolint:exhaustivestruct
|
||||||
logger: l,
|
logger: l,
|
||||||
producer: nil,
|
producer: nil,
|
||||||
|
closeCh: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init performs metadata parsing.
|
// Init performs metadata parsing.
|
||||||
func (a *RocketMQ) Init(metadata bindings.Metadata) error {
|
func (a *RocketMQ) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||||
var err error
|
var err error
|
||||||
if err = a.settings.Decode(metadata.Properties); err != nil {
|
if err = a.settings.Decode(metadata.Properties); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
a.ctx, a.cancel = context.WithCancel(context.Background())
|
|
||||||
|
|
||||||
// Default retry configuration is used if no
|
// Default retry configuration is used if no
|
||||||
// backOff properties are set.
|
// backOff properties are set.
|
||||||
if err = retry.DecodeConfigWithPrefix(
|
if err = retry.DecodeConfigWithPrefix(
|
||||||
|
|
@ -75,6 +77,10 @@ func (a *RocketMQ) Init(metadata bindings.Metadata) error {
|
||||||
|
|
||||||
// Read triggers the rocketmq subscription.
|
// Read triggers the rocketmq subscription.
|
||||||
func (a *RocketMQ) Read(ctx context.Context, handler bindings.Handler) error {
|
func (a *RocketMQ) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
|
if a.closed.Load() {
|
||||||
|
return errors.New("error: binding is closed")
|
||||||
|
}
|
||||||
|
|
||||||
a.logger.Debugf("binding rocketmq: start read input binding")
|
a.logger.Debugf("binding rocketmq: start read input binding")
|
||||||
|
|
||||||
consumer, err := a.setupConsumer()
|
consumer, err := a.setupConsumer()
|
||||||
|
|
@ -114,10 +120,12 @@ func (a *RocketMQ) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
a.logger.Debugf("binding-rocketmq: consumer started")
|
a.logger.Debugf("binding-rocketmq: consumer started")
|
||||||
|
|
||||||
// Listen for context cancelation to stop the subscription
|
// Listen for context cancelation to stop the subscription
|
||||||
|
a.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer a.wg.Done()
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
case <-a.ctx.Done():
|
case <-a.closeCh:
|
||||||
}
|
}
|
||||||
|
|
||||||
innerErr := consumer.Shutdown()
|
innerErr := consumer.Shutdown()
|
||||||
|
|
@ -131,8 +139,10 @@ func (a *RocketMQ) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
|
|
||||||
// Close implements cancel all listeners, see https://github.com/dapr/components-contrib/issues/779
|
// Close implements cancel all listeners, see https://github.com/dapr/components-contrib/issues/779
|
||||||
func (a *RocketMQ) Close() error {
|
func (a *RocketMQ) Close() error {
|
||||||
a.cancel()
|
defer a.wg.Wait()
|
||||||
|
if a.closed.CompareAndSwap(false, true) {
|
||||||
|
close(a.closeCh)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -199,21 +209,21 @@ func (a *RocketMQ) Operations() []bindings.OperationKind {
|
||||||
return []bindings.OperationKind{bindings.CreateOperation}
|
return []bindings.OperationKind{bindings.CreateOperation}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *RocketMQ) Invoke(req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
|
func (a *RocketMQ) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
|
||||||
rst := &bindings.InvokeResponse{Data: nil, Metadata: nil}
|
rst := &bindings.InvokeResponse{Data: nil, Metadata: nil}
|
||||||
|
|
||||||
if req.Operation != bindings.CreateOperation {
|
if req.Operation != bindings.CreateOperation {
|
||||||
return rst, fmt.Errorf("binding-rocketmq error: unsupported operation %s", req.Operation)
|
return rst, fmt.Errorf("binding-rocketmq error: unsupported operation %s", req.Operation)
|
||||||
}
|
}
|
||||||
|
|
||||||
return rst, a.sendMessage(req)
|
return rst, a.sendMessage(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *RocketMQ) sendMessage(req *bindings.InvokeRequest) error {
|
func (a *RocketMQ) sendMessage(ctx context.Context, req *bindings.InvokeRequest) error {
|
||||||
topic := req.Metadata[metadataRocketmqTopic]
|
topic := req.Metadata[metadataRocketmqTopic]
|
||||||
|
|
||||||
if topic != "" {
|
if topic != "" {
|
||||||
_, err := a.send(topic, req.Metadata[metadataRocketmqTag], req.Metadata[metadataRocketmqKey], req.Data)
|
_, err := a.send(ctx, topic, req.Metadata[metadataRocketmqTag], req.Metadata[metadataRocketmqKey], req.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -229,7 +239,7 @@ func (a *RocketMQ) sendMessage(req *bindings.InvokeRequest) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = a.send(topic, mqExpression, req.Metadata[metadataRocketmqKey], req.Data)
|
_, err = a.send(ctx, topic, mqExpression, req.Metadata[metadataRocketmqKey], req.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -239,9 +249,9 @@ func (a *RocketMQ) sendMessage(req *bindings.InvokeRequest) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *RocketMQ) send(topic, mqExpr, key string, data []byte) (bool, error) {
|
func (a *RocketMQ) send(ctx context.Context, topic, mqExpr, key string, data []byte) (bool, error) {
|
||||||
msg := primitive.NewMessage(topic, data).WithTag(mqExpr).WithKeys([]string{key})
|
msg := primitive.NewMessage(topic, data).WithTag(mqExpr).WithKeys([]string{key})
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
rst, err := a.producer.SendSync(ctx, msg)
|
rst, err := a.producer.SendSync(ctx, msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ func TestInputBindingRead(t *testing.T) { //nolint:paralleltest
|
||||||
m := bindings.Metadata{} //nolint:exhaustivestruct
|
m := bindings.Metadata{} //nolint:exhaustivestruct
|
||||||
m.Properties = getTestMetadata()
|
m.Properties = getTestMetadata()
|
||||||
r := NewRocketMQ(logger.NewLogger("test"))
|
r := NewRocketMQ(logger.NewLogger("test"))
|
||||||
err := r.Init(m)
|
err := r.Init(context.Background(), m)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var count int32
|
var count int32
|
||||||
|
|
@ -51,7 +51,7 @@ func TestInputBindingRead(t *testing.T) { //nolint:paralleltest
|
||||||
time.Sleep(5 * time.Second)
|
time.Sleep(5 * time.Second)
|
||||||
atomic.StoreInt32(&count, 0)
|
atomic.StoreInt32(&count, 0)
|
||||||
req := &bindings.InvokeRequest{Data: []byte("hello"), Operation: bindings.CreateOperation, Metadata: map[string]string{}}
|
req := &bindings.InvokeRequest{Data: []byte("hello"), Operation: bindings.CreateOperation, Metadata: map[string]string{}}
|
||||||
_, err = r.Invoke(req)
|
_, err = r.Invoke(context.Background(), req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(10 * time.Second)
|
time.Sleep(10 * time.Second)
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,7 @@ func NewSMTP(logger logger.Logger) bindings.OutputBinding {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init smtp component (parse metadata).
|
// Init smtp component (parse metadata).
|
||||||
func (s *Mailer) Init(metadata bindings.Metadata) error {
|
func (s *Mailer) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||||
// parse metadata
|
// parse metadata
|
||||||
meta, err := s.parseMetadata(metadata)
|
meta, err := s.parseMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -84,7 +84,7 @@ func (sg *SendGrid) parseMetadata(meta bindings.Metadata) (sendGridMetadata, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init does metadata parsing and not much else :).
|
// Init does metadata parsing and not much else :).
|
||||||
func (sg *SendGrid) Init(metadata bindings.Metadata) error {
|
func (sg *SendGrid) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||||
// Parse input metadata
|
// Parse input metadata
|
||||||
meta, err := sg.parseMetadata(metadata)
|
meta, err := sg.parseMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -60,19 +60,19 @@ func NewSMS(logger logger.Logger) bindings.OutputBinding {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *SMS) Init(metadata bindings.Metadata) error {
|
func (t *SMS) Init(_ context.Context, metadata bindings.Metadata) error {
|
||||||
twilioM := twilioMetadata{
|
twilioM := twilioMetadata{
|
||||||
timeout: time.Minute * 5,
|
timeout: time.Minute * 5,
|
||||||
}
|
}
|
||||||
|
|
||||||
if metadata.Properties[fromNumber] == "" {
|
if metadata.Properties[fromNumber] == "" {
|
||||||
return errors.New("\"fromNumber\" is a required field")
|
return errors.New(`"fromNumber" is a required field`)
|
||||||
}
|
}
|
||||||
if metadata.Properties[accountSid] == "" {
|
if metadata.Properties[accountSid] == "" {
|
||||||
return errors.New("\"accountSid\" is a required field")
|
return errors.New(`"accountSid" is a required field`)
|
||||||
}
|
}
|
||||||
if metadata.Properties[authToken] == "" {
|
if metadata.Properties[authToken] == "" {
|
||||||
return errors.New("\"authToken\" is a required field")
|
return errors.New(`"authToken" is a required field`)
|
||||||
}
|
}
|
||||||
|
|
||||||
twilioM.toNumber = metadata.Properties[toNumber]
|
twilioM.toNumber = metadata.Properties[toNumber]
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ func TestInit(t *testing.T) {
|
||||||
m := bindings.Metadata{}
|
m := bindings.Metadata{}
|
||||||
m.Properties = map[string]string{"toNumber": "toNumber", "fromNumber": "fromNumber"}
|
m.Properties = map[string]string{"toNumber": "toNumber", "fromNumber": "fromNumber"}
|
||||||
tw := NewSMS(logger.NewLogger("test"))
|
tw := NewSMS(logger.NewLogger("test"))
|
||||||
err := tw.Init(m)
|
err := tw.Init(context.Background(), m)
|
||||||
assert.NotNil(t, err)
|
assert.NotNil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -66,7 +66,7 @@ func TestParseDuration(t *testing.T) {
|
||||||
"authToken": "authToken", "timeout": "badtimeout",
|
"authToken": "authToken", "timeout": "badtimeout",
|
||||||
}
|
}
|
||||||
tw := NewSMS(logger.NewLogger("test"))
|
tw := NewSMS(logger.NewLogger("test"))
|
||||||
err := tw.Init(m)
|
err := tw.Init(context.Background(), m)
|
||||||
assert.NotNil(t, err)
|
assert.NotNil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -85,7 +85,7 @@ func TestWriteShouldSucceed(t *testing.T) {
|
||||||
tw.httpClient = &http.Client{
|
tw.httpClient = &http.Client{
|
||||||
Transport: httpTransport,
|
Transport: httpTransport,
|
||||||
}
|
}
|
||||||
err := tw.Init(m)
|
err := tw.Init(context.Background(), m)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
t.Run("Should succeed with expected url and headers", func(t *testing.T) {
|
t.Run("Should succeed with expected url and headers", func(t *testing.T) {
|
||||||
|
|
@ -123,7 +123,7 @@ func TestWriteShouldFail(t *testing.T) {
|
||||||
tw.httpClient = &http.Client{
|
tw.httpClient = &http.Client{
|
||||||
Transport: httpTransport,
|
Transport: httpTransport,
|
||||||
}
|
}
|
||||||
err := tw.Init(m)
|
err := tw.Init(context.Background(), m)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
t.Run("Missing 'to' should fail", func(t *testing.T) {
|
t.Run("Missing 'to' should fail", func(t *testing.T) {
|
||||||
|
|
@ -180,7 +180,7 @@ func TestMessageBody(t *testing.T) {
|
||||||
tw.httpClient = &http.Client{
|
tw.httpClient = &http.Client{
|
||||||
Transport: httpTransport,
|
Transport: httpTransport,
|
||||||
}
|
}
|
||||||
err := tw.Init(m)
|
err := tw.Init(context.Background(), m)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
tester := func(reqData []byte, expectBody string) func(t *testing.T) {
|
tester := func(reqData []byte, expectBody string) func(t *testing.T) {
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,8 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/dghubble/go-twitter/twitter"
|
"github.com/dghubble/go-twitter/twitter"
|
||||||
|
|
@ -31,18 +33,21 @@ import (
|
||||||
|
|
||||||
// Binding represents Twitter input/output binding.
|
// Binding represents Twitter input/output binding.
|
||||||
type Binding struct {
|
type Binding struct {
|
||||||
client *twitter.Client
|
client *twitter.Client
|
||||||
query string
|
query string
|
||||||
logger logger.Logger
|
logger logger.Logger
|
||||||
|
closed atomic.Bool
|
||||||
|
closeCh chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTwitter returns a new Twitter event input binding.
|
// NewTwitter returns a new Twitter event input binding.
|
||||||
func NewTwitter(logger logger.Logger) bindings.InputOutputBinding {
|
func NewTwitter(logger logger.Logger) bindings.InputOutputBinding {
|
||||||
return &Binding{logger: logger}
|
return &Binding{logger: logger, closeCh: make(chan struct{})}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init initializes the Twitter binding.
|
// Init initializes the Twitter binding.
|
||||||
func (t *Binding) Init(metadata bindings.Metadata) error {
|
func (t *Binding) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||||
ck, f := metadata.Properties["consumerKey"]
|
ck, f := metadata.Properties["consumerKey"]
|
||||||
if !f || ck == "" {
|
if !f || ck == "" {
|
||||||
return fmt.Errorf("consumerKey not set")
|
return fmt.Errorf("consumerKey not set")
|
||||||
|
|
@ -124,10 +129,17 @@ func (t *Binding) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
t.logger.Debug("starting handler...")
|
t.logger.Debug("starting handler...")
|
||||||
go demux.HandleChan(stream.Messages)
|
t.wg.Add(2)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
<-ctx.Done()
|
defer t.wg.Done()
|
||||||
|
demux.HandleChan(stream.Messages)
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer t.wg.Done()
|
||||||
|
select {
|
||||||
|
case <-t.closeCh:
|
||||||
|
case <-ctx.Done():
|
||||||
|
}
|
||||||
t.logger.Debug("stopping handler...")
|
t.logger.Debug("stopping handler...")
|
||||||
stream.Stop()
|
stream.Stop()
|
||||||
}()
|
}()
|
||||||
|
|
@ -135,6 +147,14 @@ func (t *Binding) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Binding) Close() error {
|
||||||
|
if t.closed.CompareAndSwap(false, true) {
|
||||||
|
close(t.closeCh)
|
||||||
|
}
|
||||||
|
t.wg.Wait()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Invoke handles all operations.
|
// Invoke handles all operations.
|
||||||
func (t *Binding) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
|
func (t *Binding) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
|
||||||
t.logger.Debugf("operation: %v", req.Operation)
|
t.logger.Debugf("operation: %v", req.Operation)
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,7 @@ func getRuntimeMetadata() map[string]string {
|
||||||
func TestInit(t *testing.T) {
|
func TestInit(t *testing.T) {
|
||||||
m := getTestMetadata()
|
m := getTestMetadata()
|
||||||
tw := NewTwitter(logger.NewLogger("test")).(*Binding)
|
tw := NewTwitter(logger.NewLogger("test")).(*Binding)
|
||||||
err := tw.Init(m)
|
err := tw.Init(context.Background(), m)
|
||||||
assert.Nilf(t, err, "error initializing valid metadata properties")
|
assert.Nilf(t, err, "error initializing valid metadata properties")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -69,7 +69,7 @@ func TestInit(t *testing.T) {
|
||||||
func TestReadError(t *testing.T) {
|
func TestReadError(t *testing.T) {
|
||||||
tw := NewTwitter(logger.NewLogger("test")).(*Binding)
|
tw := NewTwitter(logger.NewLogger("test")).(*Binding)
|
||||||
m := getTestMetadata()
|
m := getTestMetadata()
|
||||||
err := tw.Init(m)
|
err := tw.Init(context.Background(), m)
|
||||||
assert.Nilf(t, err, "error initializing valid metadata properties")
|
assert.Nilf(t, err, "error initializing valid metadata properties")
|
||||||
|
|
||||||
err = tw.Read(context.Background(), func(ctx context.Context, res *bindings.ReadResponse) ([]byte, error) {
|
err = tw.Read(context.Background(), func(ctx context.Context, res *bindings.ReadResponse) ([]byte, error) {
|
||||||
|
|
@ -79,6 +79,8 @@ func TestReadError(t *testing.T) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
})
|
})
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
assert.NoError(t, tw.Close())
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestRead executes the Read method which calls Twiter API
|
// TestRead executes the Read method which calls Twiter API
|
||||||
|
|
@ -93,7 +95,7 @@ func TestRead(t *testing.T) {
|
||||||
m.Properties["query"] = "microsoft"
|
m.Properties["query"] = "microsoft"
|
||||||
tw := NewTwitter(logger.NewLogger("test")).(*Binding)
|
tw := NewTwitter(logger.NewLogger("test")).(*Binding)
|
||||||
tw.logger.SetOutputLevel(logger.DebugLevel)
|
tw.logger.SetOutputLevel(logger.DebugLevel)
|
||||||
err := tw.Init(m)
|
err := tw.Init(context.Background(), m)
|
||||||
assert.Nilf(t, err, "error initializing read")
|
assert.Nilf(t, err, "error initializing read")
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
@ -116,6 +118,8 @@ func TestRead(t *testing.T) {
|
||||||
cancel()
|
cancel()
|
||||||
t.Fatal("Timeout waiting for messages")
|
t.Fatal("Timeout waiting for messages")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
assert.NoError(t, tw.Close())
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestInvoke executes the Invoke method which calls Twiter API
|
// TestInvoke executes the Invoke method which calls Twiter API
|
||||||
|
|
@ -129,7 +133,7 @@ func TestInvoke(t *testing.T) {
|
||||||
m.Properties = getRuntimeMetadata()
|
m.Properties = getRuntimeMetadata()
|
||||||
tw := NewTwitter(logger.NewLogger("test")).(*Binding)
|
tw := NewTwitter(logger.NewLogger("test")).(*Binding)
|
||||||
tw.logger.SetOutputLevel(logger.DebugLevel)
|
tw.logger.SetOutputLevel(logger.DebugLevel)
|
||||||
err := tw.Init(m)
|
err := tw.Init(context.Background(), m)
|
||||||
assert.Nilf(t, err, "error initializing Invoke")
|
assert.Nilf(t, err, "error initializing Invoke")
|
||||||
|
|
||||||
req := &bindings.InvokeRequest{
|
req := &bindings.InvokeRequest{
|
||||||
|
|
@ -141,4 +145,5 @@ func TestInvoke(t *testing.T) {
|
||||||
resp, err := tw.Invoke(context.Background(), req)
|
resp, err := tw.Invoke(context.Background(), req)
|
||||||
assert.Nilf(t, err, "error on invoke")
|
assert.Nilf(t, err, "error on invoke")
|
||||||
assert.NotNil(t, resp)
|
assert.NotNil(t, resp)
|
||||||
|
assert.NoError(t, tw.Close())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,7 @@ func NewZeebeCommand(logger logger.Logger) bindings.OutputBinding {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init does metadata parsing and connection creation.
|
// Init does metadata parsing and connection creation.
|
||||||
func (z *ZeebeCommand) Init(metadata bindings.Metadata) error {
|
func (z *ZeebeCommand) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||||
client, err := z.clientFactory.Get(metadata)
|
client, err := z.clientFactory.Get(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -114,7 +114,7 @@ func (z *ZeebeCommand) Invoke(ctx context.Context, req *bindings.InvokeRequest)
|
||||||
case UpdateJobRetriesOperation:
|
case UpdateJobRetriesOperation:
|
||||||
return z.updateJobRetries(ctx, req)
|
return z.updateJobRetries(ctx, req)
|
||||||
case ThrowErrorOperation:
|
case ThrowErrorOperation:
|
||||||
return z.throwError(req)
|
return z.throwError(ctx, req)
|
||||||
case bindings.GetOperation:
|
case bindings.GetOperation:
|
||||||
fallthrough
|
fallthrough
|
||||||
case bindings.CreateOperation:
|
case bindings.CreateOperation:
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,7 @@ func TestInit(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd := ZeebeCommand{clientFactory: mcf, logger: testLogger}
|
cmd := ZeebeCommand{clientFactory: mcf, logger: testLogger}
|
||||||
err := cmd.Init(metadata)
|
err := cmd.Init(context.Background(), metadata)
|
||||||
assert.Error(t, err, errParsing)
|
assert.Error(t, err, errParsing)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
@ -67,7 +67,7 @@ func TestInit(t *testing.T) {
|
||||||
mcf := mockClientFactory{}
|
mcf := mockClientFactory{}
|
||||||
|
|
||||||
cmd := ZeebeCommand{clientFactory: mcf, logger: testLogger}
|
cmd := ZeebeCommand{clientFactory: mcf, logger: testLogger}
|
||||||
err := cmd.Init(metadata)
|
err := cmd.Init(context.Background(), metadata)
|
||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ type throwErrorPayload struct {
|
||||||
ErrorMessage string `json:"errorMessage"`
|
ErrorMessage string `json:"errorMessage"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (z *ZeebeCommand) throwError(req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
|
func (z *ZeebeCommand) throwError(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
|
||||||
var payload throwErrorPayload
|
var payload throwErrorPayload
|
||||||
err := json.Unmarshal(req.Data, &payload)
|
err := json.Unmarshal(req.Data, &payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -53,7 +53,7 @@ func (z *ZeebeCommand) throwError(req *bindings.InvokeRequest) (*bindings.Invoke
|
||||||
cmd = cmd.ErrorMessage(payload.ErrorMessage)
|
cmd = cmd.ErrorMessage(payload.ErrorMessage)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = cmd.Send(context.Background())
|
_, err = cmd.Send(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("cannot throw error for job key %d: %w", payload.JobKey, err)
|
return nil, fmt.Errorf("cannot throw error for job key %d: %w", payload.JobKey, err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,8 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/camunda/zeebe/clients/go/v8/pkg/entities"
|
"github.com/camunda/zeebe/clients/go/v8/pkg/entities"
|
||||||
|
|
@ -39,6 +41,9 @@ type ZeebeJobWorker struct {
|
||||||
client zbc.Client
|
client zbc.Client
|
||||||
metadata *jobWorkerMetadata
|
metadata *jobWorkerMetadata
|
||||||
logger logger.Logger
|
logger logger.Logger
|
||||||
|
closed atomic.Bool
|
||||||
|
closeCh chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://docs.zeebe.io/basics/job-workers.html
|
// https://docs.zeebe.io/basics/job-workers.html
|
||||||
|
|
@ -64,11 +69,15 @@ type jobHandler struct {
|
||||||
|
|
||||||
// NewZeebeJobWorker returns a new ZeebeJobWorker instance.
|
// NewZeebeJobWorker returns a new ZeebeJobWorker instance.
|
||||||
func NewZeebeJobWorker(logger logger.Logger) bindings.InputBinding {
|
func NewZeebeJobWorker(logger logger.Logger) bindings.InputBinding {
|
||||||
return &ZeebeJobWorker{clientFactory: zeebe.NewClientFactoryImpl(logger), logger: logger}
|
return &ZeebeJobWorker{
|
||||||
|
clientFactory: zeebe.NewClientFactoryImpl(logger),
|
||||||
|
logger: logger,
|
||||||
|
closeCh: make(chan struct{}),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init does metadata parsing and connection creation.
|
// Init does metadata parsing and connection creation.
|
||||||
func (z *ZeebeJobWorker) Init(metadata bindings.Metadata) error {
|
func (z *ZeebeJobWorker) Init(ctx context.Context, metadata bindings.Metadata) error {
|
||||||
meta, err := z.parseMetadata(metadata)
|
meta, err := z.parseMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -90,6 +99,10 @@ func (z *ZeebeJobWorker) Init(metadata bindings.Metadata) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (z *ZeebeJobWorker) Read(ctx context.Context, handler bindings.Handler) error {
|
func (z *ZeebeJobWorker) Read(ctx context.Context, handler bindings.Handler) error {
|
||||||
|
if z.closed.Load() {
|
||||||
|
return fmt.Errorf("binding is closed")
|
||||||
|
}
|
||||||
|
|
||||||
h := jobHandler{
|
h := jobHandler{
|
||||||
callback: handler,
|
callback: handler,
|
||||||
logger: z.logger,
|
logger: z.logger,
|
||||||
|
|
@ -99,8 +112,14 @@ func (z *ZeebeJobWorker) Read(ctx context.Context, handler bindings.Handler) err
|
||||||
|
|
||||||
jobWorker := z.getJobWorker(h)
|
jobWorker := z.getJobWorker(h)
|
||||||
|
|
||||||
|
z.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
<-ctx.Done()
|
defer z.wg.Done()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-z.closeCh:
|
||||||
|
case <-ctx.Done():
|
||||||
|
}
|
||||||
|
|
||||||
jobWorker.Close()
|
jobWorker.Close()
|
||||||
jobWorker.AwaitClose()
|
jobWorker.AwaitClose()
|
||||||
|
|
@ -110,6 +129,14 @@ func (z *ZeebeJobWorker) Read(ctx context.Context, handler bindings.Handler) err
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (z *ZeebeJobWorker) Close() error {
|
||||||
|
if z.closed.CompareAndSwap(false, true) {
|
||||||
|
close(z.closeCh)
|
||||||
|
}
|
||||||
|
z.wg.Wait()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (z *ZeebeJobWorker) parseMetadata(meta bindings.Metadata) (*jobWorkerMetadata, error) {
|
func (z *ZeebeJobWorker) parseMetadata(meta bindings.Metadata) (*jobWorkerMetadata, error) {
|
||||||
var m jobWorkerMetadata
|
var m jobWorkerMetadata
|
||||||
err := metadata.DecodeMetadata(meta.Properties, &m)
|
err := metadata.DecodeMetadata(meta.Properties, &m)
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||||
package jobworker
|
package jobworker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
|
@ -53,10 +54,11 @@ func TestInit(t *testing.T) {
|
||||||
metadata := bindings.Metadata{}
|
metadata := bindings.Metadata{}
|
||||||
var mcf mockClientFactory
|
var mcf mockClientFactory
|
||||||
|
|
||||||
jobWorker := ZeebeJobWorker{clientFactory: &mcf, logger: testLogger}
|
jobWorker := ZeebeJobWorker{clientFactory: &mcf, logger: testLogger, closeCh: make(chan struct{})}
|
||||||
err := jobWorker.Init(metadata)
|
err := jobWorker.Init(context.Background(), metadata)
|
||||||
|
|
||||||
assert.Error(t, err, ErrMissingJobType)
|
assert.Error(t, err, ErrMissingJobType)
|
||||||
|
assert.NoError(t, jobWorker.Close())
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("sets client from client factory", func(t *testing.T) {
|
t.Run("sets client from client factory", func(t *testing.T) {
|
||||||
|
|
@ -66,8 +68,8 @@ func TestInit(t *testing.T) {
|
||||||
mcf := mockClientFactory{
|
mcf := mockClientFactory{
|
||||||
metadata: metadata,
|
metadata: metadata,
|
||||||
}
|
}
|
||||||
jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger}
|
jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger, closeCh: make(chan struct{})}
|
||||||
err := jobWorker.Init(metadata)
|
err := jobWorker.Init(context.Background(), metadata)
|
||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
|
@ -76,6 +78,7 @@ func TestInit(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, mc, jobWorker.client)
|
assert.Equal(t, mc, jobWorker.client)
|
||||||
assert.Equal(t, metadata, mcf.metadata)
|
assert.Equal(t, metadata, mcf.metadata)
|
||||||
|
assert.NoError(t, jobWorker.Close())
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("returns error if client could not be instantiated properly", func(t *testing.T) {
|
t.Run("returns error if client could not be instantiated properly", func(t *testing.T) {
|
||||||
|
|
@ -85,9 +88,10 @@ func TestInit(t *testing.T) {
|
||||||
error: errParsing,
|
error: errParsing,
|
||||||
}
|
}
|
||||||
|
|
||||||
jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger}
|
jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger, closeCh: make(chan struct{})}
|
||||||
err := jobWorker.Init(metadata)
|
err := jobWorker.Init(context.Background(), metadata)
|
||||||
assert.Error(t, err, errParsing)
|
assert.Error(t, err, errParsing)
|
||||||
|
assert.NoError(t, jobWorker.Close())
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("sets client from client factory", func(t *testing.T) {
|
t.Run("sets client from client factory", func(t *testing.T) {
|
||||||
|
|
@ -98,8 +102,8 @@ func TestInit(t *testing.T) {
|
||||||
metadata: metadata,
|
metadata: metadata,
|
||||||
}
|
}
|
||||||
|
|
||||||
jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger}
|
jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger, closeCh: make(chan struct{})}
|
||||||
err := jobWorker.Init(metadata)
|
err := jobWorker.Init(context.Background(), metadata)
|
||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
|
@ -108,5 +112,6 @@ func TestInit(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, mc, jobWorker.client)
|
assert.Equal(t, mc, jobWorker.client)
|
||||||
assert.Equal(t, metadata, mcf.metadata)
|
assert.Equal(t, metadata, mcf.metadata)
|
||||||
|
assert.NoError(t, jobWorker.Close())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -73,7 +73,7 @@ func NewAzureAppConfigurationStore(logger logger.Logger) configuration.Store {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init does metadata and connection parsing.
|
// Init does metadata and connection parsing.
|
||||||
func (r *ConfigurationStore) Init(metadata configuration.Metadata) error {
|
func (r *ConfigurationStore) Init(_ context.Context, metadata configuration.Metadata) error {
|
||||||
m, err := parseMetadata(metadata)
|
m, err := parseMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
||||||
|
|
@ -204,7 +204,7 @@ func TestInit(t *testing.T) {
|
||||||
Properties: testProperties,
|
Properties: testProperties,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
err := s.Init(m)
|
err := s.Init(context.Background(), m)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
cs, ok := s.(*ConfigurationStore)
|
cs, ok := s.(*ConfigurationStore)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
@ -229,7 +229,7 @@ func TestInit(t *testing.T) {
|
||||||
Properties: testProperties,
|
Properties: testProperties,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
err := s.Init(m)
|
err := s.Init(context.Background(), m)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
cs, ok := s.(*ConfigurationStore)
|
cs, ok := s.(*ConfigurationStore)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
|
||||||
|
|
@ -86,7 +86,7 @@ func NewPostgresConfigurationStore(logger logger.Logger) configuration.Store {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ConfigurationStore) Init(metadata configuration.Metadata) error {
|
func (p *ConfigurationStore) Init(parentCtx context.Context, metadata configuration.Metadata) error {
|
||||||
p.logger.Debug(InfoStartInit)
|
p.logger.Debug(InfoStartInit)
|
||||||
if p.client != nil {
|
if p.client != nil {
|
||||||
return fmt.Errorf(ErrorAlreadyInitialized)
|
return fmt.Errorf(ErrorAlreadyInitialized)
|
||||||
|
|
@ -98,7 +98,7 @@ func (p *ConfigurationStore) Init(metadata configuration.Metadata) error {
|
||||||
p.metadata = m
|
p.metadata = m
|
||||||
}
|
}
|
||||||
p.ActiveSubscriptions = make(map[string]*subscription)
|
p.ActiveSubscriptions = make(map[string]*subscription)
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), p.metadata.maxIdleTimeout)
|
ctx, cancel := context.WithTimeout(parentCtx, p.metadata.maxIdleTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
client, err := Connect(ctx, p.metadata.connectionString, p.metadata.maxIdleTimeout)
|
client, err := Connect(ctx, p.metadata.connectionString, p.metadata.maxIdleTimeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -143,7 +143,7 @@ func parseRedisMetadata(meta configuration.Metadata) (metadata, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init does metadata and connection parsing.
|
// Init does metadata and connection parsing.
|
||||||
func (r *ConfigurationStore) Init(metadata configuration.Metadata) error {
|
func (r *ConfigurationStore) Init(ctx context.Context, metadata configuration.Metadata) error {
|
||||||
m, err := parseRedisMetadata(metadata)
|
m, err := parseRedisMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -156,11 +156,11 @@ func (r *ConfigurationStore) Init(metadata configuration.Metadata) error {
|
||||||
r.client = r.newClient(m)
|
r.client = r.newClient(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err = r.client.Ping(context.TODO()).Result(); err != nil {
|
if _, err = r.client.Ping(ctx).Result(); err != nil {
|
||||||
return fmt.Errorf("redis store: error connecting to redis at %s: %s", m.Host, err)
|
return fmt.Errorf("redis store: error connecting to redis at %s: %s", m.Host, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
r.replicas, err = r.getConnectedSlaves()
|
r.replicas, err = r.getConnectedSlaves(ctx)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -204,8 +204,8 @@ func (r *ConfigurationStore) newFailoverClient(m metadata) *redis.Client {
|
||||||
return redis.NewFailoverClient(opts)
|
return redis.NewFailoverClient(opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *ConfigurationStore) getConnectedSlaves() (int, error) {
|
func (r *ConfigurationStore) getConnectedSlaves(ctx context.Context) (int, error) {
|
||||||
res, err := r.client.Do(context.Background(), "INFO", "replication").Result()
|
res, err := r.client.Do(ctx, "INFO", "replication").Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ import "context"
|
||||||
// Store is an interface to perform operations on store.
|
// Store is an interface to perform operations on store.
|
||||||
type Store interface {
|
type Store interface {
|
||||||
// Init configuration store.
|
// Init configuration store.
|
||||||
Init(metadata Metadata) error
|
Init(ctx context.Context, metadata Metadata) error
|
||||||
|
|
||||||
// Get configuration.
|
// Get configuration.
|
||||||
Get(ctx context.Context, req *GetRequest) (*GetResponse, error)
|
Get(ctx context.Context, req *GetRequest) (*GetResponse, error)
|
||||||
|
|
|
||||||
2
go.mod
2
go.mod
|
|
@ -400,3 +400,5 @@ replace github.com/toolkits/concurrent => github.com/niean/gotools v0.0.0-201512
|
||||||
|
|
||||||
// this is a fork which addresses a performance issues due to go routines
|
// this is a fork which addresses a performance issues due to go routines
|
||||||
replace dubbo.apache.org/dubbo-go/v3 => dubbo.apache.org/dubbo-go/v3 v3.0.3-0.20230118042253-4f159a2b38f3
|
replace dubbo.apache.org/dubbo-go/v3 => dubbo.apache.org/dubbo-go/v3 v3.0.3-0.20230118042253-4f159a2b38f3
|
||||||
|
|
||||||
|
replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,22 @@
|
||||||
|
/*
|
||||||
|
Copyright 2023 The Dapr Authors
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
package health
|
package health
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
type Pinger interface {
|
type Pinger interface {
|
||||||
Ping() error
|
Ping(ctx context.Context) error
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Shopify/sarama"
|
"github.com/Shopify/sarama"
|
||||||
|
|
@ -31,6 +32,7 @@ type consumer struct {
|
||||||
k *Kafka
|
k *Kafka
|
||||||
ready chan bool
|
ready chan bool
|
||||||
running chan struct{}
|
running chan struct{}
|
||||||
|
stopped atomic.Bool
|
||||||
once sync.Once
|
once sync.Once
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
@ -275,9 +277,6 @@ func (k *Kafka) Subscribe(ctx context.Context) error {
|
||||||
|
|
||||||
k.cg = cg
|
k.cg = cg
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
k.cancel = cancel
|
|
||||||
|
|
||||||
ready := make(chan bool)
|
ready := make(chan bool)
|
||||||
k.consumer = consumer{
|
k.consumer = consumer{
|
||||||
k: k,
|
k: k,
|
||||||
|
|
@ -320,7 +319,10 @@ func (k *Kafka) Subscribe(ctx context.Context) error {
|
||||||
k.logger.Errorf("Error closing consumer group: %v", err)
|
k.logger.Errorf("Error closing consumer group: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
close(k.consumer.running)
|
// Ensure running channel is only closed once.
|
||||||
|
if k.consumer.stopped.CompareAndSwap(false, true) {
|
||||||
|
close(k.consumer.running)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
<-ready
|
<-ready
|
||||||
|
|
@ -331,7 +333,6 @@ func (k *Kafka) Subscribe(ctx context.Context) error {
|
||||||
// Close down consumer group resources, refresh once.
|
// Close down consumer group resources, refresh once.
|
||||||
func (k *Kafka) closeSubscriptionResources() {
|
func (k *Kafka) closeSubscriptionResources() {
|
||||||
if k.cg != nil {
|
if k.cg != nil {
|
||||||
k.cancel()
|
|
||||||
err := k.cg.Close()
|
err := k.cg.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
k.logger.Errorf("Error closing consumer group: %v", err)
|
k.logger.Errorf("Error closing consumer group: %v", err)
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,6 @@ type Kafka struct {
|
||||||
saslPassword string
|
saslPassword string
|
||||||
initialOffset int64
|
initialOffset int64
|
||||||
cg sarama.ConsumerGroup
|
cg sarama.ConsumerGroup
|
||||||
cancel context.CancelFunc
|
|
||||||
consumer consumer
|
consumer consumer
|
||||||
config *sarama.Config
|
config *sarama.Config
|
||||||
subscribeTopics TopicHandlerConfig
|
subscribeTopics TopicHandlerConfig
|
||||||
|
|
@ -60,7 +59,7 @@ func NewKafka(logger logger.Logger) *Kafka {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init does metadata parsing and connection establishment.
|
// Init does metadata parsing and connection establishment.
|
||||||
func (k *Kafka) Init(metadata map[string]string) error {
|
func (k *Kafka) Init(_ context.Context, metadata map[string]string) error {
|
||||||
upgradedMetadata, err := k.upgradeMetadata(metadata)
|
upgradedMetadata, err := k.upgradeMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
||||||
|
|
@ -107,7 +107,8 @@ func (ts *OAuthTokenSource) Token() (*sarama.AccessToken, error) {
|
||||||
|
|
||||||
oidcCfg := ccred.Config{ClientID: ts.ClientID, ClientSecret: ts.ClientSecret, Scopes: ts.Scopes, TokenURL: ts.TokenEndpoint.TokenURL, AuthStyle: ts.TokenEndpoint.AuthStyle}
|
oidcCfg := ccred.Config{ClientID: ts.ClientID, ClientSecret: ts.ClientSecret, Scopes: ts.Scopes, TokenURL: ts.TokenEndpoint.TokenURL, AuthStyle: ts.TokenEndpoint.AuthStyle}
|
||||||
|
|
||||||
timeoutCtx, _ := ctx.WithTimeout(ctx.TODO(), tokenRequestTimeout) //nolint:govet
|
timeoutCtx, cancel := ctx.WithTimeout(ctx.TODO(), tokenRequestTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
ts.configureClient()
|
ts.configureClient()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -38,9 +38,6 @@ type StandaloneRedisLock struct {
|
||||||
metadata rediscomponent.Metadata
|
metadata rediscomponent.Metadata
|
||||||
|
|
||||||
logger logger.Logger
|
logger logger.Logger
|
||||||
|
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewStandaloneRedisLock returns a new standalone redis lock.
|
// NewStandaloneRedisLock returns a new standalone redis lock.
|
||||||
|
|
@ -54,7 +51,7 @@ func NewStandaloneRedisLock(logger logger.Logger) lock.Store {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init StandaloneRedisLock.
|
// Init StandaloneRedisLock.
|
||||||
func (r *StandaloneRedisLock) InitLockStore(metadata lock.Metadata) error {
|
func (r *StandaloneRedisLock) InitLockStore(ctx context.Context, metadata lock.Metadata) error {
|
||||||
// 1. parse config
|
// 1. parse config
|
||||||
m, err := rediscomponent.ParseRedisMetadata(metadata.Properties)
|
m, err := rediscomponent.ParseRedisMetadata(metadata.Properties)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -75,13 +72,12 @@ func (r *StandaloneRedisLock) InitLockStore(metadata lock.Metadata) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
r.ctx, r.cancel = context.WithCancel(context.Background())
|
|
||||||
// 3. connect to redis
|
// 3. connect to redis
|
||||||
if _, err = r.client.PingResult(r.ctx); err != nil {
|
if _, err = r.client.PingResult(ctx); err != nil {
|
||||||
return fmt.Errorf("[standaloneRedisLock]: error connecting to redis at %s: %s", r.clientSettings.Host, err)
|
return fmt.Errorf("[standaloneRedisLock]: error connecting to redis at %s: %s", r.clientSettings.Host, err)
|
||||||
}
|
}
|
||||||
// no replica
|
// no replica
|
||||||
replicas, err := r.getConnectedSlaves()
|
replicas, err := r.getConnectedSlaves(ctx)
|
||||||
// pass the validation if error occurs,
|
// pass the validation if error occurs,
|
||||||
// since some redis versions such as miniredis do not recognize the `INFO` command.
|
// since some redis versions such as miniredis do not recognize the `INFO` command.
|
||||||
if err == nil && replicas > 0 {
|
if err == nil && replicas > 0 {
|
||||||
|
|
@ -101,8 +97,8 @@ func needFailover(properties map[string]string) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *StandaloneRedisLock) getConnectedSlaves() (int, error) {
|
func (r *StandaloneRedisLock) getConnectedSlaves(ctx context.Context) (int, error) {
|
||||||
res, err := r.client.DoRead(r.ctx, "INFO", "replication")
|
res, err := r.client.DoRead(ctx, "INFO", "replication")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
@ -183,9 +179,6 @@ func newInternalErrorUnlockResponse() *lock.UnlockResponse {
|
||||||
|
|
||||||
// Close shuts down the client's redis connections.
|
// Close shuts down the client's redis connections.
|
||||||
func (r *StandaloneRedisLock) Close() error {
|
func (r *StandaloneRedisLock) Close() error {
|
||||||
if r.cancel != nil {
|
|
||||||
r.cancel()
|
|
||||||
}
|
|
||||||
if r.client != nil {
|
if r.client != nil {
|
||||||
closeErr := r.client.Close()
|
closeErr := r.client.Close()
|
||||||
r.client = nil
|
r.client = nil
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ func TestStandaloneRedisLock_InitError(t *testing.T) {
|
||||||
cfg.Properties["redisPassword"] = ""
|
cfg.Properties["redisPassword"] = ""
|
||||||
|
|
||||||
// init
|
// init
|
||||||
err := comp.InitLockStore(cfg)
|
err := comp.InitLockStore(context.Background(), cfg)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
@ -58,7 +58,7 @@ func TestStandaloneRedisLock_InitError(t *testing.T) {
|
||||||
cfg.Properties["redisPassword"] = ""
|
cfg.Properties["redisPassword"] = ""
|
||||||
|
|
||||||
// init
|
// init
|
||||||
err := comp.InitLockStore(cfg)
|
err := comp.InitLockStore(context.Background(), cfg)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
@ -75,7 +75,7 @@ func TestStandaloneRedisLock_InitError(t *testing.T) {
|
||||||
cfg.Properties["maxRetries"] = "1 "
|
cfg.Properties["maxRetries"] = "1 "
|
||||||
|
|
||||||
// init
|
// init
|
||||||
err := comp.InitLockStore(cfg)
|
err := comp.InitLockStore(context.Background(), cfg)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -96,7 +96,7 @@ func TestStandaloneRedisLock_TryLock(t *testing.T) {
|
||||||
cfg.Properties["redisHost"] = s.Addr()
|
cfg.Properties["redisHost"] = s.Addr()
|
||||||
cfg.Properties["redisPassword"] = ""
|
cfg.Properties["redisPassword"] = ""
|
||||||
// init
|
// init
|
||||||
err = comp.InitLockStore(cfg)
|
err = comp.InitLockStore(context.Background(), cfg)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
// 1. client1 trylock
|
// 1. client1 trylock
|
||||||
ownerID1 := uuid.New().String()
|
ownerID1 := uuid.New().String()
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ import "context"
|
||||||
|
|
||||||
type Store interface {
|
type Store interface {
|
||||||
// Init this component.
|
// Init this component.
|
||||||
InitLockStore(metadata Metadata) error
|
InitLockStore(ctx context.Context, metadata Metadata) error
|
||||||
|
|
||||||
// TryLock tries to acquire a lock.
|
// TryLock tries to acquire a lock.
|
||||||
TryLock(ctx context.Context, req *TryLockRequest) (*TryLockResponse, error)
|
TryLock(ctx context.Context, req *TryLockRequest) (*TryLockResponse, error)
|
||||||
|
|
|
||||||
|
|
@ -45,13 +45,13 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetHandler retruns the HTTP handler provided by the middleware.
|
// GetHandler retruns the HTTP handler provided by the middleware.
|
||||||
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
|
func (m *Middleware) GetHandler(ctx context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
|
||||||
meta, err := m.getNativeMetadata(metadata)
|
meta, err := m.getNativeMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
provider, err := oidc.NewProvider(context.Background(), meta.IssuerURL)
|
provider, err := oidc.NewProvider(ctx, meta.IssuerURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||||
package oauth2
|
package oauth2
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
@ -59,7 +60,7 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetHandler retruns the HTTP handler provided by the middleware.
|
// GetHandler retruns the HTTP handler provided by the middleware.
|
||||||
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
|
func (m *Middleware) GetHandler(ctx context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
|
||||||
meta, err := m.getNativeMetadata(metadata)
|
meta, err := m.getNativeMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@
|
||||||
package mock_oauth2clientcredentials
|
package mock_oauth2clientcredentials
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
|
@ -36,7 +37,7 @@ func (m *MockTokenProviderInterface) EXPECT() *MockTokenProviderInterfaceMockRec
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetToken mocks base method
|
// GetToken mocks base method
|
||||||
func (m *MockTokenProviderInterface) GetToken(arg0 *clientcredentials.Config) (*oauth2.Token, error) {
|
func (m *MockTokenProviderInterface) GetToken(ctx context.Context, arg0 *clientcredentials.Config) (*oauth2.Token, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "GetToken", arg0)
|
ret := m.ctrl.Call(m, "GetToken", arg0)
|
||||||
ret0, _ := ret[0].(*oauth2.Token)
|
ret0, _ := ret[0].(*oauth2.Token)
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ type oAuth2ClientCredentialsMiddlewareMetadata struct {
|
||||||
|
|
||||||
// TokenProviderInterface provides a common interface to Mock the Token retrieval in unit tests.
|
// TokenProviderInterface provides a common interface to Mock the Token retrieval in unit tests.
|
||||||
type TokenProviderInterface interface {
|
type TokenProviderInterface interface {
|
||||||
GetToken(conf *clientcredentials.Config) (*oauth2.Token, error)
|
GetToken(ctx context.Context, conf *clientcredentials.Config) (*oauth2.Token, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOAuth2ClientCredentialsMiddleware returns a new oAuth2 middleware.
|
// NewOAuth2ClientCredentialsMiddleware returns a new oAuth2 middleware.
|
||||||
|
|
@ -68,7 +68,7 @@ type Middleware struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetHandler retruns the HTTP handler provided by the middleware.
|
// GetHandler retruns the HTTP handler provided by the middleware.
|
||||||
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
|
func (m *Middleware) GetHandler(_ context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
|
||||||
meta, err := m.getNativeMetadata(metadata)
|
meta, err := m.getNativeMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.log.Errorf("getNativeMetadata error: %s", err)
|
m.log.Errorf("getNativeMetadata error: %s", err)
|
||||||
|
|
@ -101,7 +101,7 @@ func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Ha
|
||||||
if !found {
|
if !found {
|
||||||
m.log.Debugf("Cached token not found, try get one")
|
m.log.Debugf("Cached token not found, try get one")
|
||||||
|
|
||||||
token, err := m.tokenProvider.GetToken(conf)
|
token, err := m.tokenProvider.GetToken(r.Context(), conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.log.Errorf("Error acquiring token: %s", err)
|
m.log.Errorf("Error acquiring token: %s", err)
|
||||||
return
|
return
|
||||||
|
|
@ -171,8 +171,8 @@ func (m *Middleware) SetTokenProvider(tokenProvider TokenProviderInterface) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetToken returns a token from the current OAuth2 ClientCredentials Configuration.
|
// GetToken returns a token from the current OAuth2 ClientCredentials Configuration.
|
||||||
func (m *Middleware) GetToken(conf *clientcredentials.Config) (*oauth2.Token, error) {
|
func (m *Middleware) GetToken(ctx context.Context, conf *clientcredentials.Config) (*oauth2.Token, error) {
|
||||||
tokenSource := conf.TokenSource(context.Background())
|
tokenSource := conf.TokenSource(ctx)
|
||||||
|
|
||||||
return tokenSource.Token()
|
return tokenSource.Token()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||||
package oauth2clientcredentials
|
package oauth2clientcredentials
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
@ -45,7 +46,7 @@ func TestOAuth2ClientCredentialsMetadata(t *testing.T) {
|
||||||
metadata.Properties = map[string]string{}
|
metadata.Properties = map[string]string{}
|
||||||
|
|
||||||
log := logger.NewLogger("oauth2clientcredentials.test")
|
log := logger.NewLogger("oauth2clientcredentials.test")
|
||||||
_, err := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
|
_, err := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(context.Background(), metadata)
|
||||||
assert.EqualError(t, err, "metadata errors: Parameter 'headerName' needs to be set. Parameter 'clientID' needs to be set. Parameter 'clientSecret' needs to be set. Parameter 'scopes' needs to be set. Parameter 'tokenURL' needs to be set. ")
|
assert.EqualError(t, err, "metadata errors: Parameter 'headerName' needs to be set. Parameter 'clientID' needs to be set. Parameter 'clientSecret' needs to be set. Parameter 'scopes' needs to be set. Parameter 'tokenURL' needs to be set. ")
|
||||||
|
|
||||||
// Invalid authStyle (non int)
|
// Invalid authStyle (non int)
|
||||||
|
|
@ -57,17 +58,17 @@ func TestOAuth2ClientCredentialsMetadata(t *testing.T) {
|
||||||
"headerName": "someHeader",
|
"headerName": "someHeader",
|
||||||
"authStyle": "asdf", // This is the value to test
|
"authStyle": "asdf", // This is the value to test
|
||||||
}
|
}
|
||||||
_, err2 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
|
_, err2 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(context.Background(), metadata)
|
||||||
assert.EqualError(t, err2, "metadata errors: 1 error(s) decoding:\n\n* cannot parse 'AuthStyle' as int: strconv.ParseInt: parsing \"asdf\": invalid syntax")
|
assert.EqualError(t, err2, "metadata errors: 1 error(s) decoding:\n\n* cannot parse 'AuthStyle' as int: strconv.ParseInt: parsing \"asdf\": invalid syntax")
|
||||||
|
|
||||||
// Invalid authStyle (int > 2)
|
// Invalid authStyle (int > 2)
|
||||||
metadata.Properties["authStyle"] = "3"
|
metadata.Properties["authStyle"] = "3"
|
||||||
_, err3 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
|
_, err3 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(context.Background(), metadata)
|
||||||
assert.EqualError(t, err3, "metadata errors: Parameter 'authStyle' can only have the values 0,1,2. Received: '3'. ")
|
assert.EqualError(t, err3, "metadata errors: Parameter 'authStyle' can only have the values 0,1,2. Received: '3'. ")
|
||||||
|
|
||||||
// Invalid authStyle (int < 0)
|
// Invalid authStyle (int < 0)
|
||||||
metadata.Properties["authStyle"] = "-1"
|
metadata.Properties["authStyle"] = "-1"
|
||||||
_, err4 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(metadata)
|
_, err4 := NewOAuth2ClientCredentialsMiddleware(log).GetHandler(context.Background(), metadata)
|
||||||
assert.EqualError(t, err4, "metadata errors: Parameter 'authStyle' can only have the values 0,1,2. Received: '-1'. ")
|
assert.EqualError(t, err4, "metadata errors: Parameter 'authStyle' can only have the values 0,1,2. Received: '-1'. ")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -109,7 +110,7 @@ func TestOAuth2ClientCredentialsToken(t *testing.T) {
|
||||||
log := logger.NewLogger("oauth2clientcredentials.test")
|
log := logger.NewLogger("oauth2clientcredentials.test")
|
||||||
oauth2clientcredentialsMiddleware, _ := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware)
|
oauth2clientcredentialsMiddleware, _ := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware)
|
||||||
oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider)
|
oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider)
|
||||||
handler, err := oauth2clientcredentialsMiddleware.GetHandler(metadata)
|
handler, err := oauth2clientcredentialsMiddleware.GetHandler(context.Background(), metadata)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// First handler call should return abc Token
|
// First handler call should return abc Token
|
||||||
|
|
@ -169,7 +170,7 @@ func TestOAuth2ClientCredentialsCache(t *testing.T) {
|
||||||
log := logger.NewLogger("oauth2clientcredentials.test")
|
log := logger.NewLogger("oauth2clientcredentials.test")
|
||||||
oauth2clientcredentialsMiddleware, _ := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware)
|
oauth2clientcredentialsMiddleware, _ := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware)
|
||||||
oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider)
|
oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider)
|
||||||
handler, err := oauth2clientcredentialsMiddleware.GetHandler(metadata)
|
handler, err := oauth2clientcredentialsMiddleware.GetHandler(context.Background(), metadata)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// First handler call should return abc Token
|
// First handler call should return abc Token
|
||||||
|
|
|
||||||
|
|
@ -106,13 +106,13 @@ func (s *Status) Valid() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetHandler returns the HTTP handler provided by the middleware.
|
// GetHandler returns the HTTP handler provided by the middleware.
|
||||||
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
|
func (m *Middleware) GetHandler(parentCtx context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
|
||||||
meta, err := m.getNativeMetadata(metadata)
|
meta, err := m.getNativeMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.TODO(), time.Minute)
|
ctx, cancel := context.WithTimeout(parentCtx, time.Minute)
|
||||||
query, err := rego.New(
|
query, err := rego.New(
|
||||||
rego.Query("result = data.http.allow"),
|
rego.Query("result = data.http.allow"),
|
||||||
rego.Module("inline.rego", meta.Rego),
|
rego.Module("inline.rego", meta.Rego),
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||||
package opa
|
package opa
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
|
@ -331,7 +332,7 @@ func TestOpaPolicy(t *testing.T) {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
opaMiddleware := NewMiddleware(log)
|
opaMiddleware := NewMiddleware(log)
|
||||||
|
|
||||||
handler, err := opaMiddleware.GetHandler(test.meta)
|
handler, err := opaMiddleware.GetHandler(context.Background(), test.meta)
|
||||||
if test.shouldHandlerError {
|
if test.shouldHandlerError {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||||
package ratelimit
|
package ratelimit
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
@ -46,7 +47,7 @@ func NewRateLimitMiddleware(_ logger.Logger) middleware.Middleware {
|
||||||
type Middleware struct{}
|
type Middleware struct{}
|
||||||
|
|
||||||
// GetHandler returns the HTTP handler provided by the middleware.
|
// GetHandler returns the HTTP handler provided by the middleware.
|
||||||
func (m *Middleware) GetHandler(metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
|
func (m *Middleware) GetHandler(_ context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
|
||||||
meta, err := m.getNativeMetadata(metadata)
|
meta, err := m.getNativeMetadata(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ func NewMiddleware(logger logger.Logger) middleware.Middleware {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetHandler retruns the HTTP handler provided by the middleware.
|
// GetHandler retruns the HTTP handler provided by the middleware.
|
||||||
func (m *Middleware) GetHandler(metadata middleware.Metadata) (
|
func (m *Middleware) GetHandler(_ context.Context, metadata middleware.Metadata) (
|
||||||
func(next http.Handler) http.Handler, error,
|
func(next http.Handler) http.Handler, error,
|
||||||
) {
|
) {
|
||||||
if err := m.getNativeMetadata(metadata); err != nil {
|
if err := m.getNativeMetadata(metadata); err != nil {
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||||
package routeralias
|
package routeralias
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
|
@ -46,7 +47,7 @@ func TestRequestHandlerWithIllegalRouterRule(t *testing.T) {
|
||||||
}
|
}
|
||||||
log := logger.NewLogger("routeralias.test")
|
log := logger.NewLogger("routeralias.test")
|
||||||
ralias := NewMiddleware(log)
|
ralias := NewMiddleware(log)
|
||||||
handler, err := ralias.GetHandler(meta)
|
handler, err := ralias.GetHandler(context.Background(), meta)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
t.Run("hit: change router with common request", func(t *testing.T) {
|
t.Run("hit: change router with common request", func(t *testing.T) {
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue