Add MySQL TTL and Cleanup (#2641)
Signed-off-by: Deepanshu Agarwal <deepanshu.agarwal1984@gmail.com> Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> Co-authored-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> Co-authored-by: Bernd Verst <github@bernd.dev>
This commit is contained in:
parent
04d1e71ce7
commit
d6ce7bb5c3
|
@ -224,7 +224,7 @@ func (g *gc) updateLastCleanup(ctx context.Context) (bool, error) {
|
|||
}
|
||||
n = res.RowsAffected()
|
||||
} else {
|
||||
res, err := g.dbSQL.ExecContext(ctx, g.updateLastCleanupQuery, sql.Named("Interval", g.cleanupInterval.Milliseconds()-100))
|
||||
res, err := g.dbSQL.ExecContext(ctx, g.updateLastCleanupQuery, g.cleanupInterval.Milliseconds()-100)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("error updating last cleanup time: %w", err)
|
||||
}
|
||||
|
|
|
@ -1,22 +0,0 @@
|
|||
/*
|
||||
Copyright 2021 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package mysql
|
||||
|
||||
import "database/sql"
|
||||
|
||||
// This interface is used to help improve testing.
|
||||
type iMySQLFactory interface {
|
||||
Open(connectionString string) (*sql.DB, error)
|
||||
RegisterTLSConfig(pemPath string) error
|
||||
}
|
|
@ -25,6 +25,12 @@ import (
|
|||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
||||
// This interface is used to help improve testing.
|
||||
type iMySQLFactory interface {
|
||||
Open(connectionString string) (*sql.DB, error)
|
||||
RegisterTLSConfig(pemPath string) error
|
||||
}
|
||||
|
||||
type mySQLFactory struct {
|
||||
logger logger.Logger
|
||||
}
|
||||
|
|
|
@ -27,28 +27,20 @@ import (
|
|||
|
||||
"github.com/google/uuid"
|
||||
|
||||
sqlCleanup "github.com/dapr/components-contrib/internal/component/sql"
|
||||
"github.com/dapr/components-contrib/metadata"
|
||||
"github.com/dapr/components-contrib/state"
|
||||
"github.com/dapr/components-contrib/state/utils"
|
||||
"github.com/dapr/kit/logger"
|
||||
"github.com/dapr/kit/ptr"
|
||||
)
|
||||
|
||||
// Optimistic Concurrency is implemented using a string column that stores a UUID.
|
||||
|
||||
const (
|
||||
// The key name in the metadata if the user wants a different table name
|
||||
// than the defaultTableName.
|
||||
keyTableName = "tableName"
|
||||
|
||||
// The key name in the metadata if the user wants a different database name
|
||||
// than the defaultSchemaName.
|
||||
keySchemaName = "schemaName"
|
||||
|
||||
// The key name in the metadata for the timeout of operations, in seconds.
|
||||
keyTimeoutInSeconds = "timeoutInSeconds"
|
||||
|
||||
// The key for the mandatory connection string of the metadata.
|
||||
keyConnectionString = "connectionString"
|
||||
|
||||
// To connect to MySQL running in Azure over SSL you have to download a
|
||||
// SSL certificate. If this is provided the driver will connect using
|
||||
// SSL. If you have disable SSL you can leave this empty.
|
||||
|
@ -69,14 +61,27 @@ const (
|
|||
|
||||
// Standard error message if not connection string is provided.
|
||||
errMissingConnectionString = "missing connection string"
|
||||
|
||||
// Key name to configure interval at which entries with TTL are cleaned up.
|
||||
// This is parsed as a Go duration.
|
||||
cleanupIntervalKey = "cleanupInterval"
|
||||
|
||||
// Used if the user does not configure a metadata table name in the metadata.
|
||||
// In terms of TTL, it is required to store value for 'last-cleanup' id.
|
||||
defaultMetadataTableName = "dapr_metadata"
|
||||
|
||||
// Used if the user does not configure a cleanup interval in the metadata.
|
||||
defaultCleanupInterval = time.Hour
|
||||
)
|
||||
|
||||
// MySQL state store.
|
||||
type MySQL struct {
|
||||
tableName string
|
||||
schemaName string
|
||||
connectionString string
|
||||
timeout time.Duration
|
||||
tableName string
|
||||
metadataTableName string
|
||||
cleanupInterval *time.Duration
|
||||
schemaName string
|
||||
connectionString string
|
||||
timeout time.Duration
|
||||
|
||||
// Instance of the database to issue commands to
|
||||
db *sql.DB
|
||||
|
@ -85,14 +90,17 @@ type MySQL struct {
|
|||
logger logger.Logger
|
||||
|
||||
factory iMySQLFactory
|
||||
gc sqlCleanup.GarbageCollector
|
||||
}
|
||||
|
||||
type mySQLMetadata struct {
|
||||
TableName string
|
||||
SchemaName string
|
||||
ConnectionString string
|
||||
Timeout int
|
||||
PemPath string
|
||||
TableName string
|
||||
SchemaName string
|
||||
ConnectionString string
|
||||
Timeout int
|
||||
PemPath string
|
||||
MetadataTableName string
|
||||
CleanupInterval *time.Duration
|
||||
}
|
||||
|
||||
// NewMySQLStateStore creates a new instance of MySQL state store.
|
||||
|
@ -141,9 +149,12 @@ func (m *MySQL) Init(ctx context.Context, metadata state.Metadata) error {
|
|||
|
||||
func (m *MySQL) parseMetadata(md map[string]string) error {
|
||||
meta := mySQLMetadata{
|
||||
TableName: defaultTableName,
|
||||
SchemaName: defaultSchemaName,
|
||||
TableName: defaultTableName,
|
||||
SchemaName: defaultSchemaName,
|
||||
MetadataTableName: defaultMetadataTableName,
|
||||
CleanupInterval: ptr.Of(defaultCleanupInterval),
|
||||
}
|
||||
|
||||
err := metadata.DecodeMetadata(md, &meta)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -157,6 +168,14 @@ func (m *MySQL) parseMetadata(md map[string]string) error {
|
|||
}
|
||||
m.tableName = meta.TableName
|
||||
|
||||
if meta.MetadataTableName != "" {
|
||||
// Sanitize the metadata table name
|
||||
if !validIdentifier(meta.MetadataTableName) {
|
||||
return fmt.Errorf("metadata table name '%s' is not valid", meta.MetadataTableName)
|
||||
}
|
||||
}
|
||||
m.metadataTableName = meta.MetadataTableName
|
||||
|
||||
if meta.SchemaName != "" {
|
||||
// Sanitize the schema name
|
||||
if !validIdentifier(meta.SchemaName) {
|
||||
|
@ -171,6 +190,21 @@ func (m *MySQL) parseMetadata(md map[string]string) error {
|
|||
}
|
||||
m.connectionString = meta.ConnectionString
|
||||
|
||||
// Cleanup interval
|
||||
if meta.CleanupInterval != nil {
|
||||
// Non-positive value from meta means disable auto cleanup.
|
||||
if *meta.CleanupInterval <= 0 {
|
||||
if md[cleanupIntervalKey] == "" {
|
||||
// unfortunately the mapstructure decoder decodes an empty string to 0, a missing key would be nil however
|
||||
meta.CleanupInterval = ptr.Of(defaultCleanupInterval)
|
||||
} else {
|
||||
meta.CleanupInterval = nil
|
||||
}
|
||||
}
|
||||
|
||||
m.cleanupInterval = meta.CleanupInterval
|
||||
}
|
||||
|
||||
if meta.PemPath != "" {
|
||||
err := m.factory.RegisterTLSConfig(meta.PemPath)
|
||||
if err != nil {
|
||||
|
@ -231,7 +265,35 @@ func (m *MySQL) finishInit(ctx context.Context, db *sql.DB) error {
|
|||
}
|
||||
|
||||
// will be nil if everything is good or an err that needs to be returned
|
||||
return m.ensureStateTable(ctx, m.tableName)
|
||||
if err = m.ensureStateTable(ctx, m.schemaName, m.tableName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = m.ensureMetadataTable(ctx, m.schemaName, m.metadataTableName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.cleanupInterval != nil {
|
||||
gc, err := sqlCleanup.ScheduleGarbageCollector(sqlCleanup.GCOptions{
|
||||
Logger: m.logger,
|
||||
UpdateLastCleanupQuery: fmt.Sprintf(`INSERT INTO %[1]s (id, value)
|
||||
VALUES ('last-cleanup', CURRENT_TIMESTAMP)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
value = IF(CURRENT_TIMESTAMP > DATE_ADD(value, INTERVAL ?*1000 MICROSECOND), CURRENT_TIMESTAMP, value)`,
|
||||
m.metadataTableName),
|
||||
DeleteExpiredValuesQuery: fmt.Sprintf(
|
||||
`DELETE FROM %s WHERE expiredate IS NOT NULL AND expiredate <= CURRENT_TIMESTAMP`,
|
||||
m.tableName,
|
||||
),
|
||||
CleanupInterval: *m.cleanupInterval,
|
||||
DBSql: m.db,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.gc = gc
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MySQL) ensureStateSchema(ctx context.Context) error {
|
||||
|
@ -273,13 +335,13 @@ func (m *MySQL) ensureStateSchema(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (m *MySQL) ensureStateTable(ctx context.Context, stateTableName string) error {
|
||||
exists, err := tableExists(ctx, m.db, stateTableName, m.timeout)
|
||||
func (m *MySQL) ensureStateTable(ctx context.Context, schemaName, stateTableName string) error {
|
||||
tableExists, err := tableExists(ctx, m.db, schemaName, stateTableName, m.timeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !exists {
|
||||
if !tableExists {
|
||||
m.logger.Infof("Creating MySql state table '%s'", stateTableName)
|
||||
|
||||
// updateDate is updated automactically on every UPDATE commands so you
|
||||
|
@ -288,18 +350,85 @@ func (m *MySQL) ensureStateTable(ctx context.Context, stateTableName string) err
|
|||
// in on inserts and updates and is used for Optimistic Concurrency
|
||||
// Note that stateTableName is sanitized
|
||||
//nolint:gosec
|
||||
createTable := fmt.Sprintf(`CREATE TABLE %s (
|
||||
createTable := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (
|
||||
id VARCHAR(255) NOT NULL PRIMARY KEY,
|
||||
value JSON NOT NULL,
|
||||
isbinary BOOLEAN NOT NULL,
|
||||
insertDate TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updateDate TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
eTag VARCHAR(36) NOT NULL
|
||||
eTag VARCHAR(36) NOT NULL,
|
||||
expiredate TIMESTAMP NULL,
|
||||
INDEX expiredate_idx(expiredate)
|
||||
);`, stateTableName)
|
||||
|
||||
execCtx, execCancel := context.WithTimeout(ctx, m.timeout)
|
||||
defer execCancel()
|
||||
_, err = m.db.ExecContext(execCtx, createTable)
|
||||
_, err = m.db.ExecContext(ctx, createTable)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Check if expiredate column exists - to cater cases when table was created before v1.11.
|
||||
columnExists, err := columnExists(ctx, m.db, schemaName, stateTableName, "expiredate", m.timeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !columnExists {
|
||||
m.logger.Infof("Adding expiredate column to MySql state table '%s'", stateTableName)
|
||||
_, err = m.db.ExecContext(ctx, fmt.Sprintf(
|
||||
`ALTER TABLE %s ADD COLUMN IF NOT EXISTS expiredate TIMESTAMP NULL;`, stateTableName))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = m.db.ExecContext(ctx, fmt.Sprintf(
|
||||
`CREATE INDEX IF NOT EXISTS expiredate_idx ON %s (expiredate);`, stateTableName))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Create the DaprSaveFirstWriteV1 stored procedure
|
||||
_, err = m.db.ExecContext(ctx, `CREATE PROCEDURE IF NOT EXISTS DaprSaveFirstWriteV1(tableName VARCHAR(255), id VARCHAR(255), value JSON, etag VARCHAR(36), isbinary BOOLEAN, expiredateToken TEXT)
|
||||
LANGUAGE SQL
|
||||
MODIFIES SQL DATA
|
||||
BEGIN
|
||||
SET @id = id;
|
||||
SET @value = value;
|
||||
SET @etag = etag;
|
||||
SET @isbinary = isbinary;
|
||||
|
||||
SET @selectQuery = concat('SELECT COUNT(id) INTO @count FROM ', tableName ,' WHERE id = ? AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)');
|
||||
PREPARE select_stmt FROM @selectQuery;
|
||||
EXECUTE select_stmt USING @id;
|
||||
DEALLOCATE PREPARE select_stmt;
|
||||
|
||||
IF @count < 1 THEN
|
||||
SET @upsertQuery = concat('INSERT INTO ', tableName, ' SET id=?, value=?, eTag=?, isbinary=?, expiredate=', expiredateToken, ' ON DUPLICATE KEY UPDATE value=?, eTag=?, isbinary=?, expiredate=', expiredateToken);
|
||||
PREPARE upsert_stmt FROM @upsertQuery;
|
||||
EXECUTE upsert_stmt USING @id, @value, @etag, @isbinary, @value, @etag, @isbinary;
|
||||
DEALLOCATE PREPARE upsert_stmt;
|
||||
ELSE
|
||||
SIGNAL SQLSTATE '45000' SET MESSAGE_TEXT = 'Row already exists';
|
||||
END IF;
|
||||
|
||||
END`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MySQL) ensureMetadataTable(ctx context.Context, schemaName, metaTableName string) error {
|
||||
exists, err := tableExists(ctx, m.db, schemaName, metaTableName, m.timeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !exists {
|
||||
m.logger.Info("Creating MySQL metadata table")
|
||||
_, err = m.db.ExecContext(ctx, fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (
|
||||
id VARCHAR(255) NOT NULL PRIMARY KEY, value TEXT NOT NULL);`, metaTableName))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -321,16 +450,31 @@ func schemaExists(ctx context.Context, db *sql.DB, schemaName string, timeout ti
|
|||
return exists == 1, err
|
||||
}
|
||||
|
||||
func tableExists(ctx context.Context, db *sql.DB, tableName string, timeout time.Duration) (bool, error) {
|
||||
func tableExists(ctx context.Context, db *sql.DB, schemaName, tableName string, timeout time.Duration) (bool, error) {
|
||||
tableCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
// Returns 1 or 0 if the table exists or not
|
||||
var exists int
|
||||
query := `SELECT EXISTS (
|
||||
SELECT TABLE_NAME FROM information_schema.tables WHERE TABLE_NAME = ?
|
||||
SELECT TABLE_NAME FROM information_schema.tables WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?
|
||||
) AS 'exists'`
|
||||
err := db.QueryRowContext(tableCtx, query, tableName).Scan(&exists)
|
||||
err := db.QueryRowContext(tableCtx, query, schemaName, tableName).Scan(&exists)
|
||||
return exists == 1, err
|
||||
}
|
||||
|
||||
// columnExists returns true if the column exists in the table
|
||||
func columnExists(ctx context.Context, db *sql.DB, schemaName, tableName, columnName string, timeout time.Duration) (bool, error) {
|
||||
columnCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
// Returns 1 or 0 if the column exists or not
|
||||
var exists int
|
||||
query := `SELECT count(*) AS 'exists' FROM information_schema.columns
|
||||
WHERE table_schema = ?
|
||||
AND table_name = ?
|
||||
AND column_name = ?`
|
||||
err := db.QueryRowContext(columnCtx, query, schemaName, tableName, columnName).Scan(&exists)
|
||||
return exists == 1, err
|
||||
}
|
||||
|
||||
|
@ -343,8 +487,6 @@ func (m *MySQL) Delete(ctx context.Context, req *state.DeleteRequest) error {
|
|||
// deleteValue is an internal implementation of delete to enable passing the
|
||||
// logic to state.DeleteWithRetries as a func.
|
||||
func (m *MySQL) deleteValue(parentCtx context.Context, querier querier, req *state.DeleteRequest) error {
|
||||
m.logger.Debug("Deleting state value from MySql")
|
||||
|
||||
if req.Key == "" {
|
||||
return fmt.Errorf("missing key in delete operation")
|
||||
}
|
||||
|
@ -363,7 +505,7 @@ func (m *MySQL) deleteValue(parentCtx context.Context, querier querier, req *sta
|
|||
m.tableName), req.Key)
|
||||
} else {
|
||||
result, err = querier.ExecContext(execCtx, fmt.Sprintf(
|
||||
`DELETE FROM %s WHERE id = ? and eTag = ?`,
|
||||
`DELETE FROM %s WHERE id = ? AND eTag = ?`,
|
||||
m.tableName), req.Key, *req.ETag)
|
||||
}
|
||||
|
||||
|
@ -386,37 +528,33 @@ func (m *MySQL) deleteValue(parentCtx context.Context, querier querier, req *sta
|
|||
// BulkDelete removes multiple entries from the store
|
||||
// Store Interface.
|
||||
func (m *MySQL) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
|
||||
m.logger.Debug("Executing BulkDelete request")
|
||||
|
||||
tx, err := m.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
rollbackErr := tx.Rollback()
|
||||
if rollbackErr != nil && !errors.Is(rollbackErr, sql.ErrTxDone) {
|
||||
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
|
||||
}
|
||||
}()
|
||||
|
||||
if len(req) > 0 {
|
||||
for _, d := range req {
|
||||
da := d // Fix for goSec G601: Implicit memory aliasing in for loop.
|
||||
err = m.deleteValue(ctx, tx, &da)
|
||||
if err != nil {
|
||||
rollbackErr := tx.Rollback()
|
||||
if rollbackErr != nil {
|
||||
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
|
||||
return err
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// Get returns an entity from store
|
||||
// Store Interface.
|
||||
func (m *MySQL) Get(parentCtx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
|
||||
m.logger.Debug("Getting state value from MySql")
|
||||
|
||||
if req.Key == "" {
|
||||
return nil, fmt.Errorf("missing key in get operation")
|
||||
}
|
||||
|
@ -431,7 +569,8 @@ func (m *MySQL) Get(parentCtx context.Context, req *state.GetRequest) (*state.Ge
|
|||
defer cancel()
|
||||
//nolint:gosec
|
||||
query := fmt.Sprintf(
|
||||
`SELECT value, eTag, isbinary FROM %s WHERE id = ?`,
|
||||
`SELECT value, eTag, isbinary FROM %s WHERE id = ?
|
||||
AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)`,
|
||||
m.tableName, // m.tableName is sanitized
|
||||
)
|
||||
err := m.db.QueryRowContext(ctx, query, req.Key).Scan(&value, &eTag, &isBinary)
|
||||
|
@ -483,8 +622,6 @@ func (m *MySQL) Set(ctx context.Context, req *state.SetRequest) error {
|
|||
// setValue is an internal implementation of set to enable passing the logic
|
||||
// to state.SetWithRetries as a func.
|
||||
func (m *MySQL) setValue(parentCtx context.Context, querier querier, req *state.SetRequest) error {
|
||||
m.logger.Debug("Setting state value in MySql")
|
||||
|
||||
err := state.CheckRequestOptions(req.Options)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -494,6 +631,24 @@ func (m *MySQL) setValue(parentCtx context.Context, querier querier, req *state.
|
|||
return errors.New("missing key in set operation")
|
||||
}
|
||||
|
||||
// TTL
|
||||
var ttlSeconds int
|
||||
ttl, err := utils.ParseTTL(req.Metadata)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing TTL: %w", err)
|
||||
}
|
||||
if ttl != nil {
|
||||
ttlSeconds = *ttl
|
||||
}
|
||||
|
||||
var (
|
||||
query string
|
||||
ttlQuery string
|
||||
params []any
|
||||
result sql.Result
|
||||
maxRows int64 = 1
|
||||
)
|
||||
|
||||
var v any
|
||||
isBinary := false
|
||||
switch x := req.Value.(type) {
|
||||
|
@ -513,62 +668,91 @@ func (m *MySQL) setValue(parentCtx context.Context, querier querier, req *state.
|
|||
}
|
||||
eTag := eTagObj.String()
|
||||
|
||||
var (
|
||||
result sql.Result
|
||||
maxRows int64 = 1
|
||||
)
|
||||
if ttlSeconds > 0 {
|
||||
ttlQuery = "CURRENT_TIMESTAMP + INTERVAL " + strconv.Itoa(ttlSeconds) + " SECOND"
|
||||
} else {
|
||||
ttlQuery = "NULL"
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(parentCtx, m.timeout)
|
||||
defer cancel()
|
||||
mustCommit := false
|
||||
hasEtag := req.ETag != nil && *req.ETag != ""
|
||||
|
||||
if req.Options.Concurrency == state.FirstWrite && (req.ETag == nil || *req.ETag == "") {
|
||||
// With first-write-wins and no etag, we can insert the row only if it doesn't exist
|
||||
query := fmt.Sprintf(
|
||||
`INSERT INTO %s (value, id, eTag, isbinary) VALUES (?, ?, ?, ?);`,
|
||||
m.tableName, // m.tableName is sanitized
|
||||
)
|
||||
result, err = querier.ExecContext(ctx, query, enc, req.Key, eTag, isBinary)
|
||||
} else if req.ETag != nil && *req.ETag != "" {
|
||||
if hasEtag {
|
||||
// When an eTag is provided do an update - not insert
|
||||
query := fmt.Sprintf(
|
||||
`UPDATE %s SET value = ?, eTag = ?, isbinary = ? WHERE id = ? AND eTag = ?;`,
|
||||
m.tableName, // m.tableName is sanitized
|
||||
)
|
||||
result, err = querier.ExecContext(ctx, query, enc, eTag, isBinary, req.Key, *req.ETag)
|
||||
query = `UPDATE ` + m.tableName + `
|
||||
SET value = ?, eTag = ?, isbinary = ?, expiredate = ` + ttlQuery + `
|
||||
WHERE id = ?
|
||||
AND eTag = ?
|
||||
AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)`
|
||||
params = []any{enc, eTag, isBinary, req.Key, *req.ETag}
|
||||
} else if req.Options.Concurrency == state.FirstWrite {
|
||||
// If we're not in a transaction already, start one as we need to ensure consistency
|
||||
if querier == m.db {
|
||||
querier, err = m.db.BeginTx(parentCtx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
defer querier.(*sql.Tx).Rollback()
|
||||
mustCommit = true
|
||||
}
|
||||
|
||||
// With first-write-wins and no etag, we can insert the row only if it doesn't exist
|
||||
// Things get a bit tricky when the row exists but it is expired, so it just hasn't been garbage-collected yet
|
||||
// What we can do in that case is to first check if the row doesn't exist or has expired, and then perform an upsert
|
||||
// To do that, we use a stored procedure
|
||||
query = "CALL DaprSaveFirstWriteV1(?, ?, ?, ?, ?, ?)"
|
||||
params = []any{m.tableName, req.Key, enc, eTag, isBinary, ttlQuery}
|
||||
} else {
|
||||
// If this is a duplicate MySQL returns that two rows affected
|
||||
maxRows = 2
|
||||
query := fmt.Sprintf(
|
||||
`INSERT INTO %s (value, id, eTag, isbinary) VALUES (?, ?, ?, ?) on duplicate key update value=?, eTag=?, isbinary=?;`,
|
||||
m.tableName, // m.tableName is sanitized
|
||||
)
|
||||
result, err = querier.ExecContext(ctx, query, enc, req.Key, eTag, isBinary, enc, eTag, isBinary)
|
||||
query = `INSERT INTO ` + m.tableName + ` (id, value, eTag, isbinary, expiredate)
|
||||
VALUES (?, ?, ?, ?, ` + ttlQuery + `)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
value=?, eTag=?, isbinary=?, expiredate=` + ttlQuery
|
||||
params = []any{req.Key, enc, eTag, isBinary, enc, eTag, isBinary}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(parentCtx, m.timeout)
|
||||
defer cancel()
|
||||
result, err = querier.ExecContext(ctx, query, params...)
|
||||
|
||||
if err != nil {
|
||||
if req.ETag != nil && *req.ETag != "" {
|
||||
if hasEtag {
|
||||
return state.NewETagError(state.ETagMismatch, err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
// Do not count affected rows when using first-write
|
||||
// Conflicts are handled separately
|
||||
if hasEtag || req.Options.Concurrency != state.FirstWrite {
|
||||
var rows int64
|
||||
rows, err = result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
err = errors.New("rows affected error: no rows match given key and eTag")
|
||||
err = state.NewETagError(state.ETagMismatch, err)
|
||||
m.logger.Error(err)
|
||||
return err
|
||||
}
|
||||
|
||||
if rows > maxRows {
|
||||
err = fmt.Errorf("rows affected error: more than %d row affected; actual %d", maxRows, rows)
|
||||
m.logger.Error(err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
err = errors.New(`rows affected error: no rows match given key and eTag`)
|
||||
err = state.NewETagError(state.ETagMismatch, err)
|
||||
m.logger.Error(err)
|
||||
return err
|
||||
}
|
||||
|
||||
if rows > maxRows {
|
||||
err = fmt.Errorf(`rows affected error: more than %d row affected; actual %d`, maxRows, rows)
|
||||
m.logger.Error(err)
|
||||
return err
|
||||
// Commit the transaction if needed
|
||||
if mustCommit {
|
||||
err = querier.(*sql.Tx).Commit()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -577,21 +761,21 @@ func (m *MySQL) setValue(parentCtx context.Context, querier querier, req *state.
|
|||
// BulkSet adds/updates multiple entities on store
|
||||
// Store Interface.
|
||||
func (m *MySQL) BulkSet(ctx context.Context, req []state.SetRequest) error {
|
||||
m.logger.Debug("Executing BulkSet request")
|
||||
|
||||
tx, err := m.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
rollbackErr := tx.Rollback()
|
||||
if rollbackErr != nil && !errors.Is(rollbackErr, sql.ErrTxDone) {
|
||||
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
|
||||
}
|
||||
}()
|
||||
|
||||
if len(req) > 0 {
|
||||
for i := range req {
|
||||
err = m.setValue(ctx, tx, &req[i])
|
||||
if err != nil {
|
||||
rollbackErr := tx.Rollback()
|
||||
if rollbackErr != nil {
|
||||
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -603,50 +787,38 @@ func (m *MySQL) BulkSet(ctx context.Context, req []state.SetRequest) error {
|
|||
// Multi handles multiple transactions.
|
||||
// TransactionalStore Interface.
|
||||
func (m *MySQL) Multi(ctx context.Context, request *state.TransactionalStateRequest) error {
|
||||
m.logger.Debug("Executing Multi request")
|
||||
|
||||
tx, err := m.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
rollbackErr := tx.Rollback()
|
||||
if rollbackErr != nil && !errors.Is(rollbackErr, sql.ErrTxDone) {
|
||||
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
|
||||
}
|
||||
}()
|
||||
|
||||
for _, req := range request.Operations {
|
||||
switch req.Operation {
|
||||
case state.Upsert:
|
||||
setReq, err := m.getSets(req)
|
||||
if err != nil {
|
||||
rollbackErr := tx.Rollback()
|
||||
if rollbackErr != nil {
|
||||
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
err = m.setValue(ctx, tx, &setReq)
|
||||
if err != nil {
|
||||
rollbackErr := tx.Rollback()
|
||||
if rollbackErr != nil {
|
||||
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
case state.Delete:
|
||||
delReq, err := m.getDeletes(req)
|
||||
if err != nil {
|
||||
rollbackErr := tx.Rollback()
|
||||
if rollbackErr != nil {
|
||||
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
err = m.deleteValue(ctx, tx, &delReq)
|
||||
if err != nil {
|
||||
rollbackErr := tx.Rollback()
|
||||
if rollbackErr != nil {
|
||||
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -662,11 +834,11 @@ func (m *MySQL) Multi(ctx context.Context, request *state.TransactionalStateRequ
|
|||
func (m *MySQL) getSets(req state.TransactionalStateOperation) (state.SetRequest, error) {
|
||||
setReq, ok := req.Request.(state.SetRequest)
|
||||
if !ok {
|
||||
return setReq, fmt.Errorf("expecting set request")
|
||||
return setReq, errors.New("expecting set request")
|
||||
}
|
||||
|
||||
if setReq.Key == "" {
|
||||
return setReq, fmt.Errorf("missing key in upsert operation")
|
||||
return setReq, errors.New("missing key in upsert operation")
|
||||
}
|
||||
|
||||
return setReq, nil
|
||||
|
@ -676,11 +848,11 @@ func (m *MySQL) getSets(req state.TransactionalStateOperation) (state.SetRequest
|
|||
func (m *MySQL) getDeletes(req state.TransactionalStateOperation) (state.DeleteRequest, error) {
|
||||
delReq, ok := req.Request.(state.DeleteRequest)
|
||||
if !ok {
|
||||
return delReq, fmt.Errorf("expecting delete request")
|
||||
return delReq, errors.New("expecting delete request")
|
||||
}
|
||||
|
||||
if delReq.Key == "" {
|
||||
return delReq, fmt.Errorf("missing key in delete operation")
|
||||
return delReq, errors.New("missing key in delete operation")
|
||||
}
|
||||
|
||||
return delReq, nil
|
||||
|
@ -701,6 +873,10 @@ func (m *MySQL) Close() error {
|
|||
|
||||
err := m.db.Close()
|
||||
m.db = nil
|
||||
if m.gc != nil {
|
||||
return errors.Join(err, m.gc.Close())
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ package mysql
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GetConnection returns the database connection.
|
||||
|
@ -35,3 +36,12 @@ func (m *MySQL) SchemaName() string {
|
|||
func (m *MySQL) TableName() string {
|
||||
return m.tableName
|
||||
}
|
||||
|
||||
// CleanupInterval returns the value of the cleanupInterval property.
|
||||
func (m *MySQL) CleanupInterval() *time.Duration {
|
||||
return m.cleanupInterval
|
||||
}
|
||||
|
||||
func (m *MySQL) CleanupExpired() error {
|
||||
return m.gc.CleanupExpired()
|
||||
}
|
||||
|
|
|
@ -149,7 +149,7 @@ func TestMySQLIntegration(t *testing.T) {
|
|||
tableName := "test_state"
|
||||
|
||||
// Drop the table if it already exists
|
||||
exists, err := tableExists(context.Background(), mys.db, tableName, 10*time.Second)
|
||||
exists, err := tableExists(context.Background(), mys.db, "dapr_state_store", tableName, 10*time.Second)
|
||||
assert.Nil(t, err)
|
||||
if exists {
|
||||
dropTable(t, mys.db, tableName)
|
||||
|
@ -157,11 +157,11 @@ func TestMySQLIntegration(t *testing.T) {
|
|||
|
||||
// Create the state table and test for its existence
|
||||
// There should be no error
|
||||
err = mys.ensureStateTable(context.Background(), tableName)
|
||||
err = mys.ensureStateTable(context.Background(), "dapr_state_store", tableName)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Now create it and make sure there are no errors
|
||||
exists, err = tableExists(context.Background(), mys.db, tableName, 10*time.Second)
|
||||
exists, err = tableExists(context.Background(), mys.db, "dapr_state_store", tableName, 10*time.Second)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
|
@ -35,6 +36,9 @@ import (
|
|||
|
||||
const (
|
||||
fakeConnectionString = "not a real connection"
|
||||
keyTableName = "tableName"
|
||||
keyConnectionString = "connectionString"
|
||||
keySchemaName = "schemaName"
|
||||
)
|
||||
|
||||
func TestEnsureStateSchemaHandlesShortConnectionString(t *testing.T) {
|
||||
|
@ -67,7 +71,7 @@ func TestFinishInitHandlesSchemaExistsError(t *testing.T) {
|
|||
actualErr := m.mySQL.finishInit(context.Background(), m.mySQL.db)
|
||||
|
||||
// Assert
|
||||
assert.NotNil(t, actualErr, "now error returned")
|
||||
assert.Error(t, actualErr, "now error returned")
|
||||
assert.Equal(t, "existsError", actualErr.Error(), "wrong error")
|
||||
}
|
||||
|
||||
|
@ -86,7 +90,7 @@ func TestFinishInitHandlesDatabaseCreateError(t *testing.T) {
|
|||
actualErr := m.mySQL.finishInit(context.Background(), m.mySQL.db)
|
||||
|
||||
// Assert
|
||||
assert.NotNil(t, actualErr, "now error returned")
|
||||
assert.Error(t, actualErr, "now error returned")
|
||||
assert.Equal(t, "createDatabaseError", actualErr.Error(), "wrong error")
|
||||
}
|
||||
|
||||
|
@ -541,7 +545,7 @@ func TestTableExists(t *testing.T) {
|
|||
m.mock1.ExpectQuery("SELECT EXISTS").WillReturnRows(rows)
|
||||
|
||||
// Act
|
||||
actual, err := tableExists(context.Background(), m.mySQL.db, "store", 10*time.Second)
|
||||
actual, err := tableExists(context.Background(), m.mySQL.db, "dapr_state_store", "store", 10*time.Second)
|
||||
|
||||
// Assert
|
||||
assert.Nil(t, err, `error was returned`)
|
||||
|
@ -559,7 +563,7 @@ func TestEnsureStateTableHandlesCreateTableError(t *testing.T) {
|
|||
m.mock1.ExpectExec("CREATE TABLE").WillReturnError(fmt.Errorf("CreateTableError"))
|
||||
|
||||
// Act
|
||||
err := m.mySQL.ensureStateTable(context.Background(), "state")
|
||||
err := m.mySQL.ensureStateTable(context.Background(), "dapr_state_store", "state")
|
||||
|
||||
// Assert
|
||||
assert.NotNil(t, err, "no error returned")
|
||||
|
@ -578,12 +582,15 @@ func TestEnsureStateTableCreatesTable(t *testing.T) {
|
|||
rows := sqlmock.NewRows([]string{"exists"}).AddRow(0)
|
||||
m.mock1.ExpectQuery("SELECT EXISTS").WillReturnRows(rows)
|
||||
m.mock1.ExpectExec("CREATE TABLE").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
rows = sqlmock.NewRows([]string{"exists"}).AddRow(1)
|
||||
m.mock1.ExpectQuery("SELECT count(/*)").WillReturnRows(rows)
|
||||
m.mock1.ExpectExec("CREATE PROCEDURE").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
// Act
|
||||
err := m.mySQL.ensureStateTable(context.Background(), "state")
|
||||
err := m.mySQL.ensureStateTable(context.Background(), "dapr_state_store", "state")
|
||||
|
||||
// Assert
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify that the call to MySQL init get passed through
|
||||
|
|
|
@ -39,5 +39,16 @@ d. Get and validate eTag, which should not have changed.
|
|||
|
||||
1. Without `schemaName`, check that the default one is used
|
||||
2. Without `tableName`, check that the default one is used
|
||||
3. Instantiate a component with a custom `schemaName` and validate it's used
|
||||
4. Instantiate a component with a custom `tableName` and validate it's used
|
||||
3. Without `metadataTableName`, check that the default one is used
|
||||
4. Instantiate a component with a custom `schemaName` and validate it's used
|
||||
5. Instantiate a component with a custom `tableName` and validate it's used
|
||||
6. Instantiate a component with a custom `metadataTableName` and validate it's used
|
||||
|
||||
## TTLs and cleanups
|
||||
|
||||
1. Correctly parse the `cleanupIntervalInSeconds` metadata property:
|
||||
- No value uses the default value (3600 seconds)
|
||||
- A positive value sets the interval to the given number of seconds
|
||||
- A zero or negative value disables the cleanup
|
||||
2. The cleanup method deletes expired records and updates the metadata table with the last time it ran
|
||||
3. The cleanup method doesn't run if the last iteration was less than `cleanupIntervalInSeconds` or if another process is doing the cleanup
|
|
@ -62,6 +62,9 @@ require (
|
|||
github.com/hashicorp/golang-lru/v2 v2.0.1 // indirect
|
||||
github.com/hashicorp/serf v0.10.1 // indirect
|
||||
github.com/imdario/mergo v0.3.13 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||
github.com/jackc/pgx/v5 v5.3.1 // indirect
|
||||
github.com/jhump/protoreflect v1.14.1 // indirect
|
||||
github.com/josharian/intern v1.0.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
|
@ -108,6 +111,7 @@ require (
|
|||
go.opentelemetry.io/otel/sdk v1.11.2 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.11.2 // indirect
|
||||
go.opentelemetry.io/proto/otlp v0.19.0 // indirect
|
||||
golang.org/x/crypto v0.7.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20230310171629-522b1b587ee0 // indirect
|
||||
golang.org/x/mod v0.9.0 // indirect
|
||||
golang.org/x/net v0.8.0 // indirect
|
||||
|
|
|
@ -291,6 +291,12 @@ github.com/hashicorp/serf v0.10.1/go.mod h1:yL2t6BqATOLGc5HF7qbFkTfXoPIY0WZdWHfE
|
|||
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
|
||||
github.com/imdario/mergo v0.3.13 h1:lFzP57bqS/wsqKssCGmtLAb8A0wKjLGrve2q3PPVcBk=
|
||||
github.com/imdario/mergo v0.3.13/go.mod h1:4lJ1jqUDcsbIECGy0RUJAXNIhg+6ocWgb1ALK2O4oXg=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU=
|
||||
github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8=
|
||||
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
||||
github.com/jhump/gopoet v0.0.0-20190322174617-17282ff210b3/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI=
|
||||
github.com/jhump/gopoet v0.1.0/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI=
|
||||
|
@ -530,6 +536,8 @@ golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392/go.mod h1:/lpIB1dKB+9EgE3
|
|||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A=
|
||||
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||
|
|
|
@ -16,7 +16,9 @@ package main
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
@ -45,11 +47,17 @@ const (
|
|||
certificationTestPrefix = "stable-certification-"
|
||||
timeout = 5 * time.Second
|
||||
|
||||
defaultSchemaName = "dapr_state_store"
|
||||
defaultTableName = "state"
|
||||
defaultSchemaName = "dapr_state_store"
|
||||
defaultTableName = "state"
|
||||
defaultMetadataTableName = "dapr_metadata"
|
||||
|
||||
mysqlConnString = "root:root@tcp(localhost:3306)/?allowNativePasswords=true"
|
||||
mariadbConnString = "root:root@tcp(localhost:3307)/"
|
||||
|
||||
keyConnectionString = "connectionString"
|
||||
keyCleanupInterval = "cleanupInterval"
|
||||
keyTableName = "tableName"
|
||||
keyMetadatTableName = "metadataTableName"
|
||||
)
|
||||
|
||||
func TestMySQL(t *testing.T) {
|
||||
|
@ -315,7 +323,7 @@ func TestMySQL(t *testing.T) {
|
|||
}
|
||||
|
||||
// checks that metadata options schemaName and tableName behave correctly
|
||||
metadataTest := func(connString string, schemaName string, tableName string) func(ctx flow.Context) error {
|
||||
metadataTest := func(connString, schemaName, tableName, metadataTableName string) func(ctx flow.Context) error {
|
||||
return func(ctx flow.Context) (err error) {
|
||||
properties := map[string]string{
|
||||
"connectionString": connString,
|
||||
|
@ -332,6 +340,11 @@ func TestMySQL(t *testing.T) {
|
|||
} else {
|
||||
tableName = defaultTableName
|
||||
}
|
||||
if metadataTableName != "" {
|
||||
properties["metadataTableName"] = metadataTableName
|
||||
} else {
|
||||
metadataTableName = defaultMetadataTableName
|
||||
}
|
||||
|
||||
// Init the component
|
||||
component := stateMysql.NewMySQLStateStore(log).(*stateMysql.MySQL)
|
||||
|
@ -359,9 +372,23 @@ func TestMySQL(t *testing.T) {
|
|||
|
||||
// Check that the table exists
|
||||
query = `SELECT EXISTS (
|
||||
SELECT TABLE_NAME FROM information_schema.tables WHERE TABLE_NAME = ?
|
||||
SELECT TABLE_NAME FROM information_schema.tables WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?
|
||||
) AS 'exists'`
|
||||
err = conn.QueryRow(query, tableName).Scan(&exists)
|
||||
err = conn.QueryRow(query, schemaName, tableName).Scan(&exists)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, exists)
|
||||
|
||||
// Check that the expiredate column exists
|
||||
query = `SELECT count(*) AS 'exists' FROM information_schema.columns
|
||||
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? AND COLUMN_NAME = ?`
|
||||
err = conn.QueryRow(query, schemaName, tableName, "expiredate").Scan(&exists)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, exists)
|
||||
|
||||
// Check that the metadata table exists
|
||||
query = `SELECT count(*) AS 'exists' FROM information_schema.tables
|
||||
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?`
|
||||
err = conn.QueryRow(query, schemaName, tableName).Scan(&exists)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, exists)
|
||||
|
||||
|
@ -373,6 +400,165 @@ func TestMySQL(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// Validates TTLs and garbage collections
|
||||
ttlTest := func(connString string) func(ctx flow.Context) error {
|
||||
return func(ctx flow.Context) (err error) {
|
||||
md := state.Metadata{
|
||||
Base: metadata.Base{
|
||||
Name: "ttltest",
|
||||
Properties: map[string]string{
|
||||
keyConnectionString: connString,
|
||||
keyTableName: "ttl_state",
|
||||
keyMetadatTableName: "ttl_metadata",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("parse cleanupInterval", func(t *testing.T) {
|
||||
t.Run("default value", func(t *testing.T) {
|
||||
// Default value is 1 hr
|
||||
md.Properties[keyCleanupInterval] = ""
|
||||
storeObj := stateMysql.NewMySQLStateStore(log).(*stateMysql.MySQL)
|
||||
|
||||
err := storeObj.Init(ctx, md)
|
||||
require.NoError(t, err, "failed to init")
|
||||
defer storeObj.Close()
|
||||
|
||||
cleanupInterval := storeObj.CleanupInterval()
|
||||
_ = assert.NotNil(t, cleanupInterval) &&
|
||||
assert.Equal(t, time.Duration(1*time.Hour), *cleanupInterval)
|
||||
})
|
||||
|
||||
t.Run("positive value", func(t *testing.T) {
|
||||
md.Properties[keyCleanupInterval] = "10s"
|
||||
storeObj := stateMysql.NewMySQLStateStore(log).(*stateMysql.MySQL)
|
||||
|
||||
err := storeObj.Init(ctx, md)
|
||||
require.NoError(t, err, "failed to init")
|
||||
defer storeObj.Close()
|
||||
|
||||
cleanupInterval := storeObj.CleanupInterval()
|
||||
_ = assert.NotNil(t, cleanupInterval) &&
|
||||
assert.Equal(t, time.Duration(10*time.Second), *cleanupInterval)
|
||||
})
|
||||
|
||||
t.Run("disabled", func(t *testing.T) {
|
||||
// A value of <=0 means that the cleanup is disabled
|
||||
md.Properties[keyCleanupInterval] = "0"
|
||||
storeObj := stateMysql.NewMySQLStateStore(log).(*stateMysql.MySQL)
|
||||
|
||||
err := storeObj.Init(ctx, md)
|
||||
require.NoError(t, err, "failed to init")
|
||||
defer storeObj.Close()
|
||||
|
||||
cleanupInterval := storeObj.CleanupInterval()
|
||||
_ = assert.Nil(t, cleanupInterval)
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
t.Run("cleanup", func(t *testing.T) {
|
||||
md := state.Metadata{
|
||||
Base: metadata.Base{
|
||||
Name: "ttltest",
|
||||
Properties: map[string]string{
|
||||
keyConnectionString: connString,
|
||||
keyTableName: "ttl_state",
|
||||
keyMetadatTableName: "ttl_metadata",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("automatically delete expired records", func(t *testing.T) {
|
||||
// Run every second
|
||||
md.Properties[keyCleanupInterval] = "1s"
|
||||
|
||||
storeObj := stateMysql.NewMySQLStateStore(log).(*stateMysql.MySQL)
|
||||
err := storeObj.Init(ctx, md)
|
||||
require.NoError(t, err, "failed to init")
|
||||
defer storeObj.Close()
|
||||
|
||||
conn := storeObj.GetConnection()
|
||||
// Seed the database with some records
|
||||
err = populateTTLRecords(ctx, conn)
|
||||
require.NoError(t, err, "failed to seed records")
|
||||
|
||||
// Wait 2 seconds then verify we have only 10 rows left
|
||||
time.Sleep(2 * time.Second)
|
||||
count, err := countRowsInTable(ctx, conn, "ttl_state")
|
||||
require.NoError(t, err, "failed to run query to count rows")
|
||||
assert.Equal(t, 10, count)
|
||||
|
||||
// The "last-cleanup" value should be <= 2 seconds (+ a bit of buffer)
|
||||
lastCleanup, err := loadLastCleanupInterval(ctx, conn, "ttl_metadata")
|
||||
require.NoError(t, err, "failed to load value for 'last-cleanup'")
|
||||
assert.LessOrEqual(t, lastCleanup, 2)
|
||||
|
||||
// Wait 6 more seconds and verify there are no more rows left
|
||||
time.Sleep(6 * time.Second)
|
||||
count, err = countRowsInTable(ctx, conn, "ttl_state")
|
||||
require.NoError(t, err, "failed to run query to count rows")
|
||||
assert.Equal(t, 0, count)
|
||||
|
||||
// The "last-cleanup" value should be <= 2 seconds (+ a bit of buffer)
|
||||
lastCleanup, err = loadLastCleanupInterval(ctx, conn, "ttl_metadata")
|
||||
require.NoError(t, err, "failed to load value for 'last-cleanup'")
|
||||
assert.LessOrEqual(t, lastCleanup, 2)
|
||||
})
|
||||
|
||||
t.Run("cleanup concurrency", func(t *testing.T) {
|
||||
// Set to run every hour
|
||||
// (we'll manually trigger more frequent iterations)
|
||||
md.Properties[keyCleanupInterval] = "1h"
|
||||
|
||||
storeObj := stateMysql.NewMySQLStateStore(log).(*stateMysql.MySQL)
|
||||
err := storeObj.Init(ctx, md)
|
||||
require.NoError(t, err, "failed to init")
|
||||
defer storeObj.Close()
|
||||
|
||||
conn := storeObj.GetConnection()
|
||||
|
||||
// Seed the database with some records
|
||||
err = populateTTLRecords(ctx, conn)
|
||||
require.NoError(t, err, "failed to seed records")
|
||||
|
||||
// Validate that 20 records are present
|
||||
count, err := countRowsInTable(ctx, conn, "ttl_state")
|
||||
require.NoError(t, err, "failed to run query to count rows")
|
||||
assert.Equal(t, 20, count)
|
||||
|
||||
// Set last-cleanup to 1s ago (less than 3600s)
|
||||
err = setValueInMetadataTable(ctx, conn, "ttl_metadata", "'last-cleanup'", "CURRENT_TIMESTAMP - INTERVAL 1 SECOND")
|
||||
require.NoError(t, err, "failed to set last-cleanup")
|
||||
|
||||
// The "last-cleanup" value should be ~2 seconds (+ a bit of buffer)
|
||||
lastCleanup, err := loadLastCleanupInterval(ctx, conn, "ttl_metadata")
|
||||
require.NoError(t, err, "failed to load value for 'last-cleanup'")
|
||||
assert.LessOrEqual(t, lastCleanup, 2)
|
||||
lastCleanupValueOrig, err := getValueFromMetadataTable(ctx, conn, "ttl_metadata", "last-cleanup")
|
||||
require.NoError(t, err, "failed to load absolute value for 'last-cleanup'")
|
||||
require.NotEmpty(t, lastCleanupValueOrig)
|
||||
|
||||
// Trigger the background cleanup, which should do nothing because the last cleanup was < 3600s
|
||||
err = storeObj.CleanupExpired()
|
||||
require.NoError(t, err, "CleanupExpired returned an error")
|
||||
|
||||
// Validate that 20 records are still present
|
||||
count, err = countRowsInTable(ctx, conn, "ttl_state")
|
||||
require.NoError(t, err, "failed to run query to count rows")
|
||||
assert.Equal(t, 20, count)
|
||||
|
||||
// The "last-cleanup" value should not have been changed
|
||||
lastCleanupValue, err := getValueFromMetadataTable(ctx, conn, "ttl_metadata", "last-cleanup")
|
||||
require.NoError(t, err, "failed to load absolute value for 'last-cleanup'")
|
||||
assert.Equal(t, lastCleanupValueOrig, lastCleanupValue)
|
||||
})
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
flow.New(t, "Run tests").
|
||||
Step(dockercompose.Run("db", dockerComposeYAML)).
|
||||
Step("Wait for databases to start", flow.Sleep(30*time.Second)).
|
||||
|
@ -389,6 +575,8 @@ func TestMySQL(t *testing.T) {
|
|||
Step("Run eTag test on mariadb", eTagTest("mariadb")).
|
||||
Step("Run transactions test", transactionsTest("mysql")).
|
||||
Step("Run transactions test", transactionsTest("mariadb")).
|
||||
Step("Run TTL test on mysql", ttlTest(mysqlConnString)).
|
||||
Step("Run TTL test on mariadb", ttlTest(mariadbConnString)).
|
||||
Step("Run SQL injection test on mysql", verifySQLInjectionTest("mysql")).
|
||||
Step("Run SQL injection test on mariadb", verifySQLInjectionTest("mariadb")).
|
||||
//Step("Interrupt network and simulate timeouts", timeoutTest).
|
||||
|
@ -408,10 +596,91 @@ func TestMySQL(t *testing.T) {
|
|||
Step("Close database connection 1", closeTest(0)).
|
||||
Step("Close database connection 2", closeTest(1)).
|
||||
// Metadata
|
||||
Step("Default schemaName and tableName on mysql", metadataTest(mysqlConnString, "", "")).
|
||||
Step("Custom schemaName and tableName on mysql", metadataTest(mysqlConnString, "mydaprdb", "mytable")).
|
||||
Step("Default schemaName and tableName on mariadb", metadataTest(mariadbConnString, "", "")).
|
||||
Step("Custom schemaName and tableName on mariadb", metadataTest(mariadbConnString, "mydaprdb", "mytable")).
|
||||
Step("Default schemaName, tableName and metadataTableName on mysql", metadataTest(mysqlConnString, "", "", "")).
|
||||
Step("Custom schemaName, tableName and metadataTableName on mysql", metadataTest(mysqlConnString, "mydaprdb", "mytable", "metadatatable")).
|
||||
Step("Default schemaName, tableName and metadataTableName on mariadb", metadataTest(mariadbConnString, "", "", "")).
|
||||
Step("Custom schemaName, tableName and metadataTableName on mariadb", metadataTest(mariadbConnString, "mydaprdb", "mytable", "metadatatable")).
|
||||
// Run tests
|
||||
Run()
|
||||
}
|
||||
|
||||
func populateTTLRecords(ctx context.Context, dbClient *sql.DB) error {
|
||||
// Insert 10 records that have expired, and 10 that will expire in 4 seconds
|
||||
exp := "DATE_SUB(CURRENT_TIMESTAMP, INTERVAL 1 MINUTE)"
|
||||
rows := make([][]any, 20)
|
||||
for i := 0; i < 10; i++ {
|
||||
rows[i] = []any{
|
||||
fmt.Sprintf("expired_%d", i),
|
||||
json.RawMessage(fmt.Sprintf(`"value_%d"`, i)),
|
||||
false,
|
||||
exp,
|
||||
}
|
||||
}
|
||||
exp = "DATE_ADD(CURRENT_TIMESTAMP, INTERVAL 4 second)"
|
||||
for i := 0; i < 10; i++ {
|
||||
rows[i+10] = []any{
|
||||
fmt.Sprintf("notexpired_%d", i),
|
||||
json.RawMessage(fmt.Sprintf(`"value_%d"`, i)),
|
||||
false,
|
||||
exp,
|
||||
}
|
||||
}
|
||||
queryCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
defer cancel()
|
||||
for _, row := range rows {
|
||||
query := fmt.Sprintf("INSERT INTO ttl_state (id, value, isbinary, eTag, expiredate) VALUES (?, ?, ?, '', %s)", row[3])
|
||||
_, err := dbClient.ExecContext(queryCtx, query, row[0], row[1], row[2])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func countRowsInTable(ctx context.Context, dbClient *sql.DB, table string) (count int, err error) {
|
||||
queryCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
err = dbClient.QueryRowContext(queryCtx, "SELECT COUNT(id) FROM "+table).Scan(&count)
|
||||
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
|
||||
func loadLastCleanupInterval(ctx context.Context, dbClient *sql.DB, table string) (lastCleanup int, err error) {
|
||||
queryCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
var lastCleanupf float64
|
||||
err = dbClient.
|
||||
QueryRowContext(queryCtx,
|
||||
fmt.Sprintf("SELECT UNIX_TIMESTAMP(CURRENT_TIMESTAMP) - UNIX_TIMESTAMP(value) AS lastCleanupf FROM %s WHERE id = 'last-cleanup'", table),
|
||||
).
|
||||
Scan(&lastCleanupf)
|
||||
lastCleanup = int(lastCleanupf)
|
||||
cancel()
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func setValueInMetadataTable(ctx context.Context, dbClient *sql.DB, table, id, value string) error {
|
||||
queryCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
_, err := dbClient.ExecContext(queryCtx,
|
||||
//nolint:gosec
|
||||
fmt.Sprintf(`INSERT INTO %[1]s (id, value) VALUES (%[2]s, %[3]s) ON DUPLICATE KEY UPDATE
|
||||
value = %[3]s`, table, id, value),
|
||||
)
|
||||
cancel()
|
||||
return err
|
||||
}
|
||||
|
||||
func getValueFromMetadataTable(ctx context.Context, dbClient *sql.DB, table, id string) (value string, err error) {
|
||||
queryCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
err = dbClient.
|
||||
QueryRowContext(queryCtx, fmt.Sprintf("SELECT value FROM %s WHERE id = ?", table), id).
|
||||
Scan(&value)
|
||||
cancel()
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -33,10 +33,10 @@ components:
|
|||
operations: [ "set", "get", "delete", "bulkget", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ]
|
||||
- component: mysql.mysql
|
||||
allOperations: false
|
||||
operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "first-write" ]
|
||||
operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ]
|
||||
- component: mysql.mariadb
|
||||
allOperations: false
|
||||
operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "first-write" ]
|
||||
operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "first-write", "ttl" ]
|
||||
- component: azure.tablestorage.storage
|
||||
allOperations: false
|
||||
operations: ["set", "get", "delete", "etag", "bulkset", "bulkdelete", "first-write"]
|
||||
|
|
|
@ -859,19 +859,19 @@ func assertDataEquals(t *testing.T, expect any, actual []byte) {
|
|||
case intValueType:
|
||||
// Custom type requires case mapping
|
||||
if err := json.Unmarshal(actual, &v); err != nil {
|
||||
assert.Failf(t, "unmarshal error", "error: %w, json: %s", err, string(actual))
|
||||
assert.Failf(t, "unmarshal error", "error: %v, json: %s", err, string(actual))
|
||||
}
|
||||
assert.Equal(t, expect, v)
|
||||
case ValueType:
|
||||
// Custom type requires case mapping
|
||||
if err := json.Unmarshal(actual, &v); err != nil {
|
||||
assert.Failf(t, "unmarshal error", "error: %w, json: %s", err, string(actual))
|
||||
assert.Failf(t, "unmarshal error", "error: %v, json: %s", err, string(actual))
|
||||
}
|
||||
assert.Equal(t, expect, v)
|
||||
case int:
|
||||
// json.Unmarshal to float64 by default, case mapping to int coerces to int type
|
||||
if err := json.Unmarshal(actual, &v); err != nil {
|
||||
assert.Failf(t, "unmarshal error", "error: %w, json: %s", err, string(actual))
|
||||
assert.Failf(t, "unmarshal error", "error: %v, json: %s", err, string(actual))
|
||||
}
|
||||
assert.Equal(t, expect, v)
|
||||
case []byte:
|
||||
|
@ -879,7 +879,7 @@ func assertDataEquals(t *testing.T, expect any, actual []byte) {
|
|||
default:
|
||||
// Other golang primitive types (string, bool ...)
|
||||
if err := json.Unmarshal(actual, &v); err != nil {
|
||||
assert.Failf(t, "unmarshal error", "error: %w, json: %s", err, string(actual))
|
||||
assert.Failf(t, "unmarshal error", "error: %v, json: %s", err, string(actual))
|
||||
}
|
||||
assert.Equal(t, expect, v)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue