Use regular transactions

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
ItalyPaleAle 2023-02-03 22:40:03 +00:00
parent 89720a909d
commit 38f29a6ff2
1 changed files with 20 additions and 36 deletions

View File

@ -34,33 +34,24 @@ type migrations struct {
// Perform the required migrations
func (m *migrations) Perform(ctx context.Context) error {
// Begin an exclusive transaction
// We can't use Begin because that doesn't allow us setting the level of transaction
queryCtx, cancel := context.WithTimeout(ctx, time.Minute)
_, err := m.Conn.ExecContext(queryCtx, "BEGIN EXCLUSIVE TRANSACTION")
cancel()
// Begin a transaction
tx, err := m.Conn.Begin()
if err != nil {
return fmt.Errorf("faild to begin transaction: %w", err)
}
// Rollback the transaction in a deferred statement to catch errors
success := false
defer func() {
if success {
return
}
queryCtx, cancel = context.WithTimeout(ctx, time.Minute)
_, err = m.Conn.ExecContext(queryCtx, "ROLLBACK TRANSACTION")
cancel()
if err != nil {
err = tx.Rollback()
if err != nil && err != sql.ErrTxDone {
// Panicking here, as this forcibly closes the session and thus ensures we are not leaving locks hanging around
m.Logger.Fatalf("Failed to rollback transaction: %v", err)
}
}()
// Check if the metadata table exists, which we also use to store the migration level
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
exists, err := m.tableExists(queryCtx, m.MetadataTableName)
queryCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
exists, err := m.tableExists(queryCtx, tx, m.MetadataTableName)
cancel()
if err != nil {
return fmt.Errorf("failed to check if the metadata table exists: %w", err)
@ -69,7 +60,7 @@ func (m *migrations) Perform(ctx context.Context) error {
// If the table doesn't exist, create it
if !exists {
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
err = m.createMetadataTable(queryCtx)
err = m.createMetadataTable(queryCtx, tx)
cancel()
if err != nil {
return fmt.Errorf("failed to create metadata table: %w", err)
@ -82,10 +73,9 @@ func (m *migrations) Perform(ctx context.Context) error {
migrationLevel int
)
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
err = m.Conn.
QueryRowContext(queryCtx,
fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, m.MetadataTableName),
).Scan(&migrationLevelStr)
err = tx.QueryRowContext(queryCtx,
fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, m.MetadataTableName),
).Scan(&migrationLevelStr)
cancel()
if errors.Is(err, sql.ErrNoRows) {
// If there's no row...
@ -102,13 +92,13 @@ func (m *migrations) Perform(ctx context.Context) error {
// Perform the migrations
for i := migrationLevel; i < len(allMigrations); i++ {
m.Logger.Infof("Performing migration %d", i)
err = allMigrations[i](ctx, m)
err = allMigrations[i](ctx, tx, m)
if err != nil {
return fmt.Errorf("failed to perform migration %d: %w", i, err)
}
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
_, err = m.Conn.ExecContext(queryCtx,
_, err = tx.ExecContext(queryCtx,
fmt.Sprintf(`REPLACE INTO %s (key, value) VALUES ('migrations', ?)`, m.MetadataTableName),
strconv.Itoa(i+1),
)
@ -119,21 +109,16 @@ func (m *migrations) Perform(ctx context.Context) error {
}
// Commit the transaction
queryCtx, cancel = context.WithTimeout(ctx, time.Minute)
_, err = m.Conn.ExecContext(queryCtx, "COMMIT TRANSACTION")
cancel()
err = tx.Commit()
if err != nil {
return fmt.Errorf("failed to commit transaction")
}
// Set success to true so we don't also run a rollback
success = true
return nil
}
// Returns true if a table exists
func (m migrations) tableExists(parentCtx context.Context, tableName string) (bool, error) {
func (m migrations) tableExists(parentCtx context.Context, db querier, tableName string) (bool, error) {
ctx, cancel := context.WithTimeout(parentCtx, 30*time.Second)
defer cancel()
@ -142,17 +127,16 @@ func (m migrations) tableExists(parentCtx context.Context, tableName string) (bo
const q = `SELECT EXISTS (
SELECT name FROM sqlite_master WHERE type='table' AND name = ?
) AS 'exists'`
err := m.Conn.
QueryRowContext(ctx, q, m.MetadataTableName).
err := db.QueryRowContext(ctx, q, m.MetadataTableName).
Scan(&exists)
return exists == "1", err
}
func (m migrations) createMetadataTable(ctx context.Context) error {
func (m migrations) createMetadataTable(ctx context.Context, db querier) error {
m.Logger.Infof("Creating metadata table '%s' if it doesn't exist", m.MetadataTableName)
// Add an "IF NOT EXISTS" in case another Dapr sidecar is creating the same table at the same time
// In the next step we'll acquire a lock so there won't be issues with concurrency
_, err := m.Conn.ExecContext(ctx, fmt.Sprintf(
_, err := db.ExecContext(ctx, fmt.Sprintf(
`CREATE TABLE IF NOT EXISTS %s (
key text NOT NULL PRIMARY KEY,
value text NOT NULL
@ -165,12 +149,12 @@ func (m migrations) createMetadataTable(ctx context.Context) error {
return nil
}
var allMigrations = [1]func(ctx context.Context, m *migrations) error{
var allMigrations = [1]func(ctx context.Context, db querier, m *migrations) error{
// Migration 0: create the state table
func(ctx context.Context, m *migrations) error {
func(ctx context.Context, db querier, m *migrations) error {
// We need to add an "IF NOT EXISTS" because we may be migrating from when we did not use a metadata table
m.Logger.Infof("Creating state table '%s'", m.StateTableName)
_, err := m.Conn.ExecContext(
_, err := db.ExecContext(
ctx,
fmt.Sprintf(
`CREATE TABLE %s (