Changed postgres state store to use pgx instead of db/sql
Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
parent
2582f154bd
commit
80b1bb14b7
|
|
@ -15,7 +15,6 @@ package postgres
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
@ -114,7 +113,7 @@ func (p *ConfigurationStore) Init(metadata configuration.Metadata) error {
|
||||||
exists := false
|
exists := false
|
||||||
err = p.client.QueryRow(ctx, QueryTableExists, p.metadata.configTable).Scan(&exists)
|
err = p.client.QueryRow(ctx, QueryTableExists, p.metadata.configTable).Scan(&exists)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
return fmt.Errorf(ErrorMissingTable, p.metadata.configTable)
|
return fmt.Errorf(ErrorMissingTable, p.metadata.configTable)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("error in checking if configtable '%s' exists - '%w'", p.metadata.configTable, err)
|
return fmt.Errorf("error in checking if configtable '%s' exists - '%w'", p.metadata.configTable, err)
|
||||||
|
|
@ -135,7 +134,7 @@ func (p *ConfigurationStore) Get(ctx context.Context, req *configuration.GetRequ
|
||||||
rows, err := p.client.Query(ctx, query, params...)
|
rows, err := p.client.Query(ctx, query, params...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If no rows exist, return an empty response, otherwise return the error.
|
// If no rows exist, return an empty response, otherwise return the error.
|
||||||
if err == sql.ErrNoRows {
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
return &configuration.GetResponse{}, nil
|
return &configuration.GetResponse{}, nil
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("error in querying configuration store: '%w'", err)
|
return nil, fmt.Errorf("error in querying configuration store: '%w'", err)
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,9 @@ package postgresql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
|
||||||
"github.com/dapr/components-contrib/state"
|
"github.com/dapr/components-contrib/state"
|
||||||
)
|
)
|
||||||
|
|
@ -34,10 +36,9 @@ type dbAccess interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Interface that contains methods for querying.
|
// Interface that contains methods for querying.
|
||||||
// Applies to both *sql.DB and *sql.Tx
|
// Applies to *pgx.Conn, *pgxpool.Pool, and pgx.Tx
|
||||||
type dbquerier interface {
|
type dbquerier interface {
|
||||||
Exec(query string, args ...any) (sql.Result, error)
|
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
|
||||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
Query(context.Context, string, ...interface{}) (pgx.Rows, error)
|
||||||
QueryRow(query string, args ...any) *sql.Row
|
QueryRow(context.Context, string, ...interface{}) pgx.Row
|
||||||
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,20 +15,21 @@ package postgresql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
|
||||||
"github.com/dapr/kit/logger"
|
"github.com/dapr/kit/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Performs migrations for the database schema
|
// Performs migrations for the database schema
|
||||||
type migrations struct {
|
type migrations struct {
|
||||||
Logger logger.Logger
|
Logger logger.Logger
|
||||||
Conn *sql.DB
|
Conn pgxPoolConn
|
||||||
StateTableName string
|
StateTableName string
|
||||||
MetadataTableName string
|
MetadataTableName string
|
||||||
}
|
}
|
||||||
|
|
@ -42,7 +43,7 @@ func (m *migrations) Perform(ctx context.Context) error {
|
||||||
|
|
||||||
// Long timeout here as this query may block
|
// Long timeout here as this query may block
|
||||||
queryCtx, cancel := context.WithTimeout(ctx, time.Minute)
|
queryCtx, cancel := context.WithTimeout(ctx, time.Minute)
|
||||||
_, err := m.Conn.ExecContext(queryCtx, "SELECT pg_advisory_lock($1)", lockID)
|
_, err := m.Conn.Exec(queryCtx, "SELECT pg_advisory_lock($1)", lockID)
|
||||||
cancel()
|
cancel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("faild to acquire advisory lock: %w", err)
|
return fmt.Errorf("faild to acquire advisory lock: %w", err)
|
||||||
|
|
@ -51,7 +52,7 @@ func (m *migrations) Perform(ctx context.Context) error {
|
||||||
// Release the lock
|
// Release the lock
|
||||||
defer func() {
|
defer func() {
|
||||||
queryCtx, cancel = context.WithTimeout(ctx, time.Minute)
|
queryCtx, cancel = context.WithTimeout(ctx, time.Minute)
|
||||||
_, err = m.Conn.ExecContext(queryCtx, "SELECT pg_advisory_unlock($1)", lockID)
|
_, err = m.Conn.Exec(queryCtx, "SELECT pg_advisory_unlock($1)", lockID)
|
||||||
cancel()
|
cancel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Panicking here, as this forcibly closes the session and thus ensures we are not leaving locks hanging around
|
// Panicking here, as this forcibly closes the session and thus ensures we are not leaving locks hanging around
|
||||||
|
|
@ -84,11 +85,11 @@ func (m *migrations) Perform(ctx context.Context) error {
|
||||||
)
|
)
|
||||||
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
|
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
|
||||||
err = m.Conn.
|
err = m.Conn.
|
||||||
QueryRowContext(queryCtx,
|
QueryRow(queryCtx,
|
||||||
fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, m.MetadataTableName),
|
fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, m.MetadataTableName),
|
||||||
).Scan(&migrationLevelStr)
|
).Scan(&migrationLevelStr)
|
||||||
cancel()
|
cancel()
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
// If there's no row...
|
// If there's no row...
|
||||||
migrationLevel = 0
|
migrationLevel = 0
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
|
|
@ -109,7 +110,7 @@ func (m *migrations) Perform(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
|
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
|
||||||
_, err = m.Conn.ExecContext(queryCtx,
|
_, err = m.Conn.Exec(queryCtx,
|
||||||
fmt.Sprintf(`INSERT INTO %s (key, value) VALUES ('migrations', $1) ON CONFLICT (key) DO UPDATE SET value = $1`, m.MetadataTableName),
|
fmt.Sprintf(`INSERT INTO %s (key, value) VALUES ('migrations', $1) ON CONFLICT (key) DO UPDATE SET value = $1`, m.MetadataTableName),
|
||||||
strconv.Itoa(i+1),
|
strconv.Itoa(i+1),
|
||||||
)
|
)
|
||||||
|
|
@ -126,7 +127,7 @@ func (m migrations) createMetadataTable(ctx context.Context) error {
|
||||||
m.Logger.Infof("Creating metadata table '%s'", m.MetadataTableName)
|
m.Logger.Infof("Creating metadata table '%s'", m.MetadataTableName)
|
||||||
// Add an "IF NOT EXISTS" in case another Dapr sidecar is creating the same table at the same time
|
// Add an "IF NOT EXISTS" in case another Dapr sidecar is creating the same table at the same time
|
||||||
// In the next step we'll acquire a lock so there won't be issues with concurrency
|
// In the next step we'll acquire a lock so there won't be issues with concurrency
|
||||||
_, err := m.Conn.Exec(fmt.Sprintf(
|
_, err := m.Conn.Exec(ctx, fmt.Sprintf(
|
||||||
`CREATE TABLE IF NOT EXISTS %s (
|
`CREATE TABLE IF NOT EXISTS %s (
|
||||||
key text NOT NULL PRIMARY KEY,
|
key text NOT NULL PRIMARY KEY,
|
||||||
value text NOT NULL
|
value text NOT NULL
|
||||||
|
|
@ -148,7 +149,7 @@ func (m migrations) tableExists(ctx context.Context, tableName string) (exists b
|
||||||
|
|
||||||
if schema == "" {
|
if schema == "" {
|
||||||
err = m.Conn.
|
err = m.Conn.
|
||||||
QueryRowContext(
|
QueryRow(
|
||||||
ctx,
|
ctx,
|
||||||
`SELECT table_name, table_schema
|
`SELECT table_name, table_schema
|
||||||
FROM information_schema.tables
|
FROM information_schema.tables
|
||||||
|
|
@ -158,7 +159,7 @@ func (m migrations) tableExists(ctx context.Context, tableName string) (exists b
|
||||||
Scan(&table, &schema)
|
Scan(&table, &schema)
|
||||||
} else {
|
} else {
|
||||||
err = m.Conn.
|
err = m.Conn.
|
||||||
QueryRowContext(
|
QueryRow(
|
||||||
ctx,
|
ctx,
|
||||||
`SELECT table_name, table_schema
|
`SELECT table_name, table_schema
|
||||||
FROM information_schema.tables
|
FROM information_schema.tables
|
||||||
|
|
@ -168,7 +169,7 @@ func (m migrations) tableExists(ctx context.Context, tableName string) (exists b
|
||||||
Scan(&table, &schema)
|
Scan(&table, &schema)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil && errors.Is(err, sql.ErrNoRows) {
|
if err != nil && errors.Is(err, pgx.ErrNoRows) {
|
||||||
return false, "", "", nil
|
return false, "", "", nil
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return false, "", "", fmt.Errorf("failed to check if table '%s' exists: %w", tableName, err)
|
return false, "", "", fmt.Errorf("failed to check if table '%s' exists: %w", tableName, err)
|
||||||
|
|
@ -195,6 +196,7 @@ var allMigrations = [2]func(ctx context.Context, m *migrations) error{
|
||||||
// We need to add an "IF NOT EXISTS" because we may be migrating from when we did not use a metadata table
|
// We need to add an "IF NOT EXISTS" because we may be migrating from when we did not use a metadata table
|
||||||
m.Logger.Infof("Creating state table '%s'", m.StateTableName)
|
m.Logger.Infof("Creating state table '%s'", m.StateTableName)
|
||||||
_, err := m.Conn.Exec(
|
_, err := m.Conn.Exec(
|
||||||
|
ctx,
|
||||||
fmt.Sprintf(
|
fmt.Sprintf(
|
||||||
`CREATE TABLE IF NOT EXISTS %s (
|
`CREATE TABLE IF NOT EXISTS %s (
|
||||||
key text NOT NULL PRIMARY KEY,
|
key text NOT NULL PRIMARY KEY,
|
||||||
|
|
@ -215,7 +217,7 @@ var allMigrations = [2]func(ctx context.Context, m *migrations) error{
|
||||||
// Migration 1: add the "expiredate" column
|
// Migration 1: add the "expiredate" column
|
||||||
func(ctx context.Context, m *migrations) error {
|
func(ctx context.Context, m *migrations) error {
|
||||||
m.Logger.Infof("Adding expiredate column to state table '%s'", m.StateTableName)
|
m.Logger.Infof("Adding expiredate column to state table '%s'", m.StateTableName)
|
||||||
_, err := m.Conn.Exec(fmt.Sprintf(
|
_, err := m.Conn.Exec(ctx, fmt.Sprintf(
|
||||||
`ALTER TABLE %s ADD expiredate TIMESTAMP WITH TIME ZONE`,
|
`ALTER TABLE %s ADD expiredate TIMESTAMP WITH TIME ZONE`,
|
||||||
m.StateTableName,
|
m.StateTableName,
|
||||||
))
|
))
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ package postgresql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
@ -23,15 +22,16 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
|
|
||||||
"github.com/dapr/components-contrib/metadata"
|
"github.com/dapr/components-contrib/metadata"
|
||||||
"github.com/dapr/components-contrib/state"
|
"github.com/dapr/components-contrib/state"
|
||||||
"github.com/dapr/components-contrib/state/query"
|
"github.com/dapr/components-contrib/state/query"
|
||||||
stateutils "github.com/dapr/components-contrib/state/utils"
|
stateutils "github.com/dapr/components-contrib/state/utils"
|
||||||
"github.com/dapr/kit/logger"
|
"github.com/dapr/kit/logger"
|
||||||
"github.com/dapr/kit/ptr"
|
"github.com/dapr/kit/ptr"
|
||||||
|
|
||||||
// Blank import for the underlying Postgres driver.
|
|
||||||
_ "github.com/jackc/pgx/v5/stdlib"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
@ -43,12 +43,24 @@ const (
|
||||||
|
|
||||||
var errMissingConnectionString = errors.New("missing connection string")
|
var errMissingConnectionString = errors.New("missing connection string")
|
||||||
|
|
||||||
|
// Interface that applies to *pgxpool.Pool.
|
||||||
|
// We need this to be able to mock the connection in tests.
|
||||||
|
type pgxPoolConn interface {
|
||||||
|
Begin(context.Context) (pgx.Tx, error)
|
||||||
|
BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error)
|
||||||
|
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
|
||||||
|
Query(context.Context, string, ...interface{}) (pgx.Rows, error)
|
||||||
|
QueryRow(context.Context, string, ...interface{}) pgx.Row
|
||||||
|
Ping(context.Context) error
|
||||||
|
Close()
|
||||||
|
}
|
||||||
|
|
||||||
// PostgresDBAccess implements dbaccess.
|
// PostgresDBAccess implements dbaccess.
|
||||||
type PostgresDBAccess struct {
|
type PostgresDBAccess struct {
|
||||||
logger logger.Logger
|
logger logger.Logger
|
||||||
metadata postgresMetadataStruct
|
metadata postgresMetadataStruct
|
||||||
cleanupInterval *time.Duration
|
cleanupInterval *time.Duration
|
||||||
db *sql.DB
|
db pgxPoolConn
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
@ -81,23 +93,31 @@ func (p *PostgresDBAccess) Init(meta state.Metadata) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err := sql.Open("pgx", p.metadata.ConnectionString)
|
config, err := pgxpool.ParseConfig(p.metadata.ConnectionString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to parse connection string: %w", err)
|
||||||
|
p.logger.Error(err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if p.metadata.ConnectionMaxIdleTime > 0 {
|
||||||
|
config.MaxConnIdleTime = p.metadata.ConnectionMaxIdleTime
|
||||||
|
}
|
||||||
|
|
||||||
|
connCtx, connCancel := context.WithTimeout(p.ctx, 30*time.Second)
|
||||||
|
p.db, err = pgxpool.NewWithConfig(connCtx, config)
|
||||||
|
connCancel()
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to connect to the database: %w", err)
|
||||||
p.logger.Error(err)
|
p.logger.Error(err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
p.db = db
|
|
||||||
|
|
||||||
pingCtx, pingCancel := context.WithTimeout(p.ctx, 30*time.Second)
|
pingCtx, pingCancel := context.WithTimeout(p.ctx, 30*time.Second)
|
||||||
pingErr := db.PingContext(pingCtx)
|
err = p.db.Ping(pingCtx)
|
||||||
pingCancel()
|
pingCancel()
|
||||||
if pingErr != nil {
|
|
||||||
return pingErr
|
|
||||||
}
|
|
||||||
|
|
||||||
p.db.SetConnMaxIdleTime(p.metadata.ConnectionMaxIdleTime)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to ping the database: %w", err)
|
||||||
|
p.logger.Error(err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -117,8 +137,9 @@ func (p *PostgresDBAccess) Init(meta state.Metadata) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PostgresDBAccess) GetDB() *sql.DB {
|
func (p *PostgresDBAccess) GetDB() *pgxpool.Pool {
|
||||||
return p.db
|
// We can safely cast to *pgxpool.Pool because this method is never used in unit tests where we mock the DB
|
||||||
|
return p.db.(*pgxpool.Pool)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PostgresDBAccess) ParseMetadata(meta state.Metadata) error {
|
func (p *PostgresDBAccess) ParseMetadata(meta state.Metadata) error {
|
||||||
|
|
@ -193,8 +214,6 @@ func (p *PostgresDBAccess) doSet(parentCtx context.Context, db dbquerier, req *s
|
||||||
ttlSeconds = *ttl
|
ttlSeconds = *ttl
|
||||||
}
|
}
|
||||||
|
|
||||||
var result sql.Result
|
|
||||||
|
|
||||||
// Sprintf is required for table name because query.DB does not substitute parameters for table names.
|
// Sprintf is required for table name because query.DB does not substitute parameters for table names.
|
||||||
// Other parameters use query.DB parameter substitution.
|
// Other parameters use query.DB parameter substitution.
|
||||||
var (
|
var (
|
||||||
|
|
@ -246,8 +265,8 @@ func (p *PostgresDBAccess) doSet(parentCtx context.Context, db dbquerier, req *s
|
||||||
} else {
|
} else {
|
||||||
queryExpiredate = "NULL"
|
queryExpiredate = "NULL"
|
||||||
}
|
}
|
||||||
result, err = db.ExecContext(parentCtx, fmt.Sprintf(query, p.metadata.TableName, queryExpiredate), params...)
|
|
||||||
|
|
||||||
|
result, err := db.Exec(parentCtx, fmt.Sprintf(query, p.metadata.TableName, queryExpiredate), params...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if req.ETag != nil && *req.ETag != "" {
|
if req.ETag != nil && *req.ETag != "" {
|
||||||
return state.NewETagError(state.ETagMismatch, err)
|
return state.NewETagError(state.ETagMismatch, err)
|
||||||
|
|
@ -255,11 +274,7 @@ func (p *PostgresDBAccess) doSet(parentCtx context.Context, db dbquerier, req *s
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := result.RowsAffected()
|
if result.RowsAffected() != 1 {
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if rows != 1 {
|
|
||||||
return errors.New("no item was updated")
|
return errors.New("no item was updated")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -267,11 +282,20 @@ func (p *PostgresDBAccess) doSet(parentCtx context.Context, db dbquerier, req *s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PostgresDBAccess) BulkSet(parentCtx context.Context, req []state.SetRequest) error {
|
func (p *PostgresDBAccess) BulkSet(parentCtx context.Context, req []state.SetRequest) error {
|
||||||
tx, err := p.db.BeginTx(parentCtx, nil)
|
ctx, cancel := context.WithTimeout(parentCtx, 30*time.Second)
|
||||||
|
tx, err := p.db.Begin(ctx)
|
||||||
|
cancel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer func() {
|
||||||
|
rollbackCtx, rollbackCancel := context.WithTimeout(parentCtx, 30*time.Second)
|
||||||
|
rollbackErr := tx.Rollback(rollbackCtx)
|
||||||
|
rollbackCancel()
|
||||||
|
if rollbackErr != nil && !errors.Is(rollbackErr, pgx.ErrTxClosed) {
|
||||||
|
p.logger.Errorf("Failed to rollback transaction in BulkSet: %v", rollbackErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if len(req) > 0 {
|
if len(req) > 0 {
|
||||||
for i := range req {
|
for i := range req {
|
||||||
|
|
@ -282,7 +306,9 @@ func (p *PostgresDBAccess) BulkSet(parentCtx context.Context, req []state.SetReq
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tx.Commit()
|
ctx, cancel = context.WithTimeout(parentCtx, 30*time.Second)
|
||||||
|
err = tx.Commit(ctx)
|
||||||
|
cancel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to commit transaction: %w", err)
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -299,7 +325,7 @@ func (p *PostgresDBAccess) Get(parentCtx context.Context, req *state.GetRequest)
|
||||||
var (
|
var (
|
||||||
value []byte
|
value []byte
|
||||||
isBinary bool
|
isBinary bool
|
||||||
etag uint64 // Postgres uses uint32, but FormatUint requires uint64, so using uint64 directly to avoid re-allocations
|
etag uint32
|
||||||
)
|
)
|
||||||
query := `SELECT
|
query := `SELECT
|
||||||
value, isbinary, xmin AS etag
|
value, isbinary, xmin AS etag
|
||||||
|
|
@ -307,12 +333,11 @@ func (p *PostgresDBAccess) Get(parentCtx context.Context, req *state.GetRequest)
|
||||||
WHERE
|
WHERE
|
||||||
key = $1
|
key = $1
|
||||||
AND (expiredate IS NULL OR expiredate >= CURRENT_TIMESTAMP)`
|
AND (expiredate IS NULL OR expiredate >= CURRENT_TIMESTAMP)`
|
||||||
err := p.db.
|
err := p.db.QueryRow(parentCtx, fmt.Sprintf(query, p.metadata.TableName), req.Key).
|
||||||
QueryRowContext(parentCtx, fmt.Sprintf(query, p.metadata.TableName), req.Key).
|
|
||||||
Scan(&value, &isBinary, &etag)
|
Scan(&value, &isBinary, &etag)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If no rows exist, return an empty response, otherwise return the error.
|
// If no rows exist, return an empty response, otherwise return the error.
|
||||||
if err == sql.ErrNoRows {
|
if err == pgx.ErrNoRows {
|
||||||
return &state.GetResponse{}, nil
|
return &state.GetResponse{}, nil
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -334,14 +359,14 @@ func (p *PostgresDBAccess) Get(parentCtx context.Context, req *state.GetRequest)
|
||||||
|
|
||||||
return &state.GetResponse{
|
return &state.GetResponse{
|
||||||
Data: data,
|
Data: data,
|
||||||
ETag: ptr.Of(strconv.FormatUint(etag, 10)),
|
ETag: ptr.Of(strconv.FormatUint(uint64(etag), 10)),
|
||||||
Metadata: req.Metadata,
|
Metadata: req.Metadata,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return &state.GetResponse{
|
return &state.GetResponse{
|
||||||
Data: value,
|
Data: value,
|
||||||
ETag: ptr.Of(strconv.FormatUint(etag, 10)),
|
ETag: ptr.Of(strconv.FormatUint(uint64(etag), 10)),
|
||||||
Metadata: req.Metadata,
|
Metadata: req.Metadata,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
@ -356,10 +381,9 @@ func (p *PostgresDBAccess) doDelete(parentCtx context.Context, db dbquerier, req
|
||||||
return errors.New("missing key in delete operation")
|
return errors.New("missing key in delete operation")
|
||||||
}
|
}
|
||||||
|
|
||||||
var result sql.Result
|
var result pgconn.CommandTag
|
||||||
|
|
||||||
if req.ETag == nil || *req.ETag == "" {
|
if req.ETag == nil || *req.ETag == "" {
|
||||||
result, err = db.ExecContext(parentCtx, "DELETE FROM state WHERE key = $1", req.Key)
|
result, err = db.Exec(parentCtx, "DELETE FROM state WHERE key = $1", req.Key)
|
||||||
} else {
|
} else {
|
||||||
// Convert req.ETag to uint32 for postgres XID compatibility
|
// Convert req.ETag to uint32 for postgres XID compatibility
|
||||||
var etag64 uint64
|
var etag64 uint64
|
||||||
|
|
@ -367,20 +391,15 @@ func (p *PostgresDBAccess) doDelete(parentCtx context.Context, db dbquerier, req
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return state.NewETagError(state.ETagInvalid, err)
|
return state.NewETagError(state.ETagInvalid, err)
|
||||||
}
|
}
|
||||||
etag := uint32(etag64)
|
|
||||||
|
|
||||||
result, err = db.ExecContext(parentCtx, "DELETE FROM state WHERE key = $1 AND xmin = $2", req.Key, etag)
|
result, err = db.Exec(parentCtx, "DELETE FROM state WHERE key = $1 AND xmin = $2", req.Key, uint32(etag64))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := result.RowsAffected()
|
rows := result.RowsAffected()
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if rows != 1 && req.ETag != nil && *req.ETag != "" {
|
if rows != 1 && req.ETag != nil && *req.ETag != "" {
|
||||||
return state.NewETagError(state.ETagMismatch, nil)
|
return state.NewETagError(state.ETagMismatch, nil)
|
||||||
}
|
}
|
||||||
|
|
@ -389,11 +408,20 @@ func (p *PostgresDBAccess) doDelete(parentCtx context.Context, db dbquerier, req
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PostgresDBAccess) BulkDelete(parentCtx context.Context, req []state.DeleteRequest) error {
|
func (p *PostgresDBAccess) BulkDelete(parentCtx context.Context, req []state.DeleteRequest) error {
|
||||||
tx, err := p.db.BeginTx(parentCtx, nil)
|
ctx, cancel := context.WithTimeout(parentCtx, 30*time.Second)
|
||||||
|
tx, err := p.db.Begin(ctx)
|
||||||
|
cancel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer func() {
|
||||||
|
rollbackCtx, rollbackCancel := context.WithTimeout(parentCtx, 30*time.Second)
|
||||||
|
rollbackErr := tx.Rollback(rollbackCtx)
|
||||||
|
rollbackCancel()
|
||||||
|
if rollbackErr != nil && !errors.Is(rollbackErr, pgx.ErrTxClosed) {
|
||||||
|
p.logger.Errorf("Failed to rollback transaction in BulkDelete: %v", rollbackErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if len(req) > 0 {
|
if len(req) > 0 {
|
||||||
for i := range req {
|
for i := range req {
|
||||||
|
|
@ -404,7 +432,9 @@ func (p *PostgresDBAccess) BulkDelete(parentCtx context.Context, req []state.Del
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tx.Commit()
|
ctx, cancel = context.WithTimeout(parentCtx, 30*time.Second)
|
||||||
|
err = tx.Commit(ctx)
|
||||||
|
cancel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to commit transaction: %w", err)
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -413,11 +443,20 @@ func (p *PostgresDBAccess) BulkDelete(parentCtx context.Context, req []state.Del
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PostgresDBAccess) ExecuteMulti(parentCtx context.Context, request *state.TransactionalStateRequest) error {
|
func (p *PostgresDBAccess) ExecuteMulti(parentCtx context.Context, request *state.TransactionalStateRequest) error {
|
||||||
tx, err := p.db.BeginTx(parentCtx, nil)
|
ctx, cancel := context.WithTimeout(parentCtx, 30*time.Second)
|
||||||
|
tx, err := p.db.Begin(ctx)
|
||||||
|
cancel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer func() {
|
||||||
|
rollbackCtx, rollbackCancel := context.WithTimeout(parentCtx, 30*time.Second)
|
||||||
|
rollbackErr := tx.Rollback(rollbackCtx)
|
||||||
|
rollbackCancel()
|
||||||
|
if rollbackErr != nil && !errors.Is(rollbackErr, pgx.ErrTxClosed) {
|
||||||
|
p.logger.Errorf("Failed to rollback transaction in ExecMulti: %v", rollbackErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
for _, o := range request.Operations {
|
for _, o := range request.Operations {
|
||||||
switch o.Operation {
|
switch o.Operation {
|
||||||
|
|
@ -450,7 +489,9 @@ func (p *PostgresDBAccess) ExecuteMulti(parentCtx context.Context, request *stat
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tx.Commit()
|
ctx, cancel = context.WithTimeout(parentCtx, 30*time.Second)
|
||||||
|
err = tx.Commit(ctx)
|
||||||
|
cancel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to commit transaction: %w", err)
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -520,19 +561,13 @@ func (p *PostgresDBAccess) CleanupExpired(ctx context.Context) error {
|
||||||
// Note we're not using the transaction here as we don't want this to be rolled back half-way or to lock the table unnecessarily
|
// Note we're not using the transaction here as we don't want this to be rolled back half-way or to lock the table unnecessarily
|
||||||
// Need to use fmt.Sprintf because we can't parametrize a table name
|
// Need to use fmt.Sprintf because we can't parametrize a table name
|
||||||
// Note we are not setting a timeout here as this query can take a "long" time, especially if there's no index on expiredate
|
// Note we are not setting a timeout here as this query can take a "long" time, especially if there's no index on expiredate
|
||||||
//nolint:gosec
|
|
||||||
stmt := fmt.Sprintf(`DELETE FROM %s WHERE expiredate IS NOT NULL AND expiredate < CURRENT_TIMESTAMP`, p.metadata.TableName)
|
stmt := fmt.Sprintf(`DELETE FROM %s WHERE expiredate IS NOT NULL AND expiredate < CURRENT_TIMESTAMP`, p.metadata.TableName)
|
||||||
res, err := p.db.ExecContext(ctx, stmt)
|
res, err := p.db.Exec(ctx, stmt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to execute query: %w", err)
|
return fmt.Errorf("failed to execute query: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cleaned, err := res.RowsAffected()
|
p.logger.Infof("Removed %d expired rows", res.RowsAffected())
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to count affected rows: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
p.logger.Infof("Removed %d expired rows", cleaned)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -540,7 +575,7 @@ func (p *PostgresDBAccess) CleanupExpired(ctx context.Context) error {
|
||||||
// Returns true if the row was updated, which means that the cleanup can proceed.
|
// Returns true if the row was updated, which means that the cleanup can proceed.
|
||||||
func (p *PostgresDBAccess) UpdateLastCleanup(ctx context.Context, db dbquerier, cleanupInterval time.Duration) (bool, error) {
|
func (p *PostgresDBAccess) UpdateLastCleanup(ctx context.Context, db dbquerier, cleanupInterval time.Duration) (bool, error) {
|
||||||
queryCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
queryCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
res, err := db.ExecContext(queryCtx,
|
res, err := db.Exec(queryCtx,
|
||||||
fmt.Sprintf(`INSERT INTO %[1]s (key, value)
|
fmt.Sprintf(`INSERT INTO %[1]s (key, value)
|
||||||
VALUES ('last-cleanup', CURRENT_TIMESTAMP)
|
VALUES ('last-cleanup', CURRENT_TIMESTAMP)
|
||||||
ON CONFLICT (key)
|
ON CONFLICT (key)
|
||||||
|
|
@ -554,11 +589,7 @@ func (p *PostgresDBAccess) UpdateLastCleanup(ctx context.Context, db dbquerier,
|
||||||
return true, fmt.Errorf("failed to execute query: %w", err)
|
return true, fmt.Errorf("failed to execute query: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err := res.RowsAffected()
|
n := res.RowsAffected()
|
||||||
if err != nil {
|
|
||||||
return true, fmt.Errorf("failed to count affected rows: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return n > 0, nil
|
return n > 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -569,7 +600,8 @@ func (p *PostgresDBAccess) Close() error {
|
||||||
p.cancel = nil
|
p.cancel = nil
|
||||||
}
|
}
|
||||||
if p.db != nil {
|
if p.db != nil {
|
||||||
return p.db.Close()
|
p.db.Close()
|
||||||
|
p.db = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
|
||||||
|
|
@ -16,24 +16,20 @@ package postgresql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"encoding/json"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/DATA-DOG/go-sqlmock"
|
pgxmock "github.com/pashagolub/pgxmock/v2"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/dapr/components-contrib/metadata"
|
"github.com/dapr/components-contrib/metadata"
|
||||||
"github.com/dapr/components-contrib/state"
|
"github.com/dapr/components-contrib/state"
|
||||||
"github.com/dapr/kit/logger"
|
"github.com/dapr/kit/logger"
|
||||||
|
|
||||||
// Blank import for pgx
|
|
||||||
_ "github.com/jackc/pgx/v5/stdlib"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type mocks struct {
|
type mocks struct {
|
||||||
db *sql.DB
|
db pgxmock.PgxPoolIface
|
||||||
mock sqlmock.Sqlmock
|
|
||||||
pgDba *PostgresDBAccess
|
pgDba *PostgresDBAccess
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -67,7 +63,7 @@ func TestGetSetValid(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
set, err := getSet(operation)
|
set, err := getSet(operation)
|
||||||
assert.Nil(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "key1", set.Key)
|
assert.Equal(t, "key1", set.Key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -101,7 +97,7 @@ func TestGetDeleteValid(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
delete, err := getDelete(operation)
|
delete, err := getDelete(operation)
|
||||||
assert.Nil(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "key1", delete.Key)
|
assert.Equal(t, "key1", delete.Key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -110,8 +106,10 @@ func TestMultiWithNoRequests(t *testing.T) {
|
||||||
m, _ := mockDatabase(t)
|
m, _ := mockDatabase(t)
|
||||||
defer m.db.Close()
|
defer m.db.Close()
|
||||||
|
|
||||||
m.mock.ExpectBegin()
|
m.db.ExpectBegin()
|
||||||
m.mock.ExpectCommit()
|
m.db.ExpectCommit()
|
||||||
|
// There's also a rollback called after a commit, which is expected and will not have effect
|
||||||
|
m.db.ExpectRollback()
|
||||||
|
|
||||||
var operations []state.TransactionalStateOperation
|
var operations []state.TransactionalStateOperation
|
||||||
|
|
||||||
|
|
@ -121,7 +119,7 @@ func TestMultiWithNoRequests(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assert.Nil(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInvalidMultiInvalidAction(t *testing.T) {
|
func TestInvalidMultiInvalidAction(t *testing.T) {
|
||||||
|
|
@ -129,8 +127,8 @@ func TestInvalidMultiInvalidAction(t *testing.T) {
|
||||||
m, _ := mockDatabase(t)
|
m, _ := mockDatabase(t)
|
||||||
defer m.db.Close()
|
defer m.db.Close()
|
||||||
|
|
||||||
m.mock.ExpectBegin()
|
m.db.ExpectBegin()
|
||||||
m.mock.ExpectRollback()
|
m.db.ExpectRollback()
|
||||||
|
|
||||||
var operations []state.TransactionalStateOperation
|
var operations []state.TransactionalStateOperation
|
||||||
|
|
||||||
|
|
@ -153,16 +151,19 @@ func TestValidSetRequest(t *testing.T) {
|
||||||
m, _ := mockDatabase(t)
|
m, _ := mockDatabase(t)
|
||||||
defer m.db.Close()
|
defer m.db.Close()
|
||||||
|
|
||||||
m.mock.ExpectBegin()
|
setReq := createSetRequest()
|
||||||
m.mock.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(1, 1))
|
operations := []state.TransactionalStateOperation{
|
||||||
m.mock.ExpectCommit()
|
{Operation: state.Upsert, Request: setReq},
|
||||||
|
}
|
||||||
|
val, _ := json.Marshal(setReq.Value)
|
||||||
|
|
||||||
var operations []state.TransactionalStateOperation
|
m.db.ExpectBegin()
|
||||||
|
m.db.ExpectExec("INSERT INTO").
|
||||||
operations = append(operations, state.TransactionalStateOperation{
|
WithArgs(setReq.Key, string(val), false).
|
||||||
Operation: state.Upsert,
|
WillReturnResult(pgxmock.NewResult("INSERT", 1))
|
||||||
Request: createSetRequest(),
|
m.db.ExpectCommit()
|
||||||
})
|
// There's also a rollback called after a commit, which is expected and will not have effect
|
||||||
|
m.db.ExpectRollback()
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{
|
err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{
|
||||||
|
|
@ -170,7 +171,7 @@ func TestValidSetRequest(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assert.Nil(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInvalidMultiSetRequest(t *testing.T) {
|
func TestInvalidMultiSetRequest(t *testing.T) {
|
||||||
|
|
@ -178,15 +179,15 @@ func TestInvalidMultiSetRequest(t *testing.T) {
|
||||||
m, _ := mockDatabase(t)
|
m, _ := mockDatabase(t)
|
||||||
defer m.db.Close()
|
defer m.db.Close()
|
||||||
|
|
||||||
m.mock.ExpectBegin()
|
m.db.ExpectBegin()
|
||||||
m.mock.ExpectRollback()
|
m.db.ExpectRollback()
|
||||||
|
|
||||||
var operations []state.TransactionalStateOperation
|
operations := []state.TransactionalStateOperation{
|
||||||
|
{
|
||||||
operations = append(operations, state.TransactionalStateOperation{
|
Operation: state.Upsert,
|
||||||
Operation: state.Upsert,
|
Request: createDeleteRequest(), // Delete request is not valid for Upsert operation
|
||||||
Request: createDeleteRequest(), // Delete request is not valid for Upsert operation
|
},
|
||||||
})
|
}
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{
|
err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{
|
||||||
|
|
@ -202,8 +203,8 @@ func TestInvalidMultiSetRequestNoKey(t *testing.T) {
|
||||||
m, _ := mockDatabase(t)
|
m, _ := mockDatabase(t)
|
||||||
defer m.db.Close()
|
defer m.db.Close()
|
||||||
|
|
||||||
m.mock.ExpectBegin()
|
m.db.ExpectBegin()
|
||||||
m.mock.ExpectRollback()
|
m.db.ExpectRollback()
|
||||||
|
|
||||||
var operations []state.TransactionalStateOperation
|
var operations []state.TransactionalStateOperation
|
||||||
|
|
||||||
|
|
@ -226,16 +227,18 @@ func TestValidMultiDeleteRequest(t *testing.T) {
|
||||||
m, _ := mockDatabase(t)
|
m, _ := mockDatabase(t)
|
||||||
defer m.db.Close()
|
defer m.db.Close()
|
||||||
|
|
||||||
m.mock.ExpectBegin()
|
deleteReq := createDeleteRequest()
|
||||||
m.mock.ExpectExec("DELETE FROM").WillReturnResult(sqlmock.NewResult(1, 1))
|
operations := []state.TransactionalStateOperation{
|
||||||
m.mock.ExpectCommit()
|
{Operation: state.Delete, Request: deleteReq},
|
||||||
|
}
|
||||||
|
|
||||||
var operations []state.TransactionalStateOperation
|
m.db.ExpectBegin()
|
||||||
|
m.db.ExpectExec("DELETE FROM").
|
||||||
operations = append(operations, state.TransactionalStateOperation{
|
WithArgs(deleteReq.Key).
|
||||||
Operation: state.Delete,
|
WillReturnResult(pgxmock.NewResult("DELETE", 1))
|
||||||
Request: createDeleteRequest(),
|
m.db.ExpectCommit()
|
||||||
})
|
// There's also a rollback called after a commit, which is expected and will not have effect
|
||||||
|
m.db.ExpectRollback()
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{
|
err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{
|
||||||
|
|
@ -243,7 +246,7 @@ func TestValidMultiDeleteRequest(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assert.Nil(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInvalidMultiDeleteRequest(t *testing.T) {
|
func TestInvalidMultiDeleteRequest(t *testing.T) {
|
||||||
|
|
@ -251,8 +254,8 @@ func TestInvalidMultiDeleteRequest(t *testing.T) {
|
||||||
m, _ := mockDatabase(t)
|
m, _ := mockDatabase(t)
|
||||||
defer m.db.Close()
|
defer m.db.Close()
|
||||||
|
|
||||||
m.mock.ExpectBegin()
|
m.db.ExpectBegin()
|
||||||
m.mock.ExpectRollback()
|
m.db.ExpectRollback()
|
||||||
|
|
||||||
var operations []state.TransactionalStateOperation
|
var operations []state.TransactionalStateOperation
|
||||||
|
|
||||||
|
|
@ -275,8 +278,8 @@ func TestInvalidMultiDeleteRequestNoKey(t *testing.T) {
|
||||||
m, _ := mockDatabase(t)
|
m, _ := mockDatabase(t)
|
||||||
defer m.db.Close()
|
defer m.db.Close()
|
||||||
|
|
||||||
m.mock.ExpectBegin()
|
m.db.ExpectBegin()
|
||||||
m.mock.ExpectRollback()
|
m.db.ExpectRollback()
|
||||||
|
|
||||||
var operations []state.TransactionalStateOperation
|
var operations []state.TransactionalStateOperation
|
||||||
|
|
||||||
|
|
@ -299,23 +302,27 @@ func TestMultiOperationOrder(t *testing.T) {
|
||||||
m, _ := mockDatabase(t)
|
m, _ := mockDatabase(t)
|
||||||
defer m.db.Close()
|
defer m.db.Close()
|
||||||
|
|
||||||
m.mock.ExpectBegin()
|
operations := []state.TransactionalStateOperation{
|
||||||
m.mock.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(1, 1))
|
{
|
||||||
m.mock.ExpectExec("DELETE FROM").WillReturnResult(sqlmock.NewResult(1, 1))
|
|
||||||
m.mock.ExpectCommit()
|
|
||||||
|
|
||||||
var operations []state.TransactionalStateOperation
|
|
||||||
|
|
||||||
operations = append(operations,
|
|
||||||
state.TransactionalStateOperation{
|
|
||||||
Operation: state.Upsert,
|
Operation: state.Upsert,
|
||||||
Request: state.SetRequest{Key: "key1", Value: "value1"},
|
Request: state.SetRequest{Key: "key1", Value: "value1"},
|
||||||
},
|
},
|
||||||
state.TransactionalStateOperation{
|
{
|
||||||
Operation: state.Delete,
|
Operation: state.Delete,
|
||||||
Request: state.DeleteRequest{Key: "key1"},
|
Request: state.DeleteRequest{Key: "key1"},
|
||||||
},
|
},
|
||||||
)
|
}
|
||||||
|
|
||||||
|
m.db.ExpectBegin()
|
||||||
|
m.db.ExpectExec("INSERT INTO").
|
||||||
|
WithArgs("key1", `"value1"`, false).
|
||||||
|
WillReturnResult(pgxmock.NewResult("INSERT", 1))
|
||||||
|
m.db.ExpectExec("DELETE FROM").
|
||||||
|
WithArgs("key1").
|
||||||
|
WillReturnResult(pgxmock.NewResult("DELETE", 1))
|
||||||
|
m.db.ExpectCommit()
|
||||||
|
// There's also a rollback called after a commit, which is expected and will not have effect
|
||||||
|
m.db.ExpectRollback()
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{
|
err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{
|
||||||
|
|
@ -323,7 +330,7 @@ func TestMultiOperationOrder(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assert.Nil(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInvalidBulkSetNoKey(t *testing.T) {
|
func TestInvalidBulkSetNoKey(t *testing.T) {
|
||||||
|
|
@ -331,14 +338,13 @@ func TestInvalidBulkSetNoKey(t *testing.T) {
|
||||||
m, _ := mockDatabase(t)
|
m, _ := mockDatabase(t)
|
||||||
defer m.db.Close()
|
defer m.db.Close()
|
||||||
|
|
||||||
m.mock.ExpectBegin()
|
m.db.ExpectBegin()
|
||||||
m.mock.ExpectRollback()
|
m.db.ExpectRollback()
|
||||||
|
|
||||||
var sets []state.SetRequest
|
sets := []state.SetRequest{
|
||||||
|
// Set request without key is not valid for Set operation
|
||||||
sets = append(sets, state.SetRequest{ // Set request without key is not valid for Set operation
|
{Value: "value1"},
|
||||||
Value: "value1",
|
}
|
||||||
})
|
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
err := m.pgDba.BulkSet(context.Background(), sets)
|
err := m.pgDba.BulkSet(context.Background(), sets)
|
||||||
|
|
@ -352,15 +358,13 @@ func TestInvalidBulkSetEmptyValue(t *testing.T) {
|
||||||
m, _ := mockDatabase(t)
|
m, _ := mockDatabase(t)
|
||||||
defer m.db.Close()
|
defer m.db.Close()
|
||||||
|
|
||||||
m.mock.ExpectBegin()
|
m.db.ExpectBegin()
|
||||||
m.mock.ExpectRollback()
|
m.db.ExpectRollback()
|
||||||
|
|
||||||
var sets []state.SetRequest
|
sets := []state.SetRequest{
|
||||||
|
// Set request without value is not valid for Set operation
|
||||||
sets = append(sets, state.SetRequest{ // Set request without value is not valid for Set operation
|
{Key: "key1", Value: "value1"},
|
||||||
Key: "key1",
|
}
|
||||||
Value: "",
|
|
||||||
})
|
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
err := m.pgDba.BulkSet(context.Background(), sets)
|
err := m.pgDba.BulkSet(context.Background(), sets)
|
||||||
|
|
@ -374,22 +378,26 @@ func TestValidBulkSet(t *testing.T) {
|
||||||
m, _ := mockDatabase(t)
|
m, _ := mockDatabase(t)
|
||||||
defer m.db.Close()
|
defer m.db.Close()
|
||||||
|
|
||||||
m.mock.ExpectBegin()
|
sets := []state.SetRequest{
|
||||||
m.mock.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(1, 1))
|
{
|
||||||
m.mock.ExpectCommit()
|
Key: "key1",
|
||||||
|
Value: "value1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
var sets []state.SetRequest
|
m.db.ExpectBegin()
|
||||||
|
m.db.ExpectExec("INSERT INTO").
|
||||||
sets = append(sets, state.SetRequest{
|
WithArgs("key1", `"value1"`, false).
|
||||||
Key: "key1",
|
WillReturnResult(pgxmock.NewResult("INSERT", 1))
|
||||||
Value: "value1",
|
m.db.ExpectCommit()
|
||||||
})
|
// There's also a rollback called after a commit, which is expected and will not have effect
|
||||||
|
m.db.ExpectRollback()
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
err := m.pgDba.BulkSet(context.Background(), sets)
|
err := m.pgDba.BulkSet(context.Background(), sets)
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assert.Nil(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInvalidBulkDeleteNoKey(t *testing.T) {
|
func TestInvalidBulkDeleteNoKey(t *testing.T) {
|
||||||
|
|
@ -397,8 +405,8 @@ func TestInvalidBulkDeleteNoKey(t *testing.T) {
|
||||||
m, _ := mockDatabase(t)
|
m, _ := mockDatabase(t)
|
||||||
defer m.db.Close()
|
defer m.db.Close()
|
||||||
|
|
||||||
m.mock.ExpectBegin()
|
m.db.ExpectBegin()
|
||||||
m.mock.ExpectRollback()
|
m.db.ExpectRollback()
|
||||||
|
|
||||||
var deletes []state.DeleteRequest
|
var deletes []state.DeleteRequest
|
||||||
|
|
||||||
|
|
@ -418,21 +426,23 @@ func TestValidBulkDelete(t *testing.T) {
|
||||||
m, _ := mockDatabase(t)
|
m, _ := mockDatabase(t)
|
||||||
defer m.db.Close()
|
defer m.db.Close()
|
||||||
|
|
||||||
m.mock.ExpectBegin()
|
deletes := []state.DeleteRequest{
|
||||||
m.mock.ExpectExec("DELETE FROM").WillReturnResult(sqlmock.NewResult(1, 1))
|
{Key: "key1"},
|
||||||
m.mock.ExpectCommit()
|
}
|
||||||
|
|
||||||
var deletes []state.DeleteRequest
|
m.db.ExpectBegin()
|
||||||
|
m.db.ExpectExec("DELETE FROM").
|
||||||
deletes = append(deletes, state.DeleteRequest{
|
WithArgs("key1").
|
||||||
Key: "key1",
|
WillReturnResult(pgxmock.NewResult("DELETE", 1))
|
||||||
})
|
m.db.ExpectCommit()
|
||||||
|
// There's also a rollback called after a commit, which is expected and will not have effect
|
||||||
|
m.db.ExpectRollback()
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
err := m.pgDba.BulkDelete(context.Background(), deletes)
|
err := m.pgDba.BulkDelete(context.Background(), deletes)
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assert.Nil(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createSetRequest() state.SetRequest {
|
func createSetRequest() state.SetRequest {
|
||||||
|
|
@ -451,7 +461,7 @@ func createDeleteRequest() state.DeleteRequest {
|
||||||
func mockDatabase(t *testing.T) (*mocks, error) {
|
func mockDatabase(t *testing.T) (*mocks, error) {
|
||||||
logger := logger.NewLogger("test")
|
logger := logger.NewLogger("test")
|
||||||
|
|
||||||
db, mock, err := sqlmock.New()
|
db, err := pgxmock.NewPool()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
||||||
}
|
}
|
||||||
|
|
@ -463,7 +473,6 @@ func mockDatabase(t *testing.T) (*mocks, error) {
|
||||||
|
|
||||||
return &mocks{
|
return &mocks{
|
||||||
db: db,
|
db: db,
|
||||||
mock: mock,
|
|
||||||
pgDba: dba,
|
pgDba: dba,
|
||||||
}, err
|
}, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -12,18 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package postgresql
|
package postgresql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/dapr/components-contrib/metadata"
|
"github.com/dapr/components-contrib/metadata"
|
||||||
"github.com/dapr/components-contrib/state"
|
"github.com/dapr/components-contrib/state"
|
||||||
|
|
@ -162,7 +165,7 @@ func deleteItemThatDoesNotExist(t *testing.T, pgs *PostgreSQL) {
|
||||||
Key: randomKey(),
|
Key: randomKey(),
|
||||||
}
|
}
|
||||||
err := pgs.Delete(context.Background(), deleteReq)
|
err := pgs.Delete(context.Background(), deleteReq)
|
||||||
assert.Nil(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func multiWithSetOnly(t *testing.T, pgs *PostgreSQL) {
|
func multiWithSetOnly(t *testing.T, pgs *PostgreSQL) {
|
||||||
|
|
@ -183,7 +186,7 @@ func multiWithSetOnly(t *testing.T, pgs *PostgreSQL) {
|
||||||
err := pgs.Multi(context.Background(), &state.TransactionalStateRequest{
|
err := pgs.Multi(context.Background(), &state.TransactionalStateRequest{
|
||||||
Operations: operations,
|
Operations: operations,
|
||||||
})
|
})
|
||||||
assert.Nil(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
for _, set := range setRequests {
|
for _, set := range setRequests {
|
||||||
assert.True(t, storeItemExists(t, set.Key))
|
assert.True(t, storeItemExists(t, set.Key))
|
||||||
|
|
@ -213,7 +216,7 @@ func multiWithDeleteOnly(t *testing.T, pgs *PostgreSQL) {
|
||||||
err := pgs.Multi(context.Background(), &state.TransactionalStateRequest{
|
err := pgs.Multi(context.Background(), &state.TransactionalStateRequest{
|
||||||
Operations: operations,
|
Operations: operations,
|
||||||
})
|
})
|
||||||
assert.Nil(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
for _, delete := range deleteRequests {
|
for _, delete := range deleteRequests {
|
||||||
assert.False(t, storeItemExists(t, delete.Key))
|
assert.False(t, storeItemExists(t, delete.Key))
|
||||||
|
|
@ -256,7 +259,7 @@ func multiWithDeleteAndSet(t *testing.T, pgs *PostgreSQL) {
|
||||||
err := pgs.Multi(context.Background(), &state.TransactionalStateRequest{
|
err := pgs.Multi(context.Background(), &state.TransactionalStateRequest{
|
||||||
Operations: operations,
|
Operations: operations,
|
||||||
})
|
})
|
||||||
assert.Nil(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
for _, delete := range deleteRequests {
|
for _, delete := range deleteRequests {
|
||||||
assert.False(t, storeItemExists(t, delete.Key))
|
assert.False(t, storeItemExists(t, delete.Key))
|
||||||
|
|
@ -384,15 +387,16 @@ func setUpdatesTheUpdatedateField(t *testing.T, pgs *PostgreSQL) {
|
||||||
|
|
||||||
// insertdate should have a value and updatedate should be nil
|
// insertdate should have a value and updatedate should be nil
|
||||||
_, insertdate, updatedate := getRowData(t, key)
|
_, insertdate, updatedate := getRowData(t, key)
|
||||||
|
assert.Nil(t, updatedate)
|
||||||
assert.NotNil(t, insertdate)
|
assert.NotNil(t, insertdate)
|
||||||
assert.Equal(t, "", updatedate.String)
|
|
||||||
|
|
||||||
// insertdate should not change, updatedate should have a value
|
// insertdate should not change, updatedate should have a value
|
||||||
value = &fakeItem{Color: "aqua"}
|
value = &fakeItem{Color: "aqua"}
|
||||||
setItem(t, pgs, key, value, nil)
|
setItem(t, pgs, key, value, nil)
|
||||||
_, newinsertdate, updatedate := getRowData(t, key)
|
_, newinsertdate, updatedate := getRowData(t, key)
|
||||||
assert.Equal(t, insertdate, newinsertdate) // The insertdate should not change.
|
assert.NotNil(t, updatedate)
|
||||||
assert.NotEqual(t, "", updatedate.String)
|
assert.NotNil(t, newinsertdate)
|
||||||
|
assert.True(t, insertdate.Equal(*newinsertdate)) // The insertdate should not change.
|
||||||
|
|
||||||
deleteItem(t, pgs, key, nil)
|
deleteItem(t, pgs, key, nil)
|
||||||
}
|
}
|
||||||
|
|
@ -420,7 +424,7 @@ func testBulkSetAndBulkDelete(t *testing.T, pgs *PostgreSQL) {
|
||||||
}
|
}
|
||||||
|
|
||||||
err := pgs.BulkSet(context.Background(), setReq)
|
err := pgs.BulkSet(context.Background(), setReq)
|
||||||
assert.Nil(t, err)
|
assert.NoError(t, err)
|
||||||
assert.True(t, storeItemExists(t, setReq[0].Key))
|
assert.True(t, storeItemExists(t, setReq[0].Key))
|
||||||
assert.True(t, storeItemExists(t, setReq[1].Key))
|
assert.True(t, storeItemExists(t, setReq[1].Key))
|
||||||
|
|
||||||
|
|
@ -434,7 +438,7 @@ func testBulkSetAndBulkDelete(t *testing.T, pgs *PostgreSQL) {
|
||||||
}
|
}
|
||||||
|
|
||||||
err = pgs.BulkDelete(context.Background(), deleteReq)
|
err = pgs.BulkDelete(context.Background(), deleteReq)
|
||||||
assert.Nil(t, err)
|
assert.NoError(t, err)
|
||||||
assert.False(t, storeItemExists(t, setReq[0].Key))
|
assert.False(t, storeItemExists(t, setReq[0].Key))
|
||||||
assert.False(t, storeItemExists(t, setReq[1].Key))
|
assert.False(t, storeItemExists(t, setReq[1].Key))
|
||||||
}
|
}
|
||||||
|
|
@ -491,7 +495,7 @@ func setItem(t *testing.T, pgs *PostgreSQL, key string, value interface{}, etag
|
||||||
}
|
}
|
||||||
|
|
||||||
err := pgs.Set(context.Background(), setReq)
|
err := pgs.Set(context.Background(), setReq)
|
||||||
assert.Nil(t, err)
|
assert.NoError(t, err)
|
||||||
itemExists := storeItemExists(t, key)
|
itemExists := storeItemExists(t, key)
|
||||||
assert.True(t, itemExists)
|
assert.True(t, itemExists)
|
||||||
}
|
}
|
||||||
|
|
@ -524,25 +528,28 @@ func deleteItem(t *testing.T, pgs *PostgreSQL, key string, etag *string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func storeItemExists(t *testing.T, key string) bool {
|
func storeItemExists(t *testing.T, key string) bool {
|
||||||
db, err := sql.Open("pgx", getConnectionString())
|
ctx := context.Background()
|
||||||
assert.Nil(t, err)
|
db, err := pgx.Connect(ctx, getConnectionString())
|
||||||
defer db.Close()
|
require.NoError(t, err)
|
||||||
|
defer db.Close(ctx)
|
||||||
|
|
||||||
exists := false
|
exists := false
|
||||||
statement := fmt.Sprintf(`SELECT EXISTS (SELECT FROM %s WHERE key = $1)`, defaultTableName)
|
statement := fmt.Sprintf(`SELECT EXISTS (SELECT FROM %s WHERE key = $1)`, defaultTableName)
|
||||||
err = db.QueryRow(statement, key).Scan(&exists)
|
err = db.QueryRow(ctx, statement, key).Scan(&exists)
|
||||||
assert.Nil(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
return exists
|
return exists
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRowData(t *testing.T, key string) (returnValue string, insertdate sql.NullString, updatedate sql.NullString) {
|
func getRowData(t *testing.T, key string) (returnValue string, insertdate *time.Time, updatedate *time.Time) {
|
||||||
db, err := sql.Open("pgx", getConnectionString())
|
ctx := context.Background()
|
||||||
assert.Nil(t, err)
|
db, err := pgx.Connect(ctx, getConnectionString())
|
||||||
defer db.Close()
|
require.NoError(t, err)
|
||||||
|
defer db.Close(ctx)
|
||||||
|
|
||||||
err = db.QueryRow(fmt.Sprintf("SELECT value, insertdate, updatedate FROM %s WHERE key = $1", defaultTableName), key).Scan(&returnValue, &insertdate, &updatedate)
|
err = db.QueryRow(ctx, fmt.Sprintf("SELECT value, insertdate, updatedate FROM %s WHERE key = $1", defaultTableName), key).
|
||||||
assert.Nil(t, err)
|
Scan(&returnValue, &insertdate, &updatedate)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
return returnValue, insertdate, updatedate
|
return returnValue, insertdate, updatedate
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@ package postgresql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
@ -140,8 +139,8 @@ func (q *Query) Finalize(filters string, qq *query.Query) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Query) execute(ctx context.Context, logger logger.Logger, db *sql.DB) ([]state.QueryItem, string, error) {
|
func (q *Query) execute(ctx context.Context, logger logger.Logger, db dbquerier) ([]state.QueryItem, string, error) {
|
||||||
rows, err := db.QueryContext(ctx, q.query, q.params...)
|
rows, err := db.Query(ctx, q.query, q.params...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
@ -152,7 +151,7 @@ func (q *Query) execute(ctx context.Context, logger logger.Logger, db *sql.DB) (
|
||||||
var (
|
var (
|
||||||
key string
|
key string
|
||||||
data []byte
|
data []byte
|
||||||
etag uint64 // Postgres uses uint32, but FormatUint requires uint64, so using uint64 directly to avoid re-allocations
|
etag uint32
|
||||||
)
|
)
|
||||||
if err = rows.Scan(&key, &data, &etag); err != nil {
|
if err = rows.Scan(&key, &data, &etag); err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
|
|
@ -160,7 +159,7 @@ func (q *Query) execute(ctx context.Context, logger logger.Logger, db *sql.DB) (
|
||||||
result := state.QueryItem{
|
result := state.QueryItem{
|
||||||
Key: key,
|
Key: key,
|
||||||
Data: data,
|
Data: data,
|
||||||
ETag: ptr.Of(strconv.FormatUint(etag, 10)),
|
ETag: ptr.Of(strconv.FormatUint(uint64(etag), 10)),
|
||||||
}
|
}
|
||||||
ret = append(ret, result)
|
ret = append(ret, result)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -62,6 +62,7 @@ require (
|
||||||
github.com/imdario/mergo v0.3.12 // indirect
|
github.com/imdario/mergo v0.3.12 // indirect
|
||||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect
|
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect
|
||||||
|
github.com/jackc/puddle/v2 v2.1.2 // indirect
|
||||||
github.com/jhump/protoreflect v1.13.0 // indirect
|
github.com/jhump/protoreflect v1.13.0 // indirect
|
||||||
github.com/josharian/intern v1.0.0 // indirect
|
github.com/josharian/intern v1.0.0 // indirect
|
||||||
github.com/json-iterator/go v1.1.12 // indirect
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,6 @@ github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a h1:
|
||||||
github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go.mod h1:C0A1KeiVHs+trY6gUTPhhGammbrZ30ZfXRW/nuT7HLw=
|
github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go.mod h1:C0A1KeiVHs+trY6gUTPhhGammbrZ30ZfXRW/nuT7HLw=
|
||||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||||
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
|
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
|
||||||
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
|
|
||||||
github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ=
|
github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ=
|
||||||
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
|
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
|
||||||
github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig=
|
github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig=
|
||||||
|
|
@ -288,6 +287,8 @@ github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHF
|
||||||
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E=
|
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E=
|
||||||
github.com/jackc/pgx/v5 v5.2.0 h1:NdPpngX0Y6z6XDFKqmFQaE+bCtkqzvQIOt1wvBlAqs8=
|
github.com/jackc/pgx/v5 v5.2.0 h1:NdPpngX0Y6z6XDFKqmFQaE+bCtkqzvQIOt1wvBlAqs8=
|
||||||
github.com/jackc/pgx/v5 v5.2.0/go.mod h1:Ptn7zmohNsWEsdxRawMzk3gaKma2obW+NWTnKa0S4nk=
|
github.com/jackc/pgx/v5 v5.2.0/go.mod h1:Ptn7zmohNsWEsdxRawMzk3gaKma2obW+NWTnKa0S4nk=
|
||||||
|
github.com/jackc/puddle/v2 v2.1.2 h1:0f7vaaXINONKTsxYDn4otOAiJanX/BMeAtY//BXqzlg=
|
||||||
|
github.com/jackc/puddle/v2 v2.1.2/go.mod h1:2lpufsF5mRHO6SuZkm0fNYxM6SWHfvyFj62KwNzgels=
|
||||||
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
||||||
github.com/jhump/gopoet v0.0.0-20190322174617-17282ff210b3/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI=
|
github.com/jhump/gopoet v0.0.0-20190322174617-17282ff210b3/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI=
|
||||||
github.com/jhump/gopoet v0.1.0/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI=
|
github.com/jhump/gopoet v0.1.0/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI=
|
||||||
|
|
@ -385,6 +386,7 @@ github.com/openzipkin/zipkin-go v0.4.1/go.mod h1:qY0VqDSN1pOBN94dBc6w2GJlWLiovAy
|
||||||
github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
|
github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
|
||||||
github.com/pascaldekloe/goe v0.1.0 h1:cBOtyMzM9HTpWjXfbbunk26uA6nG3a8n06Wieeh0MwY=
|
github.com/pascaldekloe/goe v0.1.0 h1:cBOtyMzM9HTpWjXfbbunk26uA6nG3a8n06Wieeh0MwY=
|
||||||
github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
|
github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
|
||||||
|
github.com/pashagolub/pgxmock/v2 v2.4.0 h1:jNv7+svrNoMc31mvllSS/u7P2pT3gS3uY7DPRKIJNSY=
|
||||||
github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 h1:Ii+DKncOVM8Cu1Hc+ETb5K+23HdAMvESYE3ZJ5b5cMI=
|
github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 h1:Ii+DKncOVM8Cu1Hc+ETb5K+23HdAMvESYE3ZJ5b5cMI=
|
||||||
github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE=
|
github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE=
|
||||||
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@ package main
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
@ -706,7 +705,7 @@ func getMigrationLevel(dbClient *pgx.Conn, metadataTable string) (level string,
|
||||||
err = dbClient.
|
err = dbClient.
|
||||||
QueryRow(ctx, fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, metadataTable)).
|
QueryRow(ctx, fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, metadataTable)).
|
||||||
Scan(&level)
|
Scan(&level)
|
||||||
if err != nil && errors.Is(err, sql.ErrNoRows) {
|
if err != nil && errors.Is(err, pgx.ErrNoRows) {
|
||||||
err = nil
|
err = nil
|
||||||
level = ""
|
level = ""
|
||||||
}
|
}
|
||||||
|
|
@ -766,7 +765,7 @@ func loadLastCleanupInterval(ctx context.Context, dbClient *pgx.Conn, table stri
|
||||||
).
|
).
|
||||||
Scan(&lastCleanup)
|
Scan(&lastCleanup)
|
||||||
cancel()
|
cancel()
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
err = nil
|
err = nil
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
@ -790,7 +789,7 @@ func getValueFromMetadataTable(ctx context.Context, dbClient *pgx.Conn, table, k
|
||||||
QueryRow(queryCtx, fmt.Sprintf("SELECT value FROM %s WHERE key = $1", table), key).
|
QueryRow(queryCtx, fmt.Sprintf("SELECT value FROM %s WHERE key = $1", table), key).
|
||||||
Scan(&value)
|
Scan(&value)
|
||||||
cancel()
|
cancel()
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
err = nil
|
err = nil
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -272,7 +272,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
|
||||||
req.Metadata = map[string]string{metadata.ContentType: scenario.contentType}
|
req.Metadata = map[string]string{metadata.ContentType: scenario.contentType}
|
||||||
}
|
}
|
||||||
res, err := statestore.Get(context.Background(), req)
|
res, err := statestore.Get(context.Background(), req)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assertEquals(t, scenario.value, res)
|
assertEquals(t, scenario.value, res)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -287,13 +287,13 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
|
||||||
t.Logf("Querying value presence for %s", scenario.query)
|
t.Logf("Querying value presence for %s", scenario.query)
|
||||||
var req state.QueryRequest
|
var req state.QueryRequest
|
||||||
err := json.Unmarshal([]byte(scenario.query), &req.Query)
|
err := json.Unmarshal([]byte(scenario.query), &req.Query)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
req.Metadata = map[string]string{
|
req.Metadata = map[string]string{
|
||||||
metadata.ContentType: contenttype.JSONContentType,
|
metadata.ContentType: contenttype.JSONContentType,
|
||||||
metadata.QueryIndexName: "qIndx",
|
metadata.QueryIndexName: "qIndx",
|
||||||
}
|
}
|
||||||
resp, err := querier.Query(context.Background(), &req)
|
resp, err := querier.Query(context.Background(), &req)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, len(scenario.results), len(resp.Results))
|
assert.Equal(t, len(scenario.results), len(resp.Results))
|
||||||
for i := range scenario.results {
|
for i := range scenario.results {
|
||||||
var expected, actual interface{}
|
var expected, actual interface{}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue