Merge branch 'master' of https://github.com/dapr/components-contrib into release-1.10

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
ItalyPaleAle 2023-02-16 23:23:23 +00:00
commit f9cf54ca6f
334 changed files with 4869 additions and 1580 deletions

View File

@ -18,3 +18,5 @@ require (
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
)
replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181

View File

@ -1,6 +1,6 @@
{
"name": "Dapr Components Contributor Environment",
"image": "ghcr.io/dapr/dapr-dev:0.1.9",
"image": "ghcr.io/dapr/dapr-dev:latest",
"containerEnv": {
// Uncomment to overwrite devcontainer .kube/config and .minikube certs with the localhost versions
// each time the devcontainer starts, if the respective .kube-localhost/config and .minikube-localhost
@ -11,14 +11,21 @@
// the localhost bind-mount /var/run/docker-host.sock.
// "BIND_LOCALHOST_DOCKER": "true",
// Necessary for components-contrib's certification tests
"GOLANG_PROTOBUF_REGISTRATION_CONFLICT": "true"
},
"extensions": [
"davidanson.vscode-markdownlint",
"golang.go",
"ms-azuretools.vscode-dapr",
"ms-azuretools.vscode-docker",
"ms-kubernetes-tools.vscode-kubernetes-tools"
],
"features": {
"ghcr.io/devcontainers/features/sshd:1": {},
"ghcr.io/devcontainers/features/github-cli:1": {},
"ghcr.io/devcontainers/features/azure-cli:1": {}
},
"mounts": [
// Mount docker-in-docker library volume
"type=volume,source=dind-var-lib-docker,target=/var/lib/docker",
@ -57,7 +64,12 @@
"settings": {
"go.toolsManagement.checkForUpdates": "local",
"go.useLanguageServer": true,
"go.gopath": "/go"
"go.gopath": "/go",
"go.buildTags": "e2e,perf,conftests,unit,integration_test,certtests",
"git.alwaysSignOff": true,
"terminal.integrated.env.linux": {
"GOLANG_PROTOBUF_REGISTRATION_CONFLICT": "ignore"
}
},
"workspaceFolder": "/workspaces/components-contrib",
"workspaceMount": "type=bind,source=${localWorkspaceFolder},target=/workspaces/components-contrib",

View File

@ -33,3 +33,5 @@ require (
google.golang.org/protobuf v1.28.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181

View File

@ -195,6 +195,14 @@ async function cmdOkToTest(github, issue, isFromPulls) {
client_payload: testPayload,
});
// Fire repository_dispatch event to trigger unit tests for other architectures and OS
await github.rest.repos.createDispatchEvent({
owner: issue.owner,
repo: issue.repo,
event_type: "build-all",
client_payload: testPayload,
});
console.log(`[cmdOkToTest] triggered certification and conformance tests for ${JSON.stringify(testPayload)}`);
}
}

View File

@ -18,10 +18,12 @@ on:
types: [certification-test]
workflow_dispatch:
schedule:
- cron: '5 */12 * * *'
- cron: '25 */8 * * *'
push:
branches:
- release-*
pull_request:
branches:
- master
- release-*
jobs:
@ -59,7 +61,8 @@ jobs:
- state.cassandra
- state.memcached
- state.mysql
- bindings.alicloud.dubbo
- state.sqlite
- bindings.dubbo
- bindings.kafka
- bindings.redis
- bindings.cron

View File

@ -0,0 +1,166 @@
# ------------------------------------------------------------
# Copyright 2021 The Dapr Authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------
name: "Build, Lint, Unit Test - complete matrix"
on:
repository_dispatch:
types: [build-all]
workflow_dispatch:
schedule:
- cron: '0 */6 * * *'
push:
branches:
- release-*
tags:
- v*
pull_request:
branches:
- release-*
jobs:
post-comment:
name: Post comment on Repository Dispatch
runs-on: ubuntu-latest
steps:
- name: Parse repository_dispatch payload
if: github.event_name == 'repository_dispatch'
working-directory: ${{ github.workspace }}
shell: bash
run: |
if [ ${{ github.event.client_payload.command }} = "ok-to-test" ]; then
echo "CHECKOUT_REF=${{ github.event.client_payload.pull_head_ref }}" >> $GITHUB_ENV
echo "PR_NUMBER=${{ github.event.client_payload.issue.number }}" >> $GITHUB_ENV
fi
- name: Create PR comment
if: env.PR_NUMBER != ''
uses: artursouza/sticky-pull-request-comment@da9e86aa2a80e4ae3b854d251add33bd6baabcba
with:
header: ${{ github.run_id }}
number: ${{ env.PR_NUMBER }}
GITHUB_TOKEN: ${{ secrets.DAPR_BOT_TOKEN }}
message: |
# Complete Build Matrix
The build status is currently not updated here. Please visit the action run below directly.
🔗 **[Link to Action run](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }})**
Commit ref: ${{ env.CHECKOUT_REF }}
build:
name: Build ${{ matrix.target_os }}_${{ matrix.target_arch }} binaries
runs-on: ${{ matrix.os }}
needs: post-comment
env:
GOVER: "1.19"
GOOS: ${{ matrix.target_os }}
GOARCH: ${{ matrix.target_arch }}
GOPROXY: https://proxy.golang.org
GOLANGCI_LINT_VER: "v1.50.1"
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macOS-latest]
target_arch: [arm, amd64]
include:
- os: ubuntu-latest
target_os: linux
- os: windows-latest
target_os: windows
- os: macOS-latest
target_os: darwin
exclude:
- os: windows-latest
target_arch: arm
- os: macOS-latest
target_arch: arm
steps:
- name: Set default payload repo and ref
run: |
echo "CHECKOUT_REPO=${{ github.repository }}" >> $GITHUB_ENV
echo "CHECKOUT_REF=${{ github.ref }}" >> $GITHUB_ENV
- name: Parse repository_dispatch payload
if: github.event_name == 'repository_dispatch'
working-directory: ${{ github.workspace }}
shell: bash
run: |
if [ ${{ github.event.client_payload.command }} = "ok-to-test" ]; then
echo "CHECKOUT_REF=${{ github.event.client_payload.pull_head_ref }}" >> $GITHUB_ENV
fi
- name: Set up Go ${{ env.GOVER }}
if: ${{ steps.skip_check.outputs.should_skip != 'true' }}
uses: actions/setup-go@v3
with:
go-version: ${{ env.GOVER }}
- name: Check out code into the Go module directory
if: ${{ steps.skip_check.outputs.should_skip != 'true' }}
uses: actions/checkout@v3
with:
repository: ${{ env.CHECKOUT_REPO }}
ref: ${{ env.CHECKOUT_REF }}
- name: Cache Go modules (Linux)
if: matrix.target_os == 'linux'
uses: actions/cache@v3
with:
path: |
~/.cache/go-build
~/go/pkg/mod
key: ${{ matrix.target_os }}-${{ matrix.target_arch }}-go-${{ env.GOVER }}-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ matrix.target_os }}-${{ matrix.target_arch }}-go-${{ env.GOVER }}-
- name: Cache Go modules (Windows)
if: matrix.target_os == 'windows'
uses: actions/cache@v3
with:
path: |
~\AppData\Local\go-build
~\go\pkg\mod
key: ${{ matrix.target_os }}-${{ matrix.target_arch }}-go-${{ env.GOVER }}-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ matrix.target_os }}-${{ matrix.target_arch }}-go-${{ env.GOVER }}-
- name: Cache Go modules (macOS)
if: matrix.target_os == 'darwin'
uses: actions/cache@v3
with:
path: |
~/Library/Caches/go-build
~/go/pkg/mod
key: ${{ matrix.target_os }}-${{ matrix.target_arch }}-go-${{ env.GOVER }}-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ matrix.target_os }}-${{ matrix.target_arch }}-go-${{ env.GOVER }}-
- name: Check components-schema
if: matrix.target_arch == 'amd64' && matrix.target_os == 'linux' && steps.skip_check.outputs.should_skip != 'true'
run: make check-component-metadata-schema-diff
- name: Run golangci-lint
if: matrix.target_arch == 'amd64' && matrix.target_os == 'linux' && steps.skip_check.outputs.should_skip != 'true'
uses: golangci/golangci-lint-action@v3.2.0
with:
version: ${{ env.GOLANGCI_LINT_VER }}
skip-cache: true
args: --timeout 15m
- name: Run go mod tidy check diff
if: matrix.target_arch == 'amd64' && matrix.target_os == 'linux' && steps.skip_check.outputs.should_skip != 'true'
run: make modtidy-all check-mod-diff
# - name: Run Go Vulnerability Check
# if: matrix.target_arch == 'amd64' && matrix.target_os == 'linux' && steps.skip_check.outputs.should_skip != 'true'
# run: |
# go install golang.org/x/vuln/cmd/govulncheck@latest
# govulncheck ./...
- name: Run make test
env:
COVERAGE_OPTS: "-coverprofile=coverage.txt -covermode=atomic"
IPFS_TEST: "1"
if: matrix.target_arch != 'arm' && steps.skip_check.outputs.should_skip != 'true'
run: make test
- name: Codecov
if: matrix.target_arch == 'amd64' && matrix.target_os == 'linux'
uses: codecov/codecov-action@v3

View File

@ -11,47 +11,28 @@
# limitations under the License.
# ------------------------------------------------------------
name: components-contrib
name: "Build, Lint, Unit Test - Linux AMD64 Only"
on:
merge_group:
push:
branches:
- master
- release-*
- feature/*
tags:
- v*
pull_request:
branches:
- master
- release-*
- feature/*
jobs:
build:
name: Build ${{ matrix.target_os }}_${{ matrix.target_arch }} binaries
runs-on: ${{ matrix.os }}
name: Build linux_amd64 binaries
runs-on: ubuntu-latest
env:
GOVER: "1.19"
GOOS: ${{ matrix.target_os }}
GOARCH: ${{ matrix.target_arch }}
GOOS: linux
GOARCH: amd64
GOPROXY: https://proxy.golang.org
GOLANGCI_LINT_VER: "v1.50.1"
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macOS-latest]
target_arch: [arm, amd64]
include:
- os: ubuntu-latest
target_os: linux
- os: windows-latest
target_os: windows
- os: macOS-latest
target_os: darwin
exclude:
- os: windows-latest
target_arch: arm
- os: macOS-latest
target_arch: arm
steps:
- name: Set up Go ${{ env.GOVER }}
if: ${{ steps.skip_check.outputs.should_skip != 'true' }}
@ -62,50 +43,29 @@ jobs:
if: ${{ steps.skip_check.outputs.should_skip != 'true' }}
uses: actions/checkout@v3
- name: Cache Go modules (Linux)
if: matrix.target_os == 'linux'
uses: actions/cache@v3
with:
path: |
~/.cache/go-build
~/go/pkg/mod
key: ${{ matrix.target_os }}-${{ matrix.target_arch }}-go-${{ env.GOVER }}-${{ hashFiles('**/go.sum') }}
key: linux-amd64-go-${{ env.GOVER }}-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ matrix.target_os }}-${{ matrix.target_arch }}-go-${{ env.GOVER }}-
- name: Cache Go modules (Windows)
if: matrix.target_os == 'windows'
uses: actions/cache@v3
with:
path: |
~\AppData\Local\go-build
~\go\pkg\mod
key: ${{ matrix.target_os }}-${{ matrix.target_arch }}-go-${{ env.GOVER }}-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ matrix.target_os }}-${{ matrix.target_arch }}-go-${{ env.GOVER }}-
- name: Cache Go modules (macOS)
if: matrix.target_os == 'darwin'
uses: actions/cache@v3
with:
path: |
~/Library/Caches/go-build
~/go/pkg/mod
key: ${{ matrix.target_os }}-${{ matrix.target_arch }}-go-${{ env.GOVER }}-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ matrix.target_os }}-${{ matrix.target_arch }}-go-${{ env.GOVER }}-
linux-amd64-go-${{ env.GOVER }}-
- name: Check components-schema
if: matrix.target_arch == 'amd64' && matrix.target_os == 'linux' && steps.skip_check.outputs.should_skip != 'true'
if: steps.skip_check.outputs.should_skip != 'true'
run: make check-component-metadata-schema-diff
- name: Run golangci-lint
if: matrix.target_arch == 'amd64' && matrix.target_os == 'linux' && steps.skip_check.outputs.should_skip != 'true'
if: steps.skip_check.outputs.should_skip != 'true'
uses: golangci/golangci-lint-action@v3.2.0
with:
version: ${{ env.GOLANGCI_LINT_VER }}
skip-cache: true
args: --timeout 15m
- name: Run go mod tidy check diff
if: matrix.target_arch == 'amd64' && matrix.target_os == 'linux' && steps.skip_check.outputs.should_skip != 'true'
if: steps.skip_check.outputs.should_skip != 'true'
run: make modtidy-all check-mod-diff
# - name: Run Go Vulnerability Check
# if: matrix.target_arch == 'amd64' && matrix.target_os == 'linux' && steps.skip_check.outputs.should_skip != 'true'
# if: steps.skip_check.outputs.should_skip != 'true'
# run: |
# go install golang.org/x/vuln/cmd/govulncheck@latest
# govulncheck ./...
@ -113,7 +73,7 @@ jobs:
env:
COVERAGE_OPTS: "-coverprofile=coverage.txt -covermode=atomic"
IPFS_TEST: "1"
if: matrix.target_arch != 'arm' && steps.skip_check.outputs.should_skip != 'true'
if: steps.skip_check.outputs.should_skip != 'true'
run: make test
- name: Codecov
if: matrix.target_arch == 'amd64' && matrix.target_os == 'linux'

View File

@ -17,8 +17,12 @@ on:
repository_dispatch:
types: [conformance-test]
workflow_dispatch:
merge_group:
schedule:
- cron: '0 */8 * * *'
push:
branches:
- 'release-*'
pull_request:
branches:
- master

View File

@ -80,7 +80,7 @@ func NewDingTalkWebhook(l logger.Logger) bindings.InputOutputBinding {
}
// Init performs metadata parsing.
func (t *DingTalkWebhook) Init(metadata bindings.Metadata) error {
func (t *DingTalkWebhook) Init(_ context.Context, metadata bindings.Metadata) error {
var err error
if err = t.settings.Decode(metadata.Properties); err != nil {
return fmt.Errorf("dingtalk configuration error: %w", err)
@ -107,6 +107,13 @@ func (t *DingTalkWebhook) Read(ctx context.Context, handler bindings.Handler) er
return nil
}
func (t *DingTalkWebhook) Close() error {
webhooks.Lock()
defer webhooks.Unlock()
delete(webhooks.m, t.settings.ID)
return nil
}
// Operations returns list of operations supported by dingtalk webhook binding.
func (t *DingTalkWebhook) Operations() []bindings.OperationKind {
return []bindings.OperationKind{bindings.CreateOperation, bindings.GetOperation}

View File

@ -57,7 +57,7 @@ func TestPublishMsg(t *testing.T) { //nolint:paralleltest
}}}
d := NewDingTalkWebhook(logger.NewLogger("test"))
err := d.Init(m)
err := d.Init(context.Background(), m)
require.NoError(t, err)
req := &bindings.InvokeRequest{Data: []byte(msg), Operation: bindings.CreateOperation, Metadata: map[string]string{}}
@ -78,7 +78,7 @@ func TestBindingReadAndInvoke(t *testing.T) { //nolint:paralleltest
}}
d := NewDingTalkWebhook(logger.NewLogger("test"))
err := d.Init(m)
err := d.Init(context.Background(), m)
assert.NoError(t, err)
var count int32
@ -106,3 +106,18 @@ func TestBindingReadAndInvoke(t *testing.T) { //nolint:paralleltest
require.FailNow(t, "read timeout")
}
}
func TestBindingClose(t *testing.T) {
d := NewDingTalkWebhook(logger.NewLogger("test"))
m := bindings.Metadata{Base: metadata.Base{
Name: "test",
Properties: map[string]string{
"url": "/test",
"secret": "",
"id": "x",
},
}}
assert.NoError(t, d.Init(context.Background(), m))
assert.NoError(t, d.Close())
assert.NoError(t, d.Close(), "second close should not error")
}

View File

@ -45,7 +45,7 @@ func NewAliCloudOSS(logger logger.Logger) bindings.OutputBinding {
}
// Init does metadata parsing and connection creation.
func (s *AliCloudOSS) Init(metadata bindings.Metadata) error {
func (s *AliCloudOSS) Init(_ context.Context, metadata bindings.Metadata) error {
m, err := s.parseMetadata(metadata)
if err != nil {
return err

View File

@ -31,7 +31,7 @@ type Callback struct {
}
// parse metadata field
func (s *AliCloudSlsLogstorage) Init(metadata bindings.Metadata) error {
func (s *AliCloudSlsLogstorage) Init(_ context.Context, metadata bindings.Metadata) error {
m, err := s.parseMeta(metadata)
if err != nil {
return err

View File

@ -58,7 +58,7 @@ func NewAliCloudTableStore(log logger.Logger) bindings.OutputBinding {
}
}
func (s *AliCloudTableStore) Init(metadata bindings.Metadata) error {
func (s *AliCloudTableStore) Init(_ context.Context, metadata bindings.Metadata) error {
m, err := s.parseMetadata(metadata)
if err != nil {
return err

View File

@ -51,7 +51,7 @@ func TestDataEncodeAndDecode(t *testing.T) {
metadata := bindings.Metadata{Base: metadata.Base{
Properties: getTestProperties(),
}}
aliCloudTableStore.Init(metadata)
aliCloudTableStore.Init(context.Background(), metadata)
// test create
putData := map[string]interface{}{

View File

@ -78,7 +78,7 @@ func NewAPNS(logger logger.Logger) bindings.OutputBinding {
// Init will configure the APNS output binding using the metadata specified
// in the binding's configuration.
func (a *APNS) Init(metadata bindings.Metadata) error {
func (a *APNS) Init(ctx context.Context, metadata bindings.Metadata) error {
if err := a.makeURLPrefix(metadata); err != nil {
return err
}

View File

@ -51,7 +51,7 @@ func TestInit(t *testing.T) {
},
}}
binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata)
err := binding.Init(context.Background(), metadata)
assert.Nil(t, err)
assert.Equal(t, developmentPrefix, binding.urlPrefix)
})
@ -66,7 +66,7 @@ func TestInit(t *testing.T) {
},
}}
binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata)
err := binding.Init(context.Background(), metadata)
assert.Nil(t, err)
assert.Equal(t, productionPrefix, binding.urlPrefix)
})
@ -80,7 +80,7 @@ func TestInit(t *testing.T) {
},
}}
binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata)
err := binding.Init(context.Background(), metadata)
assert.Nil(t, err)
assert.Equal(t, productionPrefix, binding.urlPrefix)
})
@ -95,7 +95,7 @@ func TestInit(t *testing.T) {
},
}}
binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata)
err := binding.Init(context.Background(), metadata)
assert.Error(t, err, "invalid value for development parameter: True")
})
@ -107,7 +107,7 @@ func TestInit(t *testing.T) {
},
}}
binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata)
err := binding.Init(context.Background(), metadata)
assert.Error(t, err, "the key-id parameter is required")
})
@ -120,7 +120,7 @@ func TestInit(t *testing.T) {
},
}}
binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata)
err := binding.Init(context.Background(), metadata)
assert.Nil(t, err)
assert.Equal(t, testKeyID, binding.authorizationBuilder.keyID)
})
@ -133,7 +133,7 @@ func TestInit(t *testing.T) {
},
}}
binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata)
err := binding.Init(context.Background(), metadata)
assert.Error(t, err, "the team-id parameter is required")
})
@ -146,7 +146,7 @@ func TestInit(t *testing.T) {
},
}}
binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata)
err := binding.Init(context.Background(), metadata)
assert.Nil(t, err)
assert.Equal(t, testTeamID, binding.authorizationBuilder.teamID)
})
@ -159,7 +159,7 @@ func TestInit(t *testing.T) {
},
}}
binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata)
err := binding.Init(context.Background(), metadata)
assert.Error(t, err, "the private-key parameter is required")
})
@ -172,7 +172,7 @@ func TestInit(t *testing.T) {
},
}}
binding := NewAPNS(testLogger).(*APNS)
err := binding.Init(metadata)
err := binding.Init(context.Background(), metadata)
assert.Nil(t, err)
assert.NotNil(t, binding.authorizationBuilder.privateKey)
})
@ -335,7 +335,7 @@ func makeTestBinding(t *testing.T, log logger.Logger) *APNS {
privateKeyKey: testPrivateKey,
},
}}
err := testBinding.Init(bindingMetadata)
err := testBinding.Init(context.Background(), bindingMetadata)
assert.Nil(t, err)
return testBinding

View File

@ -49,7 +49,7 @@ func NewDynamoDB(logger logger.Logger) bindings.OutputBinding {
}
// Init performs connection parsing for DynamoDB.
func (d *DynamoDB) Init(metadata bindings.Metadata) error {
func (d *DynamoDB) Init(_ context.Context, metadata bindings.Metadata) error {
meta, err := d.getDynamoDBMetadata(metadata)
if err != nil {
return err

View File

@ -15,7 +15,10 @@ package kinesis
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/aws/aws-sdk-go/aws"
@ -45,6 +48,10 @@ type AWSKinesis struct {
streamARN *string
consumerARN *string
logger logger.Logger
closed atomic.Bool
closeCh chan struct{}
wg sync.WaitGroup
}
type kinesisMetadata struct {
@ -83,11 +90,14 @@ type recordProcessor struct {
// NewAWSKinesis returns a new AWS Kinesis instance.
func NewAWSKinesis(logger logger.Logger) bindings.InputOutputBinding {
return &AWSKinesis{logger: logger}
return &AWSKinesis{
logger: logger,
closeCh: make(chan struct{}),
}
}
// Init does metadata parsing and connection creation.
func (a *AWSKinesis) Init(metadata bindings.Metadata) error {
func (a *AWSKinesis) Init(ctx context.Context, metadata bindings.Metadata) error {
m, err := a.parseMetadata(metadata)
if err != nil {
return err
@ -107,7 +117,7 @@ func (a *AWSKinesis) Init(metadata bindings.Metadata) error {
}
streamName := aws.String(m.StreamName)
stream, err := client.DescribeStream(&kinesis.DescribeStreamInput{
stream, err := client.DescribeStreamWithContext(ctx, &kinesis.DescribeStreamInput{
StreamName: streamName,
})
if err != nil {
@ -147,6 +157,10 @@ func (a *AWSKinesis) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*
}
func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err error) {
if a.closed.Load() {
return errors.New("binding is closed")
}
if a.metadata.KinesisConsumerMode == SharedThroughput {
a.worker = worker.NewWorker(a.recordProcessorFactory(ctx, handler), a.workerConfig)
err = a.worker.Start()
@ -166,8 +180,13 @@ func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err er
}
// Wait for context cancelation then stop
a.wg.Add(1)
go func() {
<-ctx.Done()
defer a.wg.Done()
select {
case <-ctx.Done():
case <-a.closeCh:
}
if a.metadata.KinesisConsumerMode == SharedThroughput {
a.worker.Shutdown()
} else if a.metadata.KinesisConsumerMode == ExtendedFanout {
@ -188,14 +207,25 @@ func (a *AWSKinesis) Subscribe(ctx context.Context, streamDesc kinesis.StreamDes
a.consumerARN = consumerARN
a.wg.Add(len(streamDesc.Shards))
for i, shard := range streamDesc.Shards {
go func(idx int, s *kinesis.Shard) error {
go func(idx int, s *kinesis.Shard) {
defer a.wg.Done()
// Reconnection backoff
bo := backoff.NewExponentialBackOff()
bo.InitialInterval = 2 * time.Second
// Repeat until context is canceled
for ctx.Err() == nil {
// Repeat until context is canceled or binding closed.
for {
select {
case <-ctx.Done():
return
case <-a.closeCh:
return
default:
}
sub, err := a.client.SubscribeToShardWithContext(ctx, &kinesis.SubscribeToShardInput{
ConsumerARN: consumerARN,
ShardId: s.ShardId,
@ -204,8 +234,12 @@ func (a *AWSKinesis) Subscribe(ctx context.Context, streamDesc kinesis.StreamDes
if err != nil {
wait := bo.NextBackOff()
a.logger.Errorf("Error while reading from shard %v: %v. Attempting to reconnect in %s...", s.ShardId, err, wait)
time.Sleep(wait)
continue
select {
case <-ctx.Done():
return
case <-time.After(wait):
continue
}
}
// Reset the backoff on connection success
@ -223,22 +257,30 @@ func (a *AWSKinesis) Subscribe(ctx context.Context, streamDesc kinesis.StreamDes
}
}
}
return nil
}(i, shard)
}
return nil
}
func (a *AWSKinesis) ensureConsumer(parentCtx context.Context, streamARN *string) (*string, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
consumer, err := a.client.DescribeStreamConsumerWithContext(ctx, &kinesis.DescribeStreamConsumerInput{
func (a *AWSKinesis) Close() error {
if a.closed.CompareAndSwap(false, true) {
close(a.closeCh)
}
a.wg.Wait()
return nil
}
func (a *AWSKinesis) ensureConsumer(ctx context.Context, streamARN *string) (*string, error) {
// Only set timeout on consumer call.
conCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
consumer, err := a.client.DescribeStreamConsumerWithContext(conCtx, &kinesis.DescribeStreamConsumerInput{
ConsumerName: &a.metadata.ConsumerName,
StreamARN: streamARN,
})
cancel()
if err != nil {
return a.registerConsumer(parentCtx, streamARN)
return a.registerConsumer(ctx, streamARN)
}
return consumer.ConsumerDescription.ConsumerARN, nil

View File

@ -99,7 +99,7 @@ func NewAWSS3(logger logger.Logger) bindings.OutputBinding {
}
// Init does metadata parsing and connection creation.
func (s *AWSS3) Init(metadata bindings.Metadata) error {
func (s *AWSS3) Init(_ context.Context, metadata bindings.Metadata) error {
m, err := s.parseMetadata(metadata)
if err != nil {
return err

View File

@ -61,7 +61,7 @@ func NewAWSSES(logger logger.Logger) bindings.OutputBinding {
}
// Init does metadata parsing.
func (a *AWSSES) Init(metadata bindings.Metadata) error {
func (a *AWSSES) Init(_ context.Context, metadata bindings.Metadata) error {
// Parse input metadata
meta, err := a.parseMetadata(metadata)
if err != nil {

View File

@ -53,7 +53,7 @@ func NewAWSSNS(logger logger.Logger) bindings.OutputBinding {
}
// Init does metadata parsing.
func (a *AWSSNS) Init(metadata bindings.Metadata) error {
func (a *AWSSNS) Init(_ context.Context, metadata bindings.Metadata) error {
m, err := a.parseMetadata(metadata)
if err != nil {
return err

View File

@ -16,6 +16,9 @@ package sqs
import (
"context"
"encoding/json"
"errors"
"sync"
"sync/atomic"
"time"
"github.com/aws/aws-sdk-go/aws"
@ -31,7 +34,10 @@ type AWSSQS struct {
Client *sqs.SQS
QueueURL *string
logger logger.Logger
logger logger.Logger
wg sync.WaitGroup
closeCh chan struct{}
closed atomic.Bool
}
type sqsMetadata struct {
@ -45,11 +51,14 @@ type sqsMetadata struct {
// NewAWSSQS returns a new AWS SQS instance.
func NewAWSSQS(logger logger.Logger) bindings.InputOutputBinding {
return &AWSSQS{logger: logger}
return &AWSSQS{
logger: logger,
closeCh: make(chan struct{}),
}
}
// Init does metadata parsing and connection creation.
func (a *AWSSQS) Init(metadata bindings.Metadata) error {
func (a *AWSSQS) Init(ctx context.Context, metadata bindings.Metadata) error {
m, err := a.parseSQSMetadata(metadata)
if err != nil {
return err
@ -61,7 +70,7 @@ func (a *AWSSQS) Init(metadata bindings.Metadata) error {
}
queueName := m.QueueName
resultURL, err := client.GetQueueUrl(&sqs.GetQueueUrlInput{
resultURL, err := client.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{
QueueName: aws.String(queueName),
})
if err != nil {
@ -89,9 +98,20 @@ func (a *AWSSQS) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bind
}
func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error {
if a.closed.Load() {
return errors.New("binding is closed")
}
a.wg.Add(1)
go func() {
// Repeat until the context is canceled
for ctx.Err() == nil {
defer a.wg.Done()
// Repeat until the context is canceled or component is closed
for {
if ctx.Err() != nil || a.closed.Load() {
return
}
result, err := a.Client.ReceiveMessageWithContext(ctx, &sqs.ReceiveMessageInput{
QueueUrl: a.QueueURL,
AttributeNames: aws.StringSlice([]string{
@ -126,13 +146,25 @@ func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error {
}
}
time.Sleep(time.Millisecond * 50)
select {
case <-ctx.Done():
case <-a.closeCh:
case <-time.After(time.Millisecond * 50):
}
}
}()
return nil
}
func (a *AWSSQS) Close() error {
if a.closed.CompareAndSwap(false, true) {
close(a.closeCh)
}
a.wg.Wait()
return nil
}
func (a *AWSSQS) parseSQSMetadata(metadata bindings.Metadata) (*sqsMetadata, error) {
b, err := json.Marshal(metadata.Properties)
if err != nil {

View File

@ -92,7 +92,7 @@ func NewAzureBlobStorage(logger logger.Logger) bindings.OutputBinding {
}
// Init performs metadata parsing.
func (a *AzureBlobStorage) Init(metadata bindings.Metadata) error {
func (a *AzureBlobStorage) Init(_ context.Context, metadata bindings.Metadata) error {
var err error
a.containerClient, a.metadata, err = storageinternal.CreateContainerStorageClient(a.logger, metadata.Properties)
if err != nil {

View File

@ -53,7 +53,7 @@ func NewCosmosDB(logger logger.Logger) bindings.OutputBinding {
}
// Init performs CosmosDB connection parsing and connecting.
func (c *CosmosDB) Init(metadata bindings.Metadata) error {
func (c *CosmosDB) Init(ctx context.Context, metadata bindings.Metadata) error {
m, err := c.parseMetadata(metadata)
if err != nil {
return err
@ -103,9 +103,9 @@ func (c *CosmosDB) Init(metadata bindings.Metadata) error {
}
c.client = dbContainer
ctx, cancel := context.WithTimeout(context.Background(), timeoutValue*time.Second)
_, err = c.client.Read(ctx, nil)
cancel()
readCtx, readCancel := context.WithTimeout(ctx, timeoutValue*time.Second)
defer readCancel()
_, err = c.client.Read(readCtx, nil)
return err
}

View File

@ -59,7 +59,7 @@ func NewCosmosDBGremlinAPI(logger logger.Logger) bindings.OutputBinding {
}
// Init performs CosmosDBGremlinAPI connection parsing and connecting.
func (c *CosmosDBGremlinAPI) Init(metadata bindings.Metadata) error {
func (c *CosmosDBGremlinAPI) Init(_ context.Context, metadata bindings.Metadata) error {
c.logger.Debug("Initializing Cosmos Graph DB binding")
m, err := c.parseMetadata(metadata)

View File

@ -21,6 +21,8 @@ import (
"net/url"
"regexp"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
@ -34,6 +36,7 @@ import (
"github.com/valyala/fasthttp"
"github.com/dapr/components-contrib/bindings"
"github.com/dapr/components-contrib/contenttype"
azauth "github.com/dapr/components-contrib/internal/authentication/azure"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/kit/logger"
@ -56,6 +59,9 @@ type AzureEventGrid struct {
metadata *azureEventGridMetadata
logger logger.Logger
jwks jwk.Set
closeCh chan struct{}
closed atomic.Bool
wg sync.WaitGroup
}
type azureEventGridMetadata struct {
@ -84,11 +90,14 @@ type azureEventGridMetadata struct {
// NewAzureEventGrid returns a new Azure Event Grid instance.
func NewAzureEventGrid(logger logger.Logger) bindings.InputOutputBinding {
return &AzureEventGrid{logger: logger}
return &AzureEventGrid{
logger: logger,
closeCh: make(chan struct{}),
}
}
// Init performs metadata init.
func (a *AzureEventGrid) Init(metadata bindings.Metadata) error {
func (a *AzureEventGrid) Init(_ context.Context, metadata bindings.Metadata) error {
m, err := a.parseMetadata(metadata)
if err != nil {
return err
@ -149,6 +158,10 @@ func (a *AzureEventGrid) initJWKSCache(ctx context.Context) error {
}
func (a *AzureEventGrid) Read(ctx context.Context, handler bindings.Handler) error {
if a.closed.Load() {
return errors.New("binding is closed")
}
err := a.ensureInputBindingMetadata()
if err != nil {
return err
@ -164,17 +177,22 @@ func (a *AzureEventGrid) Read(ctx context.Context, handler bindings.Handler) err
}
// Run the server in background
a.wg.Add(2)
go func() {
defer a.wg.Done()
a.logger.Infof("Listening for Event Grid events at http://localhost:%s%s", a.metadata.HandshakePort, a.metadata.subscriberPath)
srvErr := srv.ListenAndServe(":" + a.metadata.HandshakePort)
if err != nil {
a.logger.Errorf("Error starting server: %v", srvErr)
}
}()
// Close the server when context is canceled
// Close the server when context is canceled or binding closed.
go func() {
<-ctx.Done()
defer a.wg.Done()
select {
case <-ctx.Done():
case <-a.closeCh:
}
srvErr := srv.Shutdown()
if err != nil {
a.logger.Errorf("Error shutting down server: %v", srvErr)
@ -193,6 +211,14 @@ func (a *AzureEventGrid) Operations() []bindings.OperationKind {
return []bindings.OperationKind{bindings.CreateOperation}
}
func (a *AzureEventGrid) Close() error {
if a.closed.CompareAndSwap(false, true) {
close(a.closeCh)
}
a.wg.Wait()
return nil
}
func (a *AzureEventGrid) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
err := a.ensureOutputBindingMetadata()
if err != nil {
@ -203,7 +229,7 @@ func (a *AzureEventGrid) Invoke(ctx context.Context, req *bindings.InvokeRequest
request := fasthttp.AcquireRequest()
defer fasthttp.ReleaseRequest(request)
request.Header.SetMethod(fasthttp.MethodPost)
request.Header.Set("Content-Type", "application/cloudevents+json")
request.Header.Set("Content-Type", contenttype.CloudEventContentType)
request.Header.Set("aeg-sas-key", a.metadata.AccessKey)
request.Header.Set("User-Agent", "dapr/"+logger.DaprVersion)
request.SetRequestURI(a.metadata.TopicEndpoint)

View File

@ -37,7 +37,7 @@ func NewAzureEventHubs(logger logger.Logger) bindings.InputOutputBinding {
}
// Init performs metadata init.
func (a *AzureEventHubs) Init(metadata bindings.Metadata) error {
func (a *AzureEventHubs) Init(_ context.Context, metadata bindings.Metadata) error {
return a.AzureEventHubs.Init(metadata.Properties)
}

View File

@ -102,7 +102,7 @@ func testEventHubsBindingsAADAuthentication(t *testing.T) {
metadata := createEventHubsBindingsAADMetadata()
eventHubsBindings := NewAzureEventHubs(log)
err := eventHubsBindings.Init(metadata)
err := eventHubsBindings.Init(context.Background(), metadata)
require.NoError(t, err)
req := &bindings.InvokeRequest{
@ -146,7 +146,7 @@ func testReadIotHubEvents(t *testing.T) {
logger := logger.NewLogger("bindings.azure.eventhubs.integration.test")
eh := NewAzureEventHubs(logger)
err := eh.Init(createIotHubBindingsMetadata())
err := eh.Init(context.Background(), createIotHubBindingsMetadata())
require.NoError(t, err)
// Invoke az CLI via bash script to send test IoT device events

View File

@ -17,6 +17,8 @@ import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
servicebus "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus"
@ -37,17 +39,21 @@ type AzureServiceBusQueues struct {
metadata *impl.Metadata
client *impl.Client
logger logger.Logger
closed atomic.Bool
wg sync.WaitGroup
closeCh chan struct{}
}
// NewAzureServiceBusQueues returns a new AzureServiceBusQueues instance.
func NewAzureServiceBusQueues(logger logger.Logger) bindings.InputOutputBinding {
return &AzureServiceBusQueues{
logger: logger,
logger: logger,
closeCh: make(chan struct{}),
}
}
// Init parses connection properties and creates a new Service Bus Queue client.
func (a *AzureServiceBusQueues) Init(metadata bindings.Metadata) (err error) {
func (a *AzureServiceBusQueues) Init(ctx context.Context, metadata bindings.Metadata) (err error) {
a.metadata, err = impl.ParseMetadata(metadata.Properties, a.logger, (impl.MetadataModeBinding | impl.MetadataModeQueues))
if err != nil {
return err
@ -59,7 +65,7 @@ func (a *AzureServiceBusQueues) Init(metadata bindings.Metadata) (err error) {
}
// Will do nothing if DisableEntityManagement is false
err = a.client.EnsureQueue(context.Background(), a.metadata.QueueName)
err = a.client.EnsureQueue(ctx, a.metadata.QueueName)
if err != nil {
return err
}
@ -78,10 +84,16 @@ func (a *AzureServiceBusQueues) Invoke(ctx context.Context, req *bindings.Invoke
}
func (a *AzureServiceBusQueues) Read(ctx context.Context, handler bindings.Handler) error {
if a.closed.Load() {
return errors.New("binding is closed")
}
// Reconnection backoff policy
bo := a.client.ReconnectionBackoff()
a.wg.Add(1)
go func() {
defer a.wg.Done()
logMsg := "queue " + a.metadata.QueueName
// Reconnect loop.
@ -127,20 +139,17 @@ func (a *AzureServiceBusQueues) Read(ctx context.Context, handler bindings.Handl
a.logger.Errorf("Error from receiver: %v", err)
}
// If context was canceled, do not attempt to reconnect
if ctx.Err() != nil {
a.logger.Debug("Context canceled; will not reconnect")
return
}
wait := bo.NextBackOff()
a.logger.Warnf("Subscription to queue %s lost connection, attempting to reconnect in %s...", a.metadata.QueueName, wait)
time.Sleep(wait)
// Check for context canceled again, after sleeping
if ctx.Err() != nil {
select {
case <-time.After(wait):
// nop
case <-ctx.Done():
a.logger.Debug("Context canceled; will not reconnect")
return
case <-a.closeCh:
a.logger.Debug("Component is closing; will not reconnect")
return
}
}
}()
@ -180,7 +189,11 @@ func (a *AzureServiceBusQueues) getHandlerFn(handler bindings.Handler) impl.Hand
}
func (a *AzureServiceBusQueues) Close() (err error) {
if a.closed.CompareAndSwap(false, true) {
close(a.closeCh)
}
a.logger.Debug("Closing component")
a.client.Close(a.logger)
a.wg.Wait()
return nil
}

View File

@ -78,7 +78,7 @@ type SignalR struct {
}
// Init is responsible for initializing the SignalR output based on the metadata.
func (s *SignalR) Init(metadata bindings.Metadata) (err error) {
func (s *SignalR) Init(_ context.Context, metadata bindings.Metadata) (err error) {
s.userAgent = "dapr-" + logger.DaprVersion
err = s.parseMetadata(metadata.Properties)

View File

@ -16,9 +16,12 @@ package storagequeues
import (
"context"
"encoding/base64"
"errors"
"fmt"
"net/url"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/Azure/azure-storage-queue-go/azqueue"
@ -40,9 +43,10 @@ type consumer struct {
// QueueHelper enables injection for testnig.
type QueueHelper interface {
Init(metadata bindings.Metadata) (*storageQueuesMetadata, error)
Init(ctx context.Context, metadata bindings.Metadata) (*storageQueuesMetadata, error)
Write(ctx context.Context, data []byte, ttl *time.Duration) error
Read(ctx context.Context, consumer *consumer) error
Close() error
}
// AzureQueueHelper concrete impl of queue helper.
@ -55,7 +59,7 @@ type AzureQueueHelper struct {
}
// Init sets up this helper.
func (d *AzureQueueHelper) Init(metadata bindings.Metadata) (*storageQueuesMetadata, error) {
func (d *AzureQueueHelper) Init(ctx context.Context, metadata bindings.Metadata) (*storageQueuesMetadata, error) {
m, err := parseMetadata(metadata)
if err != nil {
return nil, err
@ -89,9 +93,9 @@ func (d *AzureQueueHelper) Init(metadata bindings.Metadata) (*storageQueuesMetad
d.queueURL = azqueue.NewQueueURL(*URL, p)
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
_, err = d.queueURL.Create(ctx, azqueue.Metadata{})
cancel()
createCtx, createCancel := context.WithTimeout(ctx, 2*time.Minute)
_, err = d.queueURL.Create(createCtx, azqueue.Metadata{})
createCancel()
if err != nil {
return nil, err
}
@ -128,7 +132,10 @@ func (d *AzureQueueHelper) Read(ctx context.Context, consumer *consumer) error {
}
if res.NumMessages() == 0 {
// Queue was empty so back off by 10 seconds before trying again
time.Sleep(10 * time.Second)
select {
case <-time.After(10 * time.Second):
case <-ctx.Done():
}
return nil
}
mt := res.Message(0).Text
@ -162,6 +169,10 @@ func (d *AzureQueueHelper) Read(ctx context.Context, consumer *consumer) error {
return nil
}
func (d *AzureQueueHelper) Close() error {
return nil
}
// NewAzureQueueHelper creates new helper.
func NewAzureQueueHelper(logger logger.Logger) QueueHelper {
return &AzureQueueHelper{
@ -175,6 +186,10 @@ type AzureStorageQueues struct {
helper QueueHelper
logger logger.Logger
wg sync.WaitGroup
closeCh chan struct{}
closed atomic.Bool
}
type storageQueuesMetadata struct {
@ -189,12 +204,16 @@ type storageQueuesMetadata struct {
// NewAzureStorageQueues returns a new AzureStorageQueues instance.
func NewAzureStorageQueues(logger logger.Logger) bindings.InputOutputBinding {
return &AzureStorageQueues{helper: NewAzureQueueHelper(logger), logger: logger}
return &AzureStorageQueues{
helper: NewAzureQueueHelper(logger),
logger: logger,
closeCh: make(chan struct{}),
}
}
// Init parses connection properties and creates a new Storage Queue client.
func (a *AzureStorageQueues) Init(metadata bindings.Metadata) (err error) {
a.metadata, err = a.helper.Init(metadata)
func (a *AzureStorageQueues) Init(ctx context.Context, metadata bindings.Metadata) (err error) {
a.metadata, err = a.helper.Init(ctx, metadata)
if err != nil {
return err
}
@ -261,14 +280,32 @@ func (a *AzureStorageQueues) Invoke(ctx context.Context, req *bindings.InvokeReq
}
func (a *AzureStorageQueues) Read(ctx context.Context, handler bindings.Handler) error {
if a.closed.Load() {
return errors.New("input binding is closed")
}
c := consumer{
callback: handler,
}
// Close read context when binding is closed.
readCtx, cancel := context.WithCancel(ctx)
a.wg.Add(2)
go func() {
defer a.wg.Done()
defer cancel()
select {
case <-a.closeCh:
case <-ctx.Done():
}
}()
go func() {
defer a.wg.Done()
// Read until context is canceled
var err error
for ctx.Err() == nil {
err = a.helper.Read(ctx, &c)
for readCtx.Err() == nil {
err = a.helper.Read(readCtx, &c)
if err != nil {
a.logger.Errorf("error from c: %s", err)
}
@ -277,3 +314,11 @@ func (a *AzureStorageQueues) Read(ctx context.Context, handler bindings.Handler)
return nil
}
func (a *AzureStorageQueues) Close() error {
if a.closed.CompareAndSwap(false, true) {
close(a.closeCh)
}
a.wg.Wait()
return nil
}

View File

@ -16,6 +16,7 @@ package storagequeues
import (
"context"
"encoding/base64"
"sync"
"testing"
"time"
@ -32,9 +33,11 @@ type MockHelper struct {
mock.Mock
messages chan []byte
metadata *storageQueuesMetadata
closeCh chan struct{}
wg sync.WaitGroup
}
func (m *MockHelper) Init(metadata bindings.Metadata) (*storageQueuesMetadata, error) {
func (m *MockHelper) Init(ctx context.Context, metadata bindings.Metadata) (*storageQueuesMetadata, error) {
m.messages = make(chan []byte, 10)
var err error
m.metadata, err = parseMetadata(metadata)
@ -50,12 +53,23 @@ func (m *MockHelper) Write(ctx context.Context, data []byte, ttl *time.Duration)
func (m *MockHelper) Read(ctx context.Context, consumer *consumer) error {
retvals := m.Called(ctx, consumer)
readCtx, cancel := context.WithCancel(ctx)
m.wg.Add(2)
go func() {
defer m.wg.Done()
defer cancel()
select {
case <-readCtx.Done():
case <-m.closeCh:
}
}()
go func() {
defer m.wg.Done()
for msg := range m.messages {
if m.metadata.DecodeBase64 {
msg, _ = base64.StdEncoding.DecodeString(string(msg))
}
go consumer.callback(ctx, &bindings.ReadResponse{
go consumer.callback(readCtx, &bindings.ReadResponse{
Data: msg,
})
}
@ -64,18 +78,24 @@ func (m *MockHelper) Read(ctx context.Context, consumer *consumer) error {
return retvals.Error(0)
}
func (m *MockHelper) Close() error {
defer m.wg.Wait()
close(m.closeCh)
return nil
}
func TestWriteQueue(t *testing.T) {
mm := new(MockHelper)
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.MatchedBy(func(in *time.Duration) bool {
return in == nil
})).Return(nil)
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
m := bindings.Metadata{}
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
err := a.Init(m)
err := a.Init(context.Background(), m)
assert.Nil(t, err)
r := bindings.InvokeRequest{Data: []byte("This is my message")}
@ -83,6 +103,7 @@ func TestWriteQueue(t *testing.T) {
_, err = a.Invoke(context.Background(), &r)
assert.Nil(t, err)
assert.NoError(t, a.Close())
}
func TestWriteWithTTLInQueue(t *testing.T) {
@ -91,12 +112,12 @@ func TestWriteWithTTLInQueue(t *testing.T) {
return in != nil && *in == time.Second
})).Return(nil)
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
m := bindings.Metadata{}
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1", metadata.TTLMetadataKey: "1"}
err := a.Init(m)
err := a.Init(context.Background(), m)
assert.Nil(t, err)
r := bindings.InvokeRequest{Data: []byte("This is my message")}
@ -104,6 +125,7 @@ func TestWriteWithTTLInQueue(t *testing.T) {
_, err = a.Invoke(context.Background(), &r)
assert.Nil(t, err)
assert.NoError(t, a.Close())
}
func TestWriteWithTTLInWrite(t *testing.T) {
@ -112,12 +134,12 @@ func TestWriteWithTTLInWrite(t *testing.T) {
return in != nil && *in == time.Second
})).Return(nil)
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
m := bindings.Metadata{}
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1", metadata.TTLMetadataKey: "1"}
err := a.Init(m)
err := a.Init(context.Background(), m)
assert.Nil(t, err)
r := bindings.InvokeRequest{
@ -128,6 +150,7 @@ func TestWriteWithTTLInWrite(t *testing.T) {
_, err = a.Invoke(context.Background(), &r)
assert.Nil(t, err)
assert.NoError(t, a.Close())
}
// Uncomment this function to write a message to local storage queue
@ -138,7 +161,7 @@ func TestWriteWithTTLInWrite(t *testing.T) {
m := bindings.Metadata{}
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
err := a.Init(m)
err := a.Init(context.Background(), m)
assert.Nil(t, err)
r := bindings.InvokeRequest{Data: []byte("This is my message")}
@ -152,12 +175,12 @@ func TestReadQueue(t *testing.T) {
mm := new(MockHelper)
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)
mm.On("Read", mock.AnythingOfType("*context.cancelCtx"), mock.AnythingOfType("*storagequeues.consumer")).Return(nil)
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
m := bindings.Metadata{}
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
err := a.Init(m)
err := a.Init(context.Background(), m)
assert.Nil(t, err)
r := bindings.InvokeRequest{Data: []byte("This is my message")}
@ -186,6 +209,7 @@ func TestReadQueue(t *testing.T) {
t.Fatal("Timeout waiting for messages")
}
assert.Equal(t, 1, received)
assert.NoError(t, a.Close())
}
func TestReadQueueDecode(t *testing.T) {
@ -193,12 +217,12 @@ func TestReadQueueDecode(t *testing.T) {
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)
mm.On("Read", mock.AnythingOfType("*context.cancelCtx"), mock.AnythingOfType("*storagequeues.consumer")).Return(nil)
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
m := bindings.Metadata{}
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1", "decodeBase64": "true"}
err := a.Init(m)
err := a.Init(context.Background(), m)
assert.Nil(t, err)
r := bindings.InvokeRequest{Data: []byte("VGhpcyBpcyBteSBtZXNzYWdl")}
@ -227,6 +251,7 @@ func TestReadQueueDecode(t *testing.T) {
t.Fatal("Timeout waiting for messages")
}
assert.Equal(t, 1, received)
assert.NoError(t, a.Close())
}
// Uncomment this function to test reding from local queue
@ -237,7 +262,7 @@ func TestReadQueueDecode(t *testing.T) {
m := bindings.Metadata{}
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
err := a.Init(m)
err := a.Init(context.Background(), m)
assert.Nil(t, err)
r := bindings.InvokeRequest{Data: []byte("This is my message")}
@ -263,12 +288,12 @@ func TestReadQueueNoMessage(t *testing.T) {
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)
mm.On("Read", mock.AnythingOfType("*context.cancelCtx"), mock.AnythingOfType("*storagequeues.consumer")).Return(nil)
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test"), closeCh: make(chan struct{})}
m := bindings.Metadata{}
m.Properties = map[string]string{"storageAccessKey": "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", "queue": "queue1", "storageAccount": "devstoreaccount1"}
err := a.Init(m)
err := a.Init(context.Background(), m)
assert.Nil(t, err)
ctx, cancel := context.WithCancel(context.Background())
@ -285,6 +310,7 @@ func TestReadQueueNoMessage(t *testing.T) {
time.Sleep(1 * time.Second)
cancel()
assert.Equal(t, 0, received)
assert.NoError(t, a.Close())
}
func TestParseMetadata(t *testing.T) {

View File

@ -48,7 +48,7 @@ func NewCFQueues(logger logger.Logger) bindings.OutputBinding {
}
// Init the component.
func (q *CFQueues) Init(metadata bindings.Metadata) error {
func (q *CFQueues) Init(_ context.Context, metadata bindings.Metadata) error {
// Decode the metadata
err := mapstructure.Decode(metadata.Properties, &q.metadata)
if err != nil {

View File

@ -51,7 +51,7 @@ func NewCommercetools(logger logger.Logger) bindings.OutputBinding {
}
// Init does metadata parsing and connection establishment.
func (ct *Binding) Init(metadata bindings.Metadata) error {
func (ct *Binding) Init(_ context.Context, metadata bindings.Metadata) error {
commercetoolsM, err := ct.getCommercetoolsMetadata(metadata)
if err != nil {
return err

View File

@ -16,6 +16,8 @@ package cron
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/benbjohnson/clock"
@ -34,6 +36,9 @@ type Binding struct {
schedule string
parser cron.Parser
clk clock.Clock
closed atomic.Bool
closeCh chan struct{}
wg sync.WaitGroup
}
// NewCron returns a new Cron event input binding.
@ -48,6 +53,7 @@ func NewCronWithClock(logger logger.Logger, clk clock.Clock) bindings.InputBindi
parser: cron.NewParser(
cron.SecondOptional | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor,
),
closeCh: make(chan struct{}),
}
}
@ -56,7 +62,7 @@ func NewCronWithClock(logger logger.Logger, clk clock.Clock) bindings.InputBindi
//
// "15 * * * * *" - Every 15 sec
// "0 30 * * * *" - Every 30 min
func (b *Binding) Init(metadata bindings.Metadata) error {
func (b *Binding) Init(ctx context.Context, metadata bindings.Metadata) error {
b.name = metadata.Name
s, f := metadata.Properties["schedule"]
if !f || s == "" {
@ -73,6 +79,10 @@ func (b *Binding) Init(metadata bindings.Metadata) error {
// Read triggers the Cron scheduler.
func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
if b.closed.Load() {
return errors.New("binding is closed")
}
c := cron.New(cron.WithParser(b.parser), cron.WithClock(b.clk))
id, err := c.AddFunc(b.schedule, func() {
b.logger.Debugf("name: %s, schedule fired: %v", b.name, time.Now())
@ -89,12 +99,25 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
c.Start()
b.logger.Debugf("name: %s, next run: %v", b.name, time.Until(c.Entry(id).Next))
b.wg.Add(1)
go func() {
// Wait for context to be canceled
<-ctx.Done()
defer b.wg.Done()
// Wait for context to be canceled or component to be closed.
select {
case <-ctx.Done():
case <-b.closeCh:
}
b.logger.Debugf("name: %s, stopping schedule: %s", b.name, b.schedule)
c.Stop()
}()
return nil
}
func (b *Binding) Close() error {
if b.closed.CompareAndSwap(false, true) {
close(b.closeCh)
}
b.wg.Wait()
return nil
}

View File

@ -16,6 +16,7 @@ package cron
import (
"context"
"os"
"sync/atomic"
"testing"
"time"
@ -84,7 +85,7 @@ func TestCronInitSuccess(t *testing.T) {
for _, test := range initTests {
c := getNewCron()
err := c.Init(getTestMetadata(test.schedule))
err := c.Init(context.Background(), getTestMetadata(test.schedule))
if test.errorExpected {
assert.Errorf(t, err, "Got no error while initializing an invalid schedule: %s", test.schedule)
} else {
@ -99,38 +100,41 @@ func TestCronRead(t *testing.T) {
clk := clock.NewMock()
c := getNewCronWithClock(clk)
schedule := "@every 1s"
assert.NoErrorf(t, c.Init(getTestMetadata(schedule)), "error initializing valid schedule")
expectedCount := 5
observedCount := 0
assert.NoErrorf(t, c.Init(context.Background(), getTestMetadata(schedule)), "error initializing valid schedule")
expectedCount := int32(5)
var observedCount atomic.Int32
err := c.Read(context.Background(), func(ctx context.Context, res *bindings.ReadResponse) ([]byte, error) {
assert.NotNil(t, res)
observedCount++
observedCount.Add(1)
return nil, nil
})
// Check if cron triggers 5 times in 5 seconds
for i := 0; i < expectedCount; i++ {
for i := int32(0); i < expectedCount; i++ {
// Add time to mock clock in 1 second intervals using loop to allow cron go routine to run
clk.Add(time.Second)
}
// Wait for 1 second after adding the last second to mock clock to allow cron to finish triggering
time.Sleep(1 * time.Second)
assert.Equal(t, expectedCount, observedCount, "Cron did not trigger expected number of times, expected %d, got %d", expectedCount, observedCount)
assert.Eventually(t, func() bool {
return observedCount.Load() == expectedCount
}, time.Second, time.Millisecond*10,
"Cron did not trigger expected number of times, expected %d, got %d", expectedCount, observedCount.Load())
assert.NoErrorf(t, err, "error on read")
assert.NoError(t, c.Close())
}
func TestCronReadWithContextCancellation(t *testing.T) {
clk := clock.NewMock()
c := getNewCronWithClock(clk)
schedule := "@every 1s"
assert.NoErrorf(t, c.Init(getTestMetadata(schedule)), "error initializing valid schedule")
expectedCount := 5
observedCount := 0
assert.NoErrorf(t, c.Init(context.Background(), getTestMetadata(schedule)), "error initializing valid schedule")
expectedCount := int32(5)
var observedCount atomic.Int32
ctx, cancel := context.WithCancel(context.Background())
err := c.Read(ctx, func(ctx context.Context, res *bindings.ReadResponse) ([]byte, error) {
assert.NotNil(t, res)
assert.LessOrEqualf(t, observedCount, expectedCount, "Invoke didn't stop the schedule")
observedCount++
if observedCount == expectedCount {
assert.LessOrEqualf(t, observedCount.Load(), expectedCount, "Invoke didn't stop the schedule")
observedCount.Add(1)
if observedCount.Load() == expectedCount {
// Cancel context after 5 triggers
cancel()
}
@ -141,7 +145,10 @@ func TestCronReadWithContextCancellation(t *testing.T) {
// Add time to mock clock in 1 second intervals using loop to allow cron go routine to run
clk.Add(time.Second)
}
time.Sleep(1 * time.Second)
assert.Equal(t, expectedCount, observedCount, "Cron did not trigger expected number of times, expected %d, got %d", expectedCount, observedCount)
assert.Eventually(t, func() bool {
return observedCount.Load() == expectedCount
}, time.Second, time.Millisecond*10,
"Cron did not trigger expected number of times, expected %d, got %d", expectedCount, observedCount.Load())
assert.NoErrorf(t, err, "error on read")
assert.NoError(t, c.Close())
}

View File

@ -47,7 +47,7 @@ func NewDubboOutput(logger logger.Logger) bindings.OutputBinding {
return dubboBinding
}
func (out *DubboOutputBinding) Init(_ bindings.Metadata) error {
func (out *DubboOutputBinding) Init(_ context.Context, _ bindings.Metadata) error {
dubboImpl.SetSerializer(constant.Hessian2Serialization, HessianSerializer{})
return nil
}

View File

@ -54,12 +54,13 @@ func TestInvoke(t *testing.T) {
// 0. init dapr provided and dubbo server
stopCh := make(chan struct{})
defer close(stopCh)
// Create output and set serializer before go routine to prevent data race.
output := NewDubboOutput(logger.NewLogger("test"))
dubboImpl.SetSerializer(constant.Hessian2Serialization, HessianSerializer{})
go func() {
assert.Nil(t, runDubboServer(stopCh))
}()
time.Sleep(time.Second * 3)
dubboImpl.SetSerializer(constant.Hessian2Serialization, HessianSerializer{})
output := NewDubboOutput(logger.NewLogger("test"))
// 1. create req/rsp value
reqUser := &User{Name: testName}

View File

@ -83,14 +83,13 @@ func NewGCPStorage(logger logger.Logger) bindings.OutputBinding {
}
// Init performs connection parsing.
func (g *GCPStorage) Init(metadata bindings.Metadata) error {
func (g *GCPStorage) Init(ctx context.Context, metadata bindings.Metadata) error {
m, b, err := g.parseMetadata(metadata)
if err != nil {
return err
}
clientOptions := option.WithCredentialsJSON(b)
ctx := context.Background()
client, err := storage.NewClient(ctx, clientOptions)
if err != nil {
return err

View File

@ -16,7 +16,10 @@ package pubsub
import (
"context"
"encoding/json"
"errors"
"fmt"
"sync"
"sync/atomic"
"cloud.google.com/go/pubsub"
"google.golang.org/api/option"
@ -36,6 +39,9 @@ type GCPPubSub struct {
client *pubsub.Client
metadata *pubSubMetadata
logger logger.Logger
closed atomic.Bool
closeCh chan struct{}
wg sync.WaitGroup
}
type pubSubMetadata struct {
@ -55,11 +61,14 @@ type pubSubMetadata struct {
// NewGCPPubSub returns a new GCPPubSub instance.
func NewGCPPubSub(logger logger.Logger) bindings.InputOutputBinding {
return &GCPPubSub{logger: logger}
return &GCPPubSub{
logger: logger,
closeCh: make(chan struct{}),
}
}
// Init parses metadata and creates a new Pub Sub client.
func (g *GCPPubSub) Init(metadata bindings.Metadata) error {
func (g *GCPPubSub) Init(ctx context.Context, metadata bindings.Metadata) error {
b, err := g.parseMetadata(metadata)
if err != nil {
return err
@ -71,7 +80,6 @@ func (g *GCPPubSub) Init(metadata bindings.Metadata) error {
return err
}
clientOptions := option.WithCredentialsJSON(b)
ctx := context.Background()
pubsubClient, err := pubsub.NewClient(ctx, pubsubMeta.ProjectID, clientOptions)
if err != nil {
return fmt.Errorf("error creating pubsub client: %s", err)
@ -88,7 +96,12 @@ func (g *GCPPubSub) parseMetadata(metadata bindings.Metadata) ([]byte, error) {
}
func (g *GCPPubSub) Read(ctx context.Context, handler bindings.Handler) error {
if g.closed.Load() {
return errors.New("binding is closed")
}
g.wg.Add(1)
go func() {
defer g.wg.Done()
sub := g.client.Subscription(g.metadata.Subscription)
err := sub.Receive(ctx, func(c context.Context, m *pubsub.Message) {
_, err := handler(c, &bindings.ReadResponse{
@ -128,5 +141,9 @@ func (g *GCPPubSub) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*b
}
func (g *GCPPubSub) Close() error {
if g.closed.CompareAndSwap(false, true) {
close(g.closeCh)
}
defer g.wg.Wait()
return g.client.Close()
}

View File

@ -58,7 +58,7 @@ func NewGraphQL(logger logger.Logger) bindings.OutputBinding {
}
// Init initializes the GraphQL binding.
func (gql *GraphQL) Init(metadata bindings.Metadata) error {
func (gql *GraphQL) Init(_ context.Context, metadata bindings.Metadata) error {
gql.logger.Debug("GraphQL Error: Initializing GraphQL binding")
p := metadata.Properties

View File

@ -74,7 +74,7 @@ func NewHTTP(logger logger.Logger) bindings.OutputBinding {
}
// Init performs metadata parsing.
func (h *HTTPSource) Init(metadata bindings.Metadata) error {
func (h *HTTPSource) Init(_ context.Context, metadata bindings.Metadata) error {
var err error
if err = mapstructure.Decode(metadata.Properties, &h.metadata); err != nil {
return err
@ -104,7 +104,7 @@ func (h *HTTPSource) Init(metadata bindings.Metadata) error {
Transport: netTransport,
}
if val, ok := metadata.Properties["errorIfNot2XX"]; ok {
if val := metadata.Properties["errorIfNot2XX"]; val != "" {
h.errorIfNot2XX = utils.IsTruthy(val)
} else {
// Default behavior

View File

@ -132,7 +132,7 @@ func InitBinding(s *httptest.Server, extraProps map[string]string) (bindings.Out
}
hs := NewHTTP(logger.NewLogger("test"))
err := hs.Init(m)
err := hs.Init(context.Background(), m)
return hs, err
}
@ -269,7 +269,7 @@ func InitBindingForHTTPS(s *httptest.Server, extraProps map[string]string) (bind
m.Properties[k] = v
}
hs := NewHTTP(logger.NewLogger("test"))
err := hs.Init(m)
err := hs.Init(context.Background(), m)
return hs, err
}

View File

@ -75,7 +75,7 @@ func NewHuaweiOBS(logger logger.Logger) bindings.OutputBinding {
}
// Init does metadata parsing and connection creation.
func (o *HuaweiOBS) Init(metadata bindings.Metadata) error {
func (o *HuaweiOBS) Init(_ context.Context, metadata bindings.Metadata) error {
o.logger.Debugf("initializing Huawei OBS binding and parsing metadata")
m, err := o.parseMetadata(metadata)

View File

@ -92,7 +92,7 @@ func TestInit(t *testing.T) {
"accessKey": "dummy-ak",
"secretKey": "dummy-sk",
}
err := obs.Init(m)
err := obs.Init(context.Background(), m)
assert.Nil(t, err)
})
t.Run("Init with missing bucket name", func(t *testing.T) {
@ -102,7 +102,7 @@ func TestInit(t *testing.T) {
"accessKey": "dummy-ak",
"secretKey": "dummy-sk",
}
err := obs.Init(m)
err := obs.Init(context.Background(), m)
assert.NotNil(t, err)
assert.Equal(t, err, fmt.Errorf("missing obs bucket name"))
})
@ -113,7 +113,7 @@ func TestInit(t *testing.T) {
"endpoint": "dummy-endpoint",
"secretKey": "dummy-sk",
}
err := obs.Init(m)
err := obs.Init(context.Background(), m)
assert.NotNil(t, err)
assert.Equal(t, err, fmt.Errorf("missing the huawei access key"))
})
@ -124,7 +124,7 @@ func TestInit(t *testing.T) {
"endpoint": "dummy-endpoint",
"accessKey": "dummy-ak",
}
err := obs.Init(m)
err := obs.Init(context.Background(), m)
assert.NotNil(t, err)
assert.Equal(t, err, fmt.Errorf("missing the huawei secret key"))
})
@ -135,7 +135,7 @@ func TestInit(t *testing.T) {
"accessKey": "dummy-ak",
"secretKey": "dummy-sk",
}
err := obs.Init(m)
err := obs.Init(context.Background(), m)
assert.NotNil(t, err)
assert.Equal(t, err, fmt.Errorf("missing obs endpoint"))
})

View File

@ -64,7 +64,7 @@ func NewInflux(logger logger.Logger) bindings.OutputBinding {
}
// Init does metadata parsing and connection establishment.
func (i *Influx) Init(metadata bindings.Metadata) error {
func (i *Influx) Init(_ context.Context, metadata bindings.Metadata) error {
influxMeta, err := i.getInfluxMetadata(metadata)
if err != nil {
return err

View File

@ -54,7 +54,7 @@ func TestInflux_Init(t *testing.T) {
assert.Nil(t, influx.client)
m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{"Url": "a", "Token": "a", "Org": "a", "Bucket": "a"}}}
err := influx.Init(m)
err := influx.Init(context.Background(), m)
assert.Nil(t, err)
assert.NotNil(t, influx.queryAPI)

View File

@ -16,6 +16,7 @@ package bindings
import (
"context"
"fmt"
"io"
"github.com/dapr/components-contrib/health"
)
@ -23,18 +24,21 @@ import (
// InputBinding is the interface to define a binding that triggers on incoming events.
type InputBinding interface {
// Init passes connection and properties metadata to the binding implementation.
Init(metadata Metadata) error
Init(ctx context.Context, metadata Metadata) error
// Read is a method that runs in background and triggers the callback function whenever an event arrives.
Read(ctx context.Context, handler Handler) error
// Close is a method that closes the connection to the binding. Must be
// called when the binding is no longer needed to free up resources.
io.Closer
}
// Handler is the handler used to invoke the app handler.
type Handler func(context.Context, *ReadResponse) ([]byte, error)
func PingInpBinding(inputBinding InputBinding) error {
func PingInpBinding(ctx context.Context, inputBinding InputBinding) error {
// checks if this input binding has the ping option then executes
if inputBindingWithPing, ok := inputBinding.(health.Pinger); ok {
return inputBindingWithPing.Ping()
return inputBindingWithPing.Ping(ctx)
} else {
return fmt.Errorf("ping is not implemented by this input binding")
}

View File

@ -15,7 +15,10 @@ package kafka
import (
"context"
"errors"
"strings"
"sync"
"sync/atomic"
"github.com/dapr/kit/logger"
@ -29,12 +32,13 @@ const (
)
type Binding struct {
kafka *kafka.Kafka
publishTopic string
topics []string
logger logger.Logger
subscribeCtx context.Context
subscribeCancel context.CancelFunc
kafka *kafka.Kafka
publishTopic string
topics []string
logger logger.Logger
closeCh chan struct{}
closed atomic.Bool
wg sync.WaitGroup
}
// NewKafka returns a new kafka binding instance.
@ -43,15 +47,14 @@ func NewKafka(logger logger.Logger) bindings.InputOutputBinding {
// in kafka binding component, disable consumer retry by default
k.DefaultConsumeRetryEnabled = false
return &Binding{
kafka: k,
logger: logger,
kafka: k,
logger: logger,
closeCh: make(chan struct{}),
}
}
func (b *Binding) Init(metadata bindings.Metadata) error {
b.subscribeCtx, b.subscribeCancel = context.WithCancel(context.Background())
err := b.kafka.Init(metadata.Properties)
func (b *Binding) Init(ctx context.Context, metadata bindings.Metadata) error {
err := b.kafka.Init(ctx, metadata.Properties)
if err != nil {
return err
}
@ -74,7 +77,10 @@ func (b *Binding) Operations() []bindings.OperationKind {
}
func (b *Binding) Close() (err error) {
b.subscribeCancel()
if b.closed.CompareAndSwap(false, true) {
close(b.closeCh)
}
defer b.wg.Wait()
return b.kafka.Close()
}
@ -84,6 +90,10 @@ func (b *Binding) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bin
}
func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
if b.closed.Load() {
return errors.New("error: binding is closed")
}
if len(b.topics) == 0 {
b.logger.Warnf("kafka binding: no topic defined, input bindings will not be started")
return nil
@ -96,31 +106,22 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
for _, t := range b.topics {
b.kafka.AddTopicHandler(t, handlerConfig)
}
b.wg.Add(1)
go func() {
// Wait for context cancelation
defer b.wg.Done()
// Wait for context cancelation or closure.
select {
case <-ctx.Done():
case <-b.subscribeCtx.Done():
case <-b.closeCh:
}
// Remove the topic handler before restarting the subscriber
// Remove the topic handlers.
for _, t := range b.topics {
b.kafka.RemoveTopicHandler(t)
}
// If the component's context has been canceled, do not re-subscribe
if b.subscribeCtx.Err() != nil {
return
}
err := b.kafka.Subscribe(b.subscribeCtx)
if err != nil {
b.logger.Errorf("kafka binding: error re-subscribing: %v", err)
}
}()
return b.kafka.Subscribe(b.subscribeCtx)
return b.kafka.Subscribe(ctx)
}
func adaptHandler(handler bindings.Handler) kafka.EventHandler {

View File

@ -2,8 +2,11 @@ package kubemq
import (
"context"
"errors"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
qs "github.com/kubemq-io/kubemq-go/queues_stream"
@ -19,31 +22,30 @@ type Kubemq interface {
}
type kubeMQ struct {
client *qs.QueuesStreamClient
opts *options
logger logger.Logger
ctx context.Context
ctxCancel context.CancelFunc
client *qs.QueuesStreamClient
opts *options
logger logger.Logger
closed atomic.Bool
closeCh chan struct{}
wg sync.WaitGroup
}
func NewKubeMQ(logger logger.Logger) Kubemq {
return &kubeMQ{
client: nil,
opts: nil,
logger: logger,
ctx: nil,
ctxCancel: nil,
client: nil,
opts: nil,
logger: logger,
closeCh: make(chan struct{}),
}
}
func (k *kubeMQ) Init(metadata bindings.Metadata) error {
func (k *kubeMQ) Init(ctx context.Context, metadata bindings.Metadata) error {
opts, err := createOptions(metadata)
if err != nil {
return err
}
k.opts = opts
k.ctx, k.ctxCancel = context.WithCancel(context.Background())
client, err := qs.NewQueuesStreamClient(k.ctx,
client, err := qs.NewQueuesStreamClient(ctx,
qs.WithAddress(opts.host, opts.port),
qs.WithCheckConnection(true),
qs.WithAuthToken(opts.authToken),
@ -53,22 +55,39 @@ func (k *kubeMQ) Init(metadata bindings.Metadata) error {
k.logger.Errorf("error init kubemq client error: %s", err.Error())
return err
}
k.ctx, k.ctxCancel = context.WithCancel(context.Background())
k.client = client
return nil
}
func (k *kubeMQ) Read(ctx context.Context, handler bindings.Handler) error {
if k.closed.Load() {
return errors.New("binding is closed")
}
k.wg.Add(2)
processCtx, cancel := context.WithCancel(ctx)
go func() {
defer k.wg.Done()
defer cancel()
select {
case <-k.closeCh:
case <-processCtx.Done():
}
}()
go func() {
defer k.wg.Done()
for {
err := k.processQueueMessage(k.ctx, handler)
err := k.processQueueMessage(processCtx, handler)
if err != nil {
k.logger.Error(err.Error())
time.Sleep(time.Second)
}
if k.ctx.Err() != nil {
return
// If context cancelled or kubeMQ closed, exit. Otherwise, continue
// after a second.
select {
case <-time.After(time.Second):
continue
case <-processCtx.Done():
}
return
}
}()
return nil
@ -82,7 +101,7 @@ func (k *kubeMQ) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bind
SetPolicyExpirationSeconds(parsePolicyExpirationSeconds(req.Metadata)).
SetPolicyMaxReceiveCount(parseSetPolicyMaxReceiveCount(req.Metadata)).
SetPolicyMaxReceiveQueue(parsePolicyMaxReceiveQueue(req.Metadata))
result, err := k.client.Send(k.ctx, queueMessage)
result, err := k.client.Send(ctx, queueMessage)
if err != nil {
return nil, err
}
@ -101,6 +120,14 @@ func (k *kubeMQ) Operations() []bindings.OperationKind {
return []bindings.OperationKind{bindings.CreateOperation}
}
func (k *kubeMQ) Close() error {
if k.closed.CompareAndSwap(false, true) {
close(k.closeCh)
}
defer k.wg.Wait()
return k.client.Close()
}
func (k *kubeMQ) processQueueMessage(ctx context.Context, handler bindings.Handler) error {
pr := qs.NewPollRequest().
SetChannel(k.opts.channel).

View File

@ -106,7 +106,7 @@ func Test_kubeMQ_Init(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
kubemq := NewKubeMQ(logger.NewLogger("test"))
err := kubemq.Init(tt.meta)
err := kubemq.Init(context.Background(), tt.meta)
if tt.wantErr {
require.Error(t, err)
} else {
@ -120,7 +120,7 @@ func Test_kubeMQ_Invoke_Read_Single_Message(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
kubemq := NewKubeMQ(logger.NewLogger("test"))
err := kubemq.Init(getDefaultMetadata("test.read.single"))
err := kubemq.Init(context.Background(), getDefaultMetadata("test.read.single"))
require.NoError(t, err)
dataReadCh := make(chan []byte)
invokeRequest := &bindings.InvokeRequest{
@ -147,7 +147,7 @@ func Test_kubeMQ_Invoke_Read_Single_MessageWithHandlerError(t *testing.T) {
kubemq := NewKubeMQ(logger.NewLogger("test"))
md := getDefaultMetadata("test.read.single.error")
md.Properties["autoAcknowledged"] = "false"
err := kubemq.Init(md)
err := kubemq.Init(context.Background(), md)
require.NoError(t, err)
invokeRequest := &bindings.InvokeRequest{
Data: []byte("test"),
@ -182,7 +182,7 @@ func Test_kubeMQ_Invoke_Error(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
kubemq := NewKubeMQ(logger.NewLogger("test"))
err := kubemq.Init(getDefaultMetadata("***test***"))
err := kubemq.Init(context.Background(), getDefaultMetadata("***test***"))
require.NoError(t, err)
invokeRequest := &bindings.InvokeRequest{

View File

@ -18,6 +18,8 @@ import (
"encoding/json"
"errors"
"strconv"
"sync"
"sync/atomic"
"time"
v1 "k8s.io/api/core/v1"
@ -35,6 +37,9 @@ type kubernetesInput struct {
namespace string
resyncPeriod time.Duration
logger logger.Logger
closed atomic.Bool
closeCh chan struct{}
wg sync.WaitGroup
}
type EventResponse struct {
@ -45,10 +50,13 @@ type EventResponse struct {
// NewKubernetes returns a new Kubernetes event input binding.
func NewKubernetes(logger logger.Logger) bindings.InputBinding {
return &kubernetesInput{logger: logger}
return &kubernetesInput{
logger: logger,
closeCh: make(chan struct{}),
}
}
func (k *kubernetesInput) Init(metadata bindings.Metadata) error {
func (k *kubernetesInput) Init(ctx context.Context, metadata bindings.Metadata) error {
client, err := kubeclient.GetKubeClient()
if err != nil {
return err
@ -78,6 +86,9 @@ func (k *kubernetesInput) parseMetadata(metadata bindings.Metadata) error {
}
func (k *kubernetesInput) Read(ctx context.Context, handler bindings.Handler) error {
if k.closed.Load() {
return errors.New("binding is closed")
}
watchlist := cache.NewListWatchFromClient(
k.kubeClient.CoreV1().RESTClient(),
"events",
@ -126,12 +137,28 @@ func (k *kubernetesInput) Read(ctx context.Context, handler bindings.Handler) er
},
)
k.wg.Add(3)
readCtx, cancel := context.WithCancel(ctx)
// catch when binding is closed.
go func() {
defer k.wg.Done()
defer cancel()
select {
case <-readCtx.Done():
case <-k.closeCh:
}
}()
// Start the controller in backgound
stopCh := make(chan struct{})
go controller.Run(stopCh)
go func() {
defer k.wg.Done()
controller.Run(readCtx.Done())
}()
// Watch for new messages and for context cancellation
go func() {
defer k.wg.Done()
var (
obj EventResponse
data []byte
@ -148,8 +175,7 @@ func (k *kubernetesInput) Read(ctx context.Context, handler bindings.Handler) er
Data: data,
})
}
case <-ctx.Done():
close(stopCh)
case <-readCtx.Done():
return
}
}
@ -157,3 +183,11 @@ func (k *kubernetesInput) Read(ctx context.Context, handler bindings.Handler) er
return nil
}
func (k *kubernetesInput) Close() error {
if k.closed.CompareAndSwap(false, true) {
close(k.closeCh)
}
k.wg.Wait()
return nil
}

View File

@ -64,7 +64,7 @@ func NewLocalStorage(logger logger.Logger) bindings.OutputBinding {
}
// Init performs metadata parsing.
func (ls *LocalStorage) Init(metadata bindings.Metadata) error {
func (ls *LocalStorage) Init(_ context.Context, metadata bindings.Metadata) error {
m, err := ls.parseMetadata(metadata)
if err != nil {
return fmt.Errorf("failed to parse metadata: %w", err)

View File

@ -40,30 +40,29 @@ type MQTT struct {
logger logger.Logger
isSubscribed atomic.Bool
readHandler bindings.Handler
ctx context.Context
cancel context.CancelFunc
backOff backoff.BackOff
closeCh chan struct{}
closed atomic.Bool
wg sync.WaitGroup
}
// NewMQTT returns a new MQTT instance.
func NewMQTT(logger logger.Logger) bindings.InputOutputBinding {
return &MQTT{
logger: logger,
logger: logger,
closeCh: make(chan struct{}),
}
}
// Init does MQTT connection parsing.
func (m *MQTT) Init(metadata bindings.Metadata) (err error) {
func (m *MQTT) Init(ctx context.Context, metadata bindings.Metadata) (err error) {
m.metadata, err = parseMQTTMetaData(metadata, m.logger)
if err != nil {
return err
}
m.ctx, m.cancel = context.WithCancel(context.Background())
// TODO: Make the backoff configurable for constant or exponential
b := backoff.NewConstantBackOff(5 * time.Second)
m.backOff = backoff.WithContext(b, m.ctx)
m.backOff = backoff.NewConstantBackOff(5 * time.Second)
return nil
}
@ -104,7 +103,7 @@ func (m *MQTT) getProducer() (mqtt.Client, error) {
return p, nil
}
func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
func (m *MQTT) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
producer, err := m.getProducer()
if err != nil {
return nil, fmt.Errorf("failed to create producer connection: %w", err)
@ -118,7 +117,7 @@ func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (*
bo := backoff.WithMaxRetries(
backoff.NewConstantBackOff(200*time.Millisecond), 3,
)
bo = backoff.WithContext(bo, parentCtx)
bo = backoff.WithContext(bo, ctx)
topic, ok := req.Metadata[mqttTopic]
if !ok || topic == "" {
@ -127,14 +126,13 @@ func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (*
}
return nil, retry.NotifyRecover(func() (err error) {
token := producer.Publish(topic, m.metadata.qos, m.metadata.retain, req.Data)
ctx, cancel := context.WithTimeout(parentCtx, defaultWait)
defer cancel()
select {
case <-token.Done():
err = token.Error()
case <-m.ctx.Done():
// Context canceled
err = m.ctx.Err()
case <-m.closeCh:
err = errors.New("mqtt client closed")
case <-time.After(defaultWait):
err = errors.New("mqtt client timeout")
case <-ctx.Done():
// Context canceled
err = ctx.Err()
@ -151,6 +149,10 @@ func (m *MQTT) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (*
}
func (m *MQTT) Read(ctx context.Context, handler bindings.Handler) error {
if m.closed.Load() {
return errors.New("error: binding is closed")
}
// If the subscription is already active, wait 2s before retrying (in case we're still disconnecting), otherwise return an error
if !m.isSubscribed.CompareAndSwap(false, true) {
m.logger.Debug("Subscription is already active; waiting 2s before retrying…")
@ -177,11 +179,14 @@ func (m *MQTT) Read(ctx context.Context, handler bindings.Handler) error {
// In background, watch for contexts cancelation and stop the connection
// However, do not call "unsubscribe" which would cause the broker to stop tracking the last message received by this consumer group
m.wg.Add(1)
go func() {
defer m.wg.Done()
select {
case <-ctx.Done():
// nop
case <-m.ctx.Done():
case <-m.closeCh:
// nop
}
@ -208,14 +213,12 @@ func (m *MQTT) connect(clientID string, isSubscriber bool) (mqtt.Client, error)
}
client := mqtt.NewClient(opts)
ctx, cancel := context.WithTimeout(m.ctx, defaultWait)
defer cancel()
token := client.Connect()
select {
case <-token.Done():
err = token.Error()
case <-ctx.Done():
err = ctx.Err()
case <-time.After(defaultWait):
err = errors.New("mqtt client timed out connecting")
}
if err != nil {
return nil, fmt.Errorf("failed to connect: %w", err)
@ -290,43 +293,46 @@ func (m *MQTT) createClientOptions(uri *url.URL, clientID string) *mqtt.ClientOp
return opts
}
func (m *MQTT) handleMessage(client mqtt.Client, mqttMsg mqtt.Message) {
// We're using m.ctx as context in this method because we don't have access to the Read context
// Canceling the Read context makes Read invoke "Disconnect" anyways
ctx := m.ctx
func (m *MQTT) handleMessage() func(client mqtt.Client, mqttMsg mqtt.Message) {
return func(client mqtt.Client, mqttMsg mqtt.Message) {
bo := m.backOff
if m.metadata.backOffMaxRetries >= 0 {
bo = backoff.WithMaxRetries(bo, uint64(m.metadata.backOffMaxRetries))
}
var bo backoff.BackOff = backoff.WithContext(m.backOff, ctx)
if m.metadata.backOffMaxRetries >= 0 {
bo = backoff.WithMaxRetries(bo, uint64(m.metadata.backOffMaxRetries))
}
err := retry.NotifyRecover(
func() error {
m.logger.Debugf("Processing MQTT message %s/%d", mqttMsg.Topic(), mqttMsg.MessageID())
// Use a background context here so that the context is not tied to the
// first Invoke first created the producer.
// TODO: add context to mqtt library, and add a OnConnectWithContext option
// to change this func signature to
// func(c mqtt.Client, ctx context.Context)
_, err := m.readHandler(context.Background(), &bindings.ReadResponse{
Data: mqttMsg.Payload(),
Metadata: map[string]string{
mqttTopic: mqttMsg.Topic(),
},
})
if err != nil {
return err
}
err := retry.NotifyRecover(
func() error {
m.logger.Debugf("Processing MQTT message %s/%d", mqttMsg.Topic(), mqttMsg.MessageID())
_, err := m.readHandler(ctx, &bindings.ReadResponse{
Data: mqttMsg.Payload(),
Metadata: map[string]string{
mqttTopic: mqttMsg.Topic(),
},
})
if err != nil {
return err
}
// Ack the message on success
mqttMsg.Ack()
return nil
},
bo,
func(err error, d time.Duration) {
m.logger.Errorf("Error processing MQTT message: %s/%d. Retrying…", mqttMsg.Topic(), mqttMsg.MessageID())
},
func() {
m.logger.Infof("Successfully processed MQTT message after it previously failed: %s/%d", mqttMsg.Topic(), mqttMsg.MessageID())
},
)
if err != nil {
m.logger.Errorf("Failed processing MQTT message: %s/%d: %v", mqttMsg.Topic(), mqttMsg.MessageID(), err)
// Ack the message on success
mqttMsg.Ack()
return nil
},
bo,
func(err error, d time.Duration) {
m.logger.Errorf("Error processing MQTT message: %s/%d. Retrying…", mqttMsg.Topic(), mqttMsg.MessageID())
},
func() {
m.logger.Infof("Successfully processed MQTT message after it previously failed: %s/%d", mqttMsg.Topic(), mqttMsg.MessageID())
},
)
if err != nil {
m.logger.Errorf("Failed processing MQTT message: %s/%d: %v", mqttMsg.Topic(), mqttMsg.MessageID(), err)
}
}
}
@ -336,17 +342,15 @@ func (m *MQTT) createSubscriberClientOptions(uri *url.URL, clientID string) *mqt
// On (re-)connection, add the topic subscription
opts.OnConnect = func(c mqtt.Client) {
token := c.Subscribe(m.metadata.topic, m.metadata.qos, m.handleMessage)
token := c.Subscribe(m.metadata.topic, m.metadata.qos, m.handleMessage())
var err error
subscribeCtx, subscribeCancel := context.WithTimeout(m.ctx, defaultWait)
defer subscribeCancel()
select {
case <-token.Done():
// Subscription went through (sucecessfully or not)
err = token.Error()
case <-subscribeCtx.Done():
err = fmt.Errorf("error while waiting for subscription token: %w", subscribeCtx.Err())
case <-time.After(defaultWait):
err = errors.New("timed out waiting for subscription to complete")
}
// Nothing we can do in case of errors besides logging them
@ -363,13 +367,16 @@ func (m *MQTT) Close() error {
m.producerLock.Lock()
defer m.producerLock.Unlock()
// Canceling the context also causes Read to stop receiving messages
m.cancel()
if m.closed.CompareAndSwap(false, true) {
close(m.closeCh)
}
if m.producer != nil {
m.producer.Disconnect(200)
m.producer = nil
}
m.wg.Wait()
return nil
}

View File

@ -49,6 +49,7 @@ func getConnectionString() string {
func TestInvokeWithTopic(t *testing.T) {
t.Parallel()
ctx := context.Background()
url := getConnectionString()
if url == "" {
@ -79,7 +80,7 @@ func TestInvokeWithTopic(t *testing.T) {
logger := logger.NewLogger("test")
r := NewMQTT(logger).(*MQTT)
err := r.Init(metadata)
err := r.Init(ctx, metadata)
assert.Nil(t, err)
conn, err := r.connect(uuid.NewString(), false)
@ -127,4 +128,5 @@ func TestInvokeWithTopic(t *testing.T) {
assert.True(t, ok)
assert.Equal(t, dataCustomized, mqttMessage.Payload())
assert.Equal(t, topicCustomized, mqttMessage.Topic())
assert.NoError(t, r.Close())
}

View File

@ -205,7 +205,6 @@ func TestParseMetadata(t *testing.T) {
logger := logger.NewLogger("test")
m := NewMQTT(logger).(*MQTT)
m.backOff = backoff.NewConstantBackOff(5 * time.Second)
m.ctx, m.cancel = context.WithCancel(context.Background())
m.readHandler = func(ctx context.Context, r *bindings.ReadResponse) ([]byte, error) {
assert.Equal(t, payload, r.Data)
metadata := r.Metadata
@ -215,7 +214,7 @@ func TestParseMetadata(t *testing.T) {
return r.Data, nil
}
m.handleMessage(nil, &mqttMockMessage{
m.handleMessage()(nil, &mqttMockMessage{
topic: topic,
payload: payload,
})

View File

@ -81,7 +81,7 @@ func NewMysql(logger logger.Logger) bindings.OutputBinding {
}
// Init initializes the MySQL binding.
func (m *Mysql) Init(metadata bindings.Metadata) error {
func (m *Mysql) Init(ctx context.Context, metadata bindings.Metadata) error {
m.logger.Debug("Initializing MySql binding")
p := metadata.Properties
@ -115,7 +115,7 @@ func (m *Mysql) Init(metadata bindings.Metadata) error {
return err
}
err = db.Ping()
err = db.PingContext(ctx)
if err != nil {
return fmt.Errorf("unable to ping the DB: %w", err)
}

View File

@ -75,7 +75,7 @@ func TestMysqlIntegration(t *testing.T) {
b := NewMysql(logger.NewLogger("test")).(*Mysql)
m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{connectionURLKey: url}}}
if err := b.Init(m); err != nil {
if err := b.Init(context.Background(), m); err != nil {
t.Fatal(err)
}

View File

@ -21,6 +21,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/nacos-group/nacos-sdk-go/v2/clients"
@ -56,6 +57,9 @@ type Nacos struct {
logger logger.Logger
configClient config_client.IConfigClient //nolint:nosnakecase
readHandler func(ctx context.Context, response *bindings.ReadResponse) ([]byte, error)
wg sync.WaitGroup
closed atomic.Bool
closeCh chan struct{}
}
// NewNacos returns a new Nacos instance.
@ -63,11 +67,12 @@ func NewNacos(logger logger.Logger) bindings.OutputBinding {
return &Nacos{
logger: logger,
watchesLock: sync.Mutex{},
closeCh: make(chan struct{}),
}
}
// Init implements InputBinding/OutputBinding's Init method.
func (n *Nacos) Init(metadata bindings.Metadata) error {
func (n *Nacos) Init(_ context.Context, metadata bindings.Metadata) error {
n.settings = Settings{
Timeout: defaultTimeout,
}
@ -146,6 +151,10 @@ func (n *Nacos) createConfigClient() error {
// Read implements InputBinding's Read method.
func (n *Nacos) Read(ctx context.Context, handler bindings.Handler) error {
if n.closed.Load() {
return errors.New("binding is closed")
}
n.readHandler = handler
n.watchesLock.Lock()
@ -154,9 +163,14 @@ func (n *Nacos) Read(ctx context.Context, handler bindings.Handler) error {
}
n.watchesLock.Unlock()
n.wg.Add(1)
go func() {
defer n.wg.Done()
// Cancel all listeners when the context is done
<-ctx.Done()
select {
case <-ctx.Done():
case <-n.closeCh:
}
n.cancelAllListeners()
}()
@ -165,8 +179,14 @@ func (n *Nacos) Read(ctx context.Context, handler bindings.Handler) error {
// Close implements cancel all listeners, see https://github.com/dapr/components-contrib/issues/779
func (n *Nacos) Close() error {
if n.closed.CompareAndSwap(false, true) {
close(n.closeCh)
}
n.cancelAllListeners()
n.wg.Wait()
return nil
}
@ -223,7 +243,11 @@ func (n *Nacos) addListener(ctx context.Context, config configParam) {
func (n *Nacos) addListenerFoInputBinding(ctx context.Context, config configParam) {
if n.addToWatches(config) {
go n.addListener(ctx, config)
n.wg.Add(1)
go func() {
defer n.wg.Done()
n.addListener(ctx, config)
}()
}
}

View File

@ -35,7 +35,7 @@ func TestInputBindingRead(t *testing.T) { //nolint:paralleltest
m.Properties, err = getNacosLocalCacheMetadata()
require.NoError(t, err)
n := NewNacos(logger.NewLogger("test")).(*Nacos)
err = n.Init(m)
err = n.Init(context.Background(), m)
require.NoError(t, err)
var count int32
ch := make(chan bool, 1)

View File

@ -19,7 +19,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dapr/components-contrib/bindings/alicloud/nacos"
"github.com/dapr/components-contrib/bindings/nacos"
)
func TestParseMetadata(t *testing.T) { //nolint:paralleltest

View File

@ -22,15 +22,15 @@ import (
// OutputBinding is the interface for an output binding, allowing users to invoke remote systems with optional payloads.
type OutputBinding interface {
Init(metadata Metadata) error
Init(ctx context.Context, metadata Metadata) error
Invoke(ctx context.Context, req *InvokeRequest) (*InvokeResponse, error)
Operations() []OperationKind
}
func PingOutBinding(outputBinding OutputBinding) error {
func PingOutBinding(ctx context.Context, outputBinding OutputBinding) error {
// checks if this output binding has the ping option then executes
if outputBindingWithPing, ok := outputBinding.(health.Pinger); ok {
return outputBindingWithPing.Ping()
return outputBindingWithPing.Ping(ctx)
} else {
return fmt.Errorf("ping is not implemented by this output binding")
}

View File

@ -48,7 +48,7 @@ func NewPostgres(logger logger.Logger) bindings.OutputBinding {
}
// Init initializes the PostgreSql binding.
func (p *Postgres) Init(metadata bindings.Metadata) error {
func (p *Postgres) Init(ctx context.Context, metadata bindings.Metadata) error {
url, ok := metadata.Properties[connectionURLKey]
if !ok || url == "" {
return errors.Errorf("required metadata not set: %s", connectionURLKey)
@ -59,7 +59,9 @@ func (p *Postgres) Init(metadata bindings.Metadata) error {
return errors.Wrap(err, "error opening DB connection")
}
p.db, err = pgxpool.NewWithConfig(context.Background(), poolConfig)
// This context doesn't control the lifetime of the connection pool, and is
// only scoped to postgres creating resources at init.
p.db, err = pgxpool.NewWithConfig(ctx, poolConfig)
if err != nil {
return errors.Wrap(err, "unable to ping the DB")
}

View File

@ -64,7 +64,7 @@ func TestPostgresIntegration(t *testing.T) {
// live DB test
b := NewPostgres(logger.NewLogger("test")).(*Postgres)
m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{connectionURLKey: url}}}
if err := b.Init(m); err != nil {
if err := b.Init(context.Background(), m); err != nil {
t.Fatal(err)
}

View File

@ -74,7 +74,7 @@ func (p *Postmark) parseMetadata(meta bindings.Metadata) (postmarkMetadata, erro
}
// Init does metadata parsing and not much else :).
func (p *Postmark) Init(metadata bindings.Metadata) error {
func (p *Postmark) Init(_ context.Context, metadata bindings.Metadata) error {
// Parse input metadata
meta, err := p.parseMetadata(metadata)
if err != nil {

View File

@ -19,6 +19,8 @@ import (
"fmt"
"math"
"strconv"
"sync"
"sync/atomic"
"time"
amqp "github.com/rabbitmq/amqp091-go"
@ -50,6 +52,9 @@ type RabbitMQ struct {
metadata rabbitMQMetadata
logger logger.Logger
queue amqp.Queue
closed atomic.Bool
closeCh chan struct{}
wg sync.WaitGroup
}
// Metadata is the rabbitmq config.
@ -66,11 +71,14 @@ type rabbitMQMetadata struct {
// NewRabbitMQ returns a new rabbitmq instance.
func NewRabbitMQ(logger logger.Logger) bindings.InputOutputBinding {
return &RabbitMQ{logger: logger}
return &RabbitMQ{
logger: logger,
closeCh: make(chan struct{}),
}
}
// Init does metadata parsing and connection creation.
func (r *RabbitMQ) Init(metadata bindings.Metadata) error {
func (r *RabbitMQ) Init(_ context.Context, metadata bindings.Metadata) error {
err := r.parseMetadata(metadata)
if err != nil {
return err
@ -226,6 +234,10 @@ func (r *RabbitMQ) declareQueue() (amqp.Queue, error) {
}
func (r *RabbitMQ) Read(ctx context.Context, handler bindings.Handler) error {
if r.closed.Load() {
return errors.New("binding already closed")
}
msgs, err := r.channel.Consume(
r.queue.Name,
"",
@ -239,14 +251,27 @@ func (r *RabbitMQ) Read(ctx context.Context, handler bindings.Handler) error {
return err
}
readCtx, cancel := context.WithCancel(ctx)
r.wg.Add(2)
go func() {
defer r.wg.Done()
defer cancel()
select {
case <-r.closeCh:
case <-readCtx.Done():
}
}()
go func() {
defer r.wg.Done()
var err error
for {
select {
case <-ctx.Done():
case <-readCtx.Done():
return
case d := <-msgs:
_, err = handler(ctx, &bindings.ReadResponse{
_, err = handler(readCtx, &bindings.ReadResponse{
Data: d.Body,
})
if err != nil {
@ -260,3 +285,11 @@ func (r *RabbitMQ) Read(ctx context.Context, handler bindings.Handler) error {
return nil
}
func (r *RabbitMQ) Close() error {
if r.closed.CompareAndSwap(false, true) {
close(r.closeCh)
}
defer r.wg.Wait()
return r.channel.Close()
}

View File

@ -85,7 +85,7 @@ func TestQueuesWithTTL(t *testing.T) {
logger := logger.NewLogger("test")
r := NewRabbitMQ(logger).(*RabbitMQ)
err := r.Init(metadata)
err := r.Init(context.Background(), metadata)
assert.Nil(t, err)
// Assert that if waited too long, we won't see any message
@ -117,6 +117,7 @@ func TestQueuesWithTTL(t *testing.T) {
assert.True(t, ok)
msgBody := string(msg.Body)
assert.Equal(t, testMsgContent, msgBody)
assert.NoError(t, r.Close())
}
func TestPublishingWithTTL(t *testing.T) {
@ -144,7 +145,7 @@ func TestPublishingWithTTL(t *testing.T) {
logger := logger.NewLogger("test")
rabbitMQBinding1 := NewRabbitMQ(logger).(*RabbitMQ)
err := rabbitMQBinding1.Init(metadata)
err := rabbitMQBinding1.Init(context.Background(), metadata)
assert.Nil(t, err)
// Assert that if waited too long, we won't see any message
@ -175,7 +176,7 @@ func TestPublishingWithTTL(t *testing.T) {
// Getting before it is expired, should return it
rabbitMQBinding2 := NewRabbitMQ(logger).(*RabbitMQ)
err = rabbitMQBinding2.Init(metadata)
err = rabbitMQBinding2.Init(context.Background(), metadata)
assert.Nil(t, err)
const testMsgContent = "test_msg"
@ -193,6 +194,9 @@ func TestPublishingWithTTL(t *testing.T) {
assert.True(t, ok)
msgBody := string(msg.Body)
assert.Equal(t, testMsgContent, msgBody)
assert.NoError(t, rabbitMQBinding1.Close())
assert.NoError(t, rabbitMQBinding1.Close())
}
func TestExclusiveQueue(t *testing.T) {
@ -222,7 +226,7 @@ func TestExclusiveQueue(t *testing.T) {
logger := logger.NewLogger("test")
r := NewRabbitMQ(logger).(*RabbitMQ)
err := r.Init(metadata)
err := r.Init(context.Background(), metadata)
assert.Nil(t, err)
// Assert that if waited too long, we won't see any message
@ -276,7 +280,7 @@ func TestPublishWithPriority(t *testing.T) {
logger := logger.NewLogger("test")
r := NewRabbitMQ(logger).(*RabbitMQ)
err := r.Init(metadata)
err := r.Init(context.Background(), metadata)
assert.Nil(t, err)
// Assert that if waited too long, we won't see any message

View File

@ -28,9 +28,6 @@ type Redis struct {
client rediscomponent.RedisClient
clientSettings *rediscomponent.Settings
logger logger.Logger
ctx context.Context
cancel context.CancelFunc
}
// NewRedis returns a new redis bindings instance.
@ -39,15 +36,13 @@ func NewRedis(logger logger.Logger) bindings.OutputBinding {
}
// Init performs metadata parsing and connection creation.
func (r *Redis) Init(meta bindings.Metadata) (err error) {
func (r *Redis) Init(ctx context.Context, meta bindings.Metadata) (err error) {
r.client, r.clientSettings, err = rediscomponent.ParseClientFromProperties(meta.Properties, nil)
if err != nil {
return err
}
r.ctx, r.cancel = context.WithCancel(context.Background())
_, err = r.client.PingResult(r.ctx)
_, err = r.client.PingResult(ctx)
if err != nil {
return fmt.Errorf("redis binding: error connecting to redis at %s: %s", r.clientSettings.Host, err)
}
@ -55,8 +50,8 @@ func (r *Redis) Init(meta bindings.Metadata) (err error) {
return err
}
func (r *Redis) Ping() error {
if _, err := r.client.PingResult(r.ctx); err != nil {
func (r *Redis) Ping(ctx context.Context) error {
if _, err := r.client.PingResult(ctx); err != nil {
return fmt.Errorf("redis binding: error connecting to redis at %s: %s", r.clientSettings.Host, err)
}
@ -101,7 +96,5 @@ func (r *Redis) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindi
}
func (r *Redis) Close() error {
r.cancel()
return r.client.Close()
}

View File

@ -40,7 +40,6 @@ func TestInvokeCreate(t *testing.T) {
client: c,
logger: logger.NewLogger("test"),
}
bind.ctx, bind.cancel = context.WithCancel(context.Background())
_, err := c.DoRead(context.Background(), "GET", testKey)
assert.Equal(t, redis.Nil, err)
@ -66,7 +65,6 @@ func TestInvokeGet(t *testing.T) {
client: c,
logger: logger.NewLogger("test"),
}
bind.ctx, bind.cancel = context.WithCancel(context.Background())
err := c.DoWrite(context.Background(), "SET", testKey, testData)
assert.Equal(t, nil, err)
@ -87,7 +85,6 @@ func TestInvokeDelete(t *testing.T) {
client: c,
logger: logger.NewLogger("test"),
}
bind.ctx, bind.cancel = context.WithCancel(context.Background())
err := c.DoWrite(context.Background(), "SET", testKey, testData)
assert.Equal(t, nil, err)

View File

@ -18,6 +18,8 @@ import (
"encoding/json"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
r "github.com/dancannon/gorethink"
@ -34,6 +36,9 @@ type Binding struct {
logger logger.Logger
session *r.Session
config StateConfig
closed atomic.Bool
closeCh chan struct{}
wg sync.WaitGroup
}
// StateConfig is the binding config.
@ -45,12 +50,13 @@ type StateConfig struct {
// NewRethinkDBStateChangeBinding returns a new RethinkDB actor event input binding.
func NewRethinkDBStateChangeBinding(logger logger.Logger) bindings.InputBinding {
return &Binding{
logger: logger,
logger: logger,
closeCh: make(chan struct{}),
}
}
// Init initializes the RethinkDB binding.
func (b *Binding) Init(metadata bindings.Metadata) error {
func (b *Binding) Init(ctx context.Context, metadata bindings.Metadata) error {
cfg, err := metadataToConfig(metadata.Properties, b.logger)
if err != nil {
return errors.Wrap(err, "unable to parse metadata properties")
@ -68,6 +74,10 @@ func (b *Binding) Init(metadata bindings.Metadata) error {
// Read triggers the RethinkDB scheduler.
func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
if b.closed.Load() {
return errors.New("binding is closed")
}
b.logger.Infof("subscribing to state changes in %s.%s...", b.config.Database, b.config.Table)
cursor, err := r.DB(b.config.Database).
Table(b.config.Table).
@ -81,8 +91,21 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
errors.Wrapf(err, "error connecting to table %s", b.config.Table)
}
readCtx, cancel := context.WithCancel(ctx)
b.wg.Add(2)
go func() {
for ctx.Err() == nil {
defer b.wg.Done()
defer cancel()
select {
case <-b.closeCh:
case <-readCtx.Done():
}
}()
go func() {
defer b.wg.Done()
for readCtx.Err() == nil {
var change interface{}
ok := cursor.Next(&change)
if !ok {
@ -105,7 +128,7 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
},
}
if _, err := handler(ctx, resp); err != nil {
if _, err := handler(readCtx, resp); err != nil {
b.logger.Errorf("error invoking change handler: %v", err)
continue
}
@ -117,6 +140,14 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error {
return nil
}
func (b *Binding) Close() error {
if b.closed.CompareAndSwap(false, true) {
close(b.closeCh)
}
defer b.wg.Wait()
return b.session.Close()
}
func metadataToConfig(cfg map[string]string, logger logger.Logger) (StateConfig, error) {
c := StateConfig{}
for k, v := range cfg {

View File

@ -71,7 +71,7 @@ func TestBinding(t *testing.T) {
assert.NotNil(t, m.Properties)
b := getNewRethinkActorBinding()
err := b.Init(m)
err := b.Init(context.Background(), m)
assert.NoErrorf(t, err, "error initializing")
ctx, cancel := context.WithCancel(context.Background())

View File

@ -18,6 +18,8 @@ import (
"errors"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
mqc "github.com/apache/rocketmq-client-go/v2/consumer"
@ -30,32 +32,32 @@ import (
"github.com/dapr/kit/retry"
)
type AliCloudRocketMQ struct {
type RocketMQ struct {
logger logger.Logger
settings Settings
producer mqw.Producer
ctx context.Context
cancel context.CancelFunc
backOffConfig retry.Config
closeCh chan struct{}
closed atomic.Bool
wg sync.WaitGroup
}
func NewAliCloudRocketMQ(l logger.Logger) *AliCloudRocketMQ {
return &AliCloudRocketMQ{ //nolint:exhaustivestruct
func NewRocketMQ(l logger.Logger) *RocketMQ {
return &RocketMQ{ //nolint:exhaustivestruct
logger: l,
producer: nil,
closeCh: make(chan struct{}),
}
}
// Init performs metadata parsing.
func (a *AliCloudRocketMQ) Init(metadata bindings.Metadata) error {
func (a *RocketMQ) Init(ctx context.Context, metadata bindings.Metadata) error {
var err error
if err = a.settings.Decode(metadata.Properties); err != nil {
return err
}
a.ctx, a.cancel = context.WithCancel(context.Background())
// Default retry configuration is used if no
// backOff properties are set.
if err = retry.DecodeConfigWithPrefix(
@ -74,7 +76,11 @@ func (a *AliCloudRocketMQ) Init(metadata bindings.Metadata) error {
}
// Read triggers the rocketmq subscription.
func (a *AliCloudRocketMQ) Read(ctx context.Context, handler bindings.Handler) error {
func (a *RocketMQ) Read(ctx context.Context, handler bindings.Handler) error {
if a.closed.Load() {
return errors.New("error: binding is closed")
}
a.logger.Debugf("binding rocketmq: start read input binding")
consumer, err := a.setupConsumer()
@ -114,10 +120,12 @@ func (a *AliCloudRocketMQ) Read(ctx context.Context, handler bindings.Handler) e
a.logger.Debugf("binding-rocketmq: consumer started")
// Listen for context cancelation to stop the subscription
a.wg.Add(1)
go func() {
defer a.wg.Done()
select {
case <-ctx.Done():
case <-a.ctx.Done():
case <-a.closeCh:
}
innerErr := consumer.Shutdown()
@ -130,9 +138,11 @@ func (a *AliCloudRocketMQ) Read(ctx context.Context, handler bindings.Handler) e
}
// Close implements cancel all listeners, see https://github.com/dapr/components-contrib/issues/779
func (a *AliCloudRocketMQ) Close() error {
a.cancel()
func (a *RocketMQ) Close() error {
defer a.wg.Wait()
if a.closed.CompareAndSwap(false, true) {
close(a.closeCh)
}
return nil
}
@ -155,7 +165,7 @@ func parseTopic(key string) (mqType, mqExpression, topic string, err error) {
return
}
func (a *AliCloudRocketMQ) setupConsumer() (mqw.PushConsumer, error) {
func (a *RocketMQ) setupConsumer() (mqw.PushConsumer, error) {
if consumer, ok := mqw.Consumers[a.settings.AccessProto]; ok {
md := a.settings.ToRocketMQMetadata()
if err := consumer.Init(md); err != nil {
@ -172,7 +182,7 @@ func (a *AliCloudRocketMQ) setupConsumer() (mqw.PushConsumer, error) {
return nil, errors.New("binding-rocketmq error: cannot found rocketmq consumer")
}
func (a *AliCloudRocketMQ) setupPublisher() (mqw.Producer, error) {
func (a *RocketMQ) setupPublisher() (mqw.Producer, error) {
if producer, ok := mqw.Producers[a.settings.AccessProto]; ok {
md := a.settings.ToRocketMQMetadata()
if err := producer.Init(md); err != nil {
@ -195,25 +205,25 @@ func (a *AliCloudRocketMQ) setupPublisher() (mqw.Producer, error) {
}
// Operations returns list of operations supported by rocketmq binding.
func (a *AliCloudRocketMQ) Operations() []bindings.OperationKind {
func (a *RocketMQ) Operations() []bindings.OperationKind {
return []bindings.OperationKind{bindings.CreateOperation}
}
func (a *AliCloudRocketMQ) Invoke(req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
func (a *RocketMQ) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
rst := &bindings.InvokeResponse{Data: nil, Metadata: nil}
if req.Operation != bindings.CreateOperation {
return rst, fmt.Errorf("binding-rocketmq error: unsupported operation %s", req.Operation)
}
return rst, a.sendMessage(req)
return rst, a.sendMessage(ctx, req)
}
func (a *AliCloudRocketMQ) sendMessage(req *bindings.InvokeRequest) error {
func (a *RocketMQ) sendMessage(ctx context.Context, req *bindings.InvokeRequest) error {
topic := req.Metadata[metadataRocketmqTopic]
if topic != "" {
_, err := a.send(topic, req.Metadata[metadataRocketmqTag], req.Metadata[metadataRocketmqKey], req.Data)
_, err := a.send(ctx, topic, req.Metadata[metadataRocketmqTag], req.Metadata[metadataRocketmqKey], req.Data)
if err != nil {
return err
}
@ -229,7 +239,7 @@ func (a *AliCloudRocketMQ) sendMessage(req *bindings.InvokeRequest) error {
if err != nil {
return err
}
_, err = a.send(topic, mqExpression, req.Metadata[metadataRocketmqKey], req.Data)
_, err = a.send(ctx, topic, mqExpression, req.Metadata[metadataRocketmqKey], req.Data)
if err != nil {
return err
}
@ -239,9 +249,9 @@ func (a *AliCloudRocketMQ) sendMessage(req *bindings.InvokeRequest) error {
return nil
}
func (a *AliCloudRocketMQ) send(topic, mqExpr, key string, data []byte) (bool, error) {
func (a *RocketMQ) send(ctx context.Context, topic, mqExpr, key string, data []byte) (bool, error) {
msg := primitive.NewMessage(topic, data).WithTag(mqExpr).WithKeys([]string{key})
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
rst, err := a.producer.SendSync(ctx, msg)
if err != nil {
@ -256,7 +266,7 @@ func (a *AliCloudRocketMQ) send(topic, mqExpr, key string, data []byte) (bool, e
type mqCallback func(ctx context.Context, msgs ...*primitive.MessageExt) (mqc.ConsumeResult, error)
func (a *AliCloudRocketMQ) adaptCallback(_, consumerGroup, mqType, mqExpr string, handler bindings.Handler) mqCallback {
func (a *RocketMQ) adaptCallback(_, consumerGroup, mqType, mqExpr string, handler bindings.Handler) mqCallback {
return func(ctx context.Context, msgs ...*primitive.MessageExt) (mqc.ConsumeResult, error) {
success := true
for _, v := range msgs {

View File

@ -34,8 +34,8 @@ func TestInputBindingRead(t *testing.T) { //nolint:paralleltest
}
m := bindings.Metadata{} //nolint:exhaustivestruct
m.Properties = getTestMetadata()
r := NewAliCloudRocketMQ(logger.NewLogger("test"))
err := r.Init(m)
r := NewRocketMQ(logger.NewLogger("test"))
err := r.Init(context.Background(), m)
require.NoError(t, err)
var count int32
@ -51,7 +51,7 @@ func TestInputBindingRead(t *testing.T) { //nolint:paralleltest
time.Sleep(5 * time.Second)
atomic.StoreInt32(&count, 0)
req := &bindings.InvokeRequest{Data: []byte("hello"), Operation: bindings.CreateOperation, Metadata: map[string]string{}}
_, err = r.Invoke(req)
_, err = r.Invoke(context.Background(), req)
require.NoError(t, err)
time.Sleep(10 * time.Second)

View File

@ -19,7 +19,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dapr/components-contrib/bindings/alicloud/rocketmq"
"github.com/dapr/components-contrib/bindings/rocketmq"
)
func TestSettingsDecode(t *testing.T) { //nolint:paralleltest

View File

@ -61,7 +61,7 @@ func NewSMTP(logger logger.Logger) bindings.OutputBinding {
}
// Init smtp component (parse metadata).
func (s *Mailer) Init(metadata bindings.Metadata) error {
func (s *Mailer) Init(_ context.Context, metadata bindings.Metadata) error {
// parse metadata
meta, err := s.parseMetadata(metadata)
if err != nil {

View File

@ -84,7 +84,7 @@ func (sg *SendGrid) parseMetadata(meta bindings.Metadata) (sendGridMetadata, err
}
// Init does metadata parsing and not much else :).
func (sg *SendGrid) Init(metadata bindings.Metadata) error {
func (sg *SendGrid) Init(_ context.Context, metadata bindings.Metadata) error {
// Parse input metadata
meta, err := sg.parseMetadata(metadata)
if err != nil {

View File

@ -60,19 +60,19 @@ func NewSMS(logger logger.Logger) bindings.OutputBinding {
}
}
func (t *SMS) Init(metadata bindings.Metadata) error {
func (t *SMS) Init(_ context.Context, metadata bindings.Metadata) error {
twilioM := twilioMetadata{
timeout: time.Minute * 5,
}
if metadata.Properties[fromNumber] == "" {
return errors.New("\"fromNumber\" is a required field")
return errors.New(`"fromNumber" is a required field`)
}
if metadata.Properties[accountSid] == "" {
return errors.New("\"accountSid\" is a required field")
return errors.New(`"accountSid" is a required field`)
}
if metadata.Properties[authToken] == "" {
return errors.New("\"authToken\" is a required field")
return errors.New(`"authToken" is a required field`)
}
twilioM.toNumber = metadata.Properties[toNumber]

View File

@ -53,7 +53,7 @@ func TestInit(t *testing.T) {
m := bindings.Metadata{}
m.Properties = map[string]string{"toNumber": "toNumber", "fromNumber": "fromNumber"}
tw := NewSMS(logger.NewLogger("test"))
err := tw.Init(m)
err := tw.Init(context.Background(), m)
assert.NotNil(t, err)
}
@ -66,7 +66,7 @@ func TestParseDuration(t *testing.T) {
"authToken": "authToken", "timeout": "badtimeout",
}
tw := NewSMS(logger.NewLogger("test"))
err := tw.Init(m)
err := tw.Init(context.Background(), m)
assert.NotNil(t, err)
}
@ -85,7 +85,7 @@ func TestWriteShouldSucceed(t *testing.T) {
tw.httpClient = &http.Client{
Transport: httpTransport,
}
err := tw.Init(m)
err := tw.Init(context.Background(), m)
assert.NoError(t, err)
t.Run("Should succeed with expected url and headers", func(t *testing.T) {
@ -123,7 +123,7 @@ func TestWriteShouldFail(t *testing.T) {
tw.httpClient = &http.Client{
Transport: httpTransport,
}
err := tw.Init(m)
err := tw.Init(context.Background(), m)
assert.NoError(t, err)
t.Run("Missing 'to' should fail", func(t *testing.T) {
@ -180,7 +180,7 @@ func TestMessageBody(t *testing.T) {
tw.httpClient = &http.Client{
Transport: httpTransport,
}
err := tw.Init(m)
err := tw.Init(context.Background(), m)
require.NoError(t, err)
tester := func(reqData []byte, expectBody string) func(t *testing.T) {

View File

@ -19,6 +19,8 @@ import (
"encoding/json"
"fmt"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/dghubble/go-twitter/twitter"
@ -31,18 +33,21 @@ import (
// Binding represents Twitter input/output binding.
type Binding struct {
client *twitter.Client
query string
logger logger.Logger
client *twitter.Client
query string
logger logger.Logger
closed atomic.Bool
closeCh chan struct{}
wg sync.WaitGroup
}
// NewTwitter returns a new Twitter event input binding.
func NewTwitter(logger logger.Logger) bindings.InputOutputBinding {
return &Binding{logger: logger}
return &Binding{logger: logger, closeCh: make(chan struct{})}
}
// Init initializes the Twitter binding.
func (t *Binding) Init(metadata bindings.Metadata) error {
func (t *Binding) Init(ctx context.Context, metadata bindings.Metadata) error {
t.logger.Warnf("DEPRECATION NOTICE: Component bindings.twitter has been deprecated and will be removed in a future Dapr release.")
ck, f := metadata.Properties["consumerKey"]
if !f || ck == "" {
@ -125,10 +130,17 @@ func (t *Binding) Read(ctx context.Context, handler bindings.Handler) error {
}
t.logger.Debug("starting handler...")
go demux.HandleChan(stream.Messages)
t.wg.Add(2)
go func() {
<-ctx.Done()
defer t.wg.Done()
demux.HandleChan(stream.Messages)
}()
go func() {
defer t.wg.Done()
select {
case <-t.closeCh:
case <-ctx.Done():
}
t.logger.Debug("stopping handler...")
stream.Stop()
}()
@ -136,6 +148,14 @@ func (t *Binding) Read(ctx context.Context, handler bindings.Handler) error {
return nil
}
func (t *Binding) Close() error {
if t.closed.CompareAndSwap(false, true) {
close(t.closeCh)
}
t.wg.Wait()
return nil
}
// Invoke handles all operations.
func (t *Binding) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
t.logger.Debugf("operation: %v", req.Operation)

View File

@ -60,7 +60,7 @@ func getRuntimeMetadata() map[string]string {
func TestInit(t *testing.T) {
m := getTestMetadata()
tw := NewTwitter(logger.NewLogger("test")).(*Binding)
err := tw.Init(m)
err := tw.Init(context.Background(), m)
assert.Nilf(t, err, "error initializing valid metadata properties")
}
@ -69,7 +69,7 @@ func TestInit(t *testing.T) {
func TestReadError(t *testing.T) {
tw := NewTwitter(logger.NewLogger("test")).(*Binding)
m := getTestMetadata()
err := tw.Init(m)
err := tw.Init(context.Background(), m)
assert.Nilf(t, err, "error initializing valid metadata properties")
err = tw.Read(context.Background(), func(ctx context.Context, res *bindings.ReadResponse) ([]byte, error) {
@ -79,6 +79,8 @@ func TestReadError(t *testing.T) {
return nil, nil
})
assert.Error(t, err)
assert.NoError(t, tw.Close())
}
// TestRead executes the Read method which calls Twiter API
@ -93,7 +95,7 @@ func TestRead(t *testing.T) {
m.Properties["query"] = "microsoft"
tw := NewTwitter(logger.NewLogger("test")).(*Binding)
tw.logger.SetOutputLevel(logger.DebugLevel)
err := tw.Init(m)
err := tw.Init(context.Background(), m)
assert.Nilf(t, err, "error initializing read")
ctx, cancel := context.WithCancel(context.Background())
@ -116,6 +118,8 @@ func TestRead(t *testing.T) {
cancel()
t.Fatal("Timeout waiting for messages")
}
assert.NoError(t, tw.Close())
}
// TestInvoke executes the Invoke method which calls Twiter API
@ -129,7 +133,7 @@ func TestInvoke(t *testing.T) {
m.Properties = getRuntimeMetadata()
tw := NewTwitter(logger.NewLogger("test")).(*Binding)
tw.logger.SetOutputLevel(logger.DebugLevel)
err := tw.Init(m)
err := tw.Init(context.Background(), m)
assert.Nilf(t, err, "error initializing Invoke")
req := &bindings.InvokeRequest{
@ -141,4 +145,5 @@ func TestInvoke(t *testing.T) {
resp, err := tw.Invoke(context.Background(), req)
assert.Nilf(t, err, "error on invoke")
assert.NotNil(t, resp)
assert.NoError(t, tw.Close())
}

View File

@ -61,7 +61,7 @@ func NewZeebeCommand(logger logger.Logger) bindings.OutputBinding {
}
// Init does metadata parsing and connection creation.
func (z *ZeebeCommand) Init(metadata bindings.Metadata) error {
func (z *ZeebeCommand) Init(ctx context.Context, metadata bindings.Metadata) error {
client, err := z.clientFactory.Get(metadata)
if err != nil {
return err
@ -114,7 +114,7 @@ func (z *ZeebeCommand) Invoke(ctx context.Context, req *bindings.InvokeRequest)
case UpdateJobRetriesOperation:
return z.updateJobRetries(ctx, req)
case ThrowErrorOperation:
return z.throwError(req)
return z.throwError(ctx, req)
case bindings.GetOperation:
fallthrough
case bindings.CreateOperation:

View File

@ -58,7 +58,7 @@ func TestInit(t *testing.T) {
}
cmd := ZeebeCommand{clientFactory: mcf, logger: testLogger}
err := cmd.Init(metadata)
err := cmd.Init(context.Background(), metadata)
assert.Error(t, err, errParsing)
})
@ -67,7 +67,7 @@ func TestInit(t *testing.T) {
mcf := mockClientFactory{}
cmd := ZeebeCommand{clientFactory: mcf, logger: testLogger}
err := cmd.Init(metadata)
err := cmd.Init(context.Background(), metadata)
assert.NoError(t, err)

View File

@ -30,7 +30,7 @@ type throwErrorPayload struct {
ErrorMessage string `json:"errorMessage"`
}
func (z *ZeebeCommand) throwError(req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
func (z *ZeebeCommand) throwError(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
var payload throwErrorPayload
err := json.Unmarshal(req.Data, &payload)
if err != nil {
@ -53,7 +53,7 @@ func (z *ZeebeCommand) throwError(req *bindings.InvokeRequest) (*bindings.Invoke
cmd = cmd.ErrorMessage(payload.ErrorMessage)
}
_, err = cmd.Send(context.Background())
_, err = cmd.Send(ctx)
if err != nil {
return nil, fmt.Errorf("cannot throw error for job key %d: %w", payload.JobKey, err)
}

View File

@ -19,6 +19,8 @@ import (
"errors"
"fmt"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/camunda/zeebe/clients/go/v8/pkg/entities"
@ -39,6 +41,9 @@ type ZeebeJobWorker struct {
client zbc.Client
metadata *jobWorkerMetadata
logger logger.Logger
closed atomic.Bool
closeCh chan struct{}
wg sync.WaitGroup
}
// https://docs.zeebe.io/basics/job-workers.html
@ -64,11 +69,15 @@ type jobHandler struct {
// NewZeebeJobWorker returns a new ZeebeJobWorker instance.
func NewZeebeJobWorker(logger logger.Logger) bindings.InputBinding {
return &ZeebeJobWorker{clientFactory: zeebe.NewClientFactoryImpl(logger), logger: logger}
return &ZeebeJobWorker{
clientFactory: zeebe.NewClientFactoryImpl(logger),
logger: logger,
closeCh: make(chan struct{}),
}
}
// Init does metadata parsing and connection creation.
func (z *ZeebeJobWorker) Init(metadata bindings.Metadata) error {
func (z *ZeebeJobWorker) Init(ctx context.Context, metadata bindings.Metadata) error {
meta, err := z.parseMetadata(metadata)
if err != nil {
return err
@ -90,6 +99,10 @@ func (z *ZeebeJobWorker) Init(metadata bindings.Metadata) error {
}
func (z *ZeebeJobWorker) Read(ctx context.Context, handler bindings.Handler) error {
if z.closed.Load() {
return fmt.Errorf("binding is closed")
}
h := jobHandler{
callback: handler,
logger: z.logger,
@ -99,8 +112,14 @@ func (z *ZeebeJobWorker) Read(ctx context.Context, handler bindings.Handler) err
jobWorker := z.getJobWorker(h)
z.wg.Add(1)
go func() {
<-ctx.Done()
defer z.wg.Done()
select {
case <-z.closeCh:
case <-ctx.Done():
}
jobWorker.Close()
jobWorker.AwaitClose()
@ -110,6 +129,14 @@ func (z *ZeebeJobWorker) Read(ctx context.Context, handler bindings.Handler) err
return nil
}
func (z *ZeebeJobWorker) Close() error {
if z.closed.CompareAndSwap(false, true) {
close(z.closeCh)
}
z.wg.Wait()
return nil
}
func (z *ZeebeJobWorker) parseMetadata(meta bindings.Metadata) (*jobWorkerMetadata, error) {
var m jobWorkerMetadata
err := metadata.DecodeMetadata(meta.Properties, &m)

View File

@ -14,6 +14,7 @@ limitations under the License.
package jobworker
import (
"context"
"errors"
"testing"
@ -53,10 +54,11 @@ func TestInit(t *testing.T) {
metadata := bindings.Metadata{}
var mcf mockClientFactory
jobWorker := ZeebeJobWorker{clientFactory: &mcf, logger: testLogger}
err := jobWorker.Init(metadata)
jobWorker := ZeebeJobWorker{clientFactory: &mcf, logger: testLogger, closeCh: make(chan struct{})}
err := jobWorker.Init(context.Background(), metadata)
assert.Error(t, err, ErrMissingJobType)
assert.NoError(t, jobWorker.Close())
})
t.Run("sets client from client factory", func(t *testing.T) {
@ -66,8 +68,8 @@ func TestInit(t *testing.T) {
mcf := mockClientFactory{
metadata: metadata,
}
jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger}
err := jobWorker.Init(metadata)
jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger, closeCh: make(chan struct{})}
err := jobWorker.Init(context.Background(), metadata)
assert.NoError(t, err)
@ -76,6 +78,7 @@ func TestInit(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, mc, jobWorker.client)
assert.Equal(t, metadata, mcf.metadata)
assert.NoError(t, jobWorker.Close())
})
t.Run("returns error if client could not be instantiated properly", func(t *testing.T) {
@ -85,9 +88,10 @@ func TestInit(t *testing.T) {
error: errParsing,
}
jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger}
err := jobWorker.Init(metadata)
jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger, closeCh: make(chan struct{})}
err := jobWorker.Init(context.Background(), metadata)
assert.Error(t, err, errParsing)
assert.NoError(t, jobWorker.Close())
})
t.Run("sets client from client factory", func(t *testing.T) {
@ -98,8 +102,8 @@ func TestInit(t *testing.T) {
metadata: metadata,
}
jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger}
err := jobWorker.Init(metadata)
jobWorker := ZeebeJobWorker{clientFactory: mcf, logger: testLogger, closeCh: make(chan struct{})}
err := jobWorker.Init(context.Background(), metadata)
assert.NoError(t, err)
@ -108,5 +112,6 @@ func TestInit(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, mc, jobWorker.client)
assert.Equal(t, metadata, mcf.metadata)
assert.NoError(t, jobWorker.Close())
})
}

View File

@ -73,7 +73,7 @@ func NewAzureAppConfigurationStore(logger logger.Logger) configuration.Store {
}
// Init does metadata and connection parsing.
func (r *ConfigurationStore) Init(metadata configuration.Metadata) error {
func (r *ConfigurationStore) Init(_ context.Context, metadata configuration.Metadata) error {
m, err := parseMetadata(metadata)
if err != nil {
return err

View File

@ -204,7 +204,7 @@ func TestInit(t *testing.T) {
Properties: testProperties,
}}
err := s.Init(m)
err := s.Init(context.Background(), m)
assert.Nil(t, err)
cs, ok := s.(*ConfigurationStore)
assert.True(t, ok)
@ -229,7 +229,7 @@ func TestInit(t *testing.T) {
Properties: testProperties,
}}
err := s.Init(m)
err := s.Init(context.Background(), m)
assert.Nil(t, err)
cs, ok := s.(*ConfigurationStore)
assert.True(t, ok)

View File

@ -86,7 +86,7 @@ func NewPostgresConfigurationStore(logger logger.Logger) configuration.Store {
}
}
func (p *ConfigurationStore) Init(metadata configuration.Metadata) error {
func (p *ConfigurationStore) Init(parentCtx context.Context, metadata configuration.Metadata) error {
p.logger.Debug(InfoStartInit)
if p.client != nil {
return fmt.Errorf(ErrorAlreadyInitialized)
@ -98,7 +98,7 @@ func (p *ConfigurationStore) Init(metadata configuration.Metadata) error {
p.metadata = m
}
p.ActiveSubscriptions = make(map[string]*subscription)
ctx, cancel := context.WithTimeout(context.Background(), p.metadata.maxIdleTimeout)
ctx, cancel := context.WithTimeout(parentCtx, p.metadata.maxIdleTimeout)
defer cancel()
client, err := Connect(ctx, p.metadata.connectionString, p.metadata.maxIdleTimeout)
if err != nil {

View File

@ -143,7 +143,7 @@ func parseRedisMetadata(meta configuration.Metadata) (metadata, error) {
}
// Init does metadata and connection parsing.
func (r *ConfigurationStore) Init(metadata configuration.Metadata) error {
func (r *ConfigurationStore) Init(ctx context.Context, metadata configuration.Metadata) error {
m, err := parseRedisMetadata(metadata)
if err != nil {
return err
@ -156,11 +156,11 @@ func (r *ConfigurationStore) Init(metadata configuration.Metadata) error {
r.client = r.newClient(m)
}
if _, err = r.client.Ping(context.TODO()).Result(); err != nil {
if _, err = r.client.Ping(ctx).Result(); err != nil {
return fmt.Errorf("redis store: error connecting to redis at %s: %s", m.Host, err)
}
r.replicas, err = r.getConnectedSlaves()
r.replicas, err = r.getConnectedSlaves(ctx)
return err
}
@ -204,8 +204,8 @@ func (r *ConfigurationStore) newFailoverClient(m metadata) *redis.Client {
return redis.NewFailoverClient(opts)
}
func (r *ConfigurationStore) getConnectedSlaves() (int, error) {
res, err := r.client.Do(context.Background(), "INFO", "replication").Result()
func (r *ConfigurationStore) getConnectedSlaves(ctx context.Context) (int, error) {
res, err := r.client.Do(ctx, "INFO", "replication").Result()
if err != nil {
return 0, err
}

View File

@ -18,7 +18,7 @@ import "context"
// Store is an interface to perform operations on store.
type Store interface {
// Init configuration store.
Init(metadata Metadata) error
Init(ctx context.Context, metadata Metadata) error
// Get configuration.
Get(ctx context.Context, req *GetRequest) (*GetResponse, error)

2
go.mod
View File

@ -400,3 +400,5 @@ replace github.com/toolkits/concurrent => github.com/niean/gotools v0.0.0-201512
// this is a fork which addresses a performance issues due to go routines
replace dubbo.apache.org/dubbo-go/v3 => dubbo.apache.org/dubbo-go/v3 v3.0.3-0.20230118042253-4f159a2b38f3
replace github.com/dapr/dapr => github.com/JoshVanL/dapr v0.0.0-20230206115221-708ff00f7181

2
go.sum
View File

@ -641,8 +641,6 @@ github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
github.com/bytecodealliance/wasmtime-go/v3 v3.0.2 h1:3uZCA/BLTIu+DqCfguByNMJa2HVHpXvjfy0Dy7g6fuA=
github.com/camunda/zeebe/clients/go/v8 v8.0.11 h1:rDqsr0h5H9wmPg1bk0srNRakhtvVO5KcMVgbftKTQqg=
github.com/camunda/zeebe/clients/go/v8 v8.0.11/go.mod h1:vqeNO1EphExqC15spP56PNXQ6SB8sMjhEfO16bfFRPo=
github.com/camunda/zeebe/clients/go/v8 v8.1.8 h1:/i3t1PaToPfED+609uNR9kdGo/LPFTE4jK5/SEbwY4Y=
github.com/camunda/zeebe/clients/go/v8 v8.1.8/go.mod h1:nQc5qX4lPSxWUW0VuJ+k3b+FdcVLNa29A/nAQG2q9u4=
github.com/casbin/casbin/v2 v2.1.2/go.mod h1:YcPU1XXisHhLzuxH9coDNf2FbKpjGlbCg3n9yuLkIJQ=

View File

@ -1,5 +1,22 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package health
import (
"context"
)
type Pinger interface {
Ping() error
Ping(ctx context.Context) error
}

View File

@ -19,6 +19,7 @@ import (
"fmt"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/Shopify/sarama"
@ -31,6 +32,7 @@ type consumer struct {
k *Kafka
ready chan bool
running chan struct{}
stopped atomic.Bool
once sync.Once
mutex sync.Mutex
}
@ -275,9 +277,6 @@ func (k *Kafka) Subscribe(ctx context.Context) error {
k.cg = cg
ctx, cancel := context.WithCancel(ctx)
k.cancel = cancel
ready := make(chan bool)
k.consumer = consumer{
k: k,
@ -320,7 +319,10 @@ func (k *Kafka) Subscribe(ctx context.Context) error {
k.logger.Errorf("Error closing consumer group: %v", err)
}
close(k.consumer.running)
// Ensure running channel is only closed once.
if k.consumer.stopped.CompareAndSwap(false, true) {
close(k.consumer.running)
}
}()
<-ready
@ -331,7 +333,6 @@ func (k *Kafka) Subscribe(ctx context.Context) error {
// Close down consumer group resources, refresh once.
func (k *Kafka) closeSubscriptionResources() {
if k.cg != nil {
k.cancel()
err := k.cg.Close()
if err != nil {
k.logger.Errorf("Error closing consumer group: %v", err)

View File

@ -36,7 +36,6 @@ type Kafka struct {
saslPassword string
initialOffset int64
cg sarama.ConsumerGroup
cancel context.CancelFunc
consumer consumer
config *sarama.Config
subscribeTopics TopicHandlerConfig
@ -60,7 +59,7 @@ func NewKafka(logger logger.Logger) *Kafka {
}
// Init does metadata parsing and connection establishment.
func (k *Kafka) Init(metadata map[string]string) error {
func (k *Kafka) Init(_ context.Context, metadata map[string]string) error {
upgradedMetadata, err := k.upgradeMetadata(metadata)
if err != nil {
return err

View File

@ -107,7 +107,6 @@ func (k *Kafka) getKafkaMetadata(metadata map[string]string) (*kafkaMetadata, er
if val, ok := metadata["consumerID"]; ok && val != "" {
meta.ConsumerGroup = val
k.logger.Debugf("Using %s as ConsumerGroup", meta.ConsumerGroup)
k.logger.Warn("ConsumerID is deprecated, if ConsumerID and ConsumerGroup are both set, ConsumerGroup is used")
}
if val, ok := metadata["consumerGroup"]; ok && val != "" {

Some files were not shown because too many files have changed in this diff Show More