Fixes to MySQL state store (#1978)

This commit is contained in:
Alessandro (Ale) Segala 2022-08-18 11:50:53 -07:00 committed by GitHub
parent e87cd5e4cb
commit 8b48210e3e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 583 additions and 491 deletions

View File

@ -45,12 +45,10 @@ func (m *mySQLFactory) RegisterTLSConfig(pemPath string) error {
if readErr != nil {
m.logger.Error("Error reading PEM file from " + pemPath)
return readErr
}
ok := rootCertPool.AppendCertsFromPEM(pem)
if !ok {
return fmt.Errorf("failed to append PEM")
}

View File

@ -21,12 +21,11 @@ import (
"fmt"
"strings"
"github.com/agrea/ptr"
"github.com/google/uuid"
"github.com/dapr/components-contrib/state"
"github.com/dapr/components-contrib/state/utils"
"github.com/dapr/kit/logger"
"github.com/dapr/kit/ptr"
)
// Optimistic Concurrency is implemented using a string column that stores
@ -65,12 +64,10 @@ const (
// MySQL state store.
type MySQL struct {
// Name of the table to store state. If the table does not exist it will
// be created.
// Name of the table to store state. If the table does not exist it will be created.
tableName string
// Name of the table to create to store state. If the table does not exist
// it will be created.
// Name of the table to create to store state. If the table does not exist it will be created.
schemaName string
connectionString string
@ -116,8 +113,11 @@ func (m *MySQL) Init(metadata state.Metadata) error {
m.logger.Debug("Initializing MySql state store")
val, ok := metadata.Properties[tableNameKey]
if ok && val != "" {
// Sanitize the table name
if !validIdentifier(val) {
return fmt.Errorf("table name '%s' is not valid", val)
}
m.tableName = val
} else {
// Default to the constant
@ -125,8 +125,11 @@ func (m *MySQL) Init(metadata state.Metadata) error {
}
val, ok = metadata.Properties[schemaNameKey]
if ok && val != "" {
// Sanitize the schema name
if !validIdentifier(val) {
return fmt.Errorf("schema name '%s' is not valid", val)
}
m.schemaName = val
} else {
// Default to the constant
@ -134,28 +137,28 @@ func (m *MySQL) Init(metadata state.Metadata) error {
}
m.connectionString, ok = metadata.Properties[connectionStringKey]
if !ok || m.connectionString == "" {
m.logger.Error("Missing MySql connection string")
return fmt.Errorf(errMissingConnectionString)
}
val, ok = metadata.Properties[pemPathKey]
if ok && val != "" {
err := m.factory.RegisterTLSConfig(val)
if err != nil {
m.logger.Error(err)
return err
}
}
db, err := m.factory.Open(m.connectionString)
if err != nil {
m.logger.Error(err)
return err
}
// will be nil if everything is good or an err that needs to be returned
return m.finishInit(db, err)
return m.finishInit(db)
}
// Features returns the features available in this state store.
@ -164,29 +167,19 @@ func (m *MySQL) Features() []state.Feature {
}
// Separated out to make this portion of code testable.
func (m *MySQL) finishInit(db *sql.DB, err error) error {
func (m *MySQL) finishInit(db *sql.DB) error {
m.db = db
err := m.ensureStateSchema()
if err != nil {
m.logger.Error(err)
return err
}
m.db = db
schemaErr := m.ensureStateSchema()
if schemaErr != nil {
m.logger.Error(schemaErr)
return schemaErr
}
pingErr := m.db.Ping()
if pingErr != nil {
m.logger.Error(pingErr)
return pingErr
err = m.db.Ping()
if err != nil {
m.logger.Error(err)
return err
}
// will be nil if everything is good or an err that needs to be returned
@ -201,9 +194,9 @@ func (m *MySQL) ensureStateSchema() error {
if !exists {
m.logger.Infof("Creating MySql schema '%s'", m.schemaName)
_, err = m.db.Exec(`CREATE DATABASE ?`, m.schemaName)
_, err = m.db.Exec(
fmt.Sprintf("CREATE DATABASE %s;", m.schemaName),
)
if err != nil {
return err
}
@ -243,7 +236,7 @@ func (m *MySQL) ensureStateTable(stateTableName string) error {
// never need to pass it in.
// eTag is a UUID stored as a 36 characters string. It needs to be passed
// in on inserts and updates and is used for Optimistic Concurrency
// Note that stateTableName is sanitized
//nolint:gosec
createTable := fmt.Sprintf(`CREATE TABLE %s (
id VARCHAR(255) NOT NULL PRIMARY KEY,
@ -265,28 +258,22 @@ func (m *MySQL) ensureStateTable(stateTableName string) error {
}
func schemaExists(db *sql.DB, schemaName string) (bool, error) {
// Returns 1 or 0 as a string if the table exists or not
exists := ""
query := `SELECT EXISTS (
SELECT SCHEMA_NAME FROM information_schema.schemata WHERE SCHEMA_NAME = ?
) AS 'exists'`
// Returns 1 or 0 as a string if the table exists or not
) AS 'exists'`
err := db.QueryRow(query, schemaName).Scan(&exists)
return exists == "1", err
}
func tableExists(db *sql.DB, tableName string) (bool, error) {
// Returns 1 or 0 as a string if the table exists or not
exists := ""
query := `SELECT EXISTS (
SELECT TABLE_NAME FROM information_schema.tables WHERE TABLE_NAME = ?
) AS 'exists'`
// Returns 1 or 0 as a string if the table exists or not
) AS 'exists'`
err := db.QueryRow(query, tableName).Scan(&exists)
return exists == "1", err
}
@ -370,12 +357,18 @@ func (m *MySQL) Get(req *state.GetRequest) (*state.GetResponse, error) {
return nil, fmt.Errorf("missing key in get operation")
}
var eTag, value string
var isBinary bool
var (
eTag string
value []byte
isBinary bool
)
err := m.db.QueryRow(fmt.Sprintf(
//nolint:gosec
query := fmt.Sprintf(
`SELECT value, eTag, isbinary FROM %s WHERE id = ?`,
m.tableName), req.Key).Scan(&value, &eTag, &isBinary)
m.tableName, // m.tableName is sanitized
)
err := m.db.QueryRow(query, req.Key).Scan(&value, &eTag, &isBinary)
if err != nil {
// If no rows exist, return an empty response, otherwise return an error.
if errors.Is(err, sql.ErrNoRows) {
@ -386,27 +379,31 @@ func (m *MySQL) Get(req *state.GetRequest) (*state.GetResponse, error) {
}
if isBinary {
var s string
var data []byte
var (
s string
data []byte
)
if err = json.Unmarshal([]byte(value), &s); err != nil {
err = json.Unmarshal(value, &s)
if err != nil {
return nil, err
}
if data, err = base64.StdEncoding.DecodeString(s); err != nil {
data, err = base64.StdEncoding.DecodeString(s)
if err != nil {
return nil, err
}
return &state.GetResponse{
Data: data,
ETag: ptr.String(eTag),
ETag: ptr.Of(eTag),
Metadata: req.Metadata,
}, nil
}
return &state.GetResponse{
Data: []byte(value),
ETag: ptr.String(eTag),
Data: value,
ETag: ptr.Of(eTag),
Metadata: req.Metadata,
}, nil
}
@ -428,41 +425,56 @@ func (m *MySQL) setValue(req *state.SetRequest) error {
}
if req.Key == "" {
return fmt.Errorf("missing key in set operation")
return errors.New("missing key in set operation")
}
if v, ok := req.Value.(string); ok && v == "" {
return fmt.Errorf("empty string is not allowed in set operation")
var v any
isBinary := false
switch x := req.Value.(type) {
case string:
if x == "" {
return errors.New("empty string is not allowed in set operation")
}
v = x
case []uint8:
isBinary = true
v = base64.StdEncoding.EncodeToString(x)
default:
v = x
}
v := req.Value
byteArray, isBinary := req.Value.([]uint8)
if isBinary {
v = base64.StdEncoding.EncodeToString(byteArray)
}
// Convert to json string
bt, _ := utils.Marshal(v, json.Marshal)
value := string(bt)
var result sql.Result
encB, _ := json.Marshal(v)
enc := string(encB)
eTag := uuid.New().String()
// Sprintf is required for table name because sql.DB does not substitute
// parameters for table names.
// Other parameters use sql.DB parameter substitution.
if req.ETag == nil || *req.ETag == "" {
// If this is a duplicate MySQL returns that two rows affected
result, err = m.db.Exec(fmt.Sprintf(
`INSERT INTO %s (value, id, eTag, isbinary)
VALUES (?, ?, ?, ?) on duplicate key update value=?, eTag=?, isbinary=?;`,
m.tableName), value, req.Key, eTag, isBinary, value, eTag, isBinary)
} else {
var result sql.Result
var maxRows int64 = 1
if req.Options.Concurrency == state.FirstWrite && (req.ETag == nil || *req.ETag == "") {
// With first-write-wins and no etag, we can insert the row only if it doesn't exist
//nolint:gosec
query := fmt.Sprintf(
`INSERT INTO %s (value, id, eTag, isbinary) VALUES (?, ?, ?, ?);`,
m.tableName, // m.tableName is sanitized
)
result, err = m.db.Exec(query, enc, req.Key, eTag, isBinary)
} else if req.ETag != nil && *req.ETag != "" {
// When an eTag is provided do an update - not insert
result, err = m.db.Exec(fmt.Sprintf(
`UPDATE %s SET value = ?, eTag = ?, isbinary = ?
WHERE id = ? AND eTag = ?;`,
m.tableName), value, eTag, isBinary, req.Key, *req.ETag)
//nolint:gosec
query := fmt.Sprintf(
`UPDATE %s SET value = ?, eTag = ?, isbinary = ? WHERE id = ? AND eTag = ?;`,
m.tableName, // m.tableName is sanitized
)
result, err = m.db.Exec(query, enc, eTag, isBinary, req.Key, *req.ETag)
} else {
// If this is a duplicate MySQL returns that two rows affected
maxRows = 2
//nolint:gosec
query := fmt.Sprintf(
`INSERT INTO %s (value, id, eTag, isbinary) VALUES (?, ?, ?, ?) on duplicate key update value=?, eTag=?, isbinary=?;`,
m.tableName, // m.tableName is sanitized
)
result, err = m.db.Exec(query, enc, req.Key, eTag, isBinary, enc, eTag, isBinary)
}
if err != nil {
@ -482,14 +494,12 @@ func (m *MySQL) setValue(req *state.SetRequest) error {
err = fmt.Errorf(`rows affected error: no rows match given key '%s' and eTag '%s'`, req.Key, *req.ETag)
err = state.NewETagError(state.ETagMismatch, err)
m.logger.Error(err)
return err
}
if rows > 2 {
err = fmt.Errorf(`rows affected error: more than 2 row affected, expected 2, actual %d`, rows)
if rows > maxRows {
err = fmt.Errorf(`rows affected error: more than %d row affected; actual %d`, maxRows, rows)
m.logger.Error(err)
return err
}
@ -507,20 +517,16 @@ func (m *MySQL) BulkSet(req []state.SetRequest) error {
}
if len(req) > 0 {
for _, s := range req {
sa := s // Fix for goSec G601: Implicit memory aliasing in for loop.
err = m.Set(&sa)
for i := range req {
err = m.Set(&req[i])
if err != nil {
tx.Rollback()
return err
}
}
}
err = tx.Commit()
return err
return tx.Commit()
}
// Multi handles multiple transactions.
@ -538,26 +544,26 @@ func (m *MySQL) Multi(request *state.TransactionalStateRequest) error {
case state.Upsert:
setReq, err := m.getSets(req)
if err != nil {
tx.Rollback()
_ = tx.Rollback()
return err
}
err = m.Set(&setReq)
if err != nil {
tx.Rollback()
_ = tx.Rollback()
return err
}
case state.Delete:
delReq, err := m.getDeletes(req)
if err != nil {
tx.Rollback()
_ = tx.Rollback()
return err
}
err = m.Delete(&delReq)
if err != nil {
tx.Rollback()
_ = tx.Rollback()
return err
}
@ -591,7 +597,7 @@ func (m *MySQL) getDeletes(req state.TransactionalStateOperation) (state.DeleteR
}
if delReq.Key == "" {
return delReq, fmt.Errorf("missing key in upsert operation")
return delReq, fmt.Errorf("missing key in delete operation")
}
return delReq, nil
@ -612,3 +618,24 @@ func (m *MySQL) Close() error {
return nil
}
// Validates an identifier, such as table or DB name.
// This is based on the rules for allowed unquoted identifiers (https://dev.mysql.com/doc/refman/8.0/en/identifiers.html), but more restrictive as it doesn't allow non-ASCII characters or the $ sign
func validIdentifier(v string) bool {
if v == "" {
return false
}
// Loop through the string as byte slice as we only care about ASCII characters
b := []byte(v)
for i := 0; i < len(b); i++ {
if (b[i] >= '0' && b[i] <= '9') ||
(b[i] >= 'a' && b[i] <= 'z') ||
(b[i] >= 'A' && b[i] <= 'Z') ||
b[i] == '_' {
continue
}
return false
}
return true
}

View File

@ -30,6 +30,7 @@ import (
"github.com/dapr/components-contrib/state"
"github.com/dapr/kit/logger"
"github.com/dapr/kit/ptr"
)
const (
@ -44,6 +45,14 @@ type fakeItem struct {
Color string
}
func (f fakeItem) MarshalJSON() ([]byte, error) {
return json.Marshal(f.Color)
}
func (f *fakeItem) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, &f.Color)
}
func TestMySQLIntegration(t *testing.T) {
t.Parallel()
@ -60,7 +69,59 @@ func TestMySQLIntegration(t *testing.T) {
}
t.Run("Test init configurations", func(t *testing.T) {
testInitConfiguration(t)
// Tests valid and invalid config settings.
logger := logger.NewLogger("test")
// define a struct the contain the metadata and create
// two instances of it in a tests slice
tests := []struct {
name string
props map[string]string
expectedErr string
}{
{
name: "Empty",
props: map[string]string{},
expectedErr: errMissingConnectionString,
},
{
name: "Valid connection string",
props: map[string]string{
connectionStringKey: getConnectionString(""),
pemPathKey: getPemPath(),
},
expectedErr: "",
},
{
name: "Valid table name",
props: map[string]string{
connectionStringKey: getConnectionString(""),
pemPathKey: getPemPath(),
tableNameKey: "stateStore",
},
expectedErr: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := NewMySQLStateStore(logger)
defer p.Close()
metadata := state.Metadata{
Properties: tt.props,
}
err := p.Init(metadata)
if tt.expectedErr == "" {
assert.Nil(t, err)
} else {
assert.NotNil(t, err)
assert.Equal(t, err.Error(), tt.expectedErr)
}
})
}
})
pemPath := getPemPath()
@ -81,32 +142,107 @@ func TestMySQLIntegration(t *testing.T) {
t.Run("Create table succeeds", func(t *testing.T) {
t.Parallel()
testCreateTable(t, mys)
tableName := "test_state"
// Drop the table if it already exists
exists, err := tableExists(mys.db, tableName)
assert.Nil(t, err)
if exists {
dropTable(t, mys.db, tableName)
}
// Create the state table and test for its existence
// There should be no error
err = mys.ensureStateTable(tableName)
assert.Nil(t, err)
// Now create it and make sure there are no errors
exists, err = tableExists(mys.db, tableName)
assert.Nil(t, err)
assert.True(t, exists)
// Drop the state table
dropTable(t, mys.db, tableName)
})
t.Run("Get Set Delete one item", func(t *testing.T) {
t.Parallel()
setGetUpdateDeleteOneItem(t, mys)
// Validates setting one item, getting it, and deleting it.
key := randomKey()
value := &fakeItem{Color: "yellow"}
setItem(t, mys, key, value, nil)
getResponse, outputObject := getItem(t, mys, key)
assert.Equal(t, value, outputObject)
newValue := &fakeItem{Color: "green"}
setItem(t, mys, key, newValue, getResponse.ETag)
getResponse, outputObject = getItem(t, mys, key)
assert.Equal(t, newValue, outputObject)
deleteItem(t, mys, key, getResponse.ETag)
})
t.Run("Get item that does not exist", func(t *testing.T) {
t.Parallel()
getItemThatDoesNotExist(t, mys)
// Validates the behavior of retrieving an item that does not exist.
key := randomKey()
response, outputObject := getItem(t, mys, key)
assert.Nil(t, response.Data)
assert.Equal(t, "", outputObject.Color)
})
t.Run("Get item with no key fails", func(t *testing.T) {
t.Parallel()
getItemWithNoKey(t, mys)
// Validates that attempting a Get operation without providing a key will return an error.
getReq := &state.GetRequest{
Key: "",
}
response, getErr := mys.Get(getReq)
assert.NotNil(t, getErr)
assert.Nil(t, response)
})
t.Run("Set updates the updatedate field", func(t *testing.T) {
t.Parallel()
setUpdatesTheUpdatedateField(t, mys)
// Proves that the updatedate is set for an
// update, and set upon insert. The updatedate is used as the eTag so must be
// set. It is also auto updated on update by MySQL.
key := randomKey()
value := &fakeItem{Color: "orange"}
setItem(t, mys, key, value, nil)
// insertdate and updatedate should have a value
_, insertdate, updatedate, eTag := getRowData(t, key)
assert.NotNil(t, insertdate, "insertdate was not set")
assert.NotNil(t, updatedate, "updatedate was not set")
// insertdate should not change, updatedate should have a value
value = &fakeItem{Color: "aqua"}
setItem(t, mys, key, value, nil)
_, newinsertdate, _, newETag := getRowData(t, key)
assert.Equal(t, insertdate, newinsertdate, "InsertDate was changed")
assert.NotEqual(t, eTag, newETag, "eTag was not updated")
deleteItem(t, mys, key, nil)
})
t.Run("Set item with no key fails", func(t *testing.T) {
t.Parallel()
setItemWithNoKey(t, mys)
setReq := &state.SetRequest{
Key: "",
}
err := mys.Set(setReq)
assert.NotNil(t, err, "Error was not nil when setting item with no key.")
})
t.Run("Bulk set and bulk delete", func(t *testing.T) {
@ -116,255 +252,300 @@ func TestMySQLIntegration(t *testing.T) {
t.Run("Update and delete with eTag succeeds", func(t *testing.T) {
t.Parallel()
updateAndDeleteWithETagSucceeds(t, mys)
// Create and retrieve new item
key := randomKey()
value := &fakeItem{Color: "hazel"}
setItem(t, mys, key, value, nil)
getResponse, _ := getItem(t, mys, key)
assert.NotNil(t, getResponse.ETag)
// Change the value and compare
value.Color = "purple"
setItem(t, mys, key, value, getResponse.ETag)
updateResponse, updatedItem := getItem(t, mys, key)
assert.Equal(t, value, updatedItem, "Item should have been updated")
assert.NotEqual(t, getResponse.ETag, updateResponse.ETag,
"ETag should change when item is updated")
// Delete
deleteItem(t, mys, key, updateResponse.ETag)
assert.False(t, storeItemExists(t, key), "Item is not in the data store")
})
t.Run("Update with old eTag fails", func(t *testing.T) {
t.Parallel()
updateWithOldETagFails(t, mys)
// Create and retrieve new item
key := randomKey()
value := &fakeItem{Color: "gray"}
setItem(t, mys, key, value, nil)
getResponse, _ := getItem(t, mys, key)
assert.NotNil(t, getResponse.ETag)
originalEtag := getResponse.ETag
// Change the value and get the updated eTag
newValue := &fakeItem{Color: "silver"}
setItem(t, mys, key, newValue, originalEtag)
_, updatedItem := getItem(t, mys, key)
assert.Equal(t, newValue, updatedItem)
// Update again with the original eTag - expect update failure
newValue = &fakeItem{Color: "maroon"}
setReq := &state.SetRequest{
Key: key,
ETag: originalEtag,
Value: newValue,
}
err := mys.Set(setReq)
assert.NotNil(t, err, "Error was not thrown using old eTag")
})
t.Run("Insert with eTag fails", func(t *testing.T) {
t.Parallel()
newItemWithEtagFails(t, mys)
value := &fakeItem{Color: "teal"}
invalidETag := "12345"
setReq := &state.SetRequest{
Key: randomKey(),
ETag: &invalidETag,
Value: value,
}
err := mys.Set(setReq)
assert.NotNil(t, err)
})
t.Run("Delete with invalid eTag fails", func(t *testing.T) {
t.Parallel()
deleteWithInvalidEtagFails(t, mys)
// Create new item
key := randomKey()
value := &fakeItem{Color: "mauve"}
setItem(t, mys, key, value, nil)
eTag := "1234"
// Delete the item with a fake eTag
deleteReq := &state.DeleteRequest{
Key: key,
ETag: &eTag,
}
err := mys.Delete(deleteReq)
assert.NotNil(t, err)
})
t.Run("Delete item with no key fails", func(t *testing.T) {
t.Parallel()
deleteWithNoKeyFails(t, mys)
deleteReq := &state.DeleteRequest{
Key: "",
}
err := mys.Delete(deleteReq)
assert.NotNil(t, err)
})
t.Run("Delete an item that does not exist", func(t *testing.T) {
t.Parallel()
deleteItemThatDoesNotExist(t, mys)
// Delete the item with a key not in the store
deleteReq := &state.DeleteRequest{
Key: randomKey(),
}
err := mys.Delete(deleteReq)
assert.Nil(t, err)
})
t.Run("Inserts with first-write-wins", func(t *testing.T) {
t.Parallel()
// Insert without an etag should work on new keys
key := randomKey()
setReq := &state.SetRequest{
Key: key,
Value: &fakeItem{Color: "teal"},
Options: state.SetStateOption{
Concurrency: state.FirstWrite,
},
}
err := mys.Set(setReq)
assert.NoError(t, err)
// Get the etag
getResponse, _ := getItem(t, mys, key)
assert.NotNil(t, getResponse)
assert.NotNil(t, getResponse.ETag)
originalEtag := getResponse.ETag
// Insert without an etag should fail on existing keys
setReq = &state.SetRequest{
Key: key,
Value: &fakeItem{Color: "gray or grey"},
Options: state.SetStateOption{
Concurrency: state.FirstWrite,
},
}
err = mys.Set(setReq)
assert.ErrorContains(t, err, "Duplicate entry")
// Insert with invalid etag should fail on existing keys
setReq = &state.SetRequest{
Key: key,
Value: &fakeItem{Color: "pink"},
ETag: ptr.Of("no-etag"),
Options: state.SetStateOption{
Concurrency: state.FirstWrite,
},
}
err = mys.Set(setReq)
assert.ErrorContains(t, err, "possible etag mismatch")
// Insert with valid etag should succeed on existing keys
setReq = &state.SetRequest{
Key: key,
Value: &fakeItem{Color: "scarlet"},
ETag: originalEtag,
Options: state.SetStateOption{
Concurrency: state.FirstWrite,
},
}
err = mys.Set(setReq)
assert.NoError(t, err)
// Insert with an etag should fail on new keys
setReq = &state.SetRequest{
Key: randomKey(),
Value: &fakeItem{Color: "greige"},
ETag: ptr.Of("myetag"),
Options: state.SetStateOption{
Concurrency: state.FirstWrite,
},
}
err = mys.Set(setReq)
assert.ErrorContains(t, err, "possible etag mismatch")
})
t.Run("Multi with delete and set", func(t *testing.T) {
t.Parallel()
multiWithDeleteAndSet(t, mys)
var operations []state.TransactionalStateOperation
var deleteRequests []state.DeleteRequest
for i := 0; i < 3; i++ {
req := state.DeleteRequest{Key: randomKey()}
// Add the item to the database
setItem(t, mys, req.Key, randomJSON(), nil)
// Add the item to a slice of delete requests
deleteRequests = append(deleteRequests, req)
// Add the item to the multi transaction request
operations = append(operations, state.TransactionalStateOperation{
Operation: state.Delete,
Request: req,
})
}
// Create the set requests
var setRequests []state.SetRequest
for i := 0; i < 3; i++ {
req := state.SetRequest{
Key: randomKey(),
Value: randomJSON(),
}
setRequests = append(setRequests, req)
operations = append(operations, state.TransactionalStateOperation{
Operation: state.Upsert,
Request: req,
})
}
err := mys.Multi(&state.TransactionalStateRequest{
Operations: operations,
})
assert.Nil(t, err)
for _, delete := range deleteRequests {
assert.False(t, storeItemExists(t, delete.Key))
}
for _, set := range setRequests {
assert.True(t, storeItemExists(t, set.Key))
deleteItem(t, mys, set.Key, nil)
}
})
t.Run("Multi with delete only", func(t *testing.T) {
t.Parallel()
multiWithDeleteOnly(t, mys)
var operations []state.TransactionalStateOperation
var deleteRequests []state.DeleteRequest
for i := 0; i < 3; i++ {
req := state.DeleteRequest{Key: randomKey()}
// Add the item to the database
setItem(t, mys, req.Key, randomJSON(), nil)
// Add the item to a slice of delete requests
deleteRequests = append(deleteRequests, req)
// Add the item to the multi transaction request
operations = append(operations, state.TransactionalStateOperation{
Operation: state.Delete,
Request: req,
})
}
err := mys.Multi(&state.TransactionalStateRequest{
Operations: operations,
})
assert.Nil(t, err)
for _, delete := range deleteRequests {
assert.False(t, storeItemExists(t, delete.Key))
}
})
t.Run("Multi with set only", func(t *testing.T) {
t.Parallel()
multiWithSetOnly(t, mys)
})
}
func multiWithSetOnly(t *testing.T, mys *MySQL) {
var operations []state.TransactionalStateOperation
var setRequests []state.SetRequest
for i := 0; i < 3; i++ {
req := state.SetRequest{
Key: randomKey(),
Value: randomJSON(),
var operations []state.TransactionalStateOperation
var setRequests []state.SetRequest
for i := 0; i < 3; i++ {
req := state.SetRequest{
Key: randomKey(),
Value: randomJSON(),
}
setRequests = append(setRequests, req)
operations = append(operations, state.TransactionalStateOperation{
Operation: state.Upsert,
Request: req,
})
}
setRequests = append(setRequests, req)
operations = append(operations, state.TransactionalStateOperation{
Operation: state.Upsert,
Request: req,
err := mys.Multi(&state.TransactionalStateRequest{
Operations: operations,
})
}
assert.Nil(t, err)
err := mys.Multi(&state.TransactionalStateRequest{
Operations: operations,
})
assert.Nil(t, err)
for _, set := range setRequests {
assert.True(t, storeItemExists(t, set.Key))
deleteItem(t, mys, set.Key, nil)
}
}
func multiWithDeleteOnly(t *testing.T, mys *MySQL) {
var operations []state.TransactionalStateOperation
var deleteRequests []state.DeleteRequest
for i := 0; i < 3; i++ {
req := state.DeleteRequest{Key: randomKey()}
// Add the item to the database
setItem(t, mys, req.Key, randomJSON(), nil)
// Add the item to a slice of delete requests
deleteRequests = append(deleteRequests, req)
// Add the item to the multi transaction request
operations = append(operations, state.TransactionalStateOperation{
Operation: state.Delete,
Request: req,
})
}
err := mys.Multi(&state.TransactionalStateRequest{
Operations: operations,
})
assert.Nil(t, err)
for _, delete := range deleteRequests {
assert.False(t, storeItemExists(t, delete.Key))
}
}
func multiWithDeleteAndSet(t *testing.T, mys *MySQL) {
var operations []state.TransactionalStateOperation
var deleteRequests []state.DeleteRequest
for i := 0; i < 3; i++ {
req := state.DeleteRequest{Key: randomKey()}
// Add the item to the database
setItem(t, mys, req.Key, randomJSON(), nil)
// Add the item to a slice of delete requests
deleteRequests = append(deleteRequests, req)
// Add the item to the multi transaction request
operations = append(operations, state.TransactionalStateOperation{
Operation: state.Delete,
Request: req,
})
}
// Create the set requests
var setRequests []state.SetRequest
for i := 0; i < 3; i++ {
req := state.SetRequest{
Key: randomKey(),
Value: randomJSON(),
for _, set := range setRequests {
assert.True(t, storeItemExists(t, set.Key))
deleteItem(t, mys, set.Key, nil)
}
setRequests = append(setRequests, req)
operations = append(operations, state.TransactionalStateOperation{
Operation: state.Upsert,
Request: req,
})
}
err := mys.Multi(&state.TransactionalStateRequest{
Operations: operations,
})
assert.Nil(t, err)
for _, delete := range deleteRequests {
assert.False(t, storeItemExists(t, delete.Key))
}
for _, set := range setRequests {
assert.True(t, storeItemExists(t, set.Key))
deleteItem(t, mys, set.Key, nil)
}
}
func deleteItemThatDoesNotExist(t *testing.T, mys *MySQL) {
// Delete the item with a key not in the store
deleteReq := &state.DeleteRequest{
Key: randomKey(),
}
err := mys.Delete(deleteReq)
assert.Nil(t, err)
}
func deleteWithNoKeyFails(t *testing.T, mys *MySQL) {
deleteReq := &state.DeleteRequest{
Key: "",
}
err := mys.Delete(deleteReq)
assert.NotNil(t, err)
}
func deleteWithInvalidEtagFails(t *testing.T, mys *MySQL) {
// Create new item
key := randomKey()
value := &fakeItem{Color: "mauve"}
setItem(t, mys, key, value, nil)
eTag := "1234"
// Delete the item with a fake eTag
deleteReq := &state.DeleteRequest{
Key: key,
ETag: &eTag,
}
err := mys.Delete(deleteReq)
assert.NotNil(t, err)
}
// newItemWithEtagFails creates a new item and also supplies an ETag, which is
// invalid - expect failure.
func newItemWithEtagFails(t *testing.T, mys *MySQL) {
value := &fakeItem{Color: "teal"}
invalidETag := "12345"
setReq := &state.SetRequest{
Key: randomKey(),
ETag: &invalidETag,
Value: value,
}
err := mys.Set(setReq)
assert.NotNil(t, err)
}
func updateWithOldETagFails(t *testing.T, mys *MySQL) {
// Create and retrieve new item
key := randomKey()
value := &fakeItem{Color: "gray"}
setItem(t, mys, key, value, nil)
getResponse, _ := getItem(t, mys, key)
assert.NotNil(t, getResponse.ETag)
originalEtag := getResponse.ETag
// Change the value and get the updated eTag
newValue := &fakeItem{Color: "silver"}
setItem(t, mys, key, newValue, originalEtag)
_, updatedItem := getItem(t, mys, key)
assert.Equal(t, newValue, updatedItem)
// Update again with the original eTag - expect update failure
newValue = &fakeItem{Color: "maroon"}
setReq := &state.SetRequest{
Key: key,
ETag: originalEtag,
Value: newValue,
}
err := mys.Set(setReq)
assert.NotNil(t, err, "Error was not thrown using old eTag")
}
func updateAndDeleteWithETagSucceeds(t *testing.T, mys *MySQL) {
// Create and retrieve new item
key := randomKey()
value := &fakeItem{Color: "hazel"}
setItem(t, mys, key, value, nil)
getResponse, _ := getItem(t, mys, key)
assert.NotNil(t, getResponse.ETag)
// Change the value and compare
value.Color = "purple"
setItem(t, mys, key, value, getResponse.ETag)
updateResponse, updatedItem := getItem(t, mys, key)
assert.Equal(t, value, updatedItem, "Item should have been updated")
assert.NotEqual(t, getResponse.ETag, updateResponse.ETag,
"ETag should change when item is updated")
// Delete
deleteItem(t, mys, key, updateResponse.ETag)
assert.False(t, storeItemExists(t, key), "Item is not in the data store")
}
// Tests valid bulk sets and deletes.
@ -400,159 +581,6 @@ func testBulkSetAndBulkDelete(t *testing.T, mys *MySQL) {
assert.False(t, storeItemExists(t, setReq[1].Key))
}
func setItemWithNoKey(t *testing.T, mys *MySQL) {
setReq := &state.SetRequest{
Key: "",
}
err := mys.Set(setReq)
assert.NotNil(t, err, "Error was not nil when setting item with no key.")
}
// setUpdatesTheUpdatedateField proves that the updatedate is set for an
// update, and set upon insert. The updatedate is used as the eTag so must be
// set. It is also auto updated on update by MySQL.
func setUpdatesTheUpdatedateField(t *testing.T, mys *MySQL) {
key := randomKey()
value := &fakeItem{Color: "orange"}
setItem(t, mys, key, value, nil)
// insertdate and updatedate should have a value
_, insertdate, updatedate, eTag := getRowData(t, key)
assert.NotNil(t, insertdate, "insertdate was not set")
assert.NotNil(t, updatedate, "updatedate was not set")
// insertdate should not change, updatedate should have a value
value = &fakeItem{Color: "aqua"}
setItem(t, mys, key, value, nil)
_, newinsertdate, _, newETag := getRowData(t, key)
assert.Equal(t, insertdate, newinsertdate, "InsertDate was changed")
assert.NotEqual(t, eTag, newETag, "eTag was not updated")
deleteItem(t, mys, key, nil)
}
// getItemWithNoKey validates that attempting a Get operation without providing
// a key will return an error.
func getItemWithNoKey(t *testing.T, mys *MySQL) {
getReq := &state.GetRequest{
Key: "",
}
response, getErr := mys.Get(getReq)
assert.NotNil(t, getErr)
assert.Nil(t, response)
}
// getItemThatDoesNotExist validates the behavior of retrieving an item that
// does not exist.
func getItemThatDoesNotExist(t *testing.T, mys *MySQL) {
key := randomKey()
response, outputObject := getItem(t, mys, key)
assert.Nil(t, response.Data)
assert.Equal(t, "", outputObject.Color)
}
// setGetUpdateDeleteOneItem validates setting one item, getting it, and
// deleting it.
func setGetUpdateDeleteOneItem(t *testing.T, mys *MySQL) {
key := randomKey()
value := &fakeItem{Color: "yellow"}
setItem(t, mys, key, value, nil)
getResponse, outputObject := getItem(t, mys, key)
assert.Equal(t, value, outputObject)
newValue := &fakeItem{Color: "green"}
setItem(t, mys, key, newValue, getResponse.ETag)
getResponse, outputObject = getItem(t, mys, key)
assert.Equal(t, newValue, outputObject)
deleteItem(t, mys, key, getResponse.ETag)
}
// testCreateTable tests the ability to create the state table.
func testCreateTable(t *testing.T, mys *MySQL) {
tableName := "test_state"
// Drop the table if it already exists
exists, err := tableExists(mys.db, tableName)
assert.Nil(t, err)
if exists {
dropTable(t, mys.db, tableName)
}
// Create the state table and test for its existence
// There should be no error
err = mys.ensureStateTable(tableName)
assert.Nil(t, err)
// Now create it and make sure there are no errors
exists, err = tableExists(mys.db, tableName)
assert.Nil(t, err)
assert.True(t, exists)
// Drop the state table
dropTable(t, mys.db, tableName)
}
// testInitConfiguration tests valid and invalid config settings.
func testInitConfiguration(t *testing.T) {
logger := logger.NewLogger("test")
// define a struct the contain the metadata and create
// two instances of it in a tests slice
tests := []struct {
name string
props map[string]string
expectedErr string
}{
{
name: "Empty",
props: map[string]string{},
expectedErr: errMissingConnectionString,
},
{
name: "Valid connection string",
props: map[string]string{
connectionStringKey: getConnectionString(""),
pemPathKey: getPemPath(),
},
expectedErr: "",
},
{
name: "Valid table name",
props: map[string]string{
connectionStringKey: getConnectionString(""),
pemPathKey: getPemPath(),
tableNameKey: "stateStore",
},
expectedErr: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := NewMySQLStateStore(logger)
defer p.Close()
metadata := state.Metadata{
Properties: tt.props,
}
err := p.Init(metadata)
if tt.expectedErr == "" {
assert.Nil(t, err)
} else {
assert.NotNil(t, err)
assert.Equal(t, err.Error(), tt.expectedErr)
}
})
}
}
func dropTable(t *testing.T, db *sql.DB, tableName string) {
_, err := db.Exec(fmt.Sprintf(
`DROP TABLE %s;`,

View File

@ -61,7 +61,7 @@ func TestFinishInitHandlesSchemaExistsError(t *testing.T) {
m.mock1.ExpectQuery("SELECT EXISTS").WillReturnError(expectedErr)
// Act
actualErr := m.mySQL.finishInit(m.mySQL.db, nil)
actualErr := m.mySQL.finishInit(m.mySQL.db)
// Assert
assert.NotNil(t, actualErr, "now error returned")
@ -80,26 +80,13 @@ func TestFinishInitHandlesDatabaseCreateError(t *testing.T) {
m.mock1.ExpectExec("CREATE DATABASE").WillReturnError(expectedErr)
// Act
actualErr := m.mySQL.finishInit(m.mySQL.db, nil)
actualErr := m.mySQL.finishInit(m.mySQL.db)
// Assert
assert.NotNil(t, actualErr, "now error returned")
assert.Equal(t, "createDatabaseError", actualErr.Error(), "wrong error")
}
func TestFinishInitHandlesOpenError(t *testing.T) {
// Arrange
m, _ := mockDatabase(t)
defer m.mySQL.Close()
// Act
err := m.mySQL.finishInit(m.mySQL.db, fmt.Errorf("failed to open database"))
// Assert
assert.NotNil(t, err, "now error returned")
assert.Equal(t, "failed to open database", err.Error(), "wrong error")
}
func TestFinishInitHandlesPingError(t *testing.T) {
// Arrange
m, _ := mockDatabase(t)
@ -117,7 +104,7 @@ func TestFinishInitHandlesPingError(t *testing.T) {
m.mock2.ExpectPing().WillReturnError(expectedErr)
// Act
actualErr := m.mySQL.finishInit(m.mySQL.db, nil)
actualErr := m.mySQL.finishInit(m.mySQL.db)
// Assert
assert.NotNil(t, actualErr, "now error returned")
@ -145,7 +132,7 @@ func TestFinishInitHandlesTableExistsError(t *testing.T) {
m.mock2.ExpectQuery("SELECT EXISTS").WillReturnError(fmt.Errorf("tableExistsError"))
// Act
err := m.mySQL.finishInit(m.mySQL.db, nil)
err := m.mySQL.finishInit(m.mySQL.db)
// Assert
assert.NotNil(t, err, "no error returned")
@ -667,6 +654,21 @@ func TestInitSetsTableName(t *testing.T) {
assert.Equal(t, "stateStore", m.mySQL.tableName, "table name did not default")
}
func TestInitInvalidTableName(t *testing.T) {
// Arrange
t.Parallel()
m, _ := mockDatabase(t)
metadata := &state.Metadata{
Properties: map[string]string{connectionStringKey: "", tableNameKey: "🙃"},
}
// Act
err := m.mySQL.Init(*metadata)
// Assert
assert.ErrorContains(t, err, "table name '🙃' is not valid")
}
func TestInitSetsSchemaName(t *testing.T) {
// Arrange
t.Parallel()
@ -683,6 +685,21 @@ func TestInitSetsSchemaName(t *testing.T) {
assert.Equal(t, "stateStoreSchema", m.mySQL.schemaName, "table name did not default")
}
func TestInitInvalidSchemaName(t *testing.T) {
// Arrange
t.Parallel()
m, _ := mockDatabase(t)
metadata := &state.Metadata{
Properties: map[string]string{connectionStringKey: "", schemaNameKey: "?"},
}
// Act
err := m.mySQL.Init(*metadata)
// Assert
assert.ErrorContains(t, err, "schema name '?' is not valid")
}
// This state store does not support BulkGet so it must return false and
// nil nil.
func TestBulkGetReturnsNil(t *testing.T) {
@ -1018,3 +1035,25 @@ func (f *fakeMySQLFactory) Open(connectionString string) (*sql.DB, error) {
func (f *fakeMySQLFactory) RegisterTLSConfig(pemPath string) error {
return f.registerErr
}
func TestValidIdentifier(t *testing.T) {
tests := []struct {
name string
arg string
want bool
}{
{name: "empty string", arg: "", want: false},
{name: "valid characters only", arg: "acz_039_AZS", want: true},
{name: "invalid ASCII characters 1", arg: "$", want: false},
{name: "invalid ASCII characters 2", arg: "*", want: false},
{name: "invalid ASCII characters 3", arg: "hello world", want: false},
{name: "non-ASCII characters", arg: "🙃", want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := validIdentifier(tt.arg); got != tt.want {
t.Errorf("validIdentifier() = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -24,7 +24,7 @@ components:
operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "query" ]
- component: mysql
allOperations: false
operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag" ]
operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "first-write" ]
- component: azure.tablestorage.storage
allOperations: false
operations: ["set", "get", "delete", "etag", "bulkset", "bulkdelete", "first-write"]