Changed postgres state store to use pgx instead of db/sql

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
ItalyPaleAle 2022-12-20 22:45:38 +00:00
parent 2582f154bd
commit 80b1bb14b7
11 changed files with 269 additions and 218 deletions

View File

@ -15,7 +15,6 @@ package postgres
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
@ -114,7 +113,7 @@ func (p *ConfigurationStore) Init(metadata configuration.Metadata) error {
exists := false
err = p.client.QueryRow(ctx, QueryTableExists, p.metadata.configTable).Scan(&exists)
if err != nil {
if err == sql.ErrNoRows {
if errors.Is(err, pgx.ErrNoRows) {
return fmt.Errorf(ErrorMissingTable, p.metadata.configTable)
}
return fmt.Errorf("error in checking if configtable '%s' exists - '%w'", p.metadata.configTable, err)
@ -135,7 +134,7 @@ func (p *ConfigurationStore) Get(ctx context.Context, req *configuration.GetRequ
rows, err := p.client.Query(ctx, query, params...)
if err != nil {
// If no rows exist, return an empty response, otherwise return the error.
if err == sql.ErrNoRows {
if errors.Is(err, pgx.ErrNoRows) {
return &configuration.GetResponse{}, nil
}
return nil, fmt.Errorf("error in querying configuration store: '%w'", err)

View File

@ -15,7 +15,9 @@ package postgresql
import (
"context"
"database/sql"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/dapr/components-contrib/state"
)
@ -34,10 +36,9 @@ type dbAccess interface {
}
// Interface that contains methods for querying.
// Applies to both *sql.DB and *sql.Tx
// Applies to *pgx.Conn, *pgxpool.Pool, and pgx.Tx
type dbquerier interface {
Exec(query string, args ...any) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryRow(query string, args ...any) *sql.Row
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
Query(context.Context, string, ...interface{}) (pgx.Rows, error)
QueryRow(context.Context, string, ...interface{}) pgx.Row
}

View File

@ -15,20 +15,21 @@ package postgresql
import (
"context"
"database/sql"
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/jackc/pgx/v5"
"github.com/dapr/kit/logger"
)
// Performs migrations for the database schema
type migrations struct {
Logger logger.Logger
Conn *sql.DB
Conn pgxPoolConn
StateTableName string
MetadataTableName string
}
@ -42,7 +43,7 @@ func (m *migrations) Perform(ctx context.Context) error {
// Long timeout here as this query may block
queryCtx, cancel := context.WithTimeout(ctx, time.Minute)
_, err := m.Conn.ExecContext(queryCtx, "SELECT pg_advisory_lock($1)", lockID)
_, err := m.Conn.Exec(queryCtx, "SELECT pg_advisory_lock($1)", lockID)
cancel()
if err != nil {
return fmt.Errorf("faild to acquire advisory lock: %w", err)
@ -51,7 +52,7 @@ func (m *migrations) Perform(ctx context.Context) error {
// Release the lock
defer func() {
queryCtx, cancel = context.WithTimeout(ctx, time.Minute)
_, err = m.Conn.ExecContext(queryCtx, "SELECT pg_advisory_unlock($1)", lockID)
_, err = m.Conn.Exec(queryCtx, "SELECT pg_advisory_unlock($1)", lockID)
cancel()
if err != nil {
// Panicking here, as this forcibly closes the session and thus ensures we are not leaving locks hanging around
@ -84,11 +85,11 @@ func (m *migrations) Perform(ctx context.Context) error {
)
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
err = m.Conn.
QueryRowContext(queryCtx,
QueryRow(queryCtx,
fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, m.MetadataTableName),
).Scan(&migrationLevelStr)
cancel()
if errors.Is(err, sql.ErrNoRows) {
if errors.Is(err, pgx.ErrNoRows) {
// If there's no row...
migrationLevel = 0
} else if err != nil {
@ -109,7 +110,7 @@ func (m *migrations) Perform(ctx context.Context) error {
}
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
_, err = m.Conn.ExecContext(queryCtx,
_, err = m.Conn.Exec(queryCtx,
fmt.Sprintf(`INSERT INTO %s (key, value) VALUES ('migrations', $1) ON CONFLICT (key) DO UPDATE SET value = $1`, m.MetadataTableName),
strconv.Itoa(i+1),
)
@ -126,7 +127,7 @@ func (m migrations) createMetadataTable(ctx context.Context) error {
m.Logger.Infof("Creating metadata table '%s'", m.MetadataTableName)
// Add an "IF NOT EXISTS" in case another Dapr sidecar is creating the same table at the same time
// In the next step we'll acquire a lock so there won't be issues with concurrency
_, err := m.Conn.Exec(fmt.Sprintf(
_, err := m.Conn.Exec(ctx, fmt.Sprintf(
`CREATE TABLE IF NOT EXISTS %s (
key text NOT NULL PRIMARY KEY,
value text NOT NULL
@ -148,7 +149,7 @@ func (m migrations) tableExists(ctx context.Context, tableName string) (exists b
if schema == "" {
err = m.Conn.
QueryRowContext(
QueryRow(
ctx,
`SELECT table_name, table_schema
FROM information_schema.tables
@ -158,7 +159,7 @@ func (m migrations) tableExists(ctx context.Context, tableName string) (exists b
Scan(&table, &schema)
} else {
err = m.Conn.
QueryRowContext(
QueryRow(
ctx,
`SELECT table_name, table_schema
FROM information_schema.tables
@ -168,7 +169,7 @@ func (m migrations) tableExists(ctx context.Context, tableName string) (exists b
Scan(&table, &schema)
}
if err != nil && errors.Is(err, sql.ErrNoRows) {
if err != nil && errors.Is(err, pgx.ErrNoRows) {
return false, "", "", nil
} else if err != nil {
return false, "", "", fmt.Errorf("failed to check if table '%s' exists: %w", tableName, err)
@ -195,6 +196,7 @@ var allMigrations = [2]func(ctx context.Context, m *migrations) error{
// We need to add an "IF NOT EXISTS" because we may be migrating from when we did not use a metadata table
m.Logger.Infof("Creating state table '%s'", m.StateTableName)
_, err := m.Conn.Exec(
ctx,
fmt.Sprintf(
`CREATE TABLE IF NOT EXISTS %s (
key text NOT NULL PRIMARY KEY,
@ -215,7 +217,7 @@ var allMigrations = [2]func(ctx context.Context, m *migrations) error{
// Migration 1: add the "expiredate" column
func(ctx context.Context, m *migrations) error {
m.Logger.Infof("Adding expiredate column to state table '%s'", m.StateTableName)
_, err := m.Conn.Exec(fmt.Sprintf(
_, err := m.Conn.Exec(ctx, fmt.Sprintf(
`ALTER TABLE %s ADD expiredate TIMESTAMP WITH TIME ZONE`,
m.StateTableName,
))

View File

@ -15,7 +15,6 @@ package postgresql
import (
"context"
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
@ -23,15 +22,16 @@ import (
"strconv"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/state"
"github.com/dapr/components-contrib/state/query"
stateutils "github.com/dapr/components-contrib/state/utils"
"github.com/dapr/kit/logger"
"github.com/dapr/kit/ptr"
// Blank import for the underlying Postgres driver.
_ "github.com/jackc/pgx/v5/stdlib"
)
const (
@ -43,12 +43,24 @@ const (
var errMissingConnectionString = errors.New("missing connection string")
// Interface that applies to *pgxpool.Pool.
// We need this to be able to mock the connection in tests.
type pgxPoolConn interface {
Begin(context.Context) (pgx.Tx, error)
BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error)
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
Query(context.Context, string, ...interface{}) (pgx.Rows, error)
QueryRow(context.Context, string, ...interface{}) pgx.Row
Ping(context.Context) error
Close()
}
// PostgresDBAccess implements dbaccess.
type PostgresDBAccess struct {
logger logger.Logger
metadata postgresMetadataStruct
cleanupInterval *time.Duration
db *sql.DB
db pgxPoolConn
ctx context.Context
cancel context.CancelFunc
}
@ -81,23 +93,31 @@ func (p *PostgresDBAccess) Init(meta state.Metadata) error {
return err
}
db, err := sql.Open("pgx", p.metadata.ConnectionString)
config, err := pgxpool.ParseConfig(p.metadata.ConnectionString)
if err != nil {
err = fmt.Errorf("failed to parse connection string: %w", err)
p.logger.Error(err)
return err
}
if p.metadata.ConnectionMaxIdleTime > 0 {
config.MaxConnIdleTime = p.metadata.ConnectionMaxIdleTime
}
connCtx, connCancel := context.WithTimeout(p.ctx, 30*time.Second)
p.db, err = pgxpool.NewWithConfig(connCtx, config)
connCancel()
if err != nil {
err = fmt.Errorf("failed to connect to the database: %w", err)
p.logger.Error(err)
return err
}
p.db = db
pingCtx, pingCancel := context.WithTimeout(p.ctx, 30*time.Second)
pingErr := db.PingContext(pingCtx)
err = p.db.Ping(pingCtx)
pingCancel()
if pingErr != nil {
return pingErr
}
p.db.SetConnMaxIdleTime(p.metadata.ConnectionMaxIdleTime)
if err != nil {
err = fmt.Errorf("failed to ping the database: %w", err)
p.logger.Error(err)
return err
}
@ -117,8 +137,9 @@ func (p *PostgresDBAccess) Init(meta state.Metadata) error {
return nil
}
func (p *PostgresDBAccess) GetDB() *sql.DB {
return p.db
func (p *PostgresDBAccess) GetDB() *pgxpool.Pool {
// We can safely cast to *pgxpool.Pool because this method is never used in unit tests where we mock the DB
return p.db.(*pgxpool.Pool)
}
func (p *PostgresDBAccess) ParseMetadata(meta state.Metadata) error {
@ -193,8 +214,6 @@ func (p *PostgresDBAccess) doSet(parentCtx context.Context, db dbquerier, req *s
ttlSeconds = *ttl
}
var result sql.Result
// Sprintf is required for table name because query.DB does not substitute parameters for table names.
// Other parameters use query.DB parameter substitution.
var (
@ -246,8 +265,8 @@ func (p *PostgresDBAccess) doSet(parentCtx context.Context, db dbquerier, req *s
} else {
queryExpiredate = "NULL"
}
result, err = db.ExecContext(parentCtx, fmt.Sprintf(query, p.metadata.TableName, queryExpiredate), params...)
result, err := db.Exec(parentCtx, fmt.Sprintf(query, p.metadata.TableName, queryExpiredate), params...)
if err != nil {
if req.ETag != nil && *req.ETag != "" {
return state.NewETagError(state.ETagMismatch, err)
@ -255,11 +274,7 @@ func (p *PostgresDBAccess) doSet(parentCtx context.Context, db dbquerier, req *s
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows != 1 {
if result.RowsAffected() != 1 {
return errors.New("no item was updated")
}
@ -267,11 +282,20 @@ func (p *PostgresDBAccess) doSet(parentCtx context.Context, db dbquerier, req *s
}
func (p *PostgresDBAccess) BulkSet(parentCtx context.Context, req []state.SetRequest) error {
tx, err := p.db.BeginTx(parentCtx, nil)
ctx, cancel := context.WithTimeout(parentCtx, 30*time.Second)
tx, err := p.db.Begin(ctx)
cancel()
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback()
defer func() {
rollbackCtx, rollbackCancel := context.WithTimeout(parentCtx, 30*time.Second)
rollbackErr := tx.Rollback(rollbackCtx)
rollbackCancel()
if rollbackErr != nil && !errors.Is(rollbackErr, pgx.ErrTxClosed) {
p.logger.Errorf("Failed to rollback transaction in BulkSet: %v", rollbackErr)
}
}()
if len(req) > 0 {
for i := range req {
@ -282,7 +306,9 @@ func (p *PostgresDBAccess) BulkSet(parentCtx context.Context, req []state.SetReq
}
}
err = tx.Commit()
ctx, cancel = context.WithTimeout(parentCtx, 30*time.Second)
err = tx.Commit(ctx)
cancel()
if err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
@ -299,7 +325,7 @@ func (p *PostgresDBAccess) Get(parentCtx context.Context, req *state.GetRequest)
var (
value []byte
isBinary bool
etag uint64 // Postgres uses uint32, but FormatUint requires uint64, so using uint64 directly to avoid re-allocations
etag uint32
)
query := `SELECT
value, isbinary, xmin AS etag
@ -307,12 +333,11 @@ func (p *PostgresDBAccess) Get(parentCtx context.Context, req *state.GetRequest)
WHERE
key = $1
AND (expiredate IS NULL OR expiredate >= CURRENT_TIMESTAMP)`
err := p.db.
QueryRowContext(parentCtx, fmt.Sprintf(query, p.metadata.TableName), req.Key).
err := p.db.QueryRow(parentCtx, fmt.Sprintf(query, p.metadata.TableName), req.Key).
Scan(&value, &isBinary, &etag)
if err != nil {
// If no rows exist, return an empty response, otherwise return the error.
if err == sql.ErrNoRows {
if err == pgx.ErrNoRows {
return &state.GetResponse{}, nil
}
return nil, err
@ -334,14 +359,14 @@ func (p *PostgresDBAccess) Get(parentCtx context.Context, req *state.GetRequest)
return &state.GetResponse{
Data: data,
ETag: ptr.Of(strconv.FormatUint(etag, 10)),
ETag: ptr.Of(strconv.FormatUint(uint64(etag), 10)),
Metadata: req.Metadata,
}, nil
}
return &state.GetResponse{
Data: value,
ETag: ptr.Of(strconv.FormatUint(etag, 10)),
ETag: ptr.Of(strconv.FormatUint(uint64(etag), 10)),
Metadata: req.Metadata,
}, nil
}
@ -356,10 +381,9 @@ func (p *PostgresDBAccess) doDelete(parentCtx context.Context, db dbquerier, req
return errors.New("missing key in delete operation")
}
var result sql.Result
var result pgconn.CommandTag
if req.ETag == nil || *req.ETag == "" {
result, err = db.ExecContext(parentCtx, "DELETE FROM state WHERE key = $1", req.Key)
result, err = db.Exec(parentCtx, "DELETE FROM state WHERE key = $1", req.Key)
} else {
// Convert req.ETag to uint32 for postgres XID compatibility
var etag64 uint64
@ -367,20 +391,15 @@ func (p *PostgresDBAccess) doDelete(parentCtx context.Context, db dbquerier, req
if err != nil {
return state.NewETagError(state.ETagInvalid, err)
}
etag := uint32(etag64)
result, err = db.ExecContext(parentCtx, "DELETE FROM state WHERE key = $1 AND xmin = $2", req.Key, etag)
result, err = db.Exec(parentCtx, "DELETE FROM state WHERE key = $1 AND xmin = $2", req.Key, uint32(etag64))
}
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
rows := result.RowsAffected()
if rows != 1 && req.ETag != nil && *req.ETag != "" {
return state.NewETagError(state.ETagMismatch, nil)
}
@ -389,11 +408,20 @@ func (p *PostgresDBAccess) doDelete(parentCtx context.Context, db dbquerier, req
}
func (p *PostgresDBAccess) BulkDelete(parentCtx context.Context, req []state.DeleteRequest) error {
tx, err := p.db.BeginTx(parentCtx, nil)
ctx, cancel := context.WithTimeout(parentCtx, 30*time.Second)
tx, err := p.db.Begin(ctx)
cancel()
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback()
defer func() {
rollbackCtx, rollbackCancel := context.WithTimeout(parentCtx, 30*time.Second)
rollbackErr := tx.Rollback(rollbackCtx)
rollbackCancel()
if rollbackErr != nil && !errors.Is(rollbackErr, pgx.ErrTxClosed) {
p.logger.Errorf("Failed to rollback transaction in BulkDelete: %v", rollbackErr)
}
}()
if len(req) > 0 {
for i := range req {
@ -404,7 +432,9 @@ func (p *PostgresDBAccess) BulkDelete(parentCtx context.Context, req []state.Del
}
}
err = tx.Commit()
ctx, cancel = context.WithTimeout(parentCtx, 30*time.Second)
err = tx.Commit(ctx)
cancel()
if err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
@ -413,11 +443,20 @@ func (p *PostgresDBAccess) BulkDelete(parentCtx context.Context, req []state.Del
}
func (p *PostgresDBAccess) ExecuteMulti(parentCtx context.Context, request *state.TransactionalStateRequest) error {
tx, err := p.db.BeginTx(parentCtx, nil)
ctx, cancel := context.WithTimeout(parentCtx, 30*time.Second)
tx, err := p.db.Begin(ctx)
cancel()
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback()
defer func() {
rollbackCtx, rollbackCancel := context.WithTimeout(parentCtx, 30*time.Second)
rollbackErr := tx.Rollback(rollbackCtx)
rollbackCancel()
if rollbackErr != nil && !errors.Is(rollbackErr, pgx.ErrTxClosed) {
p.logger.Errorf("Failed to rollback transaction in ExecMulti: %v", rollbackErr)
}
}()
for _, o := range request.Operations {
switch o.Operation {
@ -450,7 +489,9 @@ func (p *PostgresDBAccess) ExecuteMulti(parentCtx context.Context, request *stat
}
}
err = tx.Commit()
ctx, cancel = context.WithTimeout(parentCtx, 30*time.Second)
err = tx.Commit(ctx)
cancel()
if err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
@ -520,19 +561,13 @@ func (p *PostgresDBAccess) CleanupExpired(ctx context.Context) error {
// Note we're not using the transaction here as we don't want this to be rolled back half-way or to lock the table unnecessarily
// Need to use fmt.Sprintf because we can't parametrize a table name
// Note we are not setting a timeout here as this query can take a "long" time, especially if there's no index on expiredate
//nolint:gosec
stmt := fmt.Sprintf(`DELETE FROM %s WHERE expiredate IS NOT NULL AND expiredate < CURRENT_TIMESTAMP`, p.metadata.TableName)
res, err := p.db.ExecContext(ctx, stmt)
res, err := p.db.Exec(ctx, stmt)
if err != nil {
return fmt.Errorf("failed to execute query: %w", err)
}
cleaned, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("failed to count affected rows: %w", err)
}
p.logger.Infof("Removed %d expired rows", cleaned)
p.logger.Infof("Removed %d expired rows", res.RowsAffected())
return nil
}
@ -540,7 +575,7 @@ func (p *PostgresDBAccess) CleanupExpired(ctx context.Context) error {
// Returns true if the row was updated, which means that the cleanup can proceed.
func (p *PostgresDBAccess) UpdateLastCleanup(ctx context.Context, db dbquerier, cleanupInterval time.Duration) (bool, error) {
queryCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
res, err := db.ExecContext(queryCtx,
res, err := db.Exec(queryCtx,
fmt.Sprintf(`INSERT INTO %[1]s (key, value)
VALUES ('last-cleanup', CURRENT_TIMESTAMP)
ON CONFLICT (key)
@ -554,11 +589,7 @@ func (p *PostgresDBAccess) UpdateLastCleanup(ctx context.Context, db dbquerier,
return true, fmt.Errorf("failed to execute query: %w", err)
}
n, err := res.RowsAffected()
if err != nil {
return true, fmt.Errorf("failed to count affected rows: %w", err)
}
n := res.RowsAffected()
return n > 0, nil
}
@ -569,7 +600,8 @@ func (p *PostgresDBAccess) Close() error {
p.cancel = nil
}
if p.db != nil {
return p.db.Close()
p.db.Close()
p.db = nil
}
return nil

View File

@ -16,24 +16,20 @@ package postgresql
import (
"context"
"database/sql"
"encoding/json"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
pgxmock "github.com/pashagolub/pgxmock/v2"
"github.com/stretchr/testify/assert"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/state"
"github.com/dapr/kit/logger"
// Blank import for pgx
_ "github.com/jackc/pgx/v5/stdlib"
)
type mocks struct {
db *sql.DB
mock sqlmock.Sqlmock
db pgxmock.PgxPoolIface
pgDba *PostgresDBAccess
}
@ -67,7 +63,7 @@ func TestGetSetValid(t *testing.T) {
}
set, err := getSet(operation)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, "key1", set.Key)
}
@ -101,7 +97,7 @@ func TestGetDeleteValid(t *testing.T) {
}
delete, err := getDelete(operation)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, "key1", delete.Key)
}
@ -110,8 +106,10 @@ func TestMultiWithNoRequests(t *testing.T) {
m, _ := mockDatabase(t)
defer m.db.Close()
m.mock.ExpectBegin()
m.mock.ExpectCommit()
m.db.ExpectBegin()
m.db.ExpectCommit()
// There's also a rollback called after a commit, which is expected and will not have effect
m.db.ExpectRollback()
var operations []state.TransactionalStateOperation
@ -121,7 +119,7 @@ func TestMultiWithNoRequests(t *testing.T) {
})
// Assert
assert.Nil(t, err)
assert.NoError(t, err)
}
func TestInvalidMultiInvalidAction(t *testing.T) {
@ -129,8 +127,8 @@ func TestInvalidMultiInvalidAction(t *testing.T) {
m, _ := mockDatabase(t)
defer m.db.Close()
m.mock.ExpectBegin()
m.mock.ExpectRollback()
m.db.ExpectBegin()
m.db.ExpectRollback()
var operations []state.TransactionalStateOperation
@ -153,16 +151,19 @@ func TestValidSetRequest(t *testing.T) {
m, _ := mockDatabase(t)
defer m.db.Close()
m.mock.ExpectBegin()
m.mock.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(1, 1))
m.mock.ExpectCommit()
setReq := createSetRequest()
operations := []state.TransactionalStateOperation{
{Operation: state.Upsert, Request: setReq},
}
val, _ := json.Marshal(setReq.Value)
var operations []state.TransactionalStateOperation
operations = append(operations, state.TransactionalStateOperation{
Operation: state.Upsert,
Request: createSetRequest(),
})
m.db.ExpectBegin()
m.db.ExpectExec("INSERT INTO").
WithArgs(setReq.Key, string(val), false).
WillReturnResult(pgxmock.NewResult("INSERT", 1))
m.db.ExpectCommit()
// There's also a rollback called after a commit, which is expected and will not have effect
m.db.ExpectRollback()
// Act
err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{
@ -170,7 +171,7 @@ func TestValidSetRequest(t *testing.T) {
})
// Assert
assert.Nil(t, err)
assert.NoError(t, err)
}
func TestInvalidMultiSetRequest(t *testing.T) {
@ -178,15 +179,15 @@ func TestInvalidMultiSetRequest(t *testing.T) {
m, _ := mockDatabase(t)
defer m.db.Close()
m.mock.ExpectBegin()
m.mock.ExpectRollback()
m.db.ExpectBegin()
m.db.ExpectRollback()
var operations []state.TransactionalStateOperation
operations = append(operations, state.TransactionalStateOperation{
Operation: state.Upsert,
Request: createDeleteRequest(), // Delete request is not valid for Upsert operation
})
operations := []state.TransactionalStateOperation{
{
Operation: state.Upsert,
Request: createDeleteRequest(), // Delete request is not valid for Upsert operation
},
}
// Act
err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{
@ -202,8 +203,8 @@ func TestInvalidMultiSetRequestNoKey(t *testing.T) {
m, _ := mockDatabase(t)
defer m.db.Close()
m.mock.ExpectBegin()
m.mock.ExpectRollback()
m.db.ExpectBegin()
m.db.ExpectRollback()
var operations []state.TransactionalStateOperation
@ -226,16 +227,18 @@ func TestValidMultiDeleteRequest(t *testing.T) {
m, _ := mockDatabase(t)
defer m.db.Close()
m.mock.ExpectBegin()
m.mock.ExpectExec("DELETE FROM").WillReturnResult(sqlmock.NewResult(1, 1))
m.mock.ExpectCommit()
deleteReq := createDeleteRequest()
operations := []state.TransactionalStateOperation{
{Operation: state.Delete, Request: deleteReq},
}
var operations []state.TransactionalStateOperation
operations = append(operations, state.TransactionalStateOperation{
Operation: state.Delete,
Request: createDeleteRequest(),
})
m.db.ExpectBegin()
m.db.ExpectExec("DELETE FROM").
WithArgs(deleteReq.Key).
WillReturnResult(pgxmock.NewResult("DELETE", 1))
m.db.ExpectCommit()
// There's also a rollback called after a commit, which is expected and will not have effect
m.db.ExpectRollback()
// Act
err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{
@ -243,7 +246,7 @@ func TestValidMultiDeleteRequest(t *testing.T) {
})
// Assert
assert.Nil(t, err)
assert.NoError(t, err)
}
func TestInvalidMultiDeleteRequest(t *testing.T) {
@ -251,8 +254,8 @@ func TestInvalidMultiDeleteRequest(t *testing.T) {
m, _ := mockDatabase(t)
defer m.db.Close()
m.mock.ExpectBegin()
m.mock.ExpectRollback()
m.db.ExpectBegin()
m.db.ExpectRollback()
var operations []state.TransactionalStateOperation
@ -275,8 +278,8 @@ func TestInvalidMultiDeleteRequestNoKey(t *testing.T) {
m, _ := mockDatabase(t)
defer m.db.Close()
m.mock.ExpectBegin()
m.mock.ExpectRollback()
m.db.ExpectBegin()
m.db.ExpectRollback()
var operations []state.TransactionalStateOperation
@ -299,23 +302,27 @@ func TestMultiOperationOrder(t *testing.T) {
m, _ := mockDatabase(t)
defer m.db.Close()
m.mock.ExpectBegin()
m.mock.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(1, 1))
m.mock.ExpectExec("DELETE FROM").WillReturnResult(sqlmock.NewResult(1, 1))
m.mock.ExpectCommit()
var operations []state.TransactionalStateOperation
operations = append(operations,
state.TransactionalStateOperation{
operations := []state.TransactionalStateOperation{
{
Operation: state.Upsert,
Request: state.SetRequest{Key: "key1", Value: "value1"},
},
state.TransactionalStateOperation{
{
Operation: state.Delete,
Request: state.DeleteRequest{Key: "key1"},
},
)
}
m.db.ExpectBegin()
m.db.ExpectExec("INSERT INTO").
WithArgs("key1", `"value1"`, false).
WillReturnResult(pgxmock.NewResult("INSERT", 1))
m.db.ExpectExec("DELETE FROM").
WithArgs("key1").
WillReturnResult(pgxmock.NewResult("DELETE", 1))
m.db.ExpectCommit()
// There's also a rollback called after a commit, which is expected and will not have effect
m.db.ExpectRollback()
// Act
err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{
@ -323,7 +330,7 @@ func TestMultiOperationOrder(t *testing.T) {
})
// Assert
assert.Nil(t, err)
assert.NoError(t, err)
}
func TestInvalidBulkSetNoKey(t *testing.T) {
@ -331,14 +338,13 @@ func TestInvalidBulkSetNoKey(t *testing.T) {
m, _ := mockDatabase(t)
defer m.db.Close()
m.mock.ExpectBegin()
m.mock.ExpectRollback()
m.db.ExpectBegin()
m.db.ExpectRollback()
var sets []state.SetRequest
sets = append(sets, state.SetRequest{ // Set request without key is not valid for Set operation
Value: "value1",
})
sets := []state.SetRequest{
// Set request without key is not valid for Set operation
{Value: "value1"},
}
// Act
err := m.pgDba.BulkSet(context.Background(), sets)
@ -352,15 +358,13 @@ func TestInvalidBulkSetEmptyValue(t *testing.T) {
m, _ := mockDatabase(t)
defer m.db.Close()
m.mock.ExpectBegin()
m.mock.ExpectRollback()
m.db.ExpectBegin()
m.db.ExpectRollback()
var sets []state.SetRequest
sets = append(sets, state.SetRequest{ // Set request without value is not valid for Set operation
Key: "key1",
Value: "",
})
sets := []state.SetRequest{
// Set request without value is not valid for Set operation
{Key: "key1", Value: "value1"},
}
// Act
err := m.pgDba.BulkSet(context.Background(), sets)
@ -374,22 +378,26 @@ func TestValidBulkSet(t *testing.T) {
m, _ := mockDatabase(t)
defer m.db.Close()
m.mock.ExpectBegin()
m.mock.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(1, 1))
m.mock.ExpectCommit()
sets := []state.SetRequest{
{
Key: "key1",
Value: "value1",
},
}
var sets []state.SetRequest
sets = append(sets, state.SetRequest{
Key: "key1",
Value: "value1",
})
m.db.ExpectBegin()
m.db.ExpectExec("INSERT INTO").
WithArgs("key1", `"value1"`, false).
WillReturnResult(pgxmock.NewResult("INSERT", 1))
m.db.ExpectCommit()
// There's also a rollback called after a commit, which is expected and will not have effect
m.db.ExpectRollback()
// Act
err := m.pgDba.BulkSet(context.Background(), sets)
// Assert
assert.Nil(t, err)
assert.NoError(t, err)
}
func TestInvalidBulkDeleteNoKey(t *testing.T) {
@ -397,8 +405,8 @@ func TestInvalidBulkDeleteNoKey(t *testing.T) {
m, _ := mockDatabase(t)
defer m.db.Close()
m.mock.ExpectBegin()
m.mock.ExpectRollback()
m.db.ExpectBegin()
m.db.ExpectRollback()
var deletes []state.DeleteRequest
@ -418,21 +426,23 @@ func TestValidBulkDelete(t *testing.T) {
m, _ := mockDatabase(t)
defer m.db.Close()
m.mock.ExpectBegin()
m.mock.ExpectExec("DELETE FROM").WillReturnResult(sqlmock.NewResult(1, 1))
m.mock.ExpectCommit()
deletes := []state.DeleteRequest{
{Key: "key1"},
}
var deletes []state.DeleteRequest
deletes = append(deletes, state.DeleteRequest{
Key: "key1",
})
m.db.ExpectBegin()
m.db.ExpectExec("DELETE FROM").
WithArgs("key1").
WillReturnResult(pgxmock.NewResult("DELETE", 1))
m.db.ExpectCommit()
// There's also a rollback called after a commit, which is expected and will not have effect
m.db.ExpectRollback()
// Act
err := m.pgDba.BulkDelete(context.Background(), deletes)
// Assert
assert.Nil(t, err)
assert.NoError(t, err)
}
func createSetRequest() state.SetRequest {
@ -451,7 +461,7 @@ func createDeleteRequest() state.DeleteRequest {
func mockDatabase(t *testing.T) (*mocks, error) {
logger := logger.NewLogger("test")
db, mock, err := sqlmock.New()
db, err := pgxmock.NewPool()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
@ -463,7 +473,6 @@ func mockDatabase(t *testing.T) (*mocks, error) {
return &mocks{
db: db,
mock: mock,
pgDba: dba,
}, err
}

View File

@ -12,18 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package postgresql
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"os"
"testing"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/state"
@ -162,7 +165,7 @@ func deleteItemThatDoesNotExist(t *testing.T, pgs *PostgreSQL) {
Key: randomKey(),
}
err := pgs.Delete(context.Background(), deleteReq)
assert.Nil(t, err)
assert.NoError(t, err)
}
func multiWithSetOnly(t *testing.T, pgs *PostgreSQL) {
@ -183,7 +186,7 @@ func multiWithSetOnly(t *testing.T, pgs *PostgreSQL) {
err := pgs.Multi(context.Background(), &state.TransactionalStateRequest{
Operations: operations,
})
assert.Nil(t, err)
assert.NoError(t, err)
for _, set := range setRequests {
assert.True(t, storeItemExists(t, set.Key))
@ -213,7 +216,7 @@ func multiWithDeleteOnly(t *testing.T, pgs *PostgreSQL) {
err := pgs.Multi(context.Background(), &state.TransactionalStateRequest{
Operations: operations,
})
assert.Nil(t, err)
assert.NoError(t, err)
for _, delete := range deleteRequests {
assert.False(t, storeItemExists(t, delete.Key))
@ -256,7 +259,7 @@ func multiWithDeleteAndSet(t *testing.T, pgs *PostgreSQL) {
err := pgs.Multi(context.Background(), &state.TransactionalStateRequest{
Operations: operations,
})
assert.Nil(t, err)
assert.NoError(t, err)
for _, delete := range deleteRequests {
assert.False(t, storeItemExists(t, delete.Key))
@ -384,15 +387,16 @@ func setUpdatesTheUpdatedateField(t *testing.T, pgs *PostgreSQL) {
// insertdate should have a value and updatedate should be nil
_, insertdate, updatedate := getRowData(t, key)
assert.Nil(t, updatedate)
assert.NotNil(t, insertdate)
assert.Equal(t, "", updatedate.String)
// insertdate should not change, updatedate should have a value
value = &fakeItem{Color: "aqua"}
setItem(t, pgs, key, value, nil)
_, newinsertdate, updatedate := getRowData(t, key)
assert.Equal(t, insertdate, newinsertdate) // The insertdate should not change.
assert.NotEqual(t, "", updatedate.String)
assert.NotNil(t, updatedate)
assert.NotNil(t, newinsertdate)
assert.True(t, insertdate.Equal(*newinsertdate)) // The insertdate should not change.
deleteItem(t, pgs, key, nil)
}
@ -420,7 +424,7 @@ func testBulkSetAndBulkDelete(t *testing.T, pgs *PostgreSQL) {
}
err := pgs.BulkSet(context.Background(), setReq)
assert.Nil(t, err)
assert.NoError(t, err)
assert.True(t, storeItemExists(t, setReq[0].Key))
assert.True(t, storeItemExists(t, setReq[1].Key))
@ -434,7 +438,7 @@ func testBulkSetAndBulkDelete(t *testing.T, pgs *PostgreSQL) {
}
err = pgs.BulkDelete(context.Background(), deleteReq)
assert.Nil(t, err)
assert.NoError(t, err)
assert.False(t, storeItemExists(t, setReq[0].Key))
assert.False(t, storeItemExists(t, setReq[1].Key))
}
@ -491,7 +495,7 @@ func setItem(t *testing.T, pgs *PostgreSQL, key string, value interface{}, etag
}
err := pgs.Set(context.Background(), setReq)
assert.Nil(t, err)
assert.NoError(t, err)
itemExists := storeItemExists(t, key)
assert.True(t, itemExists)
}
@ -524,25 +528,28 @@ func deleteItem(t *testing.T, pgs *PostgreSQL, key string, etag *string) {
}
func storeItemExists(t *testing.T, key string) bool {
db, err := sql.Open("pgx", getConnectionString())
assert.Nil(t, err)
defer db.Close()
ctx := context.Background()
db, err := pgx.Connect(ctx, getConnectionString())
require.NoError(t, err)
defer db.Close(ctx)
exists := false
statement := fmt.Sprintf(`SELECT EXISTS (SELECT FROM %s WHERE key = $1)`, defaultTableName)
err = db.QueryRow(statement, key).Scan(&exists)
assert.Nil(t, err)
err = db.QueryRow(ctx, statement, key).Scan(&exists)
assert.NoError(t, err)
return exists
}
func getRowData(t *testing.T, key string) (returnValue string, insertdate sql.NullString, updatedate sql.NullString) {
db, err := sql.Open("pgx", getConnectionString())
assert.Nil(t, err)
defer db.Close()
func getRowData(t *testing.T, key string) (returnValue string, insertdate *time.Time, updatedate *time.Time) {
ctx := context.Background()
db, err := pgx.Connect(ctx, getConnectionString())
require.NoError(t, err)
defer db.Close(ctx)
err = db.QueryRow(fmt.Sprintf("SELECT value, insertdate, updatedate FROM %s WHERE key = $1", defaultTableName), key).Scan(&returnValue, &insertdate, &updatedate)
assert.Nil(t, err)
err = db.QueryRow(ctx, fmt.Sprintf("SELECT value, insertdate, updatedate FROM %s WHERE key = $1", defaultTableName), key).
Scan(&returnValue, &insertdate, &updatedate)
assert.NoError(t, err)
return returnValue, insertdate, updatedate
}

View File

@ -16,7 +16,6 @@ package postgresql
import (
"context"
"database/sql"
"fmt"
"strconv"
"strings"
@ -140,8 +139,8 @@ func (q *Query) Finalize(filters string, qq *query.Query) error {
return nil
}
func (q *Query) execute(ctx context.Context, logger logger.Logger, db *sql.DB) ([]state.QueryItem, string, error) {
rows, err := db.QueryContext(ctx, q.query, q.params...)
func (q *Query) execute(ctx context.Context, logger logger.Logger, db dbquerier) ([]state.QueryItem, string, error) {
rows, err := db.Query(ctx, q.query, q.params...)
if err != nil {
return nil, "", err
}
@ -152,7 +151,7 @@ func (q *Query) execute(ctx context.Context, logger logger.Logger, db *sql.DB) (
var (
key string
data []byte
etag uint64 // Postgres uses uint32, but FormatUint requires uint64, so using uint64 directly to avoid re-allocations
etag uint32
)
if err = rows.Scan(&key, &data, &etag); err != nil {
return nil, "", err
@ -160,7 +159,7 @@ func (q *Query) execute(ctx context.Context, logger logger.Logger, db *sql.DB) (
result := state.QueryItem{
Key: key,
Data: data,
ETag: ptr.Of(strconv.FormatUint(etag, 10)),
ETag: ptr.Of(strconv.FormatUint(uint64(etag), 10)),
}
ret = append(ret, result)
}

View File

@ -62,6 +62,7 @@ require (
github.com/imdario/mergo v0.3.12 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect
github.com/jackc/puddle/v2 v2.1.2 // indirect
github.com/jhump/protoreflect v1.13.0 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect

View File

@ -37,7 +37,6 @@ github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a h1:
github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go.mod h1:C0A1KeiVHs+trY6gUTPhhGammbrZ30ZfXRW/nuT7HLw=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ=
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig=
@ -288,6 +287,8 @@ github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHF
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E=
github.com/jackc/pgx/v5 v5.2.0 h1:NdPpngX0Y6z6XDFKqmFQaE+bCtkqzvQIOt1wvBlAqs8=
github.com/jackc/pgx/v5 v5.2.0/go.mod h1:Ptn7zmohNsWEsdxRawMzk3gaKma2obW+NWTnKa0S4nk=
github.com/jackc/puddle/v2 v2.1.2 h1:0f7vaaXINONKTsxYDn4otOAiJanX/BMeAtY//BXqzlg=
github.com/jackc/puddle/v2 v2.1.2/go.mod h1:2lpufsF5mRHO6SuZkm0fNYxM6SWHfvyFj62KwNzgels=
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
github.com/jhump/gopoet v0.0.0-20190322174617-17282ff210b3/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI=
github.com/jhump/gopoet v0.1.0/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI=
@ -385,6 +386,7 @@ github.com/openzipkin/zipkin-go v0.4.1/go.mod h1:qY0VqDSN1pOBN94dBc6w2GJlWLiovAy
github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
github.com/pascaldekloe/goe v0.1.0 h1:cBOtyMzM9HTpWjXfbbunk26uA6nG3a8n06Wieeh0MwY=
github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
github.com/pashagolub/pgxmock/v2 v2.4.0 h1:jNv7+svrNoMc31mvllSS/u7P2pT3gS3uY7DPRKIJNSY=
github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 h1:Ii+DKncOVM8Cu1Hc+ETb5K+23HdAMvESYE3ZJ5b5cMI=
github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=

View File

@ -16,7 +16,6 @@ package main
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
@ -706,7 +705,7 @@ func getMigrationLevel(dbClient *pgx.Conn, metadataTable string) (level string,
err = dbClient.
QueryRow(ctx, fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, metadataTable)).
Scan(&level)
if err != nil && errors.Is(err, sql.ErrNoRows) {
if err != nil && errors.Is(err, pgx.ErrNoRows) {
err = nil
level = ""
}
@ -766,7 +765,7 @@ func loadLastCleanupInterval(ctx context.Context, dbClient *pgx.Conn, table stri
).
Scan(&lastCleanup)
cancel()
if errors.Is(err, sql.ErrNoRows) {
if errors.Is(err, pgx.ErrNoRows) {
err = nil
}
return
@ -790,7 +789,7 @@ func getValueFromMetadataTable(ctx context.Context, dbClient *pgx.Conn, table, k
QueryRow(queryCtx, fmt.Sprintf("SELECT value FROM %s WHERE key = $1", table), key).
Scan(&value)
cancel()
if errors.Is(err, sql.ErrNoRows) {
if errors.Is(err, pgx.ErrNoRows) {
err = nil
}
return

View File

@ -272,7 +272,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
req.Metadata = map[string]string{metadata.ContentType: scenario.contentType}
}
res, err := statestore.Get(context.Background(), req)
assert.Nil(t, err)
require.NoError(t, err)
assertEquals(t, scenario.value, res)
}
}
@ -287,13 +287,13 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St
t.Logf("Querying value presence for %s", scenario.query)
var req state.QueryRequest
err := json.Unmarshal([]byte(scenario.query), &req.Query)
assert.NoError(t, err)
require.NoError(t, err)
req.Metadata = map[string]string{
metadata.ContentType: contenttype.JSONContentType,
metadata.QueryIndexName: "qIndx",
}
resp, err := querier.Query(context.Background(), &req)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, len(scenario.results), len(resp.Results))
for i := range scenario.results {
var expected, actual interface{}