Fix support for custom state store table names (#2724)
Signed-off-by: Bernd Verst <github@bernd.dev>
This commit is contained in:
parent
b4cdc7a3ca
commit
9c371b6362
|
@ -39,7 +39,7 @@ import (
|
|||
const (
|
||||
connectionStringKey = "connectionString"
|
||||
errMissingConnectionString = "missing connection string"
|
||||
tableName = "state"
|
||||
defaultTableName = "state"
|
||||
defaultMaxConnectionAttempts = 5 // A bad driver connection error can occur inside the sql code so this essentially allows for more retries since the sql code does not allow that to be changed
|
||||
)
|
||||
|
||||
|
@ -70,7 +70,9 @@ func newCockroachDBAccess(logger logger.Logger) *cockroachDBAccess {
|
|||
}
|
||||
|
||||
func parseMetadata(meta state.Metadata) (*cockroachDBMetadata, error) {
|
||||
m := cockroachDBMetadata{}
|
||||
m := cockroachDBMetadata{
|
||||
TableName: defaultTableName,
|
||||
}
|
||||
metadata.DecodeMetadata(meta.Properties, &m)
|
||||
|
||||
if m.ConnectionString == "" {
|
||||
|
@ -111,7 +113,7 @@ func (p *cockroachDBAccess) Init(metadata state.Metadata) error {
|
|||
return err
|
||||
}
|
||||
|
||||
if err = p.ensureStateTable(tableName); err != nil {
|
||||
if err = p.ensureStateTable(p.metadata.TableName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -141,7 +143,7 @@ func (p *cockroachDBAccess) Set(ctx context.Context, req *state.SetRequest) erro
|
|||
result, err = p.db.ExecContext(ctx, fmt.Sprintf(
|
||||
`INSERT INTO %s (key, value, isbinary, etag) VALUES ($1, $2, $3, 1)
|
||||
ON CONFLICT (key) DO UPDATE SET value = $2, isbinary = $3, updatedate = NOW(), etag = EXCLUDED.etag + 1;`,
|
||||
tableName), req.Key, value, isBinary)
|
||||
p.metadata.TableName), req.Key, value, isBinary)
|
||||
} else {
|
||||
var etag64 uint64
|
||||
etag64, err = strconv.ParseUint(*req.ETag, 10, 32)
|
||||
|
@ -154,7 +156,7 @@ func (p *cockroachDBAccess) Set(ctx context.Context, req *state.SetRequest) erro
|
|||
result, err = p.db.ExecContext(ctx, fmt.Sprintf(
|
||||
`UPDATE %s SET value = $1, isbinary = $2, updatedate = NOW(), etag = etag + 1
|
||||
WHERE key = $3 AND etag = $4;`,
|
||||
tableName), value, isBinary, req.Key, etag)
|
||||
p.metadata.TableName), value, isBinary, req.Key, etag)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
@ -208,7 +210,7 @@ func (p *cockroachDBAccess) Get(ctx context.Context, req *state.GetRequest) (*st
|
|||
var value string
|
||||
var isBinary bool
|
||||
var etag int
|
||||
err := p.db.QueryRowContext(ctx, fmt.Sprintf("SELECT value, isbinary, etag FROM %s WHERE key = $1", tableName), req.Key).Scan(&value, &isBinary, &etag)
|
||||
err := p.db.QueryRowContext(ctx, fmt.Sprintf("SELECT value, isbinary, etag FROM %s WHERE key = $1", p.metadata.TableName), req.Key).Scan(&value, &isBinary, &etag)
|
||||
if err != nil {
|
||||
// If no rows exist, return an empty response, otherwise return the error.
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
|
@ -258,7 +260,7 @@ func (p *cockroachDBAccess) Delete(ctx context.Context, req *state.DeleteRequest
|
|||
var err error
|
||||
|
||||
if req.ETag == nil {
|
||||
result, err = p.db.ExecContext(ctx, "DELETE FROM state WHERE key = $1", req.Key)
|
||||
result, err = p.db.ExecContext(ctx, "DELETE FROM "+p.metadata.TableName+" WHERE key = $1", req.Key) //nolint:gosec
|
||||
} else {
|
||||
var etag64 uint64
|
||||
etag64, err = strconv.ParseUint(*req.ETag, 10, 32)
|
||||
|
@ -267,7 +269,7 @@ func (p *cockroachDBAccess) Delete(ctx context.Context, req *state.DeleteRequest
|
|||
}
|
||||
etag := uint32(etag64)
|
||||
|
||||
result, err = p.db.ExecContext(ctx, "DELETE FROM state WHERE key = $1 and etag = $2", req.Key, etag)
|
||||
result, err = p.db.ExecContext(ctx, "DELETE FROM "+p.metadata.TableName+" WHERE key = $1 and etag = $2", req.Key, etag) //nolint:gosec
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
|
|
@ -34,7 +34,7 @@ const (
|
|||
)
|
||||
|
||||
type fakeItem struct {
|
||||
Color string
|
||||
Color string `json:"color"`
|
||||
}
|
||||
|
||||
func TestCockroachDBIntegration(t *testing.T) {
|
||||
|
@ -699,7 +699,7 @@ func storeItemExists(t *testing.T, key string) bool {
|
|||
defer databaseConnection.Close()
|
||||
|
||||
exists := false
|
||||
statement := fmt.Sprintf(`SELECT EXISTS (SELECT * FROM %s WHERE key = $1)`, tableName)
|
||||
statement := fmt.Sprintf(`SELECT EXISTS (SELECT * FROM %s WHERE key = $1)`, defaultTableName)
|
||||
err = databaseConnection.QueryRow(statement, key).Scan(&exists)
|
||||
assert.Nil(t, err)
|
||||
|
||||
|
@ -713,7 +713,7 @@ func getRowData(t *testing.T, key string) (returnValue string, insertdate sql.Nu
|
|||
assert.Nil(t, err)
|
||||
defer databaseConnection.Close()
|
||||
|
||||
err = databaseConnection.QueryRow(fmt.Sprintf("SELECT value, insertdate, updatedate FROM %s WHERE key = $1", tableName), key).Scan(&returnValue, &insertdate, &updatedate)
|
||||
err = databaseConnection.QueryRow(fmt.Sprintf("SELECT value, insertdate, updatedate FROM %s WHERE key = $1", defaultTableName), key).Scan(&returnValue, &insertdate, &updatedate)
|
||||
assert.Nil(t, err)
|
||||
|
||||
return returnValue, insertdate, updatedate
|
||||
|
|
|
@ -96,7 +96,7 @@ func (q *Query) VisitOR(filter *query.OR) (string, error) {
|
|||
}
|
||||
|
||||
func (q *Query) Finalize(filters string, storeQuery *query.Query) error {
|
||||
q.query = fmt.Sprintf("SELECT key, value, etag FROM %s", tableName)
|
||||
q.query = fmt.Sprintf("SELECT key, value, etag FROM %s", defaultTableName)
|
||||
|
||||
if filters != "" {
|
||||
q.query += fmt.Sprintf(" WHERE %s", filters)
|
||||
|
|
|
@ -38,7 +38,7 @@ const (
|
|||
)
|
||||
|
||||
type fakeItem struct {
|
||||
Color string
|
||||
Color string `json:"color"`
|
||||
}
|
||||
|
||||
func TestOracleDatabaseIntegration(t *testing.T) {
|
||||
|
@ -817,7 +817,7 @@ func storeItemExists(t *testing.T, key string) bool {
|
|||
assert.Nil(t, err)
|
||||
defer db.Close()
|
||||
var rowCount int32
|
||||
statement := fmt.Sprintf(`SELECT count(key) FROM %s WHERE key = :key`, tableName)
|
||||
statement := fmt.Sprintf(`SELECT count(key) FROM %s WHERE key = :key`, defaultTableName)
|
||||
err = db.QueryRow(statement, key).Scan(&rowCount)
|
||||
assert.Nil(t, err)
|
||||
exists := rowCount > 0
|
||||
|
@ -832,7 +832,7 @@ func getRowData(t *testing.T, key string) (returnValue string, insertdate sql.Nu
|
|||
db, err := sql.Open("oracle", connectionString)
|
||||
assert.Nil(t, err)
|
||||
defer db.Close()
|
||||
err = db.QueryRow(fmt.Sprintf("SELECT value, creation_time, update_time FROM %s WHERE key = :key", tableName), key).Scan(&returnValue, &insertdate, &updatedate)
|
||||
err = db.QueryRow(fmt.Sprintf("SELECT value, creation_time, update_time FROM %s WHERE key = :key", defaultTableName), key).Scan(&returnValue, &insertdate, &updatedate)
|
||||
assert.Nil(t, err)
|
||||
|
||||
return returnValue, insertdate, updatedate
|
||||
|
@ -846,7 +846,7 @@ func getTimesForRow(t *testing.T, key string) (insertdate sql.NullString, update
|
|||
db, err := sql.Open("oracle", connectionString)
|
||||
assert.Nil(t, err)
|
||||
defer db.Close()
|
||||
err = db.QueryRow(fmt.Sprintf("SELECT creation_time, update_time, expiration_time FROM %s WHERE key = :key", tableName), key).Scan(&insertdate, &updatedate, &expirationtime)
|
||||
err = db.QueryRow(fmt.Sprintf("SELECT creation_time, update_time, expiration_time FROM %s WHERE key = :key", defaultTableName), key).Scan(&insertdate, &updatedate, &expirationtime)
|
||||
assert.Nil(t, err)
|
||||
|
||||
return insertdate, updatedate, expirationtime
|
||||
|
|
|
@ -36,7 +36,7 @@ const (
|
|||
connectionStringKey = "connectionString"
|
||||
oracleWalletLocationKey = "oracleWalletLocation"
|
||||
errMissingConnectionString = "missing connection string"
|
||||
tableName = "state"
|
||||
defaultTableName = "state"
|
||||
)
|
||||
|
||||
// oracleDatabaseAccess implements dbaccess.
|
||||
|
@ -69,7 +69,7 @@ func (o *oracleDatabaseAccess) Ping() error {
|
|||
|
||||
func parseMetadata(meta map[string]string) (oracleDatabaseMetadata, error) {
|
||||
m := oracleDatabaseMetadata{
|
||||
TableName: "state",
|
||||
TableName: defaultTableName,
|
||||
}
|
||||
err := metadata.DecodeMetadata(meta, &m)
|
||||
return m, err
|
||||
|
@ -105,7 +105,7 @@ func (o *oracleDatabaseAccess) Init(metadata state.Metadata) error {
|
|||
if pingErr := db.Ping(); pingErr != nil {
|
||||
return pingErr
|
||||
}
|
||||
err = o.ensureStateTable(tableName)
|
||||
err = o.ensureStateTable(o.metadata.TableName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -168,20 +168,22 @@ func (o *oracleDatabaseAccess) Set(ctx context.Context, req *state.SetRequest) e
|
|||
// Sprintf is required for table name because sql.DB does not substitute parameters for table names.
|
||||
// Other parameters use sql.DB parameter substitution.
|
||||
// As per Discord Thread https://discord.com/channels/778680217417809931/901141713089863710/938520959562952735 expiration time is reset in case of an update.
|
||||
//nolint:gosec
|
||||
mergeStatement := fmt.Sprintf(
|
||||
`MERGE INTO %s t using (select :key key, :value value, :binary_yn binary_yn, :etag etag , :ttl_in_seconds ttl_in_seconds from dual) new_state_to_store
|
||||
ON (t.key = new_state_to_store.key )
|
||||
WHEN MATCHED THEN UPDATE SET value = new_state_to_store.value, binary_yn = new_state_to_store.binary_yn, update_time = systimestamp, etag = new_state_to_store.etag, t.expiration_time = case when new_state_to_store.ttl_in_seconds >0 then systimestamp + numtodsinterval(new_state_to_store.ttl_in_seconds, 'SECOND') end
|
||||
WHEN NOT MATCHED THEN INSERT (t.key, t.value, t.binary_yn, t.etag, t.expiration_time) values (new_state_to_store.key, new_state_to_store.value, new_state_to_store.binary_yn, new_state_to_store.etag, case when new_state_to_store.ttl_in_seconds >0 then systimestamp + numtodsinterval(new_state_to_store.ttl_in_seconds, 'SECOND') end ) `,
|
||||
tableName)
|
||||
o.metadata.TableName)
|
||||
result, err = tx.ExecContext(ctx, mergeStatement, req.Key, value, binaryYN, etag, ttlSeconds)
|
||||
} else {
|
||||
// when first write policy is indicated, an existing record has to be updated - one that has the etag provided.
|
||||
// TODO: Needs to update ttl_in_seconds
|
||||
//nolint:gosec
|
||||
updateStatement := fmt.Sprintf(
|
||||
`UPDATE %s SET value = :value, binary_yn = :binary_yn, etag = :new_etag
|
||||
WHERE key = :key AND etag = :etag`,
|
||||
tableName)
|
||||
o.metadata.TableName)
|
||||
result, err = tx.ExecContext(ctx, updateStatement, value, binaryYN, etag, req.Key, *req.ETag)
|
||||
}
|
||||
if err != nil {
|
||||
|
@ -215,7 +217,7 @@ func (o *oracleDatabaseAccess) Get(ctx context.Context, req *state.GetRequest) (
|
|||
var value string
|
||||
var binaryYN string
|
||||
var etag string
|
||||
err := o.db.QueryRowContext(ctx, fmt.Sprintf("SELECT value, binary_yn, etag FROM %s WHERE key = :key and (expiration_time is null or expiration_time > systimestamp)", tableName), req.Key).Scan(&value, &binaryYN, &etag)
|
||||
err := o.db.QueryRowContext(ctx, fmt.Sprintf("SELECT value, binary_yn, etag FROM %s WHERE key = :key and (expiration_time is null or expiration_time > systimestamp)", o.metadata.TableName), req.Key).Scan(&value, &binaryYN, &etag)
|
||||
if err != nil {
|
||||
// If no rows exist, return an empty response, otherwise return the error.
|
||||
if err == sql.ErrNoRows {
|
||||
|
@ -268,9 +270,11 @@ func (o *oracleDatabaseAccess) Delete(ctx context.Context, req *state.DeleteRequ
|
|||
}
|
||||
// QUESTION: only check for etag if FirstWrite specified - or always when etag is supplied??
|
||||
if req.Options.Concurrency != state.FirstWrite {
|
||||
result, err = tx.ExecContext(ctx, "DELETE FROM state WHERE key = :key", req.Key)
|
||||
//nolint:gosec
|
||||
result, err = tx.ExecContext(ctx, "DELETE FROM "+o.metadata.TableName+" WHERE key = :key", req.Key)
|
||||
} else {
|
||||
result, err = tx.ExecContext(ctx, "DELETE FROM state WHERE key = :key and etag = :etag", req.Key, *req.ETag)
|
||||
//nolint:gosec
|
||||
result, err = tx.ExecContext(ctx, "DELETE FROM "+o.metadata.TableName+" WHERE key = :key and etag = :etag", req.Key, *req.ETag)
|
||||
}
|
||||
if err != nil {
|
||||
if o.tx == nil { // not joining a preexisting transaction.
|
||||
|
|
|
@ -323,7 +323,7 @@ func (p *PostgresDBAccess) doDelete(parentCtx context.Context, db dbquerier, req
|
|||
|
||||
var result pgconn.CommandTag
|
||||
if req.ETag == nil || *req.ETag == "" {
|
||||
result, err = db.Exec(parentCtx, "DELETE FROM state WHERE key = $1", req.Key)
|
||||
result, err = db.Exec(parentCtx, "DELETE FROM "+p.metadata.TableName+" WHERE key = $1", req.Key)
|
||||
} else {
|
||||
// Convert req.ETag to uint32 for postgres XID compatibility
|
||||
var etag64 uint64
|
||||
|
@ -332,7 +332,7 @@ func (p *PostgresDBAccess) doDelete(parentCtx context.Context, db dbquerier, req
|
|||
return state.NewETagError(state.ETagInvalid, err)
|
||||
}
|
||||
|
||||
result, err = db.Exec(parentCtx, "DELETE FROM state WHERE key = $1 AND xmin = $2", req.Key, uint32(etag64))
|
||||
result, err = db.Exec(parentCtx, "DELETE FROM "+p.metadata.TableName+" WHERE key = $1 AND xmin = $2", req.Key, uint32(etag64))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in New Issue