Add MySQL TTL and Cleanup (#2641)

Signed-off-by: Deepanshu Agarwal <deepanshu.agarwal1984@gmail.com>
Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
Co-authored-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
Co-authored-by: Bernd Verst <github@bernd.dev>
This commit is contained in:
Deepanshu Agarwal 2023-03-31 04:15:57 +05:30 committed by GitHub
parent 04d1e71ce7
commit d6ce7bb5c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 636 additions and 167 deletions

View File

@ -224,7 +224,7 @@ func (g *gc) updateLastCleanup(ctx context.Context) (bool, error) {
}
n = res.RowsAffected()
} else {
res, err := g.dbSQL.ExecContext(ctx, g.updateLastCleanupQuery, sql.Named("Interval", g.cleanupInterval.Milliseconds()-100))
res, err := g.dbSQL.ExecContext(ctx, g.updateLastCleanupQuery, g.cleanupInterval.Milliseconds()-100)
if err != nil {
return false, fmt.Errorf("error updating last cleanup time: %w", err)
}

View File

@ -1,22 +0,0 @@
/*
Copyright 2021 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package mysql
import "database/sql"
// This interface is used to help improve testing.
type iMySQLFactory interface {
Open(connectionString string) (*sql.DB, error)
RegisterTLSConfig(pemPath string) error
}

View File

@ -25,6 +25,12 @@ import (
"github.com/dapr/kit/logger"
)
// This interface is used to help improve testing.
type iMySQLFactory interface {
Open(connectionString string) (*sql.DB, error)
RegisterTLSConfig(pemPath string) error
}
type mySQLFactory struct {
logger logger.Logger
}

View File

@ -27,28 +27,20 @@ import (
"github.com/google/uuid"
sqlCleanup "github.com/dapr/components-contrib/internal/component/sql"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/state"
"github.com/dapr/components-contrib/state/utils"
"github.com/dapr/kit/logger"
"github.com/dapr/kit/ptr"
)
// Optimistic Concurrency is implemented using a string column that stores a UUID.
const (
// The key name in the metadata if the user wants a different table name
// than the defaultTableName.
keyTableName = "tableName"
// The key name in the metadata if the user wants a different database name
// than the defaultSchemaName.
keySchemaName = "schemaName"
// The key name in the metadata for the timeout of operations, in seconds.
keyTimeoutInSeconds = "timeoutInSeconds"
// The key for the mandatory connection string of the metadata.
keyConnectionString = "connectionString"
// To connect to MySQL running in Azure over SSL you have to download a
// SSL certificate. If this is provided the driver will connect using
// SSL. If you have disable SSL you can leave this empty.
@ -69,14 +61,27 @@ const (
// Standard error message if not connection string is provided.
errMissingConnectionString = "missing connection string"
// Key name to configure interval at which entries with TTL are cleaned up.
// This is parsed as a Go duration.
cleanupIntervalKey = "cleanupInterval"
// Used if the user does not configure a metadata table name in the metadata.
// In terms of TTL, it is required to store value for 'last-cleanup' id.
defaultMetadataTableName = "dapr_metadata"
// Used if the user does not configure a cleanup interval in the metadata.
defaultCleanupInterval = time.Hour
)
// MySQL state store.
type MySQL struct {
tableName string
schemaName string
connectionString string
timeout time.Duration
tableName string
metadataTableName string
cleanupInterval *time.Duration
schemaName string
connectionString string
timeout time.Duration
// Instance of the database to issue commands to
db *sql.DB
@ -85,14 +90,17 @@ type MySQL struct {
logger logger.Logger
factory iMySQLFactory
gc sqlCleanup.GarbageCollector
}
type mySQLMetadata struct {
TableName string
SchemaName string
ConnectionString string
Timeout int
PemPath string
TableName string
SchemaName string
ConnectionString string
Timeout int
PemPath string
MetadataTableName string
CleanupInterval *time.Duration
}
// NewMySQLStateStore creates a new instance of MySQL state store.
@ -141,9 +149,12 @@ func (m *MySQL) Init(ctx context.Context, metadata state.Metadata) error {
func (m *MySQL) parseMetadata(md map[string]string) error {
meta := mySQLMetadata{
TableName: defaultTableName,
SchemaName: defaultSchemaName,
TableName: defaultTableName,
SchemaName: defaultSchemaName,
MetadataTableName: defaultMetadataTableName,
CleanupInterval: ptr.Of(defaultCleanupInterval),
}
err := metadata.DecodeMetadata(md, &meta)
if err != nil {
return err
@ -157,6 +168,14 @@ func (m *MySQL) parseMetadata(md map[string]string) error {
}
m.tableName = meta.TableName
if meta.MetadataTableName != "" {
// Sanitize the metadata table name
if !validIdentifier(meta.MetadataTableName) {
return fmt.Errorf("metadata table name '%s' is not valid", meta.MetadataTableName)
}
}
m.metadataTableName = meta.MetadataTableName
if meta.SchemaName != "" {
// Sanitize the schema name
if !validIdentifier(meta.SchemaName) {
@ -171,6 +190,21 @@ func (m *MySQL) parseMetadata(md map[string]string) error {
}
m.connectionString = meta.ConnectionString
// Cleanup interval
if meta.CleanupInterval != nil {
// Non-positive value from meta means disable auto cleanup.
if *meta.CleanupInterval <= 0 {
if md[cleanupIntervalKey] == "" {
// unfortunately the mapstructure decoder decodes an empty string to 0, a missing key would be nil however
meta.CleanupInterval = ptr.Of(defaultCleanupInterval)
} else {
meta.CleanupInterval = nil
}
}
m.cleanupInterval = meta.CleanupInterval
}
if meta.PemPath != "" {
err := m.factory.RegisterTLSConfig(meta.PemPath)
if err != nil {
@ -231,7 +265,35 @@ func (m *MySQL) finishInit(ctx context.Context, db *sql.DB) error {
}
// will be nil if everything is good or an err that needs to be returned
return m.ensureStateTable(ctx, m.tableName)
if err = m.ensureStateTable(ctx, m.schemaName, m.tableName); err != nil {
return err
}
if err = m.ensureMetadataTable(ctx, m.schemaName, m.metadataTableName); err != nil {
return err
}
if m.cleanupInterval != nil {
gc, err := sqlCleanup.ScheduleGarbageCollector(sqlCleanup.GCOptions{
Logger: m.logger,
UpdateLastCleanupQuery: fmt.Sprintf(`INSERT INTO %[1]s (id, value)
VALUES ('last-cleanup', CURRENT_TIMESTAMP)
ON DUPLICATE KEY UPDATE
value = IF(CURRENT_TIMESTAMP > DATE_ADD(value, INTERVAL ?*1000 MICROSECOND), CURRENT_TIMESTAMP, value)`,
m.metadataTableName),
DeleteExpiredValuesQuery: fmt.Sprintf(
`DELETE FROM %s WHERE expiredate IS NOT NULL AND expiredate <= CURRENT_TIMESTAMP`,
m.tableName,
),
CleanupInterval: *m.cleanupInterval,
DBSql: m.db,
})
if err != nil {
return err
}
m.gc = gc
}
return nil
}
func (m *MySQL) ensureStateSchema(ctx context.Context) error {
@ -273,13 +335,13 @@ func (m *MySQL) ensureStateSchema(ctx context.Context) error {
return err
}
func (m *MySQL) ensureStateTable(ctx context.Context, stateTableName string) error {
exists, err := tableExists(ctx, m.db, stateTableName, m.timeout)
func (m *MySQL) ensureStateTable(ctx context.Context, schemaName, stateTableName string) error {
tableExists, err := tableExists(ctx, m.db, schemaName, stateTableName, m.timeout)
if err != nil {
return err
}
if !exists {
if !tableExists {
m.logger.Infof("Creating MySql state table '%s'", stateTableName)
// updateDate is updated automactically on every UPDATE commands so you
@ -288,18 +350,85 @@ func (m *MySQL) ensureStateTable(ctx context.Context, stateTableName string) err
// in on inserts and updates and is used for Optimistic Concurrency
// Note that stateTableName is sanitized
//nolint:gosec
createTable := fmt.Sprintf(`CREATE TABLE %s (
createTable := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (
id VARCHAR(255) NOT NULL PRIMARY KEY,
value JSON NOT NULL,
isbinary BOOLEAN NOT NULL,
insertDate TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updateDate TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
eTag VARCHAR(36) NOT NULL
eTag VARCHAR(36) NOT NULL,
expiredate TIMESTAMP NULL,
INDEX expiredate_idx(expiredate)
);`, stateTableName)
execCtx, execCancel := context.WithTimeout(ctx, m.timeout)
defer execCancel()
_, err = m.db.ExecContext(execCtx, createTable)
_, err = m.db.ExecContext(ctx, createTable)
if err != nil {
return err
}
}
// Check if expiredate column exists - to cater cases when table was created before v1.11.
columnExists, err := columnExists(ctx, m.db, schemaName, stateTableName, "expiredate", m.timeout)
if err != nil {
return err
}
if !columnExists {
m.logger.Infof("Adding expiredate column to MySql state table '%s'", stateTableName)
_, err = m.db.ExecContext(ctx, fmt.Sprintf(
`ALTER TABLE %s ADD COLUMN IF NOT EXISTS expiredate TIMESTAMP NULL;`, stateTableName))
if err != nil {
return err
}
_, err = m.db.ExecContext(ctx, fmt.Sprintf(
`CREATE INDEX IF NOT EXISTS expiredate_idx ON %s (expiredate);`, stateTableName))
if err != nil {
return err
}
}
// Create the DaprSaveFirstWriteV1 stored procedure
_, err = m.db.ExecContext(ctx, `CREATE PROCEDURE IF NOT EXISTS DaprSaveFirstWriteV1(tableName VARCHAR(255), id VARCHAR(255), value JSON, etag VARCHAR(36), isbinary BOOLEAN, expiredateToken TEXT)
LANGUAGE SQL
MODIFIES SQL DATA
BEGIN
SET @id = id;
SET @value = value;
SET @etag = etag;
SET @isbinary = isbinary;
SET @selectQuery = concat('SELECT COUNT(id) INTO @count FROM ', tableName ,' WHERE id = ? AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)');
PREPARE select_stmt FROM @selectQuery;
EXECUTE select_stmt USING @id;
DEALLOCATE PREPARE select_stmt;
IF @count < 1 THEN
SET @upsertQuery = concat('INSERT INTO ', tableName, ' SET id=?, value=?, eTag=?, isbinary=?, expiredate=', expiredateToken, ' ON DUPLICATE KEY UPDATE value=?, eTag=?, isbinary=?, expiredate=', expiredateToken);
PREPARE upsert_stmt FROM @upsertQuery;
EXECUTE upsert_stmt USING @id, @value, @etag, @isbinary, @value, @etag, @isbinary;
DEALLOCATE PREPARE upsert_stmt;
ELSE
SIGNAL SQLSTATE '45000' SET MESSAGE_TEXT = 'Row already exists';
END IF;
END`)
if err != nil {
return err
}
return nil
}
func (m *MySQL) ensureMetadataTable(ctx context.Context, schemaName, metaTableName string) error {
exists, err := tableExists(ctx, m.db, schemaName, metaTableName, m.timeout)
if err != nil {
return err
}
if !exists {
m.logger.Info("Creating MySQL metadata table")
_, err = m.db.ExecContext(ctx, fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (
id VARCHAR(255) NOT NULL PRIMARY KEY, value TEXT NOT NULL);`, metaTableName))
if err != nil {
return err
}
@ -321,16 +450,31 @@ func schemaExists(ctx context.Context, db *sql.DB, schemaName string, timeout ti
return exists == 1, err
}
func tableExists(ctx context.Context, db *sql.DB, tableName string, timeout time.Duration) (bool, error) {
func tableExists(ctx context.Context, db *sql.DB, schemaName, tableName string, timeout time.Duration) (bool, error) {
tableCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
// Returns 1 or 0 if the table exists or not
var exists int
query := `SELECT EXISTS (
SELECT TABLE_NAME FROM information_schema.tables WHERE TABLE_NAME = ?
SELECT TABLE_NAME FROM information_schema.tables WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?
) AS 'exists'`
err := db.QueryRowContext(tableCtx, query, tableName).Scan(&exists)
err := db.QueryRowContext(tableCtx, query, schemaName, tableName).Scan(&exists)
return exists == 1, err
}
// columnExists returns true if the column exists in the table
func columnExists(ctx context.Context, db *sql.DB, schemaName, tableName, columnName string, timeout time.Duration) (bool, error) {
columnCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
// Returns 1 or 0 if the column exists or not
var exists int
query := `SELECT count(*) AS 'exists' FROM information_schema.columns
WHERE table_schema = ?
AND table_name = ?
AND column_name = ?`
err := db.QueryRowContext(columnCtx, query, schemaName, tableName, columnName).Scan(&exists)
return exists == 1, err
}
@ -343,8 +487,6 @@ func (m *MySQL) Delete(ctx context.Context, req *state.DeleteRequest) error {
// deleteValue is an internal implementation of delete to enable passing the
// logic to state.DeleteWithRetries as a func.
func (m *MySQL) deleteValue(parentCtx context.Context, querier querier, req *state.DeleteRequest) error {
m.logger.Debug("Deleting state value from MySql")
if req.Key == "" {
return fmt.Errorf("missing key in delete operation")
}
@ -363,7 +505,7 @@ func (m *MySQL) deleteValue(parentCtx context.Context, querier querier, req *sta
m.tableName), req.Key)
} else {
result, err = querier.ExecContext(execCtx, fmt.Sprintf(
`DELETE FROM %s WHERE id = ? and eTag = ?`,
`DELETE FROM %s WHERE id = ? AND eTag = ?`,
m.tableName), req.Key, *req.ETag)
}
@ -386,37 +528,33 @@ func (m *MySQL) deleteValue(parentCtx context.Context, querier querier, req *sta
// BulkDelete removes multiple entries from the store
// Store Interface.
func (m *MySQL) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
m.logger.Debug("Executing BulkDelete request")
tx, err := m.db.Begin()
if err != nil {
return err
}
defer func() {
rollbackErr := tx.Rollback()
if rollbackErr != nil && !errors.Is(rollbackErr, sql.ErrTxDone) {
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
}
}()
if len(req) > 0 {
for _, d := range req {
da := d // Fix for goSec G601: Implicit memory aliasing in for loop.
err = m.deleteValue(ctx, tx, &da)
if err != nil {
rollbackErr := tx.Rollback()
if rollbackErr != nil {
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
}
return err
}
}
}
err = tx.Commit()
return err
return tx.Commit()
}
// Get returns an entity from store
// Store Interface.
func (m *MySQL) Get(parentCtx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
m.logger.Debug("Getting state value from MySql")
if req.Key == "" {
return nil, fmt.Errorf("missing key in get operation")
}
@ -431,7 +569,8 @@ func (m *MySQL) Get(parentCtx context.Context, req *state.GetRequest) (*state.Ge
defer cancel()
//nolint:gosec
query := fmt.Sprintf(
`SELECT value, eTag, isbinary FROM %s WHERE id = ?`,
`SELECT value, eTag, isbinary FROM %s WHERE id = ?
AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)`,
m.tableName, // m.tableName is sanitized
)
err := m.db.QueryRowContext(ctx, query, req.Key).Scan(&value, &eTag, &isBinary)
@ -483,8 +622,6 @@ func (m *MySQL) Set(ctx context.Context, req *state.SetRequest) error {
// setValue is an internal implementation of set to enable passing the logic
// to state.SetWithRetries as a func.
func (m *MySQL) setValue(parentCtx context.Context, querier querier, req *state.SetRequest) error {
m.logger.Debug("Setting state value in MySql")
err := state.CheckRequestOptions(req.Options)
if err != nil {
return err
@ -494,6 +631,24 @@ func (m *MySQL) setValue(parentCtx context.Context, querier querier, req *state.
return errors.New("missing key in set operation")
}
// TTL
var ttlSeconds int
ttl, err := utils.ParseTTL(req.Metadata)
if err != nil {
return fmt.Errorf("error parsing TTL: %w", err)
}
if ttl != nil {
ttlSeconds = *ttl
}
var (
query string
ttlQuery string
params []any
result sql.Result
maxRows int64 = 1
)
var v any
isBinary := false
switch x := req.Value.(type) {
@ -513,62 +668,91 @@ func (m *MySQL) setValue(parentCtx context.Context, querier querier, req *state.
}
eTag := eTagObj.String()
var (
result sql.Result
maxRows int64 = 1
)
if ttlSeconds > 0 {
ttlQuery = "CURRENT_TIMESTAMP + INTERVAL " + strconv.Itoa(ttlSeconds) + " SECOND"
} else {
ttlQuery = "NULL"
}
ctx, cancel := context.WithTimeout(parentCtx, m.timeout)
defer cancel()
mustCommit := false
hasEtag := req.ETag != nil && *req.ETag != ""
if req.Options.Concurrency == state.FirstWrite && (req.ETag == nil || *req.ETag == "") {
// With first-write-wins and no etag, we can insert the row only if it doesn't exist
query := fmt.Sprintf(
`INSERT INTO %s (value, id, eTag, isbinary) VALUES (?, ?, ?, ?);`,
m.tableName, // m.tableName is sanitized
)
result, err = querier.ExecContext(ctx, query, enc, req.Key, eTag, isBinary)
} else if req.ETag != nil && *req.ETag != "" {
if hasEtag {
// When an eTag is provided do an update - not insert
query := fmt.Sprintf(
`UPDATE %s SET value = ?, eTag = ?, isbinary = ? WHERE id = ? AND eTag = ?;`,
m.tableName, // m.tableName is sanitized
)
result, err = querier.ExecContext(ctx, query, enc, eTag, isBinary, req.Key, *req.ETag)
query = `UPDATE ` + m.tableName + `
SET value = ?, eTag = ?, isbinary = ?, expiredate = ` + ttlQuery + `
WHERE id = ?
AND eTag = ?
AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)`
params = []any{enc, eTag, isBinary, req.Key, *req.ETag}
} else if req.Options.Concurrency == state.FirstWrite {
// If we're not in a transaction already, start one as we need to ensure consistency
if querier == m.db {
querier, err = m.db.BeginTx(parentCtx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer querier.(*sql.Tx).Rollback()
mustCommit = true
}
// With first-write-wins and no etag, we can insert the row only if it doesn't exist
// Things get a bit tricky when the row exists but it is expired, so it just hasn't been garbage-collected yet
// What we can do in that case is to first check if the row doesn't exist or has expired, and then perform an upsert
// To do that, we use a stored procedure
query = "CALL DaprSaveFirstWriteV1(?, ?, ?, ?, ?, ?)"
params = []any{m.tableName, req.Key, enc, eTag, isBinary, ttlQuery}
} else {
// If this is a duplicate MySQL returns that two rows affected
maxRows = 2
query := fmt.Sprintf(
`INSERT INTO %s (value, id, eTag, isbinary) VALUES (?, ?, ?, ?) on duplicate key update value=?, eTag=?, isbinary=?;`,
m.tableName, // m.tableName is sanitized
)
result, err = querier.ExecContext(ctx, query, enc, req.Key, eTag, isBinary, enc, eTag, isBinary)
query = `INSERT INTO ` + m.tableName + ` (id, value, eTag, isbinary, expiredate)
VALUES (?, ?, ?, ?, ` + ttlQuery + `)
ON DUPLICATE KEY UPDATE
value=?, eTag=?, isbinary=?, expiredate=` + ttlQuery
params = []any{req.Key, enc, eTag, isBinary, enc, eTag, isBinary}
}
ctx, cancel := context.WithTimeout(parentCtx, m.timeout)
defer cancel()
result, err = querier.ExecContext(ctx, query, params...)
if err != nil {
if req.ETag != nil && *req.ETag != "" {
if hasEtag {
return state.NewETagError(state.ETagMismatch, err)
}
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
// Do not count affected rows when using first-write
// Conflicts are handled separately
if hasEtag || req.Options.Concurrency != state.FirstWrite {
var rows int64
rows, err = result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
err = errors.New("rows affected error: no rows match given key and eTag")
err = state.NewETagError(state.ETagMismatch, err)
m.logger.Error(err)
return err
}
if rows > maxRows {
err = fmt.Errorf("rows affected error: more than %d row affected; actual %d", maxRows, rows)
m.logger.Error(err)
return err
}
}
if rows == 0 {
err = errors.New(`rows affected error: no rows match given key and eTag`)
err = state.NewETagError(state.ETagMismatch, err)
m.logger.Error(err)
return err
}
if rows > maxRows {
err = fmt.Errorf(`rows affected error: more than %d row affected; actual %d`, maxRows, rows)
m.logger.Error(err)
return err
// Commit the transaction if needed
if mustCommit {
err = querier.(*sql.Tx).Commit()
if err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
}
return nil
@ -577,21 +761,21 @@ func (m *MySQL) setValue(parentCtx context.Context, querier querier, req *state.
// BulkSet adds/updates multiple entities on store
// Store Interface.
func (m *MySQL) BulkSet(ctx context.Context, req []state.SetRequest) error {
m.logger.Debug("Executing BulkSet request")
tx, err := m.db.Begin()
if err != nil {
return err
}
defer func() {
rollbackErr := tx.Rollback()
if rollbackErr != nil && !errors.Is(rollbackErr, sql.ErrTxDone) {
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
}
}()
if len(req) > 0 {
for i := range req {
err = m.setValue(ctx, tx, &req[i])
if err != nil {
rollbackErr := tx.Rollback()
if rollbackErr != nil {
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
}
return err
}
}
@ -603,50 +787,38 @@ func (m *MySQL) BulkSet(ctx context.Context, req []state.SetRequest) error {
// Multi handles multiple transactions.
// TransactionalStore Interface.
func (m *MySQL) Multi(ctx context.Context, request *state.TransactionalStateRequest) error {
m.logger.Debug("Executing Multi request")
tx, err := m.db.Begin()
if err != nil {
return err
}
defer func() {
rollbackErr := tx.Rollback()
if rollbackErr != nil && !errors.Is(rollbackErr, sql.ErrTxDone) {
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
}
}()
for _, req := range request.Operations {
switch req.Operation {
case state.Upsert:
setReq, err := m.getSets(req)
if err != nil {
rollbackErr := tx.Rollback()
if rollbackErr != nil {
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
}
return err
}
err = m.setValue(ctx, tx, &setReq)
if err != nil {
rollbackErr := tx.Rollback()
if rollbackErr != nil {
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
}
return err
}
case state.Delete:
delReq, err := m.getDeletes(req)
if err != nil {
rollbackErr := tx.Rollback()
if rollbackErr != nil {
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
}
return err
}
err = m.deleteValue(ctx, tx, &delReq)
if err != nil {
rollbackErr := tx.Rollback()
if rollbackErr != nil {
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
}
return err
}
@ -662,11 +834,11 @@ func (m *MySQL) Multi(ctx context.Context, request *state.TransactionalStateRequ
func (m *MySQL) getSets(req state.TransactionalStateOperation) (state.SetRequest, error) {
setReq, ok := req.Request.(state.SetRequest)
if !ok {
return setReq, fmt.Errorf("expecting set request")
return setReq, errors.New("expecting set request")
}
if setReq.Key == "" {
return setReq, fmt.Errorf("missing key in upsert operation")
return setReq, errors.New("missing key in upsert operation")
}
return setReq, nil
@ -676,11 +848,11 @@ func (m *MySQL) getSets(req state.TransactionalStateOperation) (state.SetRequest
func (m *MySQL) getDeletes(req state.TransactionalStateOperation) (state.DeleteRequest, error) {
delReq, ok := req.Request.(state.DeleteRequest)
if !ok {
return delReq, fmt.Errorf("expecting delete request")
return delReq, errors.New("expecting delete request")
}
if delReq.Key == "" {
return delReq, fmt.Errorf("missing key in delete operation")
return delReq, errors.New("missing key in delete operation")
}
return delReq, nil
@ -701,6 +873,10 @@ func (m *MySQL) Close() error {
err := m.db.Close()
m.db = nil
if m.gc != nil {
return errors.Join(err, m.gc.Close())
}
return err
}

View File

@ -19,6 +19,7 @@ package mysql
import (
"database/sql"
"time"
)
// GetConnection returns the database connection.
@ -35,3 +36,12 @@ func (m *MySQL) SchemaName() string {
func (m *MySQL) TableName() string {
return m.tableName
}
// CleanupInterval returns the value of the cleanupInterval property.
func (m *MySQL) CleanupInterval() *time.Duration {
return m.cleanupInterval
}
func (m *MySQL) CleanupExpired() error {
return m.gc.CleanupExpired()
}

View File

@ -149,7 +149,7 @@ func TestMySQLIntegration(t *testing.T) {
tableName := "test_state"
// Drop the table if it already exists
exists, err := tableExists(context.Background(), mys.db, tableName, 10*time.Second)
exists, err := tableExists(context.Background(), mys.db, "dapr_state_store", tableName, 10*time.Second)
assert.Nil(t, err)
if exists {
dropTable(t, mys.db, tableName)
@ -157,11 +157,11 @@ func TestMySQLIntegration(t *testing.T) {
// Create the state table and test for its existence
// There should be no error
err = mys.ensureStateTable(context.Background(), tableName)
err = mys.ensureStateTable(context.Background(), "dapr_state_store", tableName)
assert.Nil(t, err)
// Now create it and make sure there are no errors
exists, err = tableExists(context.Background(), mys.db, tableName, 10*time.Second)
exists, err = tableExists(context.Background(), mys.db, "dapr_state_store", tableName, 10*time.Second)
assert.Nil(t, err)
assert.True(t, exists)

View File

@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package mysql
import (
@ -35,6 +36,9 @@ import (
const (
fakeConnectionString = "not a real connection"
keyTableName = "tableName"
keyConnectionString = "connectionString"
keySchemaName = "schemaName"
)
func TestEnsureStateSchemaHandlesShortConnectionString(t *testing.T) {
@ -67,7 +71,7 @@ func TestFinishInitHandlesSchemaExistsError(t *testing.T) {
actualErr := m.mySQL.finishInit(context.Background(), m.mySQL.db)
// Assert
assert.NotNil(t, actualErr, "now error returned")
assert.Error(t, actualErr, "now error returned")
assert.Equal(t, "existsError", actualErr.Error(), "wrong error")
}
@ -86,7 +90,7 @@ func TestFinishInitHandlesDatabaseCreateError(t *testing.T) {
actualErr := m.mySQL.finishInit(context.Background(), m.mySQL.db)
// Assert
assert.NotNil(t, actualErr, "now error returned")
assert.Error(t, actualErr, "now error returned")
assert.Equal(t, "createDatabaseError", actualErr.Error(), "wrong error")
}
@ -541,7 +545,7 @@ func TestTableExists(t *testing.T) {
m.mock1.ExpectQuery("SELECT EXISTS").WillReturnRows(rows)
// Act
actual, err := tableExists(context.Background(), m.mySQL.db, "store", 10*time.Second)
actual, err := tableExists(context.Background(), m.mySQL.db, "dapr_state_store", "store", 10*time.Second)
// Assert
assert.Nil(t, err, `error was returned`)
@ -559,7 +563,7 @@ func TestEnsureStateTableHandlesCreateTableError(t *testing.T) {
m.mock1.ExpectExec("CREATE TABLE").WillReturnError(fmt.Errorf("CreateTableError"))
// Act
err := m.mySQL.ensureStateTable(context.Background(), "state")
err := m.mySQL.ensureStateTable(context.Background(), "dapr_state_store", "state")
// Assert
assert.NotNil(t, err, "no error returned")
@ -578,12 +582,15 @@ func TestEnsureStateTableCreatesTable(t *testing.T) {
rows := sqlmock.NewRows([]string{"exists"}).AddRow(0)
m.mock1.ExpectQuery("SELECT EXISTS").WillReturnRows(rows)
m.mock1.ExpectExec("CREATE TABLE").WillReturnResult(sqlmock.NewResult(1, 1))
rows = sqlmock.NewRows([]string{"exists"}).AddRow(1)
m.mock1.ExpectQuery("SELECT count(/*)").WillReturnRows(rows)
m.mock1.ExpectExec("CREATE PROCEDURE").WillReturnResult(sqlmock.NewResult(1, 1))
// Act
err := m.mySQL.ensureStateTable(context.Background(), "state")
err := m.mySQL.ensureStateTable(context.Background(), "dapr_state_store", "state")
// Assert
assert.Nil(t, err)
assert.NoError(t, err)
}
// Verify that the call to MySQL init get passed through

View File

@ -39,5 +39,16 @@ d. Get and validate eTag, which should not have changed.
1. Without `schemaName`, check that the default one is used
2. Without `tableName`, check that the default one is used
3. Instantiate a component with a custom `schemaName` and validate it's used
4. Instantiate a component with a custom `tableName` and validate it's used
3. Without `metadataTableName`, check that the default one is used
4. Instantiate a component with a custom `schemaName` and validate it's used
5. Instantiate a component with a custom `tableName` and validate it's used
6. Instantiate a component with a custom `metadataTableName` and validate it's used
## TTLs and cleanups
1. Correctly parse the `cleanupIntervalInSeconds` metadata property:
- No value uses the default value (3600 seconds)
- A positive value sets the interval to the given number of seconds
- A zero or negative value disables the cleanup
2. The cleanup method deletes expired records and updates the metadata table with the last time it ran
3. The cleanup method doesn't run if the last iteration was less than `cleanupIntervalInSeconds` or if another process is doing the cleanup

View File

@ -62,6 +62,9 @@ require (
github.com/hashicorp/golang-lru/v2 v2.0.1 // indirect
github.com/hashicorp/serf v0.10.1 // indirect
github.com/imdario/mergo v0.3.13 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.3.1 // indirect
github.com/jhump/protoreflect v1.14.1 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
@ -108,6 +111,7 @@ require (
go.opentelemetry.io/otel/sdk v1.11.2 // indirect
go.opentelemetry.io/otel/trace v1.11.2 // indirect
go.opentelemetry.io/proto/otlp v0.19.0 // indirect
golang.org/x/crypto v0.7.0 // indirect
golang.org/x/exp v0.0.0-20230310171629-522b1b587ee0 // indirect
golang.org/x/mod v0.9.0 // indirect
golang.org/x/net v0.8.0 // indirect

View File

@ -291,6 +291,12 @@ github.com/hashicorp/serf v0.10.1/go.mod h1:yL2t6BqATOLGc5HF7qbFkTfXoPIY0WZdWHfE
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/imdario/mergo v0.3.13 h1:lFzP57bqS/wsqKssCGmtLAb8A0wKjLGrve2q3PPVcBk=
github.com/imdario/mergo v0.3.13/go.mod h1:4lJ1jqUDcsbIECGy0RUJAXNIhg+6ocWgb1ALK2O4oXg=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU=
github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8=
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
github.com/jhump/gopoet v0.0.0-20190322174617-17282ff210b3/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI=
github.com/jhump/gopoet v0.1.0/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI=
@ -530,6 +536,8 @@ golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392/go.mod h1:/lpIB1dKB+9EgE3
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A=
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=

View File

@ -16,7 +16,9 @@ package main
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strconv"
"sync/atomic"
"testing"
@ -45,11 +47,17 @@ const (
certificationTestPrefix = "stable-certification-"
timeout = 5 * time.Second
defaultSchemaName = "dapr_state_store"
defaultTableName = "state"
defaultSchemaName = "dapr_state_store"
defaultTableName = "state"
defaultMetadataTableName = "dapr_metadata"
mysqlConnString = "root:root@tcp(localhost:3306)/?allowNativePasswords=true"
mariadbConnString = "root:root@tcp(localhost:3307)/"
keyConnectionString = "connectionString"
keyCleanupInterval = "cleanupInterval"
keyTableName = "tableName"
keyMetadatTableName = "metadataTableName"
)
func TestMySQL(t *testing.T) {
@ -315,7 +323,7 @@ func TestMySQL(t *testing.T) {
}
// checks that metadata options schemaName and tableName behave correctly
metadataTest := func(connString string, schemaName string, tableName string) func(ctx flow.Context) error {
metadataTest := func(connString, schemaName, tableName, metadataTableName string) func(ctx flow.Context) error {
return func(ctx flow.Context) (err error) {
properties := map[string]string{
"connectionString": connString,
@ -332,6 +340,11 @@ func TestMySQL(t *testing.T) {
} else {
tableName = defaultTableName
}
if metadataTableName != "" {
properties["metadataTableName"] = metadataTableName
} else {
metadataTableName = defaultMetadataTableName
}
// Init the component
component := stateMysql.NewMySQLStateStore(log).(*stateMysql.MySQL)
@ -359,9 +372,23 @@ func TestMySQL(t *testing.T) {
// Check that the table exists
query = `SELECT EXISTS (
SELECT TABLE_NAME FROM information_schema.tables WHERE TABLE_NAME = ?
SELECT TABLE_NAME FROM information_schema.tables WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?
) AS 'exists'`
err = conn.QueryRow(query, tableName).Scan(&exists)
err = conn.QueryRow(query, schemaName, tableName).Scan(&exists)
require.NoError(t, err)
assert.Equal(t, 1, exists)
// Check that the expiredate column exists
query = `SELECT count(*) AS 'exists' FROM information_schema.columns
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? AND COLUMN_NAME = ?`
err = conn.QueryRow(query, schemaName, tableName, "expiredate").Scan(&exists)
require.NoError(t, err)
assert.Equal(t, 1, exists)
// Check that the metadata table exists
query = `SELECT count(*) AS 'exists' FROM information_schema.tables
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?`
err = conn.QueryRow(query, schemaName, tableName).Scan(&exists)
require.NoError(t, err)
assert.Equal(t, 1, exists)
@ -373,6 +400,165 @@ func TestMySQL(t *testing.T) {
}
}
// Validates TTLs and garbage collections
ttlTest := func(connString string) func(ctx flow.Context) error {
return func(ctx flow.Context) (err error) {
md := state.Metadata{
Base: metadata.Base{
Name: "ttltest",
Properties: map[string]string{
keyConnectionString: connString,
keyTableName: "ttl_state",
keyMetadatTableName: "ttl_metadata",
},
},
}
t.Run("parse cleanupInterval", func(t *testing.T) {
t.Run("default value", func(t *testing.T) {
// Default value is 1 hr
md.Properties[keyCleanupInterval] = ""
storeObj := stateMysql.NewMySQLStateStore(log).(*stateMysql.MySQL)
err := storeObj.Init(ctx, md)
require.NoError(t, err, "failed to init")
defer storeObj.Close()
cleanupInterval := storeObj.CleanupInterval()
_ = assert.NotNil(t, cleanupInterval) &&
assert.Equal(t, time.Duration(1*time.Hour), *cleanupInterval)
})
t.Run("positive value", func(t *testing.T) {
md.Properties[keyCleanupInterval] = "10s"
storeObj := stateMysql.NewMySQLStateStore(log).(*stateMysql.MySQL)
err := storeObj.Init(ctx, md)
require.NoError(t, err, "failed to init")
defer storeObj.Close()
cleanupInterval := storeObj.CleanupInterval()
_ = assert.NotNil(t, cleanupInterval) &&
assert.Equal(t, time.Duration(10*time.Second), *cleanupInterval)
})
t.Run("disabled", func(t *testing.T) {
// A value of <=0 means that the cleanup is disabled
md.Properties[keyCleanupInterval] = "0"
storeObj := stateMysql.NewMySQLStateStore(log).(*stateMysql.MySQL)
err := storeObj.Init(ctx, md)
require.NoError(t, err, "failed to init")
defer storeObj.Close()
cleanupInterval := storeObj.CleanupInterval()
_ = assert.Nil(t, cleanupInterval)
})
})
t.Run("cleanup", func(t *testing.T) {
md := state.Metadata{
Base: metadata.Base{
Name: "ttltest",
Properties: map[string]string{
keyConnectionString: connString,
keyTableName: "ttl_state",
keyMetadatTableName: "ttl_metadata",
},
},
}
t.Run("automatically delete expired records", func(t *testing.T) {
// Run every second
md.Properties[keyCleanupInterval] = "1s"
storeObj := stateMysql.NewMySQLStateStore(log).(*stateMysql.MySQL)
err := storeObj.Init(ctx, md)
require.NoError(t, err, "failed to init")
defer storeObj.Close()
conn := storeObj.GetConnection()
// Seed the database with some records
err = populateTTLRecords(ctx, conn)
require.NoError(t, err, "failed to seed records")
// Wait 2 seconds then verify we have only 10 rows left
time.Sleep(2 * time.Second)
count, err := countRowsInTable(ctx, conn, "ttl_state")
require.NoError(t, err, "failed to run query to count rows")
assert.Equal(t, 10, count)
// The "last-cleanup" value should be <= 2 seconds (+ a bit of buffer)
lastCleanup, err := loadLastCleanupInterval(ctx, conn, "ttl_metadata")
require.NoError(t, err, "failed to load value for 'last-cleanup'")
assert.LessOrEqual(t, lastCleanup, 2)
// Wait 6 more seconds and verify there are no more rows left
time.Sleep(6 * time.Second)
count, err = countRowsInTable(ctx, conn, "ttl_state")
require.NoError(t, err, "failed to run query to count rows")
assert.Equal(t, 0, count)
// The "last-cleanup" value should be <= 2 seconds (+ a bit of buffer)
lastCleanup, err = loadLastCleanupInterval(ctx, conn, "ttl_metadata")
require.NoError(t, err, "failed to load value for 'last-cleanup'")
assert.LessOrEqual(t, lastCleanup, 2)
})
t.Run("cleanup concurrency", func(t *testing.T) {
// Set to run every hour
// (we'll manually trigger more frequent iterations)
md.Properties[keyCleanupInterval] = "1h"
storeObj := stateMysql.NewMySQLStateStore(log).(*stateMysql.MySQL)
err := storeObj.Init(ctx, md)
require.NoError(t, err, "failed to init")
defer storeObj.Close()
conn := storeObj.GetConnection()
// Seed the database with some records
err = populateTTLRecords(ctx, conn)
require.NoError(t, err, "failed to seed records")
// Validate that 20 records are present
count, err := countRowsInTable(ctx, conn, "ttl_state")
require.NoError(t, err, "failed to run query to count rows")
assert.Equal(t, 20, count)
// Set last-cleanup to 1s ago (less than 3600s)
err = setValueInMetadataTable(ctx, conn, "ttl_metadata", "'last-cleanup'", "CURRENT_TIMESTAMP - INTERVAL 1 SECOND")
require.NoError(t, err, "failed to set last-cleanup")
// The "last-cleanup" value should be ~2 seconds (+ a bit of buffer)
lastCleanup, err := loadLastCleanupInterval(ctx, conn, "ttl_metadata")
require.NoError(t, err, "failed to load value for 'last-cleanup'")
assert.LessOrEqual(t, lastCleanup, 2)
lastCleanupValueOrig, err := getValueFromMetadataTable(ctx, conn, "ttl_metadata", "last-cleanup")
require.NoError(t, err, "failed to load absolute value for 'last-cleanup'")
require.NotEmpty(t, lastCleanupValueOrig)
// Trigger the background cleanup, which should do nothing because the last cleanup was < 3600s
err = storeObj.CleanupExpired()
require.NoError(t, err, "CleanupExpired returned an error")
// Validate that 20 records are still present
count, err = countRowsInTable(ctx, conn, "ttl_state")
require.NoError(t, err, "failed to run query to count rows")
assert.Equal(t, 20, count)
// The "last-cleanup" value should not have been changed
lastCleanupValue, err := getValueFromMetadataTable(ctx, conn, "ttl_metadata", "last-cleanup")
require.NoError(t, err, "failed to load absolute value for 'last-cleanup'")
assert.Equal(t, lastCleanupValueOrig, lastCleanupValue)
})
})
return nil
}
}
flow.New(t, "Run tests").
Step(dockercompose.Run("db", dockerComposeYAML)).
Step("Wait for databases to start", flow.Sleep(30*time.Second)).
@ -389,6 +575,8 @@ func TestMySQL(t *testing.T) {
Step("Run eTag test on mariadb", eTagTest("mariadb")).
Step("Run transactions test", transactionsTest("mysql")).
Step("Run transactions test", transactionsTest("mariadb")).
Step("Run TTL test on mysql", ttlTest(mysqlConnString)).
Step("Run TTL test on mariadb", ttlTest(mariadbConnString)).
Step("Run SQL injection test on mysql", verifySQLInjectionTest("mysql")).
Step("Run SQL injection test on mariadb", verifySQLInjectionTest("mariadb")).
//Step("Interrupt network and simulate timeouts", timeoutTest).
@ -408,10 +596,91 @@ func TestMySQL(t *testing.T) {
Step("Close database connection 1", closeTest(0)).
Step("Close database connection 2", closeTest(1)).
// Metadata
Step("Default schemaName and tableName on mysql", metadataTest(mysqlConnString, "", "")).
Step("Custom schemaName and tableName on mysql", metadataTest(mysqlConnString, "mydaprdb", "mytable")).
Step("Default schemaName and tableName on mariadb", metadataTest(mariadbConnString, "", "")).
Step("Custom schemaName and tableName on mariadb", metadataTest(mariadbConnString, "mydaprdb", "mytable")).
Step("Default schemaName, tableName and metadataTableName on mysql", metadataTest(mysqlConnString, "", "", "")).
Step("Custom schemaName, tableName and metadataTableName on mysql", metadataTest(mysqlConnString, "mydaprdb", "mytable", "metadatatable")).
Step("Default schemaName, tableName and metadataTableName on mariadb", metadataTest(mariadbConnString, "", "", "")).
Step("Custom schemaName, tableName and metadataTableName on mariadb", metadataTest(mariadbConnString, "mydaprdb", "mytable", "metadatatable")).
// Run tests
Run()
}
func populateTTLRecords(ctx context.Context, dbClient *sql.DB) error {
// Insert 10 records that have expired, and 10 that will expire in 4 seconds
exp := "DATE_SUB(CURRENT_TIMESTAMP, INTERVAL 1 MINUTE)"
rows := make([][]any, 20)
for i := 0; i < 10; i++ {
rows[i] = []any{
fmt.Sprintf("expired_%d", i),
json.RawMessage(fmt.Sprintf(`"value_%d"`, i)),
false,
exp,
}
}
exp = "DATE_ADD(CURRENT_TIMESTAMP, INTERVAL 4 second)"
for i := 0; i < 10; i++ {
rows[i+10] = []any{
fmt.Sprintf("notexpired_%d", i),
json.RawMessage(fmt.Sprintf(`"value_%d"`, i)),
false,
exp,
}
}
queryCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
for _, row := range rows {
query := fmt.Sprintf("INSERT INTO ttl_state (id, value, isbinary, eTag, expiredate) VALUES (?, ?, ?, '', %s)", row[3])
_, err := dbClient.ExecContext(queryCtx, query, row[0], row[1], row[2])
if err != nil {
return err
}
}
return nil
}
func countRowsInTable(ctx context.Context, dbClient *sql.DB, table string) (count int, err error) {
queryCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
err = dbClient.QueryRowContext(queryCtx, "SELECT COUNT(id) FROM "+table).Scan(&count)
cancel()
return
}
func loadLastCleanupInterval(ctx context.Context, dbClient *sql.DB, table string) (lastCleanup int, err error) {
queryCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
var lastCleanupf float64
err = dbClient.
QueryRowContext(queryCtx,
fmt.Sprintf("SELECT UNIX_TIMESTAMP(CURRENT_TIMESTAMP) - UNIX_TIMESTAMP(value) AS lastCleanupf FROM %s WHERE id = 'last-cleanup'", table),
).
Scan(&lastCleanupf)
lastCleanup = int(lastCleanupf)
cancel()
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
return
}
func setValueInMetadataTable(ctx context.Context, dbClient *sql.DB, table, id, value string) error {
queryCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
_, err := dbClient.ExecContext(queryCtx,
//nolint:gosec
fmt.Sprintf(`INSERT INTO %[1]s (id, value) VALUES (%[2]s, %[3]s) ON DUPLICATE KEY UPDATE
value = %[3]s`, table, id, value),
)
cancel()
return err
}
func getValueFromMetadataTable(ctx context.Context, dbClient *sql.DB, table, id string) (value string, err error) {
queryCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
err = dbClient.
QueryRowContext(queryCtx, fmt.Sprintf("SELECT value FROM %s WHERE id = ?", table), id).
Scan(&value)
cancel()
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
return
}

View File

@ -33,10 +33,10 @@ components:
operations: [ "set", "get", "delete", "bulkget", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ]
- component: mysql.mysql
allOperations: false
operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "first-write" ]
operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ]
- component: mysql.mariadb
allOperations: false
operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "first-write" ]
operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ]
- component: azure.tablestorage.storage
allOperations: false
operations: ["set", "get", "delete", "etag", "bulkset", "bulkdelete", "first-write"]

View File

@ -859,19 +859,19 @@ func assertDataEquals(t *testing.T, expect any, actual []byte) {
case intValueType:
// Custom type requires case mapping
if err := json.Unmarshal(actual, &v); err != nil {
assert.Failf(t, "unmarshal error", "error: %w, json: %s", err, string(actual))
assert.Failf(t, "unmarshal error", "error: %v, json: %s", err, string(actual))
}
assert.Equal(t, expect, v)
case ValueType:
// Custom type requires case mapping
if err := json.Unmarshal(actual, &v); err != nil {
assert.Failf(t, "unmarshal error", "error: %w, json: %s", err, string(actual))
assert.Failf(t, "unmarshal error", "error: %v, json: %s", err, string(actual))
}
assert.Equal(t, expect, v)
case int:
// json.Unmarshal to float64 by default, case mapping to int coerces to int type
if err := json.Unmarshal(actual, &v); err != nil {
assert.Failf(t, "unmarshal error", "error: %w, json: %s", err, string(actual))
assert.Failf(t, "unmarshal error", "error: %v, json: %s", err, string(actual))
}
assert.Equal(t, expect, v)
case []byte:
@ -879,7 +879,7 @@ func assertDataEquals(t *testing.T, expect any, actual []byte) {
default:
// Other golang primitive types (string, bool ...)
if err := json.Unmarshal(actual, &v); err != nil {
assert.Failf(t, "unmarshal error", "error: %w, json: %s", err, string(actual))
assert.Failf(t, "unmarshal error", "error: %v, json: %s", err, string(actual))
}
assert.Equal(t, expect, v)
}