Add expiredate column to state table

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
ItalyPaleAle 2022-11-21 23:22:57 +00:00
parent d8c2ce47fd
commit ca729de1a9
1 changed files with 97 additions and 25 deletions

View File

@ -20,6 +20,7 @@ import (
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/dapr/components-contrib/metadata"
@ -77,7 +78,6 @@ func (p *postgresDBAccess) Init(meta state.Metadata) error {
if m.ConnectionString == "" {
p.logger.Error("Missing postgreSQL connection string")
return errors.New(errMissingConnectionString)
}
p.connectionString = m.ConnectionString
@ -112,8 +112,6 @@ func (p *postgresDBAccess) Init(meta state.Metadata) error {
// Set makes an insert or update to the database.
func (p *postgresDBAccess) Set(req *state.SetRequest) error {
p.logger.Debug("Setting state value in PostgreSQL")
err := state.CheckRequestOptions(req.Options)
if err != nil {
return err
@ -187,7 +185,6 @@ func (p *postgresDBAccess) Set(req *state.SetRequest) error {
}
func (p *postgresDBAccess) BulkSet(req []state.SetRequest) error {
p.logger.Debug("Executing BulkSet request")
tx, err := p.db.Begin()
if err != nil {
return err
@ -212,7 +209,6 @@ func (p *postgresDBAccess) BulkSet(req []state.SetRequest) error {
// Get returns data from the database. If data does not exist for the key an empty state.GetResponse will be returned.
func (p *postgresDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error) {
p.logger.Debug("Getting state value from PostgreSQL")
if req.Key == "" {
return nil, errors.New("missing key in get operation")
}
@ -261,7 +257,6 @@ func (p *postgresDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error
// Delete removes an item from the state store.
func (p *postgresDBAccess) Delete(req *state.DeleteRequest) (err error) {
p.logger.Debug("Deleting state value from PostgreSQL")
if req.Key == "" {
return errors.New("missing key in delete operation")
}
@ -299,7 +294,6 @@ func (p *postgresDBAccess) Delete(req *state.DeleteRequest) (err error) {
}
func (p *postgresDBAccess) BulkDelete(req []state.DeleteRequest) error {
p.logger.Debug("Executing BulkDelete request")
tx, err := p.db.Begin()
if err != nil {
return err
@ -321,8 +315,6 @@ func (p *postgresDBAccess) BulkDelete(req []state.DeleteRequest) error {
}
func (p *postgresDBAccess) ExecuteMulti(request *state.TransactionalStateRequest) error {
p.logger.Debug("Executing PostgreSQL transaction")
tx, err := p.db.Begin()
if err != nil {
return err
@ -373,10 +365,9 @@ func (p *postgresDBAccess) ExecuteMulti(request *state.TransactionalStateRequest
// Query executes a query against store.
func (p *postgresDBAccess) Query(req *state.QueryRequest) (*state.QueryResponse, error) {
p.logger.Debug("Getting query value from PostgreSQL")
q := &Query{
query: "",
params: []interface{}{},
params: []any{},
tableName: p.tableName,
}
qbuilder := query.NewQueryBuilder(q)
@ -404,33 +395,114 @@ func (p *postgresDBAccess) Close() error {
}
func (p *postgresDBAccess) ensureStateTable(stateTableName string) error {
exists, err := tableExists(p.db, stateTableName)
exists, schema, table, err := tableExists(p.db, stateTableName)
if err != nil {
return err
}
// Create the table if it doesn't exist
if !exists {
p.logger.Info("Creating PostgreSQL state table")
createTable := fmt.Sprintf(`CREATE TABLE %s (
key text NOT NULL PRIMARY KEY,
value jsonb NOT NULL,
isbinary boolean NOT NULL,
insertdate TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
updatedate TIMESTAMP WITH TIME ZONE NULL);`, stateTableName)
_, err = p.db.Exec(createTable)
p.logger.Infof("Creating Postgres state table '%s'", stateTableName)
_, err = p.db.Exec(fmt.Sprintf(
`CREATE TABLE %s (
key text NOT NULL PRIMARY KEY,
value jsonb NOT NULL,
isbinary boolean NOT NULL,
insertdate TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
updatedate TIMESTAMP WITH TIME ZONE NULL,
expiredate TIMESTAMP WITH TIME ZONE
)`,
stateTableName,
))
if err != nil {
return err
return fmt.Errorf("failed to create state table: %w", err)
}
return nil
}
// If the table exists, ensure it has the "expiredate" column
exists, err = tableHasExpiredateCol(p.db, schema, table)
if err != nil {
return err
}
if !exists {
p.logger.Infof("Adding column 'expiredate' to Postgres state table '%s'", stateTableName)
_, err = p.db.Exec(fmt.Sprintf(`ALTER TABLE %s ADD expiredate TIMESTAMP WITH TIME ZONE`, stateTableName))
if err != nil {
return fmt.Errorf("failed to add expiredate column to state table: %w", err)
}
}
return nil
}
func tableExists(db *sql.DB, tableName string) (bool, error) {
exists := false
err := db.QueryRow("SELECT EXISTS (SELECT FROM pg_tables where tablename = $1)", tableName).Scan(&exists)
// If the table exists, returns true and the name of the table and schema
func tableExists(db *sql.DB, tableName string) (exists bool, schema string, table string, err error) {
table, schema, err = tableSchemaName(tableName)
if err != nil {
return false, "", "", err
}
return exists, err
if schema == "" {
err = db.
QueryRow(`
SELECT
table_name, table_schema
FROM
information_schema.tables
WHERE
table_name = $1`, table).
Scan(&table, &schema)
} else {
err = db.
QueryRow(
`SELECT
table_name, table_schema
FROM
information_schema.tables
WHERE
table_schema = $1
AND table_name = $2`, schema, table).
Scan(&table, &schema)
}
if err != nil && errors.Is(err, sql.ErrNoRows) {
return false, "", "", nil
} else if err != nil {
return false, "", "", fmt.Errorf("failed to check if table %s exists: %w", tableName, err)
}
return true, schema, table, nil
}
func tableHasExpiredateCol(db *sql.DB, schema string, table string) (colExists bool, err error) {
err = db.
QueryRow(`SELECT EXISTS (
SELECT 1
FROM
information_schema.columns
WHERE
table_schema = $1
AND table_name = $2
AND column_name='expiredate'
)`, schema, table).
Scan(&colExists)
if err != nil {
return false, fmt.Errorf("failed to check if table %s.%s has 'expiredate' column: %w", schema, table, err)
}
return colExists, nil
}
// If the table name includes a schema (e.g. `schema.table`, returns the two parts separately)
func tableSchemaName(tableName string) (table string, schema string, err error) {
parts := strings.Split(tableName, ".")
switch len(parts) {
case 1:
return parts[0], "", nil
case 2:
return parts[1], parts[0], nil
default:
return "", "", errors.New("invalid table name: must be in the format 'table' or 'schema.table'")
}
}
// Returns the set requests.