Changed postgres state store to use pgx instead of db/sql
Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
parent
2582f154bd
commit
80b1bb14b7
|
|
@ -15,7 +15,6 @@ package postgres
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
|
@ -114,7 +113,7 @@ func (p *ConfigurationStore) Init(metadata configuration.Metadata) error {
|
|||
exists := false
|
||||
err = p.client.QueryRow(ctx, QueryTableExists, p.metadata.configTable).Scan(&exists)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return fmt.Errorf(ErrorMissingTable, p.metadata.configTable)
|
||||
}
|
||||
return fmt.Errorf("error in checking if configtable '%s' exists - '%w'", p.metadata.configTable, err)
|
||||
|
|
@ -135,7 +134,7 @@ func (p *ConfigurationStore) Get(ctx context.Context, req *configuration.GetRequ
|
|||
rows, err := p.client.Query(ctx, query, params...)
|
||||
if err != nil {
|
||||
// If no rows exist, return an empty response, otherwise return the error.
|
||||
if err == sql.ErrNoRows {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return &configuration.GetResponse{}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("error in querying configuration store: '%w'", err)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,9 @@ package postgresql
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
|
||||
"github.com/dapr/components-contrib/state"
|
||||
)
|
||||
|
|
@ -34,10 +36,9 @@ type dbAccess interface {
|
|||
}
|
||||
|
||||
// Interface that contains methods for querying.
|
||||
// Applies to both *sql.DB and *sql.Tx
|
||||
// Applies to *pgx.Conn, *pgxpool.Pool, and pgx.Tx
|
||||
type dbquerier interface {
|
||||
Exec(query string, args ...any) (sql.Result, error)
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
QueryRow(query string, args ...any) *sql.Row
|
||||
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
||||
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
|
||||
Query(context.Context, string, ...interface{}) (pgx.Rows, error)
|
||||
QueryRow(context.Context, string, ...interface{}) pgx.Row
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,20 +15,21 @@ package postgresql
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
||||
// Performs migrations for the database schema
|
||||
type migrations struct {
|
||||
Logger logger.Logger
|
||||
Conn *sql.DB
|
||||
Conn pgxPoolConn
|
||||
StateTableName string
|
||||
MetadataTableName string
|
||||
}
|
||||
|
|
@ -42,7 +43,7 @@ func (m *migrations) Perform(ctx context.Context) error {
|
|||
|
||||
// Long timeout here as this query may block
|
||||
queryCtx, cancel := context.WithTimeout(ctx, time.Minute)
|
||||
_, err := m.Conn.ExecContext(queryCtx, "SELECT pg_advisory_lock($1)", lockID)
|
||||
_, err := m.Conn.Exec(queryCtx, "SELECT pg_advisory_lock($1)", lockID)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return fmt.Errorf("faild to acquire advisory lock: %w", err)
|
||||
|
|
@ -51,7 +52,7 @@ func (m *migrations) Perform(ctx context.Context) error {
|
|||
// Release the lock
|
||||
defer func() {
|
||||
queryCtx, cancel = context.WithTimeout(ctx, time.Minute)
|
||||
_, err = m.Conn.ExecContext(queryCtx, "SELECT pg_advisory_unlock($1)", lockID)
|
||||
_, err = m.Conn.Exec(queryCtx, "SELECT pg_advisory_unlock($1)", lockID)
|
||||
cancel()
|
||||
if err != nil {
|
||||
// Panicking here, as this forcibly closes the session and thus ensures we are not leaving locks hanging around
|
||||
|
|
@ -84,11 +85,11 @@ func (m *migrations) Perform(ctx context.Context) error {
|
|||
)
|
||||
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
|
||||
err = m.Conn.
|
||||
QueryRowContext(queryCtx,
|
||||
QueryRow(queryCtx,
|
||||
fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, m.MetadataTableName),
|
||||
).Scan(&migrationLevelStr)
|
||||
cancel()
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
// If there's no row...
|
||||
migrationLevel = 0
|
||||
} else if err != nil {
|
||||
|
|
@ -109,7 +110,7 @@ func (m *migrations) Perform(ctx context.Context) error {
|
|||
}
|
||||
|
||||
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
|
||||
_, err = m.Conn.ExecContext(queryCtx,
|
||||
_, err = m.Conn.Exec(queryCtx,
|
||||
fmt.Sprintf(`INSERT INTO %s (key, value) VALUES ('migrations', $1) ON CONFLICT (key) DO UPDATE SET value = $1`, m.MetadataTableName),
|
||||
strconv.Itoa(i+1),
|
||||
)
|
||||
|
|
@ -126,7 +127,7 @@ func (m migrations) createMetadataTable(ctx context.Context) error {
|
|||
m.Logger.Infof("Creating metadata table '%s'", m.MetadataTableName)
|
||||
// Add an "IF NOT EXISTS" in case another Dapr sidecar is creating the same table at the same time
|
||||
// In the next step we'll acquire a lock so there won't be issues with concurrency
|
||||
_, err := m.Conn.Exec(fmt.Sprintf(
|
||||
_, err := m.Conn.Exec(ctx, fmt.Sprintf(
|
||||
`CREATE TABLE IF NOT EXISTS %s (
|
||||
key text NOT NULL PRIMARY KEY,
|
||||
value text NOT NULL
|
||||
|
|
@ -148,7 +149,7 @@ func (m migrations) tableExists(ctx context.Context, tableName string) (exists b
|
|||
|
||||
if schema == "" {
|
||||
err = m.Conn.
|
||||
QueryRowContext(
|
||||
QueryRow(
|
||||
ctx,
|
||||
`SELECT table_name, table_schema
|
||||
FROM information_schema.tables
|
||||
|
|
@ -158,7 +159,7 @@ func (m migrations) tableExists(ctx context.Context, tableName string) (exists b
|
|||
Scan(&table, &schema)
|
||||
} else {
|
||||
err = m.Conn.
|
||||
QueryRowContext(
|
||||
QueryRow(
|
||||
ctx,
|
||||
`SELECT table_name, table_schema
|
||||
FROM information_schema.tables
|
||||
|
|
@ -168,7 +169,7 @@ func (m migrations) tableExists(ctx context.Context, tableName string) (exists b
|
|||
Scan(&table, &schema)
|
||||
}
|
||||
|
||||
if err != nil && errors.Is(err, sql.ErrNoRows) {
|
||||
if err != nil && errors.Is(err, pgx.ErrNoRows) {
|
||||
return false, "", "", nil
|
||||
} else if err != nil {
|
||||
return false, "", "", fmt.Errorf("failed to check if table '%s' exists: %w", tableName, err)
|
||||
|
|
@ -195,6 +196,7 @@ var allMigrations = [2]func(ctx context.Context, m *migrations) error{
|
|||
// We need to add an "IF NOT EXISTS" because we may be migrating from when we did not use a metadata table
|
||||
m.Logger.Infof("Creating state table '%s'", m.StateTableName)
|
||||
_, err := m.Conn.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(
|
||||
`CREATE TABLE IF NOT EXISTS %s (
|
||||
key text NOT NULL PRIMARY KEY,
|
||||
|
|
@ -215,7 +217,7 @@ var allMigrations = [2]func(ctx context.Context, m *migrations) error{
|
|||
// Migration 1: add the "expiredate" column
|
||||
func(ctx context.Context, m *migrations) error {
|
||||
m.Logger.Infof("Adding expiredate column to state table '%s'", m.StateTableName)
|
||||
_, err := m.Conn.Exec(fmt.Sprintf(
|
||||
_, err := m.Conn.Exec(ctx, fmt.Sprintf(
|
||||
`ALTER TABLE %s ADD expiredate TIMESTAMP WITH TIME ZONE`,
|
||||
m.StateTableName,
|
||||
))
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ package postgresql
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
|
@ -23,15 +22,16 @@ import (
|
|||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/dapr/components-contrib/metadata"
|
||||
"github.com/dapr/components-contrib/state"
|
||||
"github.com/dapr/components-contrib/state/query"
|
||||
stateutils "github.com/dapr/components-contrib/state/utils"
|
||||
"github.com/dapr/kit/logger"
|
||||
"github.com/dapr/kit/ptr"
|
||||
|
||||
// Blank import for the underlying Postgres driver.
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -43,12 +43,24 @@ const (
|
|||
|
||||
var errMissingConnectionString = errors.New("missing connection string")
|
||||
|
||||
// Interface that applies to *pgxpool.Pool.
|
||||
// We need this to be able to mock the connection in tests.
|
||||
type pgxPoolConn interface {
|
||||
Begin(context.Context) (pgx.Tx, error)
|
||||
BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error)
|
||||
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
|
||||
Query(context.Context, string, ...interface{}) (pgx.Rows, error)
|
||||
QueryRow(context.Context, string, ...interface{}) pgx.Row
|
||||
Ping(context.Context) error
|
||||
Close()
|
||||
}
|
||||
|
||||
// PostgresDBAccess implements dbaccess.
|
||||
type PostgresDBAccess struct {
|
||||
logger logger.Logger
|
||||
metadata postgresMetadataStruct
|
||||
cleanupInterval *time.Duration
|
||||
db *sql.DB
|
||||
db pgxPoolConn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
|
@ -81,23 +93,31 @@ func (p *PostgresDBAccess) Init(meta state.Metadata) error {
|
|||
return err
|
||||
}
|
||||
|
||||
db, err := sql.Open("pgx", p.metadata.ConnectionString)
|
||||
config, err := pgxpool.ParseConfig(p.metadata.ConnectionString)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to parse connection string: %w", err)
|
||||
p.logger.Error(err)
|
||||
return err
|
||||
}
|
||||
if p.metadata.ConnectionMaxIdleTime > 0 {
|
||||
config.MaxConnIdleTime = p.metadata.ConnectionMaxIdleTime
|
||||
}
|
||||
|
||||
connCtx, connCancel := context.WithTimeout(p.ctx, 30*time.Second)
|
||||
p.db, err = pgxpool.NewWithConfig(connCtx, config)
|
||||
connCancel()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to connect to the database: %w", err)
|
||||
p.logger.Error(err)
|
||||
return err
|
||||
}
|
||||
|
||||
p.db = db
|
||||
|
||||
pingCtx, pingCancel := context.WithTimeout(p.ctx, 30*time.Second)
|
||||
pingErr := db.PingContext(pingCtx)
|
||||
err = p.db.Ping(pingCtx)
|
||||
pingCancel()
|
||||
if pingErr != nil {
|
||||
return pingErr
|
||||
}
|
||||
|
||||
p.db.SetConnMaxIdleTime(p.metadata.ConnectionMaxIdleTime)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to ping the database: %w", err)
|
||||
p.logger.Error(err)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -117,8 +137,9 @@ func (p *PostgresDBAccess) Init(meta state.Metadata) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgresDBAccess) GetDB() *sql.DB {
|
||||
return p.db
|
||||
func (p *PostgresDBAccess) GetDB() *pgxpool.Pool {
|
||||
// We can safely cast to *pgxpool.Pool because this method is never used in unit tests where we mock the DB
|
||||
return p.db.(*pgxpool.Pool)
|
||||
}
|
||||
|
||||
func (p *PostgresDBAccess) ParseMetadata(meta state.Metadata) error {
|
||||
|
|
@ -193,8 +214,6 @@ func (p *PostgresDBAccess) doSet(parentCtx context.Context, db dbquerier, req *s
|
|||
ttlSeconds = *ttl
|
||||
}
|
||||
|
||||
var result sql.Result
|
||||
|
||||
// Sprintf is required for table name because query.DB does not substitute parameters for table names.
|
||||
// Other parameters use query.DB parameter substitution.
|
||||
var (
|
||||
|
|
@ -246,8 +265,8 @@ func (p *PostgresDBAccess) doSet(parentCtx context.Context, db dbquerier, req *s
|
|||
} else {
|
||||
queryExpiredate = "NULL"
|
||||
}
|
||||
result, err = db.ExecContext(parentCtx, fmt.Sprintf(query, p.metadata.TableName, queryExpiredate), params...)
|
||||
|
||||
result, err := db.Exec(parentCtx, fmt.Sprintf(query, p.metadata.TableName, queryExpiredate), params...)
|
||||
if err != nil {
|
||||
if req.ETag != nil && *req.ETag != "" {
|
||||
return state.NewETagError(state.ETagMismatch, err)
|
||||
|
|
@ -255,11 +274,7 @@ func (p *PostgresDBAccess) doSet(parentCtx context.Context, db dbquerier, req *s
|
|||
return err
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if rows != 1 {
|
||||
if result.RowsAffected() != 1 {
|
||||
return errors.New("no item was updated")
|
||||
}
|
||||
|
||||
|
|
@ -267,11 +282,20 @@ func (p *PostgresDBAccess) doSet(parentCtx context.Context, db dbquerier, req *s
|
|||
}
|
||||
|
||||
func (p *PostgresDBAccess) BulkSet(parentCtx context.Context, req []state.SetRequest) error {
|
||||
tx, err := p.db.BeginTx(parentCtx, nil)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 30*time.Second)
|
||||
tx, err := p.db.Begin(ctx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
defer func() {
|
||||
rollbackCtx, rollbackCancel := context.WithTimeout(parentCtx, 30*time.Second)
|
||||
rollbackErr := tx.Rollback(rollbackCtx)
|
||||
rollbackCancel()
|
||||
if rollbackErr != nil && !errors.Is(rollbackErr, pgx.ErrTxClosed) {
|
||||
p.logger.Errorf("Failed to rollback transaction in BulkSet: %v", rollbackErr)
|
||||
}
|
||||
}()
|
||||
|
||||
if len(req) > 0 {
|
||||
for i := range req {
|
||||
|
|
@ -282,7 +306,9 @@ func (p *PostgresDBAccess) BulkSet(parentCtx context.Context, req []state.SetReq
|
|||
}
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
ctx, cancel = context.WithTimeout(parentCtx, 30*time.Second)
|
||||
err = tx.Commit(ctx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
|
|
@ -299,7 +325,7 @@ func (p *PostgresDBAccess) Get(parentCtx context.Context, req *state.GetRequest)
|
|||
var (
|
||||
value []byte
|
||||
isBinary bool
|
||||
etag uint64 // Postgres uses uint32, but FormatUint requires uint64, so using uint64 directly to avoid re-allocations
|
||||
etag uint32
|
||||
)
|
||||
query := `SELECT
|
||||
value, isbinary, xmin AS etag
|
||||
|
|
@ -307,12 +333,11 @@ func (p *PostgresDBAccess) Get(parentCtx context.Context, req *state.GetRequest)
|
|||
WHERE
|
||||
key = $1
|
||||
AND (expiredate IS NULL OR expiredate >= CURRENT_TIMESTAMP)`
|
||||
err := p.db.
|
||||
QueryRowContext(parentCtx, fmt.Sprintf(query, p.metadata.TableName), req.Key).
|
||||
err := p.db.QueryRow(parentCtx, fmt.Sprintf(query, p.metadata.TableName), req.Key).
|
||||
Scan(&value, &isBinary, &etag)
|
||||
if err != nil {
|
||||
// If no rows exist, return an empty response, otherwise return the error.
|
||||
if err == sql.ErrNoRows {
|
||||
if err == pgx.ErrNoRows {
|
||||
return &state.GetResponse{}, nil
|
||||
}
|
||||
return nil, err
|
||||
|
|
@ -334,14 +359,14 @@ func (p *PostgresDBAccess) Get(parentCtx context.Context, req *state.GetRequest)
|
|||
|
||||
return &state.GetResponse{
|
||||
Data: data,
|
||||
ETag: ptr.Of(strconv.FormatUint(etag, 10)),
|
||||
ETag: ptr.Of(strconv.FormatUint(uint64(etag), 10)),
|
||||
Metadata: req.Metadata,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &state.GetResponse{
|
||||
Data: value,
|
||||
ETag: ptr.Of(strconv.FormatUint(etag, 10)),
|
||||
ETag: ptr.Of(strconv.FormatUint(uint64(etag), 10)),
|
||||
Metadata: req.Metadata,
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -356,10 +381,9 @@ func (p *PostgresDBAccess) doDelete(parentCtx context.Context, db dbquerier, req
|
|||
return errors.New("missing key in delete operation")
|
||||
}
|
||||
|
||||
var result sql.Result
|
||||
|
||||
var result pgconn.CommandTag
|
||||
if req.ETag == nil || *req.ETag == "" {
|
||||
result, err = db.ExecContext(parentCtx, "DELETE FROM state WHERE key = $1", req.Key)
|
||||
result, err = db.Exec(parentCtx, "DELETE FROM state WHERE key = $1", req.Key)
|
||||
} else {
|
||||
// Convert req.ETag to uint32 for postgres XID compatibility
|
||||
var etag64 uint64
|
||||
|
|
@ -367,20 +391,15 @@ func (p *PostgresDBAccess) doDelete(parentCtx context.Context, db dbquerier, req
|
|||
if err != nil {
|
||||
return state.NewETagError(state.ETagInvalid, err)
|
||||
}
|
||||
etag := uint32(etag64)
|
||||
|
||||
result, err = db.ExecContext(parentCtx, "DELETE FROM state WHERE key = $1 AND xmin = $2", req.Key, etag)
|
||||
result, err = db.Exec(parentCtx, "DELETE FROM state WHERE key = $1 AND xmin = $2", req.Key, uint32(etag64))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rows := result.RowsAffected()
|
||||
if rows != 1 && req.ETag != nil && *req.ETag != "" {
|
||||
return state.NewETagError(state.ETagMismatch, nil)
|
||||
}
|
||||
|
|
@ -389,11 +408,20 @@ func (p *PostgresDBAccess) doDelete(parentCtx context.Context, db dbquerier, req
|
|||
}
|
||||
|
||||
func (p *PostgresDBAccess) BulkDelete(parentCtx context.Context, req []state.DeleteRequest) error {
|
||||
tx, err := p.db.BeginTx(parentCtx, nil)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 30*time.Second)
|
||||
tx, err := p.db.Begin(ctx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
defer func() {
|
||||
rollbackCtx, rollbackCancel := context.WithTimeout(parentCtx, 30*time.Second)
|
||||
rollbackErr := tx.Rollback(rollbackCtx)
|
||||
rollbackCancel()
|
||||
if rollbackErr != nil && !errors.Is(rollbackErr, pgx.ErrTxClosed) {
|
||||
p.logger.Errorf("Failed to rollback transaction in BulkDelete: %v", rollbackErr)
|
||||
}
|
||||
}()
|
||||
|
||||
if len(req) > 0 {
|
||||
for i := range req {
|
||||
|
|
@ -404,7 +432,9 @@ func (p *PostgresDBAccess) BulkDelete(parentCtx context.Context, req []state.Del
|
|||
}
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
ctx, cancel = context.WithTimeout(parentCtx, 30*time.Second)
|
||||
err = tx.Commit(ctx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
|
|
@ -413,11 +443,20 @@ func (p *PostgresDBAccess) BulkDelete(parentCtx context.Context, req []state.Del
|
|||
}
|
||||
|
||||
func (p *PostgresDBAccess) ExecuteMulti(parentCtx context.Context, request *state.TransactionalStateRequest) error {
|
||||
tx, err := p.db.BeginTx(parentCtx, nil)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 30*time.Second)
|
||||
tx, err := p.db.Begin(ctx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
defer func() {
|
||||
rollbackCtx, rollbackCancel := context.WithTimeout(parentCtx, 30*time.Second)
|
||||
rollbackErr := tx.Rollback(rollbackCtx)
|
||||
rollbackCancel()
|
||||
if rollbackErr != nil && !errors.Is(rollbackErr, pgx.ErrTxClosed) {
|
||||
p.logger.Errorf("Failed to rollback transaction in ExecMulti: %v", rollbackErr)
|
||||
}
|
||||
}()
|
||||
|
||||
for _, o := range request.Operations {
|
||||
switch o.Operation {
|
||||
|
|
@ -450,7 +489,9 @@ func (p *PostgresDBAccess) ExecuteMulti(parentCtx context.Context, request *stat
|
|||
}
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
ctx, cancel = context.WithTimeout(parentCtx, 30*time.Second)
|
||||
err = tx.Commit(ctx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
|
|
@ -520,19 +561,13 @@ func (p *PostgresDBAccess) CleanupExpired(ctx context.Context) error {
|
|||
// Note we're not using the transaction here as we don't want this to be rolled back half-way or to lock the table unnecessarily
|
||||
// Need to use fmt.Sprintf because we can't parametrize a table name
|
||||
// Note we are not setting a timeout here as this query can take a "long" time, especially if there's no index on expiredate
|
||||
//nolint:gosec
|
||||
stmt := fmt.Sprintf(`DELETE FROM %s WHERE expiredate IS NOT NULL AND expiredate < CURRENT_TIMESTAMP`, p.metadata.TableName)
|
||||
res, err := p.db.ExecContext(ctx, stmt)
|
||||
res, err := p.db.Exec(ctx, stmt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute query: %w", err)
|
||||
}
|
||||
|
||||
cleaned, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to count affected rows: %w", err)
|
||||
}
|
||||
|
||||
p.logger.Infof("Removed %d expired rows", cleaned)
|
||||
p.logger.Infof("Removed %d expired rows", res.RowsAffected())
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -540,7 +575,7 @@ func (p *PostgresDBAccess) CleanupExpired(ctx context.Context) error {
|
|||
// Returns true if the row was updated, which means that the cleanup can proceed.
|
||||
func (p *PostgresDBAccess) UpdateLastCleanup(ctx context.Context, db dbquerier, cleanupInterval time.Duration) (bool, error) {
|
||||
queryCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
res, err := db.ExecContext(queryCtx,
|
||||
res, err := db.Exec(queryCtx,
|
||||
fmt.Sprintf(`INSERT INTO %[1]s (key, value)
|
||||
VALUES ('last-cleanup', CURRENT_TIMESTAMP)
|
||||
ON CONFLICT (key)
|
||||
|
|
@ -554,11 +589,7 @@ func (p *PostgresDBAccess) UpdateLastCleanup(ctx context.Context, db dbquerier,
|
|||
return true, fmt.Errorf("failed to execute query: %w", err)
|
||||
}
|
||||
|
||||
n, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return true, fmt.Errorf("failed to count affected rows: %w", err)
|
||||
}
|
||||
|
||||
n := res.RowsAffected()
|
||||
return n > 0, nil
|
||||
}
|
||||
|
||||
|
|
@ -569,7 +600,8 @@ func (p *PostgresDBAccess) Close() error {
|
|||
p.cancel = nil
|
||||
}
|
||||
if p.db != nil {
|
||||
return p.db.Close()
|
||||
p.db.Close()
|
||||
p.db = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -16,24 +16,20 @@ package postgresql
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
pgxmock "github.com/pashagolub/pgxmock/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/dapr/components-contrib/metadata"
|
||||
"github.com/dapr/components-contrib/state"
|
||||
"github.com/dapr/kit/logger"
|
||||
|
||||
// Blank import for pgx
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
)
|
||||
|
||||
type mocks struct {
|
||||
db *sql.DB
|
||||
mock sqlmock.Sqlmock
|
||||
db pgxmock.PgxPoolIface
|
||||
pgDba *PostgresDBAccess
|
||||
}
|
||||
|
||||
|
|
@ -67,7 +63,7 @@ func TestGetSetValid(t *testing.T) {
|
|||
}
|
||||
|
||||
set, err := getSet(operation)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "key1", set.Key)
|
||||
}
|
||||
|
||||
|
|
@ -101,7 +97,7 @@ func TestGetDeleteValid(t *testing.T) {
|
|||
}
|
||||
|
||||
delete, err := getDelete(operation)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "key1", delete.Key)
|
||||
}
|
||||
|
||||
|
|
@ -110,8 +106,10 @@ func TestMultiWithNoRequests(t *testing.T) {
|
|||
m, _ := mockDatabase(t)
|
||||
defer m.db.Close()
|
||||
|
||||
m.mock.ExpectBegin()
|
||||
m.mock.ExpectCommit()
|
||||
m.db.ExpectBegin()
|
||||
m.db.ExpectCommit()
|
||||
// There's also a rollback called after a commit, which is expected and will not have effect
|
||||
m.db.ExpectRollback()
|
||||
|
||||
var operations []state.TransactionalStateOperation
|
||||
|
||||
|
|
@ -121,7 +119,7 @@ func TestMultiWithNoRequests(t *testing.T) {
|
|||
})
|
||||
|
||||
// Assert
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestInvalidMultiInvalidAction(t *testing.T) {
|
||||
|
|
@ -129,8 +127,8 @@ func TestInvalidMultiInvalidAction(t *testing.T) {
|
|||
m, _ := mockDatabase(t)
|
||||
defer m.db.Close()
|
||||
|
||||
m.mock.ExpectBegin()
|
||||
m.mock.ExpectRollback()
|
||||
m.db.ExpectBegin()
|
||||
m.db.ExpectRollback()
|
||||
|
||||
var operations []state.TransactionalStateOperation
|
||||
|
||||
|
|
@ -153,16 +151,19 @@ func TestValidSetRequest(t *testing.T) {
|
|||
m, _ := mockDatabase(t)
|
||||
defer m.db.Close()
|
||||
|
||||
m.mock.ExpectBegin()
|
||||
m.mock.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
m.mock.ExpectCommit()
|
||||
setReq := createSetRequest()
|
||||
operations := []state.TransactionalStateOperation{
|
||||
{Operation: state.Upsert, Request: setReq},
|
||||
}
|
||||
val, _ := json.Marshal(setReq.Value)
|
||||
|
||||
var operations []state.TransactionalStateOperation
|
||||
|
||||
operations = append(operations, state.TransactionalStateOperation{
|
||||
Operation: state.Upsert,
|
||||
Request: createSetRequest(),
|
||||
})
|
||||
m.db.ExpectBegin()
|
||||
m.db.ExpectExec("INSERT INTO").
|
||||
WithArgs(setReq.Key, string(val), false).
|
||||
WillReturnResult(pgxmock.NewResult("INSERT", 1))
|
||||
m.db.ExpectCommit()
|
||||
// There's also a rollback called after a commit, which is expected and will not have effect
|
||||
m.db.ExpectRollback()
|
||||
|
||||
// Act
|
||||
err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{
|
||||
|
|
@ -170,7 +171,7 @@ func TestValidSetRequest(t *testing.T) {
|
|||
})
|
||||
|
||||
// Assert
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestInvalidMultiSetRequest(t *testing.T) {
|
||||
|
|
@ -178,15 +179,15 @@ func TestInvalidMultiSetRequest(t *testing.T) {
|
|||
m, _ := mockDatabase(t)
|
||||
defer m.db.Close()
|
||||
|
||||
m.mock.ExpectBegin()
|
||||
m.mock.ExpectRollback()
|
||||
m.db.ExpectBegin()
|
||||
m.db.ExpectRollback()
|
||||
|
||||
var operations []state.TransactionalStateOperation
|
||||
|
||||
operations = append(operations, state.TransactionalStateOperation{
|
||||
Operation: state.Upsert,
|
||||
Request: createDeleteRequest(), // Delete request is not valid for Upsert operation
|
||||
})
|
||||
operations := []state.TransactionalStateOperation{
|
||||
{
|
||||
Operation: state.Upsert,
|
||||
Request: createDeleteRequest(), // Delete request is not valid for Upsert operation
|
||||
},
|
||||
}
|
||||
|
||||
// Act
|
||||
err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{
|
||||
|
|
@ -202,8 +203,8 @@ func TestInvalidMultiSetRequestNoKey(t *testing.T) {
|
|||
m, _ := mockDatabase(t)
|
||||
defer m.db.Close()
|
||||
|
||||
m.mock.ExpectBegin()
|
||||
m.mock.ExpectRollback()
|
||||
m.db.ExpectBegin()
|
||||
m.db.ExpectRollback()
|
||||
|
||||
var operations []state.TransactionalStateOperation
|
||||
|
||||
|
|
@ -226,16 +227,18 @@ func TestValidMultiDeleteRequest(t *testing.T) {
|
|||
m, _ := mockDatabase(t)
|
||||
defer m.db.Close()
|
||||
|
||||
m.mock.ExpectBegin()
|
||||
m.mock.ExpectExec("DELETE FROM").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
m.mock.ExpectCommit()
|
||||
deleteReq := createDeleteRequest()
|
||||
operations := []state.TransactionalStateOperation{
|
||||
{Operation: state.Delete, Request: deleteReq},
|
||||
}
|
||||
|
||||
var operations []state.TransactionalStateOperation
|
||||
|
||||
operations = append(operations, state.TransactionalStateOperation{
|
||||
Operation: state.Delete,
|
||||
Request: createDeleteRequest(),
|
||||
})
|
||||
m.db.ExpectBegin()
|
||||
m.db.ExpectExec("DELETE FROM").
|
||||
WithArgs(deleteReq.Key).
|
||||
WillReturnResult(pgxmock.NewResult("DELETE", 1))
|
||||
m.db.ExpectCommit()
|
||||
// There's also a rollback called after a commit, which is expected and will not have effect
|
||||
m.db.ExpectRollback()
|
||||
|
||||
// Act
|
||||
err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{
|
||||
|
|
@ -243,7 +246,7 @@ func TestValidMultiDeleteRequest(t *testing.T) {
|
|||
})
|
||||
|
||||
// Assert
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestInvalidMultiDeleteRequest(t *testing.T) {
|
||||
|
|
@ -251,8 +254,8 @@ func TestInvalidMultiDeleteRequest(t *testing.T) {
|
|||
m, _ := mockDatabase(t)
|
||||
defer m.db.Close()
|
||||
|
||||
m.mock.ExpectBegin()
|
||||
m.mock.ExpectRollback()
|
||||
m.db.ExpectBegin()
|
||||
m.db.ExpectRollback()
|
||||
|
||||
var operations []state.TransactionalStateOperation
|
||||
|
||||
|
|
@ -275,8 +278,8 @@ func TestInvalidMultiDeleteRequestNoKey(t *testing.T) {
|
|||
m, _ := mockDatabase(t)
|
||||
defer m.db.Close()
|
||||
|
||||
m.mock.ExpectBegin()
|
||||
m.mock.ExpectRollback()
|
||||
m.db.ExpectBegin()
|
||||
m.db.ExpectRollback()
|
||||
|
||||
var operations []state.TransactionalStateOperation
|
||||
|
||||
|
|
@ -299,23 +302,27 @@ func TestMultiOperationOrder(t *testing.T) {
|
|||
m, _ := mockDatabase(t)
|
||||
defer m.db.Close()
|
||||
|
||||
m.mock.ExpectBegin()
|
||||
m.mock.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
m.mock.ExpectExec("DELETE FROM").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
m.mock.ExpectCommit()
|
||||
|
||||
var operations []state.TransactionalStateOperation
|
||||
|
||||
operations = append(operations,
|
||||
state.TransactionalStateOperation{
|
||||
operations := []state.TransactionalStateOperation{
|
||||
{
|
||||
Operation: state.Upsert,
|
||||
Request: state.SetRequest{Key: "key1", Value: "value1"},
|
||||
},
|
||||
state.TransactionalStateOperation{
|
||||
{
|
||||
Operation: state.Delete,
|
||||
Request: state.DeleteRequest{Key: "key1"},
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
m.db.ExpectBegin()
|
||||
m.db.ExpectExec("INSERT INTO").
|
||||
WithArgs("key1", `"value1"`, false).
|
||||
WillReturnResult(pgxmock.NewResult("INSERT", 1))
|
||||
m.db.ExpectExec("DELETE FROM").
|
||||
WithArgs("key1").
|
||||
WillReturnResult(pgxmock.NewResult("DELETE", 1))
|
||||
m.db.ExpectCommit()
|
||||
// There's also a rollback called after a commit, which is expected and will not have effect
|
||||
m.db.ExpectRollback()
|
||||
|
||||
// Act
|
||||
err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{
|
||||
|
|
@ -323,7 +330,7 @@ func TestMultiOperationOrder(t *testing.T) {
|
|||
})
|
||||
|
||||
// Assert
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestInvalidBulkSetNoKey(t *testing.T) {
|
||||
|
|
@ -331,14 +338,13 @@ func TestInvalidBulkSetNoKey(t *testing.T) {
|
|||
m, _ := mockDatabase(t)
|
||||
defer m.db.Close()
|
||||
|
||||
m.mock.ExpectBegin()
|
||||
m.mock.ExpectRollback()
|
||||
m.db.ExpectBegin()
|
||||
m.db.ExpectRollback()
|
||||
|
||||
var sets []state.SetRequest
|
||||
|
||||
sets = append(sets, state.SetRequest{ // Set request without key is not valid for Set operation
|
||||
Value: "value1",
|
||||
})
|
||||
sets := []state.SetRequest{
|
||||
// Set request without key is not valid for Set operation
|
||||
{Value: "value1"},
|
||||
}
|
||||
|
||||
// Act
|
||||
err := m.pgDba.BulkSet(context.Background(), sets)
|
||||
|
|
@ -352,15 +358,13 @@ func TestInvalidBulkSetEmptyValue(t *testing.T) {
|
|||
m, _ := mockDatabase(t)
|
||||
defer m.db.Close()
|
||||
|
||||
m.mock.ExpectBegin()
|
||||
m.mock.ExpectRollback()
|
||||
m.db.ExpectBegin()
|
||||
m.db.ExpectRollback()
|
||||
|
||||
var sets []state.SetRequest
|
||||
|
||||
sets = append(sets, state.SetRequest{ // Set request without value is not valid for Set operation
|
||||
Key: "key1",
|
||||
Value: "",
|
||||
})
|
||||
sets := []state.SetRequest{
|
||||
// Set request without value is not valid for Set operation
|
||||
{Key: "key1", Value: "value1"},
|
||||
}
|
||||
|
||||
// Act
|
||||
err := m.pgDba.BulkSet(context.Background(), sets)
|
||||
|
|
@ -374,22 +378,26 @@ func TestValidBulkSet(t *testing.T) {
|
|||
m, _ := mockDatabase(t)
|
||||
defer m.db.Close()
|
||||
|
||||
m.mock.ExpectBegin()
|
||||
m.mock.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
m.mock.ExpectCommit()
|
||||
sets := []state.SetRequest{
|
||||
{
|
||||
Key: "key1",
|
||||
Value: "value1",
|
||||
},
|
||||
}
|
||||
|
||||
var sets []state.SetRequest
|
||||
|
||||
sets = append(sets, state.SetRequest{
|
||||
Key: "key1",
|
||||
Value: "value1",
|
||||
})
|
||||
m.db.ExpectBegin()
|
||||
m.db.ExpectExec("INSERT INTO").
|
||||
WithArgs("key1", `"value1"`, false).
|
||||
WillReturnResult(pgxmock.NewResult("INSERT", 1))
|
||||
m.db.ExpectCommit()
|
||||
// There's also a rollback called after a commit, which is expected and will not have effect
|
||||
m.db.ExpectRollback()
|
||||
|
||||
// Act
|
||||
err := m.pgDba.BulkSet(context.Background(), sets)
|
||||
|
||||
// Assert
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestInvalidBulkDeleteNoKey(t *testing.T) {
|
||||
|
|
@ -397,8 +405,8 @@ func TestInvalidBulkDeleteNoKey(t *testing.T) {
|
|||
m, _ := mockDatabase(t)
|
||||
defer m.db.Close()
|
||||
|
||||
m.mock.ExpectBegin()
|
||||
m.mock.ExpectRollback()
|
||||
m.db.ExpectBegin()
|
||||
m.db.ExpectRollback()
|
||||
|
||||
var deletes []state.DeleteRequest
|
||||
|
||||
|
|
@ -418,21 +426,23 @@ func TestValidBulkDelete(t *testing.T) {
|
|||
m, _ := mockDatabase(t)
|
||||
defer m.db.Close()
|
||||
|
||||
m.mock.ExpectBegin()
|
||||
m.mock.ExpectExec("DELETE FROM").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
m.mock.ExpectCommit()
|
||||
deletes := []state.DeleteRequest{
|
||||
{Key: "key1"},
|
||||
}
|
||||
|
||||
var deletes []state.DeleteRequest
|
||||
|
||||
deletes = append(deletes, state.DeleteRequest{
|
||||
Key: "key1",
|
||||
})
|
||||
m.db.ExpectBegin()
|
||||
m.db.ExpectExec("DELETE FROM").
|
||||
WithArgs("key1").
|
||||
WillReturnResult(pgxmock.NewResult("DELETE", 1))
|
||||
m.db.ExpectCommit()
|
||||
// There's also a rollback called after a commit, which is expected and will not have effect
|
||||
m.db.ExpectRollback()
|
||||
|
||||
// Act
|
||||
err := m.pgDba.BulkDelete(context.Background(), deletes)
|
||||
|
||||
// Assert
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func createSetRequest() state.SetRequest {
|
||||
|
|
@ -451,7 +461,7 @@ func createDeleteRequest() state.DeleteRequest {
|
|||
func mockDatabase(t *testing.T) (*mocks, error) {
|
||||
logger := logger.NewLogger("test")
|
||||
|
||||
db, mock, err := sqlmock.New()
|
||||
db, err := pgxmock.NewPool()
|
||||
if err != nil {
|
||||
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
|
|
@ -463,7 +473,6 @@ func mockDatabase(t *testing.T) (*mocks, error) {
|
|||
|
||||
return &mocks{
|
||||
db: db,
|
||||
mock: mock,
|
||||
pgDba: dba,
|
||||
}, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,18 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/dapr/components-contrib/metadata"
|
||||
"github.com/dapr/components-contrib/state"
|
||||
|
|
@ -162,7 +165,7 @@ func deleteItemThatDoesNotExist(t *testing.T, pgs *PostgreSQL) {
|
|||
Key: randomKey(),
|
||||
}
|
||||
err := pgs.Delete(context.Background(), deleteReq)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func multiWithSetOnly(t *testing.T, pgs *PostgreSQL) {
|
||||
|
|
@ -183,7 +186,7 @@ func multiWithSetOnly(t *testing.T, pgs *PostgreSQL) {
|
|||
err := pgs.Multi(context.Background(), &state.TransactionalStateRequest{
|
||||
Operations: operations,
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for _, set := range setRequests {
|
||||
assert.True(t, storeItemExists(t, set.Key))
|
||||
|
|
@ -213,7 +216,7 @@ func multiWithDeleteOnly(t *testing.T, pgs *PostgreSQL) {
|
|||
err := pgs.Multi(context.Background(), &state.TransactionalStateRequest{
|
||||
Operations: operations,
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for _, delete := range deleteRequests {
|
||||
assert.False(t, storeItemExists(t, delete.Key))
|
||||
|
|
@ -256,7 +259,7 @@ func multiWithDeleteAndSet(t *testing.T, pgs *PostgreSQL) {
|
|||
err := pgs.Multi(context.Background(), &state.TransactionalStateRequest{
|
||||
Operations: operations,
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for _, delete := range deleteRequests {
|
||||
assert.False(t, storeItemExists(t, delete.Key))
|
||||
|
|
@ -384,15 +387,16 @@ func setUpdatesTheUpdatedateField(t *testing.T, pgs *PostgreSQL) {
|
|||
|
||||
// insertdate should have a value and updatedate should be nil
|
||||
_, insertdate, updatedate := getRowData(t, key)
|
||||
assert.Nil(t, updatedate)
|
||||
assert.NotNil(t, insertdate)
|
||||
assert.Equal(t, "", updatedate.String)
|
||||
|
||||
// insertdate should not change, updatedate should have a value
|
||||
value = &fakeItem{Color: "aqua"}
|
||||
setItem(t, pgs, key, value, nil)
|
||||
_, newinsertdate, updatedate := getRowData(t, key)
|
||||
assert.Equal(t, insertdate, newinsertdate) // The insertdate should not change.
|
||||
assert.NotEqual(t, "", updatedate.String)
|
||||
assert.NotNil(t, updatedate)
|
||||
assert.NotNil(t, newinsertdate)
|
||||
assert.True(t, insertdate.Equal(*newinsertdate)) // The insertdate should not change.
|
||||
|
||||
deleteItem(t, pgs, key, nil)
|
||||
}
|
||||
|
|
@ -420,7 +424,7 @@ func testBulkSetAndBulkDelete(t *testing.T, pgs *PostgreSQL) {
|
|||
}
|
||||
|
||||
err := pgs.BulkSet(context.Background(), setReq)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, storeItemExists(t, setReq[0].Key))
|
||||
assert.True(t, storeItemExists(t, setReq[1].Key))
|
||||
|
||||
|
|
@ -434,7 +438,7 @@ func testBulkSetAndBulkDelete(t *testing.T, pgs *PostgreSQL) {
|
|||
}
|
||||
|
||||
err = pgs.BulkDelete(context.Background(), deleteReq)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, storeItemExists(t, setReq[0].Key))
|
||||
assert.False(t, storeItemExists(t, setReq[1].Key))
|
||||
}
|
||||
|
|
@ -491,7 +495,7 @@ func setItem(t *testing.T, pgs *PostgreSQL, key string, value interface{}, etag
|
|||
}
|
||||
|
||||
err := pgs.Set(context.Background(), setReq)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
itemExists := storeItemExists(t, key)
|
||||
assert.True(t, itemExists)
|
||||
}
|
||||
|
|
@ -524,25 +528,28 @@ func deleteItem(t *testing.T, pgs *PostgreSQL, key string, etag *string) {
|
|||
}
|
||||
|
||||
func storeItemExists(t *testing.T, key string) bool {
|
||||
db, err := sql.Open("pgx", getConnectionString())
|
||||
assert.Nil(t, err)
|
||||
defer db.Close()
|
||||
ctx := context.Background()
|
||||
db, err := pgx.Connect(ctx, getConnectionString())
|
||||
require.NoError(t, err)
|
||||
defer db.Close(ctx)
|
||||
|
||||
exists := false
|
||||
statement := fmt.Sprintf(`SELECT EXISTS (SELECT FROM %s WHERE key = $1)`, defaultTableName)
|
||||
err = db.QueryRow(statement, key).Scan(&exists)
|
||||
assert.Nil(t, err)
|
||||
err = db.QueryRow(ctx, statement, key).Scan(&exists)
|
||||
assert.NoError(t, err)
|
||||
|
||||
return exists
|
||||
}
|
||||
|
||||
func getRowData(t *testing.T, key string) (returnValue string, insertdate sql.NullString, updatedate sql.NullString) {
|
||||
db, err := sql.Open("pgx", getConnectionString())
|
||||
assert.Nil(t, err)
|
||||
defer db.Close()
|
||||
func getRowData(t *testing.T, key string) (returnValue string, insertdate *time.Time, updatedate *time.Time) {
|
||||
ctx := context.Background()
|
||||
db, err := pgx.Connect(ctx, getConnectionString())
|
||||
require.NoError(t, err)
|
||||
defer db.Close(ctx)
|
||||
|
||||
err = db.QueryRow(fmt.Sprintf("SELECT value, insertdate, updatedate FROM %s WHERE key = $1", defaultTableName), key).Scan(&returnValue, &insertdate, &updatedate)
|
||||
assert.Nil(t, err)
|
||||
err = db.QueryRow(ctx, fmt.Sprintf("SELECT value, insertdate, updatedate FROM %s WHERE key = $1", defaultTableName), key).
|
||||
Scan(&returnValue, &insertdate, &updatedate)
|
||||
assert.NoError(t, err)
|
||||
|
||||
return returnValue, insertdate, updatedate
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ package postgresql
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
|
@ -140,8 +139,8 @@ func (q *Query) Finalize(filters string, qq *query.Query) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (q *Query) execute(ctx context.Context, logger logger.Logger, db *sql.DB) ([]state.QueryItem, string, error) {
|
||||
rows, err := db.QueryContext(ctx, q.query, q.params...)
|
||||
func (q *Query) execute(ctx context.Context, logger logger.Logger, db dbquerier) ([]state.QueryItem, string, error) {
|
||||
rows, err := db.Query(ctx, q.query, q.params...)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
|
@ -152,7 +151,7 @@ func (q *Query) execute(ctx context.Context, logger logger.Logger, db *sql.DB) (
|
|||
var (
|
||||
key string
|
||||
data []byte
|
||||
etag uint64 // Postgres uses uint32, but FormatUint requires uint64, so using uint64 directly to avoid re-allocations
|
||||
etag uint32
|
||||
)
|
||||
if err = rows.Scan(&key, &data, &etag); err != nil {
|
||||
return nil, "", err
|
||||
|
|
@ -160,7 +159,7 @@ func (q *Query) execute(ctx context.Context, logger logger.Logger, db *sql.DB) (
|
|||
result := state.QueryItem{
|
||||
Key: key,
|
||||
Data: data,
|
||||
ETag: ptr.Of(strconv.FormatUint(etag, 10)),
|
||||
ETag: ptr.Of(strconv.FormatUint(uint64(etag), 10)),
|
||||
}
|
||||
ret = append(ret, result)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -62,6 +62,7 @@ require (
|
|||
github.com/imdario/mergo v0.3.12 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect
|
||||
github.com/jackc/puddle/v2 v2.1.2 // indirect
|
||||
github.com/jhump/protoreflect v1.13.0 // indirect
|
||||
github.com/josharian/intern v1.0.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
|
|
|
|||
|
|
@ -37,7 +37,6 @@ github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a h1:
|
|||
github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go.mod h1:C0A1KeiVHs+trY6gUTPhhGammbrZ30ZfXRW/nuT7HLw=
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
|
||||
github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ=
|
||||
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
|
||||
github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig=
|
||||
|
|
@ -288,6 +287,8 @@ github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHF
|
|||
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E=
|
||||
github.com/jackc/pgx/v5 v5.2.0 h1:NdPpngX0Y6z6XDFKqmFQaE+bCtkqzvQIOt1wvBlAqs8=
|
||||
github.com/jackc/pgx/v5 v5.2.0/go.mod h1:Ptn7zmohNsWEsdxRawMzk3gaKma2obW+NWTnKa0S4nk=
|
||||
github.com/jackc/puddle/v2 v2.1.2 h1:0f7vaaXINONKTsxYDn4otOAiJanX/BMeAtY//BXqzlg=
|
||||
github.com/jackc/puddle/v2 v2.1.2/go.mod h1:2lpufsF5mRHO6SuZkm0fNYxM6SWHfvyFj62KwNzgels=
|
||||
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
||||
github.com/jhump/gopoet v0.0.0-20190322174617-17282ff210b3/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI=
|
||||
github.com/jhump/gopoet v0.1.0/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI=
|
||||
|
|
@ -385,6 +386,7 @@ github.com/openzipkin/zipkin-go v0.4.1/go.mod h1:qY0VqDSN1pOBN94dBc6w2GJlWLiovAy
|
|||
github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
|
||||
github.com/pascaldekloe/goe v0.1.0 h1:cBOtyMzM9HTpWjXfbbunk26uA6nG3a8n06Wieeh0MwY=
|
||||
github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
|
||||
github.com/pashagolub/pgxmock/v2 v2.4.0 h1:jNv7+svrNoMc31mvllSS/u7P2pT3gS3uY7DPRKIJNSY=
|
||||
github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 h1:Ii+DKncOVM8Cu1Hc+ETb5K+23HdAMvESYE3ZJ5b5cMI=
|
||||
github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE=
|
||||
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ package main
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
|
@ -706,7 +705,7 @@ func getMigrationLevel(dbClient *pgx.Conn, metadataTable string) (level string,
|
|||
err = dbClient.
|
||||
QueryRow(ctx, fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, metadataTable)).
|
||||
Scan(&level)
|
||||
if err != nil && errors.Is(err, sql.ErrNoRows) {
|
||||
if err != nil && errors.Is(err, pgx.ErrNoRows) {
|
||||
err = nil
|
||||
level = ""
|
||||
}
|
||||
|
|
@ -766,7 +765,7 @@ func loadLastCleanupInterval(ctx context.Context, dbClient *pgx.Conn, table stri
|
|||
).
|
||||
Scan(&lastCleanup)
|
||||
cancel()
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
|
|
@ -790,7 +789,7 @@ func getValueFromMetadataTable(ctx context.Context, dbClient *pgx.Conn, table, k
|
|||
QueryRow(queryCtx, fmt.Sprintf("SELECT value FROM %s WHERE key = $1", table), key).
|
||||
Scan(&value)
|
||||
cancel()
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
|
|
|
|||
|
|
@ -272,7 +272,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
|
|||
req.Metadata = map[string]string{metadata.ContentType: scenario.contentType}
|
||||
}
|
||||
res, err := statestore.Get(context.Background(), req)
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assertEquals(t, scenario.value, res)
|
||||
}
|
||||
}
|
||||
|
|
@ -287,13 +287,13 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
|
|||
t.Logf("Querying value presence for %s", scenario.query)
|
||||
var req state.QueryRequest
|
||||
err := json.Unmarshal([]byte(scenario.query), &req.Query)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
req.Metadata = map[string]string{
|
||||
metadata.ContentType: contenttype.JSONContentType,
|
||||
metadata.QueryIndexName: "qIndx",
|
||||
}
|
||||
resp, err := querier.Query(context.Background(), &req)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(scenario.results), len(resp.Results))
|
||||
for i := range scenario.results {
|
||||
var expected, actual interface{}
|
||||
|
|
|
|||
Loading…
Reference in New Issue