components-contrib/state/postgresql/postgresdbaccess.go

287 lines
7.2 KiB
Go

// ------------------------------------------------------------
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
// ------------------------------------------------------------
package postgresql
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"strconv"
"github.com/dapr/components-contrib/state"
"github.com/dapr/dapr/pkg/logger"
// Blank import for the underlying PostgreSQL driver
_ "github.com/jackc/pgx/v4/stdlib"
)
const (
connectionStringKey = "connectionString"
errMissingConnectionString = "missing connection string"
tableName = "state"
)
// postgresDBAccess implements dbaccess
type postgresDBAccess struct {
logger logger.Logger
metadata state.Metadata
db *sql.DB
connectionString string
}
// newPostgresDBAccess creates a new instance of postgresAccess
func newPostgresDBAccess(logger logger.Logger) *postgresDBAccess {
logger.Debug("Instantiating new PostgreSQL state store")
return &postgresDBAccess{
logger: logger,
}
}
// Init sets up PostgreSQL connection and ensures that the state table exists
func (p *postgresDBAccess) Init(metadata state.Metadata) error {
p.logger.Debug("Initializing PostgreSQL state store")
p.metadata = metadata
if val, ok := metadata.Properties[connectionStringKey]; ok && val != "" {
p.connectionString = val
} else {
p.logger.Error("Missing postgreSQL connection string")
return fmt.Errorf(errMissingConnectionString)
}
db, err := sql.Open("pgx", p.connectionString)
if err != nil {
p.logger.Error(err)
return err
}
p.db = db
pingErr := db.Ping()
if pingErr != nil {
return pingErr
}
err = p.ensureStateTable(tableName)
if err != nil {
return err
}
return nil
}
// Set makes an insert or update to the database.
func (p *postgresDBAccess) Set(req *state.SetRequest) error {
return state.SetWithRetries(p.setValue, req)
}
// setValue is an internal implementation of set to enable passing the logic to state.SetWithRetries as a func.
func (p *postgresDBAccess) setValue(req *state.SetRequest) error {
p.logger.Debug("Setting state value in PostgreSQL")
err := state.CheckSetRequestOptions(req)
if err != nil {
return err
}
if req.Key == "" {
return fmt.Errorf("missing key in set operation")
}
var valueBytes []byte
// Convert to json string
valueBytes, err = json.Marshal(req.Value)
if err != nil {
return err
}
value := string(valueBytes)
var result sql.Result
// Sprintf is required for table name because sql.DB does not substitute parameters for table names.
// Other parameters use sql.DB parameter substitution.
if req.ETag == "" {
result, err = p.db.Exec(fmt.Sprintf(
`INSERT INTO %s (key, value) VALUES ($1, $2)
ON CONFLICT (key) DO UPDATE SET value = $2, updatedate = NOW();`,
tableName), req.Key, value)
} else {
// Convert req.ETag to integer for postgres compatibility
var etag int
etag, err = strconv.Atoi(req.ETag)
if err != nil {
return err
}
// When an etag is provided do an update - no insert
result, err = p.db.Exec(fmt.Sprintf(
`UPDATE %s SET value = $1, updatedate = NOW()
WHERE key = $2 AND xmin = $3;`,
tableName), value, req.Key, etag)
}
return p.returnSingleDBResult(result, err)
}
// Get returns data from the database. If data does not exist for the key an empty state.GetResponse will be returned.
func (p *postgresDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error) {
p.logger.Debug("Getting state value from PostgreSQL")
if req.Key == "" {
return nil, fmt.Errorf("missing key in get operation")
}
var value string
var etag int
err := p.db.QueryRow(fmt.Sprintf("SELECT value, xmin as etag FROM %s WHERE key = $1", tableName), req.Key).Scan(&value, &etag)
if err != nil {
// If no rows exist, return an empty response, otherwise return the error.
if err == sql.ErrNoRows {
return &state.GetResponse{}, nil
}
return nil, err
}
response := &state.GetResponse{
Data: []byte(value),
ETag: strconv.Itoa(etag),
Metadata: req.Metadata,
}
return response, nil
}
// Delete removes an item from the state store.
func (p *postgresDBAccess) Delete(req *state.DeleteRequest) error {
return state.DeleteWithRetries(p.deleteValue, req)
}
// deleteValue is an internal implementation of delete to enable passing the logic to state.DeleteWithRetries as a func.
func (p *postgresDBAccess) deleteValue(req *state.DeleteRequest) error {
p.logger.Debug("Deleting state value from PostgreSQL")
if req.Key == "" {
return fmt.Errorf("missing key in delete operation")
}
var result sql.Result
var err error
if req.ETag == "" {
result, err = p.db.Exec("DELETE FROM state WHERE key = $1", req.Key)
} else {
// Convert req.ETag to integer for postgres compatibility
etag, conversionError := strconv.Atoi(req.ETag)
if conversionError != nil {
return conversionError
}
result, err = p.db.Exec("DELETE FROM state WHERE key = $1 and xmin = $2", req.Key, etag)
}
return p.returnSingleDBResult(result, err)
}
func (p *postgresDBAccess) ExecuteMulti(sets []state.SetRequest, deletes []state.DeleteRequest) error {
p.logger.Debug("Executing multiple PostgreSQL operations")
tx, err := p.db.Begin()
if err != nil {
return err
}
if len(deletes) > 0 {
for _, d := range deletes {
da := d // Fix for gosec G601: Implicit memory aliasing in for loop.
err = p.Delete(&da)
if err != nil {
tx.Rollback()
return err
}
}
}
if len(sets) > 0 {
for _, s := range sets {
sa := s // Fix for gosec G601: Implicit memory aliasing in for loop.
err = p.Set(&sa)
if err != nil {
tx.Rollback()
return err
}
}
}
err = tx.Commit()
return err
}
// Verifies that the sql.Result affected only one row and no errors exist
func (p *postgresDBAccess) returnSingleDBResult(result sql.Result, err error) error {
if err != nil {
p.logger.Debug(err)
return err
}
rowsAffected, resultErr := result.RowsAffected()
if resultErr != nil {
p.logger.Error(resultErr)
return resultErr
}
if rowsAffected == 0 {
noRowsErr := errors.New("database operation failed: no rows match given key and etag")
p.logger.Error(noRowsErr)
return noRowsErr
}
if rowsAffected > 1 {
tooManyRowsErr := errors.New("database operation failed: more than one row affected, expected one")
p.logger.Error(tooManyRowsErr)
return tooManyRowsErr
}
return nil
}
// Close implements io.Close
func (p *postgresDBAccess) Close() error {
if p.db != nil {
return p.db.Close()
}
return nil
}
func (p *postgresDBAccess) ensureStateTable(stateTableName string) error {
exists, err := tableExists(p.db, stateTableName)
if err != nil {
return err
}
if !exists {
p.logger.Info("Creating PostgreSQL state table")
createTable := fmt.Sprintf(`CREATE TABLE %s (
key text NOT NULL PRIMARY KEY,
value json NOT NULL,
insertdate TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
updatedate TIMESTAMP WITH TIME ZONE NULL);`, stateTableName)
_, err = p.db.Exec(createTable)
if err != nil {
return err
}
}
return nil
}
func tableExists(db *sql.DB, tableName string) (bool, error) {
var exists bool = false
err := db.QueryRow("SELECT EXISTS (SELECT FROM pg_tables where tablename = $1)", tableName).Scan(&exists)
return exists, err
}