Merge branch 'master' into blob-storage-v2

This commit is contained in:
Alessandro (Ale) Segala 2023-11-01 16:07:23 -07:00 committed by GitHub
commit c8e60e920f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 1365 additions and 741 deletions

1
go.mod
View File

@ -90,6 +90,7 @@ require (
github.com/oracle/oci-go-sdk/v54 v54.0.0
github.com/pashagolub/pgxmock/v2 v2.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/puzpuzpuz/xsync/v3 v3.0.0
github.com/rabbitmq/amqp091-go v1.8.1
github.com/redis/go-redis/v9 v9.2.1
github.com/sendgrid/sendgrid-go v3.13.0+incompatible

2
go.sum
View File

@ -1732,6 +1732,8 @@ github.com/prometheus/statsd_exporter v0.21.0/go.mod h1:rbT83sZq2V+p73lHhPZfMc3M
github.com/prometheus/statsd_exporter v0.22.7 h1:7Pji/i2GuhK6Lu7DHrtTkFmNBCudCPT1pX2CziuyQR0=
github.com/prometheus/statsd_exporter v0.22.7/go.mod h1:N/TevpjkIh9ccs6nuzY3jQn9dFqnUakOjnEuMPJJJnI=
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
github.com/puzpuzpuz/xsync/v3 v3.0.0 h1:QwUcmah+dZZxy6va/QSU26M6O6Q422afP9jO8JlnRSA=
github.com/puzpuzpuz/xsync/v3 v3.0.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
github.com/rabbitmq/amqp091-go v1.8.1 h1:RejT1SBUim5doqcL6s7iN6SBmsQqyTgXb1xMlH0h1hA=
github.com/rabbitmq/amqp091-go v1.8.1/go.mod h1:+jPrT9iY2eLjRaMSRHUhc3z14E/l85kv/f+6luSD3pc=
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=

View File

@ -20,6 +20,9 @@ import (
"strings"
"time"
// Blank import for the sqlite driver
_ "modernc.org/sqlite"
"github.com/dapr/kit/logger"
)
@ -44,6 +47,7 @@ func (m *SqliteAuthMetadata) Reset() {
m.DisableWAL = false
}
// Validate the auth metadata and returns an error if it's not valid.
func (m *SqliteAuthMetadata) Validate() error {
// Validate and sanitize input
if m.ConnectionString == "" {
@ -60,10 +64,16 @@ func (m *SqliteAuthMetadata) Validate() error {
return nil
}
// IsInMemoryDB returns true if the connection string is for an in-memory database.
func (m SqliteAuthMetadata) IsInMemoryDB() bool {
lc := strings.ToLower(m.ConnectionString)
return strings.HasPrefix(lc, ":memory:") || strings.HasPrefix(lc, "file::memory:")
}
// GetConnectionString returns the parsed connection string.
func (m *SqliteAuthMetadata) GetConnectionString(log logger.Logger) (string, error) {
// Check if we're using the in-memory database
lc := strings.ToLower(m.ConnectionString)
isMemoryDB := strings.HasPrefix(lc, ":memory:") || strings.HasPrefix(lc, "file::memory:")
isMemoryDB := m.IsInMemoryDB()
// Get the "query string" from the connection string if present
idx := strings.IndexRune(m.ConnectionString, '?')
@ -151,7 +161,7 @@ func (m *SqliteAuthMetadata) GetConnectionString(log logger.Logger) (string, err
connString += "?" + qs.Encode()
// If the connection string doesn't begin with "file:", add the prefix
if !strings.HasPrefix(lc, "file:") {
if !strings.HasPrefix(strings.ToLower(m.ConnectionString), "file:") {
log.Debug("prefix 'file:' added to the connection string")
connString = "file:" + connString
}

View File

@ -56,9 +56,6 @@ type AzureEventHubs struct {
checkpointStoreLock *sync.RWMutex
managementCreds azcore.TokenCredential
// TODO(@ItalyPaleAle): Remove in Dapr 1.13
isFailed atomic.Bool
}
// HandlerResponseItem represents a response from the handler for each message.
@ -251,37 +248,6 @@ func (aeh *AzureEventHubs) Subscribe(subscribeCtx context.Context, config Subscr
return fmt.Errorf("error trying to establish a connection: %w", err)
}
// Ensure that no subscriber using the old "track 1" SDK is active
// TODO(@ItalyPaleAle): Remove this for Dapr 1.13
{
// If a previous topic already failed, no need to try with other topics, as we're about to panic anyways
if aeh.isFailed.Load() {
return errors.New("subscribing to another topic on this component failed and Dapr is scheduled to crash; will not try subscribing to a new topic")
}
ctx, cancel := context.WithTimeout(subscribeCtx, 2*time.Minute)
err = aeh.ensureNoTrack1Subscribers(ctx, topic)
cancel()
if err != nil {
// If there's a timeout, it means that the other client was still active after the timeout
// In this case, we return an error here so Dapr can continue the initialization and report a "healthy" status (but this subscription won't be active)
// After 2 minutes, then, we panic, which ensures that during a rollout Kubernetes will see that this pod is unhealthy and re-creates that. Hopefully, by then other instances of the app will have been updated and no more locks will be present
if errors.Is(err, context.DeadlineExceeded) {
aeh.isFailed.Store(true)
errMsg := fmt.Sprintf("Another instance is currently subscribed to the topic %s in this Event Hub using an old version of Dapr, and this is not supported. Please ensure that all applications subscribed to the same topic, with this consumer group, are using Dapr 1.10 or newer.", topic)
aeh.logger.Error(errMsg + " ⚠️⚠️⚠️ Dapr will crash in 2 minutes to force the orchestrator to restart the process after the rollout of other instances is complete.")
go func() {
time.Sleep(2 * time.Minute)
aeh.logger.Fatalf("Another instance is currently subscribed to the topic %s in this Event Hub using an old version of Dapr, and this is not supported. Please ensure that all applications subscribed to the same topic, with this consumer group, are using Dapr 1.10 or newer.", topic)
}()
return fmt.Errorf("another instance is currently subscribed to the topic %s in this Event Hub using an old version of Dapr", topic)
}
// In case of other errors, just return the error
return fmt.Errorf("failed to check for subscribers using an old version of Dapr: %w", err)
}
}
// This component has built-in retries because Event Hubs doesn't support N/ACK for messages
retryHandler := func(ctx context.Context, events []*azeventhubs.ReceivedEventData) ([]HandlerResponseItem, error) {
b := aeh.backOffConfig.NewBackOffWithContext(ctx)
@ -621,7 +587,7 @@ func (aeh *AzureEventHubs) createCheckpointStore(ctx context.Context) (checkpoin
}
// Get the Azure Blob Storage client and ensure the container exists
client, err := aeh.createStorageClient(ctx, true)
client, err := aeh.createStorageClient(ctx)
if err != nil {
return nil, err
}
@ -641,8 +607,7 @@ func (aeh *AzureEventHubs) createCheckpointStore(ctx context.Context) (checkpoin
}
// Creates a client to access Azure Blob Storage.
// TODO(@ItalyPaleAle): Remove ensureContainer option (and default to true) for Dapr 1.13
func (aeh *AzureEventHubs) createStorageClient(ctx context.Context, ensureContainer bool) (*container.Client, error) {
func (aeh *AzureEventHubs) createStorageClient(ctx context.Context) (*container.Client, error) {
m := blobstorage.ContainerClientOpts{
ConnectionString: aeh.metadata.StorageConnectionString,
ContainerName: aeh.metadata.StorageContainerName,
@ -655,13 +620,11 @@ func (aeh *AzureEventHubs) createStorageClient(ctx context.Context, ensureContai
return nil, err
}
if ensureContainer {
// Ensure the container exists
// We're setting "accessLevel" to nil to make sure it's private
err = m.EnsureContainer(ctx, client, nil)
if err != nil {
return nil, err
}
// Ensure the container exists
// We're setting "accessLevel" to nil to make sure it's private
err = m.EnsureContainer(ctx, client, nil)
if err != nil {
return nil, err
}
return client, nil

View File

@ -1,93 +0,0 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package eventhubs
import (
"context"
"errors"
"fmt"
"net/http"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container"
"github.com/cenkalti/backoff/v4"
"github.com/dapr/kit/retry"
)
// This method ensures that there are currently no active subscribers to the same Event Hub topic that are using the old ("track 1") SDK of Azure Event Hubs. This is the SDK that was in use until Dapr 1.9.
// Because the new SDK stores checkpoints in a different way, clients using the new ("track 2") and the old SDK cannot coexist.
// To ensure this doesn't happen, when we create a new subscription to the same topic and with the same consumer group, we check if there's a file in Azure Storage with the checkpoint created by the old SDK and with a still-active lease. If that's true, we wait (until the context expires) before we crash Dapr with a log message describing what's happening.
// These conflicts should be transient anyways, as mixed versions of Dapr should only happen during a rollout of a new version of Dapr.
// TODO(@ItalyPaleAle): Remove this (entire file) for Dapr 1.13
func (aeh *AzureEventHubs) ensureNoTrack1Subscribers(parentCtx context.Context, topic string) error {
// Get a client to Azure Blob Storage
// Because we are not using "ensureContainer=true", we can pass a nil context
client, err := aeh.createStorageClient(nil, false) //nolint:staticcheck
if err != nil {
return err
}
// In the old version of the SDK, checkpoints were stored in the root of the storage account and were named like:
// `dapr-(topic)-(consumer-group)-(partition-key)`
// We need to list those up and check if they have an active lease
prefix := fmt.Sprintf("dapr-%s-%s-", topic, aeh.metadata.ConsumerGroup)
// Retry until we find no active leases - or the context expires
backOffConfig := retry.DefaultConfig()
backOffConfig.Policy = retry.PolicyExponential
backOffConfig.MaxInterval = time.Minute
backOffConfig.MaxElapsedTime = 0
backOffConfig.MaxRetries = -1
b := backOffConfig.NewBackOffWithContext(parentCtx)
err = backoff.Retry(func() error {
pager := client.NewListBlobsFlatPager(&container.ListBlobsFlatOptions{
Prefix: &prefix,
})
for pager.More() {
ctx, cancel := context.WithTimeout(parentCtx, resourceGetTimeout)
resp, innerErr := pager.NextPage(ctx)
cancel()
if innerErr != nil {
// Treat these errors as permanent
resErr := &azcore.ResponseError{}
if !errors.As(err, &resErr) || resErr.StatusCode != http.StatusNotFound {
// A "not-found" error means that the storage container doesn't exist, so let's not handle it here
// Just return no error
return nil
}
return backoff.Permanent(fmt.Errorf("failed to list blobs: %w", innerErr))
}
for _, blob := range resp.Segment.BlobItems {
if blob == nil || blob.Name == nil || blob.Properties == nil || blob.Properties.LeaseState == nil {
continue
}
aeh.logger.Debugf("Found checkpoint from an older Dapr version %s", *blob.Name)
// If the blob is locked, it means that there's another Dapr process with an old version of the SDK running, so we need to wait
if *blob.Properties.LeaseStatus == "locked" {
aeh.logger.Warnf("Found active lease on checkpoint %s from an older Dapr version - waiting for lease to expire", *blob.Name)
return fmt.Errorf("found active lease on checkpoint %s from an old Dapr version", *blob.Name)
}
}
}
return nil
}, b)
// If the error is a DeadlineExceeded on the operation and not on parentCtx, handle that separately to avoid crashing Dapr needlessly
if err != nil && errors.Is(err, context.DeadlineExceeded) && parentCtx.Err() != context.DeadlineExceeded {
err = errors.New("failed to list blobs: request timed out")
}
return err
}

View File

@ -77,7 +77,7 @@ func Migrate(ctx context.Context, db DatabaseConn, opts MigrationOptions) error
return fmt.Errorf("invalid migration level found in metadata table: %s", migrationLevelStr)
}
}
opts.Logger.Debug("Migrate: current migration level: %d", migrationLevel)
opts.Logger.Debugf("Migrate: current migration level: %d", migrationLevel)
// Perform the migrations
for i := migrationLevel; i < len(opts.Migrations); i++ {

View File

@ -56,11 +56,11 @@ func (m Migrations) Perform(ctx context.Context, migrationFns []sqlinternal.Migr
defer func() {
m.Logger.Debug("Releasing advisory lock")
queryCtx, cancel = context.WithTimeout(ctx, time.Minute)
_, err = m.DB.Exec(queryCtx, "SELECT pg_advisory_unlock($1)", lockID)
_, rollbackErr := m.DB.Exec(queryCtx, "SELECT pg_advisory_unlock($1)", lockID)
cancel()
if err != nil {
if rollbackErr != nil {
// Panicking here, as this forcibly closes the session and thus ensures we are not leaving locks hanging around
m.Logger.Fatalf("Failed to release advisory lock: %v", err)
m.Logger.Fatalf("Failed to release advisory lock: %v", rollbackErr)
}
}()

View File

@ -58,11 +58,11 @@ func (m *Migrations) Perform(ctx context.Context, migrationFns []sqlinternal.Mig
return
}
queryCtx, cancel = context.WithTimeout(ctx, time.Minute)
_, err = m.conn.ExecContext(queryCtx, "ROLLBACK TRANSACTION")
_, rollbackErr := m.conn.ExecContext(queryCtx, "ROLLBACK TRANSACTION")
cancel()
if err != nil {
// Panicking here, as this forcibly closes the session and thus ensures we are not leaving locks hanging around
m.Logger.Fatalf("Failed to rollback transaction: %v", err)
if rollbackErr != nil {
// Panicking here, as this forcibly closes the session and thus ensures we are not leaving transactions open
m.Logger.Fatalf("Failed to rollback transaction: %v", rollbackErr)
}
}()

View File

@ -14,6 +14,7 @@ limitations under the License.
package consul
import (
"context"
"fmt"
"math/rand"
"net"
@ -238,7 +239,7 @@ func newResolver(logger logger.Logger, resolverConfig resolverConfig, client cli
}
// Init will configure component. It will also register service or validate client connection based on config.
func (r *resolver) Init(metadata nr.Metadata) (err error) {
func (r *resolver) Init(ctx context.Context, metadata nr.Metadata) (err error) {
r.config, err = getConfig(metadata)
if err != nil {
return err
@ -274,7 +275,7 @@ func (r *resolver) Init(metadata nr.Metadata) (err error) {
}
// ResolveID resolves name to address via consul.
func (r *resolver) ResolveID(req nr.ResolveRequest) (addr string, err error) {
func (r *resolver) ResolveID(ctx context.Context, req nr.ResolveRequest) (addr string, err error) {
cfg := r.config
svc, err := r.getService(req.ID)
if err != nil {
@ -327,7 +328,8 @@ func formatAddress(address string, port string) (addr string, err error) {
// getConfig configuration from metadata, defaults are best suited for self-hosted mode.
func getConfig(metadata nr.Metadata) (resolverCfg resolverConfig, err error) {
if metadata.Properties[nr.DaprPort] == "" {
props := metadata.GetPropertiesMap()
if props[nr.DaprPort] == "" {
return resolverCfg, fmt.Errorf("metadata property missing: %s", nr.DaprPort)
}
@ -341,7 +343,7 @@ func getConfig(metadata nr.Metadata) (resolverCfg resolverConfig, err error) {
resolverCfg.UseCache = cfg.UseCache
resolverCfg.Client = getClientConfig(cfg)
resolverCfg.Registration, err = getRegistrationConfig(cfg, metadata.Properties)
resolverCfg.Registration, err = getRegistrationConfig(cfg, props)
if err != nil {
return resolverCfg, err
}
@ -353,7 +355,7 @@ func getConfig(metadata nr.Metadata) (resolverCfg resolverConfig, err error) {
resolverCfg.Registration.Meta = map[string]string{}
}
resolverCfg.Registration.Meta[resolverCfg.DaprPortMetaKey] = metadata.Properties[nr.DaprPort]
resolverCfg.Registration.Meta[resolverCfg.DaprPortMetaKey] = props[nr.DaprPort]
}
return resolverCfg, nil
@ -382,22 +384,25 @@ func getRegistrationConfig(cfg configSpec, props map[string]string) (*consul.Age
appPort string
host string
httpPort string
ok bool
)
if appID, ok = props[nr.AppID]; !ok {
appID = props[nr.AppID]
if appID == "" {
return nil, fmt.Errorf("metadata property missing: %s", nr.AppID)
}
if appPort, ok = props[nr.AppPort]; !ok {
appPort = props[nr.AppPort]
if appPort == "" {
return nil, fmt.Errorf("metadata property missing: %s", nr.AppPort)
}
if host, ok = props[nr.HostAddress]; !ok {
host = props[nr.HostAddress]
if host == "" {
return nil, fmt.Errorf("metadata property missing: %s", nr.HostAddress)
}
if httpPort, ok = props[nr.DaprHTTPPort]; !ok {
httpPort = props[nr.DaprHTTPPort]
if httpPort == "" {
return nil, fmt.Errorf("metadata property missing: %s", nr.DaprHTTPPort)
} else if _, err := strconv.ParseUint(httpPort, 10, 0); err != nil {
return nil, fmt.Errorf("error parsing %s: %w", nr.DaprHTTPPort, err)

View File

@ -14,6 +14,7 @@ limitations under the License.
package consul
import (
"context"
"fmt"
"net"
"strconv"
@ -24,7 +25,6 @@ import (
consul "github.com/hashicorp/consul/api"
"github.com/stretchr/testify/assert"
"github.com/dapr/components-contrib/metadata"
nr "github.com/dapr/components-contrib/nameresolution"
"github.com/dapr/kit/logger"
)
@ -184,8 +184,6 @@ func (m *mockRegistry) get(service string) *registryEntry {
}
func TestInit(t *testing.T) {
t.Parallel()
tests := []struct {
testName string
metadata nr.Metadata
@ -193,16 +191,12 @@ func TestInit(t *testing.T) {
}{
{
"given no configuration don't register service just check agent",
nr.Metadata{Base: metadata.Base{
Properties: getTestPropsWithoutKey(""),
}, Configuration: nil},
nr.Metadata{Instance: getInstanceInfoWithoutKey(""), Configuration: nil},
func(t *testing.T, metadata nr.Metadata) {
t.Helper()
var mock mockClient
resolver := newResolver(logger.NewLogger("test"), resolverConfig{}, &mock, &registry{}, make(chan struct{}))
_ = resolver.Init(metadata)
_ = resolver.Init(context.Background(), metadata)
assert.Equal(t, 1, mock.initClientCalled)
assert.Equal(t, 0, mock.mockAgent.serviceRegisterCalled)
@ -212,20 +206,16 @@ func TestInit(t *testing.T) {
{
"given SelfRegister true then register service",
nr.Metadata{
Base: metadata.Base{
Properties: getTestPropsWithoutKey(""),
},
Instance: getInstanceInfoWithoutKey(""),
Configuration: configSpec{
SelfRegister: true,
},
},
func(t *testing.T, metadata nr.Metadata) {
t.Helper()
var mock mockClient
resolver := newResolver(logger.NewLogger("test"), resolverConfig{}, &mock, &registry{}, make(chan struct{}))
_ = resolver.Init(metadata)
_ = resolver.Init(context.Background(), metadata)
assert.Equal(t, 1, mock.initClientCalled)
assert.Equal(t, 1, mock.mockAgent.serviceRegisterCalled)
@ -235,19 +225,17 @@ func TestInit(t *testing.T) {
{
"given AdvancedRegistraion then register service",
nr.Metadata{
Base: metadata.Base{Properties: getTestPropsWithoutKey("")},
Instance: getInstanceInfoWithoutKey(""),
Configuration: configSpec{
AdvancedRegistration: &consul.AgentServiceRegistration{},
QueryOptions: &consul.QueryOptions{},
},
},
func(t *testing.T, metadata nr.Metadata) {
t.Helper()
var mock mockClient
resolver := newResolver(logger.NewLogger("test"), resolverConfig{}, &mock, &registry{}, make(chan struct{}))
_ = resolver.Init(metadata)
_ = resolver.Init(context.Background(), metadata)
assert.Equal(t, 1, mock.initClientCalled)
assert.Equal(t, 1, mock.mockAgent.serviceRegisterCalled)
@ -259,7 +247,6 @@ func TestInit(t *testing.T) {
for _, tt := range tests {
tt := tt
t.Run(tt.testName, func(t *testing.T) {
t.Parallel()
tt.test(t, tt.metadata)
})
}
@ -283,8 +270,6 @@ func TestResolveID(t *testing.T) {
ID: "test-app",
},
func(t *testing.T, req nr.ResolveRequest) {
t.Helper()
blockingCall := make(chan uint64)
meta := &consul.QueryMeta{
LastIndex: 0,
@ -360,7 +345,7 @@ func TestResolveID(t *testing.T) {
},
}
resolver := newResolver(logger.NewLogger("test"), cfg, mock, mockReg, make(chan struct{}))
addr, _ := resolver.ResolveID(req)
addr, _ := resolver.ResolveID(context.Background(), req)
// no apps in registry - cache miss, call agent directly
assert.Equal(t, 1, mockReg.getCalled)
@ -387,11 +372,11 @@ func TestResolveID(t *testing.T) {
assert.Equal(t, int32(2), mockReg.addOrUpdateCalled.Load())
// resolve id should only hit cache now
addr, _ = resolver.ResolveID(req)
addr, _ = resolver.ResolveID(context.Background(), req)
assert.Equal(t, "10.3.245.137:70007", addr)
addr, _ = resolver.ResolveID(req)
addr, _ = resolver.ResolveID(context.Background(), req)
assert.Equal(t, "10.3.245.137:70007", addr)
addr, _ = resolver.ResolveID(req)
addr, _ = resolver.ResolveID(context.Background(), req)
assert.Equal(t, "10.3.245.137:70007", addr)
assert.Equal(t, 2, mock.mockHealth.serviceCalled)
@ -426,8 +411,6 @@ func TestResolveID(t *testing.T) {
ID: "test-app",
},
func(t *testing.T, req nr.ResolveRequest) {
t.Helper()
blockingCall := make(chan uint64)
meta := &consul.QueryMeta{}
@ -520,7 +503,7 @@ func TestResolveID(t *testing.T) {
},
}
resolver := newResolver(logger.NewLogger("test"), cfg, &mock, mockReg, make(chan struct{}))
addr, _ := resolver.ResolveID(req)
addr, _ := resolver.ResolveID(context.Background(), req)
// no apps in registry - cache miss, call agent directly
assert.Equal(t, 1, mockReg.getCalled)
@ -548,9 +531,9 @@ func TestResolveID(t *testing.T) {
assert.Equal(t, int32(2), mockReg.addOrUpdateCalled.Load())
// resolve id should only hit cache now
_, _ = resolver.ResolveID(req)
_, _ = resolver.ResolveID(req)
_, _ = resolver.ResolveID(req)
_, _ = resolver.ResolveID(context.Background(), req)
_, _ = resolver.ResolveID(context.Background(), req)
_, _ = resolver.ResolveID(context.Background(), req)
assert.Equal(t, 2, mock.mockHealth.serviceCalled)
// change one check for node1 app to critical
@ -600,8 +583,6 @@ func TestResolveID(t *testing.T) {
ID: "test-app",
},
func(t *testing.T, req nr.ResolveRequest) {
t.Helper()
blockingCall := make(chan uint64)
meta := &consul.QueryMeta{
LastIndex: 0,
@ -667,7 +648,7 @@ func TestResolveID(t *testing.T) {
},
}
resolver := newResolver(logger.NewLogger("test"), cfg, mock, mockReg, make(chan struct{}))
addr, _ := resolver.ResolveID(req)
addr, _ := resolver.ResolveID(context.Background(), req)
// Cache miss pass through
assert.Equal(t, 1, mockReg.getCalled)
@ -702,8 +683,6 @@ func TestResolveID(t *testing.T) {
ID: "test-app",
},
func(t *testing.T, req nr.ResolveRequest) {
t.Helper()
blockingCall := make(chan uint64)
meta := &consul.QueryMeta{
LastIndex: 0,
@ -770,7 +749,7 @@ func TestResolveID(t *testing.T) {
},
}
resolver := newResolver(logger.NewLogger("test"), cfg, mock, mockReg, make(chan struct{})).(*resolver)
addr, _ := resolver.ResolveID(req)
addr, _ := resolver.ResolveID(context.Background(), req)
// Cache miss pass through
assert.Equal(t, 1, mockReg.getCalled)
@ -798,7 +777,6 @@ func TestResolveID(t *testing.T) {
ID: "test-app",
},
func(t *testing.T, req nr.ResolveRequest) {
t.Helper()
mock := mockClient{
mockHealth: mockHealth{
serviceResult: []*consul.ServiceEntry{},
@ -806,7 +784,7 @@ func TestResolveID(t *testing.T) {
}
resolver := newResolver(logger.NewLogger("test"), testConfig, &mock, &registry{}, make(chan struct{}))
_, err := resolver.ResolveID(req)
_, err := resolver.ResolveID(context.Background(), req)
assert.Equal(t, 1, mock.mockHealth.serviceCalled)
assert.Error(t, err)
},
@ -817,7 +795,6 @@ func TestResolveID(t *testing.T) {
ID: "test-app",
},
func(t *testing.T, req nr.ResolveRequest) {
t.Helper()
mock := mockClient{
mockHealth: mockHealth{
serviceResult: []*consul.ServiceEntry{
@ -835,7 +812,7 @@ func TestResolveID(t *testing.T) {
}
resolver := newResolver(logger.NewLogger("test"), testConfig, &mock, &registry{}, make(chan struct{}))
addr, _ := resolver.ResolveID(req)
addr, _ := resolver.ResolveID(context.Background(), req)
assert.Equal(t, "10.3.245.137:50005", addr)
},
@ -846,7 +823,6 @@ func TestResolveID(t *testing.T) {
ID: "test-app",
},
func(t *testing.T, req nr.ResolveRequest) {
t.Helper()
mock := mockClient{
mockHealth: mockHealth{
serviceResult: []*consul.ServiceEntry{
@ -864,7 +840,7 @@ func TestResolveID(t *testing.T) {
}
resolver := newResolver(logger.NewLogger("test"), testConfig, &mock, &registry{}, make(chan struct{}))
addr, _ := resolver.ResolveID(req)
addr, _ := resolver.ResolveID(context.Background(), req)
assert.Equal(t, "[2001:db8:3333:4444:5555:6666:7777:8888]:50005", addr)
},
@ -875,7 +851,6 @@ func TestResolveID(t *testing.T) {
ID: "test-app",
},
func(t *testing.T, req nr.ResolveRequest) {
t.Helper()
mock := mockClient{
mockHealth: mockHealth{
serviceResult: []*consul.ServiceEntry{
@ -905,7 +880,7 @@ func TestResolveID(t *testing.T) {
total1 := 0
total2 := 0
for i := 0; i < 100; i++ {
addr, _ := resolver.ResolveID(req)
addr, _ := resolver.ResolveID(context.Background(), req)
if addr == "10.3.245.137:50005" {
total1++
@ -928,7 +903,6 @@ func TestResolveID(t *testing.T) {
ID: "test-app",
},
func(t *testing.T, req nr.ResolveRequest) {
t.Helper()
mock := mockClient{
mockHealth: mockHealth{
serviceResult: []*consul.ServiceEntry{
@ -961,7 +935,7 @@ func TestResolveID(t *testing.T) {
}
resolver := newResolver(logger.NewLogger("test"), testConfig, &mock, &registry{}, make(chan struct{}))
addr, _ := resolver.ResolveID(req)
addr, _ := resolver.ResolveID(context.Background(), req)
assert.Equal(t, "10.3.245.137:50005", addr)
},
@ -972,7 +946,6 @@ func TestResolveID(t *testing.T) {
ID: "test-app",
},
func(t *testing.T, req nr.ResolveRequest) {
t.Helper()
mock := mockClient{
mockHealth: mockHealth{
serviceResult: []*consul.ServiceEntry{
@ -990,7 +963,7 @@ func TestResolveID(t *testing.T) {
}
resolver := newResolver(logger.NewLogger("test"), testConfig, &mock, &registry{}, make(chan struct{}))
_, err := resolver.ResolveID(req)
_, err := resolver.ResolveID(context.Background(), req)
assert.Error(t, err)
},
@ -1001,7 +974,6 @@ func TestResolveID(t *testing.T) {
ID: "test-app",
},
func(t *testing.T, req nr.ResolveRequest) {
t.Helper()
mock := mockClient{
mockHealth: mockHealth{
serviceResult: []*consul.ServiceEntry{
@ -1016,7 +988,7 @@ func TestResolveID(t *testing.T) {
}
resolver := newResolver(logger.NewLogger("test"), testConfig, &mock, &registry{}, make(chan struct{}))
_, err := resolver.ResolveID(req)
_, err := resolver.ResolveID(context.Background(), req)
assert.Error(t, err)
},
@ -1032,8 +1004,6 @@ func TestResolveID(t *testing.T) {
}
func TestClose(t *testing.T) {
t.Parallel()
tests := []struct {
testName string
metadata nr.Metadata
@ -1041,12 +1011,8 @@ func TestClose(t *testing.T) {
}{
{
"should deregister",
nr.Metadata{Base: metadata.Base{
Properties: getTestPropsWithoutKey(""),
}, Configuration: nil},
nr.Metadata{Instance: getInstanceInfoWithoutKey(""), Configuration: nil},
func(t *testing.T, metadata nr.Metadata) {
t.Helper()
var mock mockClient
cfg := resolverConfig{
Registration: &consul.AgentServiceRegistration{},
@ -1061,12 +1027,8 @@ func TestClose(t *testing.T) {
},
{
"should not deregister",
nr.Metadata{Base: metadata.Base{
Properties: getTestPropsWithoutKey(""),
}, Configuration: nil},
nr.Metadata{Instance: getInstanceInfoWithoutKey(""), Configuration: nil},
func(t *testing.T, metadata nr.Metadata) {
t.Helper()
var mock mockClient
cfg := resolverConfig{
Registration: &consul.AgentServiceRegistration{},
@ -1081,12 +1043,8 @@ func TestClose(t *testing.T) {
},
{
"should not deregister when no registration",
nr.Metadata{Base: metadata.Base{
Properties: getTestPropsWithoutKey(""),
}, Configuration: nil},
nr.Metadata{Instance: getInstanceInfoWithoutKey(""), Configuration: nil},
func(t *testing.T, metadata nr.Metadata) {
t.Helper()
var mock mockClient
cfg := resolverConfig{
Registration: nil,
@ -1101,12 +1059,8 @@ func TestClose(t *testing.T) {
},
{
"should stop watcher if started",
nr.Metadata{Base: metadata.Base{
Properties: getTestPropsWithoutKey(""),
}, Configuration: nil},
nr.Metadata{Instance: getInstanceInfoWithoutKey(""), Configuration: nil},
func(t *testing.T, metadata nr.Metadata) {
t.Helper()
var mock mockClient
resolver := newResolver(logger.NewLogger("test"), resolverConfig{}, &mock, &registry{}, make(chan struct{})).(*resolver)
resolver.watcherStarted.Store(true)
@ -1136,8 +1090,6 @@ func TestClose(t *testing.T) {
}
func TestRegistry(t *testing.T) {
t.Parallel()
appID := "myService"
tests := []struct {
testName string
@ -1146,8 +1098,6 @@ func TestRegistry(t *testing.T) {
{
"should add and update entry",
func(t *testing.T) {
t.Helper()
registry := &registry{}
result := []*consul.ServiceEntry{
@ -1181,8 +1131,6 @@ func TestRegistry(t *testing.T) {
{
"should expire entries",
func(t *testing.T) {
t.Helper()
registry := &registry{}
registry.entries.Store(
"A",
@ -1249,8 +1197,6 @@ func TestRegistry(t *testing.T) {
{
"should remove entry",
func(t *testing.T) {
t.Helper()
registry := &registry{}
entry := &registryEntry{
services: []*consul.ServiceEntry{
@ -1298,8 +1244,6 @@ func TestRegistry(t *testing.T) {
}
func TestParseConfig(t *testing.T) {
t.Parallel()
tests := []struct {
testName string
shouldParse bool
@ -1309,9 +1253,9 @@ func TestParseConfig(t *testing.T) {
{
"valid configuration in metadata",
true,
map[interface{}]interface{}{
map[any]any{
"Checks": []interface{}{
map[interface{}]interface{}{
map[any]any{
"Name": "test-app health check name",
"CheckID": "test-app health check id",
"Interval": "15s",
@ -1322,12 +1266,12 @@ func TestParseConfig(t *testing.T) {
"dapr",
"test",
},
"Meta": map[interface{}]interface{}{
"Meta": map[any]any{
"APP_PORT": "123",
"DAPR_HTTP_PORT": "3500",
"DAPR_GRPC_PORT": "50005",
},
"QueryOptions": map[interface{}]interface{}{
"QueryOptions": map[any]any{
"UseCache": true,
"Filter": "Checks.ServiceTags contains dapr",
},
@ -1371,8 +1315,8 @@ func TestParseConfig(t *testing.T) {
{
"fail on unsupported map key",
false,
map[interface{}]interface{}{
1000: map[interface{}]interface{}{
map[any]any{
1000: map[any]any{
"DAPR_HTTP_PORT": "3500",
"DAPR_GRPC_PORT": "50005",
},
@ -1384,7 +1328,6 @@ func TestParseConfig(t *testing.T) {
for _, tt := range tests {
tt := tt
t.Run(tt.testName, func(t *testing.T) {
t.Parallel()
actual, err := parseConfig(tt.input)
if tt.shouldParse {
@ -1398,8 +1341,6 @@ func TestParseConfig(t *testing.T) {
}
func TestGetConfig(t *testing.T) {
t.Parallel()
tests := []struct {
testName string
metadata nr.Metadata
@ -1408,11 +1349,10 @@ func TestGetConfig(t *testing.T) {
{
"empty configuration should only return Client, QueryOptions and DaprPortMetaKey",
nr.Metadata{
Base: metadata.Base{Properties: getTestPropsWithoutKey("")},
Instance: getInstanceInfoWithoutKey(""),
Configuration: nil,
},
func(t *testing.T, metadata nr.Metadata) {
t.Helper()
actual, _ := getConfig(metadata)
// Client
@ -1435,13 +1375,12 @@ func TestGetConfig(t *testing.T) {
{
"empty configuration with SelfRegister should default correctly",
nr.Metadata{
Base: metadata.Base{Properties: getTestPropsWithoutKey("")},
Configuration: map[interface{}]interface{}{
Instance: getInstanceInfoWithoutKey(""),
Configuration: map[any]any{
"SelfRegister": true,
},
},
func(t *testing.T, metadata nr.Metadata) {
t.Helper()
actual, _ := getConfig(metadata)
// Client
assert.Equal(t, consul.DefaultConfig().Address, actual.Client.Address)
@ -1450,9 +1389,9 @@ func TestGetConfig(t *testing.T) {
assert.Equal(t, 1, len(actual.Registration.Checks))
check := actual.Registration.Checks[0]
assert.Equal(t, "Dapr Health Status", check.Name)
assert.Equal(t, "daprHealth:test-app-"+metadata.Properties[nr.HostAddress]+"-"+metadata.Properties[nr.DaprHTTPPort], check.CheckID)
assert.Equal(t, "daprHealth:test-app-"+metadata.Instance.Address+"-"+strconv.Itoa(metadata.Instance.DaprHTTPPort), check.CheckID)
assert.Equal(t, "15s", check.Interval)
assert.Equal(t, fmt.Sprintf("http://%s/v1.0/healthz?appid=%s", net.JoinHostPort(metadata.Properties[nr.HostAddress], metadata.Properties[nr.DaprHTTPPort]), metadata.Properties[nr.AppID]), check.HTTP)
assert.Equal(t, fmt.Sprintf("http://%s/v1.0/healthz?appid=%s", net.JoinHostPort(metadata.Instance.Address, strconv.Itoa(metadata.Instance.DaprHTTPPort)), metadata.Instance.AppID), check.HTTP)
// Metadata
assert.Equal(t, 1, len(actual.Registration.Meta))
@ -1471,17 +1410,16 @@ func TestGetConfig(t *testing.T) {
{
"DaprPortMetaKey should set registration meta and config used for resolve",
nr.Metadata{
Base: metadata.Base{Properties: getTestPropsWithoutKey("")},
Configuration: map[interface{}]interface{}{
Instance: getInstanceInfoWithoutKey(""),
Configuration: map[any]any{
"SelfRegister": true,
"DaprPortMetaKey": "random_key",
},
},
func(t *testing.T, metadata nr.Metadata) {
t.Helper()
actual, _ := getConfig(metadata)
daprPort := metadata.Properties[nr.DaprPort]
daprPort := strconv.Itoa(metadata.Instance.DaprInternalPort)
assert.Equal(t, "random_key", actual.DaprPortMetaKey)
assert.Equal(t, daprPort, actual.Registration.Meta["random_key"])
@ -1490,14 +1428,13 @@ func TestGetConfig(t *testing.T) {
{
"SelfDeregister should set DeregisterOnClose",
nr.Metadata{
Base: metadata.Base{Properties: getTestPropsWithoutKey("")},
Configuration: map[interface{}]interface{}{
Instance: getInstanceInfoWithoutKey(""),
Configuration: map[any]any{
"SelfRegister": true,
"SelfDeregister": true,
},
},
func(t *testing.T, metadata nr.Metadata) {
t.Helper()
actual, _ := getConfig(metadata)
assert.Equal(t, true, actual.DeregisterOnClose)
@ -1506,13 +1443,12 @@ func TestGetConfig(t *testing.T) {
{
"missing AppID property should error when SelfRegister true",
nr.Metadata{
Base: metadata.Base{Properties: getTestPropsWithoutKey(nr.AppID)},
Configuration: map[interface{}]interface{}{
Instance: getInstanceInfoWithoutKey("AppID"),
Configuration: map[any]any{
"SelfRegister": true,
},
},
func(t *testing.T, metadata nr.Metadata) {
t.Helper()
_, err := getConfig(metadata)
assert.Error(t, err)
assert.Contains(t, err.Error(), nr.AppID)
@ -1536,13 +1472,12 @@ func TestGetConfig(t *testing.T) {
{
"missing AppPort property should error when SelfRegister true",
nr.Metadata{
Base: metadata.Base{Properties: getTestPropsWithoutKey(nr.AppPort)},
Configuration: map[interface{}]interface{}{
Instance: getInstanceInfoWithoutKey("AppPort"),
Configuration: map[any]any{
"SelfRegister": true,
},
},
func(t *testing.T, metadata nr.Metadata) {
t.Helper()
_, err := getConfig(metadata)
assert.Error(t, err)
assert.Contains(t, err.Error(), nr.AppPort)
@ -1564,18 +1499,17 @@ func TestGetConfig(t *testing.T) {
},
},
{
"missing HostAddress property should error when SelfRegister true",
"missing Address property should error when SelfRegister true",
nr.Metadata{
Base: metadata.Base{Properties: getTestPropsWithoutKey(nr.HostAddress)},
Configuration: map[interface{}]interface{}{
Instance: getInstanceInfoWithoutKey("Address"),
Configuration: map[any]any{
"SelfRegister": true,
},
},
func(t *testing.T, metadata nr.Metadata) {
t.Helper()
_, err := getConfig(metadata)
assert.Error(t, err)
assert.Contains(t, err.Error(), nr.HostAddress)
assert.Contains(t, err.Error(), "HOST_ADDRESS")
metadata.Configuration = configSpec{
SelfRegister: false,
@ -1596,16 +1530,15 @@ func TestGetConfig(t *testing.T) {
{
"missing DaprHTTPPort property should error only when SelfRegister true",
nr.Metadata{
Base: metadata.Base{Properties: getTestPropsWithoutKey(nr.DaprHTTPPort)},
Configuration: map[interface{}]interface{}{
Instance: getInstanceInfoWithoutKey("DaprHTTPPort"),
Configuration: map[any]any{
"SelfRegister": true,
},
},
func(t *testing.T, metadata nr.Metadata) {
t.Helper()
_, err := getConfig(metadata)
assert.Error(t, err)
assert.Contains(t, err.Error(), nr.DaprHTTPPort)
assert.Contains(t, err.Error(), "DAPR_HTTP_PORT")
metadata.Configuration = configSpec{
SelfRegister: false,
@ -1624,19 +1557,18 @@ func TestGetConfig(t *testing.T) {
},
},
{
"missing DaprPort property should always error",
"missing DaprInternalPort property should always error",
nr.Metadata{
Base: metadata.Base{Properties: getTestPropsWithoutKey(nr.DaprPort)},
Instance: getInstanceInfoWithoutKey("DaprInternalPort"),
},
func(t *testing.T, metadata nr.Metadata) {
t.Helper()
metadata.Configuration = configSpec{
SelfRegister: false,
}
_, err := getConfig(metadata)
assert.Error(t, err)
assert.Contains(t, err.Error(), nr.DaprPort)
assert.Contains(t, err.Error(), "DAPR_PORT")
metadata.Configuration = configSpec{
SelfRegister: true,
@ -1644,7 +1576,7 @@ func TestGetConfig(t *testing.T) {
_, err = getConfig(metadata)
assert.Error(t, err)
assert.Contains(t, err.Error(), nr.DaprPort)
assert.Contains(t, err.Error(), "DAPR_PORT")
metadata.Configuration = configSpec{
AdvancedRegistration: &consul.AgentServiceRegistration{},
@ -1653,16 +1585,16 @@ func TestGetConfig(t *testing.T) {
_, err = getConfig(metadata)
assert.Error(t, err)
assert.Contains(t, err.Error(), nr.DaprPort)
assert.Contains(t, err.Error(), "DAPR_PORT")
},
},
{
"registration should configure correctly",
nr.Metadata{
Base: metadata.Base{Properties: getTestPropsWithoutKey("")},
Configuration: map[interface{}]interface{}{
Instance: getInstanceInfoWithoutKey(""),
Configuration: map[any]any{
"Checks": []interface{}{
map[interface{}]interface{}{
map[any]any{
"Name": "test-app health check name",
"CheckID": "test-app health check id",
"Interval": "15s",
@ -1672,11 +1604,11 @@ func TestGetConfig(t *testing.T) {
"Tags": []interface{}{
"test",
},
"Meta": map[interface{}]interface{}{
"Meta": map[any]any{
"APP_PORT": "8650",
"DAPR_GRPC_PORT": "50005",
},
"QueryOptions": map[interface{}]interface{}{
"QueryOptions": map[any]any{
"UseCache": false,
"Filter": "Checks.ServiceTags contains something",
},
@ -1686,16 +1618,13 @@ func TestGetConfig(t *testing.T) {
},
},
func(t *testing.T, metadata nr.Metadata) {
t.Helper()
actual, _ := getConfig(metadata)
appPort, _ := strconv.Atoi(metadata.Properties[nr.AppPort])
// Enabled Registration
assert.NotNil(t, actual.Registration)
assert.Equal(t, metadata.Properties[nr.AppID], actual.Registration.Name)
assert.Equal(t, metadata.Properties[nr.HostAddress], actual.Registration.Address)
assert.Equal(t, appPort, actual.Registration.Port)
assert.Equal(t, metadata.Instance.AppID, actual.Registration.Name)
assert.Equal(t, metadata.Instance.Address, actual.Registration.Address)
assert.Equal(t, metadata.Instance.AppPort, actual.Registration.Port)
assert.Equal(t, "test-app health check name", actual.Registration.Checks[0].Name)
assert.Equal(t, "test-app health check id", actual.Registration.Checks[0].CheckID)
assert.Equal(t, "15s", actual.Registration.Checks[0].Interval)
@ -1703,7 +1632,7 @@ func TestGetConfig(t *testing.T) {
assert.Equal(t, "test", actual.Registration.Tags[0])
assert.Equal(t, "8650", actual.Registration.Meta["APP_PORT"])
assert.Equal(t, "50005", actual.Registration.Meta["DAPR_GRPC_PORT"])
assert.Equal(t, metadata.Properties[nr.DaprPort], actual.Registration.Meta["PORT"])
assert.Equal(t, strconv.Itoa(metadata.Instance.DaprInternalPort), actual.Registration.Meta["PORT"])
assert.Equal(t, false, actual.QueryOptions.UseCache)
assert.Equal(t, "Checks.ServiceTags contains something", actual.QueryOptions.Filter)
assert.Equal(t, "PORT", actual.DaprPortMetaKey)
@ -1713,9 +1642,9 @@ func TestGetConfig(t *testing.T) {
{
"advanced registration should override/ignore other configs",
nr.Metadata{
Base: metadata.Base{Properties: getTestPropsWithoutKey("")},
Configuration: map[interface{}]interface{}{
"AdvancedRegistration": map[interface{}]interface{}{
Instance: getInstanceInfoWithoutKey(""),
Configuration: map[any]any{
"AdvancedRegistration": map[any]any{
"Name": "random-app-id",
"Port": 0o00,
"Address": "123.345.678",
@ -1724,7 +1653,7 @@ func TestGetConfig(t *testing.T) {
"APP_PORT": "000",
},
"Checks": []interface{}{
map[interface{}]interface{}{
map[any]any{
"Name": "random health check name",
"CheckID": "random health check id",
"Interval": "15s",
@ -1733,7 +1662,7 @@ func TestGetConfig(t *testing.T) {
},
},
"Checks": []interface{}{
map[interface{}]interface{}{
map[any]any{
"Name": "test-app health check name",
"CheckID": "test-app health check id",
"Interval": "15s",
@ -1753,7 +1682,6 @@ func TestGetConfig(t *testing.T) {
},
},
func(t *testing.T, metadata nr.Metadata) {
t.Helper()
actual, _ := getConfig(metadata)
// Enabled Registration
@ -1771,18 +1699,13 @@ func TestGetConfig(t *testing.T) {
for _, tt := range tests {
tt := tt
t.Run(tt.testName, func(t *testing.T) {
t.Parallel()
tt.test(t, tt.metadata)
})
}
}
func TestMapConfig(t *testing.T) {
t.Parallel()
t.Run("should map full configuration", func(t *testing.T) {
t.Helper()
expected := intermediateConfig{
Client: &Config{
Address: "Address",
@ -2074,8 +1997,6 @@ func TestMapConfig(t *testing.T) {
})
t.Run("should map empty configuration", func(t *testing.T) {
t.Helper()
expected := intermediateConfig{}
actual := mapConfig(expected)
@ -2215,17 +2136,29 @@ func compareCheck(t *testing.T, expected *AgentServiceCheck, actual *consul.Agen
assert.Equal(t, expected.GRPCUseTLS, actual.GRPCUseTLS)
}
func getTestPropsWithoutKey(removeKey string) map[string]string {
metadata := map[string]string{
nr.AppID: "test-app",
nr.AppPort: "8650",
nr.DaprPort: "50001",
nr.DaprHTTPPort: "3500",
nr.HostAddress: "127.0.0.1",
func getInstanceInfoWithoutKey(removeKey string) nr.Instance {
res := nr.Instance{
AppID: "test-app",
AppPort: 8650,
DaprInternalPort: 50001,
DaprHTTPPort: 3500,
Address: "127.0.0.1",
}
delete(metadata, removeKey)
return metadata
switch removeKey {
case "AppID":
res.AppID = ""
case "AppPort":
res.AppPort = 0
case "DaprInternalPort":
res.DaprInternalPort = 0
case "DaprHTTPPort":
res.DaprHTTPPort = 0
case "Address":
res.Address = ""
}
return res
}
func waitTillTrueOrTimeout(d time.Duration, condition func() bool) {

View File

@ -15,6 +15,7 @@ package kubernetes
import (
"bytes"
"context"
"strconv"
"text/template"
@ -53,7 +54,7 @@ func NewResolver(logger logger.Logger) nameresolution.Resolver {
}
// Init initializes Kubernetes name resolver.
func (k *resolver) Init(metadata nameresolution.Metadata) error {
func (k *resolver) Init(ctx context.Context, metadata nameresolution.Metadata) error {
configInterface, err := config.Normalize(metadata.Configuration)
if err != nil {
return err
@ -83,7 +84,7 @@ func (k *resolver) Init(metadata nameresolution.Metadata) error {
}
// ResolveID resolves name to address in Kubernetes.
func (k *resolver) ResolveID(req nameresolution.ResolveRequest) (string, error) {
func (k *resolver) ResolveID(ctx context.Context, req nameresolution.ResolveRequest) (string, error) {
if k.tmpl != nil {
return executeTemplateWithResolveRequest(k.tmpl, req)
}

View File

@ -14,6 +14,7 @@ limitations under the License.
package kubernetes
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
@ -27,7 +28,7 @@ func TestResolve(t *testing.T) {
request := nameresolution.ResolveRequest{ID: "myid", Namespace: "abc", Port: 1234}
const expect = "myid-dapr.abc.svc.cluster.local:1234"
target, err := resolver.ResolveID(request)
target, err := resolver.ResolveID(context.Background(), request)
assert.NoError(t, err)
assert.Equal(t, expect, target)
@ -35,7 +36,7 @@ func TestResolve(t *testing.T) {
func TestResolveWithCustomClusterDomain(t *testing.T) {
resolver := NewResolver(logger.NewLogger("test"))
_ = resolver.Init(nameresolution.Metadata{
_ = resolver.Init(context.Background(), nameresolution.Metadata{
Configuration: map[string]interface{}{
"clusterDomain": "mydomain.com",
},
@ -43,7 +44,7 @@ func TestResolveWithCustomClusterDomain(t *testing.T) {
request := nameresolution.ResolveRequest{ID: "myid", Namespace: "abc", Port: 1234}
const expect = "myid-dapr.abc.svc.mydomain.com:1234"
target, err := resolver.ResolveID(request)
target, err := resolver.ResolveID(context.Background(), request)
assert.NoError(t, err)
assert.Equal(t, expect, target)
@ -51,7 +52,7 @@ func TestResolveWithCustomClusterDomain(t *testing.T) {
func TestResolveWithTemplate(t *testing.T) {
resolver := NewResolver(logger.NewLogger("test"))
_ = resolver.Init(nameresolution.Metadata{
_ = resolver.Init(context.Background(), nameresolution.Metadata{
Configuration: map[string]interface{}{
"template": "{{.ID}}-{{.Namespace}}.internal:{{.Port}}",
},
@ -59,7 +60,7 @@ func TestResolveWithTemplate(t *testing.T) {
request := nameresolution.ResolveRequest{ID: "myid", Namespace: "abc", Port: 1234}
const expected = "myid-abc.internal:1234"
target, err := resolver.ResolveID(request)
target, err := resolver.ResolveID(context.Background(), request)
assert.NoError(t, err)
assert.Equal(t, target, expected)
@ -67,7 +68,7 @@ func TestResolveWithTemplate(t *testing.T) {
func TestResolveWithTemplateAndData(t *testing.T) {
resolver := NewResolver(logger.NewLogger("test"))
_ = resolver.Init(nameresolution.Metadata{
_ = resolver.Init(context.Background(), nameresolution.Metadata{
Configuration: map[string]interface{}{
"template": "{{.ID}}-{{.Data.region}}.internal:{{.Port}}",
},
@ -82,7 +83,7 @@ func TestResolveWithTemplateAndData(t *testing.T) {
},
}
const expected = "myid-myland.internal:1234"
target, err := resolver.ResolveID(request)
target, err := resolver.ResolveID(context.Background(), request)
assert.NoError(t, err)
assert.Equal(t, target, expected)

View File

@ -269,34 +269,23 @@ func (m *Resolver) startRefreshers() {
}
// Init registers service for mDNS.
func (m *Resolver) Init(metadata nameresolution.Metadata) error {
props := metadata.Properties
appID := props[nameresolution.AppID]
if appID == "" {
func (m *Resolver) Init(ctx context.Context, metadata nameresolution.Metadata) error {
if metadata.Instance.AppID == "" {
return errors.New("name is missing")
}
hostAddress := props[nameresolution.HostAddress]
if hostAddress == "" {
if metadata.Instance.Address == "" {
return errors.New("address is missing")
}
if props[nameresolution.DaprPort] == "" {
return errors.New("port is missing")
if metadata.Instance.DaprInternalPort <= 0 {
return errors.New("port is missing or invalid")
}
port, err := strconv.Atoi(props[nameresolution.DaprPort])
if err != nil {
return errors.New("port is invalid")
}
err = m.registerMDNS("", appID, []string{hostAddress}, port)
err := m.registerMDNS("", metadata.Instance.AppID, []string{metadata.Instance.Address}, metadata.Instance.DaprInternalPort)
if err != nil {
return err
}
m.logger.Infof("local service entry announced: %s -> %s:%d", appID, hostAddress, port)
m.logger.Infof("local service entry announced: %s -> %s:%d", metadata.Instance.AppID, metadata.Instance.Address, metadata.Instance.DaprInternalPort)
go m.startRefreshers()
@ -412,7 +401,7 @@ func (m *Resolver) registerMDNS(instanceID string, appID string, ips []string, p
}
// ResolveID resolves name to address via mDNS.
func (m *Resolver) ResolveID(req nameresolution.ResolveRequest) (string, error) {
func (m *Resolver) ResolveID(parentCtx context.Context, req nameresolution.ResolveRequest) (string, error) {
// check for cached IPv4 addresses for this app id first.
if addr := m.nextIPv4Address(req.ID); addr != nil {
return *addr, nil
@ -445,7 +434,7 @@ func (m *Resolver) ResolveID(req nameresolution.ResolveRequest) (string, error)
// requested app id. The rest will subscribe for an address or error.
var once *sync.Once
var published chan struct{}
ctx, cancel := context.WithTimeout(context.Background(), browseOneTimeout)
ctx, cancel := context.WithTimeout(parentCtx, browseOneTimeout)
defer cancel()
appIDSubs.Once.Do(func() {
published = make(chan struct{})

View File

@ -14,6 +14,7 @@ limitations under the License.
package mdns
import (
"context"
"fmt"
"math"
"sync"
@ -24,7 +25,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dapr/components-contrib/metadata"
nr "github.com/dapr/components-contrib/nameresolution"
"github.com/dapr/kit/logger"
)
@ -37,35 +37,35 @@ const (
func TestInitMetadata(t *testing.T) {
tests := []struct {
missingProp string
props map[string]string
instance nr.Instance
}{
{
"name",
map[string]string{
nr.HostAddress: localhost,
nr.DaprPort: "30003",
nr.Instance{
Address: localhost,
DaprInternalPort: 30003,
},
},
{
"address",
map[string]string{
nr.AppID: "testAppID",
nr.DaprPort: "30003",
nr.Instance{
AppID: "testAppID",
DaprInternalPort: 30003,
},
},
{
"port",
map[string]string{
nr.AppID: "testAppID",
nr.HostAddress: localhost,
nr.Instance{
AppID: "testAppID",
Address: localhost,
},
},
{
"port",
map[string]string{
nr.AppID: "testAppID",
nr.HostAddress: localhost,
nr.DaprPort: "abcd",
nr.Instance{
AppID: "testAppID",
Address: localhost,
DaprInternalPort: 0,
},
},
}
@ -77,7 +77,7 @@ func TestInitMetadata(t *testing.T) {
for _, tt := range tests {
t.Run(tt.missingProp+" is missing", func(t *testing.T) {
// act
err := resolver.Init(nr.Metadata{Base: metadata.Base{Properties: tt.props}})
err := resolver.Init(context.Background(), nr.Metadata{Instance: tt.instance})
// assert
assert.Error(t, err)
@ -89,14 +89,14 @@ func TestInitRegister(t *testing.T) {
// arrange
resolver := NewResolver(logger.NewLogger("test")).(*Resolver)
defer resolver.Close()
md := nr.Metadata{Base: metadata.Base{Properties: map[string]string{
nr.AppID: "testAppID",
nr.HostAddress: localhost,
nr.DaprPort: "1234",
}}}
md := nr.Metadata{Instance: nr.Instance{
AppID: "testAppID",
Address: localhost,
DaprInternalPort: 1234,
}}
// act
err := resolver.Init(md)
err := resolver.Init(context.Background(), md)
require.NoError(t, err)
}
@ -104,21 +104,21 @@ func TestInitRegisterDuplicate(t *testing.T) {
// arrange
resolver := NewResolver(logger.NewLogger("test")).(*Resolver)
defer resolver.Close()
md := nr.Metadata{Base: metadata.Base{Properties: map[string]string{
nr.AppID: "testAppID",
nr.HostAddress: localhost,
nr.DaprPort: "1234",
}}}
md2 := nr.Metadata{Base: metadata.Base{Properties: map[string]string{
nr.AppID: "testAppID",
nr.HostAddress: localhost,
nr.DaprPort: "1234",
}}}
md := nr.Metadata{Instance: nr.Instance{
AppID: "testAppID",
Address: localhost,
DaprInternalPort: 1234,
}}
md2 := nr.Metadata{Instance: nr.Instance{
AppID: "testAppID",
Address: localhost,
DaprInternalPort: 1234,
}}
// act
err := resolver.Init(md)
err := resolver.Init(context.Background(), md)
require.NoError(t, err)
err = resolver.Init(md2)
err = resolver.Init(context.Background(), md2)
expectedError := "app id testAppID already registered for port 1234"
require.EqualErrorf(t, err, expectedError, "Error should be: %v, got %v", expectedError, err)
}
@ -127,18 +127,18 @@ func TestResolver(t *testing.T) {
// arrange
resolver := NewResolver(logger.NewLogger("test")).(*Resolver)
defer resolver.Close()
md := nr.Metadata{Base: metadata.Base{Properties: map[string]string{
nr.AppID: "testAppID",
nr.HostAddress: localhost,
nr.DaprPort: "1234",
}}}
md := nr.Metadata{Instance: nr.Instance{
AppID: "testAppID",
Address: localhost,
DaprInternalPort: 1234,
}}
// act
err := resolver.Init(md)
err := resolver.Init(context.Background(), md)
require.NoError(t, err)
request := nr.ResolveRequest{ID: "testAppID"}
pt, err := resolver.ResolveID(request)
pt, err := resolver.ResolveID(context.Background(), request)
// assert
require.NoError(t, err)
@ -148,18 +148,18 @@ func TestResolver(t *testing.T) {
func TestResolverClose(t *testing.T) {
// arrange
resolver := NewResolver(logger.NewLogger("test")).(*Resolver)
md := nr.Metadata{Base: metadata.Base{Properties: map[string]string{
nr.AppID: "testAppID",
nr.HostAddress: localhost,
nr.DaprPort: "1234",
}}}
md := nr.Metadata{Instance: nr.Instance{
AppID: "testAppID",
Address: localhost,
DaprInternalPort: 1234,
}}
// act
err := resolver.Init(md)
err := resolver.Init(context.Background(), md)
require.NoError(t, err)
request := nr.ResolveRequest{ID: "testAppID"}
pt, err := resolver.ResolveID(request)
pt, err := resolver.ResolveID(context.Background(), request)
// assert
require.NoError(t, err)
@ -203,7 +203,7 @@ func TestResolverMultipleInstances(t *testing.T) {
request := nr.ResolveRequest{ID: "testAppID"}
// first resolution will return the first responder's address and trigger a cache refresh.
addr1, err := resolver.ResolveID(request)
addr1, err := resolver.ResolveID(context.Background(), request)
require.NoError(t, err)
require.Contains(t, []string{instanceAPQDN, instanceBPQDN}, addr1)
@ -218,7 +218,7 @@ func TestResolverMultipleInstances(t *testing.T) {
instanceACount := atomic.Uint32{}
instanceBCount := atomic.Uint32{}
for i := 0; i < 100; i++ {
addr, err := resolver.ResolveID(request)
addr, err := resolver.ResolveID(context.Background(), request)
require.NoError(t, err)
require.Contains(t, []string{instanceAPQDN, instanceBPQDN}, addr)
if addr == instanceAPQDN {
@ -239,7 +239,7 @@ func TestResolverNotFound(t *testing.T) {
// act
request := nr.ResolveRequest{ID: "testAppIDNotFound"}
pt, err := resolver.ResolveID(request)
pt, err := resolver.ResolveID(context.Background(), request)
// assert
expectedError := "couldn't find service: testAppIDNotFound"
@ -281,14 +281,14 @@ func ResolverConcurrencySubsriberClear(t *testing.T) {
// arrange
resolver := NewResolver(logger.NewLogger("test")).(*Resolver)
defer resolver.Close()
md := nr.Metadata{Base: metadata.Base{Properties: map[string]string{
nr.AppID: "testAppID",
nr.HostAddress: localhost,
nr.DaprPort: "1234",
}}}
md := nr.Metadata{Instance: nr.Instance{
AppID: "testAppID",
Address: localhost,
DaprInternalPort: 1234,
}}
// act
err := resolver.Init(md)
err := resolver.Init(context.Background(), md)
require.NoError(t, err)
request := nr.ResolveRequest{ID: "testAppID"}
@ -299,7 +299,7 @@ func ResolverConcurrencySubsriberClear(t *testing.T) {
go func() {
defer wg.Done()
pt, err := resolver.ResolveID(request)
pt, err := resolver.ResolveID(context.Background(), request)
require.NoError(t, err)
require.Equal(t, fmt.Sprintf("%s:1234", localhost), pt)
}()
@ -371,7 +371,7 @@ func ResolverConcurrencyFound(t *testing.T) {
request := nr.ResolveRequest{ID: appID}
start := time.Now()
pt, err := resolver.ResolveID(request)
pt, err := resolver.ResolveID(context.Background(), request)
elapsed := time.Since(start)
// assert
require.NoError(t, err)
@ -420,7 +420,7 @@ func ResolverConcurrencyNotFound(t *testing.T) {
// act
start := time.Now()
pt, err := resolver.ResolveID(request)
pt, err := resolver.ResolveID(context.Background(), request)
elapsed := time.Since(start)
// assert

View File

@ -13,8 +13,13 @@ limitations under the License.
package nameresolution
import "github.com/dapr/components-contrib/metadata"
import (
"strconv"
"github.com/dapr/components-contrib/metadata"
)
// These constants are used for the "legacy" way to pass instance information using a map.
const (
// HostAddress is the address of the instance.
HostAddress string = "HOST_ADDRESS"
@ -26,10 +31,53 @@ const (
AppPort string = "APP_PORT"
// AppID is the ID of the application.
AppID string = "APP_ID"
// Namespace is the namespace of the application.
Namespace string = "NAMESPACE"
)
// Metadata contains a name resolution specific set of metadata properties.
type Metadata struct {
metadata.Base `json:",inline"`
Configuration interface{}
Instance Instance
Configuration any
}
// Instance contains information about the instance.
type Instance struct {
// App ID.
AppID string
// Namespace of the app.
Namespace string
// Address of the instance.
Address string
// Dapr HTTP API port.
DaprHTTPPort int
// Dapr internal gRPC port (for sidecar-to-sidecar communication).
DaprInternalPort int
// Port the application is listening on (either HTTP or gRPC).
AppPort int
}
// GetPropertiesMap returns a map with the instance properties.
// This is used by components that haven't adopted the new Instance struct to receive instance information.
func (m Metadata) GetPropertiesMap() map[string]string {
var daprHTTPPort, daprPort, appPort string
if m.Instance.DaprHTTPPort > 0 {
daprHTTPPort = strconv.Itoa(m.Instance.DaprHTTPPort)
}
if m.Instance.DaprInternalPort > 0 {
daprPort = strconv.Itoa(m.Instance.DaprInternalPort)
}
if m.Instance.AppPort > 0 {
appPort = strconv.Itoa(m.Instance.AppPort)
}
return map[string]string{
HostAddress: m.Instance.Address,
DaprHTTPPort: daprHTTPPort,
DaprPort: daprPort,
AppPort: appPort,
AppID: m.Instance.AppID,
Namespace: m.Instance.Namespace,
}
}

View File

@ -13,10 +13,14 @@ limitations under the License.
package nameresolution
import (
"context"
)
// Resolver is the interface of name resolver.
type Resolver interface {
// Init initializes name resolver.
Init(metadata Metadata) error
Init(ctx context.Context, metadata Metadata) error
// ResolveID resolves name to address.
ResolveID(req ResolveRequest) (string, error)
ResolveID(ctx context.Context, req ResolveRequest) (string, error)
}

View File

@ -0,0 +1,328 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package sqlite
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
internalsql "github.com/dapr/components-contrib/internal/component/sql"
"github.com/dapr/components-contrib/nameresolution"
"github.com/dapr/kit/logger"
)
// ErrNoHost is returned by ResolveID when no host can be found.
var ErrNoHost = errors.New("no host found with the given ID")
// Internally-used error to indicate the registration was lost
var errRegistrationLost = errors.New("host registration lost")
type resolver struct {
logger logger.Logger
metadata sqliteMetadata
db *sql.DB
gc internalsql.GarbageCollector
registrationID string
closed atomic.Bool
closeCh chan struct{}
wg sync.WaitGroup
}
// NewResolver creates a name resolver that is based on a SQLite DB.
func NewResolver(logger logger.Logger) nameresolution.Resolver {
return &resolver{
logger: logger,
closeCh: make(chan struct{}),
}
}
// Init initializes the name resolver.
func (s *resolver) Init(ctx context.Context, md nameresolution.Metadata) error {
if s.closed.Load() {
return errors.New("component is closed")
}
err := s.metadata.InitWithMetadata(md)
if err != nil {
return err
}
connString, err := s.metadata.GetConnectionString(s.logger)
if err != nil {
// Already logged
return err
}
// Show a warning if SQLite is configured with an in-memory DB
if s.metadata.SqliteAuthMetadata.IsInMemoryDB() {
s.logger.Warn("Configuring name resolution with an in-memory SQLite database. Service invocation across different apps will not work.")
} else {
s.logger.Infof("Configuring SQLite name resolution with path %s", connString[len("file:"):strings.Index(connString, "?")])
}
s.db, err = sql.Open("sqlite", connString)
if err != nil {
return fmt.Errorf("failed to create connection: %w", err)
}
// Performs migrations
err = performMigrations(ctx, s.db, s.logger, migrationOptions{
HostsTableName: s.metadata.TableName,
MetadataTableName: s.metadata.MetadataTableName,
})
if err != nil {
return fmt.Errorf("failed to perform migrations: %w", err)
}
// Init the background GC
err = s.initGC()
if err != nil {
return err
}
// Register the host and update in background
err = s.registerHost(ctx)
if err != nil {
return err
}
s.wg.Add(1)
go s.renewRegistration()
return nil
}
func (s *resolver) initGC() (err error) {
s.gc, err = internalsql.ScheduleGarbageCollector(internalsql.GCOptions{
Logger: s.logger,
UpdateLastCleanupQuery: func(arg any) (string, any) {
return fmt.Sprintf(`INSERT INTO %s (key, value)
VALUES ('nr-last-cleanup', CURRENT_TIMESTAMP)
ON CONFLICT (key)
DO UPDATE SET value = CURRENT_TIMESTAMP
WHERE unixepoch(CURRENT_TIMESTAMP) - unixepoch(value) > ?;`,
s.metadata.MetadataTableName,
), arg
},
DeleteExpiredValuesQuery: fmt.Sprintf(
`DELETE FROM %s WHERE unixepoch(CURRENT_TIMESTAMP) - last_update < %d`,
s.metadata.TableName,
int(s.metadata.UpdateInterval.Seconds()),
),
CleanupInterval: s.metadata.CleanupInterval,
DB: internalsql.AdaptDatabaseSQLConn(s.db),
})
return err
}
// Registers the host
func (s *resolver) registerHost(ctx context.Context) error {
// Get the registration ID
u, err := uuid.NewRandom()
if err != nil {
return fmt.Errorf("failed to generate registration ID: %w", err)
}
s.registrationID = u.String()
queryCtx, queryCancel := context.WithTimeout(ctx, s.metadata.Timeout)
defer queryCancel()
// There's a unique index on address
// We use REPLACE to take over any previous registration for that address
// TODO: Add support for namespacing. See https://github.com/dapr/components-contrib/issues/3179
_, err = s.db.ExecContext(queryCtx,
fmt.Sprintf("REPLACE INTO %s (registration_id, address, app_id, namespace, last_update) VALUES (?, ?, ?, ?, unixepoch(CURRENT_TIMESTAMP))", s.metadata.TableName),
s.registrationID, s.metadata.GetAddress(), s.metadata.appID, "",
)
if err != nil {
return fmt.Errorf("failed to register host: %w", err)
}
return nil
}
// In backgrounds, periodically renews the host's registration
// Should be invoked in a background goroutine
func (s *resolver) renewRegistration() {
defer s.wg.Done()
addr := s.metadata.GetAddress()
// Update every UpdateInterval - Timeout (+ 1 second buffer)
// This is because the record has to be updated every UpdateInterval, but we allow up to "timeout" for it to be performed
d := s.metadata.UpdateInterval - s.metadata.Timeout - 1
s.logger.Debugf("Started renewing host registration in background with interval %v", s.metadata.UpdateInterval)
t := time.NewTicker(d)
defer t.Stop()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for {
select {
case <-s.closeCh:
// Component is closing
s.logger.Debug("Stopped renewing host registration: component is closing")
return
case <-t.C:
// Renew on the ticker
s.wg.Add(1)
go func() {
defer s.wg.Done()
err := s.doRenewRegistration(ctx, addr)
if err != nil {
// Log errors
s.logger.Errorf("Failed to update host registration: %v", err)
if errors.Is(err, errRegistrationLost) {
// This means that our registration has been taken over by another host
// It should never happen unless there's something really bad going on
// Panicking here to force a restart of Dapr
s.logger.Fatalf("Host registration lost")
}
}
}()
}
}
}
func (s *resolver) doRenewRegistration(ctx context.Context, addr string) error {
// We retry this query in case of database error, up to the timeout
queryCtx, queryCancel := context.WithTimeout(ctx, s.metadata.Timeout)
defer queryCancel()
// We use string formatting here for the table name only
//nolint:gosec
query := fmt.Sprintf("UPDATE %s SET last_update = unixepoch(CURRENT_TIMESTAMP) WHERE registration_id = ? AND address = ?", s.metadata.TableName)
b := backoff.WithContext(backoff.NewConstantBackOff(50*time.Millisecond), queryCtx)
return backoff.Retry(func() error {
res, err := s.db.ExecContext(queryCtx, query, s.registrationID, addr)
if err != nil {
return fmt.Errorf("database error: %w", err)
}
n, _ := res.RowsAffected()
if n == 0 {
// This is a permanent error
return backoff.Permanent(errRegistrationLost)
}
return nil
}, b)
}
// ResolveID resolves name to address.
func (s *resolver) ResolveID(ctx context.Context, req nameresolution.ResolveRequest) (addr string, err error) {
queryCtx, queryCancel := context.WithTimeout(ctx, s.metadata.Timeout)
defer queryCancel()
//nolint:gosec
q := fmt.Sprintf(
// See: https://stackoverflow.com/a/24591696
`SELECT address
FROM %[1]s
WHERE
ROWID = (
SELECT ROWID
FROM %[1]s
WHERE
app_id = ?
AND unixepoch(CURRENT_TIMESTAMP) - last_update < %[2]d
ORDER BY RANDOM()
LIMIT 1
)`,
s.metadata.TableName,
int(s.metadata.UpdateInterval.Seconds()),
)
err = s.db.QueryRowContext(queryCtx, q, req.ID).Scan(&addr)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", ErrNoHost
}
return "", fmt.Errorf("failed to look up address: %w", err)
}
return addr, nil
}
// Removes the registration for the host
func (s *resolver) deregisterHost(ctx context.Context) error {
if s.registrationID == "" {
// We never registered
return nil
}
queryCtx, queryCancel := context.WithTimeout(ctx, s.metadata.Timeout)
defer queryCancel()
res, err := s.db.ExecContext(queryCtx,
fmt.Sprintf("DELETE FROM %s WHERE registration_id = ? AND address = ?", s.metadata.TableName),
s.registrationID, s.metadata.GetAddress(),
)
if err != nil {
return fmt.Errorf("failed to unregister host: %w", err)
}
n, _ := res.RowsAffected()
if n == 0 {
return errors.New("failed to unregister host: no row deleted")
}
return nil
}
// Close implements io.Closer.
func (s *resolver) Close() (err error) {
if !s.closed.CompareAndSwap(false, true) {
s.wg.Wait()
return nil
}
close(s.closeCh)
s.wg.Wait()
errs := make([]error, 0)
if s.gc != nil {
err = s.gc.Close()
if err != nil {
errs = append(errs, err)
}
}
if s.db != nil {
err := s.deregisterHost(context.Background())
if err != nil {
errs = append(errs, err)
}
err = s.db.Close()
if err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
}

View File

@ -0,0 +1,125 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package sqlite
import (
"errors"
"fmt"
"net"
"strconv"
"time"
authSqlite "github.com/dapr/components-contrib/internal/authentication/sqlite"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/nameresolution"
)
const (
defaultTableName = "hosts"
defaultMetadataTableName = "metadata"
defaultUpdateInterval = 5 * time.Second
defaultCleanupInternal = time.Hour
// For a nameresolver, we want a fairly low timeout
defaultTimeout = time.Second
defaultBusyTimeout = 800 * time.Millisecond
)
type sqliteMetadata struct {
// Config options - passed by the user via the Configuration resource
authSqlite.SqliteAuthMetadata `mapstructure:",squash"`
TableName string `mapstructure:"tableName"`
MetadataTableName string `mapstructure:"metadataTableName"`
UpdateInterval time.Duration `mapstructure:"updateInterval"` // Units smaller than seconds are not accepted
CleanupInterval time.Duration `mapstructure:"cleanupInterval" mapstructurealiases:"cleanupIntervalInSeconds"`
// Instance properties - these are passed by the runtime
appID string
namespace string
hostAddress string
port int
}
func (m *sqliteMetadata) InitWithMetadata(meta nameresolution.Metadata) error {
// Reset the object
m.reset()
// Set and validate the instance properties
m.appID = meta.Instance.AppID
if m.appID == "" {
return errors.New("name is missing")
}
m.hostAddress = meta.Instance.Address
if m.hostAddress == "" {
return errors.New("address is missing")
}
m.port = meta.Instance.DaprInternalPort
if m.port == 0 {
return errors.New("port is missing or invalid")
}
m.namespace = meta.Instance.Namespace // Can be empty
// Decode the configuration using DecodeMetadata
err := metadata.DecodeMetadata(meta.Configuration, &m)
if err != nil {
return err
}
// Validate and sanitize configuration
err = m.SqliteAuthMetadata.Validate()
if err != nil {
return err
}
if !authSqlite.ValidIdentifier(m.TableName) {
return fmt.Errorf("invalid identifier for table name: %s", m.TableName)
}
if !authSqlite.ValidIdentifier(m.MetadataTableName) {
return fmt.Errorf("invalid identifier for metadata table name: %s", m.MetadataTableName)
}
// For updateInterval, we do not accept units smaller than seconds due to implementation limitations with SQLite
if m.UpdateInterval != m.UpdateInterval.Truncate(time.Second) {
return errors.New("update interval must not contain fractions of seconds")
}
// UpdateInterval must also be greater than Timeout
if (m.UpdateInterval - m.Timeout) < time.Second {
return errors.New("update interval must be at least 1s greater than timeout")
}
return nil
}
func (m sqliteMetadata) GetAddress() string {
return net.JoinHostPort(m.hostAddress, strconv.Itoa(m.port))
}
// Reset the object
func (m *sqliteMetadata) reset() {
m.SqliteAuthMetadata.Reset()
// We lower the default thresholds for the nameresolver
m.Timeout = defaultTimeout
m.BusyTimeout = defaultBusyTimeout
m.TableName = defaultTableName
m.MetadataTableName = defaultMetadataTableName
m.UpdateInterval = defaultUpdateInterval
m.CleanupInterval = defaultCleanupInternal
m.appID = ""
m.namespace = ""
m.hostAddress = ""
m.port = 0
}

View File

@ -0,0 +1,65 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package sqlite
import (
"context"
"database/sql"
"fmt"
sqlinternal "github.com/dapr/components-contrib/internal/component/sql"
sqlitemigrations "github.com/dapr/components-contrib/internal/component/sql/migrations/sqlite"
"github.com/dapr/kit/logger"
)
type migrationOptions struct {
HostsTableName string
MetadataTableName string
}
// Perform the required migrations
func performMigrations(ctx context.Context, db *sql.DB, logger logger.Logger, opts migrationOptions) error {
m := sqlitemigrations.Migrations{
Pool: db,
Logger: logger,
MetadataTableName: opts.MetadataTableName,
MetadataKey: "nr-migrations",
}
return m.Perform(ctx, []sqlinternal.MigrationFn{
// Migration 0: create the hosts table
func(ctx context.Context) error {
logger.Infof("Creating hosts table '%s'", opts.HostsTableName)
_, err := m.GetConn().ExecContext(
ctx,
fmt.Sprintf(
`CREATE TABLE %[1]s (
registration_id TEXT NOT NULL PRIMARY KEY,
address TEXT NOT NULL,
app_id TEXT NOT NULL,
namespace TEXT NOT NULL,
last_update INTEGER NOT NULL
);
CREATE UNIQUE INDEX %[1]s_address_idx ON %[1]s (address);
CREATE INDEX %[1]s_last_update_idx ON %[1]s (last_update);`,
opts.HostsTableName,
),
)
if err != nil {
return fmt.Errorf("failed to create hosts table: %w", err)
}
return nil
},
})
}

View File

@ -0,0 +1,153 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package sqlite
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/dapr/components-contrib/nameresolution"
"github.com/dapr/kit/logger"
)
func TestSqliteNameResolver(t *testing.T) {
nr := NewResolver(logger.NewLogger("test")).(*resolver)
t.Run("Init", func(t *testing.T) {
err := nr.Init(context.Background(), nameresolution.Metadata{
Instance: nameresolution.Instance{
Address: "127.0.0.1",
DaprInternalPort: 1234,
AppID: "myapp",
},
Configuration: map[string]string{
"connectionString": ":memory:",
"cleanupInterval": "0",
"updateInterval": "120s",
},
})
require.NoError(t, err)
})
require.False(t, t.Failed(), "Cannot continue if init step failed")
t.Run("Populate test data", func(t *testing.T) {
// Note updateInterval is 120s
now := time.Now().Unix()
rows := [][]any{
{"2cb5f837", "1.1.1.1:1", "app-1", "", now},
{"4d1e7b11", "1.1.1.1:2", "app-1", "", now},
{"05add1fa", "1.1.1.1:3", "app-1", "", now},
{"f1b24d4b", "2.2.2.2:1", "app-2", "", now},
{"23fb164f", "2.2.2.2:2", "app-2", "", now - 200},
{"db50a29e", "3.3.3.3:1", "app-3", "", now},
{"eef793d4", "4.4.4.4:1", "app-4", "", now - 200},
{"ef06eb49", "5.5.5.5:1", "app-5", "", now},
{"b0e6cd89", "6.6.6.6:1", "app-6", "", now},
{"36e99c68", "7.7.7.7:1", "app-7", "", now},
{"f77ed318", "8.8.8.8:1", "app-8", "", now - 100},
}
for i, r := range rows {
_, err := nr.db.Exec("INSERT INTO hosts VALUES (?, ?, ?, ?, ?)", r...)
require.NoErrorf(t, err, "Failed to insert row %d", i)
}
})
if t.Failed() {
nr.Close()
require.Fail(t, "Cannot continue if populate test data step failed")
}
t.Run("Resolve", func(t *testing.T) {
type testCase struct {
appID string
expectEmpty bool
expectOne string
expectAny []string
}
tt := map[string]testCase{
"single host resolved 1": {appID: "app-5", expectOne: "5.5.5.5:1"},
"single host resolved 2": {appID: "app-8", expectAny: []string{"8.8.8.8:1"}}, // Use expectAny to make the test run multiple times
"not found": {appID: "notfound", expectEmpty: true},
"host expired": {appID: "app-4", expectEmpty: true},
"multiple hosts found": {appID: "app-1", expectAny: []string{"1.1.1.1:1", "1.1.1.1:2", "1.1.1.1:3"}},
"one host expired": {appID: "app-2", expectAny: []string{"2.2.2.2:1"}}, // Use expectAny to make the test run multiple times
}
for name, tc := range tt {
t.Run(name, func(t *testing.T) {
if len(tc.expectAny) == 0 {
res, err := nr.ResolveID(context.Background(), nameresolution.ResolveRequest{ID: tc.appID})
if tc.expectEmpty {
require.Error(t, err)
require.ErrorIs(t, err, ErrNoHost)
require.Empty(t, res)
} else {
require.NoError(t, err)
require.Equal(t, tc.expectOne, res)
}
} else {
for i := 0; i < 20; i++ {
res, err := nr.ResolveID(context.Background(), nameresolution.ResolveRequest{ID: tc.appID})
require.NoErrorf(t, err, "Error on iteration %d", i)
require.Contains(t, tc.expectAny, res)
}
}
})
}
})
// Simulate the ticker
t.Run("Renew registration", func(t *testing.T) {
t.Run("Succeess", func(t *testing.T) {
const addr = "127.0.0.1:1234"
// Get current last_update value
var lastUpdate int
err := nr.db.QueryRow("SELECT last_update FROM hosts WHERE address = ?", addr).Scan(&lastUpdate)
require.NoError(t, err)
// Must sleep for 1s
time.Sleep(time.Second)
// Renew
err = nr.doRenewRegistration(context.Background(), addr)
require.NoError(t, err)
// Get updated last_update
var newLastUpdate int
err = nr.db.QueryRow("SELECT last_update FROM hosts WHERE address = ?", addr).Scan(&newLastUpdate)
require.NoError(t, err)
// Should have increased
require.Greater(t, newLastUpdate, lastUpdate)
})
t.Run("Lost registration", func(t *testing.T) {
// Renew
err := nr.doRenewRegistration(context.Background(), "fail")
require.Error(t, err)
require.ErrorIs(t, err, errRegistrationLost)
})
})
t.Run("Close", func(t *testing.T) {
err := nr.Close()
require.NoError(t, err)
})
}

View File

@ -26,7 +26,7 @@ type snsSqsMetadata struct {
// aws partition in which SNS/SQS should create resources.
internalPartition string `mapstructure:"-"`
// name of the queue for this application. The is provided by the runtime as "consumerID".
SqsQueueName string `mapstructure:"consumerID" mdignore:"true"`
SqsQueueName string `mapstructure:"consumerID" mdignore:"true"`
// name of the dead letter queue for this application.
SqsDeadLettersQueueName string `mapstructure:"sqsDeadLettersQueueName"`
// flag to SNS and SQS FIFO.

View File

@ -41,36 +41,24 @@ import (
"github.com/dapr/kit/logger"
)
type topicHandler struct {
topicName string
handler pubsub.Handler
ctx context.Context
}
type snsSqs struct {
topicsLocker TopicsLocker
// key is the sanitized topic name
topicArns map[string]string
// key is the sanitized topic name
topicHandlers map[string]topicHandler
topicsLock sync.RWMutex
// key is the topic name, value holds the ARN of the queue and its url.
queues sync.Map
queues map[string]*sqsQueueInfo
// key is a composite key of queue ARN and topic ARN mapping to subscription ARN.
subscriptions sync.Map
snsClient *sns.SNS
sqsClient *sqs.SQS
stsClient *sts.STS
metadata *snsSqsMetadata
logger logger.Logger
id string
opsTimeout time.Duration
backOffConfig retry.Config
pollerRunning chan struct{}
closeCh chan struct{}
closed atomic.Bool
wg sync.WaitGroup
subscriptions map[string]string
snsClient *sns.SNS
sqsClient *sqs.SQS
stsClient *sts.STS
metadata *snsSqsMetadata
logger logger.Logger
id string
opsTimeout time.Duration
backOffConfig retry.Config
subscriptionManager SubscriptionManagement
closed atomic.Bool
}
type sqsQueueInfo struct {
@ -105,11 +93,8 @@ func NewSnsSqs(l logger.Logger) pubsub.PubSub {
}
return &snsSqs{
logger: l,
id: id,
topicsLock: sync.RWMutex{},
pollerRunning: make(chan struct{}, 1),
closeCh: make(chan struct{}),
logger: l,
id: id,
}
}
@ -160,13 +145,6 @@ func (s *snsSqs) Init(ctx context.Context, metadata pubsub.Metadata) error {
s.metadata = md
// both Publish and Subscribe need reference the topic ARN, queue ARN and subscription ARN between topic and queue
// track these ARNs in these maps.
s.topicArns = make(map[string]string)
s.topicHandlers = make(map[string]topicHandler)
s.queues = sync.Map{}
s.subscriptions = sync.Map{}
sess, err := awsAuth.GetClient(md.AccessKey, md.SecretKey, md.SessionToken, md.Region, md.Endpoint)
if err != nil {
return fmt.Errorf("error creating an AWS client: %w", err)
@ -189,6 +167,13 @@ func (s *snsSqs) Init(ctx context.Context, metadata pubsub.Metadata) error {
if err != nil {
return fmt.Errorf("error decoding backOff config: %w", err)
}
// subscription manager responsible for managing the lifecycle of subscriptions.
s.subscriptionManager = NewSubscriptionMgmt(s.logger)
s.topicsLocker = NewLockManager()
s.topicArns = make(map[string]string)
s.queues = make(map[string]*sqsQueueInfo)
s.subscriptions = make(map[string]string)
return nil
}
@ -243,7 +228,7 @@ func (s *snsSqs) getTopicArn(parentCtx context.Context, topic string) (string, e
})
cancelFn()
if err != nil {
return "", fmt.Errorf("error: %w while getting topic: %v with arn: %v", err, topic, arn)
return "", fmt.Errorf("error: %w, while getting (sanitized) topic: %v with arn: %v", err, topic, arn)
}
return *getTopicOutput.Attributes["TopicArn"], nil
@ -251,40 +236,45 @@ func (s *snsSqs) getTopicArn(parentCtx context.Context, topic string) (string, e
// get the topic ARN from the topics map. If it doesn't exist in the map, try to fetch it from AWS, if it doesn't exist
// at all, issue a request to create the topic.
func (s *snsSqs) getOrCreateTopic(ctx context.Context, topic string) (topicArn string, sanitizedName string, err error) {
s.topicsLock.Lock()
defer s.topicsLock.Unlock()
func (s *snsSqs) getOrCreateTopic(ctx context.Context, topic string) (topicArn string, sanitizedTopic string, err error) {
sanitizedTopic = nameToAWSSanitizedName(topic, s.metadata.Fifo)
sanitizedName = nameToAWSSanitizedName(topic, s.metadata.Fifo)
var loadOK bool
if topicArn, loadOK = s.topicArns[sanitizedTopic]; loadOK {
if len(topicArn) > 0 {
s.logger.Debugf("Found existing topic ARN for topic %s: %s", topic, topicArn)
topicArnCached, ok := s.topicArns[sanitizedName]
if ok && topicArnCached != "" {
s.logger.Debugf("found existing topic ARN for topic %s: %s", topic, topicArnCached)
return topicArnCached, sanitizedName, nil
return topicArn, sanitizedTopic, err
} else {
err = fmt.Errorf("the ARN for (sanitized) topic: %s was empty", sanitizedTopic)
return topicArn, sanitizedTopic, err
}
}
// creating queues is idempotent, the names serve as unique keys among a given region.
s.logger.Debugf("No SNS topic arn found for %s\nCreating SNS topic", topic)
s.logger.Debugf("No SNS topic ARN found for topic: %s. creating SNS with (sanitized) topic: %s", topic, sanitizedTopic)
if !s.metadata.DisableEntityManagement {
topicArn, err = s.createTopic(ctx, sanitizedName)
topicArn, err = s.createTopic(ctx, sanitizedTopic)
if err != nil {
s.logger.Errorf("error creating new topic %s: %w", topic, err)
err = fmt.Errorf("error creating new (sanitized) topic '%s': %w", topic, err)
return "", "", err
return topicArn, sanitizedTopic, err
}
} else {
topicArn, err = s.getTopicArn(ctx, sanitizedName)
topicArn, err = s.getTopicArn(ctx, sanitizedTopic)
if err != nil {
s.logger.Errorf("error fetching info for topic %s: %w", topic, err)
err = fmt.Errorf("error fetching info for (sanitized) topic: %s. wrapped error is: %w", topic, err)
return "", "", err
return topicArn, sanitizedTopic, err
}
}
// record topic ARN.
s.topicArns[sanitizedName] = topicArn
s.topicArns[sanitizedTopic] = topicArn
return topicArn, sanitizedName, nil
return topicArn, sanitizedTopic, err
}
func (s *snsSqs) createQueue(parentCtx context.Context, queueName string) (*sqsQueueInfo, error) {
@ -346,13 +336,13 @@ func (s *snsSqs) getOrCreateQueue(ctx context.Context, queueName string) (*sqsQu
queueInfo *sqsQueueInfo
)
if cachedQueueInfo, ok := s.queues.Load(queueName); ok {
s.logger.Debugf("Found queue arn for %s: %s", queueName, cachedQueueInfo.(*sqsQueueInfo).arn)
if cachedQueueInfo, ok := s.queues[queueName]; ok {
s.logger.Debugf("Found queue ARN for %s: %s", queueName, cachedQueueInfo.arn)
return cachedQueueInfo.(*sqsQueueInfo), nil
return cachedQueueInfo, nil
}
// creating queues is idempotent, the names serve as unique keys among a given region.
s.logger.Debugf("No SQS queue arn found for %s\nCreating SQS queue", queueName)
s.logger.Debugf("No SQS queue ARN found for %s\nCreating SQS queue", queueName)
sanitizedName := nameToAWSSanitizedName(queueName, s.metadata.Fifo)
@ -372,8 +362,8 @@ func (s *snsSqs) getOrCreateQueue(ctx context.Context, queueName string) (*sqsQu
}
}
s.queues.Store(queueName, queueInfo)
s.logger.Debugf("Created SQS queue: %s: with arn: %s", queueName, queueInfo.arn)
s.queues[queueName] = queueInfo
s.logger.Debugf("created SQS queue: %s: with arn: %s", queueName, queueInfo.arn)
return queueInfo, nil
}
@ -429,13 +419,13 @@ func (s *snsSqs) getSnsSqsSubscriptionArn(parentCtx context.Context, topicArn st
func (s *snsSqs) getOrCreateSnsSqsSubscription(ctx context.Context, queueArn, topicArn string) (subscriptionArn string, err error) {
compositeKey := fmt.Sprintf("%s:%s", queueArn, topicArn)
if cachedSubscriptionArn, ok := s.subscriptions.Load(compositeKey); ok {
if cachedSubscriptionArn, ok := s.subscriptions[compositeKey]; ok {
s.logger.Debugf("Found subscription of queue arn: %s to topic arn: %s: %s", queueArn, topicArn, cachedSubscriptionArn)
return cachedSubscriptionArn.(string), nil
return cachedSubscriptionArn, nil
}
s.logger.Debugf("No subscription arn found of queue arn:%s to topic arn: %s\nCreating subscription", queueArn, topicArn)
s.logger.Debugf("No subscription ARN found of queue arn:%s to topic arn: %s\nCreating subscription", queueArn, topicArn)
if !s.metadata.DisableEntityManagement {
subscriptionArn, err = s.createSnsSqsSubscription(ctx, queueArn, topicArn)
@ -447,13 +437,13 @@ func (s *snsSqs) getOrCreateSnsSqsSubscription(ctx context.Context, queueArn, to
} else {
subscriptionArn, err = s.getSnsSqsSubscriptionArn(ctx, topicArn)
if err != nil {
s.logger.Errorf("error fetching info for topic arn %s: %w", topicArn, err)
s.logger.Errorf("error fetching info for topic ARN %s: %w", topicArn, err)
return "", err
}
}
s.subscriptions.Store(compositeKey, subscriptionArn)
s.subscriptions[compositeKey] = subscriptionArn
s.logger.Debugf("Subscribed to topic %s: %s", topicArn, subscriptionArn)
return subscriptionArn, nil
@ -555,18 +545,25 @@ func (s *snsSqs) callHandler(ctx context.Context, message *sqs.Message, queueInf
// for the user to be able to understand the source of the coming message, we'd use the original,
// dirty name to be carried over in the pubsub.NewMessage Topic field.
sanitizedTopic := snsMessagePayload.parseTopicArn()
s.topicsLock.RLock()
handler, ok := s.topicHandlers[sanitizedTopic]
s.topicsLock.RUnlock()
if !ok || handler.topicName == "" {
return fmt.Errorf("handler for topic (sanitized): %s not found", sanitizedTopic)
// get a handler by sanitized topic name and perform validations
var (
handler *SubscriptionTopicHandler
loadOK bool
)
if handler, loadOK = s.subscriptionManager.GetSubscriptionTopicHandler(sanitizedTopic); loadOK {
if len(handler.requestTopic) == 0 {
return fmt.Errorf("handler topic name is missing")
}
} else {
return fmt.Errorf("handler for (sanitized) topic: %s was not found", sanitizedTopic)
}
s.logger.Debugf("Processing SNS message id: %s of topic: %s", *message.MessageId, sanitizedTopic)
s.logger.Debugf("Processing SNS message id: %s of (sanitized) topic: %s", *message.MessageId, sanitizedTopic)
// call the handler with its own subscription context
err = handler.handler(handler.ctx, &pubsub.NewMessage{
Data: []byte(snsMessagePayload.Message),
Topic: handler.topicName,
Topic: handler.requestTopic,
})
if err != nil {
return fmt.Errorf("error handling message: %w", err)
@ -575,6 +572,8 @@ func (s *snsSqs) callHandler(ctx context.Context, message *sqs.Message, queueInf
return s.acknowledgeMessage(ctx, queueInfo.url, message.ReceiptHandle)
}
// consumeSubscription is responsible for polling messages from the queue and calling the handler.
// it is being passed as a callback to the subscription manager that initializes the context of the handler.
func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLettersQueueInfo *sqsQueueInfo) {
sqsPullExponentialBackoff := s.backOffConfig.NewBackOffWithContext(ctx)
@ -601,12 +600,13 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters
// iteration. Therefore, a global backoff (to the internal backoff) is used (sqsPullExponentialBackoff).
messageResponse, err := s.sqsClient.ReceiveMessageWithContext(ctx, receiveMessageInput)
if err != nil {
if err == context.Canceled || err == context.DeadlineExceeded {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || ctx.Err() != nil {
s.logger.Warn("context canceled; stopping consuming from queue arn: %v", queueInfo.arn)
continue
}
if awsErr, ok := err.(awserr.Error); ok {
var awsErr awserr.Error
if errors.As(err, &awsErr) {
s.logger.Errorf("AWS operation error while consuming from queue arn: %v with error: %w. retrying...", queueInfo.arn, awsErr.Error())
} else {
s.logger.Errorf("error consuming from queue arn: %v with error: %w. retrying...", queueInfo.arn, err)
@ -619,7 +619,6 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters
sqsPullExponentialBackoff.Reset()
if len(messageResponse.Messages) < 1 {
// s.logger.Debug("No messages received, continuing")
continue
}
s.logger.Debugf("%v message(s) received on queue %s", len(messageResponse.Messages), queueInfo.arn)
@ -632,11 +631,10 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters
}
f := func(message *sqs.Message) {
defer wg.Done()
if err := s.callHandler(ctx, message, queueInfo); err != nil {
s.logger.Errorf("error while handling received message. error is: %v", err)
}
wg.Done()
}
wg.Add(1)
@ -653,9 +651,6 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters
}
wg.Wait()
}
// Signal that the poller stopped
<-s.pollerRunning
}
func (s *snsSqs) createDeadLettersQueueAttributes(queueInfo, deadLettersQueueInfo *sqsQueueInfo) (*sqs.SetQueueAttributesInput, error) {
@ -763,6 +758,9 @@ func (s *snsSqs) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, han
return errors.New("component is closed")
}
s.topicsLocker.Lock(req.Topic)
defer s.topicsLocker.Unlock(req.Topic)
// subscribers declare a topic ARN and declare a SQS queue to use
// these should be idempotent - queues should not be created if they exist.
topicArn, sanitizedName, err := s.getOrCreateTopic(ctx, req.Topic)
@ -824,63 +822,15 @@ func (s *snsSqs) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, han
return wrappedErr
}
// Store the handler for this topic
s.topicsLock.Lock()
defer s.topicsLock.Unlock()
s.topicHandlers[sanitizedName] = topicHandler{
topicName: req.Topic,
handler: handler,
ctx: ctx,
}
// start the subscription manager
s.subscriptionManager.Init(queueInfo, deadLettersQueueInfo, s.consumeSubscription)
// pollerCancel is used to cancel the polling goroutine. We use a noop cancel
// func in case the poller is already running and there is no cancel to use
// from the select below.
var pollerCancel context.CancelFunc = func() {}
// Start the poller for the queue if it's not running already
select {
case s.pollerRunning <- struct{}{}:
// If inserting in the channel succeeds, then it's not running already
// Use a context that is tied to the background context
var subctx context.Context
subctx, pollerCancel = context.WithCancel(context.Background())
s.wg.Add(2)
go func() {
defer s.wg.Done()
defer pollerCancel()
select {
case <-s.closeCh:
case <-subctx.Done():
}
}()
go func() {
defer s.wg.Done()
s.consumeSubscription(subctx, queueInfo, deadLettersQueueInfo)
}()
default:
// Do nothing, it means the poller is already running
}
// Watch for subscription context cancellation to remove this subscription
s.wg.Add(1)
go func() {
defer s.wg.Done()
select {
case <-ctx.Done():
case <-s.closeCh:
}
s.topicsLock.Lock()
defer s.topicsLock.Unlock()
// Remove the handler
delete(s.topicHandlers, sanitizedName)
// If we don't have any topic left, close the poller.
if len(s.topicHandlers) == 0 {
pollerCancel()
}
}()
s.subscriptionManager.Subscribe(&SubscriptionTopicHandler{
topic: sanitizedName,
requestTopic: req.Topic,
handler: handler,
ctx: ctx,
})
return nil
}
@ -920,9 +870,9 @@ func (s *snsSqs) Publish(ctx context.Context, req *pubsub.PublishRequest) error
// client. Blocks until all goroutines have returned.
func (s *snsSqs) Close() error {
if s.closed.CompareAndSwap(false, true) {
close(s.closeCh)
s.subscriptionManager.Close()
}
s.wg.Wait()
return nil
}

View File

@ -0,0 +1,179 @@
package snssqs
import (
"context"
"sync"
"github.com/puzpuzpuz/xsync/v3"
"github.com/dapr/components-contrib/pubsub"
"github.com/dapr/kit/logger"
)
type (
SubscriptionAction int
)
const (
Subscribe SubscriptionAction = iota
Unsubscribe
)
type SubscriptionTopicHandler struct {
topic string
requestTopic string
handler pubsub.Handler
ctx context.Context
}
type changeSubscriptionTopicHandler struct {
action SubscriptionAction
handler *SubscriptionTopicHandler
}
type SubscriptionManager struct {
logger logger.Logger
consumeCancelFunc context.CancelFunc
closeCh chan struct{}
topicsChangeCh chan changeSubscriptionTopicHandler
topicsHandlers *xsync.MapOf[string, *SubscriptionTopicHandler]
lock sync.Mutex
wg sync.WaitGroup
initOnce sync.Once
}
type SubscriptionManagement interface {
Init(queueInfo *sqsQueueInfo, dlqInfo *sqsQueueInfo, cbk func(context.Context, *sqsQueueInfo, *sqsQueueInfo))
Subscribe(topicHandler *SubscriptionTopicHandler)
Close()
GetSubscriptionTopicHandler(topic string) (*SubscriptionTopicHandler, bool)
}
func NewSubscriptionMgmt(log logger.Logger) SubscriptionManagement {
return &SubscriptionManager{
logger: log,
consumeCancelFunc: func() {}, // noop until we (re)start sqs consumption
closeCh: make(chan struct{}),
topicsChangeCh: make(chan changeSubscriptionTopicHandler),
topicsHandlers: xsync.NewMapOf[string, *SubscriptionTopicHandler](),
}
}
func createQueueConsumerCbk(queueInfo *sqsQueueInfo, dlqInfo *sqsQueueInfo, cbk func(ctx context.Context, queueInfo *sqsQueueInfo, dlqInfo *sqsQueueInfo)) func(ctx context.Context) {
return func(ctx context.Context) {
cbk(ctx, queueInfo, dlqInfo)
}
}
func (sm *SubscriptionManager) Init(queueInfo *sqsQueueInfo, dlqInfo *sqsQueueInfo, cbk func(context.Context, *sqsQueueInfo, *sqsQueueInfo)) {
sm.initOnce.Do(func() {
queueConsumerCbk := createQueueConsumerCbk(queueInfo, dlqInfo, cbk)
go sm.queueConsumerController(queueConsumerCbk)
sm.logger.Debug("Subscription manager initialized")
})
}
// queueConsumerController is responsible for managing the subscription lifecycle
// and the only place where the topicsHandlers map is updated.
// it is running in a separate goroutine and is responsible for starting and stopping sqs consumption
// where its lifecycle is managed by the subscription manager,
// and it has its own context with its child contexts used for sqs consumption and aborting of the consumption.
// it is also responsible for managing the lifecycle of the subscription handlers.
func (sm *SubscriptionManager) queueConsumerController(queueConsumerCbk func(context.Context)) {
ctx := context.Background()
for {
select {
case changeEvent := <-sm.topicsChangeCh:
topic := changeEvent.handler.topic
sm.logger.Debugf("Subscription change event received with action: %v, on topic: %s", changeEvent.action, topic)
// topic change events are serialized so that no interleaving can occur
sm.lock.Lock()
// although we have a lock here, the topicsHandlers map is thread safe and can be accessed concurrently so other subscribers that are already consuming messages
// can get the handler for the topic while we're still updating the map without blocking them
current := sm.topicsHandlers.Size()
switch changeEvent.action {
case Subscribe:
sm.topicsHandlers.Store(topic, changeEvent.handler)
// if before we've added the subscription there were no subscriptions, this subscribe signals us to start consuming from sqs
if current == 0 {
var subCtx context.Context
// create a new context for sqs consumption with a cancel func to be used when we unsubscribe from all topics
subCtx, sm.consumeCancelFunc = context.WithCancel(ctx)
// start sqs consumption
sm.logger.Info("Starting SQS consumption")
go queueConsumerCbk(subCtx)
}
case Unsubscribe:
sm.topicsHandlers.Delete(topic)
// for idempotency, we check the size of the map after the delete operation, as we might have already deleted the subscription
afterDelete := sm.topicsHandlers.Size()
// if before we've removed this subscription we had one (last) subscription, this signals us to stop sqs consumption
if current == 1 && afterDelete == 0 {
sm.logger.Info("Last subscription removed. no more handlers are mapped to topics. stopping SQS consumption")
sm.consumeCancelFunc()
}
}
sm.lock.Unlock()
case <-sm.closeCh:
return
}
}
}
func (sm *SubscriptionManager) Subscribe(topicHandler *SubscriptionTopicHandler) {
sm.logger.Debug("Subscribing to topic: ", topicHandler.topic)
sm.wg.Add(1)
go func() {
defer sm.wg.Done()
sm.createSubscribeListener(topicHandler)
}()
}
func (sm *SubscriptionManager) createSubscribeListener(topicHandler *SubscriptionTopicHandler) {
sm.logger.Debug("Creating a subscribe listener for topic: ", topicHandler.topic)
sm.topicsChangeCh <- changeSubscriptionTopicHandler{Subscribe, topicHandler}
closeCh := make(chan struct{})
// the unsubscriber is expected to be terminated by the dapr runtime as it cancels the context upon unsubscribe
go sm.createUnsubscribeListener(topicHandler.ctx, topicHandler.topic, closeCh)
// if the SubscriptinoManager is being closed and somehow the dapr runtime did not call unsubscribe, we close the control
// channel here to terminate the unsubscriber and return
defer close(closeCh)
<-sm.closeCh
}
// ctx is a context provided by daprd per subscription. unrelated to the consuming sm.baseCtx
func (sm *SubscriptionManager) createUnsubscribeListener(ctx context.Context, topic string, closeCh <-chan struct{}) {
sm.logger.Debug("Creating an unsubscribe listener for topic: ", topic)
defer sm.unsubscribe(topic)
for {
select {
case <-ctx.Done():
return
case <-closeCh:
return
}
}
}
func (sm *SubscriptionManager) unsubscribe(topic string) {
sm.logger.Debug("Unsubscribing from topic: ", topic)
if value, ok := sm.GetSubscriptionTopicHandler(topic); ok {
sm.topicsChangeCh <- changeSubscriptionTopicHandler{Unsubscribe, value}
}
}
func (sm *SubscriptionManager) Close() {
close(sm.closeCh)
sm.wg.Wait()
}
func (sm *SubscriptionManager) GetSubscriptionTopicHandler(topic string) (*SubscriptionTopicHandler, bool) {
return sm.topicsHandlers.Load(topic)
}

View File

@ -0,0 +1,44 @@
package snssqs
import (
"sync"
"github.com/puzpuzpuz/xsync/v3"
)
// TopicsLockManager is a singleton for fine-grained locking, to prevent the component r/w operations
// from locking the entire component out when performing operations on different topics.
type TopicsLockManager struct {
xLockMap *xsync.MapOf[string, *sync.Mutex]
}
type TopicsLocker interface {
Lock(topic string) *sync.Mutex
Unlock(topic string)
}
func NewLockManager() *TopicsLockManager {
return &TopicsLockManager{xLockMap: xsync.NewMapOf[string, *sync.Mutex]()}
}
func (lm *TopicsLockManager) Lock(key string) *sync.Mutex {
lock, _ := lm.xLockMap.LoadOrCompute(key, func() *sync.Mutex {
l := &sync.Mutex{}
l.Lock()
return l
})
return lock
}
func (lm *TopicsLockManager) Unlock(key string) {
lm.xLockMap.Compute(key, func(oldValue *sync.Mutex, exists bool) (newValue *sync.Mutex, delete bool) {
// if exists then the mutex must be already locked, and we unlock it
if exists {
oldValue.Unlock()
}
// we return to comply with the Compute signature, but not using the returned values
return oldValue, false
})
}

View File

@ -389,35 +389,6 @@ func (m *MySQL) ensureStateTable(ctx context.Context, schemaName, stateTableName
}
}
// Create the DaprSaveFirstWriteV1 stored procedure
_, err = m.db.ExecContext(ctx, `CREATE PROCEDURE IF NOT EXISTS DaprSaveFirstWriteV1(tableName VARCHAR(255), id VARCHAR(255), value JSON, etag VARCHAR(36), isbinary BOOLEAN, expiredateToken TEXT)
LANGUAGE SQL
MODIFIES SQL DATA
BEGIN
SET @id = id;
SET @value = value;
SET @etag = etag;
SET @isbinary = isbinary;
SET @selectQuery = concat('SELECT COUNT(id) INTO @count FROM ', tableName ,' WHERE id = ? AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)');
PREPARE select_stmt FROM @selectQuery;
EXECUTE select_stmt USING @id;
DEALLOCATE PREPARE select_stmt;
IF @count < 1 THEN
SET @upsertQuery = concat('INSERT INTO ', tableName, ' SET id=?, value=?, eTag=?, isbinary=?, expiredate=', expiredateToken, ' ON DUPLICATE KEY UPDATE value=?, eTag=?, isbinary=?, expiredate=', expiredateToken);
PREPARE upsert_stmt FROM @upsertQuery;
EXECUTE upsert_stmt USING @id, @value, @etag, @isbinary, @value, @etag, @isbinary;
DEALLOCATE PREPARE upsert_stmt;
ELSE
SIGNAL SQLSTATE '45000' SET MESSAGE_TEXT = 'Row already exists';
END IF;
END`)
if err != nil {
return err
}
return nil
}
@ -596,7 +567,6 @@ func (m *MySQL) setValue(parentCtx context.Context, querier querier, req *state.
ttlQuery string
params []any
result sql.Result
maxRows int64 = 1
)
var v any
@ -624,10 +594,7 @@ func (m *MySQL) setValue(parentCtx context.Context, querier querier, req *state.
ttlQuery = "NULL"
}
mustCommit := false
hasEtag := req.ETag != nil && *req.ETag != ""
if hasEtag {
if req.HasETag() {
// When an eTag is provided do an update - not insert
query = `UPDATE ` + m.tableName + `
SET value = ?, eTag = ?, isbinary = ?, expiredate = ` + ttlQuery + `
@ -636,30 +603,32 @@ func (m *MySQL) setValue(parentCtx context.Context, querier querier, req *state.
AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)`
params = []any{enc, eTag, isBinary, req.Key, *req.ETag}
} else if req.Options.Concurrency == state.FirstWrite {
// If we're not in a transaction already, start one as we need to ensure consistency
if querier == m.db {
querier, err = m.db.BeginTx(parentCtx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer querier.(*sql.Tx).Rollback()
mustCommit = true
}
// With first-write-wins and no etag, we can insert the row only if it doesn't exist
// Things get a bit tricky when the row exists but it is expired, so it just hasn't been garbage-collected yet
// What we can do in that case is to first check if the row doesn't exist or has expired, and then perform an upsert
// To do that, we use a stored procedure
query = "CALL DaprSaveFirstWriteV1(?, ?, ?, ?, ?, ?)"
params = []any{m.tableName, req.Key, enc, eTag, isBinary, ttlQuery}
// If the operation uses first-write concurrency, we need to handle the special case of a row that has expired but hasn't been garbage collected yet
// In this case, the row should be considered as if it were deleted
query = `REPLACE INTO ` + m.tableName + `
WITH a AS (
SELECT
? AS id,
? AS value,
? AS isbinary,
CURRENT_TIMESTAMP AS insertDate,
CURRENT_TIMESTAMP AS updateDate,
? AS eTag,
` + ttlQuery + ` AS expiredate
FROM ` + m.tableName + `
WHERE NOT EXISTS (
SELECT 1
FROM ` + m.tableName + `
WHERE id = ?
AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)
)
)
SELECT * FROM a`
params = []any{req.Key, enc, isBinary, eTag, req.Key}
} else {
// If this is a duplicate MySQL returns that two rows affected
maxRows = 2
query = `INSERT INTO ` + m.tableName + ` (id, value, eTag, isbinary, expiredate)
VALUES (?, ?, ?, ?, ` + ttlQuery + `)
ON DUPLICATE KEY UPDATE
value=?, eTag=?, isbinary=?, expiredate=` + ttlQuery
params = []any{req.Key, enc, eTag, isBinary, enc, eTag, isBinary}
query = `REPLACE INTO ` + m.tableName + ` (id, value, eTag, isbinary, expiredate)
VALUES (?, ?, ?, ?, ` + ttlQuery + `)`
params = []any{req.Key, enc, eTag, isBinary}
}
ctx, cancel := context.WithTimeout(parentCtx, m.timeout)
@ -667,42 +636,19 @@ func (m *MySQL) setValue(parentCtx context.Context, querier querier, req *state.
result, err = querier.ExecContext(ctx, query, params...)
if err != nil {
if hasEtag {
return state.NewETagError(state.ETagMismatch, err)
}
return err
}
// Do not count affected rows when using first-write
// Conflicts are handled separately
if hasEtag || req.Options.Concurrency != state.FirstWrite {
var rows int64
rows, err = result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
err = errors.New("rows affected error: no rows match given key and eTag")
err = state.NewETagError(state.ETagMismatch, err)
m.logger.Error(err)
return err
}
if rows > maxRows {
err = fmt.Errorf("rows affected error: more than %d row affected; actual %d", maxRows, rows)
m.logger.Error(err)
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
// Commit the transaction if needed
if mustCommit {
err = querier.(*sql.Tx).Commit()
if err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
if rows == 0 && (req.HasETag() || req.Options.Concurrency == state.FirstWrite) {
err = errors.New("rows affected error: no rows match given key and eTag")
err = state.NewETagError(state.ETagMismatch, err)
m.logger.Error(err)
return err
}
return nil

View File

@ -180,7 +180,7 @@ func TestMultiCommitSetsAndDeletes(t *testing.T) {
defer m.mySQL.Close()
m.mock1.ExpectBegin()
m.mock1.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectExec("REPLACE INTO").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectExec("DELETE FROM").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectCommit()
@ -255,24 +255,8 @@ func TestSetHandlesErr(t *testing.T) {
m, _ := mockDatabase(t)
defer m.mySQL.Close()
t.Run("error occurs when update with tag", func(t *testing.T) {
m.mock1.ExpectExec("UPDATE state").WillReturnError(errors.New("error"))
eTag := "946af561"
request := createSetRequest()
request.ETag = &eTag
// Act
err := m.mySQL.Set(context.Background(), &request)
// Assert
assert.Error(t, err)
assert.IsType(t, &state.ETagError{}, err)
assert.Equal(t, err.(*state.ETagError).Kind(), state.ETagMismatch)
})
t.Run("error occurs when insert", func(t *testing.T) {
m.mock1.ExpectExec("INSERT INTO state").WillReturnError(errors.New("error"))
m.mock1.ExpectExec("REPLACE INTO state").WillReturnError(errors.New("error"))
request := createSetRequest()
// Act
@ -284,7 +268,7 @@ func TestSetHandlesErr(t *testing.T) {
})
t.Run("insert on conflict", func(t *testing.T) {
m.mock1.ExpectExec("INSERT INTO state").WillReturnResult(sqlmock.NewResult(1, 2))
m.mock1.ExpectExec("REPLACE INTO state").WillReturnResult(sqlmock.NewResult(1, 2))
request := createSetRequest()
// Act
@ -294,17 +278,6 @@ func TestSetHandlesErr(t *testing.T) {
assert.NoError(t, err)
})
t.Run("too many rows error", func(t *testing.T) {
m.mock1.ExpectExec("INSERT INTO state").WillReturnResult(sqlmock.NewResult(1, 3))
request := createSetRequest()
// Act
err := m.mySQL.Set(context.Background(), &request)
// Assert
assert.Error(t, err)
})
t.Run("no rows effected error", func(t *testing.T) {
m.mock1.ExpectExec("UPDATE state").WillReturnResult(sqlmock.NewResult(1, 0))
@ -716,7 +689,7 @@ func TestValidSetRequest(t *testing.T) {
}
m.mock1.ExpectBegin()
m.mock1.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectExec("REPLACE INTO").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectCommit()
// Act
@ -805,9 +778,9 @@ func TestMultiOperationOrder(t *testing.T) {
// expected to run the operations in sequence
m.mock1.ExpectBegin()
m.mock1.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectExec("REPLACE INTO").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectExec("DELETE FROM").WithArgs("k1").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectExec("REPLACE INTO").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectCommit()
// Act

View File

@ -104,7 +104,17 @@ func (a *sqliteDBAccess) Init(ctx context.Context, md state.Metadata) error {
return fmt.Errorf("failed to perform migrations: %w", err)
}
gc, err := internalsql.ScheduleGarbageCollector(internalsql.GCOptions{
// Init the background GC
err = a.initGC()
if err != nil {
return err
}
return nil
}
func (a *sqliteDBAccess) initGC() (err error) {
a.gc, err = internalsql.ScheduleGarbageCollector(internalsql.GCOptions{
Logger: a.logger,
UpdateLastCleanupQuery: func(arg any) (string, any) {
return fmt.Sprintf(`INSERT INTO %s (key, value)
@ -124,12 +134,7 @@ func (a *sqliteDBAccess) Init(ctx context.Context, md state.Metadata) error {
CleanupInterval: a.metadata.CleanupInterval,
DB: internalsql.AdaptDatabaseSQLConn(a.db),
})
if err != nil {
return err
}
a.gc = gc
return nil
return err
}
func (a *sqliteDBAccess) CleanupExpired() error {
@ -333,52 +338,40 @@ func (a *sqliteDBAccess) doSet(parentCtx context.Context, db querier, req *state
// Only check for etag if FirstWrite specified (ref oracledatabaseaccess)
var (
res sql.Result
mustCommit bool
stmt string
res sql.Result
stmt string
)
ctx, cancel := context.WithTimeout(context.Background(), a.metadata.Timeout)
defer cancel()
// Sprintf is required for table name because sql.DB does not substitute parameters for table names.
// And the same is for DATETIME function's seconds parameter (which is from an integer anyways).
if !req.HasETag() {
switch {
case !req.HasETag() && req.Options.Concurrency == state.FirstWrite:
// If the operation uses first-write concurrency, we need to handle the special case of a row that has expired but hasn't been garbage collected yet
// In this case, the row should be considered as if it were deleted
// With SQLite, the only way we can handle that is by performing a SELECT query first
if req.Options.Concurrency == state.FirstWrite {
// If we're not in a transaction already, start one as we need to ensure consistency
if db == a.db {
db, err = a.db.BeginTx(parentCtx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer db.(*sql.Tx).Rollback()
mustCommit = true
}
// Check if there's already a row with the given key that has not expired yet
var count int
stmt = `SELECT COUNT(key)
stmt = `WITH a AS (
SELECT
?, ?, ?, ?, ` + expiration + `, CURRENT_TIMESTAMP
FROM ` + a.metadata.TableName + `
WHERE key = ?
AND (expiration_time IS NULL OR expiration_time > CURRENT_TIMESTAMP)`
err = db.QueryRowContext(parentCtx, stmt, req.Key).Scan(&count)
if err != nil {
return fmt.Errorf("failed to check for existing row with first-write concurrency: %w", err)
}
// If the row exists, then we just return an etag error
// Otherwise, we can fall through and continue with an INSERT OR REPLACE statement
if count > 0 {
return state.NewETagError(state.ETagMismatch, nil)
}
}
WHERE NOT EXISTS (
SELECT 1
FROM ` + a.metadata.TableName + `
WHERE key = ?
AND (expiration_time IS NULL OR expiration_time > CURRENT_TIMESTAMP)
)
)
INSERT OR REPLACE INTO ` + a.metadata.TableName + `
SELECT * FROM a`
res, err = db.ExecContext(ctx, stmt, req.Key, requestValue, isBinary, newEtag, req.Key)
case !req.HasETag():
stmt = "INSERT OR REPLACE INTO " + a.metadata.TableName + `
(key, value, is_binary, etag, update_time, expiration_time)
VALUES(?, ?, ?, ?, CURRENT_TIMESTAMP, ` + expiration + `)`
ctx, cancel := context.WithTimeout(context.Background(), a.metadata.Timeout)
defer cancel()
res, err = db.ExecContext(ctx, stmt, req.Key, requestValue, isBinary, newEtag, req.Key)
} else {
res, err = db.ExecContext(ctx, stmt, req.Key, requestValue, isBinary, newEtag)
default:
stmt = `UPDATE ` + a.metadata.TableName + ` SET
value = ?,
etag = ?,
@ -389,8 +382,6 @@ func (a *sqliteDBAccess) doSet(parentCtx context.Context, db querier, req *state
key = ?
AND etag = ?
AND (expiration_time IS NULL OR expiration_time > CURRENT_TIMESTAMP)`
ctx, cancel := context.WithTimeout(context.Background(), a.metadata.Timeout)
defer cancel()
res, err = db.ExecContext(ctx, stmt, requestValue, newEtag, isBinary, req.Key, *req.ETag)
}
if err != nil {
@ -403,20 +394,12 @@ func (a *sqliteDBAccess) doSet(parentCtx context.Context, db querier, req *state
return err
}
if rows == 0 {
if req.HasETag() {
if req.HasETag() || req.Options.Concurrency == state.FirstWrite {
return state.NewETagError(state.ETagMismatch, nil)
}
return errors.New("no item was updated")
}
// Commit the transaction if needed
if mustCommit {
err = db.(*sql.Tx).Commit()
if err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
}
return nil
}
@ -450,17 +433,25 @@ func (a *sqliteDBAccess) ExecuteMulti(parentCtx context.Context, reqs []state.Tr
return tx.Commit()
}
// Close implements io.Close.
func (a *sqliteDBAccess) Close() error {
if a.db != nil {
_ = a.db.Close()
}
// Close implements io.Closer.
func (a *sqliteDBAccess) Close() (err error) {
errs := make([]error, 0)
if a.gc != nil {
return a.gc.Close()
err = a.gc.Close()
if err != nil {
errs = append(errs, err)
}
}
return nil
if a.db != nil {
err = a.db.Close()
if err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
}
func (a *sqliteDBAccess) doDelete(parentCtx context.Context, db querier, req *state.DeleteRequest) error {

View File

@ -17,8 +17,8 @@ require (
github.com/cloudwego/kitex v0.5.0
github.com/cloudwego/kitex-examples v0.1.1
github.com/dapr/components-contrib v1.12.0-rc.4.0.20231009175401-9f2cc5c158bb
github.com/dapr/dapr v1.12.0
github.com/dapr/go-sdk v1.9.0
github.com/dapr/dapr v1.12.1-0.20231013174004-b6540a1c464d
github.com/dapr/go-sdk v1.6.1-0.20231014032604-69e788045df0
github.com/dapr/kit v0.12.1
github.com/eclipse/paho.mqtt.golang v1.4.3
github.com/go-chi/chi/v5 v5.0.10
@ -227,6 +227,7 @@ require (
github.com/prometheus/common v0.44.0 // indirect
github.com/prometheus/procfs v0.11.0 // indirect
github.com/prometheus/statsd_exporter v0.22.7 // indirect
github.com/puzpuzpuz/xsync/v3 v3.0.0 // indirect
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 // indirect
github.com/redis/go-redis/v9 v9.2.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
@ -314,3 +315,6 @@ require (
)
replace github.com/dapr/components-contrib => ../../
// TODO: REMOVE WHEN https://github.com/dapr/dapr/pull/7038 IS MERGED
replace github.com/dapr/dapr => github.com/italypaleale/dapr v1.6.1-0.20231015174742-7538aab2c0f2

View File

@ -332,10 +332,8 @@ github.com/cyphar/filepath-securejoin v0.2.4 h1:Ugdm7cg7i6ZK6x3xDF1oEu1nfkyfH53E
github.com/cyphar/filepath-securejoin v0.2.4/go.mod h1:aPGpWjXOXUn2NCNjFvBE6aRxGGx79pTxQpKOJNYHHl4=
github.com/danieljoos/wincred v1.1.2 h1:QLdCxFs1/Yl4zduvBdcHB8goaYk9RARS2SgLLRuAyr0=
github.com/danieljoos/wincred v1.1.2/go.mod h1:GijpziifJoIBfYh+S7BbkdUTU4LfM+QnGqR5Vl2tAx0=
github.com/dapr/dapr v1.12.0 h1:JrnzIYupSHLVK95YWlMZv+h/rfg5rGl53ml9EG4Vrqg=
github.com/dapr/dapr v1.12.0/go.mod h1:LjmQepPe5+A898VHMUsmX0r0D0OTj6ijtqCIZqMvo7o=
github.com/dapr/go-sdk v1.9.0 h1:36pUgSwgh/SIYniRT6t1DAu3tv4DcYUmdIvktI6QpoM=
github.com/dapr/go-sdk v1.9.0/go.mod h1:bK9bNEsC6hY3RMKh69r0nBjLqb6njeWTEGVMOgP9g20=
github.com/dapr/go-sdk v1.6.1-0.20231014032604-69e788045df0 h1:mojcJ67LMl6mzpZRnpCuuD9iwOnu6UUNnHxGQvuDH40=
github.com/dapr/go-sdk v1.6.1-0.20231014032604-69e788045df0/go.mod h1:hgDH/7xTmza3hO8eGCfeQvAW6XnPAZ76j4XsYTyvKSM=
github.com/dapr/kit v0.12.1 h1:XT0CJQQaKRYSzIzZo15O1PAHGUrMGoAavdFRcNVZ+UE=
github.com/dapr/kit v0.12.1/go.mod h1:eNYjsudq3Ij0x8CLWsPturHor56sZRNu5tk2hUiJT80=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@ -757,6 +755,8 @@ github.com/imdario/mergo v0.3.13 h1:lFzP57bqS/wsqKssCGmtLAb8A0wKjLGrve2q3PPVcBk=
github.com/imdario/mergo v0.3.13/go.mod h1:4lJ1jqUDcsbIECGy0RUJAXNIhg+6ocWgb1ALK2O4oXg=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo=
github.com/italypaleale/dapr v1.6.1-0.20231015174742-7538aab2c0f2 h1:mcWBaoo2bHzrvXBLH1/XhplzfAC5BAw2b3ws5ECaPMg=
github.com/italypaleale/dapr v1.6.1-0.20231015174742-7538aab2c0f2/go.mod h1:PHlURwPY4djIz7ZLdCLxc5hrkFwf2r+E5GjNWxvyP0A=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
@ -1108,6 +1108,8 @@ github.com/prometheus/statsd_exporter v0.21.0/go.mod h1:rbT83sZq2V+p73lHhPZfMc3M
github.com/prometheus/statsd_exporter v0.22.7 h1:7Pji/i2GuhK6Lu7DHrtTkFmNBCudCPT1pX2CziuyQR0=
github.com/prometheus/statsd_exporter v0.22.7/go.mod h1:N/TevpjkIh9ccs6nuzY3jQn9dFqnUakOjnEuMPJJJnI=
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
github.com/puzpuzpuz/xsync/v3 v3.0.0 h1:QwUcmah+dZZxy6va/QSU26M6O6Q422afP9jO8JlnRSA=
github.com/puzpuzpuz/xsync/v3 v3.0.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
github.com/rabbitmq/amqp091-go v1.8.1 h1:RejT1SBUim5doqcL6s7iN6SBmsQqyTgXb1xMlH0h1hA=
github.com/rabbitmq/amqp091-go v1.8.1/go.mod h1:+jPrT9iY2eLjRaMSRHUhc3z14E/l85kv/f+6luSD3pc=
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=