components-contrib/tests/certification/state/sqlserver/sqlserver_test.go

598 lines
22 KiB
Go

/*
Copyright 2021 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package sqlserver_test
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"os"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
// State.
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/state"
state_sqlserver "github.com/dapr/components-contrib/state/sqlserver"
state_loader "github.com/dapr/dapr/pkg/components/state"
"github.com/dapr/kit/logger"
// Secret stores.
secretstore_env "github.com/dapr/components-contrib/secretstores/local/env"
secretstores_loader "github.com/dapr/dapr/pkg/components/secretstores"
// Dapr runtime and Go-SDK
"github.com/dapr/dapr/pkg/runtime"
dapr_testing "github.com/dapr/dapr/pkg/testing"
"github.com/dapr/go-sdk/client"
// Certification testing runnables
"github.com/dapr/components-contrib/tests/certification/embedded"
"github.com/dapr/components-contrib/tests/certification/flow"
"github.com/dapr/components-contrib/tests/certification/flow/dockercompose"
"github.com/dapr/components-contrib/tests/certification/flow/network"
"github.com/dapr/components-contrib/tests/certification/flow/retry"
"github.com/dapr/components-contrib/tests/certification/flow/sidecar"
)
const (
sidecarNamePrefix = "sqlserver-sidecar-"
dockerComposeYAML = "docker-compose.yml"
stateStoreName = "dapr-state-store"
certificationTestPrefix = "stable-certification-"
dockerConnectionString = "server=localhost;user id=sa;password=Pass@Word1;port=1433;"
)
func TestSqlServer(t *testing.T) {
ports, err := dapr_testing.GetFreePorts(2)
assert.NoError(t, err)
currentGrpcPort := ports[0]
currentHTTPPort := ports[1]
basicTest := func(ctx flow.Context) error {
ctx.T.Run("basic test", func(t *testing.T) {
client, err := client.NewClientWithPort(strconv.Itoa(currentGrpcPort))
if err != nil {
panic(err)
}
defer client.Close()
// save state, default options: strong, last-write
err = client.SaveState(ctx, stateStoreName, certificationTestPrefix+"key1", []byte("certificationdata"), nil)
require.NoError(t, err)
// get state
item, err := client.GetState(ctx, stateStoreName, certificationTestPrefix+"key1", nil)
require.NoError(t, err)
assert.Equal(t, "certificationdata", string(item.Value))
// delete state
err = client.DeleteState(ctx, stateStoreName, certificationTestPrefix+"key1", nil)
require.NoError(t, err)
})
return nil
}
basicTTLTest := func(ctx flow.Context) error {
ctx.T.Run("basic TTL test", func(t *testing.T) {
client, err := client.NewClientWithPort(strconv.Itoa(currentGrpcPort))
if err != nil {
panic(err)
}
defer client.Close()
err = client.SaveState(ctx, stateStoreName, certificationTestPrefix+"key2", []byte("certificationdata"), map[string]string{
"ttlInSeconds": "86400",
})
require.NoError(t, err)
// get state
item, err := client.GetState(ctx, stateStoreName, certificationTestPrefix+"key2", nil)
require.NoError(t, err)
assert.Equal(t, "certificationdata", string(item.Value))
assert.Contains(t, item.Metadata, "ttlExpireTime")
expireTime, err := time.Parse(time.RFC3339, item.Metadata["ttlExpireTime"])
_ = assert.NoError(t, err) &&
assert.InDelta(t, time.Now().Add(24*time.Hour).Unix(), expireTime.Unix(), 10)
err = client.SaveState(ctx, stateStoreName, certificationTestPrefix+"key2", []byte("certificationdata"), map[string]string{
"ttlInSeconds": "1",
})
require.NoError(t, err)
time.Sleep(2 * time.Second)
item, err = client.GetState(ctx, stateStoreName, certificationTestPrefix+"key2", nil)
require.NoError(t, err)
assert.Nil(t, item.Value)
assert.Nil(t, item.Metadata)
})
return nil
}
// this test function heavily depends on the values defined in ./components/docker/customschemawithindex
verifyIndexedPopertiesTest := func(ctx flow.Context) error {
// verify indices were created by Dapr as specified in the component metadata
db, err := sql.Open("mssql", fmt.Sprintf("%sdatabase=certificationtest;", dockerConnectionString))
require.NoError(ctx.T, err)
defer db.Close()
rows, err := db.Query("sp_helpindex '[customschema].[mystates]'")
require.NoError(ctx.T, err)
assert.NoError(ctx.T, rows.Err())
defer rows.Close()
indexFoundCount := 0
for rows.Next() {
var indexedField, otherdata1, otherdata2 string
err = rows.Scan(&indexedField, &otherdata1, &otherdata2)
assert.NoError(ctx.T, err)
expectedIndices := []string{"IX_customerid", "IX_transactionid", "PK_mystates"}
for _, item := range expectedIndices {
if item == indexedField {
indexFoundCount++
break
}
}
}
assert.Equal(ctx.T, 3, indexFoundCount)
// write JSON data to the state store (which will automatically be indexed in separate columns)
client, err := client.NewClientWithPort(strconv.Itoa(currentGrpcPort))
if err != nil {
panic(err)
}
defer client.Close()
order := struct {
ID int `json:"id"`
Customer string `json:"customer"`
Description string `json:"description"`
}{123456, "John Doe", "something"}
data, err := json.Marshal(order)
assert.NoError(ctx.T, err)
// save state with the key certificationkey1, default options: strong, last-write
err = client.SaveState(ctx, stateStoreName, certificationTestPrefix+"key1", data, nil)
require.NoError(ctx.T, err)
// get state for key certificationkey1
item, err := client.GetState(ctx, stateStoreName, certificationTestPrefix+"key1", nil)
assert.NoError(ctx.T, err)
assert.Equal(ctx.T, string(data), string(item.Value))
// check that Dapr wrote the indexed properties to separate columns
rows, err = db.Query("SELECT TOP 1 transactionid, customerid FROM [customschema].[mystates];")
assert.NoError(ctx.T, err)
assert.NoError(ctx.T, rows.Err())
defer rows.Close()
if rows.Next() {
var transactionID int
var customerID string
err = rows.Scan(&transactionID, &customerID)
assert.NoError(ctx.T, err)
assert.Equal(ctx.T, transactionID, order.ID)
assert.Equal(ctx.T, customerID, order.Customer)
} else {
assert.Fail(ctx.T, "no rows returned")
}
// delete state for key certificationkey1
err = client.DeleteState(ctx, stateStoreName, certificationTestPrefix+"key1", nil)
assert.NoError(ctx.T, err)
return nil
}
// helper function for testing the use of an existing custom schema
createCustomSchema := func(ctx flow.Context) error {
db, err := sql.Open("mssql", dockerConnectionString)
assert.NoError(ctx.T, err)
_, err = db.Exec("CREATE SCHEMA customschema;")
assert.NoError(ctx.T, err)
db.Close()
return nil
}
// helper function to insure the SQL Server Docker Container is truly ready
checkSQLServerAvailability := func(ctx flow.Context) error {
db, err := sql.Open("mssql", dockerConnectionString)
if err != nil {
return err
}
_, err = db.Exec("SELECT * FROM INFORMATION_SCHEMA.TABLES;")
if err != nil {
return err
}
return nil
}
// checks the state store component is not vulnerable to SQL injection
verifySQLInjectionTest := func(ctx flow.Context) error {
client, err := client.NewClientWithPort(strconv.Itoa(currentGrpcPort))
if err != nil {
panic(err)
}
defer client.Close()
// common SQL injection techniques for SQL Server
sqlInjectionAttempts := []string{
"; DROP states--",
"dapr' OR '1'='1",
}
for _, sqlInjectionAttempt := range sqlInjectionAttempts {
// save state with sqlInjectionAttempt's value as key, default options: strong, last-write
err = client.SaveState(ctx, stateStoreName, sqlInjectionAttempt, []byte(sqlInjectionAttempt), nil)
assert.NoError(ctx.T, err)
// get state for key sqlInjectionAttempt's value
item, err := client.GetState(ctx, stateStoreName, sqlInjectionAttempt, nil)
assert.NoError(ctx.T, err)
assert.Equal(ctx.T, sqlInjectionAttempt, string(item.Value))
// delete state for key sqlInjectionAttempt's value
err = client.DeleteState(ctx, stateStoreName, sqlInjectionAttempt, nil)
assert.NoError(ctx.T, err)
}
return nil
}
// Validates TTLs and garbage collections
ttlTest := func(connString string) func(ctx flow.Context) error {
return func(ctx flow.Context) error {
log := logger.NewLogger("dapr.components")
md := state.Metadata{
Base: metadata.Base{
Name: "ttltest",
Properties: map[string]string{
"connectionString": connString,
"databaseName": "certificationtest",
"tableName": "ttltest",
"metadataTableName": "ttltest_metadata",
"schema": "ttlschema",
},
},
}
ctx.T.Run("parse cleanupIntervalInSeconds", func(t *testing.T) {
t.Run("default value", func(t *testing.T) {
// Default value is 1 hr
md.Properties["cleanupIntervalInSeconds"] = ""
storeObj := state_sqlserver.New(log).(*state_sqlserver.SQLServer)
err := storeObj.Init(context.Background(), md)
require.NoError(t, err, "failed to init")
defer storeObj.Close()
cleanupInterval := storeObj.GetCleanupInterval()
require.NotNil(t, cleanupInterval)
assert.Equal(t, time.Duration(1*time.Hour), *cleanupInterval)
})
t.Run("positive value", func(t *testing.T) {
// A positive value is interpreted in seconds
md.Properties["cleanupIntervalInSeconds"] = "10"
storeObj := state_sqlserver.New(log).(*state_sqlserver.SQLServer)
err := storeObj.Init(context.Background(), md)
require.NoError(t, err, "failed to init")
defer storeObj.Close()
cleanupInterval := storeObj.GetCleanupInterval()
require.NotNil(t, cleanupInterval)
assert.Equal(t, time.Duration(10*time.Second), *cleanupInterval)
})
t.Run("disabled", func(t *testing.T) {
// A value of <=0 means that the cleanup is disabled
md.Properties["cleanupIntervalInSeconds"] = "0"
storeObj := state_sqlserver.New(log).(*state_sqlserver.SQLServer)
err := storeObj.Init(context.Background(), md)
require.NoError(t, err, "failed to init")
defer storeObj.Close()
cleanupInterval := storeObj.GetCleanupInterval()
assert.Nil(t, cleanupInterval)
})
})
ctx.T.Run("cleanup", func(t *testing.T) {
dbClient, err := sql.Open("mssql", connString)
require.NoError(t, err)
t.Run("automatically delete expiredate records", func(t *testing.T) {
// Run every second
md.Properties["cleanupIntervalInSeconds"] = "1"
storeObj := state_sqlserver.New(log).(*state_sqlserver.SQLServer)
err := storeObj.Init(context.Background(), md)
require.NoError(t, err, "failed to init")
defer storeObj.Close()
// Seed the database with some records
err = clearTable(ctx, dbClient, "ttlschema", "ttltest")
require.NoError(t, err, "failed to clear table")
err = populateTTLRecords(ctx, dbClient, "ttlschema", "ttltest")
require.NoError(t, err, "failed to seed records")
cleanupInterval := storeObj.GetCleanupInterval()
assert.NotNil(t, cleanupInterval)
assert.Equal(t, time.Duration(time.Second), *cleanupInterval)
// Wait up to 3 seconds then verify we have only 10 rows left
var count int
assert.Eventually(t, func() bool {
count, err = countRowsInTable(ctx, dbClient, "ttlschema", "ttltest")
require.NoError(t, err, "failed to run query to count rows")
return count == 10
}, 3*time.Second, 10*time.Millisecond, "expected 10 rows, got %d", count)
// The "last-cleanup" value should be <= 1 second (+ a bit of buffer)
lastCleanup, err := loadLastCleanupInterval(ctx, dbClient, "ttlschema", "ttltest_metadata")
require.NoError(t, err, "failed to load value for 'last-cleanup'")
assert.LessOrEqual(t, lastCleanup, int64(1200))
// Wait 6 more seconds and verify there are no more rows left
assert.Eventually(t, func() bool {
count, err = countRowsInTable(ctx, dbClient, "ttlschema", "ttltest")
require.NoError(t, err, "failed to run query to count rows")
return count == 0
}, 6*time.Second, 10*time.Millisecond, "expected 0 rows, got %d", count)
// The "last-cleanup" value should be <= 1 second (+ a bit of buffer)
lastCleanup, err = loadLastCleanupInterval(ctx, dbClient, "ttlschema", "ttltest_metadata")
require.NoError(t, err, "failed to load value for 'last-cleanup'")
assert.LessOrEqual(t, lastCleanup, int64(1200))
})
t.Run("cleanup concurrency", func(t *testing.T) {
// Set to run every hour
// (we'll manually trigger more frequent iterations)
md.Properties["cleanupIntervalInSeconds"] = "3600"
storeObj := state_sqlserver.New(log).(*state_sqlserver.SQLServer)
err := storeObj.Init(context.Background(), md)
require.NoError(t, err, "failed to init")
defer storeObj.Close()
cleanupInterval := storeObj.GetCleanupInterval()
assert.NotNil(t, cleanupInterval)
assert.Equal(t, time.Hour, *cleanupInterval)
// Seed the database with some records
err = clearTable(ctx, dbClient, "ttlschema", "ttltest")
require.NoError(t, err, "failed to clear table")
err = populateTTLRecords(ctx, dbClient, "ttlschema", "ttltest")
require.NoError(t, err, "failed to seed records")
// Validate that 20 records are present
count, err := countRowsInTable(ctx, dbClient, "ttlschema", "ttltest")
require.NoError(t, err, "failed to run query to count rows")
assert.Equal(t, 20, count)
// Set last-cleanup to 1s ago
_, err = dbClient.ExecContext(ctx, `UPDATE [ttlschema].[ttltest_metadata] SET [Value] = CONVERT(nvarchar(MAX), DATEADD(second, -1, GETDATE()), 21) WHERE [Key] = 'last-cleanup'`)
require.NoError(t, err, "failed to set last-cleanup")
// The "last-cleanup" value
lastCleanup, err := loadLastCleanupInterval(ctx, dbClient, "ttlschema", "ttltest_metadata")
require.NoError(t, err, "failed to load value for 'last-cleanup'")
assert.LessOrEqual(t, lastCleanup, int64(1200))
lastCleanupValueOrig, err := getValueFromMetadataTable(ctx, dbClient, "ttlschema", "ttltest_metadata", "'last-cleanup'")
require.NoError(t, err, "failed to load absolute value for 'last-cleanup'")
require.NotEmpty(t, lastCleanupValueOrig)
// Trigger the background cleanup, which should do nothing because the last cleanup was < 3600s
require.NoError(t, storeObj.CleanupExpired(), "CleanupExpired returned an error")
// Validate that 20 records are still present
count, err = countRowsInTable(ctx, dbClient, "ttlschema", "ttltest")
require.NoError(t, err, "failed to run query to count rows")
assert.Equal(t, 20, count)
// The "last-cleanup" value should not have been changed
lastCleanupValue, err := getValueFromMetadataTable(ctx, dbClient, "ttlschema", "ttltest_metadata", "'last-cleanup'")
require.NoError(t, err, "failed to load absolute value for 'last-cleanup'")
assert.Equal(t, lastCleanupValueOrig, lastCleanupValue)
})
})
return nil
}
}
t.Run("SQLServer certification using SQL Server Docker", func(t *testing.T) {
flow.New(t, "SQLServer certification using SQL Server Docker").
// Run SQL Server using Docker Compose.
Step(dockercompose.Run("sqlserver", dockerComposeYAML)).
Step("wait for SQL Server readiness", retry.Do(time.Second*3, 10, checkSQLServerAvailability)).
// Run the Dapr sidecar with the SQL Server component.
Step(sidecar.Run(sidecarNamePrefix+"dockerDefault",
embedded.WithoutApp(),
embedded.WithDaprGRPCPort(currentGrpcPort),
embedded.WithDaprHTTPPort(currentHTTPPort),
embedded.WithResourcesPath("components/docker/default"),
embedded.WithProfilingEnabled(false),
componentRuntimeOptions(),
)).
Step("Run basic test", basicTest).
Step("Run basic TTL test", basicTTLTest).
// Introduce network interruption of 10 seconds
// Note: the connection timeout is set to 5 seconds via the component metadata connection string.
Step("interrupt network",
network.InterruptNetwork(10*time.Second, nil, nil, "1433", "1434")).
// Component should recover at this point.
Step("wait", flow.Sleep(5*time.Second)).
Step("Run basic test again to verify reconnection occurred", basicTest).
Step("Run SQL injection test", verifySQLInjectionTest, sidecar.Stop(sidecarNamePrefix+"dockerDefault")).
Step("run TTL test", ttlTest(dockerConnectionString+"database=certificationtest;")).
Step("Stopping SQL Server Docker container", dockercompose.Stop("sqlserver", dockerComposeYAML)).
Run()
})
ports, err = dapr_testing.GetFreePorts(2)
assert.NoError(t, err)
currentGrpcPort = ports[0]
currentHTTPPort = ports[1]
t.Run("Using existing custom schema with indexed data", func(t *testing.T) {
flow.New(t, "Using existing custom schema with indexed data").
// Run SQL Server using Docker Compose.
Step(dockercompose.Run("sqlserver", dockerComposeYAML)).
Step("wait for SQL Server readiness", retry.Do(time.Second*3, 10, checkSQLServerAvailability)).
Step("Creating schema", createCustomSchema).
// Run the Dapr sidecar with the SQL Server component.
Step(sidecar.Run(sidecarNamePrefix+"dockerCustomSchema",
embedded.WithoutApp(),
embedded.WithDaprGRPCPort(currentGrpcPort),
embedded.WithDaprHTTPPort(currentHTTPPort),
embedded.WithResourcesPath("components/docker/customschemawithindex"),
embedded.WithProfilingEnabled(false),
componentRuntimeOptions(),
)).
Step("Run indexed properties verification test", verifyIndexedPopertiesTest, sidecar.Stop(sidecarNamePrefix+"dockerCustomSchema")).
Step("Stopping SQL Server Docker container", dockercompose.Stop("sqlserver", dockerComposeYAML)).
Run()
})
ports, err = dapr_testing.GetFreePorts(2)
assert.NoError(t, err)
currentGrpcPort = ports[0]
currentHTTPPort = ports[1]
t.Run("SQL Server certification using Azure SQL", func(t *testing.T) {
flow.New(t, "SQL Server certification using Azure SQL").
// Run the Dapr sidecar with the SQL Server component.
Step(sidecar.Run(sidecarNamePrefix+"azure",
embedded.WithoutApp(),
embedded.WithDaprGRPCPort(currentGrpcPort),
embedded.WithDaprHTTPPort(currentHTTPPort),
embedded.WithResourcesPath("components/azure"),
embedded.WithProfilingEnabled(false),
componentRuntimeOptions(),
)).
Step("Run basic test", basicTest).
Step("Run basic TTL test", basicTTLTest).
Step("interrupt network",
network.InterruptNetwork(15*time.Second, nil, nil, "1433", "1434")).
// Component should recover at this point.
Step("wait", flow.Sleep(10*time.Second)).
Step("Run basic test again to verify reconnection occurred", basicTest).
Step("Run SQL injection test", verifySQLInjectionTest, sidecar.Stop(sidecarNamePrefix+"azure")).
Step("run TTL test", ttlTest(os.Getenv("AzureSqlServerConnectionString")+"database=stablecertification;")).
Run()
})
}
func componentRuntimeOptions() []runtime.Option {
log := logger.NewLogger("dapr.components")
stateRegistry := state_loader.NewRegistry()
stateRegistry.Logger = log
stateRegistry.RegisterComponent(state_sqlserver.New, "sqlserver")
secretstoreRegistry := secretstores_loader.NewRegistry()
secretstoreRegistry.Logger = log
secretstoreRegistry.RegisterComponent(secretstore_env.NewEnvSecretStore, "local.env")
return []runtime.Option{
runtime.WithStates(stateRegistry),
runtime.WithSecretStores(secretstoreRegistry),
}
}
func countRowsInTable(ctx context.Context, dbClient *sql.DB, schema, table string) (count int, err error) {
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
err = dbClient.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM [%s].[%s]", schema, table)).Scan(&count)
return
}
func clearTable(ctx context.Context, dbClient *sql.DB, schema, table string) error {
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
_, err := dbClient.ExecContext(ctx, fmt.Sprintf("DELETE FROM [%s].[%s]", schema, table))
return err
}
func loadLastCleanupInterval(ctx context.Context, dbClient *sql.DB, schema, table string) (lastCleanup int64, err error) {
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
err = dbClient.
QueryRowContext(ctx,
fmt.Sprintf("SELECT DATEDIFF(MILLISECOND, CAST([Value] AS DATETIME2), GETDATE()) FROM [%s].[%s] WHERE [Key] = 'last-cleanup'", schema, table),
).
Scan(&lastCleanup)
return
}
func populateTTLRecords(ctx context.Context, dbClient *sql.DB, schema, table string) error {
// Insert 10 records that have expired, and 10 that will expire in 4
// seconds.
rows := make([][]any, 20)
for i := 0; i < 10; i++ {
rows[i] = []any{
fmt.Sprintf("'expired_%d'", i),
json.RawMessage(fmt.Sprintf("'value_%d'", i)),
"DATEADD(MINUTE, -1, GETDATE())",
}
}
for i := 0; i < 10; i++ {
rows[i+10] = []any{
fmt.Sprintf("'notexpired_%d'", i),
json.RawMessage(fmt.Sprintf(`'value_%d'`, i)),
"DATEADD(SECOND, 4, GETDATE())",
}
}
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
for _, row := range rows {
_, err := dbClient.ExecContext(ctx, fmt.Sprintf(
"INSERT INTO [%[1]s].[%[2]s] ([Key], [Data], [ExpireDate]) VALUES (%[3]s, %[4]s, %[5]s)",
schema, table, row[0], row[1], row[2]),
)
if err != nil {
return err
}
}
return nil
}
func getValueFromMetadataTable(ctx context.Context, dbClient *sql.DB, schema, table, key string) (value string, err error) {
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
err = dbClient.
QueryRowContext(ctx, fmt.Sprintf("SELECT [Value] FROM [%[1]s].[%[2]s] WHERE [Key] = %[3]s", schema, table, key)).
Scan(&value)
return
}