Merge pull request #2302 from ItalyPaleAle/postgres-ttl

Add TTL to postgres state store
This commit is contained in:
Bernd Verst 2022-12-16 11:55:19 -08:00 committed by GitHub
commit ea9b623ccb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 1345 additions and 397 deletions

View File

@ -21,7 +21,6 @@ import (
"fmt"
"net/http"
"reflect"
"strconv"
"strings"
"time"
@ -36,6 +35,7 @@ import (
contribmeta "github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/state"
"github.com/dapr/components-contrib/state/query"
stateutils "github.com/dapr/components-contrib/state/utils"
"github.com/dapr/kit/logger"
"github.com/dapr/kit/ptr"
)
@ -77,7 +77,6 @@ type CosmosItem struct {
const (
metadataPartitionKey = "partitionKey"
metadataTTLKey = "ttlInSeconds"
defaultTimeout = 20 * time.Second
statusNotFound = "NotFound"
)
@ -481,7 +480,7 @@ func createUpsertItem(contentType string, req state.SetRequest, partitionKey str
isBinary = false
}
ttl, err := parseTTL(req.Metadata)
ttl, err := stateutils.ParseTTL(req.Metadata)
if err != nil {
return CosmosItem{}, fmt.Errorf("error parsing TTL from metadata: %s", err)
}
@ -534,20 +533,6 @@ func populatePartitionMetadata(key string, requestMetadata map[string]string) st
return key
}
func parseTTL(requestMetadata map[string]string) (*int, error) {
if val, found := requestMetadata[metadataTTLKey]; found && val != "" {
parsedVal, err := strconv.ParseInt(val, 10, 0)
if err != nil {
return nil, err
}
i := int(parsedVal)
return &i, nil
}
return nil, nil
}
func isNotFoundError(err error) bool {
if err == nil {
return false

View File

@ -21,6 +21,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/dapr/components-contrib/state"
stateutils "github.com/dapr/components-contrib/state/utils"
)
type widget struct {
@ -284,7 +285,7 @@ func TestCreateCosmosItemWithTTL(t *testing.T) {
Key: "testKey",
Value: value,
Metadata: map[string]string{
metadataTTLKey: strconv.Itoa(ttl),
stateutils.MetadataTTLKey: strconv.Itoa(ttl),
},
}
@ -316,7 +317,7 @@ func TestCreateCosmosItemWithTTL(t *testing.T) {
Key: "testKey",
Value: value,
Metadata: map[string]string{
metadataTTLKey: strconv.Itoa(ttl),
stateutils.MetadataTTLKey: strconv.Itoa(ttl),
},
}
@ -347,7 +348,7 @@ func TestCreateCosmosItemWithTTL(t *testing.T) {
Key: "testKey",
Value: value,
Metadata: map[string]string{
metadataTTLKey: "notattl",
stateutils.MetadataTTLKey: "notattl",
},
}

View File

@ -24,6 +24,7 @@ import (
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/state"
stateutils "github.com/dapr/components-contrib/state/utils"
"github.com/dapr/kit/logger"
)
@ -279,7 +280,7 @@ func (c *Cassandra) Set(ctx context.Context, req *state.SetRequest) error {
session = sess
}
ttl, err := parseTTL(req.Metadata)
ttl, err := stateutils.ParseTTL(req.Metadata)
if err != nil {
return fmt.Errorf("error parsing TTL from Metadata: %s", err)
}
@ -302,20 +303,6 @@ func (c *Cassandra) createSession(consistency gocql.Consistency) (*gocql.Session
return session, nil
}
func parseTTL(requestMetadata map[string]string) (*int, error) {
if val, found := requestMetadata[metadataTTLKey]; found && val != "" {
parsedVal, err := strconv.ParseInt(val, 10, 0)
if err != nil {
return nil, err
}
parsedInt := int(parsedVal)
return &parsedInt, nil
}
return nil, nil
}
func (c *Cassandra) GetComponentMetadata() map[string]string {
metadataStruct := cassandraMetadata{}
metadataInfo := map[string]string{}

View File

@ -14,7 +14,6 @@ limitations under the License.
package cassandra
import (
"strconv"
"strings"
"testing"
@ -111,35 +110,3 @@ func TestGetCassandraMetadata(t *testing.T) {
assert.NotNil(t, err)
})
}
func TestParseTTL(t *testing.T) {
t.Run("TTL Not an integer", func(t *testing.T) {
ttlInSeconds := "not an integer"
ttl, err := parseTTL(map[string]string{
"ttlInSeconds": ttlInSeconds,
})
assert.Error(t, err)
assert.Nil(t, ttl)
})
t.Run("TTL specified with wrong key", func(t *testing.T) {
ttlInSeconds := 12345
ttl, err := parseTTL(map[string]string{
"expirationTime": strconv.Itoa(ttlInSeconds),
})
assert.NoError(t, err)
assert.Nil(t, ttl)
})
t.Run("TTL is a number", func(t *testing.T) {
ttlInSeconds := 12345
ttl, err := parseTTL(map[string]string{
"ttlInSeconds": strconv.Itoa(ttlInSeconds),
})
assert.NoError(t, err)
assert.Equal(t, *ttl, ttlInSeconds)
})
t.Run("TTL not set", func(t *testing.T) {
ttl, err := parseTTL(map[string]string{})
assert.NoError(t, err)
assert.Nil(t, ttl)
})
}

View File

@ -93,7 +93,7 @@ func TestParseTTL(t *testing.T) {
},
})
assert.NotNil(t, err, "tll is not an integer")
assert.Error(t, err)
assert.Nil(t, ttl)
})
t.Run("TTL is a negative integer ends up translated to 0", func(t *testing.T) {

View File

@ -22,7 +22,6 @@ import (
"os"
"path"
"reflect"
"strconv"
"strings"
"time"
@ -34,6 +33,7 @@ import (
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/state"
stateutils "github.com/dapr/components-contrib/state/utils"
"github.com/dapr/kit/logger"
)
@ -278,17 +278,13 @@ func (r *StateStore) writeDocument(ctx context.Context, req *state.SetRequest) e
}
func (r *StateStore) convertTTLtoExpiryTime(req *state.SetRequest, metadata map[string]string) error {
ttl, ttlerr := parseTTL(req.Metadata)
ttl, ttlerr := stateutils.ParseTTL(req.Metadata)
if ttlerr != nil {
return fmt.Errorf("error in parsing TTL %w", ttlerr)
return fmt.Errorf("error parsing TTL: %w", ttlerr)
}
if ttl != nil {
if *ttl == -1 {
r.logger.Debugf("TTL is set to -1; this means: never expire. ")
} else {
metadata[expiryTimeMetaLabel] = time.Now().UTC().Add(time.Second * time.Duration(*ttl)).Format(isoDateTimeFormat)
r.logger.Debugf("Set %s in meta properties for object to ", expiryTimeMetaLabel, metadata[expiryTimeMetaLabel])
}
metadata[expiryTimeMetaLabel] = time.Now().UTC().Add(time.Second * time.Duration(*ttl)).Format(isoDateTimeFormat)
r.logger.Debugf("Set %s in meta properties for object to ", expiryTimeMetaLabel, metadata[expiryTimeMetaLabel])
}
return nil
}
@ -367,20 +363,6 @@ func getFileName(key string) string {
return path.Join(pr[0], pr[1])
}
func parseTTL(requestMetadata map[string]string) (*int, error) {
if val, found := requestMetadata[metadataTTLKey]; found && val != "" {
parsedVal, err := strconv.ParseInt(val, 10, 0)
if err != nil {
return nil, fmt.Errorf("error in parsing ttl metadata : %w", err)
}
parsedInt := int(parsedVal)
return &parsedInt, nil
}
return nil, nil
}
/**************** functions with OCI ObjectStorage Service interaction. */
func getNamespace(ctx context.Context, client objectstorage.ObjectStorageClient) (string, error) {

View File

@ -17,7 +17,6 @@ import (
"context"
"fmt"
"io"
"strconv"
"testing"
"time"
@ -394,36 +393,3 @@ func TestGetFilename(t *testing.T) {
assert.Equal(t, "app-id-key", filename)
})
}
func TestParseTTL(t *testing.T) {
t.Parallel()
t.Run("TTL Not an integer", func(t *testing.T) {
ttlInSeconds := "not an integer"
ttl, err := parseTTL(map[string]string{
"ttlInSeconds": ttlInSeconds,
})
assert.Error(t, err)
assert.Nil(t, ttl)
})
t.Run("TTL specified with wrong key", func(t *testing.T) {
ttlInSeconds := 12345
ttl, err := parseTTL(map[string]string{
"expirationTime": strconv.Itoa(ttlInSeconds),
})
assert.NoError(t, err)
assert.Nil(t, ttl)
})
t.Run("TTL is a number", func(t *testing.T) {
ttlInSeconds := 12345
ttl, err := parseTTL(map[string]string{
"ttlInSeconds": strconv.Itoa(ttlInSeconds),
})
assert.NoError(t, err)
assert.Equal(t, *ttl, ttlInSeconds)
})
t.Run("TTL not set", func(t *testing.T) {
ttl, err := parseTTL(map[string]string{})
assert.NoError(t, err)
assert.Nil(t, ttl)
})
}

View File

@ -21,7 +21,6 @@ import (
"fmt"
"net/url"
"os"
"strconv"
"testing"
"time"
@ -658,43 +657,6 @@ func setItemWithNoKey(t *testing.T, ods *OracleDatabase) {
assert.NotNil(t, err)
}
func TestParseTTL(t *testing.T) {
t.Parallel()
t.Run("TTL Not an integer", func(t *testing.T) {
t.Parallel()
ttlInSeconds := "not an integer"
ttl, err := parseTTL(map[string]string{
"ttlInSeconds": ttlInSeconds,
})
assert.Error(t, err)
assert.Nil(t, ttl)
})
t.Run("TTL specified with wrong key", func(t *testing.T) {
t.Parallel()
ttlInSeconds := 12345
ttl, err := parseTTL(map[string]string{
"expirationTime": strconv.Itoa(ttlInSeconds),
})
assert.NoError(t, err)
assert.Nil(t, ttl)
})
t.Run("TTL is a number", func(t *testing.T) {
t.Parallel()
ttlInSeconds := 12345
ttl, err := parseTTL(map[string]string{
"ttlInSeconds": strconv.Itoa(ttlInSeconds),
})
assert.NoError(t, err)
assert.Equal(t, *ttl, ttlInSeconds)
})
t.Run("TTL not set", func(t *testing.T) {
t.Parallel()
ttl, err := parseTTL(map[string]string{})
assert.NoError(t, err)
assert.Nil(t, ttl)
})
}
func testSetItemWithInvalidTTL(t *testing.T, ods *OracleDatabase) {
setReq := &state.SetRequest{
Key: randomKey(),

View File

@ -20,13 +20,12 @@ import (
"encoding/json"
"fmt"
"net/url"
"strconv"
"github.com/google/uuid"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/state"
"github.com/dapr/components-contrib/state/utils"
stateutils "github.com/dapr/components-contrib/state/utils"
"github.com/dapr/kit/logger"
// Blank import for the underlying Oracle Database driver.
@ -36,7 +35,6 @@ import (
const (
connectionStringKey = "connectionString"
oracleWalletLocationKey = "oracleWalletLocation"
metadataTTLKey = "ttlInSeconds"
errMissingConnectionString = "missing connection string"
tableName = "state"
)
@ -115,20 +113,6 @@ func (o *oracleDatabaseAccess) Init(metadata state.Metadata) error {
return nil
}
func parseTTL(requestMetadata map[string]string) (*int, error) {
if val, found := requestMetadata[metadataTTLKey]; found && val != "" {
parsedVal, err := strconv.ParseInt(val, 10, 0)
if err != nil {
return nil, fmt.Errorf("error in parsing ttl metadata : %w", err)
}
parsedInt := int(parsedVal)
return &parsedInt, nil
}
return nil, nil
}
// Set makes an insert or update to the database.
func (o *oracleDatabaseAccess) Set(ctx context.Context, req *state.SetRequest) error {
o.logger.Debug("Setting state value in OracleDatabase")
@ -149,19 +133,12 @@ func (o *oracleDatabaseAccess) Set(ctx context.Context, req *state.SetRequest) e
return fmt.Errorf("when FirstWrite is to be enforced, a value must be provided for the ETag")
}
var ttlSeconds int
ttl, ttlerr := parseTTL(req.Metadata)
ttl, ttlerr := stateutils.ParseTTL(req.Metadata)
if ttlerr != nil {
return fmt.Errorf("error in parsing TTL %w", ttlerr)
return fmt.Errorf("error parsing TTL: %w", ttlerr)
}
if ttl != nil {
if *ttl == -1 {
o.logger.Debugf("TTL is set to -1; this means: never expire. ")
} else {
if *ttl < -1 {
return fmt.Errorf("incorrect value for %s %d", metadataTTLKey, *ttl)
}
ttlSeconds = *ttl
}
ttlSeconds = *ttl
}
requestValue := req.Value
byteArray, isBinary := req.Value.([]uint8)
@ -172,7 +149,7 @@ func (o *oracleDatabaseAccess) Set(ctx context.Context, req *state.SetRequest) e
}
// Convert to json string.
bt, _ := utils.Marshal(requestValue, json.Marshal)
bt, _ := stateutils.Marshal(requestValue, json.Marshal)
value := string(bt)
var result sql.Result

View File

@ -15,6 +15,7 @@ package postgresql
import (
"context"
"database/sql"
"github.com/dapr/components-contrib/state"
)
@ -31,3 +32,12 @@ type dbAccess interface {
Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error)
Close() error // io.Closer
}
// Interface that contains methods for querying.
// Applies to both *sql.DB and *sql.Tx
type dbquerier interface {
Exec(query string, args ...any) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryRow(query string, args ...any) *sql.Row
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}

View File

@ -0,0 +1,227 @@
/*
Copyright 2021 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package postgresql
import (
"context"
"database/sql"
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/dapr/kit/logger"
)
// Performs migrations for the database schema
type migrations struct {
Logger logger.Logger
Conn *sql.DB
StateTableName string
MetadataTableName string
}
// Perform the required migrations
func (m *migrations) Perform(ctx context.Context) error {
// Use an advisory lock (with an arbitrary number) to ensure that no one else is performing migrations at the same time
// This is the only way to also ensure we are not running multiple "CREATE TABLE IF NOT EXISTS" at the exact same time
// See: https://www.postgresql.org/message-id/CA+TgmoZAdYVtwBfp1FL2sMZbiHCWT4UPrzRLNnX1Nb30Ku3-gg@mail.gmail.com
const lockID = 42
// Long timeout here as this query may block
queryCtx, cancel := context.WithTimeout(ctx, time.Minute)
_, err := m.Conn.ExecContext(queryCtx, "SELECT pg_advisory_lock($1)", lockID)
cancel()
if err != nil {
return fmt.Errorf("faild to acquire advisory lock: %w", err)
}
// Release the lock
defer func() {
queryCtx, cancel = context.WithTimeout(ctx, time.Minute)
_, err = m.Conn.ExecContext(queryCtx, "SELECT pg_advisory_unlock($1)", lockID)
cancel()
if err != nil {
// Panicking here, as this forcibly closes the session and thus ensures we are not leaving locks hanging around
m.Logger.Fatalf("Failed to release advisory lock: %v", err)
}
}()
// Check if the metadata table exists, which we also use to store the migration level
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
exists, _, _, err := m.tableExists(queryCtx, m.MetadataTableName)
cancel()
if err != nil {
return err
}
// If the table doesn't exist, create it
if !exists {
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
err = m.createMetadataTable(queryCtx)
cancel()
if err != nil {
return err
}
}
// Select the migration level
var (
migrationLevelStr string
migrationLevel int
)
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
err = m.Conn.
QueryRowContext(queryCtx,
fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, m.MetadataTableName),
).Scan(&migrationLevelStr)
cancel()
if errors.Is(err, sql.ErrNoRows) {
// If there's no row...
migrationLevel = 0
} else if err != nil {
return fmt.Errorf("failed to read migration level: %w", err)
} else {
migrationLevel, err = strconv.Atoi(migrationLevelStr)
if err != nil || migrationLevel < 0 {
return fmt.Errorf("invalid migration level found in metadata table: %s", migrationLevelStr)
}
}
// Perform the migrations
for i := migrationLevel; i < len(allMigrations); i++ {
m.Logger.Infof("Performing migration %d", i)
err = allMigrations[i](ctx, m)
if err != nil {
return fmt.Errorf("failed to perform migration %d: %w", i, err)
}
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
_, err = m.Conn.ExecContext(queryCtx,
fmt.Sprintf(`INSERT INTO %s (key, value) VALUES ('migrations', $1) ON CONFLICT (key) DO UPDATE SET value = $1`, m.MetadataTableName),
strconv.Itoa(i+1),
)
cancel()
if err != nil {
return fmt.Errorf("failed to update migration level in metadata table: %w", err)
}
}
return nil
}
func (m migrations) createMetadataTable(ctx context.Context) error {
m.Logger.Infof("Creating metadata table '%s'", m.MetadataTableName)
// Add an "IF NOT EXISTS" in case another Dapr sidecar is creating the same table at the same time
// In the next step we'll acquire a lock so there won't be issues with concurrency
_, err := m.Conn.Exec(fmt.Sprintf(
`CREATE TABLE IF NOT EXISTS %s (
key text NOT NULL PRIMARY KEY,
value text NOT NULL
)`,
m.MetadataTableName,
))
if err != nil {
return fmt.Errorf("failed to create metadata table: %w", err)
}
return nil
}
// If the table exists, returns true and the name of the table and schema
func (m migrations) tableExists(ctx context.Context, tableName string) (exists bool, schema string, table string, err error) {
table, schema, err = m.tableSchemaName(tableName)
if err != nil {
return false, "", "", err
}
if schema == "" {
err = m.Conn.
QueryRowContext(
ctx,
`SELECT table_name, table_schema
FROM information_schema.tables
WHERE table_name = $1`,
table,
).
Scan(&table, &schema)
} else {
err = m.Conn.
QueryRowContext(
ctx,
`SELECT table_name, table_schema
FROM information_schema.tables
WHERE table_schema = $1 AND table_name = $2`,
schema, table,
).
Scan(&table, &schema)
}
if err != nil && errors.Is(err, sql.ErrNoRows) {
return false, "", "", nil
} else if err != nil {
return false, "", "", fmt.Errorf("failed to check if table '%s' exists: %w", tableName, err)
}
return true, schema, table, nil
}
// If the table name includes a schema (e.g. `schema.table`, returns the two parts separately)
func (m migrations) tableSchemaName(tableName string) (table string, schema string, err error) {
parts := strings.Split(tableName, ".")
switch len(parts) {
case 1:
return parts[0], "", nil
case 2:
return parts[1], parts[0], nil
default:
return "", "", errors.New("invalid table name: must be in the format 'table' or 'schema.table'")
}
}
var allMigrations = [2]func(ctx context.Context, m *migrations) error{
// Migration 0: create the state table
func(ctx context.Context, m *migrations) error {
// We need to add an "IF NOT EXISTS" because we may be migrating from when we did not use a metadata table
m.Logger.Infof("Creating state table '%s'", m.StateTableName)
_, err := m.Conn.Exec(
fmt.Sprintf(
`CREATE TABLE IF NOT EXISTS %s (
key text NOT NULL PRIMARY KEY,
value jsonb NOT NULL,
isbinary boolean NOT NULL,
insertdate TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
updatedate TIMESTAMP WITH TIME ZONE NULL
)`,
m.StateTableName,
),
)
if err != nil {
return fmt.Errorf("failed to create state table: %w", err)
}
return nil
},
// Migration 1: add the "expiredate" column
func(ctx context.Context, m *migrations) error {
m.Logger.Infof("Adding expiredate column to state table '%s'", m.StateTableName)
_, err := m.Conn.Exec(fmt.Sprintf(
`ALTER TABLE %s ADD expiredate TIMESTAMP WITH TIME ZONE`,
m.StateTableName,
))
if err != nil {
return fmt.Errorf("failed to update state table: %w", err)
}
return nil
},
}

View File

@ -26,34 +26,38 @@ import (
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/state"
"github.com/dapr/components-contrib/state/query"
"github.com/dapr/components-contrib/state/utils"
stateutils "github.com/dapr/components-contrib/state/utils"
"github.com/dapr/kit/logger"
"github.com/dapr/kit/ptr"
// Blank import for the underlying PostgreSQL driver.
// Blank import for the underlying Postgres driver.
_ "github.com/jackc/pgx/v5/stdlib"
)
const (
connectionStringKey = "connectionString"
errMissingConnectionString = "missing connection string"
defaultTableName = "state"
defaultTableName = "state"
defaultMetadataTableName = "dapr_metadata"
cleanupIntervalKey = "cleanupIntervalInSeconds"
defaultCleanupInternal = 3600 // In seconds = 1 hour
)
// postgresDBAccess implements dbaccess.
type postgresDBAccess struct {
logger logger.Logger
metadata postgresMetadataStruct
db *sql.DB
connectionString string
tableName string
var errMissingConnectionString = errors.New("missing connection string")
// PostgresDBAccess implements dbaccess.
type PostgresDBAccess struct {
logger logger.Logger
metadata postgresMetadataStruct
cleanupInterval *time.Duration
db *sql.DB
ctx context.Context
cancel context.CancelFunc
}
// newPostgresDBAccess creates a new instance of postgresAccess.
func newPostgresDBAccess(logger logger.Logger) *postgresDBAccess {
logger.Debug("Instantiating new PostgreSQL state store")
func newPostgresDBAccess(logger logger.Logger) *PostgresDBAccess {
logger.Debug("Instantiating new Postgres state store")
return &postgresDBAccess{
return &PostgresDBAccess{
logger: logger,
}
}
@ -61,14 +65,66 @@ func newPostgresDBAccess(logger logger.Logger) *postgresDBAccess {
type postgresMetadataStruct struct {
ConnectionString string
ConnectionMaxIdleTime time.Duration
TableName string
TableName string // Could be in the format "schema.table" or just "table"
MetadataTableName string // Could be in the format "schema.table" or just "table"
}
// Init sets up PostgreSQL connection and ensures that the state table exists.
func (p *postgresDBAccess) Init(meta state.Metadata) error {
p.logger.Debug("Initializing PostgreSQL state store")
// Init sets up Postgres connection and ensures that the state table exists.
func (p *PostgresDBAccess) Init(meta state.Metadata) error {
p.logger.Debug("Initializing Postgres state store")
p.ctx, p.cancel = context.WithCancel(context.Background())
err := p.ParseMetadata(meta)
if err != nil {
p.logger.Errorf("Failed to parse metadata: %v", err)
return err
}
db, err := sql.Open("pgx", p.metadata.ConnectionString)
if err != nil {
p.logger.Error(err)
return err
}
p.db = db
pingCtx, pingCancel := context.WithTimeout(p.ctx, 30*time.Second)
pingErr := db.PingContext(pingCtx)
pingCancel()
if pingErr != nil {
return pingErr
}
p.db.SetConnMaxIdleTime(p.metadata.ConnectionMaxIdleTime)
if err != nil {
return err
}
migrate := &migrations{
Logger: p.logger,
Conn: p.db,
MetadataTableName: p.metadata.MetadataTableName,
StateTableName: p.metadata.TableName,
}
err = migrate.Perform(p.ctx)
if err != nil {
return err
}
p.ScheduleCleanupExpiredData(p.ctx)
return nil
}
func (p *PostgresDBAccess) GetDB() *sql.DB {
return p.db
}
func (p *PostgresDBAccess) ParseMetadata(meta state.Metadata) error {
m := postgresMetadataStruct{
TableName: defaultTableName,
TableName: defaultTableName,
MetadataTableName: defaultMetadataTableName,
}
err := metadata.DecodeMetadata(meta.Properties, &m)
if err != nil {
@ -77,44 +133,33 @@ func (p *postgresDBAccess) Init(meta state.Metadata) error {
p.metadata = m
if m.ConnectionString == "" {
p.logger.Error("Missing postgreSQL connection string")
return errors.New(errMissingConnectionString)
}
p.connectionString = m.ConnectionString
db, err := sql.Open("pgx", p.connectionString)
if err != nil {
p.logger.Error(err)
return err
return errMissingConnectionString
}
p.db = db
s, ok := meta.Properties[cleanupIntervalKey]
if ok && s != "" {
cleanupIntervalInSec, err := strconv.ParseInt(s, 10, 0)
if err != nil {
return fmt.Errorf("invalid value for '%s': %s", cleanupIntervalKey, s)
}
pingErr := db.Ping()
if pingErr != nil {
return pingErr
// Non-positive value from meta means disable auto cleanup.
if cleanupIntervalInSec > 0 {
p.cleanupInterval = ptr.Of(time.Duration(cleanupIntervalInSec) * time.Second)
}
} else {
p.cleanupInterval = ptr.Of(defaultCleanupInternal * time.Second)
}
p.db.SetConnMaxIdleTime(m.ConnectionMaxIdleTime)
if err != nil {
return err
}
err = p.ensureStateTable(m.TableName)
if err != nil {
return err
}
p.tableName = m.TableName
return nil
}
// Set makes an insert or update to the database.
func (p *postgresDBAccess) Set(ctx context.Context, req *state.SetRequest) error {
p.logger.Debug("Setting state value in PostgreSQL")
func (p *PostgresDBAccess) Set(ctx context.Context, req *state.SetRequest) error {
return p.doSet(ctx, p.db, req)
}
func (p *PostgresDBAccess) doSet(parentCtx context.Context, db dbquerier, req *state.SetRequest) error {
err := state.CheckRequestOptions(req.Options)
if err != nil {
return err
@ -135,22 +180,47 @@ func (p *postgresDBAccess) Set(ctx context.Context, req *state.SetRequest) error
}
// Convert to json string
bt, _ := utils.Marshal(v, json.Marshal)
bt, _ := stateutils.Marshal(v, json.Marshal)
value := string(bt)
// TTL
var ttlSeconds int
ttl, ttlerr := stateutils.ParseTTL(req.Metadata)
if ttlerr != nil {
return fmt.Errorf("error parsing TTL: %w", ttlerr)
}
if ttl != nil {
ttlSeconds = *ttl
}
var result sql.Result
// Sprintf is required for table name because sql.DB does not substitute parameters for table names.
// Other parameters use sql.DB parameter substitution.
if req.Options.Concurrency == state.FirstWrite && (req.ETag == nil || *req.ETag == "") {
result, err = p.db.ExecContext(ctx, fmt.Sprintf(
`INSERT INTO %s (key, value, isbinary) VALUES ($1, $2, $3);`,
p.tableName), req.Key, value, isBinary)
} else if req.ETag == nil || *req.ETag == "" {
result, err = p.db.ExecContext(ctx, fmt.Sprintf(
`INSERT INTO %s (key, value, isbinary) VALUES ($1, $2, $3)
ON CONFLICT (key) DO UPDATE SET value = $2, isbinary = $3, updatedate = NOW();`,
p.tableName), req.Key, value, isBinary)
// Sprintf is required for table name because query.DB does not substitute parameters for table names.
// Other parameters use query.DB parameter substitution.
var (
query string
queryExpiredate string
params []any
)
if req.ETag == nil || *req.ETag == "" {
if req.Options.Concurrency == state.FirstWrite {
query = `INSERT INTO %[1]s
(key, value, isbinary, expiredate)
VALUES
($1, $2, $3, %[2]s)`
} else {
query = `INSERT INTO %[1]s
(key, value, isbinary, expiredate)
VALUES
($1, $2, $3, %[2]s)
ON CONFLICT (key)
DO UPDATE SET
value = $2,
isbinary = $3,
updatedate = CURRENT_TIMESTAMP,
expiredate = %[2]s`
}
params = []any{req.Key, value, isBinary}
} else {
// Convert req.ETag to uint32 for postgres XID compatibility
var etag64 uint64
@ -158,20 +228,30 @@ func (p *postgresDBAccess) Set(ctx context.Context, req *state.SetRequest) error
if err != nil {
return state.NewETagError(state.ETagInvalid, err)
}
etag := uint32(etag64)
// When an etag is provided do an update - no insert
result, err = p.db.ExecContext(ctx, fmt.Sprintf(
`UPDATE %s SET value = $1, isbinary = $2, updatedate = NOW()
WHERE key = $3 AND xmin = $4;`,
p.tableName), value, isBinary, req.Key, etag)
query = `UPDATE %[1]s
SET
value = $1,
isbinary = $2,
updatedate = CURRENT_TIMESTAMP,
expiredate = %[2]s
WHERE
key = $3
AND xmin = $4`
params = []any{value, isBinary, req.Key, uint32(etag64)}
}
if ttlSeconds > 0 {
queryExpiredate = "CURRENT_TIMESTAMP + interval '" + strconv.Itoa(ttlSeconds) + " seconds'"
} else {
queryExpiredate = "NULL"
}
result, err = db.ExecContext(parentCtx, fmt.Sprintf(query, p.metadata.TableName, queryExpiredate), params...)
if err != nil {
if req.ETag != nil && *req.ETag != "" {
return state.NewETagError(state.ETagMismatch, err)
}
return err
}
@ -179,7 +259,6 @@ func (p *postgresDBAccess) Set(ctx context.Context, req *state.SetRequest) error
if err != nil {
return err
}
if rows != 1 {
return errors.New("no item was updated")
}
@ -187,33 +266,32 @@ func (p *postgresDBAccess) Set(ctx context.Context, req *state.SetRequest) error
return nil
}
func (p *postgresDBAccess) BulkSet(ctx context.Context, req []state.SetRequest) error {
p.logger.Debug("Executing BulkSet request")
tx, err := p.db.Begin()
func (p *PostgresDBAccess) BulkSet(parentCtx context.Context, req []state.SetRequest) error {
tx, err := p.db.BeginTx(parentCtx, nil)
if err != nil {
return err
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback()
if len(req) > 0 {
for _, s := range req {
sa := s // Fix for gosec G601: Implicit memory aliasing in for loop.
err = p.Set(ctx, &sa)
for i := range req {
err = p.doSet(parentCtx, tx, &req[i])
if err != nil {
tx.Rollback()
return err
}
}
}
err = tx.Commit()
if err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return err
return nil
}
// Get returns data from the database. If data does not exist for the key an empty state.GetResponse will be returned.
func (p *postgresDBAccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
p.logger.Debug("Getting state value from PostgreSQL")
func (p *PostgresDBAccess) Get(parentCtx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
if req.Key == "" {
return nil, errors.New("missing key in get operation")
}
@ -223,7 +301,15 @@ func (p *postgresDBAccess) Get(ctx context.Context, req *state.GetRequest) (*sta
isBinary bool
etag uint64 // Postgres uses uint32, but FormatUint requires uint64, so using uint64 directly to avoid re-allocations
)
err := p.db.QueryRowContext(ctx, fmt.Sprintf("SELECT value, isbinary, xmin as etag FROM %s WHERE key = $1", p.tableName), req.Key).Scan(&value, &isBinary, &etag)
query := `SELECT
value, isbinary, xmin AS etag
FROM %s
WHERE
key = $1
AND (expiredate IS NULL OR expiredate >= CURRENT_TIMESTAMP)`
err := p.db.
QueryRowContext(parentCtx, fmt.Sprintf(query, p.metadata.TableName), req.Key).
Scan(&value, &isBinary, &etag)
if err != nil {
// If no rows exist, return an empty response, otherwise return the error.
if err == sql.ErrNoRows {
@ -261,8 +347,11 @@ func (p *postgresDBAccess) Get(ctx context.Context, req *state.GetRequest) (*sta
}
// Delete removes an item from the state store.
func (p *postgresDBAccess) Delete(ctx context.Context, req *state.DeleteRequest) (err error) {
p.logger.Debug("Deleting state value from PostgreSQL")
func (p *PostgresDBAccess) Delete(ctx context.Context, req *state.DeleteRequest) (err error) {
return p.doDelete(ctx, p.db, req)
}
func (p *PostgresDBAccess) doDelete(parentCtx context.Context, db dbquerier, req *state.DeleteRequest) (err error) {
if req.Key == "" {
return errors.New("missing key in delete operation")
}
@ -270,7 +359,7 @@ func (p *postgresDBAccess) Delete(ctx context.Context, req *state.DeleteRequest)
var result sql.Result
if req.ETag == nil || *req.ETag == "" {
result, err = p.db.ExecContext(ctx, "DELETE FROM state WHERE key = $1", req.Key)
result, err = db.ExecContext(parentCtx, "DELETE FROM state WHERE key = $1", req.Key)
} else {
// Convert req.ETag to uint32 for postgres XID compatibility
var etag64 uint64
@ -280,7 +369,7 @@ func (p *postgresDBAccess) Delete(ctx context.Context, req *state.DeleteRequest)
}
etag := uint32(etag64)
result, err = p.db.ExecContext(ctx, "DELETE FROM state WHERE key = $1 and xmin = $2", req.Key, etag)
result, err = db.ExecContext(parentCtx, "DELETE FROM state WHERE key = $1 AND xmin = $2", req.Key, etag)
}
if err != nil {
@ -299,92 +388,88 @@ func (p *postgresDBAccess) Delete(ctx context.Context, req *state.DeleteRequest)
return nil
}
func (p *postgresDBAccess) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
p.logger.Debug("Executing BulkDelete request")
tx, err := p.db.Begin()
func (p *PostgresDBAccess) BulkDelete(parentCtx context.Context, req []state.DeleteRequest) error {
tx, err := p.db.BeginTx(parentCtx, nil)
if err != nil {
return err
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback()
if len(req) > 0 {
for i := range req {
err = p.Delete(ctx, &req[i])
err = p.doDelete(parentCtx, tx, &req[i])
if err != nil {
tx.Rollback()
return err
}
}
}
err = tx.Commit()
if err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return err
return nil
}
func (p *postgresDBAccess) ExecuteMulti(ctx context.Context, request *state.TransactionalStateRequest) error {
p.logger.Debug("Executing PostgreSQL transaction")
tx, err := p.db.Begin()
func (p *PostgresDBAccess) ExecuteMulti(parentCtx context.Context, request *state.TransactionalStateRequest) error {
tx, err := p.db.BeginTx(parentCtx, nil)
if err != nil {
return err
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback()
for _, o := range request.Operations {
switch o.Operation {
case state.Upsert:
var setReq state.SetRequest
setReq, err = getSet(o)
if err != nil {
tx.Rollback()
return err
}
err = p.Set(ctx, &setReq)
err = p.doSet(parentCtx, tx, &setReq)
if err != nil {
tx.Rollback()
return err
}
case state.Delete:
var delReq state.DeleteRequest
delReq, err = getDelete(o)
if err != nil {
tx.Rollback()
return err
}
err = p.Delete(ctx, &delReq)
err = p.doDelete(parentCtx, tx, &delReq)
if err != nil {
tx.Rollback()
return err
}
default:
tx.Rollback()
return fmt.Errorf("unsupported operation: %s", o.Operation)
}
}
err = tx.Commit()
if err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return err
return nil
}
// Query executes a query against store.
func (p *postgresDBAccess) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) {
p.logger.Debug("Getting query value from PostgreSQL")
func (p *PostgresDBAccess) Query(parentCtx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) {
q := &Query{
query: "",
params: []interface{}{},
tableName: p.tableName,
params: []any{},
tableName: p.metadata.TableName,
}
qbuilder := query.NewQueryBuilder(q)
if err := qbuilder.BuildQuery(&req.Query); err != nil {
return &state.QueryResponse{}, err
}
data, token, err := q.execute(ctx, p.logger, p.db)
data, token, err := q.execute(parentCtx, p.logger, p.db)
if err != nil {
return &state.QueryResponse{}, err
}
@ -395,8 +480,94 @@ func (p *postgresDBAccess) Query(ctx context.Context, req *state.QueryRequest) (
}, nil
}
func (p *PostgresDBAccess) ScheduleCleanupExpiredData(ctx context.Context) {
if p.cleanupInterval == nil {
return
}
p.logger.Infof("Schedule expired data clean up every %d seconds", int(p.cleanupInterval.Seconds()))
go func() {
ticker := time.NewTicker(*p.cleanupInterval)
for {
select {
case <-ticker.C:
err := p.CleanupExpired(ctx)
if err != nil {
p.logger.Errorf("Error removing expired data: %v", err)
}
case <-ctx.Done():
p.logger.Debug("Stopped background cleanup of expired data")
return
}
}
}()
}
func (p *PostgresDBAccess) CleanupExpired(ctx context.Context) error {
// Check if the last iteration was too recent
// This performs an atomic operation, so allows coordination with other daprd processes too
canContinue, err := p.UpdateLastCleanup(ctx, p.db, *p.cleanupInterval)
if err != nil {
// Log errors only
p.logger.Warnf("Failed to read last cleanup time from database: %v", err)
}
if !canContinue {
p.logger.Debug("Last cleanup was performed too recently")
return nil
}
// Note we're not using the transaction here as we don't want this to be rolled back half-way or to lock the table unnecessarily
// Need to use fmt.Sprintf because we can't parametrize a table name
// Note we are not setting a timeout here as this query can take a "long" time, especially if there's no index on expiredate
//nolint:gosec
stmt := fmt.Sprintf(`DELETE FROM %s WHERE expiredate IS NOT NULL AND expiredate < CURRENT_TIMESTAMP`, p.metadata.TableName)
res, err := p.db.ExecContext(ctx, stmt)
if err != nil {
return fmt.Errorf("failed to execute query: %w", err)
}
cleaned, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("failed to count affected rows: %w", err)
}
p.logger.Infof("Removed %d expired rows", cleaned)
return nil
}
// UpdateLastCleanup sets the 'last-cleanup' value only if it's less than cleanupInterval.
// Returns true if the row was updated, which means that the cleanup can proceed.
func (p *PostgresDBAccess) UpdateLastCleanup(ctx context.Context, db dbquerier, cleanupInterval time.Duration) (bool, error) {
queryCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
res, err := db.ExecContext(queryCtx,
fmt.Sprintf(`INSERT INTO %[1]s (key, value)
VALUES ('last-cleanup', CURRENT_TIMESTAMP)
ON CONFLICT (key)
DO UPDATE SET value = CURRENT_TIMESTAMP
WHERE (EXTRACT('epoch' FROM CURRENT_TIMESTAMP - %[1]s.value::timestamp with time zone) * 1000)::bigint > $1`,
p.metadata.MetadataTableName),
cleanupInterval.Milliseconds()-100, // Subtract 100ms for some buffer
)
cancel()
if err != nil {
return true, fmt.Errorf("failed to execute query: %w", err)
}
n, err := res.RowsAffected()
if err != nil {
return true, fmt.Errorf("failed to count affected rows: %w", err)
}
return n > 0, nil
}
// Close implements io.Close.
func (p *postgresDBAccess) Close() error {
func (p *PostgresDBAccess) Close() error {
if p.cancel != nil {
p.cancel()
p.cancel = nil
}
if p.db != nil {
return p.db.Close()
}
@ -404,34 +575,10 @@ func (p *postgresDBAccess) Close() error {
return nil
}
func (p *postgresDBAccess) ensureStateTable(stateTableName string) error {
exists, err := tableExists(p.db, stateTableName)
if err != nil {
return err
}
if !exists {
p.logger.Info("Creating PostgreSQL state table")
createTable := fmt.Sprintf(`CREATE TABLE %s (
key text NOT NULL PRIMARY KEY,
value jsonb NOT NULL,
isbinary boolean NOT NULL,
insertdate TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
updatedate TIMESTAMP WITH TIME ZONE NULL);`, stateTableName)
_, err = p.db.Exec(createTable)
if err != nil {
return err
}
}
return nil
}
func tableExists(db *sql.DB, tableName string) (bool, error) {
exists := false
err := db.QueryRow("SELECT EXISTS (SELECT FROM pg_tables where tablename = $1)", tableName).Scan(&exists)
return exists, err
// GetCleanupInterval returns the cleanupInterval property.
// This is primarily used for tests.
func (p *PostgresDBAccess) GetCleanupInterval() *time.Duration {
return p.cleanupInterval
}
// Returns the set requests.

View File

@ -18,18 +18,23 @@ import (
"context"
"database/sql"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/state"
"github.com/dapr/kit/logger"
// Blank import for pgx
_ "github.com/jackc/pgx/v5/stdlib"
)
type mocks struct {
db *sql.DB
mock sqlmock.Sqlmock
pgDba *postgresDBAccess
pgDba *PostgresDBAccess
}
func TestGetSetWithWrongType(t *testing.T) {
@ -451,7 +456,7 @@ func mockDatabase(t *testing.T) (*mocks, error) {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
dba := &postgresDBAccess{
dba := &PostgresDBAccess{
logger: logger,
db: db,
}
@ -462,3 +467,95 @@ func mockDatabase(t *testing.T) (*mocks, error) {
pgDba: dba,
}, err
}
func TestParseMetadata(t *testing.T) {
t.Run("missing connection string", func(t *testing.T) {
p := &PostgresDBAccess{}
props := map[string]string{}
err := p.ParseMetadata(state.Metadata{Base: metadata.Base{Properties: props}})
assert.Error(t, err)
assert.ErrorIs(t, err, errMissingConnectionString)
})
t.Run("has connection string", func(t *testing.T) {
p := &PostgresDBAccess{}
props := map[string]string{
"connectionString": "foo",
}
err := p.ParseMetadata(state.Metadata{Base: metadata.Base{Properties: props}})
assert.NoError(t, err)
})
t.Run("default table name", func(t *testing.T) {
p := &PostgresDBAccess{}
props := map[string]string{
"connectionString": "foo",
}
err := p.ParseMetadata(state.Metadata{Base: metadata.Base{Properties: props}})
assert.NoError(t, err)
assert.Equal(t, p.metadata.TableName, defaultTableName)
})
t.Run("custom table name", func(t *testing.T) {
p := &PostgresDBAccess{}
props := map[string]string{
"connectionString": "foo",
"tableName": "mytable",
}
err := p.ParseMetadata(state.Metadata{Base: metadata.Base{Properties: props}})
assert.NoError(t, err)
assert.Equal(t, p.metadata.TableName, "mytable")
})
t.Run("default cleanupIntervalInSeconds", func(t *testing.T) {
p := &PostgresDBAccess{}
props := map[string]string{
"connectionString": "foo",
}
err := p.ParseMetadata(state.Metadata{Base: metadata.Base{Properties: props}})
assert.NoError(t, err)
_ = assert.NotNil(t, p.cleanupInterval) &&
assert.Equal(t, *p.cleanupInterval, defaultCleanupInternal*time.Second)
})
t.Run("invalid cleanupIntervalInSeconds", func(t *testing.T) {
p := &PostgresDBAccess{}
props := map[string]string{
"connectionString": "foo",
"cleanupIntervalInSeconds": "NaN",
}
err := p.ParseMetadata(state.Metadata{Base: metadata.Base{Properties: props}})
assert.Error(t, err)
})
t.Run("positive cleanupIntervalInSeconds", func(t *testing.T) {
p := &PostgresDBAccess{}
props := map[string]string{
"connectionString": "foo",
"cleanupIntervalInSeconds": "42",
}
err := p.ParseMetadata(state.Metadata{Base: metadata.Base{Properties: props}})
assert.NoError(t, err)
_ = assert.NotNil(t, p.cleanupInterval) &&
assert.Equal(t, *p.cleanupInterval, 42*time.Second)
})
t.Run("zero cleanupIntervalInSeconds", func(t *testing.T) {
p := &PostgresDBAccess{}
props := map[string]string{
"connectionString": "foo",
"cleanupIntervalInSeconds": "0",
}
err := p.ParseMetadata(state.Metadata{Base: metadata.Base{Properties: props}})
assert.NoError(t, err)
assert.Nil(t, p.cleanupInterval)
})
}

View File

@ -24,7 +24,6 @@ import (
// PostgreSQL state store.
type PostgreSQL struct {
features []state.Feature
logger logger.Logger
dbaccess dbAccess
}
@ -40,7 +39,6 @@ func NewPostgreSQLStateStore(logger logger.Logger) state.Store {
// This unexported constructor allows injecting a dbAccess instance for unit testing.
func newPostgreSQLStateStore(logger logger.Logger, dba dbAccess) *PostgreSQL {
return &PostgreSQL{
features: []state.Feature{state.FeatureETag, state.FeatureTransactional, state.FeatureQueryAPI},
logger: logger,
dbaccess: dba,
}
@ -53,7 +51,7 @@ func (p *PostgreSQL) Init(metadata state.Metadata) error {
// Features returns the features available in this state store.
func (p *PostgreSQL) Features() []state.Feature {
return p.features
return []state.Feature{state.FeatureETag, state.FeatureTransactional, state.FeatureQueryAPI}
}
// Delete removes an entity from the store.
@ -102,10 +100,15 @@ func (p *PostgreSQL) Close() error {
if p.dbaccess != nil {
return p.dbaccess.Close()
}
return nil
}
// Returns the dbaccess property.
// This method is used in tests.
func (p *PostgreSQL) GetDBAccess() dbAccess {
return p.dbaccess
}
func (p *PostgreSQL) GetComponentMetadata() map[string]string {
metadataStruct := postgresMetadataStruct{}
metadataInfo := map[string]string{}

View File

@ -49,7 +49,7 @@ func TestPostgreSQLIntegration(t *testing.T) {
})
metadata := state.Metadata{
Base: metadata.Base{Properties: map[string]string{connectionStringKey: connectionString}},
Base: metadata.Base{Properties: map[string]string{"connectionString": connectionString}},
}
pgs := NewPostgreSQLStateStore(logger.NewLogger("test")).(*PostgreSQL)
@ -62,11 +62,6 @@ func TestPostgreSQLIntegration(t *testing.T) {
t.Fatal(error)
}
t.Run("Create table succeeds", func(t *testing.T) {
t.Parallel()
testCreateTable(t, pgs.dbaccess.(*postgresDBAccess))
})
t.Run("Get Set Delete one item", func(t *testing.T) {
t.Parallel()
setGetUpdateDeleteOneItem(t, pgs)
@ -161,33 +156,6 @@ func setGetUpdateDeleteOneItem(t *testing.T, pgs *PostgreSQL) {
deleteItem(t, pgs, key, getResponse.ETag)
}
// testCreateTable tests the ability to create the state table.
func testCreateTable(t *testing.T, dba *postgresDBAccess) {
tableName := "test_state"
// Drop the table if it already exists
exists, err := tableExists(dba.db, tableName)
assert.Nil(t, err)
if exists {
dropTable(t, dba.db, tableName)
}
// Create the state table and test for its existence
err = dba.ensureStateTable(tableName)
assert.Nil(t, err)
exists, err = tableExists(dba.db, tableName)
assert.Nil(t, err)
assert.True(t, exists)
// Drop the state table
dropTable(t, dba.db, tableName)
}
func dropTable(t *testing.T, db *sql.DB, tableName string) {
_, err := db.Exec(fmt.Sprintf("DROP TABLE %s", tableName))
assert.Nil(t, err)
}
func deleteItemThatDoesNotExist(t *testing.T, pgs *PostgreSQL) {
// Delete the item with a key not in the store
deleteReq := &state.DeleteRequest{
@ -477,7 +445,7 @@ func testInitConfiguration(t *testing.T) {
tests := []struct {
name string
props map[string]string
expectedErr string
expectedErr error
}{
{
name: "Empty",
@ -486,8 +454,8 @@ func testInitConfiguration(t *testing.T) {
},
{
name: "Valid connection string",
props: map[string]string{connectionStringKey: getConnectionString()},
expectedErr: "",
props: map[string]string{"connectionString": getConnectionString()},
expectedErr: nil,
},
}
@ -501,11 +469,11 @@ func testInitConfiguration(t *testing.T) {
}
err := p.Init(metadata)
if tt.expectedErr == "" {
assert.Nil(t, err)
if tt.expectedErr == nil {
assert.NoError(t, err)
} else {
assert.NotNil(t, err)
assert.Equal(t, err.Error(), tt.expectedErr)
assert.Error(t, err)
assert.ErrorIs(t, err, tt.expectedErr)
}
})
}

View File

@ -107,7 +107,7 @@ func createPostgreSQL(t *testing.T) *PostgreSQL {
assert.NotNil(t, pgs)
metadata := &state.Metadata{
Base: metadata.Base{Properties: map[string]string{connectionStringKey: fakeConnectionString}},
Base: metadata.Base{Properties: map[string]string{"connectionString": fakeConnectionString}},
}
err := pgs.Init(*metadata)

40
state/utils/ttl.go Normal file
View File

@ -0,0 +1,40 @@
/*
Copyright 2021 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package utils
import (
"fmt"
"math"
"strconv"
)
// Key used for "ttlInSeconds" in metadata.
const MetadataTTLKey = "ttlInSeconds"
// ParseTTL parses the "ttlInSeconds" metadata property.
func ParseTTL(requestMetadata map[string]string) (*int, error) {
val, found := requestMetadata[MetadataTTLKey]
if found && val != "" {
parsedVal, err := strconv.ParseInt(val, 10, 0)
if err != nil {
return nil, fmt.Errorf("incorrect value for metadata '%s': %w", MetadataTTLKey, err)
}
if parsedVal < -1 || parsedVal > math.MaxInt32 {
return nil, fmt.Errorf("incorrect value for metadata '%s': must be -1 or greater", MetadataTTLKey)
}
i := int(parsedVal)
return &i, nil
}
return nil, nil
}

74
state/utils/ttl_test.go Normal file
View File

@ -0,0 +1,74 @@
/*
Copyright 2021 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package utils
import (
"math"
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseTTL(t *testing.T) {
t.Run("TTL Not an integer", func(t *testing.T) {
ttlInSeconds := "not an integer"
ttl, err := ParseTTL(map[string]string{
MetadataTTLKey: ttlInSeconds,
})
require.Error(t, err)
assert.Nil(t, ttl)
})
t.Run("TTL specified with wrong key", func(t *testing.T) {
ttlInSeconds := 12345
ttl, err := ParseTTL(map[string]string{
"expirationTime": strconv.Itoa(ttlInSeconds),
})
require.NoError(t, err)
assert.Nil(t, ttl)
})
t.Run("TTL is a number", func(t *testing.T) {
ttlInSeconds := 12345
ttl, err := ParseTTL(map[string]string{
MetadataTTLKey: strconv.Itoa(ttlInSeconds),
})
require.NoError(t, err)
assert.Equal(t, *ttl, ttlInSeconds)
})
t.Run("TTL not set", func(t *testing.T) {
ttl, err := ParseTTL(map[string]string{})
require.NoError(t, err)
assert.Nil(t, ttl)
})
t.Run("TTL < -1", func(t *testing.T) {
ttl, err := ParseTTL(map[string]string{
MetadataTTLKey: "-3",
})
require.Error(t, err)
assert.Nil(t, ttl)
})
t.Run("TTL bigger than 32-bit", func(t *testing.T) {
ttl, err := ParseTTL(map[string]string{
MetadataTTLKey: strconv.FormatInt(math.MaxInt32+1, 10),
})
require.Error(t, err)
assert.Nil(t, ttl)
})
}

View File

@ -2,8 +2,24 @@
This project aims to test the PostgreSQL State Store component under various conditions.
To run these tests:
```sh
go test -v -tags certtests -count=1 .
```
## Test plan
## Initialization and migrations
Also test the `tableName` and `metadataTableName` metadata properties.
1. Initializes the component with names for tables that don't exist
2. Initializes the component with names for tables that don't exist, specifying an explicit schema
3. Initializes the component with all migrations performed (current level is "2")
4. Initializes the component with only the state table, created before the metadata table was added (implied migration level "1")
5. Initializes three components at the same time and ensure no race conditions exist in performing migrations
## Test for CRUD operations
1. Able to create and test connection.
@ -16,6 +32,15 @@ This project aims to test the PostgreSQL State Store component under various con
* Not prone to SQL injection on read
* Not prone to SQL injection on delete
## TTLs and cleanups
1. Correctly parse the `cleanupIntervalInSeconds` metadata property:
- No value uses the default value (3600 seconds)
- A positive value sets the interval to the given number of seconds
- A zero or negative value disables the cleanup
2. The cleanup method deletes expired records and updates the metadata table with the last time it ran
3. The cleanup method doesn't run if the last iteration was less than `cleanupIntervalInSeconds` or if another process is doing the cleanup
## Connection Recovery
1. When PostgreSQL goes down and then comes back up - client is able to connect

View File

@ -8,6 +8,7 @@ require (
github.com/dapr/dapr v1.9.5
github.com/dapr/go-sdk v1.6.0
github.com/dapr/kit v0.0.3
github.com/jackc/pgx/v5 v5.2.0
github.com/stretchr/testify v1.8.1
)
@ -61,7 +62,6 @@ require (
github.com/imdario/mergo v0.3.12 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect
github.com/jackc/pgx/v5 v5.2.0 // indirect
github.com/jhump/protoreflect v1.13.0 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect

View File

@ -14,27 +14,36 @@ limitations under the License.
package main
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"strconv"
"sync/atomic"
"testing"
"time"
pgx "github.com/jackc/pgx/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/tests/certification/embedded"
"github.com/dapr/components-contrib/tests/certification/flow"
"github.com/dapr/components-contrib/tests/certification/flow/dockercompose"
"github.com/dapr/components-contrib/tests/certification/flow/sidecar"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dapr/components-contrib/state"
state_postgres "github.com/dapr/components-contrib/state/postgresql"
state_loader "github.com/dapr/dapr/pkg/components/state"
"github.com/dapr/dapr/pkg/runtime"
dapr_testing "github.com/dapr/dapr/pkg/testing"
"github.com/dapr/kit/logger"
"github.com/dapr/go-sdk/client"
"github.com/dapr/kit/logger"
)
const (
@ -42,6 +51,11 @@ const (
dockerComposeYAML = "docker-compose.yml"
stateStoreName = "statestore"
certificationTestPrefix = "stable-certification-"
connStringValue = "postgres://postgres:example@localhost:5432/dapr_test"
keyConnectionString = "connectionString"
keyCleanupInterval = "cleanupIntervalInSeconds"
keyTableName = "tableName"
keyMetadatTableName = "metadataTableName"
)
func TestPostgreSQL(t *testing.T) {
@ -59,8 +73,238 @@ func TestPostgreSQL(t *testing.T) {
currentGrpcPort := ports[0]
// Update this constant if you add more migrations
const migrationLevel = "2"
// Holds a DB client as the "postgres" (ie. "root") user which we'll use to validate migrations and other changes in state
var dbClient *pgx.Conn
connectStep := func(ctx flow.Context) error {
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
// Continue re-trying until the context times out, so we can wait for the DB to be up
for {
dbClient, err = pgx.Connect(connCtx, connStringValue)
if err == nil || connCtx.Err() != nil {
break
}
time.Sleep(750 * time.Millisecond)
}
return err
}
// Tests the "Init" method and the database migrations
// It also tests the metadata properties "tableName" and "metadataTableName"
initTest := func(ctx flow.Context) error {
md := state.Metadata{
Base: metadata.Base{
Name: "inittest",
Properties: map[string]string{
keyConnectionString: connStringValue,
keyCleanupInterval: "-1",
},
},
}
t.Run("initial state clean", func(t *testing.T) {
storeObj := state_postgres.NewPostgreSQLStateStore(log).(*state_postgres.PostgreSQL)
md.Properties[keyTableName] = "clean_state"
md.Properties[keyMetadatTableName] = "clean_metadata"
// Init and perform the migrations
err := storeObj.Init(md)
require.NoError(t, err, "failed to init")
// We should have the tables correctly created
err = tableExists(dbClient, "public", "clean_state")
assert.NoError(t, err, "state table does not exist")
err = tableExists(dbClient, "public", "clean_metadata")
assert.NoError(t, err, "metadata table does not exist")
// Ensure migration level is correct
level, err := getMigrationLevel(dbClient, "clean_metadata")
assert.NoError(t, err, "failed to get migration level")
assert.Equal(t, migrationLevel, level, "migration level mismatch: found '%s' but expected '%s'", level, migrationLevel)
err = storeObj.Close()
require.NoError(t, err, "failed to close component")
})
t.Run("initial state clean, with explicit schema name", func(t *testing.T) {
storeObj := state_postgres.NewPostgreSQLStateStore(log).(*state_postgres.PostgreSQL)
md.Properties[keyTableName] = "public.clean2_state"
md.Properties[keyMetadatTableName] = "public.clean2_metadata"
// Init and perform the migrations
err := storeObj.Init(md)
require.NoError(t, err, "failed to init")
// We should have the tables correctly created
err = tableExists(dbClient, "public", "clean2_state")
assert.NoError(t, err, "state table does not exist")
err = tableExists(dbClient, "public", "clean2_metadata")
assert.NoError(t, err, "metadata table does not exist")
// Ensure migration level is correct
level, err := getMigrationLevel(dbClient, "clean2_metadata")
assert.NoError(t, err, "failed to get migration level")
assert.Equal(t, migrationLevel, level, "migration level mismatch: found '%s' but expected '%s'", level, migrationLevel)
err = storeObj.Close()
require.NoError(t, err, "failed to close component")
})
t.Run("all migrations performed", func(t *testing.T) {
// Re-use "clean_state" and "clean_metadata"
storeObj := state_postgres.NewPostgreSQLStateStore(log).(*state_postgres.PostgreSQL)
md.Properties[keyTableName] = "clean_state"
md.Properties[keyMetadatTableName] = "clean_metadata"
// Should already have migration level 2
level, err := getMigrationLevel(dbClient, "clean_metadata")
assert.NoError(t, err, "failed to get migration level")
assert.Equal(t, migrationLevel, level, "migration level mismatch: found '%s' but expected '%s'", level, migrationLevel)
// Init and perform the migrations
err = storeObj.Init(md)
require.NoError(t, err, "failed to init")
// Ensure migration level is correct
level, err = getMigrationLevel(dbClient, "clean_metadata")
assert.NoError(t, err, "failed to get migration level")
assert.Equal(t, migrationLevel, level, "migration level mismatch: found '%s' but expected '%s'", level, migrationLevel)
err = storeObj.Close()
require.NoError(t, err, "failed to close component")
})
t.Run("migrate from implied level 1", func(t *testing.T) {
// Before we added the metadata table, the "implied" level 1 had only the state table
// Create that table to simulate the old state and validate the migration
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
_, err := dbClient.Exec(
ctx,
`CREATE TABLE pre_state (
key text NOT NULL PRIMARY KEY,
value jsonb NOT NULL,
isbinary boolean NOT NULL,
insertdate TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
updatedate TIMESTAMP WITH TIME ZONE NULL
)`,
)
require.NoError(t, err, "failed to create initial migration state")
storeObj := state_postgres.NewPostgreSQLStateStore(log).(*state_postgres.PostgreSQL)
md.Properties[keyTableName] = "pre_state"
md.Properties[keyMetadatTableName] = "pre_metadata"
// Init and perform the migrations
err = storeObj.Init(md)
require.NoError(t, err, "failed to init")
// We should have the metadata table created
err = tableExists(dbClient, "public", "pre_metadata")
assert.NoError(t, err, "metadata table does not exist")
// Ensure migration level is correct
level, err := getMigrationLevel(dbClient, "pre_metadata")
assert.NoError(t, err, "failed to get migration level")
assert.Equal(t, migrationLevel, level, "migration level mismatch: found '%s' but expected '%s'", level, migrationLevel)
// Ensure the expiredate column has been added
var colExists bool
ctx, cancel = context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
err = dbClient.
QueryRow(ctx,
`SELECT EXISTS (
SELECT 1
FROM information_schema.columns
WHERE
table_schema = 'public'
AND table_name = 'pre_state'
AND column_name = 'expiredate'
)`,
).
Scan(&colExists)
assert.True(t, colExists, "column expiredate not found in updated table")
err = storeObj.Close()
require.NoError(t, err, "failed to close component")
})
t.Run("initialize components concurrently", func(t *testing.T) {
// Initializes 3 components concurrently using the same table names, and ensure that they perform migrations without conflicts and race conditions
md.Properties[keyTableName] = "mystate"
md.Properties[keyMetadatTableName] = "mymetadata"
errs := make(chan error, 3)
hasLogs := atomic.Int32{}
for i := 0; i < 3; i++ {
go func(i int) {
buf := &bytes.Buffer{}
l := logger.NewLogger("multi-init-" + strconv.Itoa(i))
l.SetOutput(io.MultiWriter(buf, os.Stdout))
// Init and perform the migrations
storeObj := state_postgres.NewPostgreSQLStateStore(l).(*state_postgres.PostgreSQL)
err := storeObj.Init(md)
if err != nil {
errs <- fmt.Errorf("%d failed to init: %w", i, err)
return
}
// One and only one of the loggers should have any message
if buf.Len() > 0 {
hasLogs.Add(1)
}
// Close the component right away
err = storeObj.Close()
if err != nil {
errs <- fmt.Errorf("%d failed to close: %w", i, err)
return
}
errs <- nil
}(i)
}
failed := false
for i := 0; i < 3; i++ {
select {
case err := <-errs:
failed = failed || !assert.NoError(t, err)
case <-time.After(time.Minute):
t.Fatal("timed out waiting for components to initialize")
}
}
if failed {
// Short-circuit
t.FailNow()
}
// Exactly one component should have written logs (which means generated any activity during migrations)
assert.Equal(t, int32(1), hasLogs.Load(), "expected 1 component to log anything to indicate migration activity, but got %d", hasLogs.Load())
// We should have the tables correctly created
err = tableExists(dbClient, "public", "mystate")
assert.NoError(t, err, "state table does not exist")
err = tableExists(dbClient, "public", "mymetadata")
assert.NoError(t, err, "metadata table does not exist")
// Ensure migration level is correct
level, err := getMigrationLevel(dbClient, "mymetadata")
assert.NoError(t, err, "failed to get migration level")
assert.Equal(t, migrationLevel, level, "migration level mismatch: found '%s' but expected '%s'", level, migrationLevel)
})
return nil
}
basicTest := func(ctx flow.Context) error {
client, err := client.NewClientWithPort(fmt.Sprint(currentGrpcPort))
client, err := client.NewClientWithPort(strconv.Itoa(currentGrpcPort))
if err != nil {
panic(err)
}
@ -199,7 +443,7 @@ func TestPostgreSQL(t *testing.T) {
}
testGetAfterPostgresRestart := func(ctx flow.Context) error {
client, err := client.NewClientWithPort(fmt.Sprint(currentGrpcPort))
client, err := client.NewClientWithPort(strconv.Itoa(currentGrpcPort))
if err != nil {
panic(err)
}
@ -214,7 +458,7 @@ func TestPostgreSQL(t *testing.T) {
// checks the state store component is not vulnerable to SQL injection
verifySQLInjectionTest := func(ctx flow.Context) error {
client, err := client.NewClientWithPort(fmt.Sprint(currentGrpcPort))
client, err := client.NewClientWithPort(strconv.Itoa(currentGrpcPort))
if err != nil {
panic(err)
}
@ -244,23 +488,310 @@ func TestPostgreSQL(t *testing.T) {
return nil
}
// Validates TTLs and garbage collections
ttlTest := func(ctx flow.Context) error {
md := state.Metadata{
Base: metadata.Base{
Name: "ttltest",
Properties: map[string]string{
keyConnectionString: connStringValue,
keyTableName: "ttl_state",
keyMetadatTableName: "ttl_metadata",
},
},
}
t.Run("parse cleanupIntervalInSeconds", func(t *testing.T) {
t.Run("default value", func(t *testing.T) {
// Default value is 1 hr
md.Properties[keyCleanupInterval] = ""
storeObj := state_postgres.NewPostgreSQLStateStore(log).(*state_postgres.PostgreSQL)
err := storeObj.Init(md)
require.NoError(t, err, "failed to init")
defer storeObj.Close()
dbAccess := storeObj.GetDBAccess().(*state_postgres.PostgresDBAccess)
require.NotNil(t, dbAccess)
cleanupInterval := dbAccess.GetCleanupInterval()
_ = assert.NotNil(t, cleanupInterval) &&
assert.Equal(t, time.Duration(1*time.Hour), *cleanupInterval)
})
t.Run("positive value", func(t *testing.T) {
// A positive value is interpreted in seconds
md.Properties[keyCleanupInterval] = "10"
storeObj := state_postgres.NewPostgreSQLStateStore(log).(*state_postgres.PostgreSQL)
err := storeObj.Init(md)
require.NoError(t, err, "failed to init")
defer storeObj.Close()
dbAccess := storeObj.GetDBAccess().(*state_postgres.PostgresDBAccess)
require.NotNil(t, dbAccess)
cleanupInterval := dbAccess.GetCleanupInterval()
_ = assert.NotNil(t, cleanupInterval) &&
assert.Equal(t, time.Duration(10*time.Second), *cleanupInterval)
})
t.Run("disabled", func(t *testing.T) {
// A value of <=0 means that the cleanup is disabled
md.Properties[keyCleanupInterval] = "0"
storeObj := state_postgres.NewPostgreSQLStateStore(log).(*state_postgres.PostgreSQL)
err := storeObj.Init(md)
require.NoError(t, err, "failed to init")
defer storeObj.Close()
dbAccess := storeObj.GetDBAccess().(*state_postgres.PostgresDBAccess)
require.NotNil(t, dbAccess)
cleanupInterval := dbAccess.GetCleanupInterval()
_ = assert.Nil(t, cleanupInterval)
})
})
t.Run("cleanup", func(t *testing.T) {
md := state.Metadata{
Base: metadata.Base{
Name: "ttltest",
Properties: map[string]string{
keyConnectionString: connStringValue,
keyTableName: "ttl_state",
keyMetadatTableName: "ttl_metadata",
},
},
}
t.Run("automatically delete expired records", func(t *testing.T) {
// Run every second
md.Properties[keyCleanupInterval] = "1"
storeObj := state_postgres.NewPostgreSQLStateStore(log).(*state_postgres.PostgreSQL)
err := storeObj.Init(md)
require.NoError(t, err, "failed to init")
defer storeObj.Close()
// Seed the database with some records
err = populateTTLRecords(ctx, dbClient)
require.NoError(t, err, "failed to seed records")
// Wait 2 seconds then verify we have only 10 rows left
time.Sleep(2 * time.Second)
count, err := countRowsInTable(ctx, dbClient, "ttl_state")
require.NoError(t, err, "failed to run query to count rows")
assert.Equal(t, 10, count)
// The "last-cleanup" value should be <= 1 second (+ a bit of buffer)
lastCleanup, err := loadLastCleanupInterval(ctx, dbClient, "ttl_metadata")
require.NoError(t, err, "failed to load value for 'last-cleanup'")
assert.LessOrEqual(t, lastCleanup, int64(1200))
// Wait 6 more seconds and verify there are no more rows left
time.Sleep(6 * time.Second)
count, err = countRowsInTable(ctx, dbClient, "ttl_state")
require.NoError(t, err, "failed to run query to count rows")
assert.Equal(t, 0, count)
// The "last-cleanup" value should be <= 1 second (+ a bit of buffer)
lastCleanup, err = loadLastCleanupInterval(ctx, dbClient, "ttl_metadata")
require.NoError(t, err, "failed to load value for 'last-cleanup'")
assert.LessOrEqual(t, lastCleanup, int64(1200))
})
t.Run("cleanup concurrency", func(t *testing.T) {
// Set to run every hour
// (we'll manually trigger more frequent iterations)
md.Properties[keyCleanupInterval] = "3600"
storeObj := state_postgres.NewPostgreSQLStateStore(log).(*state_postgres.PostgreSQL)
err := storeObj.Init(md)
require.NoError(t, err, "failed to init")
defer storeObj.Close()
dbAccess := storeObj.GetDBAccess().(*state_postgres.PostgresDBAccess)
require.NotNil(t, dbAccess)
// Seed the database with some records
err = populateTTLRecords(ctx, dbClient)
require.NoError(t, err, "failed to seed records")
// Validate that 20 records are present
count, err := countRowsInTable(ctx, dbClient, "ttl_state")
require.NoError(t, err, "failed to run query to count rows")
assert.Equal(t, 20, count)
// Set last-cleanup to 1s ago (less than 3600s)
err = setValueInMetadataTable(ctx, dbClient, "ttl_metadata", "'last-cleanup'", "CURRENT_TIMESTAMP - interval '1 second'")
require.NoError(t, err, "failed to set last-cleanup")
// The "last-cleanup" value should be ~1 second (+ a bit of buffer)
lastCleanup, err := loadLastCleanupInterval(ctx, dbClient, "ttl_metadata")
require.NoError(t, err, "failed to load value for 'last-cleanup'")
assert.LessOrEqual(t, lastCleanup, int64(1200))
lastCleanupValueOrig, err := getValueFromMetadataTable(ctx, dbClient, "ttl_metadata", "last-cleanup")
require.NoError(t, err, "failed to load absolute value for 'last-cleanup'")
require.NotEmpty(t, lastCleanupValueOrig)
// Trigger the background cleanup, which should do nothing because the last cleanup was < 3600s
err = dbAccess.CleanupExpired(ctx)
require.NoError(t, err, "CleanupExpired returned an error")
// Validate that 20 records are still present
count, err = countRowsInTable(ctx, dbClient, "ttl_state")
require.NoError(t, err, "failed to run query to count rows")
assert.Equal(t, 20, count)
// The "last-cleanup" value should not have been changed
lastCleanupValue, err := getValueFromMetadataTable(ctx, dbClient, "ttl_metadata", "last-cleanup")
require.NoError(t, err, "failed to load absolute value for 'last-cleanup'")
assert.Equal(t, lastCleanupValueOrig, lastCleanupValue)
})
})
return nil
}
flow.New(t, "Run tests").
Step(dockercompose.Run("db", dockerComposeYAML)).
Step("wait for component to start", flow.Sleep(10*time.Second)).
// No waiting here, as connectStep retries until it's ready (or there's a timeout)
//Step("wait for component to start", flow.Sleep(10*time.Second)).
Step("connect to the database", connectStep).
Step("run Init test", initTest).
Step(sidecar.Run(sidecarNamePrefix+"dockerDefault",
embedded.WithoutApp(),
embedded.WithDaprGRPCPort(currentGrpcPort),
embedded.WithComponentsPath("components/docker/default"),
runtime.WithStates(stateRegistry),
)).
Step("Run CRUD test", basicTest).
Step("Run eTag test", eTagTest).
Step("Run transactions test", transactionsTest).
Step("Run SQL injection test", verifySQLInjectionTest).
Step("run CRUD test", basicTest).
Step("run eTag test", eTagTest).
Step("run transactions test", transactionsTest).
Step("run SQL injection test", verifySQLInjectionTest).
Step("run TTL test", ttlTest).
Step("stop postgresql", dockercompose.Stop("db", dockerComposeYAML, "db")).
Step("wait for component to stop", flow.Sleep(10*time.Second)).
Step("start postgresql", dockercompose.Start("db", dockerComposeYAML, "db")).
Step("wait for component to start", flow.Sleep(10*time.Second)).
Step("Run connection test", testGetAfterPostgresRestart).
Step("run connection test", testGetAfterPostgresRestart).
Run()
}
func tableExists(dbClient *pgx.Conn, schema string, table string) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
var scanTable, scanSchema string
err := dbClient.QueryRow(
ctx,
"SELECT table_name, table_schema FROM information_schema.tables WHERE table_name = $1 AND table_schema = $2",
table, schema,
).Scan(&scanTable, &scanSchema)
if err != nil {
return fmt.Errorf("error querying for table: %w", err)
}
if table != scanTable || schema != scanSchema {
return fmt.Errorf("found table '%s.%s' does not match", scanSchema, scanTable)
}
return nil
}
func getMigrationLevel(dbClient *pgx.Conn, metadataTable string) (level string, err error) {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
err = dbClient.
QueryRow(ctx, fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, metadataTable)).
Scan(&level)
if err != nil && errors.Is(err, sql.ErrNoRows) {
err = nil
level = ""
}
return level, err
}
func populateTTLRecords(ctx context.Context, dbClient *pgx.Conn) error {
// Insert 10 records that have expired, and 10 that will expire in 6 seconds
exp := time.Now().Add(-1 * time.Minute)
rows := make([][]any, 20)
for i := 0; i < 10; i++ {
rows[i] = []any{
fmt.Sprintf("expired_%d", i),
json.RawMessage(fmt.Sprintf(`"value_%d"`, i)),
false,
exp,
}
}
exp = time.Now().Add(4 * time.Second)
for i := 0; i < 10; i++ {
rows[i+10] = []any{
fmt.Sprintf("notexpired_%d", i),
json.RawMessage(fmt.Sprintf(`"value_%d"`, i)),
false,
exp,
}
}
queryCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
n, err := dbClient.CopyFrom(
queryCtx,
pgx.Identifier{"ttl_state"},
[]string{"key", "value", "isbinary", "expiredate"},
pgx.CopyFromRows(rows),
)
if err != nil {
return err
}
if n != 20 {
return fmt.Errorf("expected to copy 20 rows, but only got %d", n)
}
return nil
}
func countRowsInTable(ctx context.Context, dbClient *pgx.Conn, table string) (count int, err error) {
queryCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
err = dbClient.QueryRow(queryCtx, "SELECT COUNT(key) FROM "+table).Scan(&count)
cancel()
return
}
func loadLastCleanupInterval(ctx context.Context, dbClient *pgx.Conn, table string) (lastCleanup int64, err error) {
queryCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
err = dbClient.
QueryRow(queryCtx,
fmt.Sprintf("SELECT (EXTRACT('epoch' FROM CURRENT_TIMESTAMP - value::timestamp with time zone) * 1000)::bigint FROM %s WHERE key = 'last-cleanup'", table),
).
Scan(&lastCleanup)
cancel()
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
return
}
// Note this uses fmt.Sprintf and not parametrized queries-on purpose, so we can pass Postgres functions).
// Normally this would be a very bad idea, just don't do it... (do as I say don't do as I do :) ).
func setValueInMetadataTable(ctx context.Context, dbClient *pgx.Conn, table, key, value string) error {
queryCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
_, err := dbClient.Exec(queryCtx,
//nolint:gosec
fmt.Sprintf(`INSERT INTO %[1]s (key, value) VALUES (%[2]s, %[3]s) ON CONFLICT (key) DO UPDATE SET value = %[3]s`, table, key, value),
)
cancel()
return err
}
func getValueFromMetadataTable(ctx context.Context, dbClient *pgx.Conn, table, key string) (value string, err error) {
queryCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
err = dbClient.
QueryRow(queryCtx, fmt.Sprintf("SELECT value FROM %s WHERE key = $1", table), key).
Scan(&value)
cancel()
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
return
}

View File

@ -20,8 +20,7 @@ components:
allOperations: false
operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "first-write" ]
- component: postgresql
allOperations: false
operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "query", "first-write" ]
allOperations: true
- component: mysql.mysql
allOperations: false
operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "first-write" ]