287 lines
7.2 KiB
Go
287 lines
7.2 KiB
Go
// ------------------------------------------------------------
|
|
// Copyright (c) Microsoft Corporation.
|
|
// Licensed under the MIT License.
|
|
// ------------------------------------------------------------
|
|
|
|
package postgresql
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/json"
|
|
|
|
"errors"
|
|
"fmt"
|
|
"strconv"
|
|
|
|
"github.com/dapr/components-contrib/state"
|
|
"github.com/dapr/dapr/pkg/logger"
|
|
|
|
// Blank import for the underlying PostgreSQL driver
|
|
_ "github.com/jackc/pgx/v4/stdlib"
|
|
)
|
|
|
|
const (
|
|
connectionStringKey = "connectionString"
|
|
errMissingConnectionString = "missing connection string"
|
|
tableName = "state"
|
|
)
|
|
|
|
// postgresDBAccess implements dbaccess
|
|
type postgresDBAccess struct {
|
|
logger logger.Logger
|
|
metadata state.Metadata
|
|
db *sql.DB
|
|
connectionString string
|
|
}
|
|
|
|
// newPostgresDBAccess creates a new instance of postgresAccess
|
|
func newPostgresDBAccess(logger logger.Logger) *postgresDBAccess {
|
|
logger.Debug("Instantiating new PostgreSQL state store")
|
|
return &postgresDBAccess{
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// Init sets up PostgreSQL connection and ensures that the state table exists
|
|
func (p *postgresDBAccess) Init(metadata state.Metadata) error {
|
|
p.logger.Debug("Initializing PostgreSQL state store")
|
|
p.metadata = metadata
|
|
|
|
if val, ok := metadata.Properties[connectionStringKey]; ok && val != "" {
|
|
p.connectionString = val
|
|
} else {
|
|
p.logger.Error("Missing postgreSQL connection string")
|
|
return fmt.Errorf(errMissingConnectionString)
|
|
}
|
|
|
|
db, err := sql.Open("pgx", p.connectionString)
|
|
if err != nil {
|
|
p.logger.Error(err)
|
|
return err
|
|
}
|
|
|
|
p.db = db
|
|
|
|
pingErr := db.Ping()
|
|
if pingErr != nil {
|
|
return pingErr
|
|
}
|
|
|
|
err = p.ensureStateTable(tableName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Set makes an insert or update to the database.
|
|
func (p *postgresDBAccess) Set(req *state.SetRequest) error {
|
|
return state.SetWithRetries(p.setValue, req)
|
|
}
|
|
|
|
// setValue is an internal implementation of set to enable passing the logic to state.SetWithRetries as a func.
|
|
func (p *postgresDBAccess) setValue(req *state.SetRequest) error {
|
|
p.logger.Debug("Setting state value in PostgreSQL")
|
|
|
|
err := state.CheckSetRequestOptions(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if req.Key == "" {
|
|
return fmt.Errorf("missing key in set operation")
|
|
}
|
|
|
|
var valueBytes []byte
|
|
|
|
// Convert to json string
|
|
valueBytes, err = json.Marshal(req.Value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
value := string(valueBytes)
|
|
|
|
var result sql.Result
|
|
|
|
// Sprintf is required for table name because sql.DB does not substitute parameters for table names.
|
|
// Other parameters use sql.DB parameter substitution.
|
|
if req.ETag == "" {
|
|
result, err = p.db.Exec(fmt.Sprintf(
|
|
`INSERT INTO %s (key, value) VALUES ($1, $2)
|
|
ON CONFLICT (key) DO UPDATE SET value = $2, updatedate = NOW();`,
|
|
tableName), req.Key, value)
|
|
} else {
|
|
// Convert req.ETag to integer for postgres compatibility
|
|
var etag int
|
|
etag, err = strconv.Atoi(req.ETag)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// When an etag is provided do an update - no insert
|
|
result, err = p.db.Exec(fmt.Sprintf(
|
|
`UPDATE %s SET value = $1, updatedate = NOW()
|
|
WHERE key = $2 AND xmin = $3;`,
|
|
tableName), value, req.Key, etag)
|
|
}
|
|
|
|
return p.returnSingleDBResult(result, err)
|
|
}
|
|
|
|
// Get returns data from the database. If data does not exist for the key an empty state.GetResponse will be returned.
|
|
func (p *postgresDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error) {
|
|
p.logger.Debug("Getting state value from PostgreSQL")
|
|
if req.Key == "" {
|
|
return nil, fmt.Errorf("missing key in get operation")
|
|
}
|
|
|
|
var value string
|
|
var etag int
|
|
err := p.db.QueryRow(fmt.Sprintf("SELECT value, xmin as etag FROM %s WHERE key = $1", tableName), req.Key).Scan(&value, &etag)
|
|
if err != nil {
|
|
// If no rows exist, return an empty response, otherwise return the error.
|
|
if err == sql.ErrNoRows {
|
|
return &state.GetResponse{}, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
response := &state.GetResponse{
|
|
Data: []byte(value),
|
|
ETag: strconv.Itoa(etag),
|
|
Metadata: req.Metadata,
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// Delete removes an item from the state store.
|
|
func (p *postgresDBAccess) Delete(req *state.DeleteRequest) error {
|
|
return state.DeleteWithRetries(p.deleteValue, req)
|
|
}
|
|
|
|
// deleteValue is an internal implementation of delete to enable passing the logic to state.DeleteWithRetries as a func.
|
|
func (p *postgresDBAccess) deleteValue(req *state.DeleteRequest) error {
|
|
p.logger.Debug("Deleting state value from PostgreSQL")
|
|
if req.Key == "" {
|
|
return fmt.Errorf("missing key in delete operation")
|
|
}
|
|
|
|
var result sql.Result
|
|
var err error
|
|
|
|
if req.ETag == "" {
|
|
result, err = p.db.Exec("DELETE FROM state WHERE key = $1", req.Key)
|
|
} else {
|
|
// Convert req.ETag to integer for postgres compatibility
|
|
etag, conversionError := strconv.Atoi(req.ETag)
|
|
if conversionError != nil {
|
|
return conversionError
|
|
}
|
|
|
|
result, err = p.db.Exec("DELETE FROM state WHERE key = $1 and xmin = $2", req.Key, etag)
|
|
}
|
|
|
|
return p.returnSingleDBResult(result, err)
|
|
}
|
|
|
|
func (p *postgresDBAccess) ExecuteMulti(sets []state.SetRequest, deletes []state.DeleteRequest) error {
|
|
p.logger.Debug("Executing multiple PostgreSQL operations")
|
|
tx, err := p.db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(deletes) > 0 {
|
|
for _, d := range deletes {
|
|
da := d // Fix for gosec G601: Implicit memory aliasing in for loop.
|
|
err = p.Delete(&da)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(sets) > 0 {
|
|
for _, s := range sets {
|
|
sa := s // Fix for gosec G601: Implicit memory aliasing in for loop.
|
|
err = p.Set(&sa)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
err = tx.Commit()
|
|
return err
|
|
}
|
|
|
|
// Verifies that the sql.Result affected only one row and no errors exist
|
|
func (p *postgresDBAccess) returnSingleDBResult(result sql.Result, err error) error {
|
|
if err != nil {
|
|
p.logger.Debug(err)
|
|
return err
|
|
}
|
|
|
|
rowsAffected, resultErr := result.RowsAffected()
|
|
|
|
if resultErr != nil {
|
|
p.logger.Error(resultErr)
|
|
return resultErr
|
|
}
|
|
|
|
if rowsAffected == 0 {
|
|
noRowsErr := errors.New("database operation failed: no rows match given key and etag")
|
|
p.logger.Error(noRowsErr)
|
|
return noRowsErr
|
|
}
|
|
|
|
if rowsAffected > 1 {
|
|
tooManyRowsErr := errors.New("database operation failed: more than one row affected, expected one")
|
|
p.logger.Error(tooManyRowsErr)
|
|
return tooManyRowsErr
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Close implements io.Close
|
|
func (p *postgresDBAccess) Close() error {
|
|
if p.db != nil {
|
|
return p.db.Close()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *postgresDBAccess) ensureStateTable(stateTableName string) error {
|
|
exists, err := tableExists(p.db, stateTableName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !exists {
|
|
p.logger.Info("Creating PostgreSQL state table")
|
|
createTable := fmt.Sprintf(`CREATE TABLE %s (
|
|
key text NOT NULL PRIMARY KEY,
|
|
value json NOT NULL,
|
|
insertdate TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
|
updatedate TIMESTAMP WITH TIME ZONE NULL);`, stateTableName)
|
|
_, err = p.db.Exec(createTable)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func tableExists(db *sql.DB, tableName string) (bool, error) {
|
|
var exists bool = false
|
|
err := db.QueryRow("SELECT EXISTS (SELECT FROM pg_tables where tablename = $1)", tableName).Scan(&exists)
|
|
return exists, err
|
|
}
|