Merge pull request #2302 from ItalyPaleAle/postgres-ttl
Add TTL to postgres state store
This commit is contained in:
commit
ea9b623ccb
|
@ -21,7 +21,6 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -36,6 +35,7 @@ import (
|
|||
contribmeta "github.com/dapr/components-contrib/metadata"
|
||||
"github.com/dapr/components-contrib/state"
|
||||
"github.com/dapr/components-contrib/state/query"
|
||||
stateutils "github.com/dapr/components-contrib/state/utils"
|
||||
"github.com/dapr/kit/logger"
|
||||
"github.com/dapr/kit/ptr"
|
||||
)
|
||||
|
@ -77,7 +77,6 @@ type CosmosItem struct {
|
|||
|
||||
const (
|
||||
metadataPartitionKey = "partitionKey"
|
||||
metadataTTLKey = "ttlInSeconds"
|
||||
defaultTimeout = 20 * time.Second
|
||||
statusNotFound = "NotFound"
|
||||
)
|
||||
|
@ -481,7 +480,7 @@ func createUpsertItem(contentType string, req state.SetRequest, partitionKey str
|
|||
isBinary = false
|
||||
}
|
||||
|
||||
ttl, err := parseTTL(req.Metadata)
|
||||
ttl, err := stateutils.ParseTTL(req.Metadata)
|
||||
if err != nil {
|
||||
return CosmosItem{}, fmt.Errorf("error parsing TTL from metadata: %s", err)
|
||||
}
|
||||
|
@ -534,20 +533,6 @@ func populatePartitionMetadata(key string, requestMetadata map[string]string) st
|
|||
return key
|
||||
}
|
||||
|
||||
func parseTTL(requestMetadata map[string]string) (*int, error) {
|
||||
if val, found := requestMetadata[metadataTTLKey]; found && val != "" {
|
||||
parsedVal, err := strconv.ParseInt(val, 10, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
i := int(parsedVal)
|
||||
|
||||
return &i, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func isNotFoundError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/dapr/components-contrib/state"
|
||||
stateutils "github.com/dapr/components-contrib/state/utils"
|
||||
)
|
||||
|
||||
type widget struct {
|
||||
|
@ -284,7 +285,7 @@ func TestCreateCosmosItemWithTTL(t *testing.T) {
|
|||
Key: "testKey",
|
||||
Value: value,
|
||||
Metadata: map[string]string{
|
||||
metadataTTLKey: strconv.Itoa(ttl),
|
||||
stateutils.MetadataTTLKey: strconv.Itoa(ttl),
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -316,7 +317,7 @@ func TestCreateCosmosItemWithTTL(t *testing.T) {
|
|||
Key: "testKey",
|
||||
Value: value,
|
||||
Metadata: map[string]string{
|
||||
metadataTTLKey: strconv.Itoa(ttl),
|
||||
stateutils.MetadataTTLKey: strconv.Itoa(ttl),
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -347,7 +348,7 @@ func TestCreateCosmosItemWithTTL(t *testing.T) {
|
|||
Key: "testKey",
|
||||
Value: value,
|
||||
Metadata: map[string]string{
|
||||
metadataTTLKey: "notattl",
|
||||
stateutils.MetadataTTLKey: "notattl",
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
|
||||
"github.com/dapr/components-contrib/metadata"
|
||||
"github.com/dapr/components-contrib/state"
|
||||
stateutils "github.com/dapr/components-contrib/state/utils"
|
||||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
||||
|
@ -279,7 +280,7 @@ func (c *Cassandra) Set(ctx context.Context, req *state.SetRequest) error {
|
|||
session = sess
|
||||
}
|
||||
|
||||
ttl, err := parseTTL(req.Metadata)
|
||||
ttl, err := stateutils.ParseTTL(req.Metadata)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing TTL from Metadata: %s", err)
|
||||
}
|
||||
|
@ -302,20 +303,6 @@ func (c *Cassandra) createSession(consistency gocql.Consistency) (*gocql.Session
|
|||
return session, nil
|
||||
}
|
||||
|
||||
func parseTTL(requestMetadata map[string]string) (*int, error) {
|
||||
if val, found := requestMetadata[metadataTTLKey]; found && val != "" {
|
||||
parsedVal, err := strconv.ParseInt(val, 10, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parsedInt := int(parsedVal)
|
||||
|
||||
return &parsedInt, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *Cassandra) GetComponentMetadata() map[string]string {
|
||||
metadataStruct := cassandraMetadata{}
|
||||
metadataInfo := map[string]string{}
|
||||
|
|
|
@ -14,7 +14,6 @@ limitations under the License.
|
|||
package cassandra
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
|
@ -111,35 +110,3 @@ func TestGetCassandraMetadata(t *testing.T) {
|
|||
assert.NotNil(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseTTL(t *testing.T) {
|
||||
t.Run("TTL Not an integer", func(t *testing.T) {
|
||||
ttlInSeconds := "not an integer"
|
||||
ttl, err := parseTTL(map[string]string{
|
||||
"ttlInSeconds": ttlInSeconds,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, ttl)
|
||||
})
|
||||
t.Run("TTL specified with wrong key", func(t *testing.T) {
|
||||
ttlInSeconds := 12345
|
||||
ttl, err := parseTTL(map[string]string{
|
||||
"expirationTime": strconv.Itoa(ttlInSeconds),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, ttl)
|
||||
})
|
||||
t.Run("TTL is a number", func(t *testing.T) {
|
||||
ttlInSeconds := 12345
|
||||
ttl, err := parseTTL(map[string]string{
|
||||
"ttlInSeconds": strconv.Itoa(ttlInSeconds),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, *ttl, ttlInSeconds)
|
||||
})
|
||||
t.Run("TTL not set", func(t *testing.T) {
|
||||
ttl, err := parseTTL(map[string]string{})
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, ttl)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -93,7 +93,7 @@ func TestParseTTL(t *testing.T) {
|
|||
},
|
||||
})
|
||||
|
||||
assert.NotNil(t, err, "tll is not an integer")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, ttl)
|
||||
})
|
||||
t.Run("TTL is a negative integer ends up translated to 0", func(t *testing.T) {
|
||||
|
|
|
@ -22,7 +22,6 @@ import (
|
|||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -34,6 +33,7 @@ import (
|
|||
|
||||
"github.com/dapr/components-contrib/metadata"
|
||||
"github.com/dapr/components-contrib/state"
|
||||
stateutils "github.com/dapr/components-contrib/state/utils"
|
||||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
||||
|
@ -278,17 +278,13 @@ func (r *StateStore) writeDocument(ctx context.Context, req *state.SetRequest) e
|
|||
}
|
||||
|
||||
func (r *StateStore) convertTTLtoExpiryTime(req *state.SetRequest, metadata map[string]string) error {
|
||||
ttl, ttlerr := parseTTL(req.Metadata)
|
||||
ttl, ttlerr := stateutils.ParseTTL(req.Metadata)
|
||||
if ttlerr != nil {
|
||||
return fmt.Errorf("error in parsing TTL %w", ttlerr)
|
||||
return fmt.Errorf("error parsing TTL: %w", ttlerr)
|
||||
}
|
||||
if ttl != nil {
|
||||
if *ttl == -1 {
|
||||
r.logger.Debugf("TTL is set to -1; this means: never expire. ")
|
||||
} else {
|
||||
metadata[expiryTimeMetaLabel] = time.Now().UTC().Add(time.Second * time.Duration(*ttl)).Format(isoDateTimeFormat)
|
||||
r.logger.Debugf("Set %s in meta properties for object to ", expiryTimeMetaLabel, metadata[expiryTimeMetaLabel])
|
||||
}
|
||||
metadata[expiryTimeMetaLabel] = time.Now().UTC().Add(time.Second * time.Duration(*ttl)).Format(isoDateTimeFormat)
|
||||
r.logger.Debugf("Set %s in meta properties for object to ", expiryTimeMetaLabel, metadata[expiryTimeMetaLabel])
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -367,20 +363,6 @@ func getFileName(key string) string {
|
|||
return path.Join(pr[0], pr[1])
|
||||
}
|
||||
|
||||
func parseTTL(requestMetadata map[string]string) (*int, error) {
|
||||
if val, found := requestMetadata[metadataTTLKey]; found && val != "" {
|
||||
parsedVal, err := strconv.ParseInt(val, 10, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error in parsing ttl metadata : %w", err)
|
||||
}
|
||||
parsedInt := int(parsedVal)
|
||||
|
||||
return &parsedInt, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
/**************** functions with OCI ObjectStorage Service interaction. */
|
||||
|
||||
func getNamespace(ctx context.Context, client objectstorage.ObjectStorageClient) (string, error) {
|
||||
|
|
|
@ -17,7 +17,6 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -394,36 +393,3 @@ func TestGetFilename(t *testing.T) {
|
|||
assert.Equal(t, "app-id-key", filename)
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseTTL(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("TTL Not an integer", func(t *testing.T) {
|
||||
ttlInSeconds := "not an integer"
|
||||
ttl, err := parseTTL(map[string]string{
|
||||
"ttlInSeconds": ttlInSeconds,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, ttl)
|
||||
})
|
||||
t.Run("TTL specified with wrong key", func(t *testing.T) {
|
||||
ttlInSeconds := 12345
|
||||
ttl, err := parseTTL(map[string]string{
|
||||
"expirationTime": strconv.Itoa(ttlInSeconds),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, ttl)
|
||||
})
|
||||
t.Run("TTL is a number", func(t *testing.T) {
|
||||
ttlInSeconds := 12345
|
||||
ttl, err := parseTTL(map[string]string{
|
||||
"ttlInSeconds": strconv.Itoa(ttlInSeconds),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, *ttl, ttlInSeconds)
|
||||
})
|
||||
t.Run("TTL not set", func(t *testing.T) {
|
||||
ttl, err := parseTTL(map[string]string{})
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, ttl)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -21,7 +21,6 @@ import (
|
|||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -658,43 +657,6 @@ func setItemWithNoKey(t *testing.T, ods *OracleDatabase) {
|
|||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestParseTTL(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("TTL Not an integer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ttlInSeconds := "not an integer"
|
||||
ttl, err := parseTTL(map[string]string{
|
||||
"ttlInSeconds": ttlInSeconds,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, ttl)
|
||||
})
|
||||
t.Run("TTL specified with wrong key", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ttlInSeconds := 12345
|
||||
ttl, err := parseTTL(map[string]string{
|
||||
"expirationTime": strconv.Itoa(ttlInSeconds),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, ttl)
|
||||
})
|
||||
t.Run("TTL is a number", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ttlInSeconds := 12345
|
||||
ttl, err := parseTTL(map[string]string{
|
||||
"ttlInSeconds": strconv.Itoa(ttlInSeconds),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, *ttl, ttlInSeconds)
|
||||
})
|
||||
t.Run("TTL not set", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ttl, err := parseTTL(map[string]string{})
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, ttl)
|
||||
})
|
||||
}
|
||||
|
||||
func testSetItemWithInvalidTTL(t *testing.T, ods *OracleDatabase) {
|
||||
setReq := &state.SetRequest{
|
||||
Key: randomKey(),
|
||||
|
|
|
@ -20,13 +20,12 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/dapr/components-contrib/metadata"
|
||||
"github.com/dapr/components-contrib/state"
|
||||
"github.com/dapr/components-contrib/state/utils"
|
||||
stateutils "github.com/dapr/components-contrib/state/utils"
|
||||
"github.com/dapr/kit/logger"
|
||||
|
||||
// Blank import for the underlying Oracle Database driver.
|
||||
|
@ -36,7 +35,6 @@ import (
|
|||
const (
|
||||
connectionStringKey = "connectionString"
|
||||
oracleWalletLocationKey = "oracleWalletLocation"
|
||||
metadataTTLKey = "ttlInSeconds"
|
||||
errMissingConnectionString = "missing connection string"
|
||||
tableName = "state"
|
||||
)
|
||||
|
@ -115,20 +113,6 @@ func (o *oracleDatabaseAccess) Init(metadata state.Metadata) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func parseTTL(requestMetadata map[string]string) (*int, error) {
|
||||
if val, found := requestMetadata[metadataTTLKey]; found && val != "" {
|
||||
parsedVal, err := strconv.ParseInt(val, 10, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error in parsing ttl metadata : %w", err)
|
||||
}
|
||||
parsedInt := int(parsedVal)
|
||||
|
||||
return &parsedInt, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Set makes an insert or update to the database.
|
||||
func (o *oracleDatabaseAccess) Set(ctx context.Context, req *state.SetRequest) error {
|
||||
o.logger.Debug("Setting state value in OracleDatabase")
|
||||
|
@ -149,19 +133,12 @@ func (o *oracleDatabaseAccess) Set(ctx context.Context, req *state.SetRequest) e
|
|||
return fmt.Errorf("when FirstWrite is to be enforced, a value must be provided for the ETag")
|
||||
}
|
||||
var ttlSeconds int
|
||||
ttl, ttlerr := parseTTL(req.Metadata)
|
||||
ttl, ttlerr := stateutils.ParseTTL(req.Metadata)
|
||||
if ttlerr != nil {
|
||||
return fmt.Errorf("error in parsing TTL %w", ttlerr)
|
||||
return fmt.Errorf("error parsing TTL: %w", ttlerr)
|
||||
}
|
||||
if ttl != nil {
|
||||
if *ttl == -1 {
|
||||
o.logger.Debugf("TTL is set to -1; this means: never expire. ")
|
||||
} else {
|
||||
if *ttl < -1 {
|
||||
return fmt.Errorf("incorrect value for %s %d", metadataTTLKey, *ttl)
|
||||
}
|
||||
ttlSeconds = *ttl
|
||||
}
|
||||
ttlSeconds = *ttl
|
||||
}
|
||||
requestValue := req.Value
|
||||
byteArray, isBinary := req.Value.([]uint8)
|
||||
|
@ -172,7 +149,7 @@ func (o *oracleDatabaseAccess) Set(ctx context.Context, req *state.SetRequest) e
|
|||
}
|
||||
|
||||
// Convert to json string.
|
||||
bt, _ := utils.Marshal(requestValue, json.Marshal)
|
||||
bt, _ := stateutils.Marshal(requestValue, json.Marshal)
|
||||
value := string(bt)
|
||||
|
||||
var result sql.Result
|
||||
|
|
|
@ -15,6 +15,7 @@ package postgresql
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/dapr/components-contrib/state"
|
||||
)
|
||||
|
@ -31,3 +32,12 @@ type dbAccess interface {
|
|||
Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error)
|
||||
Close() error // io.Closer
|
||||
}
|
||||
|
||||
// Interface that contains methods for querying.
|
||||
// Applies to both *sql.DB and *sql.Tx
|
||||
type dbquerier interface {
|
||||
Exec(query string, args ...any) (sql.Result, error)
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
QueryRow(query string, args ...any) *sql.Row
|
||||
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
||||
}
|
||||
|
|
|
@ -0,0 +1,227 @@
|
|||
/*
|
||||
Copyright 2021 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
||||
// Performs migrations for the database schema
|
||||
type migrations struct {
|
||||
Logger logger.Logger
|
||||
Conn *sql.DB
|
||||
StateTableName string
|
||||
MetadataTableName string
|
||||
}
|
||||
|
||||
// Perform the required migrations
|
||||
func (m *migrations) Perform(ctx context.Context) error {
|
||||
// Use an advisory lock (with an arbitrary number) to ensure that no one else is performing migrations at the same time
|
||||
// This is the only way to also ensure we are not running multiple "CREATE TABLE IF NOT EXISTS" at the exact same time
|
||||
// See: https://www.postgresql.org/message-id/CA+TgmoZAdYVtwBfp1FL2sMZbiHCWT4UPrzRLNnX1Nb30Ku3-gg@mail.gmail.com
|
||||
const lockID = 42
|
||||
|
||||
// Long timeout here as this query may block
|
||||
queryCtx, cancel := context.WithTimeout(ctx, time.Minute)
|
||||
_, err := m.Conn.ExecContext(queryCtx, "SELECT pg_advisory_lock($1)", lockID)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return fmt.Errorf("faild to acquire advisory lock: %w", err)
|
||||
}
|
||||
|
||||
// Release the lock
|
||||
defer func() {
|
||||
queryCtx, cancel = context.WithTimeout(ctx, time.Minute)
|
||||
_, err = m.Conn.ExecContext(queryCtx, "SELECT pg_advisory_unlock($1)", lockID)
|
||||
cancel()
|
||||
if err != nil {
|
||||
// Panicking here, as this forcibly closes the session and thus ensures we are not leaving locks hanging around
|
||||
m.Logger.Fatalf("Failed to release advisory lock: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Check if the metadata table exists, which we also use to store the migration level
|
||||
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
|
||||
exists, _, _, err := m.tableExists(queryCtx, m.MetadataTableName)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If the table doesn't exist, create it
|
||||
if !exists {
|
||||
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
|
||||
err = m.createMetadataTable(queryCtx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Select the migration level
|
||||
var (
|
||||
migrationLevelStr string
|
||||
migrationLevel int
|
||||
)
|
||||
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
|
||||
err = m.Conn.
|
||||
QueryRowContext(queryCtx,
|
||||
fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, m.MetadataTableName),
|
||||
).Scan(&migrationLevelStr)
|
||||
cancel()
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
// If there's no row...
|
||||
migrationLevel = 0
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to read migration level: %w", err)
|
||||
} else {
|
||||
migrationLevel, err = strconv.Atoi(migrationLevelStr)
|
||||
if err != nil || migrationLevel < 0 {
|
||||
return fmt.Errorf("invalid migration level found in metadata table: %s", migrationLevelStr)
|
||||
}
|
||||
}
|
||||
|
||||
// Perform the migrations
|
||||
for i := migrationLevel; i < len(allMigrations); i++ {
|
||||
m.Logger.Infof("Performing migration %d", i)
|
||||
err = allMigrations[i](ctx, m)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to perform migration %d: %w", i, err)
|
||||
}
|
||||
|
||||
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
|
||||
_, err = m.Conn.ExecContext(queryCtx,
|
||||
fmt.Sprintf(`INSERT INTO %s (key, value) VALUES ('migrations', $1) ON CONFLICT (key) DO UPDATE SET value = $1`, m.MetadataTableName),
|
||||
strconv.Itoa(i+1),
|
||||
)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update migration level in metadata table: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m migrations) createMetadataTable(ctx context.Context) error {
|
||||
m.Logger.Infof("Creating metadata table '%s'", m.MetadataTableName)
|
||||
// Add an "IF NOT EXISTS" in case another Dapr sidecar is creating the same table at the same time
|
||||
// In the next step we'll acquire a lock so there won't be issues with concurrency
|
||||
_, err := m.Conn.Exec(fmt.Sprintf(
|
||||
`CREATE TABLE IF NOT EXISTS %s (
|
||||
key text NOT NULL PRIMARY KEY,
|
||||
value text NOT NULL
|
||||
)`,
|
||||
m.MetadataTableName,
|
||||
))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create metadata table: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// If the table exists, returns true and the name of the table and schema
|
||||
func (m migrations) tableExists(ctx context.Context, tableName string) (exists bool, schema string, table string, err error) {
|
||||
table, schema, err = m.tableSchemaName(tableName)
|
||||
if err != nil {
|
||||
return false, "", "", err
|
||||
}
|
||||
|
||||
if schema == "" {
|
||||
err = m.Conn.
|
||||
QueryRowContext(
|
||||
ctx,
|
||||
`SELECT table_name, table_schema
|
||||
FROM information_schema.tables
|
||||
WHERE table_name = $1`,
|
||||
table,
|
||||
).
|
||||
Scan(&table, &schema)
|
||||
} else {
|
||||
err = m.Conn.
|
||||
QueryRowContext(
|
||||
ctx,
|
||||
`SELECT table_name, table_schema
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = $1 AND table_name = $2`,
|
||||
schema, table,
|
||||
).
|
||||
Scan(&table, &schema)
|
||||
}
|
||||
|
||||
if err != nil && errors.Is(err, sql.ErrNoRows) {
|
||||
return false, "", "", nil
|
||||
} else if err != nil {
|
||||
return false, "", "", fmt.Errorf("failed to check if table '%s' exists: %w", tableName, err)
|
||||
}
|
||||
return true, schema, table, nil
|
||||
}
|
||||
|
||||
// If the table name includes a schema (e.g. `schema.table`, returns the two parts separately)
|
||||
func (m migrations) tableSchemaName(tableName string) (table string, schema string, err error) {
|
||||
parts := strings.Split(tableName, ".")
|
||||
switch len(parts) {
|
||||
case 1:
|
||||
return parts[0], "", nil
|
||||
case 2:
|
||||
return parts[1], parts[0], nil
|
||||
default:
|
||||
return "", "", errors.New("invalid table name: must be in the format 'table' or 'schema.table'")
|
||||
}
|
||||
}
|
||||
|
||||
var allMigrations = [2]func(ctx context.Context, m *migrations) error{
|
||||
// Migration 0: create the state table
|
||||
func(ctx context.Context, m *migrations) error {
|
||||
// We need to add an "IF NOT EXISTS" because we may be migrating from when we did not use a metadata table
|
||||
m.Logger.Infof("Creating state table '%s'", m.StateTableName)
|
||||
_, err := m.Conn.Exec(
|
||||
fmt.Sprintf(
|
||||
`CREATE TABLE IF NOT EXISTS %s (
|
||||
key text NOT NULL PRIMARY KEY,
|
||||
value jsonb NOT NULL,
|
||||
isbinary boolean NOT NULL,
|
||||
insertdate TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
updatedate TIMESTAMP WITH TIME ZONE NULL
|
||||
)`,
|
||||
m.StateTableName,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create state table: %w", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
|
||||
// Migration 1: add the "expiredate" column
|
||||
func(ctx context.Context, m *migrations) error {
|
||||
m.Logger.Infof("Adding expiredate column to state table '%s'", m.StateTableName)
|
||||
_, err := m.Conn.Exec(fmt.Sprintf(
|
||||
`ALTER TABLE %s ADD expiredate TIMESTAMP WITH TIME ZONE`,
|
||||
m.StateTableName,
|
||||
))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update state table: %w", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
|
@ -26,34 +26,38 @@ import (
|
|||
"github.com/dapr/components-contrib/metadata"
|
||||
"github.com/dapr/components-contrib/state"
|
||||
"github.com/dapr/components-contrib/state/query"
|
||||
"github.com/dapr/components-contrib/state/utils"
|
||||
stateutils "github.com/dapr/components-contrib/state/utils"
|
||||
"github.com/dapr/kit/logger"
|
||||
"github.com/dapr/kit/ptr"
|
||||
|
||||
// Blank import for the underlying PostgreSQL driver.
|
||||
// Blank import for the underlying Postgres driver.
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
)
|
||||
|
||||
const (
|
||||
connectionStringKey = "connectionString"
|
||||
errMissingConnectionString = "missing connection string"
|
||||
defaultTableName = "state"
|
||||
defaultTableName = "state"
|
||||
defaultMetadataTableName = "dapr_metadata"
|
||||
cleanupIntervalKey = "cleanupIntervalInSeconds"
|
||||
defaultCleanupInternal = 3600 // In seconds = 1 hour
|
||||
)
|
||||
|
||||
// postgresDBAccess implements dbaccess.
|
||||
type postgresDBAccess struct {
|
||||
logger logger.Logger
|
||||
metadata postgresMetadataStruct
|
||||
db *sql.DB
|
||||
connectionString string
|
||||
tableName string
|
||||
var errMissingConnectionString = errors.New("missing connection string")
|
||||
|
||||
// PostgresDBAccess implements dbaccess.
|
||||
type PostgresDBAccess struct {
|
||||
logger logger.Logger
|
||||
metadata postgresMetadataStruct
|
||||
cleanupInterval *time.Duration
|
||||
db *sql.DB
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// newPostgresDBAccess creates a new instance of postgresAccess.
|
||||
func newPostgresDBAccess(logger logger.Logger) *postgresDBAccess {
|
||||
logger.Debug("Instantiating new PostgreSQL state store")
|
||||
func newPostgresDBAccess(logger logger.Logger) *PostgresDBAccess {
|
||||
logger.Debug("Instantiating new Postgres state store")
|
||||
|
||||
return &postgresDBAccess{
|
||||
return &PostgresDBAccess{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
@ -61,14 +65,66 @@ func newPostgresDBAccess(logger logger.Logger) *postgresDBAccess {
|
|||
type postgresMetadataStruct struct {
|
||||
ConnectionString string
|
||||
ConnectionMaxIdleTime time.Duration
|
||||
TableName string
|
||||
TableName string // Could be in the format "schema.table" or just "table"
|
||||
MetadataTableName string // Could be in the format "schema.table" or just "table"
|
||||
}
|
||||
|
||||
// Init sets up PostgreSQL connection and ensures that the state table exists.
|
||||
func (p *postgresDBAccess) Init(meta state.Metadata) error {
|
||||
p.logger.Debug("Initializing PostgreSQL state store")
|
||||
// Init sets up Postgres connection and ensures that the state table exists.
|
||||
func (p *PostgresDBAccess) Init(meta state.Metadata) error {
|
||||
p.logger.Debug("Initializing Postgres state store")
|
||||
|
||||
p.ctx, p.cancel = context.WithCancel(context.Background())
|
||||
|
||||
err := p.ParseMetadata(meta)
|
||||
if err != nil {
|
||||
p.logger.Errorf("Failed to parse metadata: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
db, err := sql.Open("pgx", p.metadata.ConnectionString)
|
||||
if err != nil {
|
||||
p.logger.Error(err)
|
||||
return err
|
||||
}
|
||||
|
||||
p.db = db
|
||||
|
||||
pingCtx, pingCancel := context.WithTimeout(p.ctx, 30*time.Second)
|
||||
pingErr := db.PingContext(pingCtx)
|
||||
pingCancel()
|
||||
if pingErr != nil {
|
||||
return pingErr
|
||||
}
|
||||
|
||||
p.db.SetConnMaxIdleTime(p.metadata.ConnectionMaxIdleTime)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
migrate := &migrations{
|
||||
Logger: p.logger,
|
||||
Conn: p.db,
|
||||
MetadataTableName: p.metadata.MetadataTableName,
|
||||
StateTableName: p.metadata.TableName,
|
||||
}
|
||||
err = migrate.Perform(p.ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.ScheduleCleanupExpiredData(p.ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgresDBAccess) GetDB() *sql.DB {
|
||||
return p.db
|
||||
}
|
||||
|
||||
func (p *PostgresDBAccess) ParseMetadata(meta state.Metadata) error {
|
||||
m := postgresMetadataStruct{
|
||||
TableName: defaultTableName,
|
||||
TableName: defaultTableName,
|
||||
MetadataTableName: defaultMetadataTableName,
|
||||
}
|
||||
err := metadata.DecodeMetadata(meta.Properties, &m)
|
||||
if err != nil {
|
||||
|
@ -77,44 +133,33 @@ func (p *postgresDBAccess) Init(meta state.Metadata) error {
|
|||
p.metadata = m
|
||||
|
||||
if m.ConnectionString == "" {
|
||||
p.logger.Error("Missing postgreSQL connection string")
|
||||
|
||||
return errors.New(errMissingConnectionString)
|
||||
}
|
||||
p.connectionString = m.ConnectionString
|
||||
|
||||
db, err := sql.Open("pgx", p.connectionString)
|
||||
if err != nil {
|
||||
p.logger.Error(err)
|
||||
|
||||
return err
|
||||
return errMissingConnectionString
|
||||
}
|
||||
|
||||
p.db = db
|
||||
s, ok := meta.Properties[cleanupIntervalKey]
|
||||
if ok && s != "" {
|
||||
cleanupIntervalInSec, err := strconv.ParseInt(s, 10, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid value for '%s': %s", cleanupIntervalKey, s)
|
||||
}
|
||||
|
||||
pingErr := db.Ping()
|
||||
if pingErr != nil {
|
||||
return pingErr
|
||||
// Non-positive value from meta means disable auto cleanup.
|
||||
if cleanupIntervalInSec > 0 {
|
||||
p.cleanupInterval = ptr.Of(time.Duration(cleanupIntervalInSec) * time.Second)
|
||||
}
|
||||
} else {
|
||||
p.cleanupInterval = ptr.Of(defaultCleanupInternal * time.Second)
|
||||
}
|
||||
|
||||
p.db.SetConnMaxIdleTime(m.ConnectionMaxIdleTime)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = p.ensureStateTable(m.TableName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.tableName = m.TableName
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set makes an insert or update to the database.
|
||||
func (p *postgresDBAccess) Set(ctx context.Context, req *state.SetRequest) error {
|
||||
p.logger.Debug("Setting state value in PostgreSQL")
|
||||
func (p *PostgresDBAccess) Set(ctx context.Context, req *state.SetRequest) error {
|
||||
return p.doSet(ctx, p.db, req)
|
||||
}
|
||||
|
||||
func (p *PostgresDBAccess) doSet(parentCtx context.Context, db dbquerier, req *state.SetRequest) error {
|
||||
err := state.CheckRequestOptions(req.Options)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -135,22 +180,47 @@ func (p *postgresDBAccess) Set(ctx context.Context, req *state.SetRequest) error
|
|||
}
|
||||
|
||||
// Convert to json string
|
||||
bt, _ := utils.Marshal(v, json.Marshal)
|
||||
bt, _ := stateutils.Marshal(v, json.Marshal)
|
||||
value := string(bt)
|
||||
|
||||
// TTL
|
||||
var ttlSeconds int
|
||||
ttl, ttlerr := stateutils.ParseTTL(req.Metadata)
|
||||
if ttlerr != nil {
|
||||
return fmt.Errorf("error parsing TTL: %w", ttlerr)
|
||||
}
|
||||
if ttl != nil {
|
||||
ttlSeconds = *ttl
|
||||
}
|
||||
|
||||
var result sql.Result
|
||||
|
||||
// Sprintf is required for table name because sql.DB does not substitute parameters for table names.
|
||||
// Other parameters use sql.DB parameter substitution.
|
||||
if req.Options.Concurrency == state.FirstWrite && (req.ETag == nil || *req.ETag == "") {
|
||||
result, err = p.db.ExecContext(ctx, fmt.Sprintf(
|
||||
`INSERT INTO %s (key, value, isbinary) VALUES ($1, $2, $3);`,
|
||||
p.tableName), req.Key, value, isBinary)
|
||||
} else if req.ETag == nil || *req.ETag == "" {
|
||||
result, err = p.db.ExecContext(ctx, fmt.Sprintf(
|
||||
`INSERT INTO %s (key, value, isbinary) VALUES ($1, $2, $3)
|
||||
ON CONFLICT (key) DO UPDATE SET value = $2, isbinary = $3, updatedate = NOW();`,
|
||||
p.tableName), req.Key, value, isBinary)
|
||||
// Sprintf is required for table name because query.DB does not substitute parameters for table names.
|
||||
// Other parameters use query.DB parameter substitution.
|
||||
var (
|
||||
query string
|
||||
queryExpiredate string
|
||||
params []any
|
||||
)
|
||||
if req.ETag == nil || *req.ETag == "" {
|
||||
if req.Options.Concurrency == state.FirstWrite {
|
||||
query = `INSERT INTO %[1]s
|
||||
(key, value, isbinary, expiredate)
|
||||
VALUES
|
||||
($1, $2, $3, %[2]s)`
|
||||
} else {
|
||||
query = `INSERT INTO %[1]s
|
||||
(key, value, isbinary, expiredate)
|
||||
VALUES
|
||||
($1, $2, $3, %[2]s)
|
||||
ON CONFLICT (key)
|
||||
DO UPDATE SET
|
||||
value = $2,
|
||||
isbinary = $3,
|
||||
updatedate = CURRENT_TIMESTAMP,
|
||||
expiredate = %[2]s`
|
||||
}
|
||||
params = []any{req.Key, value, isBinary}
|
||||
} else {
|
||||
// Convert req.ETag to uint32 for postgres XID compatibility
|
||||
var etag64 uint64
|
||||
|
@ -158,20 +228,30 @@ func (p *postgresDBAccess) Set(ctx context.Context, req *state.SetRequest) error
|
|||
if err != nil {
|
||||
return state.NewETagError(state.ETagInvalid, err)
|
||||
}
|
||||
etag := uint32(etag64)
|
||||
|
||||
// When an etag is provided do an update - no insert
|
||||
result, err = p.db.ExecContext(ctx, fmt.Sprintf(
|
||||
`UPDATE %s SET value = $1, isbinary = $2, updatedate = NOW()
|
||||
WHERE key = $3 AND xmin = $4;`,
|
||||
p.tableName), value, isBinary, req.Key, etag)
|
||||
query = `UPDATE %[1]s
|
||||
SET
|
||||
value = $1,
|
||||
isbinary = $2,
|
||||
updatedate = CURRENT_TIMESTAMP,
|
||||
expiredate = %[2]s
|
||||
WHERE
|
||||
key = $3
|
||||
AND xmin = $4`
|
||||
params = []any{value, isBinary, req.Key, uint32(etag64)}
|
||||
}
|
||||
|
||||
if ttlSeconds > 0 {
|
||||
queryExpiredate = "CURRENT_TIMESTAMP + interval '" + strconv.Itoa(ttlSeconds) + " seconds'"
|
||||
} else {
|
||||
queryExpiredate = "NULL"
|
||||
}
|
||||
result, err = db.ExecContext(parentCtx, fmt.Sprintf(query, p.metadata.TableName, queryExpiredate), params...)
|
||||
|
||||
if err != nil {
|
||||
if req.ETag != nil && *req.ETag != "" {
|
||||
return state.NewETagError(state.ETagMismatch, err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -179,7 +259,6 @@ func (p *postgresDBAccess) Set(ctx context.Context, req *state.SetRequest) error
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if rows != 1 {
|
||||
return errors.New("no item was updated")
|
||||
}
|
||||
|
@ -187,33 +266,32 @@ func (p *postgresDBAccess) Set(ctx context.Context, req *state.SetRequest) error
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *postgresDBAccess) BulkSet(ctx context.Context, req []state.SetRequest) error {
|
||||
p.logger.Debug("Executing BulkSet request")
|
||||
tx, err := p.db.Begin()
|
||||
func (p *PostgresDBAccess) BulkSet(parentCtx context.Context, req []state.SetRequest) error {
|
||||
tx, err := p.db.BeginTx(parentCtx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
if len(req) > 0 {
|
||||
for _, s := range req {
|
||||
sa := s // Fix for gosec G601: Implicit memory aliasing in for loop.
|
||||
err = p.Set(ctx, &sa)
|
||||
for i := range req {
|
||||
err = p.doSet(parentCtx, tx, &req[i])
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns data from the database. If data does not exist for the key an empty state.GetResponse will be returned.
|
||||
func (p *postgresDBAccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
|
||||
p.logger.Debug("Getting state value from PostgreSQL")
|
||||
func (p *PostgresDBAccess) Get(parentCtx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
|
||||
if req.Key == "" {
|
||||
return nil, errors.New("missing key in get operation")
|
||||
}
|
||||
|
@ -223,7 +301,15 @@ func (p *postgresDBAccess) Get(ctx context.Context, req *state.GetRequest) (*sta
|
|||
isBinary bool
|
||||
etag uint64 // Postgres uses uint32, but FormatUint requires uint64, so using uint64 directly to avoid re-allocations
|
||||
)
|
||||
err := p.db.QueryRowContext(ctx, fmt.Sprintf("SELECT value, isbinary, xmin as etag FROM %s WHERE key = $1", p.tableName), req.Key).Scan(&value, &isBinary, &etag)
|
||||
query := `SELECT
|
||||
value, isbinary, xmin AS etag
|
||||
FROM %s
|
||||
WHERE
|
||||
key = $1
|
||||
AND (expiredate IS NULL OR expiredate >= CURRENT_TIMESTAMP)`
|
||||
err := p.db.
|
||||
QueryRowContext(parentCtx, fmt.Sprintf(query, p.metadata.TableName), req.Key).
|
||||
Scan(&value, &isBinary, &etag)
|
||||
if err != nil {
|
||||
// If no rows exist, return an empty response, otherwise return the error.
|
||||
if err == sql.ErrNoRows {
|
||||
|
@ -261,8 +347,11 @@ func (p *postgresDBAccess) Get(ctx context.Context, req *state.GetRequest) (*sta
|
|||
}
|
||||
|
||||
// Delete removes an item from the state store.
|
||||
func (p *postgresDBAccess) Delete(ctx context.Context, req *state.DeleteRequest) (err error) {
|
||||
p.logger.Debug("Deleting state value from PostgreSQL")
|
||||
func (p *PostgresDBAccess) Delete(ctx context.Context, req *state.DeleteRequest) (err error) {
|
||||
return p.doDelete(ctx, p.db, req)
|
||||
}
|
||||
|
||||
func (p *PostgresDBAccess) doDelete(parentCtx context.Context, db dbquerier, req *state.DeleteRequest) (err error) {
|
||||
if req.Key == "" {
|
||||
return errors.New("missing key in delete operation")
|
||||
}
|
||||
|
@ -270,7 +359,7 @@ func (p *postgresDBAccess) Delete(ctx context.Context, req *state.DeleteRequest)
|
|||
var result sql.Result
|
||||
|
||||
if req.ETag == nil || *req.ETag == "" {
|
||||
result, err = p.db.ExecContext(ctx, "DELETE FROM state WHERE key = $1", req.Key)
|
||||
result, err = db.ExecContext(parentCtx, "DELETE FROM state WHERE key = $1", req.Key)
|
||||
} else {
|
||||
// Convert req.ETag to uint32 for postgres XID compatibility
|
||||
var etag64 uint64
|
||||
|
@ -280,7 +369,7 @@ func (p *postgresDBAccess) Delete(ctx context.Context, req *state.DeleteRequest)
|
|||
}
|
||||
etag := uint32(etag64)
|
||||
|
||||
result, err = p.db.ExecContext(ctx, "DELETE FROM state WHERE key = $1 and xmin = $2", req.Key, etag)
|
||||
result, err = db.ExecContext(parentCtx, "DELETE FROM state WHERE key = $1 AND xmin = $2", req.Key, etag)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
@ -299,92 +388,88 @@ func (p *postgresDBAccess) Delete(ctx context.Context, req *state.DeleteRequest)
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *postgresDBAccess) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
|
||||
p.logger.Debug("Executing BulkDelete request")
|
||||
tx, err := p.db.Begin()
|
||||
func (p *PostgresDBAccess) BulkDelete(parentCtx context.Context, req []state.DeleteRequest) error {
|
||||
tx, err := p.db.BeginTx(parentCtx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
if len(req) > 0 {
|
||||
for i := range req {
|
||||
err = p.Delete(ctx, &req[i])
|
||||
err = p.doDelete(parentCtx, tx, &req[i])
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *postgresDBAccess) ExecuteMulti(ctx context.Context, request *state.TransactionalStateRequest) error {
|
||||
p.logger.Debug("Executing PostgreSQL transaction")
|
||||
|
||||
tx, err := p.db.Begin()
|
||||
func (p *PostgresDBAccess) ExecuteMulti(parentCtx context.Context, request *state.TransactionalStateRequest) error {
|
||||
tx, err := p.db.BeginTx(parentCtx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
for _, o := range request.Operations {
|
||||
switch o.Operation {
|
||||
case state.Upsert:
|
||||
var setReq state.SetRequest
|
||||
|
||||
setReq, err = getSet(o)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
err = p.Set(ctx, &setReq)
|
||||
err = p.doSet(parentCtx, tx, &setReq)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
case state.Delete:
|
||||
var delReq state.DeleteRequest
|
||||
|
||||
delReq, err = getDelete(o)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
err = p.Delete(ctx, &delReq)
|
||||
err = p.doDelete(parentCtx, tx, &delReq)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
default:
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("unsupported operation: %s", o.Operation)
|
||||
}
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query executes a query against store.
|
||||
func (p *postgresDBAccess) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) {
|
||||
p.logger.Debug("Getting query value from PostgreSQL")
|
||||
func (p *PostgresDBAccess) Query(parentCtx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) {
|
||||
q := &Query{
|
||||
query: "",
|
||||
params: []interface{}{},
|
||||
tableName: p.tableName,
|
||||
params: []any{},
|
||||
tableName: p.metadata.TableName,
|
||||
}
|
||||
qbuilder := query.NewQueryBuilder(q)
|
||||
if err := qbuilder.BuildQuery(&req.Query); err != nil {
|
||||
return &state.QueryResponse{}, err
|
||||
}
|
||||
data, token, err := q.execute(ctx, p.logger, p.db)
|
||||
data, token, err := q.execute(parentCtx, p.logger, p.db)
|
||||
if err != nil {
|
||||
return &state.QueryResponse{}, err
|
||||
}
|
||||
|
@ -395,8 +480,94 @@ func (p *postgresDBAccess) Query(ctx context.Context, req *state.QueryRequest) (
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (p *PostgresDBAccess) ScheduleCleanupExpiredData(ctx context.Context) {
|
||||
if p.cleanupInterval == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.logger.Infof("Schedule expired data clean up every %d seconds", int(p.cleanupInterval.Seconds()))
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(*p.cleanupInterval)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
err := p.CleanupExpired(ctx)
|
||||
if err != nil {
|
||||
p.logger.Errorf("Error removing expired data: %v", err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
p.logger.Debug("Stopped background cleanup of expired data")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (p *PostgresDBAccess) CleanupExpired(ctx context.Context) error {
|
||||
// Check if the last iteration was too recent
|
||||
// This performs an atomic operation, so allows coordination with other daprd processes too
|
||||
canContinue, err := p.UpdateLastCleanup(ctx, p.db, *p.cleanupInterval)
|
||||
if err != nil {
|
||||
// Log errors only
|
||||
p.logger.Warnf("Failed to read last cleanup time from database: %v", err)
|
||||
}
|
||||
if !canContinue {
|
||||
p.logger.Debug("Last cleanup was performed too recently")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Note we're not using the transaction here as we don't want this to be rolled back half-way or to lock the table unnecessarily
|
||||
// Need to use fmt.Sprintf because we can't parametrize a table name
|
||||
// Note we are not setting a timeout here as this query can take a "long" time, especially if there's no index on expiredate
|
||||
//nolint:gosec
|
||||
stmt := fmt.Sprintf(`DELETE FROM %s WHERE expiredate IS NOT NULL AND expiredate < CURRENT_TIMESTAMP`, p.metadata.TableName)
|
||||
res, err := p.db.ExecContext(ctx, stmt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute query: %w", err)
|
||||
}
|
||||
|
||||
cleaned, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to count affected rows: %w", err)
|
||||
}
|
||||
|
||||
p.logger.Infof("Removed %d expired rows", cleaned)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateLastCleanup sets the 'last-cleanup' value only if it's less than cleanupInterval.
|
||||
// Returns true if the row was updated, which means that the cleanup can proceed.
|
||||
func (p *PostgresDBAccess) UpdateLastCleanup(ctx context.Context, db dbquerier, cleanupInterval time.Duration) (bool, error) {
|
||||
queryCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
res, err := db.ExecContext(queryCtx,
|
||||
fmt.Sprintf(`INSERT INTO %[1]s (key, value)
|
||||
VALUES ('last-cleanup', CURRENT_TIMESTAMP)
|
||||
ON CONFLICT (key)
|
||||
DO UPDATE SET value = CURRENT_TIMESTAMP
|
||||
WHERE (EXTRACT('epoch' FROM CURRENT_TIMESTAMP - %[1]s.value::timestamp with time zone) * 1000)::bigint > $1`,
|
||||
p.metadata.MetadataTableName),
|
||||
cleanupInterval.Milliseconds()-100, // Subtract 100ms for some buffer
|
||||
)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return true, fmt.Errorf("failed to execute query: %w", err)
|
||||
}
|
||||
|
||||
n, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return true, fmt.Errorf("failed to count affected rows: %w", err)
|
||||
}
|
||||
|
||||
return n > 0, nil
|
||||
}
|
||||
|
||||
// Close implements io.Close.
|
||||
func (p *postgresDBAccess) Close() error {
|
||||
func (p *PostgresDBAccess) Close() error {
|
||||
if p.cancel != nil {
|
||||
p.cancel()
|
||||
p.cancel = nil
|
||||
}
|
||||
if p.db != nil {
|
||||
return p.db.Close()
|
||||
}
|
||||
|
@ -404,34 +575,10 @@ func (p *postgresDBAccess) Close() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *postgresDBAccess) ensureStateTable(stateTableName string) error {
|
||||
exists, err := tableExists(p.db, stateTableName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !exists {
|
||||
p.logger.Info("Creating PostgreSQL state table")
|
||||
createTable := fmt.Sprintf(`CREATE TABLE %s (
|
||||
key text NOT NULL PRIMARY KEY,
|
||||
value jsonb NOT NULL,
|
||||
isbinary boolean NOT NULL,
|
||||
insertdate TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
updatedate TIMESTAMP WITH TIME ZONE NULL);`, stateTableName)
|
||||
_, err = p.db.Exec(createTable)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func tableExists(db *sql.DB, tableName string) (bool, error) {
|
||||
exists := false
|
||||
err := db.QueryRow("SELECT EXISTS (SELECT FROM pg_tables where tablename = $1)", tableName).Scan(&exists)
|
||||
|
||||
return exists, err
|
||||
// GetCleanupInterval returns the cleanupInterval property.
|
||||
// This is primarily used for tests.
|
||||
func (p *PostgresDBAccess) GetCleanupInterval() *time.Duration {
|
||||
return p.cleanupInterval
|
||||
}
|
||||
|
||||
// Returns the set requests.
|
||||
|
|
|
@ -18,18 +18,23 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/dapr/components-contrib/metadata"
|
||||
"github.com/dapr/components-contrib/state"
|
||||
"github.com/dapr/kit/logger"
|
||||
|
||||
// Blank import for pgx
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
)
|
||||
|
||||
type mocks struct {
|
||||
db *sql.DB
|
||||
mock sqlmock.Sqlmock
|
||||
pgDba *postgresDBAccess
|
||||
pgDba *PostgresDBAccess
|
||||
}
|
||||
|
||||
func TestGetSetWithWrongType(t *testing.T) {
|
||||
|
@ -451,7 +456,7 @@ func mockDatabase(t *testing.T) (*mocks, error) {
|
|||
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
|
||||
dba := &postgresDBAccess{
|
||||
dba := &PostgresDBAccess{
|
||||
logger: logger,
|
||||
db: db,
|
||||
}
|
||||
|
@ -462,3 +467,95 @@ func mockDatabase(t *testing.T) (*mocks, error) {
|
|||
pgDba: dba,
|
||||
}, err
|
||||
}
|
||||
|
||||
func TestParseMetadata(t *testing.T) {
|
||||
t.Run("missing connection string", func(t *testing.T) {
|
||||
p := &PostgresDBAccess{}
|
||||
props := map[string]string{}
|
||||
|
||||
err := p.ParseMetadata(state.Metadata{Base: metadata.Base{Properties: props}})
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, errMissingConnectionString)
|
||||
})
|
||||
|
||||
t.Run("has connection string", func(t *testing.T) {
|
||||
p := &PostgresDBAccess{}
|
||||
props := map[string]string{
|
||||
"connectionString": "foo",
|
||||
}
|
||||
|
||||
err := p.ParseMetadata(state.Metadata{Base: metadata.Base{Properties: props}})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("default table name", func(t *testing.T) {
|
||||
p := &PostgresDBAccess{}
|
||||
props := map[string]string{
|
||||
"connectionString": "foo",
|
||||
}
|
||||
|
||||
err := p.ParseMetadata(state.Metadata{Base: metadata.Base{Properties: props}})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, p.metadata.TableName, defaultTableName)
|
||||
})
|
||||
|
||||
t.Run("custom table name", func(t *testing.T) {
|
||||
p := &PostgresDBAccess{}
|
||||
props := map[string]string{
|
||||
"connectionString": "foo",
|
||||
"tableName": "mytable",
|
||||
}
|
||||
|
||||
err := p.ParseMetadata(state.Metadata{Base: metadata.Base{Properties: props}})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, p.metadata.TableName, "mytable")
|
||||
})
|
||||
|
||||
t.Run("default cleanupIntervalInSeconds", func(t *testing.T) {
|
||||
p := &PostgresDBAccess{}
|
||||
props := map[string]string{
|
||||
"connectionString": "foo",
|
||||
}
|
||||
|
||||
err := p.ParseMetadata(state.Metadata{Base: metadata.Base{Properties: props}})
|
||||
assert.NoError(t, err)
|
||||
_ = assert.NotNil(t, p.cleanupInterval) &&
|
||||
assert.Equal(t, *p.cleanupInterval, defaultCleanupInternal*time.Second)
|
||||
})
|
||||
|
||||
t.Run("invalid cleanupIntervalInSeconds", func(t *testing.T) {
|
||||
p := &PostgresDBAccess{}
|
||||
props := map[string]string{
|
||||
"connectionString": "foo",
|
||||
"cleanupIntervalInSeconds": "NaN",
|
||||
}
|
||||
|
||||
err := p.ParseMetadata(state.Metadata{Base: metadata.Base{Properties: props}})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("positive cleanupIntervalInSeconds", func(t *testing.T) {
|
||||
p := &PostgresDBAccess{}
|
||||
props := map[string]string{
|
||||
"connectionString": "foo",
|
||||
"cleanupIntervalInSeconds": "42",
|
||||
}
|
||||
|
||||
err := p.ParseMetadata(state.Metadata{Base: metadata.Base{Properties: props}})
|
||||
assert.NoError(t, err)
|
||||
_ = assert.NotNil(t, p.cleanupInterval) &&
|
||||
assert.Equal(t, *p.cleanupInterval, 42*time.Second)
|
||||
})
|
||||
|
||||
t.Run("zero cleanupIntervalInSeconds", func(t *testing.T) {
|
||||
p := &PostgresDBAccess{}
|
||||
props := map[string]string{
|
||||
"connectionString": "foo",
|
||||
"cleanupIntervalInSeconds": "0",
|
||||
}
|
||||
|
||||
err := p.ParseMetadata(state.Metadata{Base: metadata.Base{Properties: props}})
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, p.cleanupInterval)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -24,7 +24,6 @@ import (
|
|||
|
||||
// PostgreSQL state store.
|
||||
type PostgreSQL struct {
|
||||
features []state.Feature
|
||||
logger logger.Logger
|
||||
dbaccess dbAccess
|
||||
}
|
||||
|
@ -40,7 +39,6 @@ func NewPostgreSQLStateStore(logger logger.Logger) state.Store {
|
|||
// This unexported constructor allows injecting a dbAccess instance for unit testing.
|
||||
func newPostgreSQLStateStore(logger logger.Logger, dba dbAccess) *PostgreSQL {
|
||||
return &PostgreSQL{
|
||||
features: []state.Feature{state.FeatureETag, state.FeatureTransactional, state.FeatureQueryAPI},
|
||||
logger: logger,
|
||||
dbaccess: dba,
|
||||
}
|
||||
|
@ -53,7 +51,7 @@ func (p *PostgreSQL) Init(metadata state.Metadata) error {
|
|||
|
||||
// Features returns the features available in this state store.
|
||||
func (p *PostgreSQL) Features() []state.Feature {
|
||||
return p.features
|
||||
return []state.Feature{state.FeatureETag, state.FeatureTransactional, state.FeatureQueryAPI}
|
||||
}
|
||||
|
||||
// Delete removes an entity from the store.
|
||||
|
@ -102,10 +100,15 @@ func (p *PostgreSQL) Close() error {
|
|||
if p.dbaccess != nil {
|
||||
return p.dbaccess.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Returns the dbaccess property.
|
||||
// This method is used in tests.
|
||||
func (p *PostgreSQL) GetDBAccess() dbAccess {
|
||||
return p.dbaccess
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) GetComponentMetadata() map[string]string {
|
||||
metadataStruct := postgresMetadataStruct{}
|
||||
metadataInfo := map[string]string{}
|
||||
|
|
|
@ -49,7 +49,7 @@ func TestPostgreSQLIntegration(t *testing.T) {
|
|||
})
|
||||
|
||||
metadata := state.Metadata{
|
||||
Base: metadata.Base{Properties: map[string]string{connectionStringKey: connectionString}},
|
||||
Base: metadata.Base{Properties: map[string]string{"connectionString": connectionString}},
|
||||
}
|
||||
|
||||
pgs := NewPostgreSQLStateStore(logger.NewLogger("test")).(*PostgreSQL)
|
||||
|
@ -62,11 +62,6 @@ func TestPostgreSQLIntegration(t *testing.T) {
|
|||
t.Fatal(error)
|
||||
}
|
||||
|
||||
t.Run("Create table succeeds", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testCreateTable(t, pgs.dbaccess.(*postgresDBAccess))
|
||||
})
|
||||
|
||||
t.Run("Get Set Delete one item", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
setGetUpdateDeleteOneItem(t, pgs)
|
||||
|
@ -161,33 +156,6 @@ func setGetUpdateDeleteOneItem(t *testing.T, pgs *PostgreSQL) {
|
|||
deleteItem(t, pgs, key, getResponse.ETag)
|
||||
}
|
||||
|
||||
// testCreateTable tests the ability to create the state table.
|
||||
func testCreateTable(t *testing.T, dba *postgresDBAccess) {
|
||||
tableName := "test_state"
|
||||
|
||||
// Drop the table if it already exists
|
||||
exists, err := tableExists(dba.db, tableName)
|
||||
assert.Nil(t, err)
|
||||
if exists {
|
||||
dropTable(t, dba.db, tableName)
|
||||
}
|
||||
|
||||
// Create the state table and test for its existence
|
||||
err = dba.ensureStateTable(tableName)
|
||||
assert.Nil(t, err)
|
||||
exists, err = tableExists(dba.db, tableName)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Drop the state table
|
||||
dropTable(t, dba.db, tableName)
|
||||
}
|
||||
|
||||
func dropTable(t *testing.T, db *sql.DB, tableName string) {
|
||||
_, err := db.Exec(fmt.Sprintf("DROP TABLE %s", tableName))
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func deleteItemThatDoesNotExist(t *testing.T, pgs *PostgreSQL) {
|
||||
// Delete the item with a key not in the store
|
||||
deleteReq := &state.DeleteRequest{
|
||||
|
@ -477,7 +445,7 @@ func testInitConfiguration(t *testing.T) {
|
|||
tests := []struct {
|
||||
name string
|
||||
props map[string]string
|
||||
expectedErr string
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "Empty",
|
||||
|
@ -486,8 +454,8 @@ func testInitConfiguration(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "Valid connection string",
|
||||
props: map[string]string{connectionStringKey: getConnectionString()},
|
||||
expectedErr: "",
|
||||
props: map[string]string{"connectionString": getConnectionString()},
|
||||
expectedErr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -501,11 +469,11 @@ func testInitConfiguration(t *testing.T) {
|
|||
}
|
||||
|
||||
err := p.Init(metadata)
|
||||
if tt.expectedErr == "" {
|
||||
assert.Nil(t, err)
|
||||
if tt.expectedErr == nil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, err.Error(), tt.expectedErr)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, tt.expectedErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -107,7 +107,7 @@ func createPostgreSQL(t *testing.T) *PostgreSQL {
|
|||
assert.NotNil(t, pgs)
|
||||
|
||||
metadata := &state.Metadata{
|
||||
Base: metadata.Base{Properties: map[string]string{connectionStringKey: fakeConnectionString}},
|
||||
Base: metadata.Base{Properties: map[string]string{"connectionString": fakeConnectionString}},
|
||||
}
|
||||
|
||||
err := pgs.Init(*metadata)
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
/*
|
||||
Copyright 2021 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// Key used for "ttlInSeconds" in metadata.
|
||||
const MetadataTTLKey = "ttlInSeconds"
|
||||
|
||||
// ParseTTL parses the "ttlInSeconds" metadata property.
|
||||
func ParseTTL(requestMetadata map[string]string) (*int, error) {
|
||||
val, found := requestMetadata[MetadataTTLKey]
|
||||
if found && val != "" {
|
||||
parsedVal, err := strconv.ParseInt(val, 10, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("incorrect value for metadata '%s': %w", MetadataTTLKey, err)
|
||||
}
|
||||
if parsedVal < -1 || parsedVal > math.MaxInt32 {
|
||||
return nil, fmt.Errorf("incorrect value for metadata '%s': must be -1 or greater", MetadataTTLKey)
|
||||
}
|
||||
i := int(parsedVal)
|
||||
return &i, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
|
@ -0,0 +1,74 @@
|
|||
/*
|
||||
Copyright 2021 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"math"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseTTL(t *testing.T) {
|
||||
t.Run("TTL Not an integer", func(t *testing.T) {
|
||||
ttlInSeconds := "not an integer"
|
||||
ttl, err := ParseTTL(map[string]string{
|
||||
MetadataTTLKey: ttlInSeconds,
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, ttl)
|
||||
})
|
||||
|
||||
t.Run("TTL specified with wrong key", func(t *testing.T) {
|
||||
ttlInSeconds := 12345
|
||||
ttl, err := ParseTTL(map[string]string{
|
||||
"expirationTime": strconv.Itoa(ttlInSeconds),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, ttl)
|
||||
})
|
||||
|
||||
t.Run("TTL is a number", func(t *testing.T) {
|
||||
ttlInSeconds := 12345
|
||||
ttl, err := ParseTTL(map[string]string{
|
||||
MetadataTTLKey: strconv.Itoa(ttlInSeconds),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, *ttl, ttlInSeconds)
|
||||
})
|
||||
|
||||
t.Run("TTL not set", func(t *testing.T) {
|
||||
ttl, err := ParseTTL(map[string]string{})
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, ttl)
|
||||
})
|
||||
|
||||
t.Run("TTL < -1", func(t *testing.T) {
|
||||
ttl, err := ParseTTL(map[string]string{
|
||||
MetadataTTLKey: "-3",
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, ttl)
|
||||
})
|
||||
|
||||
t.Run("TTL bigger than 32-bit", func(t *testing.T) {
|
||||
ttl, err := ParseTTL(map[string]string{
|
||||
MetadataTTLKey: strconv.FormatInt(math.MaxInt32+1, 10),
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, ttl)
|
||||
})
|
||||
}
|
|
@ -2,8 +2,24 @@
|
|||
|
||||
This project aims to test the PostgreSQL State Store component under various conditions.
|
||||
|
||||
To run these tests:
|
||||
|
||||
```sh
|
||||
go test -v -tags certtests -count=1 .
|
||||
```
|
||||
|
||||
## Test plan
|
||||
|
||||
## Initialization and migrations
|
||||
|
||||
Also test the `tableName` and `metadataTableName` metadata properties.
|
||||
|
||||
1. Initializes the component with names for tables that don't exist
|
||||
2. Initializes the component with names for tables that don't exist, specifying an explicit schema
|
||||
3. Initializes the component with all migrations performed (current level is "2")
|
||||
4. Initializes the component with only the state table, created before the metadata table was added (implied migration level "1")
|
||||
5. Initializes three components at the same time and ensure no race conditions exist in performing migrations
|
||||
|
||||
## Test for CRUD operations
|
||||
|
||||
1. Able to create and test connection.
|
||||
|
@ -16,6 +32,15 @@ This project aims to test the PostgreSQL State Store component under various con
|
|||
* Not prone to SQL injection on read
|
||||
* Not prone to SQL injection on delete
|
||||
|
||||
## TTLs and cleanups
|
||||
|
||||
1. Correctly parse the `cleanupIntervalInSeconds` metadata property:
|
||||
- No value uses the default value (3600 seconds)
|
||||
- A positive value sets the interval to the given number of seconds
|
||||
- A zero or negative value disables the cleanup
|
||||
2. The cleanup method deletes expired records and updates the metadata table with the last time it ran
|
||||
3. The cleanup method doesn't run if the last iteration was less than `cleanupIntervalInSeconds` or if another process is doing the cleanup
|
||||
|
||||
## Connection Recovery
|
||||
|
||||
1. When PostgreSQL goes down and then comes back up - client is able to connect
|
||||
|
|
|
@ -8,6 +8,7 @@ require (
|
|||
github.com/dapr/dapr v1.9.5
|
||||
github.com/dapr/go-sdk v1.6.0
|
||||
github.com/dapr/kit v0.0.3
|
||||
github.com/jackc/pgx/v5 v5.2.0
|
||||
github.com/stretchr/testify v1.8.1
|
||||
)
|
||||
|
||||
|
@ -61,7 +62,6 @@ require (
|
|||
github.com/imdario/mergo v0.3.12 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect
|
||||
github.com/jackc/pgx/v5 v5.2.0 // indirect
|
||||
github.com/jhump/protoreflect v1.13.0 // indirect
|
||||
github.com/josharian/intern v1.0.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
|
|
|
@ -14,27 +14,36 @@ limitations under the License.
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
pgx "github.com/jackc/pgx/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/dapr/components-contrib/metadata"
|
||||
"github.com/dapr/components-contrib/tests/certification/embedded"
|
||||
"github.com/dapr/components-contrib/tests/certification/flow"
|
||||
"github.com/dapr/components-contrib/tests/certification/flow/dockercompose"
|
||||
"github.com/dapr/components-contrib/tests/certification/flow/sidecar"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/dapr/components-contrib/state"
|
||||
state_postgres "github.com/dapr/components-contrib/state/postgresql"
|
||||
state_loader "github.com/dapr/dapr/pkg/components/state"
|
||||
"github.com/dapr/dapr/pkg/runtime"
|
||||
dapr_testing "github.com/dapr/dapr/pkg/testing"
|
||||
"github.com/dapr/kit/logger"
|
||||
|
||||
"github.com/dapr/go-sdk/client"
|
||||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -42,6 +51,11 @@ const (
|
|||
dockerComposeYAML = "docker-compose.yml"
|
||||
stateStoreName = "statestore"
|
||||
certificationTestPrefix = "stable-certification-"
|
||||
connStringValue = "postgres://postgres:example@localhost:5432/dapr_test"
|
||||
keyConnectionString = "connectionString"
|
||||
keyCleanupInterval = "cleanupIntervalInSeconds"
|
||||
keyTableName = "tableName"
|
||||
keyMetadatTableName = "metadataTableName"
|
||||
)
|
||||
|
||||
func TestPostgreSQL(t *testing.T) {
|
||||
|
@ -59,8 +73,238 @@ func TestPostgreSQL(t *testing.T) {
|
|||
|
||||
currentGrpcPort := ports[0]
|
||||
|
||||
// Update this constant if you add more migrations
|
||||
const migrationLevel = "2"
|
||||
|
||||
// Holds a DB client as the "postgres" (ie. "root") user which we'll use to validate migrations and other changes in state
|
||||
var dbClient *pgx.Conn
|
||||
connectStep := func(ctx flow.Context) error {
|
||||
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Continue re-trying until the context times out, so we can wait for the DB to be up
|
||||
for {
|
||||
dbClient, err = pgx.Connect(connCtx, connStringValue)
|
||||
if err == nil || connCtx.Err() != nil {
|
||||
break
|
||||
}
|
||||
time.Sleep(750 * time.Millisecond)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Tests the "Init" method and the database migrations
|
||||
// It also tests the metadata properties "tableName" and "metadataTableName"
|
||||
initTest := func(ctx flow.Context) error {
|
||||
md := state.Metadata{
|
||||
Base: metadata.Base{
|
||||
Name: "inittest",
|
||||
Properties: map[string]string{
|
||||
keyConnectionString: connStringValue,
|
||||
keyCleanupInterval: "-1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("initial state clean", func(t *testing.T) {
|
||||
storeObj := state_postgres.NewPostgreSQLStateStore(log).(*state_postgres.PostgreSQL)
|
||||
md.Properties[keyTableName] = "clean_state"
|
||||
md.Properties[keyMetadatTableName] = "clean_metadata"
|
||||
|
||||
// Init and perform the migrations
|
||||
err := storeObj.Init(md)
|
||||
require.NoError(t, err, "failed to init")
|
||||
|
||||
// We should have the tables correctly created
|
||||
err = tableExists(dbClient, "public", "clean_state")
|
||||
assert.NoError(t, err, "state table does not exist")
|
||||
err = tableExists(dbClient, "public", "clean_metadata")
|
||||
assert.NoError(t, err, "metadata table does not exist")
|
||||
|
||||
// Ensure migration level is correct
|
||||
level, err := getMigrationLevel(dbClient, "clean_metadata")
|
||||
assert.NoError(t, err, "failed to get migration level")
|
||||
assert.Equal(t, migrationLevel, level, "migration level mismatch: found '%s' but expected '%s'", level, migrationLevel)
|
||||
|
||||
err = storeObj.Close()
|
||||
require.NoError(t, err, "failed to close component")
|
||||
})
|
||||
|
||||
t.Run("initial state clean, with explicit schema name", func(t *testing.T) {
|
||||
storeObj := state_postgres.NewPostgreSQLStateStore(log).(*state_postgres.PostgreSQL)
|
||||
md.Properties[keyTableName] = "public.clean2_state"
|
||||
md.Properties[keyMetadatTableName] = "public.clean2_metadata"
|
||||
|
||||
// Init and perform the migrations
|
||||
err := storeObj.Init(md)
|
||||
require.NoError(t, err, "failed to init")
|
||||
|
||||
// We should have the tables correctly created
|
||||
err = tableExists(dbClient, "public", "clean2_state")
|
||||
assert.NoError(t, err, "state table does not exist")
|
||||
err = tableExists(dbClient, "public", "clean2_metadata")
|
||||
assert.NoError(t, err, "metadata table does not exist")
|
||||
|
||||
// Ensure migration level is correct
|
||||
level, err := getMigrationLevel(dbClient, "clean2_metadata")
|
||||
assert.NoError(t, err, "failed to get migration level")
|
||||
assert.Equal(t, migrationLevel, level, "migration level mismatch: found '%s' but expected '%s'", level, migrationLevel)
|
||||
|
||||
err = storeObj.Close()
|
||||
require.NoError(t, err, "failed to close component")
|
||||
})
|
||||
|
||||
t.Run("all migrations performed", func(t *testing.T) {
|
||||
// Re-use "clean_state" and "clean_metadata"
|
||||
storeObj := state_postgres.NewPostgreSQLStateStore(log).(*state_postgres.PostgreSQL)
|
||||
md.Properties[keyTableName] = "clean_state"
|
||||
md.Properties[keyMetadatTableName] = "clean_metadata"
|
||||
|
||||
// Should already have migration level 2
|
||||
level, err := getMigrationLevel(dbClient, "clean_metadata")
|
||||
assert.NoError(t, err, "failed to get migration level")
|
||||
assert.Equal(t, migrationLevel, level, "migration level mismatch: found '%s' but expected '%s'", level, migrationLevel)
|
||||
|
||||
// Init and perform the migrations
|
||||
err = storeObj.Init(md)
|
||||
require.NoError(t, err, "failed to init")
|
||||
|
||||
// Ensure migration level is correct
|
||||
level, err = getMigrationLevel(dbClient, "clean_metadata")
|
||||
assert.NoError(t, err, "failed to get migration level")
|
||||
assert.Equal(t, migrationLevel, level, "migration level mismatch: found '%s' but expected '%s'", level, migrationLevel)
|
||||
|
||||
err = storeObj.Close()
|
||||
require.NoError(t, err, "failed to close component")
|
||||
})
|
||||
|
||||
t.Run("migrate from implied level 1", func(t *testing.T) {
|
||||
// Before we added the metadata table, the "implied" level 1 had only the state table
|
||||
// Create that table to simulate the old state and validate the migration
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
_, err := dbClient.Exec(
|
||||
ctx,
|
||||
`CREATE TABLE pre_state (
|
||||
key text NOT NULL PRIMARY KEY,
|
||||
value jsonb NOT NULL,
|
||||
isbinary boolean NOT NULL,
|
||||
insertdate TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
updatedate TIMESTAMP WITH TIME ZONE NULL
|
||||
)`,
|
||||
)
|
||||
require.NoError(t, err, "failed to create initial migration state")
|
||||
|
||||
storeObj := state_postgres.NewPostgreSQLStateStore(log).(*state_postgres.PostgreSQL)
|
||||
md.Properties[keyTableName] = "pre_state"
|
||||
md.Properties[keyMetadatTableName] = "pre_metadata"
|
||||
|
||||
// Init and perform the migrations
|
||||
err = storeObj.Init(md)
|
||||
require.NoError(t, err, "failed to init")
|
||||
|
||||
// We should have the metadata table created
|
||||
err = tableExists(dbClient, "public", "pre_metadata")
|
||||
assert.NoError(t, err, "metadata table does not exist")
|
||||
|
||||
// Ensure migration level is correct
|
||||
level, err := getMigrationLevel(dbClient, "pre_metadata")
|
||||
assert.NoError(t, err, "failed to get migration level")
|
||||
assert.Equal(t, migrationLevel, level, "migration level mismatch: found '%s' but expected '%s'", level, migrationLevel)
|
||||
|
||||
// Ensure the expiredate column has been added
|
||||
var colExists bool
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
err = dbClient.
|
||||
QueryRow(ctx,
|
||||
`SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.columns
|
||||
WHERE
|
||||
table_schema = 'public'
|
||||
AND table_name = 'pre_state'
|
||||
AND column_name = 'expiredate'
|
||||
)`,
|
||||
).
|
||||
Scan(&colExists)
|
||||
assert.True(t, colExists, "column expiredate not found in updated table")
|
||||
|
||||
err = storeObj.Close()
|
||||
require.NoError(t, err, "failed to close component")
|
||||
})
|
||||
|
||||
t.Run("initialize components concurrently", func(t *testing.T) {
|
||||
// Initializes 3 components concurrently using the same table names, and ensure that they perform migrations without conflicts and race conditions
|
||||
md.Properties[keyTableName] = "mystate"
|
||||
md.Properties[keyMetadatTableName] = "mymetadata"
|
||||
|
||||
errs := make(chan error, 3)
|
||||
hasLogs := atomic.Int32{}
|
||||
for i := 0; i < 3; i++ {
|
||||
go func(i int) {
|
||||
buf := &bytes.Buffer{}
|
||||
l := logger.NewLogger("multi-init-" + strconv.Itoa(i))
|
||||
l.SetOutput(io.MultiWriter(buf, os.Stdout))
|
||||
|
||||
// Init and perform the migrations
|
||||
storeObj := state_postgres.NewPostgreSQLStateStore(l).(*state_postgres.PostgreSQL)
|
||||
err := storeObj.Init(md)
|
||||
if err != nil {
|
||||
errs <- fmt.Errorf("%d failed to init: %w", i, err)
|
||||
return
|
||||
}
|
||||
|
||||
// One and only one of the loggers should have any message
|
||||
if buf.Len() > 0 {
|
||||
hasLogs.Add(1)
|
||||
}
|
||||
|
||||
// Close the component right away
|
||||
err = storeObj.Close()
|
||||
if err != nil {
|
||||
errs <- fmt.Errorf("%d failed to close: %w", i, err)
|
||||
return
|
||||
}
|
||||
|
||||
errs <- nil
|
||||
}(i)
|
||||
}
|
||||
|
||||
failed := false
|
||||
for i := 0; i < 3; i++ {
|
||||
select {
|
||||
case err := <-errs:
|
||||
failed = failed || !assert.NoError(t, err)
|
||||
case <-time.After(time.Minute):
|
||||
t.Fatal("timed out waiting for components to initialize")
|
||||
}
|
||||
}
|
||||
if failed {
|
||||
// Short-circuit
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
// Exactly one component should have written logs (which means generated any activity during migrations)
|
||||
assert.Equal(t, int32(1), hasLogs.Load(), "expected 1 component to log anything to indicate migration activity, but got %d", hasLogs.Load())
|
||||
|
||||
// We should have the tables correctly created
|
||||
err = tableExists(dbClient, "public", "mystate")
|
||||
assert.NoError(t, err, "state table does not exist")
|
||||
err = tableExists(dbClient, "public", "mymetadata")
|
||||
assert.NoError(t, err, "metadata table does not exist")
|
||||
|
||||
// Ensure migration level is correct
|
||||
level, err := getMigrationLevel(dbClient, "mymetadata")
|
||||
assert.NoError(t, err, "failed to get migration level")
|
||||
assert.Equal(t, migrationLevel, level, "migration level mismatch: found '%s' but expected '%s'", level, migrationLevel)
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
basicTest := func(ctx flow.Context) error {
|
||||
client, err := client.NewClientWithPort(fmt.Sprint(currentGrpcPort))
|
||||
client, err := client.NewClientWithPort(strconv.Itoa(currentGrpcPort))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -199,7 +443,7 @@ func TestPostgreSQL(t *testing.T) {
|
|||
}
|
||||
|
||||
testGetAfterPostgresRestart := func(ctx flow.Context) error {
|
||||
client, err := client.NewClientWithPort(fmt.Sprint(currentGrpcPort))
|
||||
client, err := client.NewClientWithPort(strconv.Itoa(currentGrpcPort))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -214,7 +458,7 @@ func TestPostgreSQL(t *testing.T) {
|
|||
|
||||
// checks the state store component is not vulnerable to SQL injection
|
||||
verifySQLInjectionTest := func(ctx flow.Context) error {
|
||||
client, err := client.NewClientWithPort(fmt.Sprint(currentGrpcPort))
|
||||
client, err := client.NewClientWithPort(strconv.Itoa(currentGrpcPort))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -244,23 +488,310 @@ func TestPostgreSQL(t *testing.T) {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Validates TTLs and garbage collections
|
||||
ttlTest := func(ctx flow.Context) error {
|
||||
md := state.Metadata{
|
||||
Base: metadata.Base{
|
||||
Name: "ttltest",
|
||||
Properties: map[string]string{
|
||||
keyConnectionString: connStringValue,
|
||||
keyTableName: "ttl_state",
|
||||
keyMetadatTableName: "ttl_metadata",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("parse cleanupIntervalInSeconds", func(t *testing.T) {
|
||||
t.Run("default value", func(t *testing.T) {
|
||||
// Default value is 1 hr
|
||||
md.Properties[keyCleanupInterval] = ""
|
||||
storeObj := state_postgres.NewPostgreSQLStateStore(log).(*state_postgres.PostgreSQL)
|
||||
|
||||
err := storeObj.Init(md)
|
||||
require.NoError(t, err, "failed to init")
|
||||
defer storeObj.Close()
|
||||
|
||||
dbAccess := storeObj.GetDBAccess().(*state_postgres.PostgresDBAccess)
|
||||
require.NotNil(t, dbAccess)
|
||||
|
||||
cleanupInterval := dbAccess.GetCleanupInterval()
|
||||
_ = assert.NotNil(t, cleanupInterval) &&
|
||||
assert.Equal(t, time.Duration(1*time.Hour), *cleanupInterval)
|
||||
})
|
||||
|
||||
t.Run("positive value", func(t *testing.T) {
|
||||
// A positive value is interpreted in seconds
|
||||
md.Properties[keyCleanupInterval] = "10"
|
||||
storeObj := state_postgres.NewPostgreSQLStateStore(log).(*state_postgres.PostgreSQL)
|
||||
|
||||
err := storeObj.Init(md)
|
||||
require.NoError(t, err, "failed to init")
|
||||
defer storeObj.Close()
|
||||
|
||||
dbAccess := storeObj.GetDBAccess().(*state_postgres.PostgresDBAccess)
|
||||
require.NotNil(t, dbAccess)
|
||||
|
||||
cleanupInterval := dbAccess.GetCleanupInterval()
|
||||
_ = assert.NotNil(t, cleanupInterval) &&
|
||||
assert.Equal(t, time.Duration(10*time.Second), *cleanupInterval)
|
||||
})
|
||||
|
||||
t.Run("disabled", func(t *testing.T) {
|
||||
// A value of <=0 means that the cleanup is disabled
|
||||
md.Properties[keyCleanupInterval] = "0"
|
||||
storeObj := state_postgres.NewPostgreSQLStateStore(log).(*state_postgres.PostgreSQL)
|
||||
|
||||
err := storeObj.Init(md)
|
||||
require.NoError(t, err, "failed to init")
|
||||
defer storeObj.Close()
|
||||
|
||||
dbAccess := storeObj.GetDBAccess().(*state_postgres.PostgresDBAccess)
|
||||
require.NotNil(t, dbAccess)
|
||||
|
||||
cleanupInterval := dbAccess.GetCleanupInterval()
|
||||
_ = assert.Nil(t, cleanupInterval)
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
t.Run("cleanup", func(t *testing.T) {
|
||||
md := state.Metadata{
|
||||
Base: metadata.Base{
|
||||
Name: "ttltest",
|
||||
Properties: map[string]string{
|
||||
keyConnectionString: connStringValue,
|
||||
keyTableName: "ttl_state",
|
||||
keyMetadatTableName: "ttl_metadata",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("automatically delete expired records", func(t *testing.T) {
|
||||
// Run every second
|
||||
md.Properties[keyCleanupInterval] = "1"
|
||||
|
||||
storeObj := state_postgres.NewPostgreSQLStateStore(log).(*state_postgres.PostgreSQL)
|
||||
err := storeObj.Init(md)
|
||||
require.NoError(t, err, "failed to init")
|
||||
defer storeObj.Close()
|
||||
|
||||
// Seed the database with some records
|
||||
err = populateTTLRecords(ctx, dbClient)
|
||||
require.NoError(t, err, "failed to seed records")
|
||||
|
||||
// Wait 2 seconds then verify we have only 10 rows left
|
||||
time.Sleep(2 * time.Second)
|
||||
count, err := countRowsInTable(ctx, dbClient, "ttl_state")
|
||||
require.NoError(t, err, "failed to run query to count rows")
|
||||
assert.Equal(t, 10, count)
|
||||
|
||||
// The "last-cleanup" value should be <= 1 second (+ a bit of buffer)
|
||||
lastCleanup, err := loadLastCleanupInterval(ctx, dbClient, "ttl_metadata")
|
||||
require.NoError(t, err, "failed to load value for 'last-cleanup'")
|
||||
assert.LessOrEqual(t, lastCleanup, int64(1200))
|
||||
|
||||
// Wait 6 more seconds and verify there are no more rows left
|
||||
time.Sleep(6 * time.Second)
|
||||
count, err = countRowsInTable(ctx, dbClient, "ttl_state")
|
||||
require.NoError(t, err, "failed to run query to count rows")
|
||||
assert.Equal(t, 0, count)
|
||||
|
||||
// The "last-cleanup" value should be <= 1 second (+ a bit of buffer)
|
||||
lastCleanup, err = loadLastCleanupInterval(ctx, dbClient, "ttl_metadata")
|
||||
require.NoError(t, err, "failed to load value for 'last-cleanup'")
|
||||
assert.LessOrEqual(t, lastCleanup, int64(1200))
|
||||
})
|
||||
|
||||
t.Run("cleanup concurrency", func(t *testing.T) {
|
||||
// Set to run every hour
|
||||
// (we'll manually trigger more frequent iterations)
|
||||
md.Properties[keyCleanupInterval] = "3600"
|
||||
|
||||
storeObj := state_postgres.NewPostgreSQLStateStore(log).(*state_postgres.PostgreSQL)
|
||||
err := storeObj.Init(md)
|
||||
require.NoError(t, err, "failed to init")
|
||||
defer storeObj.Close()
|
||||
|
||||
dbAccess := storeObj.GetDBAccess().(*state_postgres.PostgresDBAccess)
|
||||
require.NotNil(t, dbAccess)
|
||||
|
||||
// Seed the database with some records
|
||||
err = populateTTLRecords(ctx, dbClient)
|
||||
require.NoError(t, err, "failed to seed records")
|
||||
|
||||
// Validate that 20 records are present
|
||||
count, err := countRowsInTable(ctx, dbClient, "ttl_state")
|
||||
require.NoError(t, err, "failed to run query to count rows")
|
||||
assert.Equal(t, 20, count)
|
||||
|
||||
// Set last-cleanup to 1s ago (less than 3600s)
|
||||
err = setValueInMetadataTable(ctx, dbClient, "ttl_metadata", "'last-cleanup'", "CURRENT_TIMESTAMP - interval '1 second'")
|
||||
require.NoError(t, err, "failed to set last-cleanup")
|
||||
|
||||
// The "last-cleanup" value should be ~1 second (+ a bit of buffer)
|
||||
lastCleanup, err := loadLastCleanupInterval(ctx, dbClient, "ttl_metadata")
|
||||
require.NoError(t, err, "failed to load value for 'last-cleanup'")
|
||||
assert.LessOrEqual(t, lastCleanup, int64(1200))
|
||||
lastCleanupValueOrig, err := getValueFromMetadataTable(ctx, dbClient, "ttl_metadata", "last-cleanup")
|
||||
require.NoError(t, err, "failed to load absolute value for 'last-cleanup'")
|
||||
require.NotEmpty(t, lastCleanupValueOrig)
|
||||
|
||||
// Trigger the background cleanup, which should do nothing because the last cleanup was < 3600s
|
||||
err = dbAccess.CleanupExpired(ctx)
|
||||
require.NoError(t, err, "CleanupExpired returned an error")
|
||||
|
||||
// Validate that 20 records are still present
|
||||
count, err = countRowsInTable(ctx, dbClient, "ttl_state")
|
||||
require.NoError(t, err, "failed to run query to count rows")
|
||||
assert.Equal(t, 20, count)
|
||||
|
||||
// The "last-cleanup" value should not have been changed
|
||||
lastCleanupValue, err := getValueFromMetadataTable(ctx, dbClient, "ttl_metadata", "last-cleanup")
|
||||
require.NoError(t, err, "failed to load absolute value for 'last-cleanup'")
|
||||
assert.Equal(t, lastCleanupValueOrig, lastCleanupValue)
|
||||
})
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
flow.New(t, "Run tests").
|
||||
Step(dockercompose.Run("db", dockerComposeYAML)).
|
||||
Step("wait for component to start", flow.Sleep(10*time.Second)).
|
||||
// No waiting here, as connectStep retries until it's ready (or there's a timeout)
|
||||
//Step("wait for component to start", flow.Sleep(10*time.Second)).
|
||||
Step("connect to the database", connectStep).
|
||||
Step("run Init test", initTest).
|
||||
Step(sidecar.Run(sidecarNamePrefix+"dockerDefault",
|
||||
embedded.WithoutApp(),
|
||||
embedded.WithDaprGRPCPort(currentGrpcPort),
|
||||
embedded.WithComponentsPath("components/docker/default"),
|
||||
runtime.WithStates(stateRegistry),
|
||||
)).
|
||||
Step("Run CRUD test", basicTest).
|
||||
Step("Run eTag test", eTagTest).
|
||||
Step("Run transactions test", transactionsTest).
|
||||
Step("Run SQL injection test", verifySQLInjectionTest).
|
||||
Step("run CRUD test", basicTest).
|
||||
Step("run eTag test", eTagTest).
|
||||
Step("run transactions test", transactionsTest).
|
||||
Step("run SQL injection test", verifySQLInjectionTest).
|
||||
Step("run TTL test", ttlTest).
|
||||
Step("stop postgresql", dockercompose.Stop("db", dockerComposeYAML, "db")).
|
||||
Step("wait for component to stop", flow.Sleep(10*time.Second)).
|
||||
Step("start postgresql", dockercompose.Start("db", dockerComposeYAML, "db")).
|
||||
Step("wait for component to start", flow.Sleep(10*time.Second)).
|
||||
Step("Run connection test", testGetAfterPostgresRestart).
|
||||
Step("run connection test", testGetAfterPostgresRestart).
|
||||
Run()
|
||||
}
|
||||
|
||||
func tableExists(dbClient *pgx.Conn, schema string, table string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var scanTable, scanSchema string
|
||||
err := dbClient.QueryRow(
|
||||
ctx,
|
||||
"SELECT table_name, table_schema FROM information_schema.tables WHERE table_name = $1 AND table_schema = $2",
|
||||
table, schema,
|
||||
).Scan(&scanTable, &scanSchema)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error querying for table: %w", err)
|
||||
}
|
||||
if table != scanTable || schema != scanSchema {
|
||||
return fmt.Errorf("found table '%s.%s' does not match", scanSchema, scanTable)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getMigrationLevel(dbClient *pgx.Conn, metadataTable string) (level string, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err = dbClient.
|
||||
QueryRow(ctx, fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, metadataTable)).
|
||||
Scan(&level)
|
||||
if err != nil && errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
level = ""
|
||||
}
|
||||
return level, err
|
||||
}
|
||||
|
||||
func populateTTLRecords(ctx context.Context, dbClient *pgx.Conn) error {
|
||||
// Insert 10 records that have expired, and 10 that will expire in 6 seconds
|
||||
exp := time.Now().Add(-1 * time.Minute)
|
||||
rows := make([][]any, 20)
|
||||
for i := 0; i < 10; i++ {
|
||||
rows[i] = []any{
|
||||
fmt.Sprintf("expired_%d", i),
|
||||
json.RawMessage(fmt.Sprintf(`"value_%d"`, i)),
|
||||
false,
|
||||
exp,
|
||||
}
|
||||
}
|
||||
exp = time.Now().Add(4 * time.Second)
|
||||
for i := 0; i < 10; i++ {
|
||||
rows[i+10] = []any{
|
||||
fmt.Sprintf("notexpired_%d", i),
|
||||
json.RawMessage(fmt.Sprintf(`"value_%d"`, i)),
|
||||
false,
|
||||
exp,
|
||||
}
|
||||
}
|
||||
queryCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
defer cancel()
|
||||
n, err := dbClient.CopyFrom(
|
||||
queryCtx,
|
||||
pgx.Identifier{"ttl_state"},
|
||||
[]string{"key", "value", "isbinary", "expiredate"},
|
||||
pgx.CopyFromRows(rows),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n != 20 {
|
||||
return fmt.Errorf("expected to copy 20 rows, but only got %d", n)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func countRowsInTable(ctx context.Context, dbClient *pgx.Conn, table string) (count int, err error) {
|
||||
queryCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
err = dbClient.QueryRow(queryCtx, "SELECT COUNT(key) FROM "+table).Scan(&count)
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
|
||||
func loadLastCleanupInterval(ctx context.Context, dbClient *pgx.Conn, table string) (lastCleanup int64, err error) {
|
||||
queryCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
err = dbClient.
|
||||
QueryRow(queryCtx,
|
||||
fmt.Sprintf("SELECT (EXTRACT('epoch' FROM CURRENT_TIMESTAMP - value::timestamp with time zone) * 1000)::bigint FROM %s WHERE key = 'last-cleanup'", table),
|
||||
).
|
||||
Scan(&lastCleanup)
|
||||
cancel()
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Note this uses fmt.Sprintf and not parametrized queries-on purpose, so we can pass Postgres functions).
|
||||
// Normally this would be a very bad idea, just don't do it... (do as I say don't do as I do :) ).
|
||||
func setValueInMetadataTable(ctx context.Context, dbClient *pgx.Conn, table, key, value string) error {
|
||||
queryCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
_, err := dbClient.Exec(queryCtx,
|
||||
//nolint:gosec
|
||||
fmt.Sprintf(`INSERT INTO %[1]s (key, value) VALUES (%[2]s, %[3]s) ON CONFLICT (key) DO UPDATE SET value = %[3]s`, table, key, value),
|
||||
)
|
||||
cancel()
|
||||
return err
|
||||
}
|
||||
|
||||
func getValueFromMetadataTable(ctx context.Context, dbClient *pgx.Conn, table, key string) (value string, err error) {
|
||||
queryCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
err = dbClient.
|
||||
QueryRow(queryCtx, fmt.Sprintf("SELECT value FROM %s WHERE key = $1", table), key).
|
||||
Scan(&value)
|
||||
cancel()
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -20,8 +20,7 @@ components:
|
|||
allOperations: false
|
||||
operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "first-write" ]
|
||||
- component: postgresql
|
||||
allOperations: false
|
||||
operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "query", "first-write" ]
|
||||
allOperations: true
|
||||
- component: mysql.mysql
|
||||
allOperations: false
|
||||
operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "first-write" ]
|
||||
|
|
Loading…
Reference in New Issue