Improved concurrency handling for migrations

And added cert test for that

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
ItalyPaleAle 2022-12-05 21:00:31 +00:00
parent 4d9bae154f
commit d4dcc54e29
3 changed files with 98 additions and 24 deletions

View File

@ -35,8 +35,32 @@ type migrations struct {
// Perform the required migrations
func (m *migrations) Perform(ctx context.Context) error {
// Use an advisory lock (with an arbitrary number) to ensure that no one else is performing migrations at the same time
// This is the only way to also ensure we are not running multiple "CREATE TABLE IF NOT EXISTS" at the exact same time
// See: https://www.postgresql.org/message-id/CA+TgmoZAdYVtwBfp1FL2sMZbiHCWT4UPrzRLNnX1Nb30Ku3-gg@mail.gmail.com
const lockId = 42
// Long timeout here as this query may block
queryCtx, cancel := context.WithTimeout(ctx, time.Minute)
_, err := m.Conn.ExecContext(queryCtx, "SELECT pg_advisory_lock($1)", lockId)
cancel()
if err != nil {
return fmt.Errorf("faild to acquire advisory lock: %w", err)
}
// Release the lock
defer func() {
queryCtx, cancel = context.WithTimeout(ctx, time.Minute)
_, err = m.Conn.ExecContext(queryCtx, "SELECT pg_advisory_unlock($1)", lockId)
cancel()
if err != nil {
// Panicking here, as this forcibly closes the session and thus ensures we are not leaving locks hanging around
m.Logger.Fatalf("Failed to release advisory lock: %v", err)
}
}()
// Check if the metadata table exists, which we also use to store the migration level
queryCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
exists, _, _, err := m.tableExists(queryCtx, m.MetadataTableName)
cancel()
if err != nil {
@ -53,28 +77,13 @@ func (m *migrations) Perform(ctx context.Context) error {
}
}
// Acquire an exclusive lock on the metadata table, which we will use to ensure no one else is performing migrations
metadataTx, err := m.Conn.Begin()
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer metadataTx.Rollback()
// Long timeout here to wait for other processes to complete the migrations, since this query blocks
queryCtx, cancel = context.WithTimeout(ctx, 2*time.Minute)
_, err = metadataTx.ExecContext(queryCtx, fmt.Sprintf("LOCK TABLE %s IN SHARE MODE", m.MetadataTableName))
cancel()
if err != nil {
return fmt.Errorf("failed to acquire lock on metadata table: %w", err)
}
// Select the migration level
var (
migrationLevelStr string
migrationLevel int
)
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
err = metadataTx.
err = m.Conn.
QueryRowContext(queryCtx,
fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, m.MetadataTableName),
).Scan(&migrationLevelStr)
@ -100,7 +109,7 @@ func (m *migrations) Perform(ctx context.Context) error {
}
queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
_, err = metadataTx.ExecContext(queryCtx,
_, err = m.Conn.ExecContext(queryCtx,
fmt.Sprintf(`INSERT INTO %s (key, value) VALUES ('migrations', $1) ON CONFLICT (key) DO UPDATE SET value = $1`, m.MetadataTableName),
strconv.Itoa(i+1),
)
@ -110,12 +119,6 @@ func (m *migrations) Perform(ctx context.Context) error {
}
}
// Commit changes to the metadata table, which also releases the lock
err = metadataTx.Commit()
if err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}

View File

@ -18,6 +18,7 @@ Also test the `tableName` and `metadataTableName` metadata properties.
1. Initializes the component with names for tables that don't exist, specifying an explicit schema
1. Initializes the component with all migrations performed (current level is "2")
1. Initializes the component with only the state table, created before the metadata table was added (implied migration level "1")
1. Initializes three components at the same time and ensure no race conditions exist in performing migrations
## Test for CRUD operations

View File

@ -14,11 +14,15 @@ limitations under the License.
package main
import (
"bytes"
"context"
"database/sql"
"errors"
"fmt"
"io"
"os"
"strconv"
"sync/atomic"
"testing"
"time"
@ -214,6 +218,72 @@ func TestPostgreSQL(t *testing.T) {
require.NoError(t, err, "failed to close component")
})
t.Run("initialize components concurrently", func(t *testing.T) {
// Initializes 3 components concurrently using the same table names, and ensure that they perform migrations without conflicts and race conditions
md.Properties["tableName"] = "mystate"
md.Properties["metadataTableName"] = "mymetadata"
errs := make(chan error, 3)
hasLogs := atomic.Int32{}
for i := 0; i < 3; i++ {
go func(i int) {
buf := &bytes.Buffer{}
l := logger.NewLogger("multi-init-" + strconv.Itoa(i))
l.SetOutput(io.MultiWriter(buf, os.Stdout))
// Init and perform the migrations
storeObj := state_postgres.NewPostgreSQLStateStore(l).(*state_postgres.PostgreSQL)
err := storeObj.Init(md)
if err != nil {
errs <- fmt.Errorf("%d failed to init: %w", i, err)
return
}
// One and only one of the loggers should have any message
if buf.Len() > 0 {
hasLogs.Add(1)
}
// Close the component right away
err = storeObj.Close()
if err != nil {
errs <- fmt.Errorf("%d failed to close: %w", i, err)
return
}
errs <- nil
}(i)
}
failed := false
for i := 0; i < 3; i++ {
select {
case err := <-errs:
failed = failed || !assert.NoError(t, err)
case <-time.After(time.Minute):
t.Fatal("timed out waiting for components to initialize")
}
}
if failed {
// Short-circuit
t.FailNow()
}
// Exactly one component should have written logs (which means generated any activity during migrations)
assert.Equal(t, int32(1), hasLogs.Load(), "expected 1 component to log anything to indicate migration activity, but got %d", hasLogs.Load())
// We should have the tables correctly created
err = tableExists(dbClient, "public", "mystate")
assert.NoError(t, err, "state table does not exist")
err = tableExists(dbClient, "public", "mymetadata")
assert.NoError(t, err, "metadata table does not exist")
// Ensure migration level is correct
level, err := getMigrationLevel(dbClient, "mymetadata")
assert.NoError(t, err, "failed to get migration level")
assert.Equal(t, migrationLevel, level, "migration level mismatch: found '%s' but expected '%s'", level, migrationLevel)
})
return nil
}