Fix support for custom state store table names (#2724)

Signed-off-by: Bernd Verst <github@bernd.dev>
This commit is contained in:
Bernd Verst 2023-03-29 23:46:00 -07:00 committed by GitHub
parent b4cdc7a3ca
commit 9c371b6362
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 32 additions and 26 deletions

View File

@ -39,7 +39,7 @@ import (
const (
connectionStringKey = "connectionString"
errMissingConnectionString = "missing connection string"
tableName = "state"
defaultTableName = "state"
defaultMaxConnectionAttempts = 5 // A bad driver connection error can occur inside the sql code so this essentially allows for more retries since the sql code does not allow that to be changed
)
@ -70,7 +70,9 @@ func newCockroachDBAccess(logger logger.Logger) *cockroachDBAccess {
}
func parseMetadata(meta state.Metadata) (*cockroachDBMetadata, error) {
m := cockroachDBMetadata{}
m := cockroachDBMetadata{
TableName: defaultTableName,
}
metadata.DecodeMetadata(meta.Properties, &m)
if m.ConnectionString == "" {
@ -111,7 +113,7 @@ func (p *cockroachDBAccess) Init(metadata state.Metadata) error {
return err
}
if err = p.ensureStateTable(tableName); err != nil {
if err = p.ensureStateTable(p.metadata.TableName); err != nil {
return err
}
@ -141,7 +143,7 @@ func (p *cockroachDBAccess) Set(ctx context.Context, req *state.SetRequest) erro
result, err = p.db.ExecContext(ctx, fmt.Sprintf(
`INSERT INTO %s (key, value, isbinary, etag) VALUES ($1, $2, $3, 1)
ON CONFLICT (key) DO UPDATE SET value = $2, isbinary = $3, updatedate = NOW(), etag = EXCLUDED.etag + 1;`,
tableName), req.Key, value, isBinary)
p.metadata.TableName), req.Key, value, isBinary)
} else {
var etag64 uint64
etag64, err = strconv.ParseUint(*req.ETag, 10, 32)
@ -154,7 +156,7 @@ func (p *cockroachDBAccess) Set(ctx context.Context, req *state.SetRequest) erro
result, err = p.db.ExecContext(ctx, fmt.Sprintf(
`UPDATE %s SET value = $1, isbinary = $2, updatedate = NOW(), etag = etag + 1
WHERE key = $3 AND etag = $4;`,
tableName), value, isBinary, req.Key, etag)
p.metadata.TableName), value, isBinary, req.Key, etag)
}
if err != nil {
@ -208,7 +210,7 @@ func (p *cockroachDBAccess) Get(ctx context.Context, req *state.GetRequest) (*st
var value string
var isBinary bool
var etag int
err := p.db.QueryRowContext(ctx, fmt.Sprintf("SELECT value, isbinary, etag FROM %s WHERE key = $1", tableName), req.Key).Scan(&value, &isBinary, &etag)
err := p.db.QueryRowContext(ctx, fmt.Sprintf("SELECT value, isbinary, etag FROM %s WHERE key = $1", p.metadata.TableName), req.Key).Scan(&value, &isBinary, &etag)
if err != nil {
// If no rows exist, return an empty response, otherwise return the error.
if errors.Is(err, sql.ErrNoRows) {
@ -258,7 +260,7 @@ func (p *cockroachDBAccess) Delete(ctx context.Context, req *state.DeleteRequest
var err error
if req.ETag == nil {
result, err = p.db.ExecContext(ctx, "DELETE FROM state WHERE key = $1", req.Key)
result, err = p.db.ExecContext(ctx, "DELETE FROM "+p.metadata.TableName+" WHERE key = $1", req.Key) //nolint:gosec
} else {
var etag64 uint64
etag64, err = strconv.ParseUint(*req.ETag, 10, 32)
@ -267,7 +269,7 @@ func (p *cockroachDBAccess) Delete(ctx context.Context, req *state.DeleteRequest
}
etag := uint32(etag64)
result, err = p.db.ExecContext(ctx, "DELETE FROM state WHERE key = $1 and etag = $2", req.Key, etag)
result, err = p.db.ExecContext(ctx, "DELETE FROM "+p.metadata.TableName+" WHERE key = $1 and etag = $2", req.Key, etag) //nolint:gosec
}
if err != nil {

View File

@ -34,7 +34,7 @@ const (
)
type fakeItem struct {
Color string
Color string `json:"color"`
}
func TestCockroachDBIntegration(t *testing.T) {
@ -699,7 +699,7 @@ func storeItemExists(t *testing.T, key string) bool {
defer databaseConnection.Close()
exists := false
statement := fmt.Sprintf(`SELECT EXISTS (SELECT * FROM %s WHERE key = $1)`, tableName)
statement := fmt.Sprintf(`SELECT EXISTS (SELECT * FROM %s WHERE key = $1)`, defaultTableName)
err = databaseConnection.QueryRow(statement, key).Scan(&exists)
assert.Nil(t, err)
@ -713,7 +713,7 @@ func getRowData(t *testing.T, key string) (returnValue string, insertdate sql.Nu
assert.Nil(t, err)
defer databaseConnection.Close()
err = databaseConnection.QueryRow(fmt.Sprintf("SELECT value, insertdate, updatedate FROM %s WHERE key = $1", tableName), key).Scan(&returnValue, &insertdate, &updatedate)
err = databaseConnection.QueryRow(fmt.Sprintf("SELECT value, insertdate, updatedate FROM %s WHERE key = $1", defaultTableName), key).Scan(&returnValue, &insertdate, &updatedate)
assert.Nil(t, err)
return returnValue, insertdate, updatedate

View File

@ -96,7 +96,7 @@ func (q *Query) VisitOR(filter *query.OR) (string, error) {
}
func (q *Query) Finalize(filters string, storeQuery *query.Query) error {
q.query = fmt.Sprintf("SELECT key, value, etag FROM %s", tableName)
q.query = fmt.Sprintf("SELECT key, value, etag FROM %s", defaultTableName)
if filters != "" {
q.query += fmt.Sprintf(" WHERE %s", filters)

View File

@ -38,7 +38,7 @@ const (
)
type fakeItem struct {
Color string
Color string `json:"color"`
}
func TestOracleDatabaseIntegration(t *testing.T) {
@ -817,7 +817,7 @@ func storeItemExists(t *testing.T, key string) bool {
assert.Nil(t, err)
defer db.Close()
var rowCount int32
statement := fmt.Sprintf(`SELECT count(key) FROM %s WHERE key = :key`, tableName)
statement := fmt.Sprintf(`SELECT count(key) FROM %s WHERE key = :key`, defaultTableName)
err = db.QueryRow(statement, key).Scan(&rowCount)
assert.Nil(t, err)
exists := rowCount > 0
@ -832,7 +832,7 @@ func getRowData(t *testing.T, key string) (returnValue string, insertdate sql.Nu
db, err := sql.Open("oracle", connectionString)
assert.Nil(t, err)
defer db.Close()
err = db.QueryRow(fmt.Sprintf("SELECT value, creation_time, update_time FROM %s WHERE key = :key", tableName), key).Scan(&returnValue, &insertdate, &updatedate)
err = db.QueryRow(fmt.Sprintf("SELECT value, creation_time, update_time FROM %s WHERE key = :key", defaultTableName), key).Scan(&returnValue, &insertdate, &updatedate)
assert.Nil(t, err)
return returnValue, insertdate, updatedate
@ -846,7 +846,7 @@ func getTimesForRow(t *testing.T, key string) (insertdate sql.NullString, update
db, err := sql.Open("oracle", connectionString)
assert.Nil(t, err)
defer db.Close()
err = db.QueryRow(fmt.Sprintf("SELECT creation_time, update_time, expiration_time FROM %s WHERE key = :key", tableName), key).Scan(&insertdate, &updatedate, &expirationtime)
err = db.QueryRow(fmt.Sprintf("SELECT creation_time, update_time, expiration_time FROM %s WHERE key = :key", defaultTableName), key).Scan(&insertdate, &updatedate, &expirationtime)
assert.Nil(t, err)
return insertdate, updatedate, expirationtime

View File

@ -36,7 +36,7 @@ const (
connectionStringKey = "connectionString"
oracleWalletLocationKey = "oracleWalletLocation"
errMissingConnectionString = "missing connection string"
tableName = "state"
defaultTableName = "state"
)
// oracleDatabaseAccess implements dbaccess.
@ -69,7 +69,7 @@ func (o *oracleDatabaseAccess) Ping() error {
func parseMetadata(meta map[string]string) (oracleDatabaseMetadata, error) {
m := oracleDatabaseMetadata{
TableName: "state",
TableName: defaultTableName,
}
err := metadata.DecodeMetadata(meta, &m)
return m, err
@ -105,7 +105,7 @@ func (o *oracleDatabaseAccess) Init(metadata state.Metadata) error {
if pingErr := db.Ping(); pingErr != nil {
return pingErr
}
err = o.ensureStateTable(tableName)
err = o.ensureStateTable(o.metadata.TableName)
if err != nil {
return err
}
@ -168,20 +168,22 @@ func (o *oracleDatabaseAccess) Set(ctx context.Context, req *state.SetRequest) e
// Sprintf is required for table name because sql.DB does not substitute parameters for table names.
// Other parameters use sql.DB parameter substitution.
// As per Discord Thread https://discord.com/channels/778680217417809931/901141713089863710/938520959562952735 expiration time is reset in case of an update.
//nolint:gosec
mergeStatement := fmt.Sprintf(
`MERGE INTO %s t using (select :key key, :value value, :binary_yn binary_yn, :etag etag , :ttl_in_seconds ttl_in_seconds from dual) new_state_to_store
ON (t.key = new_state_to_store.key )
WHEN MATCHED THEN UPDATE SET value = new_state_to_store.value, binary_yn = new_state_to_store.binary_yn, update_time = systimestamp, etag = new_state_to_store.etag, t.expiration_time = case when new_state_to_store.ttl_in_seconds >0 then systimestamp + numtodsinterval(new_state_to_store.ttl_in_seconds, 'SECOND') end
WHEN NOT MATCHED THEN INSERT (t.key, t.value, t.binary_yn, t.etag, t.expiration_time) values (new_state_to_store.key, new_state_to_store.value, new_state_to_store.binary_yn, new_state_to_store.etag, case when new_state_to_store.ttl_in_seconds >0 then systimestamp + numtodsinterval(new_state_to_store.ttl_in_seconds, 'SECOND') end ) `,
tableName)
o.metadata.TableName)
result, err = tx.ExecContext(ctx, mergeStatement, req.Key, value, binaryYN, etag, ttlSeconds)
} else {
// when first write policy is indicated, an existing record has to be updated - one that has the etag provided.
// TODO: Needs to update ttl_in_seconds
//nolint:gosec
updateStatement := fmt.Sprintf(
`UPDATE %s SET value = :value, binary_yn = :binary_yn, etag = :new_etag
WHERE key = :key AND etag = :etag`,
tableName)
o.metadata.TableName)
result, err = tx.ExecContext(ctx, updateStatement, value, binaryYN, etag, req.Key, *req.ETag)
}
if err != nil {
@ -215,7 +217,7 @@ func (o *oracleDatabaseAccess) Get(ctx context.Context, req *state.GetRequest) (
var value string
var binaryYN string
var etag string
err := o.db.QueryRowContext(ctx, fmt.Sprintf("SELECT value, binary_yn, etag FROM %s WHERE key = :key and (expiration_time is null or expiration_time > systimestamp)", tableName), req.Key).Scan(&value, &binaryYN, &etag)
err := o.db.QueryRowContext(ctx, fmt.Sprintf("SELECT value, binary_yn, etag FROM %s WHERE key = :key and (expiration_time is null or expiration_time > systimestamp)", o.metadata.TableName), req.Key).Scan(&value, &binaryYN, &etag)
if err != nil {
// If no rows exist, return an empty response, otherwise return the error.
if err == sql.ErrNoRows {
@ -268,9 +270,11 @@ func (o *oracleDatabaseAccess) Delete(ctx context.Context, req *state.DeleteRequ
}
// QUESTION: only check for etag if FirstWrite specified - or always when etag is supplied??
if req.Options.Concurrency != state.FirstWrite {
result, err = tx.ExecContext(ctx, "DELETE FROM state WHERE key = :key", req.Key)
//nolint:gosec
result, err = tx.ExecContext(ctx, "DELETE FROM "+o.metadata.TableName+" WHERE key = :key", req.Key)
} else {
result, err = tx.ExecContext(ctx, "DELETE FROM state WHERE key = :key and etag = :etag", req.Key, *req.ETag)
//nolint:gosec
result, err = tx.ExecContext(ctx, "DELETE FROM "+o.metadata.TableName+" WHERE key = :key and etag = :etag", req.Key, *req.ETag)
}
if err != nil {
if o.tx == nil { // not joining a preexisting transaction.

View File

@ -323,7 +323,7 @@ func (p *PostgresDBAccess) doDelete(parentCtx context.Context, db dbquerier, req
var result pgconn.CommandTag
if req.ETag == nil || *req.ETag == "" {
result, err = db.Exec(parentCtx, "DELETE FROM state WHERE key = $1", req.Key)
result, err = db.Exec(parentCtx, "DELETE FROM "+p.metadata.TableName+" WHERE key = $1", req.Key)
} else {
// Convert req.ETag to uint32 for postgres XID compatibility
var etag64 uint64
@ -332,7 +332,7 @@ func (p *PostgresDBAccess) doDelete(parentCtx context.Context, db dbquerier, req
return state.NewETagError(state.ETagInvalid, err)
}
result, err = db.Exec(parentCtx, "DELETE FROM state WHERE key = $1 AND xmin = $2", req.Key, uint32(etag64))
result, err = db.Exec(parentCtx, "DELETE FROM "+p.metadata.TableName+" WHERE key = $1 AND xmin = $2", req.Key, uint32(etag64))
}
if err != nil {