Changed postgres state store to use pgx instead of db/sql

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
ItalyPaleAle 2022-12-20 22:45:38 +00:00
parent 2582f154bd
commit 80b1bb14b7
11 changed files with 269 additions and 218 deletions

View File

@ -15,7 +15,6 @@ package postgres
import ( import (
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -114,7 +113,7 @@ func (p *ConfigurationStore) Init(metadata configuration.Metadata) error {
exists := false exists := false
err = p.client.QueryRow(ctx, QueryTableExists, p.metadata.configTable).Scan(&exists) err = p.client.QueryRow(ctx, QueryTableExists, p.metadata.configTable).Scan(&exists)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if errors.Is(err, pgx.ErrNoRows) {
return fmt.Errorf(ErrorMissingTable, p.metadata.configTable) return fmt.Errorf(ErrorMissingTable, p.metadata.configTable)
} }
return fmt.Errorf("error in checking if configtable '%s' exists - '%w'", p.metadata.configTable, err) return fmt.Errorf("error in checking if configtable '%s' exists - '%w'", p.metadata.configTable, err)
@ -135,7 +134,7 @@ func (p *ConfigurationStore) Get(ctx context.Context, req *configuration.GetRequ
rows, err := p.client.Query(ctx, query, params...) rows, err := p.client.Query(ctx, query, params...)
if err != nil { if err != nil {
// If no rows exist, return an empty response, otherwise return the error. // If no rows exist, return an empty response, otherwise return the error.
if err == sql.ErrNoRows { if errors.Is(err, pgx.ErrNoRows) {
return &configuration.GetResponse{}, nil return &configuration.GetResponse{}, nil
} }
return nil, fmt.Errorf("error in querying configuration store: '%w'", err) return nil, fmt.Errorf("error in querying configuration store: '%w'", err)

View File

@ -15,7 +15,9 @@ package postgresql
import ( import (
"context" "context"
"database/sql"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/dapr/components-contrib/state" "github.com/dapr/components-contrib/state"
) )
@ -34,10 +36,9 @@ type dbAccess interface {
} }
// Interface that contains methods for querying. // Interface that contains methods for querying.
// Applies to both *sql.DB and *sql.Tx // Applies to *pgx.Conn, *pgxpool.Pool, and pgx.Tx
type dbquerier interface { type dbquerier interface {
Exec(query string, args ...any) (sql.Result, error) Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) Query(context.Context, string, ...interface{}) (pgx.Rows, error)
QueryRow(query string, args ...any) *sql.Row QueryRow(context.Context, string, ...interface{}) pgx.Row
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
} }

View File

@ -15,20 +15,21 @@ package postgresql
import ( import (
"context" "context"
"database/sql"
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/jackc/pgx/v5"
"github.com/dapr/kit/logger" "github.com/dapr/kit/logger"
) )
// Performs migrations for the database schema // Performs migrations for the database schema
type migrations struct { type migrations struct {
Logger logger.Logger Logger logger.Logger
Conn *sql.DB Conn pgxPoolConn
StateTableName string StateTableName string
MetadataTableName string MetadataTableName string
} }
@ -42,7 +43,7 @@ func (m *migrations) Perform(ctx context.Context) error {
// Long timeout here as this query may block // Long timeout here as this query may block
queryCtx, cancel := context.WithTimeout(ctx, time.Minute) queryCtx, cancel := context.WithTimeout(ctx, time.Minute)
_, err := m.Conn.ExecContext(queryCtx, "SELECT pg_advisory_lock($1)", lockID) _, err := m.Conn.Exec(queryCtx, "SELECT pg_advisory_lock($1)", lockID)
cancel() cancel()
if err != nil { if err != nil {
return fmt.Errorf("faild to acquire advisory lock: %w", err) return fmt.Errorf("faild to acquire advisory lock: %w", err)
@ -51,7 +52,7 @@ func (m *migrations) Perform(ctx context.Context) error {
// Release the lock // Release the lock
defer func() { defer func() {
queryCtx, cancel = context.WithTimeout(ctx, time.Minute) queryCtx, cancel = context.WithTimeout(ctx, time.Minute)
_, err = m.Conn.ExecContext(queryCtx, "SELECT pg_advisory_unlock($1)", lockID) _, err = m.Conn.Exec(queryCtx, "SELECT pg_advisory_unlock($1)", lockID)
cancel() cancel()
if err != nil { if err != nil {
// Panicking here, as this forcibly closes the session and thus ensures we are not leaving locks hanging around // Panicking here, as this forcibly closes the session and thus ensures we are not leaving locks hanging around
@ -84,11 +85,11 @@ func (m *migrations) Perform(ctx context.Context) error {
) )
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second) queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
err = m.Conn. err = m.Conn.
QueryRowContext(queryCtx, QueryRow(queryCtx,
fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, m.MetadataTableName), fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, m.MetadataTableName),
).Scan(&migrationLevelStr) ).Scan(&migrationLevelStr)
cancel() cancel()
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, pgx.ErrNoRows) {
// If there's no row... // If there's no row...
migrationLevel = 0 migrationLevel = 0
} else if err != nil { } else if err != nil {
@ -109,7 +110,7 @@ func (m *migrations) Perform(ctx context.Context) error {
} }
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second) queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
_, err = m.Conn.ExecContext(queryCtx, _, err = m.Conn.Exec(queryCtx,
fmt.Sprintf(`INSERT INTO %s (key, value) VALUES ('migrations', $1) ON CONFLICT (key) DO UPDATE SET value = $1`, m.MetadataTableName), fmt.Sprintf(`INSERT INTO %s (key, value) VALUES ('migrations', $1) ON CONFLICT (key) DO UPDATE SET value = $1`, m.MetadataTableName),
strconv.Itoa(i+1), strconv.Itoa(i+1),
) )
@ -126,7 +127,7 @@ func (m migrations) createMetadataTable(ctx context.Context) error {
m.Logger.Infof("Creating metadata table '%s'", m.MetadataTableName) m.Logger.Infof("Creating metadata table '%s'", m.MetadataTableName)
// Add an "IF NOT EXISTS" in case another Dapr sidecar is creating the same table at the same time // Add an "IF NOT EXISTS" in case another Dapr sidecar is creating the same table at the same time
// In the next step we'll acquire a lock so there won't be issues with concurrency // In the next step we'll acquire a lock so there won't be issues with concurrency
_, err := m.Conn.Exec(fmt.Sprintf( _, err := m.Conn.Exec(ctx, fmt.Sprintf(
`CREATE TABLE IF NOT EXISTS %s ( `CREATE TABLE IF NOT EXISTS %s (
key text NOT NULL PRIMARY KEY, key text NOT NULL PRIMARY KEY,
value text NOT NULL value text NOT NULL
@ -148,7 +149,7 @@ func (m migrations) tableExists(ctx context.Context, tableName string) (exists b
if schema == "" { if schema == "" {
err = m.Conn. err = m.Conn.
QueryRowContext( QueryRow(
ctx, ctx,
`SELECT table_name, table_schema `SELECT table_name, table_schema
FROM information_schema.tables FROM information_schema.tables
@ -158,7 +159,7 @@ func (m migrations) tableExists(ctx context.Context, tableName string) (exists b
Scan(&table, &schema) Scan(&table, &schema)
} else { } else {
err = m.Conn. err = m.Conn.
QueryRowContext( QueryRow(
ctx, ctx,
`SELECT table_name, table_schema `SELECT table_name, table_schema
FROM information_schema.tables FROM information_schema.tables
@ -168,7 +169,7 @@ func (m migrations) tableExists(ctx context.Context, tableName string) (exists b
Scan(&table, &schema) Scan(&table, &schema)
} }
if err != nil && errors.Is(err, sql.ErrNoRows) { if err != nil && errors.Is(err, pgx.ErrNoRows) {
return false, "", "", nil return false, "", "", nil
} else if err != nil { } else if err != nil {
return false, "", "", fmt.Errorf("failed to check if table '%s' exists: %w", tableName, err) return false, "", "", fmt.Errorf("failed to check if table '%s' exists: %w", tableName, err)
@ -195,6 +196,7 @@ var allMigrations = [2]func(ctx context.Context, m *migrations) error{
// We need to add an "IF NOT EXISTS" because we may be migrating from when we did not use a metadata table // We need to add an "IF NOT EXISTS" because we may be migrating from when we did not use a metadata table
m.Logger.Infof("Creating state table '%s'", m.StateTableName) m.Logger.Infof("Creating state table '%s'", m.StateTableName)
_, err := m.Conn.Exec( _, err := m.Conn.Exec(
ctx,
fmt.Sprintf( fmt.Sprintf(
`CREATE TABLE IF NOT EXISTS %s ( `CREATE TABLE IF NOT EXISTS %s (
key text NOT NULL PRIMARY KEY, key text NOT NULL PRIMARY KEY,
@ -215,7 +217,7 @@ var allMigrations = [2]func(ctx context.Context, m *migrations) error{
// Migration 1: add the "expiredate" column // Migration 1: add the "expiredate" column
func(ctx context.Context, m *migrations) error { func(ctx context.Context, m *migrations) error {
m.Logger.Infof("Adding expiredate column to state table '%s'", m.StateTableName) m.Logger.Infof("Adding expiredate column to state table '%s'", m.StateTableName)
_, err := m.Conn.Exec(fmt.Sprintf( _, err := m.Conn.Exec(ctx, fmt.Sprintf(
`ALTER TABLE %s ADD expiredate TIMESTAMP WITH TIME ZONE`, `ALTER TABLE %s ADD expiredate TIMESTAMP WITH TIME ZONE`,
m.StateTableName, m.StateTableName,
)) ))

View File

@ -15,7 +15,6 @@ package postgresql
import ( import (
"context" "context"
"database/sql"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
@ -23,15 +22,16 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/dapr/components-contrib/metadata" "github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/state" "github.com/dapr/components-contrib/state"
"github.com/dapr/components-contrib/state/query" "github.com/dapr/components-contrib/state/query"
stateutils "github.com/dapr/components-contrib/state/utils" stateutils "github.com/dapr/components-contrib/state/utils"
"github.com/dapr/kit/logger" "github.com/dapr/kit/logger"
"github.com/dapr/kit/ptr" "github.com/dapr/kit/ptr"
// Blank import for the underlying Postgres driver.
_ "github.com/jackc/pgx/v5/stdlib"
) )
const ( const (
@ -43,12 +43,24 @@ const (
var errMissingConnectionString = errors.New("missing connection string") var errMissingConnectionString = errors.New("missing connection string")
// Interface that applies to *pgxpool.Pool.
// We need this to be able to mock the connection in tests.
type pgxPoolConn interface {
Begin(context.Context) (pgx.Tx, error)
BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error)
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
Query(context.Context, string, ...interface{}) (pgx.Rows, error)
QueryRow(context.Context, string, ...interface{}) pgx.Row
Ping(context.Context) error
Close()
}
// PostgresDBAccess implements dbaccess. // PostgresDBAccess implements dbaccess.
type PostgresDBAccess struct { type PostgresDBAccess struct {
logger logger.Logger logger logger.Logger
metadata postgresMetadataStruct metadata postgresMetadataStruct
cleanupInterval *time.Duration cleanupInterval *time.Duration
db *sql.DB db pgxPoolConn
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
} }
@ -81,23 +93,31 @@ func (p *PostgresDBAccess) Init(meta state.Metadata) error {
return err return err
} }
db, err := sql.Open("pgx", p.metadata.ConnectionString) config, err := pgxpool.ParseConfig(p.metadata.ConnectionString)
if err != nil { if err != nil {
err = fmt.Errorf("failed to parse connection string: %w", err)
p.logger.Error(err)
return err
}
if p.metadata.ConnectionMaxIdleTime > 0 {
config.MaxConnIdleTime = p.metadata.ConnectionMaxIdleTime
}
connCtx, connCancel := context.WithTimeout(p.ctx, 30*time.Second)
p.db, err = pgxpool.NewWithConfig(connCtx, config)
connCancel()
if err != nil {
err = fmt.Errorf("failed to connect to the database: %w", err)
p.logger.Error(err) p.logger.Error(err)
return err return err
} }
p.db = db
pingCtx, pingCancel := context.WithTimeout(p.ctx, 30*time.Second) pingCtx, pingCancel := context.WithTimeout(p.ctx, 30*time.Second)
pingErr := db.PingContext(pingCtx) err = p.db.Ping(pingCtx)
pingCancel() pingCancel()
if pingErr != nil {
return pingErr
}
p.db.SetConnMaxIdleTime(p.metadata.ConnectionMaxIdleTime)
if err != nil { if err != nil {
err = fmt.Errorf("failed to ping the database: %w", err)
p.logger.Error(err)
return err return err
} }
@ -117,8 +137,9 @@ func (p *PostgresDBAccess) Init(meta state.Metadata) error {
return nil return nil
} }
func (p *PostgresDBAccess) GetDB() *sql.DB { func (p *PostgresDBAccess) GetDB() *pgxpool.Pool {
return p.db // We can safely cast to *pgxpool.Pool because this method is never used in unit tests where we mock the DB
return p.db.(*pgxpool.Pool)
} }
func (p *PostgresDBAccess) ParseMetadata(meta state.Metadata) error { func (p *PostgresDBAccess) ParseMetadata(meta state.Metadata) error {
@ -193,8 +214,6 @@ func (p *PostgresDBAccess) doSet(parentCtx context.Context, db dbquerier, req *s
ttlSeconds = *ttl ttlSeconds = *ttl
} }
var result sql.Result
// Sprintf is required for table name because query.DB does not substitute parameters for table names. // Sprintf is required for table name because query.DB does not substitute parameters for table names.
// Other parameters use query.DB parameter substitution. // Other parameters use query.DB parameter substitution.
var ( var (
@ -246,8 +265,8 @@ func (p *PostgresDBAccess) doSet(parentCtx context.Context, db dbquerier, req *s
} else { } else {
queryExpiredate = "NULL" queryExpiredate = "NULL"
} }
result, err = db.ExecContext(parentCtx, fmt.Sprintf(query, p.metadata.TableName, queryExpiredate), params...)
result, err := db.Exec(parentCtx, fmt.Sprintf(query, p.metadata.TableName, queryExpiredate), params...)
if err != nil { if err != nil {
if req.ETag != nil && *req.ETag != "" { if req.ETag != nil && *req.ETag != "" {
return state.NewETagError(state.ETagMismatch, err) return state.NewETagError(state.ETagMismatch, err)
@ -255,11 +274,7 @@ func (p *PostgresDBAccess) doSet(parentCtx context.Context, db dbquerier, req *s
return err return err
} }
rows, err := result.RowsAffected() if result.RowsAffected() != 1 {
if err != nil {
return err
}
if rows != 1 {
return errors.New("no item was updated") return errors.New("no item was updated")
} }
@ -267,11 +282,20 @@ func (p *PostgresDBAccess) doSet(parentCtx context.Context, db dbquerier, req *s
} }
func (p *PostgresDBAccess) BulkSet(parentCtx context.Context, req []state.SetRequest) error { func (p *PostgresDBAccess) BulkSet(parentCtx context.Context, req []state.SetRequest) error {
tx, err := p.db.BeginTx(parentCtx, nil) ctx, cancel := context.WithTimeout(parentCtx, 30*time.Second)
tx, err := p.db.Begin(ctx)
cancel()
if err != nil { if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err) return fmt.Errorf("failed to begin transaction: %w", err)
} }
defer tx.Rollback() defer func() {
rollbackCtx, rollbackCancel := context.WithTimeout(parentCtx, 30*time.Second)
rollbackErr := tx.Rollback(rollbackCtx)
rollbackCancel()
if rollbackErr != nil && !errors.Is(rollbackErr, pgx.ErrTxClosed) {
p.logger.Errorf("Failed to rollback transaction in BulkSet: %v", rollbackErr)
}
}()
if len(req) > 0 { if len(req) > 0 {
for i := range req { for i := range req {
@ -282,7 +306,9 @@ func (p *PostgresDBAccess) BulkSet(parentCtx context.Context, req []state.SetReq
} }
} }
err = tx.Commit() ctx, cancel = context.WithTimeout(parentCtx, 30*time.Second)
err = tx.Commit(ctx)
cancel()
if err != nil { if err != nil {
return fmt.Errorf("failed to commit transaction: %w", err) return fmt.Errorf("failed to commit transaction: %w", err)
} }
@ -299,7 +325,7 @@ func (p *PostgresDBAccess) Get(parentCtx context.Context, req *state.GetRequest)
var ( var (
value []byte value []byte
isBinary bool isBinary bool
etag uint64 // Postgres uses uint32, but FormatUint requires uint64, so using uint64 directly to avoid re-allocations etag uint32
) )
query := `SELECT query := `SELECT
value, isbinary, xmin AS etag value, isbinary, xmin AS etag
@ -307,12 +333,11 @@ func (p *PostgresDBAccess) Get(parentCtx context.Context, req *state.GetRequest)
WHERE WHERE
key = $1 key = $1
AND (expiredate IS NULL OR expiredate >= CURRENT_TIMESTAMP)` AND (expiredate IS NULL OR expiredate >= CURRENT_TIMESTAMP)`
err := p.db. err := p.db.QueryRow(parentCtx, fmt.Sprintf(query, p.metadata.TableName), req.Key).
QueryRowContext(parentCtx, fmt.Sprintf(query, p.metadata.TableName), req.Key).
Scan(&value, &isBinary, &etag) Scan(&value, &isBinary, &etag)
if err != nil { if err != nil {
// If no rows exist, return an empty response, otherwise return the error. // If no rows exist, return an empty response, otherwise return the error.
if err == sql.ErrNoRows { if err == pgx.ErrNoRows {
return &state.GetResponse{}, nil return &state.GetResponse{}, nil
} }
return nil, err return nil, err
@ -334,14 +359,14 @@ func (p *PostgresDBAccess) Get(parentCtx context.Context, req *state.GetRequest)
return &state.GetResponse{ return &state.GetResponse{
Data: data, Data: data,
ETag: ptr.Of(strconv.FormatUint(etag, 10)), ETag: ptr.Of(strconv.FormatUint(uint64(etag), 10)),
Metadata: req.Metadata, Metadata: req.Metadata,
}, nil }, nil
} }
return &state.GetResponse{ return &state.GetResponse{
Data: value, Data: value,
ETag: ptr.Of(strconv.FormatUint(etag, 10)), ETag: ptr.Of(strconv.FormatUint(uint64(etag), 10)),
Metadata: req.Metadata, Metadata: req.Metadata,
}, nil }, nil
} }
@ -356,10 +381,9 @@ func (p *PostgresDBAccess) doDelete(parentCtx context.Context, db dbquerier, req
return errors.New("missing key in delete operation") return errors.New("missing key in delete operation")
} }
var result sql.Result var result pgconn.CommandTag
if req.ETag == nil || *req.ETag == "" { if req.ETag == nil || *req.ETag == "" {
result, err = db.ExecContext(parentCtx, "DELETE FROM state WHERE key = $1", req.Key) result, err = db.Exec(parentCtx, "DELETE FROM state WHERE key = $1", req.Key)
} else { } else {
// Convert req.ETag to uint32 for postgres XID compatibility // Convert req.ETag to uint32 for postgres XID compatibility
var etag64 uint64 var etag64 uint64
@ -367,20 +391,15 @@ func (p *PostgresDBAccess) doDelete(parentCtx context.Context, db dbquerier, req
if err != nil { if err != nil {
return state.NewETagError(state.ETagInvalid, err) return state.NewETagError(state.ETagInvalid, err)
} }
etag := uint32(etag64)
result, err = db.ExecContext(parentCtx, "DELETE FROM state WHERE key = $1 AND xmin = $2", req.Key, etag) result, err = db.Exec(parentCtx, "DELETE FROM state WHERE key = $1 AND xmin = $2", req.Key, uint32(etag64))
} }
if err != nil { if err != nil {
return err return err
} }
rows, err := result.RowsAffected() rows := result.RowsAffected()
if err != nil {
return err
}
if rows != 1 && req.ETag != nil && *req.ETag != "" { if rows != 1 && req.ETag != nil && *req.ETag != "" {
return state.NewETagError(state.ETagMismatch, nil) return state.NewETagError(state.ETagMismatch, nil)
} }
@ -389,11 +408,20 @@ func (p *PostgresDBAccess) doDelete(parentCtx context.Context, db dbquerier, req
} }
func (p *PostgresDBAccess) BulkDelete(parentCtx context.Context, req []state.DeleteRequest) error { func (p *PostgresDBAccess) BulkDelete(parentCtx context.Context, req []state.DeleteRequest) error {
tx, err := p.db.BeginTx(parentCtx, nil) ctx, cancel := context.WithTimeout(parentCtx, 30*time.Second)
tx, err := p.db.Begin(ctx)
cancel()
if err != nil { if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err) return fmt.Errorf("failed to begin transaction: %w", err)
} }
defer tx.Rollback() defer func() {
rollbackCtx, rollbackCancel := context.WithTimeout(parentCtx, 30*time.Second)
rollbackErr := tx.Rollback(rollbackCtx)
rollbackCancel()
if rollbackErr != nil && !errors.Is(rollbackErr, pgx.ErrTxClosed) {
p.logger.Errorf("Failed to rollback transaction in BulkDelete: %v", rollbackErr)
}
}()
if len(req) > 0 { if len(req) > 0 {
for i := range req { for i := range req {
@ -404,7 +432,9 @@ func (p *PostgresDBAccess) BulkDelete(parentCtx context.Context, req []state.Del
} }
} }
err = tx.Commit() ctx, cancel = context.WithTimeout(parentCtx, 30*time.Second)
err = tx.Commit(ctx)
cancel()
if err != nil { if err != nil {
return fmt.Errorf("failed to commit transaction: %w", err) return fmt.Errorf("failed to commit transaction: %w", err)
} }
@ -413,11 +443,20 @@ func (p *PostgresDBAccess) BulkDelete(parentCtx context.Context, req []state.Del
} }
func (p *PostgresDBAccess) ExecuteMulti(parentCtx context.Context, request *state.TransactionalStateRequest) error { func (p *PostgresDBAccess) ExecuteMulti(parentCtx context.Context, request *state.TransactionalStateRequest) error {
tx, err := p.db.BeginTx(parentCtx, nil) ctx, cancel := context.WithTimeout(parentCtx, 30*time.Second)
tx, err := p.db.Begin(ctx)
cancel()
if err != nil { if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err) return fmt.Errorf("failed to begin transaction: %w", err)
} }
defer tx.Rollback() defer func() {
rollbackCtx, rollbackCancel := context.WithTimeout(parentCtx, 30*time.Second)
rollbackErr := tx.Rollback(rollbackCtx)
rollbackCancel()
if rollbackErr != nil && !errors.Is(rollbackErr, pgx.ErrTxClosed) {
p.logger.Errorf("Failed to rollback transaction in ExecMulti: %v", rollbackErr)
}
}()
for _, o := range request.Operations { for _, o := range request.Operations {
switch o.Operation { switch o.Operation {
@ -450,7 +489,9 @@ func (p *PostgresDBAccess) ExecuteMulti(parentCtx context.Context, request *stat
} }
} }
err = tx.Commit() ctx, cancel = context.WithTimeout(parentCtx, 30*time.Second)
err = tx.Commit(ctx)
cancel()
if err != nil { if err != nil {
return fmt.Errorf("failed to commit transaction: %w", err) return fmt.Errorf("failed to commit transaction: %w", err)
} }
@ -520,19 +561,13 @@ func (p *PostgresDBAccess) CleanupExpired(ctx context.Context) error {
// Note we're not using the transaction here as we don't want this to be rolled back half-way or to lock the table unnecessarily // Note we're not using the transaction here as we don't want this to be rolled back half-way or to lock the table unnecessarily
// Need to use fmt.Sprintf because we can't parametrize a table name // Need to use fmt.Sprintf because we can't parametrize a table name
// Note we are not setting a timeout here as this query can take a "long" time, especially if there's no index on expiredate // Note we are not setting a timeout here as this query can take a "long" time, especially if there's no index on expiredate
//nolint:gosec
stmt := fmt.Sprintf(`DELETE FROM %s WHERE expiredate IS NOT NULL AND expiredate < CURRENT_TIMESTAMP`, p.metadata.TableName) stmt := fmt.Sprintf(`DELETE FROM %s WHERE expiredate IS NOT NULL AND expiredate < CURRENT_TIMESTAMP`, p.metadata.TableName)
res, err := p.db.ExecContext(ctx, stmt) res, err := p.db.Exec(ctx, stmt)
if err != nil { if err != nil {
return fmt.Errorf("failed to execute query: %w", err) return fmt.Errorf("failed to execute query: %w", err)
} }
cleaned, err := res.RowsAffected() p.logger.Infof("Removed %d expired rows", res.RowsAffected())
if err != nil {
return fmt.Errorf("failed to count affected rows: %w", err)
}
p.logger.Infof("Removed %d expired rows", cleaned)
return nil return nil
} }
@ -540,7 +575,7 @@ func (p *PostgresDBAccess) CleanupExpired(ctx context.Context) error {
// Returns true if the row was updated, which means that the cleanup can proceed. // Returns true if the row was updated, which means that the cleanup can proceed.
func (p *PostgresDBAccess) UpdateLastCleanup(ctx context.Context, db dbquerier, cleanupInterval time.Duration) (bool, error) { func (p *PostgresDBAccess) UpdateLastCleanup(ctx context.Context, db dbquerier, cleanupInterval time.Duration) (bool, error) {
queryCtx, cancel := context.WithTimeout(ctx, 30*time.Second) queryCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
res, err := db.ExecContext(queryCtx, res, err := db.Exec(queryCtx,
fmt.Sprintf(`INSERT INTO %[1]s (key, value) fmt.Sprintf(`INSERT INTO %[1]s (key, value)
VALUES ('last-cleanup', CURRENT_TIMESTAMP) VALUES ('last-cleanup', CURRENT_TIMESTAMP)
ON CONFLICT (key) ON CONFLICT (key)
@ -554,11 +589,7 @@ func (p *PostgresDBAccess) UpdateLastCleanup(ctx context.Context, db dbquerier,
return true, fmt.Errorf("failed to execute query: %w", err) return true, fmt.Errorf("failed to execute query: %w", err)
} }
n, err := res.RowsAffected() n := res.RowsAffected()
if err != nil {
return true, fmt.Errorf("failed to count affected rows: %w", err)
}
return n > 0, nil return n > 0, nil
} }
@ -569,7 +600,8 @@ func (p *PostgresDBAccess) Close() error {
p.cancel = nil p.cancel = nil
} }
if p.db != nil { if p.db != nil {
return p.db.Close() p.db.Close()
p.db = nil
} }
return nil return nil

View File

@ -16,24 +16,20 @@ package postgresql
import ( import (
"context" "context"
"database/sql" "encoding/json"
"testing" "testing"
"time" "time"
"github.com/DATA-DOG/go-sqlmock" pgxmock "github.com/pashagolub/pgxmock/v2"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/dapr/components-contrib/metadata" "github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/state" "github.com/dapr/components-contrib/state"
"github.com/dapr/kit/logger" "github.com/dapr/kit/logger"
// Blank import for pgx
_ "github.com/jackc/pgx/v5/stdlib"
) )
type mocks struct { type mocks struct {
db *sql.DB db pgxmock.PgxPoolIface
mock sqlmock.Sqlmock
pgDba *PostgresDBAccess pgDba *PostgresDBAccess
} }
@ -67,7 +63,7 @@ func TestGetSetValid(t *testing.T) {
} }
set, err := getSet(operation) set, err := getSet(operation)
assert.Nil(t, err) assert.NoError(t, err)
assert.Equal(t, "key1", set.Key) assert.Equal(t, "key1", set.Key)
} }
@ -101,7 +97,7 @@ func TestGetDeleteValid(t *testing.T) {
} }
delete, err := getDelete(operation) delete, err := getDelete(operation)
assert.Nil(t, err) assert.NoError(t, err)
assert.Equal(t, "key1", delete.Key) assert.Equal(t, "key1", delete.Key)
} }
@ -110,8 +106,10 @@ func TestMultiWithNoRequests(t *testing.T) {
m, _ := mockDatabase(t) m, _ := mockDatabase(t)
defer m.db.Close() defer m.db.Close()
m.mock.ExpectBegin() m.db.ExpectBegin()
m.mock.ExpectCommit() m.db.ExpectCommit()
// There's also a rollback called after a commit, which is expected and will not have effect
m.db.ExpectRollback()
var operations []state.TransactionalStateOperation var operations []state.TransactionalStateOperation
@ -121,7 +119,7 @@ func TestMultiWithNoRequests(t *testing.T) {
}) })
// Assert // Assert
assert.Nil(t, err) assert.NoError(t, err)
} }
func TestInvalidMultiInvalidAction(t *testing.T) { func TestInvalidMultiInvalidAction(t *testing.T) {
@ -129,8 +127,8 @@ func TestInvalidMultiInvalidAction(t *testing.T) {
m, _ := mockDatabase(t) m, _ := mockDatabase(t)
defer m.db.Close() defer m.db.Close()
m.mock.ExpectBegin() m.db.ExpectBegin()
m.mock.ExpectRollback() m.db.ExpectRollback()
var operations []state.TransactionalStateOperation var operations []state.TransactionalStateOperation
@ -153,16 +151,19 @@ func TestValidSetRequest(t *testing.T) {
m, _ := mockDatabase(t) m, _ := mockDatabase(t)
defer m.db.Close() defer m.db.Close()
m.mock.ExpectBegin() setReq := createSetRequest()
m.mock.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(1, 1)) operations := []state.TransactionalStateOperation{
m.mock.ExpectCommit() {Operation: state.Upsert, Request: setReq},
}
val, _ := json.Marshal(setReq.Value)
var operations []state.TransactionalStateOperation m.db.ExpectBegin()
m.db.ExpectExec("INSERT INTO").
operations = append(operations, state.TransactionalStateOperation{ WithArgs(setReq.Key, string(val), false).
Operation: state.Upsert, WillReturnResult(pgxmock.NewResult("INSERT", 1))
Request: createSetRequest(), m.db.ExpectCommit()
}) // There's also a rollback called after a commit, which is expected and will not have effect
m.db.ExpectRollback()
// Act // Act
err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{ err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{
@ -170,7 +171,7 @@ func TestValidSetRequest(t *testing.T) {
}) })
// Assert // Assert
assert.Nil(t, err) assert.NoError(t, err)
} }
func TestInvalidMultiSetRequest(t *testing.T) { func TestInvalidMultiSetRequest(t *testing.T) {
@ -178,15 +179,15 @@ func TestInvalidMultiSetRequest(t *testing.T) {
m, _ := mockDatabase(t) m, _ := mockDatabase(t)
defer m.db.Close() defer m.db.Close()
m.mock.ExpectBegin() m.db.ExpectBegin()
m.mock.ExpectRollback() m.db.ExpectRollback()
var operations []state.TransactionalStateOperation operations := []state.TransactionalStateOperation{
{
operations = append(operations, state.TransactionalStateOperation{ Operation: state.Upsert,
Operation: state.Upsert, Request: createDeleteRequest(), // Delete request is not valid for Upsert operation
Request: createDeleteRequest(), // Delete request is not valid for Upsert operation },
}) }
// Act // Act
err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{ err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{
@ -202,8 +203,8 @@ func TestInvalidMultiSetRequestNoKey(t *testing.T) {
m, _ := mockDatabase(t) m, _ := mockDatabase(t)
defer m.db.Close() defer m.db.Close()
m.mock.ExpectBegin() m.db.ExpectBegin()
m.mock.ExpectRollback() m.db.ExpectRollback()
var operations []state.TransactionalStateOperation var operations []state.TransactionalStateOperation
@ -226,16 +227,18 @@ func TestValidMultiDeleteRequest(t *testing.T) {
m, _ := mockDatabase(t) m, _ := mockDatabase(t)
defer m.db.Close() defer m.db.Close()
m.mock.ExpectBegin() deleteReq := createDeleteRequest()
m.mock.ExpectExec("DELETE FROM").WillReturnResult(sqlmock.NewResult(1, 1)) operations := []state.TransactionalStateOperation{
m.mock.ExpectCommit() {Operation: state.Delete, Request: deleteReq},
}
var operations []state.TransactionalStateOperation m.db.ExpectBegin()
m.db.ExpectExec("DELETE FROM").
operations = append(operations, state.TransactionalStateOperation{ WithArgs(deleteReq.Key).
Operation: state.Delete, WillReturnResult(pgxmock.NewResult("DELETE", 1))
Request: createDeleteRequest(), m.db.ExpectCommit()
}) // There's also a rollback called after a commit, which is expected and will not have effect
m.db.ExpectRollback()
// Act // Act
err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{ err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{
@ -243,7 +246,7 @@ func TestValidMultiDeleteRequest(t *testing.T) {
}) })
// Assert // Assert
assert.Nil(t, err) assert.NoError(t, err)
} }
func TestInvalidMultiDeleteRequest(t *testing.T) { func TestInvalidMultiDeleteRequest(t *testing.T) {
@ -251,8 +254,8 @@ func TestInvalidMultiDeleteRequest(t *testing.T) {
m, _ := mockDatabase(t) m, _ := mockDatabase(t)
defer m.db.Close() defer m.db.Close()
m.mock.ExpectBegin() m.db.ExpectBegin()
m.mock.ExpectRollback() m.db.ExpectRollback()
var operations []state.TransactionalStateOperation var operations []state.TransactionalStateOperation
@ -275,8 +278,8 @@ func TestInvalidMultiDeleteRequestNoKey(t *testing.T) {
m, _ := mockDatabase(t) m, _ := mockDatabase(t)
defer m.db.Close() defer m.db.Close()
m.mock.ExpectBegin() m.db.ExpectBegin()
m.mock.ExpectRollback() m.db.ExpectRollback()
var operations []state.TransactionalStateOperation var operations []state.TransactionalStateOperation
@ -299,23 +302,27 @@ func TestMultiOperationOrder(t *testing.T) {
m, _ := mockDatabase(t) m, _ := mockDatabase(t)
defer m.db.Close() defer m.db.Close()
m.mock.ExpectBegin() operations := []state.TransactionalStateOperation{
m.mock.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(1, 1)) {
m.mock.ExpectExec("DELETE FROM").WillReturnResult(sqlmock.NewResult(1, 1))
m.mock.ExpectCommit()
var operations []state.TransactionalStateOperation
operations = append(operations,
state.TransactionalStateOperation{
Operation: state.Upsert, Operation: state.Upsert,
Request: state.SetRequest{Key: "key1", Value: "value1"}, Request: state.SetRequest{Key: "key1", Value: "value1"},
}, },
state.TransactionalStateOperation{ {
Operation: state.Delete, Operation: state.Delete,
Request: state.DeleteRequest{Key: "key1"}, Request: state.DeleteRequest{Key: "key1"},
}, },
) }
m.db.ExpectBegin()
m.db.ExpectExec("INSERT INTO").
WithArgs("key1", `"value1"`, false).
WillReturnResult(pgxmock.NewResult("INSERT", 1))
m.db.ExpectExec("DELETE FROM").
WithArgs("key1").
WillReturnResult(pgxmock.NewResult("DELETE", 1))
m.db.ExpectCommit()
// There's also a rollback called after a commit, which is expected and will not have effect
m.db.ExpectRollback()
// Act // Act
err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{ err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{
@ -323,7 +330,7 @@ func TestMultiOperationOrder(t *testing.T) {
}) })
// Assert // Assert
assert.Nil(t, err) assert.NoError(t, err)
} }
func TestInvalidBulkSetNoKey(t *testing.T) { func TestInvalidBulkSetNoKey(t *testing.T) {
@ -331,14 +338,13 @@ func TestInvalidBulkSetNoKey(t *testing.T) {
m, _ := mockDatabase(t) m, _ := mockDatabase(t)
defer m.db.Close() defer m.db.Close()
m.mock.ExpectBegin() m.db.ExpectBegin()
m.mock.ExpectRollback() m.db.ExpectRollback()
var sets []state.SetRequest sets := []state.SetRequest{
// Set request without key is not valid for Set operation
sets = append(sets, state.SetRequest{ // Set request without key is not valid for Set operation {Value: "value1"},
Value: "value1", }
})
// Act // Act
err := m.pgDba.BulkSet(context.Background(), sets) err := m.pgDba.BulkSet(context.Background(), sets)
@ -352,15 +358,13 @@ func TestInvalidBulkSetEmptyValue(t *testing.T) {
m, _ := mockDatabase(t) m, _ := mockDatabase(t)
defer m.db.Close() defer m.db.Close()
m.mock.ExpectBegin() m.db.ExpectBegin()
m.mock.ExpectRollback() m.db.ExpectRollback()
var sets []state.SetRequest sets := []state.SetRequest{
// Set request without value is not valid for Set operation
sets = append(sets, state.SetRequest{ // Set request without value is not valid for Set operation {Key: "key1", Value: "value1"},
Key: "key1", }
Value: "",
})
// Act // Act
err := m.pgDba.BulkSet(context.Background(), sets) err := m.pgDba.BulkSet(context.Background(), sets)
@ -374,22 +378,26 @@ func TestValidBulkSet(t *testing.T) {
m, _ := mockDatabase(t) m, _ := mockDatabase(t)
defer m.db.Close() defer m.db.Close()
m.mock.ExpectBegin() sets := []state.SetRequest{
m.mock.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(1, 1)) {
m.mock.ExpectCommit() Key: "key1",
Value: "value1",
},
}
var sets []state.SetRequest m.db.ExpectBegin()
m.db.ExpectExec("INSERT INTO").
sets = append(sets, state.SetRequest{ WithArgs("key1", `"value1"`, false).
Key: "key1", WillReturnResult(pgxmock.NewResult("INSERT", 1))
Value: "value1", m.db.ExpectCommit()
}) // There's also a rollback called after a commit, which is expected and will not have effect
m.db.ExpectRollback()
// Act // Act
err := m.pgDba.BulkSet(context.Background(), sets) err := m.pgDba.BulkSet(context.Background(), sets)
// Assert // Assert
assert.Nil(t, err) assert.NoError(t, err)
} }
func TestInvalidBulkDeleteNoKey(t *testing.T) { func TestInvalidBulkDeleteNoKey(t *testing.T) {
@ -397,8 +405,8 @@ func TestInvalidBulkDeleteNoKey(t *testing.T) {
m, _ := mockDatabase(t) m, _ := mockDatabase(t)
defer m.db.Close() defer m.db.Close()
m.mock.ExpectBegin() m.db.ExpectBegin()
m.mock.ExpectRollback() m.db.ExpectRollback()
var deletes []state.DeleteRequest var deletes []state.DeleteRequest
@ -418,21 +426,23 @@ func TestValidBulkDelete(t *testing.T) {
m, _ := mockDatabase(t) m, _ := mockDatabase(t)
defer m.db.Close() defer m.db.Close()
m.mock.ExpectBegin() deletes := []state.DeleteRequest{
m.mock.ExpectExec("DELETE FROM").WillReturnResult(sqlmock.NewResult(1, 1)) {Key: "key1"},
m.mock.ExpectCommit() }
var deletes []state.DeleteRequest m.db.ExpectBegin()
m.db.ExpectExec("DELETE FROM").
deletes = append(deletes, state.DeleteRequest{ WithArgs("key1").
Key: "key1", WillReturnResult(pgxmock.NewResult("DELETE", 1))
}) m.db.ExpectCommit()
// There's also a rollback called after a commit, which is expected and will not have effect
m.db.ExpectRollback()
// Act // Act
err := m.pgDba.BulkDelete(context.Background(), deletes) err := m.pgDba.BulkDelete(context.Background(), deletes)
// Assert // Assert
assert.Nil(t, err) assert.NoError(t, err)
} }
func createSetRequest() state.SetRequest { func createSetRequest() state.SetRequest {
@ -451,7 +461,7 @@ func createDeleteRequest() state.DeleteRequest {
func mockDatabase(t *testing.T) (*mocks, error) { func mockDatabase(t *testing.T) (*mocks, error) {
logger := logger.NewLogger("test") logger := logger.NewLogger("test")
db, mock, err := sqlmock.New() db, err := pgxmock.NewPool()
if err != nil { if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
} }
@ -463,7 +473,6 @@ func mockDatabase(t *testing.T) (*mocks, error) {
return &mocks{ return &mocks{
db: db, db: db,
mock: mock,
pgDba: dba, pgDba: dba,
}, err }, err
} }

View File

@ -12,18 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
package postgresql package postgresql
import ( import (
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
"testing" "testing"
"time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dapr/components-contrib/metadata" "github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/state" "github.com/dapr/components-contrib/state"
@ -162,7 +165,7 @@ func deleteItemThatDoesNotExist(t *testing.T, pgs *PostgreSQL) {
Key: randomKey(), Key: randomKey(),
} }
err := pgs.Delete(context.Background(), deleteReq) err := pgs.Delete(context.Background(), deleteReq)
assert.Nil(t, err) assert.NoError(t, err)
} }
func multiWithSetOnly(t *testing.T, pgs *PostgreSQL) { func multiWithSetOnly(t *testing.T, pgs *PostgreSQL) {
@ -183,7 +186,7 @@ func multiWithSetOnly(t *testing.T, pgs *PostgreSQL) {
err := pgs.Multi(context.Background(), &state.TransactionalStateRequest{ err := pgs.Multi(context.Background(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
assert.Nil(t, err) assert.NoError(t, err)
for _, set := range setRequests { for _, set := range setRequests {
assert.True(t, storeItemExists(t, set.Key)) assert.True(t, storeItemExists(t, set.Key))
@ -213,7 +216,7 @@ func multiWithDeleteOnly(t *testing.T, pgs *PostgreSQL) {
err := pgs.Multi(context.Background(), &state.TransactionalStateRequest{ err := pgs.Multi(context.Background(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
assert.Nil(t, err) assert.NoError(t, err)
for _, delete := range deleteRequests { for _, delete := range deleteRequests {
assert.False(t, storeItemExists(t, delete.Key)) assert.False(t, storeItemExists(t, delete.Key))
@ -256,7 +259,7 @@ func multiWithDeleteAndSet(t *testing.T, pgs *PostgreSQL) {
err := pgs.Multi(context.Background(), &state.TransactionalStateRequest{ err := pgs.Multi(context.Background(), &state.TransactionalStateRequest{
Operations: operations, Operations: operations,
}) })
assert.Nil(t, err) assert.NoError(t, err)
for _, delete := range deleteRequests { for _, delete := range deleteRequests {
assert.False(t, storeItemExists(t, delete.Key)) assert.False(t, storeItemExists(t, delete.Key))
@ -384,15 +387,16 @@ func setUpdatesTheUpdatedateField(t *testing.T, pgs *PostgreSQL) {
// insertdate should have a value and updatedate should be nil // insertdate should have a value and updatedate should be nil
_, insertdate, updatedate := getRowData(t, key) _, insertdate, updatedate := getRowData(t, key)
assert.Nil(t, updatedate)
assert.NotNil(t, insertdate) assert.NotNil(t, insertdate)
assert.Equal(t, "", updatedate.String)
// insertdate should not change, updatedate should have a value // insertdate should not change, updatedate should have a value
value = &fakeItem{Color: "aqua"} value = &fakeItem{Color: "aqua"}
setItem(t, pgs, key, value, nil) setItem(t, pgs, key, value, nil)
_, newinsertdate, updatedate := getRowData(t, key) _, newinsertdate, updatedate := getRowData(t, key)
assert.Equal(t, insertdate, newinsertdate) // The insertdate should not change. assert.NotNil(t, updatedate)
assert.NotEqual(t, "", updatedate.String) assert.NotNil(t, newinsertdate)
assert.True(t, insertdate.Equal(*newinsertdate)) // The insertdate should not change.
deleteItem(t, pgs, key, nil) deleteItem(t, pgs, key, nil)
} }
@ -420,7 +424,7 @@ func testBulkSetAndBulkDelete(t *testing.T, pgs *PostgreSQL) {
} }
err := pgs.BulkSet(context.Background(), setReq) err := pgs.BulkSet(context.Background(), setReq)
assert.Nil(t, err) assert.NoError(t, err)
assert.True(t, storeItemExists(t, setReq[0].Key)) assert.True(t, storeItemExists(t, setReq[0].Key))
assert.True(t, storeItemExists(t, setReq[1].Key)) assert.True(t, storeItemExists(t, setReq[1].Key))
@ -434,7 +438,7 @@ func testBulkSetAndBulkDelete(t *testing.T, pgs *PostgreSQL) {
} }
err = pgs.BulkDelete(context.Background(), deleteReq) err = pgs.BulkDelete(context.Background(), deleteReq)
assert.Nil(t, err) assert.NoError(t, err)
assert.False(t, storeItemExists(t, setReq[0].Key)) assert.False(t, storeItemExists(t, setReq[0].Key))
assert.False(t, storeItemExists(t, setReq[1].Key)) assert.False(t, storeItemExists(t, setReq[1].Key))
} }
@ -491,7 +495,7 @@ func setItem(t *testing.T, pgs *PostgreSQL, key string, value interface{}, etag
} }
err := pgs.Set(context.Background(), setReq) err := pgs.Set(context.Background(), setReq)
assert.Nil(t, err) assert.NoError(t, err)
itemExists := storeItemExists(t, key) itemExists := storeItemExists(t, key)
assert.True(t, itemExists) assert.True(t, itemExists)
} }
@ -524,25 +528,28 @@ func deleteItem(t *testing.T, pgs *PostgreSQL, key string, etag *string) {
} }
func storeItemExists(t *testing.T, key string) bool { func storeItemExists(t *testing.T, key string) bool {
db, err := sql.Open("pgx", getConnectionString()) ctx := context.Background()
assert.Nil(t, err) db, err := pgx.Connect(ctx, getConnectionString())
defer db.Close() require.NoError(t, err)
defer db.Close(ctx)
exists := false exists := false
statement := fmt.Sprintf(`SELECT EXISTS (SELECT FROM %s WHERE key = $1)`, defaultTableName) statement := fmt.Sprintf(`SELECT EXISTS (SELECT FROM %s WHERE key = $1)`, defaultTableName)
err = db.QueryRow(statement, key).Scan(&exists) err = db.QueryRow(ctx, statement, key).Scan(&exists)
assert.Nil(t, err) assert.NoError(t, err)
return exists return exists
} }
func getRowData(t *testing.T, key string) (returnValue string, insertdate sql.NullString, updatedate sql.NullString) { func getRowData(t *testing.T, key string) (returnValue string, insertdate *time.Time, updatedate *time.Time) {
db, err := sql.Open("pgx", getConnectionString()) ctx := context.Background()
assert.Nil(t, err) db, err := pgx.Connect(ctx, getConnectionString())
defer db.Close() require.NoError(t, err)
defer db.Close(ctx)
err = db.QueryRow(fmt.Sprintf("SELECT value, insertdate, updatedate FROM %s WHERE key = $1", defaultTableName), key).Scan(&returnValue, &insertdate, &updatedate) err = db.QueryRow(ctx, fmt.Sprintf("SELECT value, insertdate, updatedate FROM %s WHERE key = $1", defaultTableName), key).
assert.Nil(t, err) Scan(&returnValue, &insertdate, &updatedate)
assert.NoError(t, err)
return returnValue, insertdate, updatedate return returnValue, insertdate, updatedate
} }

View File

@ -16,7 +16,6 @@ package postgresql
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
@ -140,8 +139,8 @@ func (q *Query) Finalize(filters string, qq *query.Query) error {
return nil return nil
} }
func (q *Query) execute(ctx context.Context, logger logger.Logger, db *sql.DB) ([]state.QueryItem, string, error) { func (q *Query) execute(ctx context.Context, logger logger.Logger, db dbquerier) ([]state.QueryItem, string, error) {
rows, err := db.QueryContext(ctx, q.query, q.params...) rows, err := db.Query(ctx, q.query, q.params...)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
@ -152,7 +151,7 @@ func (q *Query) execute(ctx context.Context, logger logger.Logger, db *sql.DB) (
var ( var (
key string key string
data []byte data []byte
etag uint64 // Postgres uses uint32, but FormatUint requires uint64, so using uint64 directly to avoid re-allocations etag uint32
) )
if err = rows.Scan(&key, &data, &etag); err != nil { if err = rows.Scan(&key, &data, &etag); err != nil {
return nil, "", err return nil, "", err
@ -160,7 +159,7 @@ func (q *Query) execute(ctx context.Context, logger logger.Logger, db *sql.DB) (
result := state.QueryItem{ result := state.QueryItem{
Key: key, Key: key,
Data: data, Data: data,
ETag: ptr.Of(strconv.FormatUint(etag, 10)), ETag: ptr.Of(strconv.FormatUint(uint64(etag), 10)),
} }
ret = append(ret, result) ret = append(ret, result)
} }

View File

@ -62,6 +62,7 @@ require (
github.com/imdario/mergo v0.3.12 // indirect github.com/imdario/mergo v0.3.12 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect
github.com/jackc/puddle/v2 v2.1.2 // indirect
github.com/jhump/protoreflect v1.13.0 // indirect github.com/jhump/protoreflect v1.13.0 // indirect
github.com/josharian/intern v1.0.0 // indirect github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect

View File

@ -37,7 +37,6 @@ github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a h1:
github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go.mod h1:C0A1KeiVHs+trY6gUTPhhGammbrZ30ZfXRW/nuT7HLw= github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go.mod h1:C0A1KeiVHs+trY6gUTPhhGammbrZ30ZfXRW/nuT7HLw=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ=
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig=
@ -288,6 +287,8 @@ github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHF
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E=
github.com/jackc/pgx/v5 v5.2.0 h1:NdPpngX0Y6z6XDFKqmFQaE+bCtkqzvQIOt1wvBlAqs8= github.com/jackc/pgx/v5 v5.2.0 h1:NdPpngX0Y6z6XDFKqmFQaE+bCtkqzvQIOt1wvBlAqs8=
github.com/jackc/pgx/v5 v5.2.0/go.mod h1:Ptn7zmohNsWEsdxRawMzk3gaKma2obW+NWTnKa0S4nk= github.com/jackc/pgx/v5 v5.2.0/go.mod h1:Ptn7zmohNsWEsdxRawMzk3gaKma2obW+NWTnKa0S4nk=
github.com/jackc/puddle/v2 v2.1.2 h1:0f7vaaXINONKTsxYDn4otOAiJanX/BMeAtY//BXqzlg=
github.com/jackc/puddle/v2 v2.1.2/go.mod h1:2lpufsF5mRHO6SuZkm0fNYxM6SWHfvyFj62KwNzgels=
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
github.com/jhump/gopoet v0.0.0-20190322174617-17282ff210b3/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI= github.com/jhump/gopoet v0.0.0-20190322174617-17282ff210b3/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI=
github.com/jhump/gopoet v0.1.0/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI= github.com/jhump/gopoet v0.1.0/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI=
@ -385,6 +386,7 @@ github.com/openzipkin/zipkin-go v0.4.1/go.mod h1:qY0VqDSN1pOBN94dBc6w2GJlWLiovAy
github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
github.com/pascaldekloe/goe v0.1.0 h1:cBOtyMzM9HTpWjXfbbunk26uA6nG3a8n06Wieeh0MwY= github.com/pascaldekloe/goe v0.1.0 h1:cBOtyMzM9HTpWjXfbbunk26uA6nG3a8n06Wieeh0MwY=
github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
github.com/pashagolub/pgxmock/v2 v2.4.0 h1:jNv7+svrNoMc31mvllSS/u7P2pT3gS3uY7DPRKIJNSY=
github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 h1:Ii+DKncOVM8Cu1Hc+ETb5K+23HdAMvESYE3ZJ5b5cMI= github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 h1:Ii+DKncOVM8Cu1Hc+ETb5K+23HdAMvESYE3ZJ5b5cMI=
github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE= github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=

View File

@ -16,7 +16,6 @@ package main
import ( import (
"bytes" "bytes"
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -706,7 +705,7 @@ func getMigrationLevel(dbClient *pgx.Conn, metadataTable string) (level string,
err = dbClient. err = dbClient.
QueryRow(ctx, fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, metadataTable)). QueryRow(ctx, fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, metadataTable)).
Scan(&level) Scan(&level)
if err != nil && errors.Is(err, sql.ErrNoRows) { if err != nil && errors.Is(err, pgx.ErrNoRows) {
err = nil err = nil
level = "" level = ""
} }
@ -766,7 +765,7 @@ func loadLastCleanupInterval(ctx context.Context, dbClient *pgx.Conn, table stri
). ).
Scan(&lastCleanup) Scan(&lastCleanup)
cancel() cancel()
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, pgx.ErrNoRows) {
err = nil err = nil
} }
return return
@ -790,7 +789,7 @@ func getValueFromMetadataTable(ctx context.Context, dbClient *pgx.Conn, table, k
QueryRow(queryCtx, fmt.Sprintf("SELECT value FROM %s WHERE key = $1", table), key). QueryRow(queryCtx, fmt.Sprintf("SELECT value FROM %s WHERE key = $1", table), key).
Scan(&value) Scan(&value)
cancel() cancel()
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, pgx.ErrNoRows) {
err = nil err = nil
} }
return return

View File

@ -272,7 +272,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
req.Metadata = map[string]string{metadata.ContentType: scenario.contentType} req.Metadata = map[string]string{metadata.ContentType: scenario.contentType}
} }
res, err := statestore.Get(context.Background(), req) res, err := statestore.Get(context.Background(), req)
assert.Nil(t, err) require.NoError(t, err)
assertEquals(t, scenario.value, res) assertEquals(t, scenario.value, res)
} }
} }
@ -287,13 +287,13 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
t.Logf("Querying value presence for %s", scenario.query) t.Logf("Querying value presence for %s", scenario.query)
var req state.QueryRequest var req state.QueryRequest
err := json.Unmarshal([]byte(scenario.query), &req.Query) err := json.Unmarshal([]byte(scenario.query), &req.Query)
assert.NoError(t, err) require.NoError(t, err)
req.Metadata = map[string]string{ req.Metadata = map[string]string{
metadata.ContentType: contenttype.JSONContentType, metadata.ContentType: contenttype.JSONContentType,
metadata.QueryIndexName: "qIndx", metadata.QueryIndexName: "qIndx",
} }
resp, err := querier.Query(context.Background(), &req) resp, err := querier.Query(context.Background(), &req)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, len(scenario.results), len(resp.Results)) assert.Equal(t, len(scenario.results), len(resp.Results))
for i := range scenario.results { for i := range scenario.results {
var expected, actual interface{} var expected, actual interface{}