Add expiredate column to state table
Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
parent
d8c2ce47fd
commit
ca729de1a9
|
@ -20,6 +20,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dapr/components-contrib/metadata"
|
||||
|
@ -77,7 +78,6 @@ func (p *postgresDBAccess) Init(meta state.Metadata) error {
|
|||
|
||||
if m.ConnectionString == "" {
|
||||
p.logger.Error("Missing postgreSQL connection string")
|
||||
|
||||
return errors.New(errMissingConnectionString)
|
||||
}
|
||||
p.connectionString = m.ConnectionString
|
||||
|
@ -112,8 +112,6 @@ func (p *postgresDBAccess) Init(meta state.Metadata) error {
|
|||
|
||||
// Set makes an insert or update to the database.
|
||||
func (p *postgresDBAccess) Set(req *state.SetRequest) error {
|
||||
p.logger.Debug("Setting state value in PostgreSQL")
|
||||
|
||||
err := state.CheckRequestOptions(req.Options)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -187,7 +185,6 @@ func (p *postgresDBAccess) Set(req *state.SetRequest) error {
|
|||
}
|
||||
|
||||
func (p *postgresDBAccess) BulkSet(req []state.SetRequest) error {
|
||||
p.logger.Debug("Executing BulkSet request")
|
||||
tx, err := p.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -212,7 +209,6 @@ func (p *postgresDBAccess) BulkSet(req []state.SetRequest) error {
|
|||
|
||||
// Get returns data from the database. If data does not exist for the key an empty state.GetResponse will be returned.
|
||||
func (p *postgresDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error) {
|
||||
p.logger.Debug("Getting state value from PostgreSQL")
|
||||
if req.Key == "" {
|
||||
return nil, errors.New("missing key in get operation")
|
||||
}
|
||||
|
@ -261,7 +257,6 @@ func (p *postgresDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error
|
|||
|
||||
// Delete removes an item from the state store.
|
||||
func (p *postgresDBAccess) Delete(req *state.DeleteRequest) (err error) {
|
||||
p.logger.Debug("Deleting state value from PostgreSQL")
|
||||
if req.Key == "" {
|
||||
return errors.New("missing key in delete operation")
|
||||
}
|
||||
|
@ -299,7 +294,6 @@ func (p *postgresDBAccess) Delete(req *state.DeleteRequest) (err error) {
|
|||
}
|
||||
|
||||
func (p *postgresDBAccess) BulkDelete(req []state.DeleteRequest) error {
|
||||
p.logger.Debug("Executing BulkDelete request")
|
||||
tx, err := p.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -321,8 +315,6 @@ func (p *postgresDBAccess) BulkDelete(req []state.DeleteRequest) error {
|
|||
}
|
||||
|
||||
func (p *postgresDBAccess) ExecuteMulti(request *state.TransactionalStateRequest) error {
|
||||
p.logger.Debug("Executing PostgreSQL transaction")
|
||||
|
||||
tx, err := p.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -373,10 +365,9 @@ func (p *postgresDBAccess) ExecuteMulti(request *state.TransactionalStateRequest
|
|||
|
||||
// Query executes a query against store.
|
||||
func (p *postgresDBAccess) Query(req *state.QueryRequest) (*state.QueryResponse, error) {
|
||||
p.logger.Debug("Getting query value from PostgreSQL")
|
||||
q := &Query{
|
||||
query: "",
|
||||
params: []interface{}{},
|
||||
params: []any{},
|
||||
tableName: p.tableName,
|
||||
}
|
||||
qbuilder := query.NewQueryBuilder(q)
|
||||
|
@ -404,33 +395,114 @@ func (p *postgresDBAccess) Close() error {
|
|||
}
|
||||
|
||||
func (p *postgresDBAccess) ensureStateTable(stateTableName string) error {
|
||||
exists, err := tableExists(p.db, stateTableName)
|
||||
exists, schema, table, err := tableExists(p.db, stateTableName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create the table if it doesn't exist
|
||||
if !exists {
|
||||
p.logger.Info("Creating PostgreSQL state table")
|
||||
createTable := fmt.Sprintf(`CREATE TABLE %s (
|
||||
key text NOT NULL PRIMARY KEY,
|
||||
value jsonb NOT NULL,
|
||||
isbinary boolean NOT NULL,
|
||||
insertdate TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
updatedate TIMESTAMP WITH TIME ZONE NULL);`, stateTableName)
|
||||
_, err = p.db.Exec(createTable)
|
||||
p.logger.Infof("Creating Postgres state table '%s'", stateTableName)
|
||||
_, err = p.db.Exec(fmt.Sprintf(
|
||||
`CREATE TABLE %s (
|
||||
key text NOT NULL PRIMARY KEY,
|
||||
value jsonb NOT NULL,
|
||||
isbinary boolean NOT NULL,
|
||||
insertdate TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
updatedate TIMESTAMP WITH TIME ZONE NULL,
|
||||
expiredate TIMESTAMP WITH TIME ZONE
|
||||
)`,
|
||||
stateTableName,
|
||||
))
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to create state table: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// If the table exists, ensure it has the "expiredate" column
|
||||
exists, err = tableHasExpiredateCol(p.db, schema, table)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !exists {
|
||||
p.logger.Infof("Adding column 'expiredate' to Postgres state table '%s'", stateTableName)
|
||||
_, err = p.db.Exec(fmt.Sprintf(`ALTER TABLE %s ADD expiredate TIMESTAMP WITH TIME ZONE`, stateTableName))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add expiredate column to state table: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func tableExists(db *sql.DB, tableName string) (bool, error) {
|
||||
exists := false
|
||||
err := db.QueryRow("SELECT EXISTS (SELECT FROM pg_tables where tablename = $1)", tableName).Scan(&exists)
|
||||
// If the table exists, returns true and the name of the table and schema
|
||||
func tableExists(db *sql.DB, tableName string) (exists bool, schema string, table string, err error) {
|
||||
table, schema, err = tableSchemaName(tableName)
|
||||
if err != nil {
|
||||
return false, "", "", err
|
||||
}
|
||||
|
||||
return exists, err
|
||||
if schema == "" {
|
||||
err = db.
|
||||
QueryRow(`
|
||||
SELECT
|
||||
table_name, table_schema
|
||||
FROM
|
||||
information_schema.tables
|
||||
WHERE
|
||||
table_name = $1`, table).
|
||||
Scan(&table, &schema)
|
||||
} else {
|
||||
err = db.
|
||||
QueryRow(
|
||||
`SELECT
|
||||
table_name, table_schema
|
||||
FROM
|
||||
information_schema.tables
|
||||
WHERE
|
||||
table_schema = $1
|
||||
AND table_name = $2`, schema, table).
|
||||
Scan(&table, &schema)
|
||||
}
|
||||
|
||||
if err != nil && errors.Is(err, sql.ErrNoRows) {
|
||||
return false, "", "", nil
|
||||
} else if err != nil {
|
||||
return false, "", "", fmt.Errorf("failed to check if table %s exists: %w", tableName, err)
|
||||
}
|
||||
return true, schema, table, nil
|
||||
}
|
||||
|
||||
func tableHasExpiredateCol(db *sql.DB, schema string, table string) (colExists bool, err error) {
|
||||
err = db.
|
||||
QueryRow(`SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM
|
||||
information_schema.columns
|
||||
WHERE
|
||||
table_schema = $1
|
||||
AND table_name = $2
|
||||
AND column_name='expiredate'
|
||||
)`, schema, table).
|
||||
Scan(&colExists)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check if table %s.%s has 'expiredate' column: %w", schema, table, err)
|
||||
}
|
||||
return colExists, nil
|
||||
}
|
||||
|
||||
// If the table name includes a schema (e.g. `schema.table`, returns the two parts separately)
|
||||
func tableSchemaName(tableName string) (table string, schema string, err error) {
|
||||
parts := strings.Split(tableName, ".")
|
||||
switch len(parts) {
|
||||
case 1:
|
||||
return parts[0], "", nil
|
||||
case 2:
|
||||
return parts[1], parts[0], nil
|
||||
default:
|
||||
return "", "", errors.New("invalid table name: must be in the format 'table' or 'schema.table'")
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the set requests.
|
||||
|
|
Loading…
Reference in New Issue