Changed postgres state store to use pgx instead of db/sql
Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									2582f154bd
								
							
						
					
					
						commit
						80b1bb14b7
					
				|  | @ -15,7 +15,6 @@ package postgres | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"database/sql" |  | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | @ -114,7 +113,7 @@ func (p *ConfigurationStore) Init(metadata configuration.Metadata) error { | ||||||
| 	exists := false | 	exists := false | ||||||
| 	err = p.client.QueryRow(ctx, QueryTableExists, p.metadata.configTable).Scan(&exists) | 	err = p.client.QueryRow(ctx, QueryTableExists, p.metadata.configTable).Scan(&exists) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		if err == sql.ErrNoRows { | 		if errors.Is(err, pgx.ErrNoRows) { | ||||||
| 			return fmt.Errorf(ErrorMissingTable, p.metadata.configTable) | 			return fmt.Errorf(ErrorMissingTable, p.metadata.configTable) | ||||||
| 		} | 		} | ||||||
| 		return fmt.Errorf("error in checking if configtable '%s' exists - '%w'", p.metadata.configTable, err) | 		return fmt.Errorf("error in checking if configtable '%s' exists - '%w'", p.metadata.configTable, err) | ||||||
|  | @ -135,7 +134,7 @@ func (p *ConfigurationStore) Get(ctx context.Context, req *configuration.GetRequ | ||||||
| 	rows, err := p.client.Query(ctx, query, params...) | 	rows, err := p.client.Query(ctx, query, params...) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		// If no rows exist, return an empty response, otherwise return the error.
 | 		// If no rows exist, return an empty response, otherwise return the error.
 | ||||||
| 		if err == sql.ErrNoRows { | 		if errors.Is(err, pgx.ErrNoRows) { | ||||||
| 			return &configuration.GetResponse{}, nil | 			return &configuration.GetResponse{}, nil | ||||||
| 		} | 		} | ||||||
| 		return nil, fmt.Errorf("error in querying configuration store: '%w'", err) | 		return nil, fmt.Errorf("error in querying configuration store: '%w'", err) | ||||||
|  |  | ||||||
|  | @ -15,7 +15,9 @@ package postgresql | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"database/sql" | 
 | ||||||
|  | 	"github.com/jackc/pgx/v5" | ||||||
|  | 	"github.com/jackc/pgx/v5/pgconn" | ||||||
| 
 | 
 | ||||||
| 	"github.com/dapr/components-contrib/state" | 	"github.com/dapr/components-contrib/state" | ||||||
| ) | ) | ||||||
|  | @ -34,10 +36,9 @@ type dbAccess interface { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Interface that contains methods for querying.
 | // Interface that contains methods for querying.
 | ||||||
| // Applies to both *sql.DB and *sql.Tx
 | // Applies to *pgx.Conn, *pgxpool.Pool, and pgx.Tx
 | ||||||
| type dbquerier interface { | type dbquerier interface { | ||||||
| 	Exec(query string, args ...any) (sql.Result, error) | 	Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) | ||||||
| 	ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) | 	Query(context.Context, string, ...interface{}) (pgx.Rows, error) | ||||||
| 	QueryRow(query string, args ...any) *sql.Row | 	QueryRow(context.Context, string, ...interface{}) pgx.Row | ||||||
| 	QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row |  | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -15,20 +15,21 @@ package postgresql | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"database/sql" |  | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/jackc/pgx/v5" | ||||||
|  | 
 | ||||||
| 	"github.com/dapr/kit/logger" | 	"github.com/dapr/kit/logger" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // Performs migrations for the database schema
 | // Performs migrations for the database schema
 | ||||||
| type migrations struct { | type migrations struct { | ||||||
| 	Logger            logger.Logger | 	Logger            logger.Logger | ||||||
| 	Conn              *sql.DB | 	Conn              pgxPoolConn | ||||||
| 	StateTableName    string | 	StateTableName    string | ||||||
| 	MetadataTableName string | 	MetadataTableName string | ||||||
| } | } | ||||||
|  | @ -42,7 +43,7 @@ func (m *migrations) Perform(ctx context.Context) error { | ||||||
| 
 | 
 | ||||||
| 	// Long timeout here as this query may block
 | 	// Long timeout here as this query may block
 | ||||||
| 	queryCtx, cancel := context.WithTimeout(ctx, time.Minute) | 	queryCtx, cancel := context.WithTimeout(ctx, time.Minute) | ||||||
| 	_, err := m.Conn.ExecContext(queryCtx, "SELECT pg_advisory_lock($1)", lockID) | 	_, err := m.Conn.Exec(queryCtx, "SELECT pg_advisory_lock($1)", lockID) | ||||||
| 	cancel() | 	cancel() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return fmt.Errorf("faild to acquire advisory lock: %w", err) | 		return fmt.Errorf("faild to acquire advisory lock: %w", err) | ||||||
|  | @ -51,7 +52,7 @@ func (m *migrations) Perform(ctx context.Context) error { | ||||||
| 	// Release the lock
 | 	// Release the lock
 | ||||||
| 	defer func() { | 	defer func() { | ||||||
| 		queryCtx, cancel = context.WithTimeout(ctx, time.Minute) | 		queryCtx, cancel = context.WithTimeout(ctx, time.Minute) | ||||||
| 		_, err = m.Conn.ExecContext(queryCtx, "SELECT pg_advisory_unlock($1)", lockID) | 		_, err = m.Conn.Exec(queryCtx, "SELECT pg_advisory_unlock($1)", lockID) | ||||||
| 		cancel() | 		cancel() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			// Panicking here, as this forcibly closes the session and thus ensures we are not leaving locks hanging around
 | 			// Panicking here, as this forcibly closes the session and thus ensures we are not leaving locks hanging around
 | ||||||
|  | @ -84,11 +85,11 @@ func (m *migrations) Perform(ctx context.Context) error { | ||||||
| 	) | 	) | ||||||
| 	queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second) | 	queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second) | ||||||
| 	err = m.Conn. | 	err = m.Conn. | ||||||
| 		QueryRowContext(queryCtx, | 		QueryRow(queryCtx, | ||||||
| 			fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, m.MetadataTableName), | 			fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, m.MetadataTableName), | ||||||
| 		).Scan(&migrationLevelStr) | 		).Scan(&migrationLevelStr) | ||||||
| 	cancel() | 	cancel() | ||||||
| 	if errors.Is(err, sql.ErrNoRows) { | 	if errors.Is(err, pgx.ErrNoRows) { | ||||||
| 		// If there's no row...
 | 		// If there's no row...
 | ||||||
| 		migrationLevel = 0 | 		migrationLevel = 0 | ||||||
| 	} else if err != nil { | 	} else if err != nil { | ||||||
|  | @ -109,7 +110,7 @@ func (m *migrations) Perform(ctx context.Context) error { | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second) | 		queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second) | ||||||
| 		_, err = m.Conn.ExecContext(queryCtx, | 		_, err = m.Conn.Exec(queryCtx, | ||||||
| 			fmt.Sprintf(`INSERT INTO %s (key, value) VALUES ('migrations', $1) ON CONFLICT (key) DO UPDATE SET value = $1`, m.MetadataTableName), | 			fmt.Sprintf(`INSERT INTO %s (key, value) VALUES ('migrations', $1) ON CONFLICT (key) DO UPDATE SET value = $1`, m.MetadataTableName), | ||||||
| 			strconv.Itoa(i+1), | 			strconv.Itoa(i+1), | ||||||
| 		) | 		) | ||||||
|  | @ -126,7 +127,7 @@ func (m migrations) createMetadataTable(ctx context.Context) error { | ||||||
| 	m.Logger.Infof("Creating metadata table '%s'", m.MetadataTableName) | 	m.Logger.Infof("Creating metadata table '%s'", m.MetadataTableName) | ||||||
| 	// Add an "IF NOT EXISTS" in case another Dapr sidecar is creating the same table at the same time
 | 	// Add an "IF NOT EXISTS" in case another Dapr sidecar is creating the same table at the same time
 | ||||||
| 	// In the next step we'll acquire a lock so there won't be issues with concurrency
 | 	// In the next step we'll acquire a lock so there won't be issues with concurrency
 | ||||||
| 	_, err := m.Conn.Exec(fmt.Sprintf( | 	_, err := m.Conn.Exec(ctx, fmt.Sprintf( | ||||||
| 		`CREATE TABLE IF NOT EXISTS %s ( | 		`CREATE TABLE IF NOT EXISTS %s ( | ||||||
| 			key text NOT NULL PRIMARY KEY, | 			key text NOT NULL PRIMARY KEY, | ||||||
| 			value text NOT NULL | 			value text NOT NULL | ||||||
|  | @ -148,7 +149,7 @@ func (m migrations) tableExists(ctx context.Context, tableName string) (exists b | ||||||
| 
 | 
 | ||||||
| 	if schema == "" { | 	if schema == "" { | ||||||
| 		err = m.Conn. | 		err = m.Conn. | ||||||
| 			QueryRowContext( | 			QueryRow( | ||||||
| 				ctx, | 				ctx, | ||||||
| 				`SELECT table_name, table_schema | 				`SELECT table_name, table_schema | ||||||
| 				FROM information_schema.tables  | 				FROM information_schema.tables  | ||||||
|  | @ -158,7 +159,7 @@ func (m migrations) tableExists(ctx context.Context, tableName string) (exists b | ||||||
| 			Scan(&table, &schema) | 			Scan(&table, &schema) | ||||||
| 	} else { | 	} else { | ||||||
| 		err = m.Conn. | 		err = m.Conn. | ||||||
| 			QueryRowContext( | 			QueryRow( | ||||||
| 				ctx, | 				ctx, | ||||||
| 				`SELECT table_name, table_schema | 				`SELECT table_name, table_schema | ||||||
| 				FROM information_schema.tables  | 				FROM information_schema.tables  | ||||||
|  | @ -168,7 +169,7 @@ func (m migrations) tableExists(ctx context.Context, tableName string) (exists b | ||||||
| 			Scan(&table, &schema) | 			Scan(&table, &schema) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err != nil && errors.Is(err, sql.ErrNoRows) { | 	if err != nil && errors.Is(err, pgx.ErrNoRows) { | ||||||
| 		return false, "", "", nil | 		return false, "", "", nil | ||||||
| 	} else if err != nil { | 	} else if err != nil { | ||||||
| 		return false, "", "", fmt.Errorf("failed to check if table '%s' exists: %w", tableName, err) | 		return false, "", "", fmt.Errorf("failed to check if table '%s' exists: %w", tableName, err) | ||||||
|  | @ -195,6 +196,7 @@ var allMigrations = [2]func(ctx context.Context, m *migrations) error{ | ||||||
| 		// We need to add an "IF NOT EXISTS" because we may be migrating from when we did not use a metadata table
 | 		// We need to add an "IF NOT EXISTS" because we may be migrating from when we did not use a metadata table
 | ||||||
| 		m.Logger.Infof("Creating state table '%s'", m.StateTableName) | 		m.Logger.Infof("Creating state table '%s'", m.StateTableName) | ||||||
| 		_, err := m.Conn.Exec( | 		_, err := m.Conn.Exec( | ||||||
|  | 			ctx, | ||||||
| 			fmt.Sprintf( | 			fmt.Sprintf( | ||||||
| 				`CREATE TABLE IF NOT EXISTS %s ( | 				`CREATE TABLE IF NOT EXISTS %s ( | ||||||
| 					key text NOT NULL PRIMARY KEY, | 					key text NOT NULL PRIMARY KEY, | ||||||
|  | @ -215,7 +217,7 @@ var allMigrations = [2]func(ctx context.Context, m *migrations) error{ | ||||||
| 	// Migration 1: add the "expiredate" column
 | 	// Migration 1: add the "expiredate" column
 | ||||||
| 	func(ctx context.Context, m *migrations) error { | 	func(ctx context.Context, m *migrations) error { | ||||||
| 		m.Logger.Infof("Adding expiredate column to state table '%s'", m.StateTableName) | 		m.Logger.Infof("Adding expiredate column to state table '%s'", m.StateTableName) | ||||||
| 		_, err := m.Conn.Exec(fmt.Sprintf( | 		_, err := m.Conn.Exec(ctx, fmt.Sprintf( | ||||||
| 			`ALTER TABLE %s ADD expiredate TIMESTAMP WITH TIME ZONE`, | 			`ALTER TABLE %s ADD expiredate TIMESTAMP WITH TIME ZONE`, | ||||||
| 			m.StateTableName, | 			m.StateTableName, | ||||||
| 		)) | 		)) | ||||||
|  |  | ||||||
|  | @ -15,7 +15,6 @@ package postgresql | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"database/sql" |  | ||||||
| 	"encoding/base64" | 	"encoding/base64" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
|  | @ -23,15 +22,16 @@ import ( | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/jackc/pgx/v5" | ||||||
|  | 	"github.com/jackc/pgx/v5/pgconn" | ||||||
|  | 	"github.com/jackc/pgx/v5/pgxpool" | ||||||
|  | 
 | ||||||
| 	"github.com/dapr/components-contrib/metadata" | 	"github.com/dapr/components-contrib/metadata" | ||||||
| 	"github.com/dapr/components-contrib/state" | 	"github.com/dapr/components-contrib/state" | ||||||
| 	"github.com/dapr/components-contrib/state/query" | 	"github.com/dapr/components-contrib/state/query" | ||||||
| 	stateutils "github.com/dapr/components-contrib/state/utils" | 	stateutils "github.com/dapr/components-contrib/state/utils" | ||||||
| 	"github.com/dapr/kit/logger" | 	"github.com/dapr/kit/logger" | ||||||
| 	"github.com/dapr/kit/ptr" | 	"github.com/dapr/kit/ptr" | ||||||
| 
 |  | ||||||
| 	// Blank import for the underlying Postgres driver.
 |  | ||||||
| 	_ "github.com/jackc/pgx/v5/stdlib" |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
|  | @ -43,12 +43,24 @@ const ( | ||||||
| 
 | 
 | ||||||
| var errMissingConnectionString = errors.New("missing connection string") | var errMissingConnectionString = errors.New("missing connection string") | ||||||
| 
 | 
 | ||||||
|  | // Interface that applies to *pgxpool.Pool.
 | ||||||
|  | // We need this to be able to mock the connection in tests.
 | ||||||
|  | type pgxPoolConn interface { | ||||||
|  | 	Begin(context.Context) (pgx.Tx, error) | ||||||
|  | 	BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) | ||||||
|  | 	Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) | ||||||
|  | 	Query(context.Context, string, ...interface{}) (pgx.Rows, error) | ||||||
|  | 	QueryRow(context.Context, string, ...interface{}) pgx.Row | ||||||
|  | 	Ping(context.Context) error | ||||||
|  | 	Close() | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // PostgresDBAccess implements dbaccess.
 | // PostgresDBAccess implements dbaccess.
 | ||||||
| type PostgresDBAccess struct { | type PostgresDBAccess struct { | ||||||
| 	logger          logger.Logger | 	logger          logger.Logger | ||||||
| 	metadata        postgresMetadataStruct | 	metadata        postgresMetadataStruct | ||||||
| 	cleanupInterval *time.Duration | 	cleanupInterval *time.Duration | ||||||
| 	db              *sql.DB | 	db              pgxPoolConn | ||||||
| 	ctx             context.Context | 	ctx             context.Context | ||||||
| 	cancel          context.CancelFunc | 	cancel          context.CancelFunc | ||||||
| } | } | ||||||
|  | @ -81,23 +93,31 @@ func (p *PostgresDBAccess) Init(meta state.Metadata) error { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	db, err := sql.Open("pgx", p.metadata.ConnectionString) | 	config, err := pgxpool.ParseConfig(p.metadata.ConnectionString) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | 		err = fmt.Errorf("failed to parse connection string: %w", err) | ||||||
|  | 		p.logger.Error(err) | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	if p.metadata.ConnectionMaxIdleTime > 0 { | ||||||
|  | 		config.MaxConnIdleTime = p.metadata.ConnectionMaxIdleTime | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	connCtx, connCancel := context.WithTimeout(p.ctx, 30*time.Second) | ||||||
|  | 	p.db, err = pgxpool.NewWithConfig(connCtx, config) | ||||||
|  | 	connCancel() | ||||||
|  | 	if err != nil { | ||||||
|  | 		err = fmt.Errorf("failed to connect to the database: %w", err) | ||||||
| 		p.logger.Error(err) | 		p.logger.Error(err) | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	p.db = db |  | ||||||
| 
 |  | ||||||
| 	pingCtx, pingCancel := context.WithTimeout(p.ctx, 30*time.Second) | 	pingCtx, pingCancel := context.WithTimeout(p.ctx, 30*time.Second) | ||||||
| 	pingErr := db.PingContext(pingCtx) | 	err = p.db.Ping(pingCtx) | ||||||
| 	pingCancel() | 	pingCancel() | ||||||
| 	if pingErr != nil { |  | ||||||
| 		return pingErr |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	p.db.SetConnMaxIdleTime(p.metadata.ConnectionMaxIdleTime) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | 		err = fmt.Errorf("failed to ping the database: %w", err) | ||||||
|  | 		p.logger.Error(err) | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -117,8 +137,9 @@ func (p *PostgresDBAccess) Init(meta state.Metadata) error { | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *PostgresDBAccess) GetDB() *sql.DB { | func (p *PostgresDBAccess) GetDB() *pgxpool.Pool { | ||||||
| 	return p.db | 	// We can safely cast to *pgxpool.Pool because this method is never used in unit tests where we mock the DB
 | ||||||
|  | 	return p.db.(*pgxpool.Pool) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *PostgresDBAccess) ParseMetadata(meta state.Metadata) error { | func (p *PostgresDBAccess) ParseMetadata(meta state.Metadata) error { | ||||||
|  | @ -193,8 +214,6 @@ func (p *PostgresDBAccess) doSet(parentCtx context.Context, db dbquerier, req *s | ||||||
| 		ttlSeconds = *ttl | 		ttlSeconds = *ttl | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	var result sql.Result |  | ||||||
| 
 |  | ||||||
| 	// Sprintf is required for table name because query.DB does not substitute parameters for table names.
 | 	// Sprintf is required for table name because query.DB does not substitute parameters for table names.
 | ||||||
| 	// Other parameters use query.DB parameter substitution.
 | 	// Other parameters use query.DB parameter substitution.
 | ||||||
| 	var ( | 	var ( | ||||||
|  | @ -246,8 +265,8 @@ func (p *PostgresDBAccess) doSet(parentCtx context.Context, db dbquerier, req *s | ||||||
| 	} else { | 	} else { | ||||||
| 		queryExpiredate = "NULL" | 		queryExpiredate = "NULL" | ||||||
| 	} | 	} | ||||||
| 	result, err = db.ExecContext(parentCtx, fmt.Sprintf(query, p.metadata.TableName, queryExpiredate), params...) |  | ||||||
| 
 | 
 | ||||||
|  | 	result, err := db.Exec(parentCtx, fmt.Sprintf(query, p.metadata.TableName, queryExpiredate), params...) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		if req.ETag != nil && *req.ETag != "" { | 		if req.ETag != nil && *req.ETag != "" { | ||||||
| 			return state.NewETagError(state.ETagMismatch, err) | 			return state.NewETagError(state.ETagMismatch, err) | ||||||
|  | @ -255,11 +274,7 @@ func (p *PostgresDBAccess) doSet(parentCtx context.Context, db dbquerier, req *s | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	rows, err := result.RowsAffected() | 	if result.RowsAffected() != 1 { | ||||||
| 	if err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 	if rows != 1 { |  | ||||||
| 		return errors.New("no item was updated") | 		return errors.New("no item was updated") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -267,11 +282,20 @@ func (p *PostgresDBAccess) doSet(parentCtx context.Context, db dbquerier, req *s | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *PostgresDBAccess) BulkSet(parentCtx context.Context, req []state.SetRequest) error { | func (p *PostgresDBAccess) BulkSet(parentCtx context.Context, req []state.SetRequest) error { | ||||||
| 	tx, err := p.db.BeginTx(parentCtx, nil) | 	ctx, cancel := context.WithTimeout(parentCtx, 30*time.Second) | ||||||
|  | 	tx, err := p.db.Begin(ctx) | ||||||
|  | 	cancel() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return fmt.Errorf("failed to begin transaction: %w", err) | 		return fmt.Errorf("failed to begin transaction: %w", err) | ||||||
| 	} | 	} | ||||||
| 	defer tx.Rollback() | 	defer func() { | ||||||
|  | 		rollbackCtx, rollbackCancel := context.WithTimeout(parentCtx, 30*time.Second) | ||||||
|  | 		rollbackErr := tx.Rollback(rollbackCtx) | ||||||
|  | 		rollbackCancel() | ||||||
|  | 		if rollbackErr != nil && !errors.Is(rollbackErr, pgx.ErrTxClosed) { | ||||||
|  | 			p.logger.Errorf("Failed to rollback transaction in BulkSet: %v", rollbackErr) | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
| 
 | 
 | ||||||
| 	if len(req) > 0 { | 	if len(req) > 0 { | ||||||
| 		for i := range req { | 		for i := range req { | ||||||
|  | @ -282,7 +306,9 @@ func (p *PostgresDBAccess) BulkSet(parentCtx context.Context, req []state.SetReq | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	err = tx.Commit() | 	ctx, cancel = context.WithTimeout(parentCtx, 30*time.Second) | ||||||
|  | 	err = tx.Commit(ctx) | ||||||
|  | 	cancel() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return fmt.Errorf("failed to commit transaction: %w", err) | 		return fmt.Errorf("failed to commit transaction: %w", err) | ||||||
| 	} | 	} | ||||||
|  | @ -299,7 +325,7 @@ func (p *PostgresDBAccess) Get(parentCtx context.Context, req *state.GetRequest) | ||||||
| 	var ( | 	var ( | ||||||
| 		value    []byte | 		value    []byte | ||||||
| 		isBinary bool | 		isBinary bool | ||||||
| 		etag     uint64 // Postgres uses uint32, but FormatUint requires uint64, so using uint64 directly to avoid re-allocations
 | 		etag     uint32 | ||||||
| 	) | 	) | ||||||
| 	query := `SELECT | 	query := `SELECT | ||||||
| 			value, isbinary, xmin AS etag | 			value, isbinary, xmin AS etag | ||||||
|  | @ -307,12 +333,11 @@ func (p *PostgresDBAccess) Get(parentCtx context.Context, req *state.GetRequest) | ||||||
| 			WHERE | 			WHERE | ||||||
| 				key = $1 | 				key = $1 | ||||||
| 				AND (expiredate IS NULL OR expiredate >= CURRENT_TIMESTAMP)` | 				AND (expiredate IS NULL OR expiredate >= CURRENT_TIMESTAMP)` | ||||||
| 	err := p.db. | 	err := p.db.QueryRow(parentCtx, fmt.Sprintf(query, p.metadata.TableName), req.Key). | ||||||
| 		QueryRowContext(parentCtx, fmt.Sprintf(query, p.metadata.TableName), req.Key). |  | ||||||
| 		Scan(&value, &isBinary, &etag) | 		Scan(&value, &isBinary, &etag) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		// If no rows exist, return an empty response, otherwise return the error.
 | 		// If no rows exist, return an empty response, otherwise return the error.
 | ||||||
| 		if err == sql.ErrNoRows { | 		if err == pgx.ErrNoRows { | ||||||
| 			return &state.GetResponse{}, nil | 			return &state.GetResponse{}, nil | ||||||
| 		} | 		} | ||||||
| 		return nil, err | 		return nil, err | ||||||
|  | @ -334,14 +359,14 @@ func (p *PostgresDBAccess) Get(parentCtx context.Context, req *state.GetRequest) | ||||||
| 
 | 
 | ||||||
| 		return &state.GetResponse{ | 		return &state.GetResponse{ | ||||||
| 			Data:     data, | 			Data:     data, | ||||||
| 			ETag:     ptr.Of(strconv.FormatUint(etag, 10)), | 			ETag:     ptr.Of(strconv.FormatUint(uint64(etag), 10)), | ||||||
| 			Metadata: req.Metadata, | 			Metadata: req.Metadata, | ||||||
| 		}, nil | 		}, nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return &state.GetResponse{ | 	return &state.GetResponse{ | ||||||
| 		Data:     value, | 		Data:     value, | ||||||
| 		ETag:     ptr.Of(strconv.FormatUint(etag, 10)), | 		ETag:     ptr.Of(strconv.FormatUint(uint64(etag), 10)), | ||||||
| 		Metadata: req.Metadata, | 		Metadata: req.Metadata, | ||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
|  | @ -356,10 +381,9 @@ func (p *PostgresDBAccess) doDelete(parentCtx context.Context, db dbquerier, req | ||||||
| 		return errors.New("missing key in delete operation") | 		return errors.New("missing key in delete operation") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	var result sql.Result | 	var result pgconn.CommandTag | ||||||
| 
 |  | ||||||
| 	if req.ETag == nil || *req.ETag == "" { | 	if req.ETag == nil || *req.ETag == "" { | ||||||
| 		result, err = db.ExecContext(parentCtx, "DELETE FROM state WHERE key = $1", req.Key) | 		result, err = db.Exec(parentCtx, "DELETE FROM state WHERE key = $1", req.Key) | ||||||
| 	} else { | 	} else { | ||||||
| 		// Convert req.ETag to uint32 for postgres XID compatibility
 | 		// Convert req.ETag to uint32 for postgres XID compatibility
 | ||||||
| 		var etag64 uint64 | 		var etag64 uint64 | ||||||
|  | @ -367,20 +391,15 @@ func (p *PostgresDBAccess) doDelete(parentCtx context.Context, db dbquerier, req | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return state.NewETagError(state.ETagInvalid, err) | 			return state.NewETagError(state.ETagInvalid, err) | ||||||
| 		} | 		} | ||||||
| 		etag := uint32(etag64) |  | ||||||
| 
 | 
 | ||||||
| 		result, err = db.ExecContext(parentCtx, "DELETE FROM state WHERE key = $1 AND xmin = $2", req.Key, etag) | 		result, err = db.Exec(parentCtx, "DELETE FROM state WHERE key = $1 AND xmin = $2", req.Key, uint32(etag64)) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	rows, err := result.RowsAffected() | 	rows := result.RowsAffected() | ||||||
| 	if err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if rows != 1 && req.ETag != nil && *req.ETag != "" { | 	if rows != 1 && req.ETag != nil && *req.ETag != "" { | ||||||
| 		return state.NewETagError(state.ETagMismatch, nil) | 		return state.NewETagError(state.ETagMismatch, nil) | ||||||
| 	} | 	} | ||||||
|  | @ -389,11 +408,20 @@ func (p *PostgresDBAccess) doDelete(parentCtx context.Context, db dbquerier, req | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *PostgresDBAccess) BulkDelete(parentCtx context.Context, req []state.DeleteRequest) error { | func (p *PostgresDBAccess) BulkDelete(parentCtx context.Context, req []state.DeleteRequest) error { | ||||||
| 	tx, err := p.db.BeginTx(parentCtx, nil) | 	ctx, cancel := context.WithTimeout(parentCtx, 30*time.Second) | ||||||
|  | 	tx, err := p.db.Begin(ctx) | ||||||
|  | 	cancel() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return fmt.Errorf("failed to begin transaction: %w", err) | 		return fmt.Errorf("failed to begin transaction: %w", err) | ||||||
| 	} | 	} | ||||||
| 	defer tx.Rollback() | 	defer func() { | ||||||
|  | 		rollbackCtx, rollbackCancel := context.WithTimeout(parentCtx, 30*time.Second) | ||||||
|  | 		rollbackErr := tx.Rollback(rollbackCtx) | ||||||
|  | 		rollbackCancel() | ||||||
|  | 		if rollbackErr != nil && !errors.Is(rollbackErr, pgx.ErrTxClosed) { | ||||||
|  | 			p.logger.Errorf("Failed to rollback transaction in BulkDelete: %v", rollbackErr) | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
| 
 | 
 | ||||||
| 	if len(req) > 0 { | 	if len(req) > 0 { | ||||||
| 		for i := range req { | 		for i := range req { | ||||||
|  | @ -404,7 +432,9 @@ func (p *PostgresDBAccess) BulkDelete(parentCtx context.Context, req []state.Del | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	err = tx.Commit() | 	ctx, cancel = context.WithTimeout(parentCtx, 30*time.Second) | ||||||
|  | 	err = tx.Commit(ctx) | ||||||
|  | 	cancel() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return fmt.Errorf("failed to commit transaction: %w", err) | 		return fmt.Errorf("failed to commit transaction: %w", err) | ||||||
| 	} | 	} | ||||||
|  | @ -413,11 +443,20 @@ func (p *PostgresDBAccess) BulkDelete(parentCtx context.Context, req []state.Del | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *PostgresDBAccess) ExecuteMulti(parentCtx context.Context, request *state.TransactionalStateRequest) error { | func (p *PostgresDBAccess) ExecuteMulti(parentCtx context.Context, request *state.TransactionalStateRequest) error { | ||||||
| 	tx, err := p.db.BeginTx(parentCtx, nil) | 	ctx, cancel := context.WithTimeout(parentCtx, 30*time.Second) | ||||||
|  | 	tx, err := p.db.Begin(ctx) | ||||||
|  | 	cancel() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return fmt.Errorf("failed to begin transaction: %w", err) | 		return fmt.Errorf("failed to begin transaction: %w", err) | ||||||
| 	} | 	} | ||||||
| 	defer tx.Rollback() | 	defer func() { | ||||||
|  | 		rollbackCtx, rollbackCancel := context.WithTimeout(parentCtx, 30*time.Second) | ||||||
|  | 		rollbackErr := tx.Rollback(rollbackCtx) | ||||||
|  | 		rollbackCancel() | ||||||
|  | 		if rollbackErr != nil && !errors.Is(rollbackErr, pgx.ErrTxClosed) { | ||||||
|  | 			p.logger.Errorf("Failed to rollback transaction in ExecMulti: %v", rollbackErr) | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
| 
 | 
 | ||||||
| 	for _, o := range request.Operations { | 	for _, o := range request.Operations { | ||||||
| 		switch o.Operation { | 		switch o.Operation { | ||||||
|  | @ -450,7 +489,9 @@ func (p *PostgresDBAccess) ExecuteMulti(parentCtx context.Context, request *stat | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	err = tx.Commit() | 	ctx, cancel = context.WithTimeout(parentCtx, 30*time.Second) | ||||||
|  | 	err = tx.Commit(ctx) | ||||||
|  | 	cancel() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return fmt.Errorf("failed to commit transaction: %w", err) | 		return fmt.Errorf("failed to commit transaction: %w", err) | ||||||
| 	} | 	} | ||||||
|  | @ -520,19 +561,13 @@ func (p *PostgresDBAccess) CleanupExpired(ctx context.Context) error { | ||||||
| 	// Note we're not using the transaction here as we don't want this to be rolled back half-way or to lock the table unnecessarily
 | 	// Note we're not using the transaction here as we don't want this to be rolled back half-way or to lock the table unnecessarily
 | ||||||
| 	// Need to use fmt.Sprintf because we can't parametrize a table name
 | 	// Need to use fmt.Sprintf because we can't parametrize a table name
 | ||||||
| 	// Note we are not setting a timeout here as this query can take a "long" time, especially if there's no index on expiredate
 | 	// Note we are not setting a timeout here as this query can take a "long" time, especially if there's no index on expiredate
 | ||||||
| 	//nolint:gosec
 |  | ||||||
| 	stmt := fmt.Sprintf(`DELETE FROM %s WHERE expiredate IS NOT NULL AND expiredate < CURRENT_TIMESTAMP`, p.metadata.TableName) | 	stmt := fmt.Sprintf(`DELETE FROM %s WHERE expiredate IS NOT NULL AND expiredate < CURRENT_TIMESTAMP`, p.metadata.TableName) | ||||||
| 	res, err := p.db.ExecContext(ctx, stmt) | 	res, err := p.db.Exec(ctx, stmt) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return fmt.Errorf("failed to execute query: %w", err) | 		return fmt.Errorf("failed to execute query: %w", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	cleaned, err := res.RowsAffected() | 	p.logger.Infof("Removed %d expired rows", res.RowsAffected()) | ||||||
| 	if err != nil { |  | ||||||
| 		return fmt.Errorf("failed to count affected rows: %w", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	p.logger.Infof("Removed %d expired rows", cleaned) |  | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -540,7 +575,7 @@ func (p *PostgresDBAccess) CleanupExpired(ctx context.Context) error { | ||||||
| // Returns true if the row was updated, which means that the cleanup can proceed.
 | // Returns true if the row was updated, which means that the cleanup can proceed.
 | ||||||
| func (p *PostgresDBAccess) UpdateLastCleanup(ctx context.Context, db dbquerier, cleanupInterval time.Duration) (bool, error) { | func (p *PostgresDBAccess) UpdateLastCleanup(ctx context.Context, db dbquerier, cleanupInterval time.Duration) (bool, error) { | ||||||
| 	queryCtx, cancel := context.WithTimeout(ctx, 30*time.Second) | 	queryCtx, cancel := context.WithTimeout(ctx, 30*time.Second) | ||||||
| 	res, err := db.ExecContext(queryCtx, | 	res, err := db.Exec(queryCtx, | ||||||
| 		fmt.Sprintf(`INSERT INTO %[1]s (key, value) | 		fmt.Sprintf(`INSERT INTO %[1]s (key, value) | ||||||
| 			VALUES ('last-cleanup', CURRENT_TIMESTAMP) | 			VALUES ('last-cleanup', CURRENT_TIMESTAMP) | ||||||
| 			ON CONFLICT (key) | 			ON CONFLICT (key) | ||||||
|  | @ -554,11 +589,7 @@ func (p *PostgresDBAccess) UpdateLastCleanup(ctx context.Context, db dbquerier, | ||||||
| 		return true, fmt.Errorf("failed to execute query: %w", err) | 		return true, fmt.Errorf("failed to execute query: %w", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	n, err := res.RowsAffected() | 	n := res.RowsAffected() | ||||||
| 	if err != nil { |  | ||||||
| 		return true, fmt.Errorf("failed to count affected rows: %w", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return n > 0, nil | 	return n > 0, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -569,7 +600,8 @@ func (p *PostgresDBAccess) Close() error { | ||||||
| 		p.cancel = nil | 		p.cancel = nil | ||||||
| 	} | 	} | ||||||
| 	if p.db != nil { | 	if p.db != nil { | ||||||
| 		return p.db.Close() | 		p.db.Close() | ||||||
|  | 		p.db = nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  |  | ||||||
|  | @ -16,24 +16,20 @@ package postgresql | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"database/sql" | 	"encoding/json" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/DATA-DOG/go-sqlmock" | 	pgxmock "github.com/pashagolub/pgxmock/v2" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| 
 | 
 | ||||||
| 	"github.com/dapr/components-contrib/metadata" | 	"github.com/dapr/components-contrib/metadata" | ||||||
| 	"github.com/dapr/components-contrib/state" | 	"github.com/dapr/components-contrib/state" | ||||||
| 	"github.com/dapr/kit/logger" | 	"github.com/dapr/kit/logger" | ||||||
| 
 |  | ||||||
| 	// Blank import for pgx
 |  | ||||||
| 	_ "github.com/jackc/pgx/v5/stdlib" |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type mocks struct { | type mocks struct { | ||||||
| 	db    *sql.DB | 	db    pgxmock.PgxPoolIface | ||||||
| 	mock  sqlmock.Sqlmock |  | ||||||
| 	pgDba *PostgresDBAccess | 	pgDba *PostgresDBAccess | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -67,7 +63,7 @@ func TestGetSetValid(t *testing.T) { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	set, err := getSet(operation) | 	set, err := getSet(operation) | ||||||
| 	assert.Nil(t, err) | 	assert.NoError(t, err) | ||||||
| 	assert.Equal(t, "key1", set.Key) | 	assert.Equal(t, "key1", set.Key) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -101,7 +97,7 @@ func TestGetDeleteValid(t *testing.T) { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	delete, err := getDelete(operation) | 	delete, err := getDelete(operation) | ||||||
| 	assert.Nil(t, err) | 	assert.NoError(t, err) | ||||||
| 	assert.Equal(t, "key1", delete.Key) | 	assert.Equal(t, "key1", delete.Key) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -110,8 +106,10 @@ func TestMultiWithNoRequests(t *testing.T) { | ||||||
| 	m, _ := mockDatabase(t) | 	m, _ := mockDatabase(t) | ||||||
| 	defer m.db.Close() | 	defer m.db.Close() | ||||||
| 
 | 
 | ||||||
| 	m.mock.ExpectBegin() | 	m.db.ExpectBegin() | ||||||
| 	m.mock.ExpectCommit() | 	m.db.ExpectCommit() | ||||||
|  | 	// There's also a rollback called after a commit, which is expected and will not have effect
 | ||||||
|  | 	m.db.ExpectRollback() | ||||||
| 
 | 
 | ||||||
| 	var operations []state.TransactionalStateOperation | 	var operations []state.TransactionalStateOperation | ||||||
| 
 | 
 | ||||||
|  | @ -121,7 +119,7 @@ func TestMultiWithNoRequests(t *testing.T) { | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	// Assert
 | 	// Assert
 | ||||||
| 	assert.Nil(t, err) | 	assert.NoError(t, err) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestInvalidMultiInvalidAction(t *testing.T) { | func TestInvalidMultiInvalidAction(t *testing.T) { | ||||||
|  | @ -129,8 +127,8 @@ func TestInvalidMultiInvalidAction(t *testing.T) { | ||||||
| 	m, _ := mockDatabase(t) | 	m, _ := mockDatabase(t) | ||||||
| 	defer m.db.Close() | 	defer m.db.Close() | ||||||
| 
 | 
 | ||||||
| 	m.mock.ExpectBegin() | 	m.db.ExpectBegin() | ||||||
| 	m.mock.ExpectRollback() | 	m.db.ExpectRollback() | ||||||
| 
 | 
 | ||||||
| 	var operations []state.TransactionalStateOperation | 	var operations []state.TransactionalStateOperation | ||||||
| 
 | 
 | ||||||
|  | @ -153,16 +151,19 @@ func TestValidSetRequest(t *testing.T) { | ||||||
| 	m, _ := mockDatabase(t) | 	m, _ := mockDatabase(t) | ||||||
| 	defer m.db.Close() | 	defer m.db.Close() | ||||||
| 
 | 
 | ||||||
| 	m.mock.ExpectBegin() | 	setReq := createSetRequest() | ||||||
| 	m.mock.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(1, 1)) | 	operations := []state.TransactionalStateOperation{ | ||||||
| 	m.mock.ExpectCommit() | 		{Operation: state.Upsert, Request: setReq}, | ||||||
|  | 	} | ||||||
|  | 	val, _ := json.Marshal(setReq.Value) | ||||||
| 
 | 
 | ||||||
| 	var operations []state.TransactionalStateOperation | 	m.db.ExpectBegin() | ||||||
| 
 | 	m.db.ExpectExec("INSERT INTO"). | ||||||
| 	operations = append(operations, state.TransactionalStateOperation{ | 		WithArgs(setReq.Key, string(val), false). | ||||||
| 		Operation: state.Upsert, | 		WillReturnResult(pgxmock.NewResult("INSERT", 1)) | ||||||
| 		Request:   createSetRequest(), | 	m.db.ExpectCommit() | ||||||
| 	}) | 	// There's also a rollback called after a commit, which is expected and will not have effect
 | ||||||
|  | 	m.db.ExpectRollback() | ||||||
| 
 | 
 | ||||||
| 	// Act
 | 	// Act
 | ||||||
| 	err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{ | 	err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{ | ||||||
|  | @ -170,7 +171,7 @@ func TestValidSetRequest(t *testing.T) { | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	// Assert
 | 	// Assert
 | ||||||
| 	assert.Nil(t, err) | 	assert.NoError(t, err) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestInvalidMultiSetRequest(t *testing.T) { | func TestInvalidMultiSetRequest(t *testing.T) { | ||||||
|  | @ -178,15 +179,15 @@ func TestInvalidMultiSetRequest(t *testing.T) { | ||||||
| 	m, _ := mockDatabase(t) | 	m, _ := mockDatabase(t) | ||||||
| 	defer m.db.Close() | 	defer m.db.Close() | ||||||
| 
 | 
 | ||||||
| 	m.mock.ExpectBegin() | 	m.db.ExpectBegin() | ||||||
| 	m.mock.ExpectRollback() | 	m.db.ExpectRollback() | ||||||
| 
 | 
 | ||||||
| 	var operations []state.TransactionalStateOperation | 	operations := []state.TransactionalStateOperation{ | ||||||
| 
 | 		{ | ||||||
| 	operations = append(operations, state.TransactionalStateOperation{ |  | ||||||
| 			Operation: state.Upsert, | 			Operation: state.Upsert, | ||||||
| 			Request:   createDeleteRequest(), // Delete request is not valid for Upsert operation
 | 			Request:   createDeleteRequest(), // Delete request is not valid for Upsert operation
 | ||||||
| 	}) | 		}, | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	// Act
 | 	// Act
 | ||||||
| 	err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{ | 	err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{ | ||||||
|  | @ -202,8 +203,8 @@ func TestInvalidMultiSetRequestNoKey(t *testing.T) { | ||||||
| 	m, _ := mockDatabase(t) | 	m, _ := mockDatabase(t) | ||||||
| 	defer m.db.Close() | 	defer m.db.Close() | ||||||
| 
 | 
 | ||||||
| 	m.mock.ExpectBegin() | 	m.db.ExpectBegin() | ||||||
| 	m.mock.ExpectRollback() | 	m.db.ExpectRollback() | ||||||
| 
 | 
 | ||||||
| 	var operations []state.TransactionalStateOperation | 	var operations []state.TransactionalStateOperation | ||||||
| 
 | 
 | ||||||
|  | @ -226,16 +227,18 @@ func TestValidMultiDeleteRequest(t *testing.T) { | ||||||
| 	m, _ := mockDatabase(t) | 	m, _ := mockDatabase(t) | ||||||
| 	defer m.db.Close() | 	defer m.db.Close() | ||||||
| 
 | 
 | ||||||
| 	m.mock.ExpectBegin() | 	deleteReq := createDeleteRequest() | ||||||
| 	m.mock.ExpectExec("DELETE FROM").WillReturnResult(sqlmock.NewResult(1, 1)) | 	operations := []state.TransactionalStateOperation{ | ||||||
| 	m.mock.ExpectCommit() | 		{Operation: state.Delete, Request: deleteReq}, | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	var operations []state.TransactionalStateOperation | 	m.db.ExpectBegin() | ||||||
| 
 | 	m.db.ExpectExec("DELETE FROM"). | ||||||
| 	operations = append(operations, state.TransactionalStateOperation{ | 		WithArgs(deleteReq.Key). | ||||||
| 		Operation: state.Delete, | 		WillReturnResult(pgxmock.NewResult("DELETE", 1)) | ||||||
| 		Request:   createDeleteRequest(), | 	m.db.ExpectCommit() | ||||||
| 	}) | 	// There's also a rollback called after a commit, which is expected and will not have effect
 | ||||||
|  | 	m.db.ExpectRollback() | ||||||
| 
 | 
 | ||||||
| 	// Act
 | 	// Act
 | ||||||
| 	err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{ | 	err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{ | ||||||
|  | @ -243,7 +246,7 @@ func TestValidMultiDeleteRequest(t *testing.T) { | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	// Assert
 | 	// Assert
 | ||||||
| 	assert.Nil(t, err) | 	assert.NoError(t, err) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestInvalidMultiDeleteRequest(t *testing.T) { | func TestInvalidMultiDeleteRequest(t *testing.T) { | ||||||
|  | @ -251,8 +254,8 @@ func TestInvalidMultiDeleteRequest(t *testing.T) { | ||||||
| 	m, _ := mockDatabase(t) | 	m, _ := mockDatabase(t) | ||||||
| 	defer m.db.Close() | 	defer m.db.Close() | ||||||
| 
 | 
 | ||||||
| 	m.mock.ExpectBegin() | 	m.db.ExpectBegin() | ||||||
| 	m.mock.ExpectRollback() | 	m.db.ExpectRollback() | ||||||
| 
 | 
 | ||||||
| 	var operations []state.TransactionalStateOperation | 	var operations []state.TransactionalStateOperation | ||||||
| 
 | 
 | ||||||
|  | @ -275,8 +278,8 @@ func TestInvalidMultiDeleteRequestNoKey(t *testing.T) { | ||||||
| 	m, _ := mockDatabase(t) | 	m, _ := mockDatabase(t) | ||||||
| 	defer m.db.Close() | 	defer m.db.Close() | ||||||
| 
 | 
 | ||||||
| 	m.mock.ExpectBegin() | 	m.db.ExpectBegin() | ||||||
| 	m.mock.ExpectRollback() | 	m.db.ExpectRollback() | ||||||
| 
 | 
 | ||||||
| 	var operations []state.TransactionalStateOperation | 	var operations []state.TransactionalStateOperation | ||||||
| 
 | 
 | ||||||
|  | @ -299,23 +302,27 @@ func TestMultiOperationOrder(t *testing.T) { | ||||||
| 	m, _ := mockDatabase(t) | 	m, _ := mockDatabase(t) | ||||||
| 	defer m.db.Close() | 	defer m.db.Close() | ||||||
| 
 | 
 | ||||||
| 	m.mock.ExpectBegin() | 	operations := []state.TransactionalStateOperation{ | ||||||
| 	m.mock.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(1, 1)) | 		{ | ||||||
| 	m.mock.ExpectExec("DELETE FROM").WillReturnResult(sqlmock.NewResult(1, 1)) |  | ||||||
| 	m.mock.ExpectCommit() |  | ||||||
| 
 |  | ||||||
| 	var operations []state.TransactionalStateOperation |  | ||||||
| 
 |  | ||||||
| 	operations = append(operations, |  | ||||||
| 		state.TransactionalStateOperation{ |  | ||||||
| 			Operation: state.Upsert, | 			Operation: state.Upsert, | ||||||
| 			Request:   state.SetRequest{Key: "key1", Value: "value1"}, | 			Request:   state.SetRequest{Key: "key1", Value: "value1"}, | ||||||
| 		}, | 		}, | ||||||
| 		state.TransactionalStateOperation{ | 		{ | ||||||
| 			Operation: state.Delete, | 			Operation: state.Delete, | ||||||
| 			Request:   state.DeleteRequest{Key: "key1"}, | 			Request:   state.DeleteRequest{Key: "key1"}, | ||||||
| 		}, | 		}, | ||||||
| 	) | 	} | ||||||
|  | 
 | ||||||
|  | 	m.db.ExpectBegin() | ||||||
|  | 	m.db.ExpectExec("INSERT INTO"). | ||||||
|  | 		WithArgs("key1", `"value1"`, false). | ||||||
|  | 		WillReturnResult(pgxmock.NewResult("INSERT", 1)) | ||||||
|  | 	m.db.ExpectExec("DELETE FROM"). | ||||||
|  | 		WithArgs("key1"). | ||||||
|  | 		WillReturnResult(pgxmock.NewResult("DELETE", 1)) | ||||||
|  | 	m.db.ExpectCommit() | ||||||
|  | 	// There's also a rollback called after a commit, which is expected and will not have effect
 | ||||||
|  | 	m.db.ExpectRollback() | ||||||
| 
 | 
 | ||||||
| 	// Act
 | 	// Act
 | ||||||
| 	err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{ | 	err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{ | ||||||
|  | @ -323,7 +330,7 @@ func TestMultiOperationOrder(t *testing.T) { | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	// Assert
 | 	// Assert
 | ||||||
| 	assert.Nil(t, err) | 	assert.NoError(t, err) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestInvalidBulkSetNoKey(t *testing.T) { | func TestInvalidBulkSetNoKey(t *testing.T) { | ||||||
|  | @ -331,14 +338,13 @@ func TestInvalidBulkSetNoKey(t *testing.T) { | ||||||
| 	m, _ := mockDatabase(t) | 	m, _ := mockDatabase(t) | ||||||
| 	defer m.db.Close() | 	defer m.db.Close() | ||||||
| 
 | 
 | ||||||
| 	m.mock.ExpectBegin() | 	m.db.ExpectBegin() | ||||||
| 	m.mock.ExpectRollback() | 	m.db.ExpectRollback() | ||||||
| 
 | 
 | ||||||
| 	var sets []state.SetRequest | 	sets := []state.SetRequest{ | ||||||
| 
 | 		// Set request without key is not valid for Set operation
 | ||||||
| 	sets = append(sets, state.SetRequest{ // Set request without key is not valid for Set operation
 | 		{Value: "value1"}, | ||||||
| 		Value: "value1", | 	} | ||||||
| 	}) |  | ||||||
| 
 | 
 | ||||||
| 	// Act
 | 	// Act
 | ||||||
| 	err := m.pgDba.BulkSet(context.Background(), sets) | 	err := m.pgDba.BulkSet(context.Background(), sets) | ||||||
|  | @ -352,15 +358,13 @@ func TestInvalidBulkSetEmptyValue(t *testing.T) { | ||||||
| 	m, _ := mockDatabase(t) | 	m, _ := mockDatabase(t) | ||||||
| 	defer m.db.Close() | 	defer m.db.Close() | ||||||
| 
 | 
 | ||||||
| 	m.mock.ExpectBegin() | 	m.db.ExpectBegin() | ||||||
| 	m.mock.ExpectRollback() | 	m.db.ExpectRollback() | ||||||
| 
 | 
 | ||||||
| 	var sets []state.SetRequest | 	sets := []state.SetRequest{ | ||||||
| 
 | 		// Set request without value is not valid for Set operation
 | ||||||
| 	sets = append(sets, state.SetRequest{ // Set request without value is not valid for Set operation
 | 		{Key: "key1", Value: "value1"}, | ||||||
| 		Key:   "key1", | 	} | ||||||
| 		Value: "", |  | ||||||
| 	}) |  | ||||||
| 
 | 
 | ||||||
| 	// Act
 | 	// Act
 | ||||||
| 	err := m.pgDba.BulkSet(context.Background(), sets) | 	err := m.pgDba.BulkSet(context.Background(), sets) | ||||||
|  | @ -374,22 +378,26 @@ func TestValidBulkSet(t *testing.T) { | ||||||
| 	m, _ := mockDatabase(t) | 	m, _ := mockDatabase(t) | ||||||
| 	defer m.db.Close() | 	defer m.db.Close() | ||||||
| 
 | 
 | ||||||
| 	m.mock.ExpectBegin() | 	sets := []state.SetRequest{ | ||||||
| 	m.mock.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(1, 1)) | 		{ | ||||||
| 	m.mock.ExpectCommit() |  | ||||||
| 
 |  | ||||||
| 	var sets []state.SetRequest |  | ||||||
| 
 |  | ||||||
| 	sets = append(sets, state.SetRequest{ |  | ||||||
| 			Key:   "key1", | 			Key:   "key1", | ||||||
| 			Value: "value1", | 			Value: "value1", | ||||||
| 	}) | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	m.db.ExpectBegin() | ||||||
|  | 	m.db.ExpectExec("INSERT INTO"). | ||||||
|  | 		WithArgs("key1", `"value1"`, false). | ||||||
|  | 		WillReturnResult(pgxmock.NewResult("INSERT", 1)) | ||||||
|  | 	m.db.ExpectCommit() | ||||||
|  | 	// There's also a rollback called after a commit, which is expected and will not have effect
 | ||||||
|  | 	m.db.ExpectRollback() | ||||||
| 
 | 
 | ||||||
| 	// Act
 | 	// Act
 | ||||||
| 	err := m.pgDba.BulkSet(context.Background(), sets) | 	err := m.pgDba.BulkSet(context.Background(), sets) | ||||||
| 
 | 
 | ||||||
| 	// Assert
 | 	// Assert
 | ||||||
| 	assert.Nil(t, err) | 	assert.NoError(t, err) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestInvalidBulkDeleteNoKey(t *testing.T) { | func TestInvalidBulkDeleteNoKey(t *testing.T) { | ||||||
|  | @ -397,8 +405,8 @@ func TestInvalidBulkDeleteNoKey(t *testing.T) { | ||||||
| 	m, _ := mockDatabase(t) | 	m, _ := mockDatabase(t) | ||||||
| 	defer m.db.Close() | 	defer m.db.Close() | ||||||
| 
 | 
 | ||||||
| 	m.mock.ExpectBegin() | 	m.db.ExpectBegin() | ||||||
| 	m.mock.ExpectRollback() | 	m.db.ExpectRollback() | ||||||
| 
 | 
 | ||||||
| 	var deletes []state.DeleteRequest | 	var deletes []state.DeleteRequest | ||||||
| 
 | 
 | ||||||
|  | @ -418,21 +426,23 @@ func TestValidBulkDelete(t *testing.T) { | ||||||
| 	m, _ := mockDatabase(t) | 	m, _ := mockDatabase(t) | ||||||
| 	defer m.db.Close() | 	defer m.db.Close() | ||||||
| 
 | 
 | ||||||
| 	m.mock.ExpectBegin() | 	deletes := []state.DeleteRequest{ | ||||||
| 	m.mock.ExpectExec("DELETE FROM").WillReturnResult(sqlmock.NewResult(1, 1)) | 		{Key: "key1"}, | ||||||
| 	m.mock.ExpectCommit() | 	} | ||||||
| 
 | 
 | ||||||
| 	var deletes []state.DeleteRequest | 	m.db.ExpectBegin() | ||||||
| 
 | 	m.db.ExpectExec("DELETE FROM"). | ||||||
| 	deletes = append(deletes, state.DeleteRequest{ | 		WithArgs("key1"). | ||||||
| 		Key: "key1", | 		WillReturnResult(pgxmock.NewResult("DELETE", 1)) | ||||||
| 	}) | 	m.db.ExpectCommit() | ||||||
|  | 	// There's also a rollback called after a commit, which is expected and will not have effect
 | ||||||
|  | 	m.db.ExpectRollback() | ||||||
| 
 | 
 | ||||||
| 	// Act
 | 	// Act
 | ||||||
| 	err := m.pgDba.BulkDelete(context.Background(), deletes) | 	err := m.pgDba.BulkDelete(context.Background(), deletes) | ||||||
| 
 | 
 | ||||||
| 	// Assert
 | 	// Assert
 | ||||||
| 	assert.Nil(t, err) | 	assert.NoError(t, err) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func createSetRequest() state.SetRequest { | func createSetRequest() state.SetRequest { | ||||||
|  | @ -451,7 +461,7 @@ func createDeleteRequest() state.DeleteRequest { | ||||||
| func mockDatabase(t *testing.T) (*mocks, error) { | func mockDatabase(t *testing.T) (*mocks, error) { | ||||||
| 	logger := logger.NewLogger("test") | 	logger := logger.NewLogger("test") | ||||||
| 
 | 
 | ||||||
| 	db, mock, err := sqlmock.New() | 	db, err := pgxmock.NewPool() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) | 		t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) | ||||||
| 	} | 	} | ||||||
|  | @ -463,7 +473,6 @@ func mockDatabase(t *testing.T) (*mocks, error) { | ||||||
| 
 | 
 | ||||||
| 	return &mocks{ | 	return &mocks{ | ||||||
| 		db:    db, | 		db:    db, | ||||||
| 		mock:  mock, |  | ||||||
| 		pgDba: dba, | 		pgDba: dba, | ||||||
| 	}, err | 	}, err | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -12,18 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| See the License for the specific language governing permissions and | See the License for the specific language governing permissions and | ||||||
| limitations under the License. | limitations under the License. | ||||||
| */ | */ | ||||||
|  | 
 | ||||||
| package postgresql | package postgresql | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"database/sql" |  | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"os" | 	"os" | ||||||
| 	"testing" | 	"testing" | ||||||
|  | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/google/uuid" | 	"github.com/google/uuid" | ||||||
|  | 	"github.com/jackc/pgx/v5" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
|  | 	"github.com/stretchr/testify/require" | ||||||
| 
 | 
 | ||||||
| 	"github.com/dapr/components-contrib/metadata" | 	"github.com/dapr/components-contrib/metadata" | ||||||
| 	"github.com/dapr/components-contrib/state" | 	"github.com/dapr/components-contrib/state" | ||||||
|  | @ -162,7 +165,7 @@ func deleteItemThatDoesNotExist(t *testing.T, pgs *PostgreSQL) { | ||||||
| 		Key: randomKey(), | 		Key: randomKey(), | ||||||
| 	} | 	} | ||||||
| 	err := pgs.Delete(context.Background(), deleteReq) | 	err := pgs.Delete(context.Background(), deleteReq) | ||||||
| 	assert.Nil(t, err) | 	assert.NoError(t, err) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func multiWithSetOnly(t *testing.T, pgs *PostgreSQL) { | func multiWithSetOnly(t *testing.T, pgs *PostgreSQL) { | ||||||
|  | @ -183,7 +186,7 @@ func multiWithSetOnly(t *testing.T, pgs *PostgreSQL) { | ||||||
| 	err := pgs.Multi(context.Background(), &state.TransactionalStateRequest{ | 	err := pgs.Multi(context.Background(), &state.TransactionalStateRequest{ | ||||||
| 		Operations: operations, | 		Operations: operations, | ||||||
| 	}) | 	}) | ||||||
| 	assert.Nil(t, err) | 	assert.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 	for _, set := range setRequests { | 	for _, set := range setRequests { | ||||||
| 		assert.True(t, storeItemExists(t, set.Key)) | 		assert.True(t, storeItemExists(t, set.Key)) | ||||||
|  | @ -213,7 +216,7 @@ func multiWithDeleteOnly(t *testing.T, pgs *PostgreSQL) { | ||||||
| 	err := pgs.Multi(context.Background(), &state.TransactionalStateRequest{ | 	err := pgs.Multi(context.Background(), &state.TransactionalStateRequest{ | ||||||
| 		Operations: operations, | 		Operations: operations, | ||||||
| 	}) | 	}) | ||||||
| 	assert.Nil(t, err) | 	assert.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 	for _, delete := range deleteRequests { | 	for _, delete := range deleteRequests { | ||||||
| 		assert.False(t, storeItemExists(t, delete.Key)) | 		assert.False(t, storeItemExists(t, delete.Key)) | ||||||
|  | @ -256,7 +259,7 @@ func multiWithDeleteAndSet(t *testing.T, pgs *PostgreSQL) { | ||||||
| 	err := pgs.Multi(context.Background(), &state.TransactionalStateRequest{ | 	err := pgs.Multi(context.Background(), &state.TransactionalStateRequest{ | ||||||
| 		Operations: operations, | 		Operations: operations, | ||||||
| 	}) | 	}) | ||||||
| 	assert.Nil(t, err) | 	assert.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 	for _, delete := range deleteRequests { | 	for _, delete := range deleteRequests { | ||||||
| 		assert.False(t, storeItemExists(t, delete.Key)) | 		assert.False(t, storeItemExists(t, delete.Key)) | ||||||
|  | @ -384,15 +387,16 @@ func setUpdatesTheUpdatedateField(t *testing.T, pgs *PostgreSQL) { | ||||||
| 
 | 
 | ||||||
| 	// insertdate should have a value and updatedate should be nil
 | 	// insertdate should have a value and updatedate should be nil
 | ||||||
| 	_, insertdate, updatedate := getRowData(t, key) | 	_, insertdate, updatedate := getRowData(t, key) | ||||||
|  | 	assert.Nil(t, updatedate) | ||||||
| 	assert.NotNil(t, insertdate) | 	assert.NotNil(t, insertdate) | ||||||
| 	assert.Equal(t, "", updatedate.String) |  | ||||||
| 
 | 
 | ||||||
| 	// insertdate should not change, updatedate should have a value
 | 	// insertdate should not change, updatedate should have a value
 | ||||||
| 	value = &fakeItem{Color: "aqua"} | 	value = &fakeItem{Color: "aqua"} | ||||||
| 	setItem(t, pgs, key, value, nil) | 	setItem(t, pgs, key, value, nil) | ||||||
| 	_, newinsertdate, updatedate := getRowData(t, key) | 	_, newinsertdate, updatedate := getRowData(t, key) | ||||||
| 	assert.Equal(t, insertdate, newinsertdate) // The insertdate should not change.
 | 	assert.NotNil(t, updatedate) | ||||||
| 	assert.NotEqual(t, "", updatedate.String) | 	assert.NotNil(t, newinsertdate) | ||||||
|  | 	assert.True(t, insertdate.Equal(*newinsertdate)) // The insertdate should not change.
 | ||||||
| 
 | 
 | ||||||
| 	deleteItem(t, pgs, key, nil) | 	deleteItem(t, pgs, key, nil) | ||||||
| } | } | ||||||
|  | @ -420,7 +424,7 @@ func testBulkSetAndBulkDelete(t *testing.T, pgs *PostgreSQL) { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	err := pgs.BulkSet(context.Background(), setReq) | 	err := pgs.BulkSet(context.Background(), setReq) | ||||||
| 	assert.Nil(t, err) | 	assert.NoError(t, err) | ||||||
| 	assert.True(t, storeItemExists(t, setReq[0].Key)) | 	assert.True(t, storeItemExists(t, setReq[0].Key)) | ||||||
| 	assert.True(t, storeItemExists(t, setReq[1].Key)) | 	assert.True(t, storeItemExists(t, setReq[1].Key)) | ||||||
| 
 | 
 | ||||||
|  | @ -434,7 +438,7 @@ func testBulkSetAndBulkDelete(t *testing.T, pgs *PostgreSQL) { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	err = pgs.BulkDelete(context.Background(), deleteReq) | 	err = pgs.BulkDelete(context.Background(), deleteReq) | ||||||
| 	assert.Nil(t, err) | 	assert.NoError(t, err) | ||||||
| 	assert.False(t, storeItemExists(t, setReq[0].Key)) | 	assert.False(t, storeItemExists(t, setReq[0].Key)) | ||||||
| 	assert.False(t, storeItemExists(t, setReq[1].Key)) | 	assert.False(t, storeItemExists(t, setReq[1].Key)) | ||||||
| } | } | ||||||
|  | @ -491,7 +495,7 @@ func setItem(t *testing.T, pgs *PostgreSQL, key string, value interface{}, etag | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	err := pgs.Set(context.Background(), setReq) | 	err := pgs.Set(context.Background(), setReq) | ||||||
| 	assert.Nil(t, err) | 	assert.NoError(t, err) | ||||||
| 	itemExists := storeItemExists(t, key) | 	itemExists := storeItemExists(t, key) | ||||||
| 	assert.True(t, itemExists) | 	assert.True(t, itemExists) | ||||||
| } | } | ||||||
|  | @ -524,25 +528,28 @@ func deleteItem(t *testing.T, pgs *PostgreSQL, key string, etag *string) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func storeItemExists(t *testing.T, key string) bool { | func storeItemExists(t *testing.T, key string) bool { | ||||||
| 	db, err := sql.Open("pgx", getConnectionString()) | 	ctx := context.Background() | ||||||
| 	assert.Nil(t, err) | 	db, err := pgx.Connect(ctx, getConnectionString()) | ||||||
| 	defer db.Close() | 	require.NoError(t, err) | ||||||
|  | 	defer db.Close(ctx) | ||||||
| 
 | 
 | ||||||
| 	exists := false | 	exists := false | ||||||
| 	statement := fmt.Sprintf(`SELECT EXISTS (SELECT FROM %s WHERE key = $1)`, defaultTableName) | 	statement := fmt.Sprintf(`SELECT EXISTS (SELECT FROM %s WHERE key = $1)`, defaultTableName) | ||||||
| 	err = db.QueryRow(statement, key).Scan(&exists) | 	err = db.QueryRow(ctx, statement, key).Scan(&exists) | ||||||
| 	assert.Nil(t, err) | 	assert.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 	return exists | 	return exists | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func getRowData(t *testing.T, key string) (returnValue string, insertdate sql.NullString, updatedate sql.NullString) { | func getRowData(t *testing.T, key string) (returnValue string, insertdate *time.Time, updatedate *time.Time) { | ||||||
| 	db, err := sql.Open("pgx", getConnectionString()) | 	ctx := context.Background() | ||||||
| 	assert.Nil(t, err) | 	db, err := pgx.Connect(ctx, getConnectionString()) | ||||||
| 	defer db.Close() | 	require.NoError(t, err) | ||||||
|  | 	defer db.Close(ctx) | ||||||
| 
 | 
 | ||||||
| 	err = db.QueryRow(fmt.Sprintf("SELECT value, insertdate, updatedate FROM %s WHERE key = $1", defaultTableName), key).Scan(&returnValue, &insertdate, &updatedate) | 	err = db.QueryRow(ctx, fmt.Sprintf("SELECT value, insertdate, updatedate FROM %s WHERE key = $1", defaultTableName), key). | ||||||
| 	assert.Nil(t, err) | 		Scan(&returnValue, &insertdate, &updatedate) | ||||||
|  | 	assert.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 	return returnValue, insertdate, updatedate | 	return returnValue, insertdate, updatedate | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -16,7 +16,6 @@ package postgresql | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"database/sql" |  | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
|  | @ -140,8 +139,8 @@ func (q *Query) Finalize(filters string, qq *query.Query) error { | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (q *Query) execute(ctx context.Context, logger logger.Logger, db *sql.DB) ([]state.QueryItem, string, error) { | func (q *Query) execute(ctx context.Context, logger logger.Logger, db dbquerier) ([]state.QueryItem, string, error) { | ||||||
| 	rows, err := db.QueryContext(ctx, q.query, q.params...) | 	rows, err := db.Query(ctx, q.query, q.params...) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, "", err | 		return nil, "", err | ||||||
| 	} | 	} | ||||||
|  | @ -152,7 +151,7 @@ func (q *Query) execute(ctx context.Context, logger logger.Logger, db *sql.DB) ( | ||||||
| 		var ( | 		var ( | ||||||
| 			key  string | 			key  string | ||||||
| 			data []byte | 			data []byte | ||||||
| 			etag uint64 // Postgres uses uint32, but FormatUint requires uint64, so using uint64 directly to avoid re-allocations
 | 			etag uint32 | ||||||
| 		) | 		) | ||||||
| 		if err = rows.Scan(&key, &data, &etag); err != nil { | 		if err = rows.Scan(&key, &data, &etag); err != nil { | ||||||
| 			return nil, "", err | 			return nil, "", err | ||||||
|  | @ -160,7 +159,7 @@ func (q *Query) execute(ctx context.Context, logger logger.Logger, db *sql.DB) ( | ||||||
| 		result := state.QueryItem{ | 		result := state.QueryItem{ | ||||||
| 			Key:  key, | 			Key:  key, | ||||||
| 			Data: data, | 			Data: data, | ||||||
| 			ETag: ptr.Of(strconv.FormatUint(etag, 10)), | 			ETag: ptr.Of(strconv.FormatUint(uint64(etag), 10)), | ||||||
| 		} | 		} | ||||||
| 		ret = append(ret, result) | 		ret = append(ret, result) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -62,6 +62,7 @@ require ( | ||||||
| 	github.com/imdario/mergo v0.3.12 // indirect | 	github.com/imdario/mergo v0.3.12 // indirect | ||||||
| 	github.com/jackc/pgpassfile v1.0.0 // indirect | 	github.com/jackc/pgpassfile v1.0.0 // indirect | ||||||
| 	github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect | 	github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect | ||||||
|  | 	github.com/jackc/puddle/v2 v2.1.2 // indirect | ||||||
| 	github.com/jhump/protoreflect v1.13.0 // indirect | 	github.com/jhump/protoreflect v1.13.0 // indirect | ||||||
| 	github.com/josharian/intern v1.0.0 // indirect | 	github.com/josharian/intern v1.0.0 // indirect | ||||||
| 	github.com/json-iterator/go v1.1.12 // indirect | 	github.com/json-iterator/go v1.1.12 // indirect | ||||||
|  |  | ||||||
|  | @ -37,7 +37,6 @@ github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a h1: | ||||||
| github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go.mod h1:C0A1KeiVHs+trY6gUTPhhGammbrZ30ZfXRW/nuT7HLw= | github.com/AdhityaRamadhanus/fasthttpcors v0.0.0-20170121111917-d4c07198763a/go.mod h1:C0A1KeiVHs+trY6gUTPhhGammbrZ30ZfXRW/nuT7HLw= | ||||||
| github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= | github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= | ||||||
| github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= | github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= | ||||||
| github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= |  | ||||||
| github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= | github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= | ||||||
| github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= | github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= | ||||||
| github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= | github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49Ezig= | ||||||
|  | @ -288,6 +287,8 @@ github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHF | ||||||
| github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= | github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= | ||||||
| github.com/jackc/pgx/v5 v5.2.0 h1:NdPpngX0Y6z6XDFKqmFQaE+bCtkqzvQIOt1wvBlAqs8= | github.com/jackc/pgx/v5 v5.2.0 h1:NdPpngX0Y6z6XDFKqmFQaE+bCtkqzvQIOt1wvBlAqs8= | ||||||
| github.com/jackc/pgx/v5 v5.2.0/go.mod h1:Ptn7zmohNsWEsdxRawMzk3gaKma2obW+NWTnKa0S4nk= | github.com/jackc/pgx/v5 v5.2.0/go.mod h1:Ptn7zmohNsWEsdxRawMzk3gaKma2obW+NWTnKa0S4nk= | ||||||
|  | github.com/jackc/puddle/v2 v2.1.2 h1:0f7vaaXINONKTsxYDn4otOAiJanX/BMeAtY//BXqzlg= | ||||||
|  | github.com/jackc/puddle/v2 v2.1.2/go.mod h1:2lpufsF5mRHO6SuZkm0fNYxM6SWHfvyFj62KwNzgels= | ||||||
| github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= | github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= | ||||||
| github.com/jhump/gopoet v0.0.0-20190322174617-17282ff210b3/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI= | github.com/jhump/gopoet v0.0.0-20190322174617-17282ff210b3/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI= | ||||||
| github.com/jhump/gopoet v0.1.0/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI= | github.com/jhump/gopoet v0.1.0/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI= | ||||||
|  | @ -385,6 +386,7 @@ github.com/openzipkin/zipkin-go v0.4.1/go.mod h1:qY0VqDSN1pOBN94dBc6w2GJlWLiovAy | ||||||
| github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= | github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= | ||||||
| github.com/pascaldekloe/goe v0.1.0 h1:cBOtyMzM9HTpWjXfbbunk26uA6nG3a8n06Wieeh0MwY= | github.com/pascaldekloe/goe v0.1.0 h1:cBOtyMzM9HTpWjXfbbunk26uA6nG3a8n06Wieeh0MwY= | ||||||
| github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= | github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= | ||||||
|  | github.com/pashagolub/pgxmock/v2 v2.4.0 h1:jNv7+svrNoMc31mvllSS/u7P2pT3gS3uY7DPRKIJNSY= | ||||||
| github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 h1:Ii+DKncOVM8Cu1Hc+ETb5K+23HdAMvESYE3ZJ5b5cMI= | github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 h1:Ii+DKncOVM8Cu1Hc+ETb5K+23HdAMvESYE3ZJ5b5cMI= | ||||||
| github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE= | github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE= | ||||||
| github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= | github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= | ||||||
|  |  | ||||||
|  | @ -16,7 +16,6 @@ package main | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"context" | 	"context" | ||||||
| 	"database/sql" |  | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | @ -706,7 +705,7 @@ func getMigrationLevel(dbClient *pgx.Conn, metadataTable string) (level string, | ||||||
| 	err = dbClient. | 	err = dbClient. | ||||||
| 		QueryRow(ctx, fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, metadataTable)). | 		QueryRow(ctx, fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, metadataTable)). | ||||||
| 		Scan(&level) | 		Scan(&level) | ||||||
| 	if err != nil && errors.Is(err, sql.ErrNoRows) { | 	if err != nil && errors.Is(err, pgx.ErrNoRows) { | ||||||
| 		err = nil | 		err = nil | ||||||
| 		level = "" | 		level = "" | ||||||
| 	} | 	} | ||||||
|  | @ -766,7 +765,7 @@ func loadLastCleanupInterval(ctx context.Context, dbClient *pgx.Conn, table stri | ||||||
| 		). | 		). | ||||||
| 		Scan(&lastCleanup) | 		Scan(&lastCleanup) | ||||||
| 	cancel() | 	cancel() | ||||||
| 	if errors.Is(err, sql.ErrNoRows) { | 	if errors.Is(err, pgx.ErrNoRows) { | ||||||
| 		err = nil | 		err = nil | ||||||
| 	} | 	} | ||||||
| 	return | 	return | ||||||
|  | @ -790,7 +789,7 @@ func getValueFromMetadataTable(ctx context.Context, dbClient *pgx.Conn, table, k | ||||||
| 		QueryRow(queryCtx, fmt.Sprintf("SELECT value FROM %s WHERE key = $1", table), key). | 		QueryRow(queryCtx, fmt.Sprintf("SELECT value FROM %s WHERE key = $1", table), key). | ||||||
| 		Scan(&value) | 		Scan(&value) | ||||||
| 	cancel() | 	cancel() | ||||||
| 	if errors.Is(err, sql.ErrNoRows) { | 	if errors.Is(err, pgx.ErrNoRows) { | ||||||
| 		err = nil | 		err = nil | ||||||
| 	} | 	} | ||||||
| 	return | 	return | ||||||
|  |  | ||||||
|  | @ -272,7 +272,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St | ||||||
| 						req.Metadata = map[string]string{metadata.ContentType: scenario.contentType} | 						req.Metadata = map[string]string{metadata.ContentType: scenario.contentType} | ||||||
| 					} | 					} | ||||||
| 					res, err := statestore.Get(context.Background(), req) | 					res, err := statestore.Get(context.Background(), req) | ||||||
| 					assert.Nil(t, err) | 					require.NoError(t, err) | ||||||
| 					assertEquals(t, scenario.value, res) | 					assertEquals(t, scenario.value, res) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
|  | @ -287,13 +287,13 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St | ||||||
| 				t.Logf("Querying value presence for %s", scenario.query) | 				t.Logf("Querying value presence for %s", scenario.query) | ||||||
| 				var req state.QueryRequest | 				var req state.QueryRequest | ||||||
| 				err := json.Unmarshal([]byte(scenario.query), &req.Query) | 				err := json.Unmarshal([]byte(scenario.query), &req.Query) | ||||||
| 				assert.NoError(t, err) | 				require.NoError(t, err) | ||||||
| 				req.Metadata = map[string]string{ | 				req.Metadata = map[string]string{ | ||||||
| 					metadata.ContentType:    contenttype.JSONContentType, | 					metadata.ContentType:    contenttype.JSONContentType, | ||||||
| 					metadata.QueryIndexName: "qIndx", | 					metadata.QueryIndexName: "qIndx", | ||||||
| 				} | 				} | ||||||
| 				resp, err := querier.Query(context.Background(), &req) | 				resp, err := querier.Query(context.Background(), &req) | ||||||
| 				assert.NoError(t, err) | 				require.NoError(t, err) | ||||||
| 				assert.Equal(t, len(scenario.results), len(resp.Results)) | 				assert.Equal(t, len(scenario.results), len(resp.Results)) | ||||||
| 				for i := range scenario.results { | 				for i := range scenario.results { | ||||||
| 					var expected, actual interface{} | 					var expected, actual interface{} | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue