Use regular transactions

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
ItalyPaleAle 2023-02-03 22:40:03 +00:00
parent 89720a909d
commit 38f29a6ff2
1 changed files with 20 additions and 36 deletions

View File

@ -34,33 +34,24 @@ type migrations struct {
// Perform the required migrations // Perform the required migrations
func (m *migrations) Perform(ctx context.Context) error { func (m *migrations) Perform(ctx context.Context) error {
// Begin an exclusive transaction // Begin a transaction
// We can't use Begin because that doesn't allow us setting the level of transaction tx, err := m.Conn.Begin()
queryCtx, cancel := context.WithTimeout(ctx, time.Minute)
_, err := m.Conn.ExecContext(queryCtx, "BEGIN EXCLUSIVE TRANSACTION")
cancel()
if err != nil { if err != nil {
return fmt.Errorf("faild to begin transaction: %w", err) return fmt.Errorf("faild to begin transaction: %w", err)
} }
// Rollback the transaction in a deferred statement to catch errors // Rollback the transaction in a deferred statement to catch errors
success := false
defer func() { defer func() {
if success { err = tx.Rollback()
return if err != nil && err != sql.ErrTxDone {
}
queryCtx, cancel = context.WithTimeout(ctx, time.Minute)
_, err = m.Conn.ExecContext(queryCtx, "ROLLBACK TRANSACTION")
cancel()
if err != nil {
// Panicking here, as this forcibly closes the session and thus ensures we are not leaving locks hanging around // Panicking here, as this forcibly closes the session and thus ensures we are not leaving locks hanging around
m.Logger.Fatalf("Failed to rollback transaction: %v", err) m.Logger.Fatalf("Failed to rollback transaction: %v", err)
} }
}() }()
// Check if the metadata table exists, which we also use to store the migration level // Check if the metadata table exists, which we also use to store the migration level
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second) queryCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
exists, err := m.tableExists(queryCtx, m.MetadataTableName) exists, err := m.tableExists(queryCtx, tx, m.MetadataTableName)
cancel() cancel()
if err != nil { if err != nil {
return fmt.Errorf("failed to check if the metadata table exists: %w", err) return fmt.Errorf("failed to check if the metadata table exists: %w", err)
@ -69,7 +60,7 @@ func (m *migrations) Perform(ctx context.Context) error {
// If the table doesn't exist, create it // If the table doesn't exist, create it
if !exists { if !exists {
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second) queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
err = m.createMetadataTable(queryCtx) err = m.createMetadataTable(queryCtx, tx)
cancel() cancel()
if err != nil { if err != nil {
return fmt.Errorf("failed to create metadata table: %w", err) return fmt.Errorf("failed to create metadata table: %w", err)
@ -82,10 +73,9 @@ func (m *migrations) Perform(ctx context.Context) error {
migrationLevel int migrationLevel int
) )
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second) queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
err = m.Conn. err = tx.QueryRowContext(queryCtx,
QueryRowContext(queryCtx, fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, m.MetadataTableName),
fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, m.MetadataTableName), ).Scan(&migrationLevelStr)
).Scan(&migrationLevelStr)
cancel() cancel()
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
// If there's no row... // If there's no row...
@ -102,13 +92,13 @@ func (m *migrations) Perform(ctx context.Context) error {
// Perform the migrations // Perform the migrations
for i := migrationLevel; i < len(allMigrations); i++ { for i := migrationLevel; i < len(allMigrations); i++ {
m.Logger.Infof("Performing migration %d", i) m.Logger.Infof("Performing migration %d", i)
err = allMigrations[i](ctx, m) err = allMigrations[i](ctx, tx, m)
if err != nil { if err != nil {
return fmt.Errorf("failed to perform migration %d: %w", i, err) return fmt.Errorf("failed to perform migration %d: %w", i, err)
} }
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second) queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
_, err = m.Conn.ExecContext(queryCtx, _, err = tx.ExecContext(queryCtx,
fmt.Sprintf(`REPLACE INTO %s (key, value) VALUES ('migrations', ?)`, m.MetadataTableName), fmt.Sprintf(`REPLACE INTO %s (key, value) VALUES ('migrations', ?)`, m.MetadataTableName),
strconv.Itoa(i+1), strconv.Itoa(i+1),
) )
@ -119,21 +109,16 @@ func (m *migrations) Perform(ctx context.Context) error {
} }
// Commit the transaction // Commit the transaction
queryCtx, cancel = context.WithTimeout(ctx, time.Minute) err = tx.Commit()
_, err = m.Conn.ExecContext(queryCtx, "COMMIT TRANSACTION")
cancel()
if err != nil { if err != nil {
return fmt.Errorf("failed to commit transaction") return fmt.Errorf("failed to commit transaction")
} }
// Set success to true so we don't also run a rollback
success = true
return nil return nil
} }
// Returns true if a table exists // Returns true if a table exists
func (m migrations) tableExists(parentCtx context.Context, tableName string) (bool, error) { func (m migrations) tableExists(parentCtx context.Context, db querier, tableName string) (bool, error) {
ctx, cancel := context.WithTimeout(parentCtx, 30*time.Second) ctx, cancel := context.WithTimeout(parentCtx, 30*time.Second)
defer cancel() defer cancel()
@ -142,17 +127,16 @@ func (m migrations) tableExists(parentCtx context.Context, tableName string) (bo
const q = `SELECT EXISTS ( const q = `SELECT EXISTS (
SELECT name FROM sqlite_master WHERE type='table' AND name = ? SELECT name FROM sqlite_master WHERE type='table' AND name = ?
) AS 'exists'` ) AS 'exists'`
err := m.Conn. err := db.QueryRowContext(ctx, q, m.MetadataTableName).
QueryRowContext(ctx, q, m.MetadataTableName).
Scan(&exists) Scan(&exists)
return exists == "1", err return exists == "1", err
} }
func (m migrations) createMetadataTable(ctx context.Context) error { func (m migrations) createMetadataTable(ctx context.Context, db querier) error {
m.Logger.Infof("Creating metadata table '%s' if it doesn't exist", m.MetadataTableName) m.Logger.Infof("Creating metadata table '%s' if it doesn't exist", m.MetadataTableName)
// Add an "IF NOT EXISTS" in case another Dapr sidecar is creating the same table at the same time // Add an "IF NOT EXISTS" in case another Dapr sidecar is creating the same table at the same time
// In the next step we'll acquire a lock so there won't be issues with concurrency // In the next step we'll acquire a lock so there won't be issues with concurrency
_, err := m.Conn.ExecContext(ctx, fmt.Sprintf( _, err := db.ExecContext(ctx, fmt.Sprintf(
`CREATE TABLE IF NOT EXISTS %s ( `CREATE TABLE IF NOT EXISTS %s (
key text NOT NULL PRIMARY KEY, key text NOT NULL PRIMARY KEY,
value text NOT NULL value text NOT NULL
@ -165,12 +149,12 @@ func (m migrations) createMetadataTable(ctx context.Context) error {
return nil return nil
} }
var allMigrations = [1]func(ctx context.Context, m *migrations) error{ var allMigrations = [1]func(ctx context.Context, db querier, m *migrations) error{
// Migration 0: create the state table // Migration 0: create the state table
func(ctx context.Context, m *migrations) error { func(ctx context.Context, db querier, m *migrations) error {
// We need to add an "IF NOT EXISTS" because we may be migrating from when we did not use a metadata table // We need to add an "IF NOT EXISTS" because we may be migrating from when we did not use a metadata table
m.Logger.Infof("Creating state table '%s'", m.StateTableName) m.Logger.Infof("Creating state table '%s'", m.StateTableName)
_, err := m.Conn.ExecContext( _, err := db.ExecContext(
ctx, ctx,
fmt.Sprintf( fmt.Sprintf(
`CREATE TABLE %s ( `CREATE TABLE %s (