Add expiredate column to state table

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
ItalyPaleAle 2022-11-21 23:22:57 +00:00
parent d8c2ce47fd
commit ca729de1a9
1 changed files with 97 additions and 25 deletions

View File

@ -20,6 +20,7 @@ import (
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/dapr/components-contrib/metadata"
@ -77,7 +78,6 @@ func (p *postgresDBAccess) Init(meta state.Metadata) error {
if m.ConnectionString == "" {
p.logger.Error("Missing postgreSQL connection string")
return errors.New(errMissingConnectionString)
}
p.connectionString = m.ConnectionString
@ -112,8 +112,6 @@ func (p *postgresDBAccess) Init(meta state.Metadata) error {
// Set makes an insert or update to the database.
func (p *postgresDBAccess) Set(req *state.SetRequest) error {
p.logger.Debug("Setting state value in PostgreSQL")
err := state.CheckRequestOptions(req.Options)
if err != nil {
return err
@ -187,7 +185,6 @@ func (p *postgresDBAccess) Set(req *state.SetRequest) error {
}
func (p *postgresDBAccess) BulkSet(req []state.SetRequest) error {
p.logger.Debug("Executing BulkSet request")
tx, err := p.db.Begin()
if err != nil {
return err
@ -212,7 +209,6 @@ func (p *postgresDBAccess) BulkSet(req []state.SetRequest) error {
// Get returns data from the database. If data does not exist for the key an empty state.GetResponse will be returned.
func (p *postgresDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error) {
p.logger.Debug("Getting state value from PostgreSQL")
if req.Key == "" {
return nil, errors.New("missing key in get operation")
}
@ -261,7 +257,6 @@ func (p *postgresDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error
// Delete removes an item from the state store.
func (p *postgresDBAccess) Delete(req *state.DeleteRequest) (err error) {
p.logger.Debug("Deleting state value from PostgreSQL")
if req.Key == "" {
return errors.New("missing key in delete operation")
}
@ -299,7 +294,6 @@ func (p *postgresDBAccess) Delete(req *state.DeleteRequest) (err error) {
}
func (p *postgresDBAccess) BulkDelete(req []state.DeleteRequest) error {
p.logger.Debug("Executing BulkDelete request")
tx, err := p.db.Begin()
if err != nil {
return err
@ -321,8 +315,6 @@ func (p *postgresDBAccess) BulkDelete(req []state.DeleteRequest) error {
}
func (p *postgresDBAccess) ExecuteMulti(request *state.TransactionalStateRequest) error {
p.logger.Debug("Executing PostgreSQL transaction")
tx, err := p.db.Begin()
if err != nil {
return err
@ -373,10 +365,9 @@ func (p *postgresDBAccess) ExecuteMulti(request *state.TransactionalStateRequest
// Query executes a query against store.
func (p *postgresDBAccess) Query(req *state.QueryRequest) (*state.QueryResponse, error) {
p.logger.Debug("Getting query value from PostgreSQL")
q := &Query{
query: "",
params: []interface{}{},
params: []any{},
tableName: p.tableName,
}
qbuilder := query.NewQueryBuilder(q)
@ -404,33 +395,114 @@ func (p *postgresDBAccess) Close() error {
}
func (p *postgresDBAccess) ensureStateTable(stateTableName string) error {
exists, err := tableExists(p.db, stateTableName)
exists, schema, table, err := tableExists(p.db, stateTableName)
if err != nil {
return err
}
// Create the table if it doesn't exist
if !exists {
p.logger.Info("Creating PostgreSQL state table")
createTable := fmt.Sprintf(`CREATE TABLE %s (
p.logger.Infof("Creating Postgres state table '%s'", stateTableName)
_, err = p.db.Exec(fmt.Sprintf(
`CREATE TABLE %s (
key text NOT NULL PRIMARY KEY,
value jsonb NOT NULL,
isbinary boolean NOT NULL,
insertdate TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
updatedate TIMESTAMP WITH TIME ZONE NULL);`, stateTableName)
_, err = p.db.Exec(createTable)
updatedate TIMESTAMP WITH TIME ZONE NULL,
expiredate TIMESTAMP WITH TIME ZONE
)`,
stateTableName,
))
if err != nil {
return fmt.Errorf("failed to create state table: %w", err)
}
return nil
}
// If the table exists, ensure it has the "expiredate" column
exists, err = tableHasExpiredateCol(p.db, schema, table)
if err != nil {
return err
}
if !exists {
p.logger.Infof("Adding column 'expiredate' to Postgres state table '%s'", stateTableName)
_, err = p.db.Exec(fmt.Sprintf(`ALTER TABLE %s ADD expiredate TIMESTAMP WITH TIME ZONE`, stateTableName))
if err != nil {
return fmt.Errorf("failed to add expiredate column to state table: %w", err)
}
}
return nil
}
func tableExists(db *sql.DB, tableName string) (bool, error) {
exists := false
err := db.QueryRow("SELECT EXISTS (SELECT FROM pg_tables where tablename = $1)", tableName).Scan(&exists)
// If the table exists, returns true and the name of the table and schema
func tableExists(db *sql.DB, tableName string) (exists bool, schema string, table string, err error) {
table, schema, err = tableSchemaName(tableName)
if err != nil {
return false, "", "", err
}
return exists, err
if schema == "" {
err = db.
QueryRow(`
SELECT
table_name, table_schema
FROM
information_schema.tables
WHERE
table_name = $1`, table).
Scan(&table, &schema)
} else {
err = db.
QueryRow(
`SELECT
table_name, table_schema
FROM
information_schema.tables
WHERE
table_schema = $1
AND table_name = $2`, schema, table).
Scan(&table, &schema)
}
if err != nil && errors.Is(err, sql.ErrNoRows) {
return false, "", "", nil
} else if err != nil {
return false, "", "", fmt.Errorf("failed to check if table %s exists: %w", tableName, err)
}
return true, schema, table, nil
}
func tableHasExpiredateCol(db *sql.DB, schema string, table string) (colExists bool, err error) {
err = db.
QueryRow(`SELECT EXISTS (
SELECT 1
FROM
information_schema.columns
WHERE
table_schema = $1
AND table_name = $2
AND column_name='expiredate'
)`, schema, table).
Scan(&colExists)
if err != nil {
return false, fmt.Errorf("failed to check if table %s.%s has 'expiredate' column: %w", schema, table, err)
}
return colExists, nil
}
// If the table name includes a schema (e.g. `schema.table`, returns the two parts separately)
func tableSchemaName(tableName string) (table string, schema string, err error) {
parts := strings.Split(tableName, ".")
switch len(parts) {
case 1:
return parts[0], "", nil
case 2:
return parts[1], parts[0], nil
default:
return "", "", errors.New("invalid table name: must be in the format 'table' or 'schema.table'")
}
}
// Returns the set requests.