Use regular transactions
Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
parent
89720a909d
commit
38f29a6ff2
|
@ -34,33 +34,24 @@ type migrations struct {
|
||||||
|
|
||||||
// Perform the required migrations
|
// Perform the required migrations
|
||||||
func (m *migrations) Perform(ctx context.Context) error {
|
func (m *migrations) Perform(ctx context.Context) error {
|
||||||
// Begin an exclusive transaction
|
// Begin a transaction
|
||||||
// We can't use Begin because that doesn't allow us setting the level of transaction
|
tx, err := m.Conn.Begin()
|
||||||
queryCtx, cancel := context.WithTimeout(ctx, time.Minute)
|
|
||||||
_, err := m.Conn.ExecContext(queryCtx, "BEGIN EXCLUSIVE TRANSACTION")
|
|
||||||
cancel()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("faild to begin transaction: %w", err)
|
return fmt.Errorf("faild to begin transaction: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rollback the transaction in a deferred statement to catch errors
|
// Rollback the transaction in a deferred statement to catch errors
|
||||||
success := false
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if success {
|
err = tx.Rollback()
|
||||||
return
|
if err != nil && err != sql.ErrTxDone {
|
||||||
}
|
|
||||||
queryCtx, cancel = context.WithTimeout(ctx, time.Minute)
|
|
||||||
_, err = m.Conn.ExecContext(queryCtx, "ROLLBACK TRANSACTION")
|
|
||||||
cancel()
|
|
||||||
if err != nil {
|
|
||||||
// Panicking here, as this forcibly closes the session and thus ensures we are not leaving locks hanging around
|
// Panicking here, as this forcibly closes the session and thus ensures we are not leaving locks hanging around
|
||||||
m.Logger.Fatalf("Failed to rollback transaction: %v", err)
|
m.Logger.Fatalf("Failed to rollback transaction: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Check if the metadata table exists, which we also use to store the migration level
|
// Check if the metadata table exists, which we also use to store the migration level
|
||||||
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
|
queryCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
exists, err := m.tableExists(queryCtx, m.MetadataTableName)
|
exists, err := m.tableExists(queryCtx, tx, m.MetadataTableName)
|
||||||
cancel()
|
cancel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to check if the metadata table exists: %w", err)
|
return fmt.Errorf("failed to check if the metadata table exists: %w", err)
|
||||||
|
@ -69,7 +60,7 @@ func (m *migrations) Perform(ctx context.Context) error {
|
||||||
// If the table doesn't exist, create it
|
// If the table doesn't exist, create it
|
||||||
if !exists {
|
if !exists {
|
||||||
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
|
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
|
||||||
err = m.createMetadataTable(queryCtx)
|
err = m.createMetadataTable(queryCtx, tx)
|
||||||
cancel()
|
cancel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create metadata table: %w", err)
|
return fmt.Errorf("failed to create metadata table: %w", err)
|
||||||
|
@ -82,10 +73,9 @@ func (m *migrations) Perform(ctx context.Context) error {
|
||||||
migrationLevel int
|
migrationLevel int
|
||||||
)
|
)
|
||||||
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
|
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
|
||||||
err = m.Conn.
|
err = tx.QueryRowContext(queryCtx,
|
||||||
QueryRowContext(queryCtx,
|
fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, m.MetadataTableName),
|
||||||
fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, m.MetadataTableName),
|
).Scan(&migrationLevelStr)
|
||||||
).Scan(&migrationLevelStr)
|
|
||||||
cancel()
|
cancel()
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
// If there's no row...
|
// If there's no row...
|
||||||
|
@ -102,13 +92,13 @@ func (m *migrations) Perform(ctx context.Context) error {
|
||||||
// Perform the migrations
|
// Perform the migrations
|
||||||
for i := migrationLevel; i < len(allMigrations); i++ {
|
for i := migrationLevel; i < len(allMigrations); i++ {
|
||||||
m.Logger.Infof("Performing migration %d", i)
|
m.Logger.Infof("Performing migration %d", i)
|
||||||
err = allMigrations[i](ctx, m)
|
err = allMigrations[i](ctx, tx, m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to perform migration %d: %w", i, err)
|
return fmt.Errorf("failed to perform migration %d: %w", i, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
|
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
|
||||||
_, err = m.Conn.ExecContext(queryCtx,
|
_, err = tx.ExecContext(queryCtx,
|
||||||
fmt.Sprintf(`REPLACE INTO %s (key, value) VALUES ('migrations', ?)`, m.MetadataTableName),
|
fmt.Sprintf(`REPLACE INTO %s (key, value) VALUES ('migrations', ?)`, m.MetadataTableName),
|
||||||
strconv.Itoa(i+1),
|
strconv.Itoa(i+1),
|
||||||
)
|
)
|
||||||
|
@ -119,21 +109,16 @@ func (m *migrations) Perform(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Commit the transaction
|
// Commit the transaction
|
||||||
queryCtx, cancel = context.WithTimeout(ctx, time.Minute)
|
err = tx.Commit()
|
||||||
_, err = m.Conn.ExecContext(queryCtx, "COMMIT TRANSACTION")
|
|
||||||
cancel()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to commit transaction")
|
return fmt.Errorf("failed to commit transaction")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set success to true so we don't also run a rollback
|
|
||||||
success = true
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns true if a table exists
|
// Returns true if a table exists
|
||||||
func (m migrations) tableExists(parentCtx context.Context, tableName string) (bool, error) {
|
func (m migrations) tableExists(parentCtx context.Context, db querier, tableName string) (bool, error) {
|
||||||
ctx, cancel := context.WithTimeout(parentCtx, 30*time.Second)
|
ctx, cancel := context.WithTimeout(parentCtx, 30*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
@ -142,17 +127,16 @@ func (m migrations) tableExists(parentCtx context.Context, tableName string) (bo
|
||||||
const q = `SELECT EXISTS (
|
const q = `SELECT EXISTS (
|
||||||
SELECT name FROM sqlite_master WHERE type='table' AND name = ?
|
SELECT name FROM sqlite_master WHERE type='table' AND name = ?
|
||||||
) AS 'exists'`
|
) AS 'exists'`
|
||||||
err := m.Conn.
|
err := db.QueryRowContext(ctx, q, m.MetadataTableName).
|
||||||
QueryRowContext(ctx, q, m.MetadataTableName).
|
|
||||||
Scan(&exists)
|
Scan(&exists)
|
||||||
return exists == "1", err
|
return exists == "1", err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m migrations) createMetadataTable(ctx context.Context) error {
|
func (m migrations) createMetadataTable(ctx context.Context, db querier) error {
|
||||||
m.Logger.Infof("Creating metadata table '%s' if it doesn't exist", m.MetadataTableName)
|
m.Logger.Infof("Creating metadata table '%s' if it doesn't exist", m.MetadataTableName)
|
||||||
// Add an "IF NOT EXISTS" in case another Dapr sidecar is creating the same table at the same time
|
// Add an "IF NOT EXISTS" in case another Dapr sidecar is creating the same table at the same time
|
||||||
// In the next step we'll acquire a lock so there won't be issues with concurrency
|
// In the next step we'll acquire a lock so there won't be issues with concurrency
|
||||||
_, err := m.Conn.ExecContext(ctx, fmt.Sprintf(
|
_, err := db.ExecContext(ctx, fmt.Sprintf(
|
||||||
`CREATE TABLE IF NOT EXISTS %s (
|
`CREATE TABLE IF NOT EXISTS %s (
|
||||||
key text NOT NULL PRIMARY KEY,
|
key text NOT NULL PRIMARY KEY,
|
||||||
value text NOT NULL
|
value text NOT NULL
|
||||||
|
@ -165,12 +149,12 @@ func (m migrations) createMetadataTable(ctx context.Context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var allMigrations = [1]func(ctx context.Context, m *migrations) error{
|
var allMigrations = [1]func(ctx context.Context, db querier, m *migrations) error{
|
||||||
// Migration 0: create the state table
|
// Migration 0: create the state table
|
||||||
func(ctx context.Context, m *migrations) error {
|
func(ctx context.Context, db querier, m *migrations) error {
|
||||||
// We need to add an "IF NOT EXISTS" because we may be migrating from when we did not use a metadata table
|
// We need to add an "IF NOT EXISTS" because we may be migrating from when we did not use a metadata table
|
||||||
m.Logger.Infof("Creating state table '%s'", m.StateTableName)
|
m.Logger.Infof("Creating state table '%s'", m.StateTableName)
|
||||||
_, err := m.Conn.ExecContext(
|
_, err := db.ExecContext(
|
||||||
ctx,
|
ctx,
|
||||||
fmt.Sprintf(
|
fmt.Sprintf(
|
||||||
`CREATE TABLE %s (
|
`CREATE TABLE %s (
|
||||||
|
|
Loading…
Reference in New Issue