Add migration utils for SQL Server (#3280)

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
Co-authored-by: Deepanshu Agarwal <deepanshu.agarwal1984@gmail.com>
This commit is contained in:
Alessandro (Ale) Segala 2024-01-02 09:00:17 -08:00 committed by GitHub
parent e903af18bd
commit 56579c6d47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 508 additions and 12 deletions

View File

@ -0,0 +1,149 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package sqlservermigrations
import (
"context"
"database/sql"
"fmt"
"time"
commonsql "github.com/dapr/components-contrib/common/component/sql"
"github.com/dapr/kit/logger"
)
// Migrations performs migrations for the database schema
type Migrations struct {
DB *sql.DB
Logger logger.Logger
Schema string
MetadataTableName string
MetadataKey string
tableName string
}
// Perform the required migrations
func (m *Migrations) Perform(ctx context.Context, migrationFns []commonsql.MigrationFn) (err error) {
// Setting a short-hand since it's going to be used a lot
m.tableName = fmt.Sprintf("[%s].[%s]", m.Schema, m.MetadataTableName)
// Ensure the metadata table exists
err = m.ensureMetadataTable(ctx)
if err != nil {
return fmt.Errorf("failed to ensure metadata table exists: %w", err)
}
// In order to acquire a row-level lock, we need to have a row in the metadata table
// So, we're going to write a row in there (not using a transaction, as that causes a table-level lock to be created), ignoring duplicates
const lockKey = "lock"
m.Logger.Debugf("Ensuring lock row '%s' exists in metadata table", lockKey)
queryCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
_, err = m.DB.ExecContext(queryCtx, fmt.Sprintf(`
INSERT INTO %[1]s
([Key], [Value])
SELECT @Key, @Value
WHERE NOT EXISTS (
SELECT 1
FROM %[1]s
WHERE [Key] = @Key
);
`, m.tableName), sql.Named("Key", lockKey), sql.Named("Value", lockKey))
cancel()
if err != nil {
return fmt.Errorf("failed to ensure lock row '%s' exists: %w", lockKey, err)
}
// Now, let's use a transaction on a row in the metadata table as a lock
m.Logger.Debug("Starting transaction pre-migration")
tx, err := m.DB.Begin()
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
// Always rollback the transaction at the end to release the lock, since the value doesn't really matter
defer func() {
m.Logger.Debug("Releasing migration lock")
rollbackErr := tx.Rollback()
if rollbackErr != nil {
// Panicking here, as this forcibly closes the session and thus ensures we are not leaving locks hanging around
m.Logger.Fatalf("Failed to roll back transaction: %v", rollbackErr)
}
}()
// Now, perform a SELECT with FOR UPDATE to lock the row used for locking, and only that row
// We use a long timeout here as this query may block
m.Logger.Debug("Acquiring migration lock")
queryCtx, cancel = context.WithTimeout(ctx, time.Minute)
var lock string
//nolint:gosec
q := fmt.Sprintf(`
SELECT [Value]
FROM %s
WITH (XLOCK, ROWLOCK)
WHERE [key] = @Key
`, m.tableName)
err = tx.QueryRowContext(queryCtx, q, sql.Named("Key", lockKey)).Scan(&lock)
cancel()
if err != nil {
return fmt.Errorf("failed to acquire migration lock (row-level lock on key '%s'): %w", lockKey, err)
}
m.Logger.Debug("Migration lock acquired")
// Perform the migrations
// Here we pass the database connection and not the transaction, since the transaction is only used to acquire the lock
err = commonsql.Migrate(ctx, commonsql.AdaptDatabaseSQLConn(m.DB), commonsql.MigrationOptions{
Logger: m.Logger,
// Yes, we are using fmt.Sprintf for adding a value in a query.
// This comes from a constant hardcoded at development-time, and cannot be influenced by users. So, no risk of SQL injections here.
GetVersionQuery: fmt.Sprintf(`SELECT [Value] FROM %s WHERE [Key] = '%s'`, m.tableName, m.MetadataKey),
UpdateVersionQuery: func(version string) (string, any) {
return fmt.Sprintf(`
MERGE
%[1]s WITH (HOLDLOCK) AS t
USING (SELECT '%[2]s' AS [Key]) AS s
ON [t].[Key] = [s].[Key]
WHEN MATCHED THEN
UPDATE SET [Value] = @Value
WHEN NOT MATCHED THEN
INSERT ([Key], [Value]) VALUES ('%[2]s', @Value)
;
`,
m.tableName, m.MetadataKey,
), sql.Named("Value", version)
},
Migrations: migrationFns,
})
if err != nil {
return err
}
return nil
}
func (m Migrations) ensureMetadataTable(ctx context.Context) error {
m.Logger.Infof("Ensuring metadata table '%s' exists", m.tableName)
_, err := m.DB.ExecContext(ctx, fmt.Sprintf(`
IF OBJECT_ID('%[1]s', 'U') IS NULL
CREATE TABLE %[1]s (
[Key] VARCHAR(255) COLLATE Latin1_General_100_BIN2 NOT NULL PRIMARY KEY,
[Value] VARCHAR(max) COLLATE Latin1_General_100_BIN2 NOT NULL
)`,
m.tableName,
))
if err != nil {
return fmt.Errorf("failed to create metadata table: %w", err)
}
return nil
}

View File

@ -0,0 +1,240 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package sqlservermigrations
import (
"bytes"
"context"
"crypto/rand"
"database/sql"
"encoding/hex"
"fmt"
"io"
"os"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
// Blank import for the SQL Server driver
_ "github.com/microsoft/go-mssqldb"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
commonsql "github.com/dapr/components-contrib/common/component/sql"
commonsqlserver "github.com/dapr/components-contrib/common/component/sqlserver"
"github.com/dapr/kit/logger"
)
// connectionStringEnvKey defines the env key containing the integration test connection string
// To use Docker: `server=localhost;user id=sa;password=Pass@Word1;port=1433;database=dapr_test;`
// To use Azure SQL: `server=<your-db-server-name>.database.windows.net;user id=<your-db-user>;port=1433;password=<your-password>;database=dapr_test;`
const connectionStringEnvKey = "DAPR_TEST_SQL_CONNSTRING"
// Disable gosec in this test as we use string concatenation a lot with queries
func TestMigration(t *testing.T) {
connectionString := os.Getenv(connectionStringEnvKey)
if connectionString == "" {
t.Skipf(`SQLServer migration test skipped. To enable this test, define the connection string using environment variable '%[1]s' (example 'export %[1]s="server=localhost;user id=sa;password=Pass@Word1;port=1433;database=dapr_test;")'`, connectionStringEnvKey)
}
log := logger.NewLogger("migration-test")
log.SetOutputLevel(logger.DebugLevel)
// Connect to the database
db, err := sql.Open("sqlserver", connectionString)
require.NoError(t, err, "Failed to connect to database")
t.Cleanup(func() {
db.Close()
})
// Create a new schema for testing
schema := getUniqueDBSchema(t)
_, err = db.Exec(fmt.Sprintf("CREATE SCHEMA [%s]", schema))
require.NoError(t, err, "Failed to create schema")
t.Cleanup(func() {
err = commonsqlserver.DropSchema(context.Background(), db, schema)
require.NoError(t, err, "Failed to drop schema")
})
t.Run("Metadata table", func(t *testing.T) {
m := &Migrations{
DB: db,
Logger: log,
Schema: schema,
MetadataTableName: "metadata_1",
MetadataKey: "migrations",
}
t.Run("Create new", func(t *testing.T) {
err = m.Perform(context.Background(), []commonsql.MigrationFn{})
require.NoError(t, err)
assertTableExists(t, db, schema, "metadata_1")
})
t.Run("Already exists", func(t *testing.T) {
err = m.Perform(context.Background(), []commonsql.MigrationFn{})
require.NoError(t, err)
assertTableExists(t, db, schema, "metadata_1")
})
})
t.Run("Perform migrations", func(t *testing.T) {
m := &Migrations{
DB: db,
Logger: log,
Schema: schema,
MetadataTableName: "metadata_2",
MetadataKey: "migrations",
}
fn1 := func(ctx context.Context) error {
_, err = m.DB.Exec(fmt.Sprintf("CREATE TABLE [%s].[TestTable] ([Key] INTEGER NOT NULL PRIMARY KEY)", schema))
return err
}
t.Run("First migration", func(t *testing.T) {
err = m.Perform(context.Background(), []commonsql.MigrationFn{fn1})
require.NoError(t, err)
assertTableExists(t, db, schema, "TestTable")
assertMigrationsLevel(t, db, schema, "metadata_2", "migrations", "1")
})
t.Run("Second migration", func(t *testing.T) {
var called bool
fn2 := func(ctx context.Context) error {
// We don't actually have to do anything here, we just care that the migration level has increased
called = true
return nil
}
err = m.Perform(context.Background(), []commonsql.MigrationFn{fn1, fn2})
require.NoError(t, err)
assert.True(t, called)
assertMigrationsLevel(t, db, schema, "metadata_2", "migrations", "2")
})
t.Run("Already has migrated", func(t *testing.T) {
var called bool
fn2 := func(ctx context.Context) error {
// We don't actually have to do anything here, we just care that the migration level has increased
called = true
return nil
}
err = m.Perform(context.Background(), []commonsql.MigrationFn{fn1, fn2})
require.NoError(t, err)
assert.False(t, called)
assertMigrationsLevel(t, db, schema, "metadata_2", "migrations", "2")
})
})
t.Run("Perform migrations concurrently", func(t *testing.T) {
counter := atomic.Uint32{}
fn := func(ctx context.Context) error {
// This migration doesn't actually do anything
counter.Add(1)
return nil
}
const parallel = 5
errs := make(chan error, parallel)
hasLogs := atomic.Uint32{}
for i := 0; i < parallel; i++ {
go func(i int) {
// Collect logs
collectLog := logger.NewLogger("concurrent-" + strconv.Itoa(i))
collectLog.SetOutputLevel(logger.DebugLevel)
buf := &bytes.Buffer{}
collectLog.SetOutput(io.MultiWriter(buf, os.Stdout))
m := &Migrations{
DB: db,
Logger: collectLog,
Schema: schema,
MetadataTableName: "metadata_2",
MetadataKey: "migrations_concurrent",
}
migrateErr := m.Perform(context.Background(), []commonsql.MigrationFn{fn})
if migrateErr != nil {
errs <- fmt.Errorf("migration failed in handler %d: %w", i, migrateErr)
}
// One and only one of the loggers should have any message including "Performing migration"
if strings.Contains(buf.String(), "Performing migration") {
hasLogs.Add(1)
}
errs <- nil
}(i)
}
for i := 0; i < parallel; i++ {
select {
case err := <-errs:
assert.NoError(t, err) //nolint:testifylint
case <-time.After(30 * time.Second):
t.Fatal("timed out waiting for migrations to complete")
}
}
if t.Failed() {
// Short-circuit
t.FailNow()
}
// Handler should have been invoked just once
assert.Equal(t, uint32(1), counter.Load(), "Migrations handler invoked more than once")
assert.Equal(t, uint32(1), hasLogs.Load(), "More than one logger indicated a migration")
})
}
func getUniqueDBSchema(t *testing.T) string {
t.Helper()
b := make([]byte, 4)
_, err := io.ReadFull(rand.Reader, b)
require.NoError(t, err)
return fmt.Sprintf("m%s", hex.EncodeToString(b))
}
func assertTableExists(t *testing.T, db *sql.DB, schema, table string) {
t.Helper()
var found int
err := db.QueryRow(
fmt.Sprintf("SELECT 1 WHERE OBJECT_ID('[%s].[%s]', 'U') IS NOT NULL", schema, table),
).Scan(&found)
require.NoErrorf(t, err, "Table %s not found", table)
require.Equalf(t, 1, found, "Table %s not found", table)
}
func assertMigrationsLevel(t *testing.T, db *sql.DB, schema, table, key, expectLevel string) {
t.Helper()
var foundLevel string
err := db.QueryRow(
fmt.Sprintf("SELECT [Value] FROM [%s].[%s] WHERE [Key] = @Key", schema, table),
sql.Named("Key", key),
).Scan(&foundLevel)
require.NoError(t, err, "Failed to load migrations level")
require.Equal(t, expectLevel, foundLevel, "Migration level does not match")
}

View File

@ -0,0 +1,105 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package sqlserver
import (
"context"
"database/sql"
)
// DropSchema drops a schema from a SQL Server database, including all resources that were created inside
// Adapted from: https://stackoverflow.com/a/76742742/192024
func DropSchema(ctx context.Context, db *sql.DB, schema string) error {
_, err := db.ExecContext(ctx, `
DECLARE @command NVARCHAR(MAX) = '';
WITH Schemas AS (
SELECT
s.schema_id,
s.name AS schema_name,
IIF(s.Name = 'dbo', 1, 0) schema_predefined
FROM sys.schemas s
INNER JOIN sys.sysusers u ON u.uid = s.principal_id
WHERE u.issqlrole = 0
AND s.Name = @Schema
AND u.name NOT IN ('sys', 'guest', 'INFORMATION_SCHEMA')
),
Commands(Command) AS (
-- Procedures
SELECT 'DROP PROCEDURE [' + schema_name + '].[' + name + ']'
FROM sys.procedures o
JOIN Schemas schemas ON o.schema_id = schemas.schema_id
-- Functions
UNION ALL
SELECT 'DROP FUNCTION [' + schema_name + '].[' + name + ']'
FROM sys.objects o
JOIN Schemas schemas ON o.schema_id = schemas.schema_id
WHERE type IN ('FN', 'IF', 'TF')
-- Views
UNION ALL
SELECT 'DROP VIEW [' + schema_name + '].[' + name + ']'
FROM sys.views o
JOIN Schemas schemas ON o.schema_id = schemas.schema_id
-- Check constraints
UNION ALL
SELECT
'ALTER TABLE [' + schema_name + '].[' + object_name(parent_object_id) + '] ' +
'DROP CONSTRAINT [' + name + ']'
FROM sys.check_constraints o
JOIN Schemas schemas ON o.schema_id = schemas.schema_id
-- Foreign keys
UNION ALL
SELECT
'ALTER TABLE [' + schema_name + '].[' + object_name(parent_object_id) + '] ' +
'DROP CONSTRAINT [' + name + ']'
FROM sys.foreign_keys o
JOIN Schemas schemas ON o.schema_id = schemas.schema_id
-- Tables
UNION ALL
SELECT 'DROP TABLE [' + schema_name + '].[' + name + ']'
FROM sys.tables o
JOIN Schemas schemas ON o.schema_id = schemas.schema_id
-- Sequences
UNION ALL
SELECT 'DROP SEQUENCE [' + schema_name + '].[' + name + ']'
FROM sys.sequences o
JOIN Schemas schemas ON o.schema_id = schemas.schema_id
-- User defined types
UNION ALL
SELECT 'DROP TYPE [' + schema_name + '].[' + name + ']'
FROM sys.types o
JOIN Schemas schemas ON o.schema_id = schemas.schema_id
WHERE is_user_defined = 1
-- Schemas
UNION ALL
SELECT 'DROP SCHEMA [' + schema_name + ']'
FROM Schemas
WHERE schema_predefined = 0
)
SELECT @command = STRING_AGG(Command, CHAR(10))
FROM Commands
PRINT @command
EXEC sp_executesql @command
`, sql.Named("Schema", schema))
return err
}

View File

@ -17,13 +17,15 @@ package sqlserver
import (
"context"
"crypto/rand"
"database/sql"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"math"
"os"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
@ -42,10 +44,10 @@ const (
// connectionStringEnvKey defines the key containing the integration test connection string
// To use docker, server=localhost;user id=sa;password=Pass@Word1;port=1433;
// To use Azure SQL, server=<your-db-server-name>.database.windows.net;user id=<your-db-user>;port=1433;password=<your-password>;database=dapr_test;.
connectionStringEnvKey = "DAPR_TEST_SQL_CONNSTRING"
usersTableName = "Users"
beverageTea = "tea"
invalidEtag string = "FFFFFFFFFFFFFFFF"
connectionStringEnvKey = "DAPR_TEST_SQL_CONNSTRING"
usersTableName = "Users"
beverageTea = "tea"
invalidEtag = "FFFFFFFFFFFFFFFF"
)
type user struct {
@ -67,7 +69,7 @@ type userWithEtag struct {
func TestIntegrationCases(t *testing.T) {
connectionString := os.Getenv(connectionStringEnvKey)
if connectionString == "" {
t.Skipf("SQLServer state integration tests skipped. To enable define the connection string using environment variable '%s' (example 'export %s=\"server=localhost;user id=sa;password=Pass@Word1;port=1433;\")", connectionStringEnvKey, connectionStringEnvKey)
t.Skipf(`SQLServer state integration tests skipped. To enable this test, define the connection string using environment variable '%[1]s' (example 'export %[1]s="server=localhost;user id=sa;password=Pass@Word1;port=1433;")'`, connectionStringEnvKey)
}
t.Run("Single operations", testSingleOperations)
@ -84,11 +86,11 @@ func TestIntegrationCases(t *testing.T) {
}
}
func getUniqueDBSchema() string {
uuid := uuid.New().String()
uuid = strings.ReplaceAll(uuid, "-", "")
return fmt.Sprintf("v%s", uuid)
func getUniqueDBSchema(t *testing.T) string {
b := make([]byte, 4)
_, err := io.ReadFull(rand.Reader, b)
require.NoError(t, err)
return fmt.Sprintf("v%s", hex.EncodeToString(b))
}
func createMetadata(schema string, kt KeyType, indexedProperties string) state.Metadata {
@ -116,7 +118,7 @@ func getTestStore(t *testing.T, indexedProperties string) *SQLServer {
}
func getTestStoreWithKeyType(t *testing.T, kt KeyType, indexedProperties string) *SQLServer {
schema := getUniqueDBSchema()
schema := getUniqueDBSchema(t)
metadata := createMetadata(schema, kt, indexedProperties)
store := &SQLServer{
logger: logger.NewLogger("test"),