diff --git a/state/mysql/mySQLFactory.go b/state/mysql/mySQLFactory.go index cfa804340..806487b66 100644 --- a/state/mysql/mySQLFactory.go +++ b/state/mysql/mySQLFactory.go @@ -45,12 +45,10 @@ func (m *mySQLFactory) RegisterTLSConfig(pemPath string) error { if readErr != nil { m.logger.Error("Error reading PEM file from " + pemPath) - return readErr } ok := rootCertPool.AppendCertsFromPEM(pem) - if !ok { return fmt.Errorf("failed to append PEM") } diff --git a/state/mysql/mysql.go b/state/mysql/mysql.go index bf6e63204..bee62f7c7 100644 --- a/state/mysql/mysql.go +++ b/state/mysql/mysql.go @@ -21,12 +21,11 @@ import ( "fmt" "strings" - "github.com/agrea/ptr" "github.com/google/uuid" "github.com/dapr/components-contrib/state" - "github.com/dapr/components-contrib/state/utils" "github.com/dapr/kit/logger" + "github.com/dapr/kit/ptr" ) // Optimistic Concurrency is implemented using a string column that stores @@ -65,12 +64,10 @@ const ( // MySQL state store. type MySQL struct { - // Name of the table to store state. If the table does not exist it will - // be created. + // Name of the table to store state. If the table does not exist it will be created. tableName string - // Name of the table to create to store state. If the table does not exist - // it will be created. + // Name of the table to create to store state. If the table does not exist it will be created. schemaName string connectionString string @@ -116,8 +113,11 @@ func (m *MySQL) Init(metadata state.Metadata) error { m.logger.Debug("Initializing MySql state store") val, ok := metadata.Properties[tableNameKey] - if ok && val != "" { + // Sanitize the table name + if !validIdentifier(val) { + return fmt.Errorf("table name '%s' is not valid", val) + } m.tableName = val } else { // Default to the constant @@ -125,8 +125,11 @@ func (m *MySQL) Init(metadata state.Metadata) error { } val, ok = metadata.Properties[schemaNameKey] - if ok && val != "" { + // Sanitize the schema name + if !validIdentifier(val) { + return fmt.Errorf("schema name '%s' is not valid", val) + } m.schemaName = val } else { // Default to the constant @@ -134,28 +137,28 @@ func (m *MySQL) Init(metadata state.Metadata) error { } m.connectionString, ok = metadata.Properties[connectionStringKey] - if !ok || m.connectionString == "" { m.logger.Error("Missing MySql connection string") - return fmt.Errorf(errMissingConnectionString) } val, ok = metadata.Properties[pemPathKey] - if ok && val != "" { err := m.factory.RegisterTLSConfig(val) if err != nil { m.logger.Error(err) - return err } } db, err := m.factory.Open(m.connectionString) + if err != nil { + m.logger.Error(err) + return err + } // will be nil if everything is good or an err that needs to be returned - return m.finishInit(db, err) + return m.finishInit(db) } // Features returns the features available in this state store. @@ -164,29 +167,19 @@ func (m *MySQL) Features() []state.Feature { } // Separated out to make this portion of code testable. -func (m *MySQL) finishInit(db *sql.DB, err error) error { +func (m *MySQL) finishInit(db *sql.DB) error { + m.db = db + + err := m.ensureStateSchema() if err != nil { m.logger.Error(err) - return err } - m.db = db - - schemaErr := m.ensureStateSchema() - - if schemaErr != nil { - m.logger.Error(schemaErr) - - return schemaErr - } - - pingErr := m.db.Ping() - - if pingErr != nil { - m.logger.Error(pingErr) - - return pingErr + err = m.db.Ping() + if err != nil { + m.logger.Error(err) + return err } // will be nil if everything is good or an err that needs to be returned @@ -201,9 +194,9 @@ func (m *MySQL) ensureStateSchema() error { if !exists { m.logger.Infof("Creating MySql schema '%s'", m.schemaName) - - _, err = m.db.Exec(`CREATE DATABASE ?`, m.schemaName) - + _, err = m.db.Exec( + fmt.Sprintf("CREATE DATABASE %s;", m.schemaName), + ) if err != nil { return err } @@ -243,7 +236,7 @@ func (m *MySQL) ensureStateTable(stateTableName string) error { // never need to pass it in. // eTag is a UUID stored as a 36 characters string. It needs to be passed // in on inserts and updates and is used for Optimistic Concurrency - + // Note that stateTableName is sanitized //nolint:gosec createTable := fmt.Sprintf(`CREATE TABLE %s ( id VARCHAR(255) NOT NULL PRIMARY KEY, @@ -265,28 +258,22 @@ func (m *MySQL) ensureStateTable(stateTableName string) error { } func schemaExists(db *sql.DB, schemaName string) (bool, error) { + // Returns 1 or 0 as a string if the table exists or not exists := "" - query := `SELECT EXISTS ( SELECT SCHEMA_NAME FROM information_schema.schemata WHERE SCHEMA_NAME = ? - ) AS 'exists'` - - // Returns 1 or 0 as a string if the table exists or not + ) AS 'exists'` err := db.QueryRow(query, schemaName).Scan(&exists) - return exists == "1", err } func tableExists(db *sql.DB, tableName string) (bool, error) { + // Returns 1 or 0 as a string if the table exists or not exists := "" - query := `SELECT EXISTS ( SELECT TABLE_NAME FROM information_schema.tables WHERE TABLE_NAME = ? - ) AS 'exists'` - - // Returns 1 or 0 as a string if the table exists or not + ) AS 'exists'` err := db.QueryRow(query, tableName).Scan(&exists) - return exists == "1", err } @@ -370,12 +357,18 @@ func (m *MySQL) Get(req *state.GetRequest) (*state.GetResponse, error) { return nil, fmt.Errorf("missing key in get operation") } - var eTag, value string - var isBinary bool + var ( + eTag string + value []byte + isBinary bool + ) - err := m.db.QueryRow(fmt.Sprintf( + //nolint:gosec + query := fmt.Sprintf( `SELECT value, eTag, isbinary FROM %s WHERE id = ?`, - m.tableName), req.Key).Scan(&value, &eTag, &isBinary) + m.tableName, // m.tableName is sanitized + ) + err := m.db.QueryRow(query, req.Key).Scan(&value, &eTag, &isBinary) if err != nil { // If no rows exist, return an empty response, otherwise return an error. if errors.Is(err, sql.ErrNoRows) { @@ -386,27 +379,31 @@ func (m *MySQL) Get(req *state.GetRequest) (*state.GetResponse, error) { } if isBinary { - var s string - var data []byte + var ( + s string + data []byte + ) - if err = json.Unmarshal([]byte(value), &s); err != nil { + err = json.Unmarshal(value, &s) + if err != nil { return nil, err } - if data, err = base64.StdEncoding.DecodeString(s); err != nil { + data, err = base64.StdEncoding.DecodeString(s) + if err != nil { return nil, err } return &state.GetResponse{ Data: data, - ETag: ptr.String(eTag), + ETag: ptr.Of(eTag), Metadata: req.Metadata, }, nil } return &state.GetResponse{ - Data: []byte(value), - ETag: ptr.String(eTag), + Data: value, + ETag: ptr.Of(eTag), Metadata: req.Metadata, }, nil } @@ -428,41 +425,56 @@ func (m *MySQL) setValue(req *state.SetRequest) error { } if req.Key == "" { - return fmt.Errorf("missing key in set operation") + return errors.New("missing key in set operation") } - if v, ok := req.Value.(string); ok && v == "" { - return fmt.Errorf("empty string is not allowed in set operation") + var v any + isBinary := false + switch x := req.Value.(type) { + case string: + if x == "" { + return errors.New("empty string is not allowed in set operation") + } + v = x + case []uint8: + isBinary = true + v = base64.StdEncoding.EncodeToString(x) + default: + v = x } - v := req.Value - byteArray, isBinary := req.Value.([]uint8) - if isBinary { - v = base64.StdEncoding.EncodeToString(byteArray) - } - - // Convert to json string - bt, _ := utils.Marshal(v, json.Marshal) - value := string(bt) - - var result sql.Result + encB, _ := json.Marshal(v) + enc := string(encB) eTag := uuid.New().String() - // Sprintf is required for table name because sql.DB does not substitute - // parameters for table names. - // Other parameters use sql.DB parameter substitution. - if req.ETag == nil || *req.ETag == "" { - // If this is a duplicate MySQL returns that two rows affected - result, err = m.db.Exec(fmt.Sprintf( - `INSERT INTO %s (value, id, eTag, isbinary) - VALUES (?, ?, ?, ?) on duplicate key update value=?, eTag=?, isbinary=?;`, - m.tableName), value, req.Key, eTag, isBinary, value, eTag, isBinary) - } else { + var result sql.Result + var maxRows int64 = 1 + + if req.Options.Concurrency == state.FirstWrite && (req.ETag == nil || *req.ETag == "") { + // With first-write-wins and no etag, we can insert the row only if it doesn't exist + //nolint:gosec + query := fmt.Sprintf( + `INSERT INTO %s (value, id, eTag, isbinary) VALUES (?, ?, ?, ?);`, + m.tableName, // m.tableName is sanitized + ) + result, err = m.db.Exec(query, enc, req.Key, eTag, isBinary) + } else if req.ETag != nil && *req.ETag != "" { // When an eTag is provided do an update - not insert - result, err = m.db.Exec(fmt.Sprintf( - `UPDATE %s SET value = ?, eTag = ?, isbinary = ? - WHERE id = ? AND eTag = ?;`, - m.tableName), value, eTag, isBinary, req.Key, *req.ETag) + //nolint:gosec + query := fmt.Sprintf( + `UPDATE %s SET value = ?, eTag = ?, isbinary = ? WHERE id = ? AND eTag = ?;`, + m.tableName, // m.tableName is sanitized + ) + result, err = m.db.Exec(query, enc, eTag, isBinary, req.Key, *req.ETag) + } else { + // If this is a duplicate MySQL returns that two rows affected + maxRows = 2 + //nolint:gosec + query := fmt.Sprintf( + `INSERT INTO %s (value, id, eTag, isbinary) VALUES (?, ?, ?, ?) on duplicate key update value=?, eTag=?, isbinary=?;`, + m.tableName, // m.tableName is sanitized + ) + result, err = m.db.Exec(query, enc, req.Key, eTag, isBinary, enc, eTag, isBinary) } if err != nil { @@ -482,14 +494,12 @@ func (m *MySQL) setValue(req *state.SetRequest) error { err = fmt.Errorf(`rows affected error: no rows match given key '%s' and eTag '%s'`, req.Key, *req.ETag) err = state.NewETagError(state.ETagMismatch, err) m.logger.Error(err) - return err } - if rows > 2 { - err = fmt.Errorf(`rows affected error: more than 2 row affected, expected 2, actual %d`, rows) + if rows > maxRows { + err = fmt.Errorf(`rows affected error: more than %d row affected; actual %d`, maxRows, rows) m.logger.Error(err) - return err } @@ -507,20 +517,16 @@ func (m *MySQL) BulkSet(req []state.SetRequest) error { } if len(req) > 0 { - for _, s := range req { - sa := s // Fix for goSec G601: Implicit memory aliasing in for loop. - err = m.Set(&sa) + for i := range req { + err = m.Set(&req[i]) if err != nil { tx.Rollback() - return err } } } - err = tx.Commit() - - return err + return tx.Commit() } // Multi handles multiple transactions. @@ -538,26 +544,26 @@ func (m *MySQL) Multi(request *state.TransactionalStateRequest) error { case state.Upsert: setReq, err := m.getSets(req) if err != nil { - tx.Rollback() + _ = tx.Rollback() return err } err = m.Set(&setReq) if err != nil { - tx.Rollback() + _ = tx.Rollback() return err } case state.Delete: delReq, err := m.getDeletes(req) if err != nil { - tx.Rollback() + _ = tx.Rollback() return err } err = m.Delete(&delReq) if err != nil { - tx.Rollback() + _ = tx.Rollback() return err } @@ -591,7 +597,7 @@ func (m *MySQL) getDeletes(req state.TransactionalStateOperation) (state.DeleteR } if delReq.Key == "" { - return delReq, fmt.Errorf("missing key in upsert operation") + return delReq, fmt.Errorf("missing key in delete operation") } return delReq, nil @@ -612,3 +618,24 @@ func (m *MySQL) Close() error { return nil } + +// Validates an identifier, such as table or DB name. +// This is based on the rules for allowed unquoted identifiers (https://dev.mysql.com/doc/refman/8.0/en/identifiers.html), but more restrictive as it doesn't allow non-ASCII characters or the $ sign +func validIdentifier(v string) bool { + if v == "" { + return false + } + + // Loop through the string as byte slice as we only care about ASCII characters + b := []byte(v) + for i := 0; i < len(b); i++ { + if (b[i] >= '0' && b[i] <= '9') || + (b[i] >= 'a' && b[i] <= 'z') || + (b[i] >= 'A' && b[i] <= 'Z') || + b[i] == '_' { + continue + } + return false + } + return true +} diff --git a/state/mysql/mysql_integration_test.go b/state/mysql/mysql_integration_test.go index 739a4be82..f70ed0f7f 100644 --- a/state/mysql/mysql_integration_test.go +++ b/state/mysql/mysql_integration_test.go @@ -30,6 +30,7 @@ import ( "github.com/dapr/components-contrib/state" "github.com/dapr/kit/logger" + "github.com/dapr/kit/ptr" ) const ( @@ -44,6 +45,14 @@ type fakeItem struct { Color string } +func (f fakeItem) MarshalJSON() ([]byte, error) { + return json.Marshal(f.Color) +} + +func (f *fakeItem) UnmarshalJSON(data []byte) error { + return json.Unmarshal(data, &f.Color) +} + func TestMySQLIntegration(t *testing.T) { t.Parallel() @@ -60,7 +69,59 @@ func TestMySQLIntegration(t *testing.T) { } t.Run("Test init configurations", func(t *testing.T) { - testInitConfiguration(t) + // Tests valid and invalid config settings. + logger := logger.NewLogger("test") + + // define a struct the contain the metadata and create + // two instances of it in a tests slice + tests := []struct { + name string + props map[string]string + expectedErr string + }{ + { + name: "Empty", + props: map[string]string{}, + expectedErr: errMissingConnectionString, + }, + { + name: "Valid connection string", + props: map[string]string{ + connectionStringKey: getConnectionString(""), + pemPathKey: getPemPath(), + }, + expectedErr: "", + }, + { + name: "Valid table name", + props: map[string]string{ + connectionStringKey: getConnectionString(""), + pemPathKey: getPemPath(), + tableNameKey: "stateStore", + }, + expectedErr: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := NewMySQLStateStore(logger) + defer p.Close() + + metadata := state.Metadata{ + Properties: tt.props, + } + + err := p.Init(metadata) + + if tt.expectedErr == "" { + assert.Nil(t, err) + } else { + assert.NotNil(t, err) + assert.Equal(t, err.Error(), tt.expectedErr) + } + }) + } }) pemPath := getPemPath() @@ -81,32 +142,107 @@ func TestMySQLIntegration(t *testing.T) { t.Run("Create table succeeds", func(t *testing.T) { t.Parallel() - testCreateTable(t, mys) + + tableName := "test_state" + + // Drop the table if it already exists + exists, err := tableExists(mys.db, tableName) + assert.Nil(t, err) + if exists { + dropTable(t, mys.db, tableName) + } + + // Create the state table and test for its existence + // There should be no error + err = mys.ensureStateTable(tableName) + assert.Nil(t, err) + + // Now create it and make sure there are no errors + exists, err = tableExists(mys.db, tableName) + assert.Nil(t, err) + assert.True(t, exists) + + // Drop the state table + dropTable(t, mys.db, tableName) }) t.Run("Get Set Delete one item", func(t *testing.T) { t.Parallel() - setGetUpdateDeleteOneItem(t, mys) + + // Validates setting one item, getting it, and deleting it. + key := randomKey() + value := &fakeItem{Color: "yellow"} + + setItem(t, mys, key, value, nil) + + getResponse, outputObject := getItem(t, mys, key) + assert.Equal(t, value, outputObject) + + newValue := &fakeItem{Color: "green"} + setItem(t, mys, key, newValue, getResponse.ETag) + getResponse, outputObject = getItem(t, mys, key) + assert.Equal(t, newValue, outputObject) + + deleteItem(t, mys, key, getResponse.ETag) }) t.Run("Get item that does not exist", func(t *testing.T) { t.Parallel() - getItemThatDoesNotExist(t, mys) + + // Validates the behavior of retrieving an item that does not exist. + key := randomKey() + response, outputObject := getItem(t, mys, key) + assert.Nil(t, response.Data) + assert.Equal(t, "", outputObject.Color) }) t.Run("Get item with no key fails", func(t *testing.T) { t.Parallel() - getItemWithNoKey(t, mys) + + // Validates that attempting a Get operation without providing a key will return an error. + getReq := &state.GetRequest{ + Key: "", + } + + response, getErr := mys.Get(getReq) + assert.NotNil(t, getErr) + assert.Nil(t, response) }) t.Run("Set updates the updatedate field", func(t *testing.T) { t.Parallel() - setUpdatesTheUpdatedateField(t, mys) + + // Proves that the updatedate is set for an + // update, and set upon insert. The updatedate is used as the eTag so must be + // set. It is also auto updated on update by MySQL. + key := randomKey() + value := &fakeItem{Color: "orange"} + setItem(t, mys, key, value, nil) + + // insertdate and updatedate should have a value + _, insertdate, updatedate, eTag := getRowData(t, key) + assert.NotNil(t, insertdate, "insertdate was not set") + assert.NotNil(t, updatedate, "updatedate was not set") + + // insertdate should not change, updatedate should have a value + value = &fakeItem{Color: "aqua"} + setItem(t, mys, key, value, nil) + _, newinsertdate, _, newETag := getRowData(t, key) + assert.Equal(t, insertdate, newinsertdate, "InsertDate was changed") + assert.NotEqual(t, eTag, newETag, "eTag was not updated") + + deleteItem(t, mys, key, nil) }) t.Run("Set item with no key fails", func(t *testing.T) { t.Parallel() - setItemWithNoKey(t, mys) + + setReq := &state.SetRequest{ + Key: "", + } + + err := mys.Set(setReq) + assert.NotNil(t, err, "Error was not nil when setting item with no key.") }) t.Run("Bulk set and bulk delete", func(t *testing.T) { @@ -116,255 +252,300 @@ func TestMySQLIntegration(t *testing.T) { t.Run("Update and delete with eTag succeeds", func(t *testing.T) { t.Parallel() - updateAndDeleteWithETagSucceeds(t, mys) + + // Create and retrieve new item + key := randomKey() + value := &fakeItem{Color: "hazel"} + setItem(t, mys, key, value, nil) + getResponse, _ := getItem(t, mys, key) + assert.NotNil(t, getResponse.ETag) + + // Change the value and compare + value.Color = "purple" + setItem(t, mys, key, value, getResponse.ETag) + updateResponse, updatedItem := getItem(t, mys, key) + assert.Equal(t, value, updatedItem, "Item should have been updated") + assert.NotEqual(t, getResponse.ETag, updateResponse.ETag, + "ETag should change when item is updated") + + // Delete + deleteItem(t, mys, key, updateResponse.ETag) + + assert.False(t, storeItemExists(t, key), "Item is not in the data store") }) t.Run("Update with old eTag fails", func(t *testing.T) { t.Parallel() - updateWithOldETagFails(t, mys) + + // Create and retrieve new item + key := randomKey() + value := &fakeItem{Color: "gray"} + setItem(t, mys, key, value, nil) + + getResponse, _ := getItem(t, mys, key) + assert.NotNil(t, getResponse.ETag) + originalEtag := getResponse.ETag + + // Change the value and get the updated eTag + newValue := &fakeItem{Color: "silver"} + setItem(t, mys, key, newValue, originalEtag) + + _, updatedItem := getItem(t, mys, key) + assert.Equal(t, newValue, updatedItem) + + // Update again with the original eTag - expect update failure + newValue = &fakeItem{Color: "maroon"} + setReq := &state.SetRequest{ + Key: key, + ETag: originalEtag, + Value: newValue, + } + + err := mys.Set(setReq) + assert.NotNil(t, err, "Error was not thrown using old eTag") }) t.Run("Insert with eTag fails", func(t *testing.T) { t.Parallel() - newItemWithEtagFails(t, mys) + + value := &fakeItem{Color: "teal"} + invalidETag := "12345" + + setReq := &state.SetRequest{ + Key: randomKey(), + ETag: &invalidETag, + Value: value, + } + + err := mys.Set(setReq) + assert.NotNil(t, err) }) t.Run("Delete with invalid eTag fails", func(t *testing.T) { t.Parallel() - deleteWithInvalidEtagFails(t, mys) + + // Create new item + key := randomKey() + value := &fakeItem{Color: "mauve"} + setItem(t, mys, key, value, nil) + + eTag := "1234" + + // Delete the item with a fake eTag + deleteReq := &state.DeleteRequest{ + Key: key, + ETag: &eTag, + } + + err := mys.Delete(deleteReq) + assert.NotNil(t, err) }) t.Run("Delete item with no key fails", func(t *testing.T) { t.Parallel() - deleteWithNoKeyFails(t, mys) + + deleteReq := &state.DeleteRequest{ + Key: "", + } + + err := mys.Delete(deleteReq) + assert.NotNil(t, err) }) t.Run("Delete an item that does not exist", func(t *testing.T) { t.Parallel() - deleteItemThatDoesNotExist(t, mys) + + // Delete the item with a key not in the store + deleteReq := &state.DeleteRequest{ + Key: randomKey(), + } + + err := mys.Delete(deleteReq) + assert.Nil(t, err) + }) + + t.Run("Inserts with first-write-wins", func(t *testing.T) { + t.Parallel() + + // Insert without an etag should work on new keys + key := randomKey() + setReq := &state.SetRequest{ + Key: key, + Value: &fakeItem{Color: "teal"}, + Options: state.SetStateOption{ + Concurrency: state.FirstWrite, + }, + } + + err := mys.Set(setReq) + assert.NoError(t, err) + + // Get the etag + getResponse, _ := getItem(t, mys, key) + assert.NotNil(t, getResponse) + assert.NotNil(t, getResponse.ETag) + originalEtag := getResponse.ETag + + // Insert without an etag should fail on existing keys + setReq = &state.SetRequest{ + Key: key, + Value: &fakeItem{Color: "gray or grey"}, + Options: state.SetStateOption{ + Concurrency: state.FirstWrite, + }, + } + + err = mys.Set(setReq) + assert.ErrorContains(t, err, "Duplicate entry") + + // Insert with invalid etag should fail on existing keys + setReq = &state.SetRequest{ + Key: key, + Value: &fakeItem{Color: "pink"}, + ETag: ptr.Of("no-etag"), + Options: state.SetStateOption{ + Concurrency: state.FirstWrite, + }, + } + + err = mys.Set(setReq) + assert.ErrorContains(t, err, "possible etag mismatch") + + // Insert with valid etag should succeed on existing keys + setReq = &state.SetRequest{ + Key: key, + Value: &fakeItem{Color: "scarlet"}, + ETag: originalEtag, + Options: state.SetStateOption{ + Concurrency: state.FirstWrite, + }, + } + + err = mys.Set(setReq) + assert.NoError(t, err) + + // Insert with an etag should fail on new keys + setReq = &state.SetRequest{ + Key: randomKey(), + Value: &fakeItem{Color: "greige"}, + ETag: ptr.Of("myetag"), + Options: state.SetStateOption{ + Concurrency: state.FirstWrite, + }, + } + + err = mys.Set(setReq) + assert.ErrorContains(t, err, "possible etag mismatch") }) t.Run("Multi with delete and set", func(t *testing.T) { t.Parallel() - multiWithDeleteAndSet(t, mys) + + var operations []state.TransactionalStateOperation + var deleteRequests []state.DeleteRequest + for i := 0; i < 3; i++ { + req := state.DeleteRequest{Key: randomKey()} + + // Add the item to the database + setItem(t, mys, req.Key, randomJSON(), nil) + + // Add the item to a slice of delete requests + deleteRequests = append(deleteRequests, req) + + // Add the item to the multi transaction request + operations = append(operations, state.TransactionalStateOperation{ + Operation: state.Delete, + Request: req, + }) + } + + // Create the set requests + var setRequests []state.SetRequest + for i := 0; i < 3; i++ { + req := state.SetRequest{ + Key: randomKey(), + Value: randomJSON(), + } + setRequests = append(setRequests, req) + operations = append(operations, state.TransactionalStateOperation{ + Operation: state.Upsert, + Request: req, + }) + } + + err := mys.Multi(&state.TransactionalStateRequest{ + Operations: operations, + }) + assert.Nil(t, err) + + for _, delete := range deleteRequests { + assert.False(t, storeItemExists(t, delete.Key)) + } + + for _, set := range setRequests { + assert.True(t, storeItemExists(t, set.Key)) + deleteItem(t, mys, set.Key, nil) + } }) t.Run("Multi with delete only", func(t *testing.T) { t.Parallel() - multiWithDeleteOnly(t, mys) + + var operations []state.TransactionalStateOperation + var deleteRequests []state.DeleteRequest + for i := 0; i < 3; i++ { + req := state.DeleteRequest{Key: randomKey()} + + // Add the item to the database + setItem(t, mys, req.Key, randomJSON(), nil) + + // Add the item to a slice of delete requests + deleteRequests = append(deleteRequests, req) + + // Add the item to the multi transaction request + operations = append(operations, state.TransactionalStateOperation{ + Operation: state.Delete, + Request: req, + }) + } + + err := mys.Multi(&state.TransactionalStateRequest{ + Operations: operations, + }) + assert.Nil(t, err) + + for _, delete := range deleteRequests { + assert.False(t, storeItemExists(t, delete.Key)) + } }) t.Run("Multi with set only", func(t *testing.T) { t.Parallel() - multiWithSetOnly(t, mys) - }) -} -func multiWithSetOnly(t *testing.T, mys *MySQL) { - var operations []state.TransactionalStateOperation - var setRequests []state.SetRequest - for i := 0; i < 3; i++ { - req := state.SetRequest{ - Key: randomKey(), - Value: randomJSON(), + var operations []state.TransactionalStateOperation + var setRequests []state.SetRequest + for i := 0; i < 3; i++ { + req := state.SetRequest{ + Key: randomKey(), + Value: randomJSON(), + } + setRequests = append(setRequests, req) + operations = append(operations, state.TransactionalStateOperation{ + Operation: state.Upsert, + Request: req, + }) } - setRequests = append(setRequests, req) - operations = append(operations, state.TransactionalStateOperation{ - Operation: state.Upsert, - Request: req, + + err := mys.Multi(&state.TransactionalStateRequest{ + Operations: operations, }) - } + assert.Nil(t, err) - err := mys.Multi(&state.TransactionalStateRequest{ - Operations: operations, - }) - assert.Nil(t, err) - - for _, set := range setRequests { - assert.True(t, storeItemExists(t, set.Key)) - deleteItem(t, mys, set.Key, nil) - } -} - -func multiWithDeleteOnly(t *testing.T, mys *MySQL) { - var operations []state.TransactionalStateOperation - var deleteRequests []state.DeleteRequest - for i := 0; i < 3; i++ { - req := state.DeleteRequest{Key: randomKey()} - - // Add the item to the database - setItem(t, mys, req.Key, randomJSON(), nil) - - // Add the item to a slice of delete requests - deleteRequests = append(deleteRequests, req) - - // Add the item to the multi transaction request - operations = append(operations, state.TransactionalStateOperation{ - Operation: state.Delete, - Request: req, - }) - } - - err := mys.Multi(&state.TransactionalStateRequest{ - Operations: operations, - }) - assert.Nil(t, err) - - for _, delete := range deleteRequests { - assert.False(t, storeItemExists(t, delete.Key)) - } -} - -func multiWithDeleteAndSet(t *testing.T, mys *MySQL) { - var operations []state.TransactionalStateOperation - var deleteRequests []state.DeleteRequest - for i := 0; i < 3; i++ { - req := state.DeleteRequest{Key: randomKey()} - - // Add the item to the database - setItem(t, mys, req.Key, randomJSON(), nil) - - // Add the item to a slice of delete requests - deleteRequests = append(deleteRequests, req) - - // Add the item to the multi transaction request - operations = append(operations, state.TransactionalStateOperation{ - Operation: state.Delete, - Request: req, - }) - } - - // Create the set requests - var setRequests []state.SetRequest - for i := 0; i < 3; i++ { - req := state.SetRequest{ - Key: randomKey(), - Value: randomJSON(), + for _, set := range setRequests { + assert.True(t, storeItemExists(t, set.Key)) + deleteItem(t, mys, set.Key, nil) } - setRequests = append(setRequests, req) - operations = append(operations, state.TransactionalStateOperation{ - Operation: state.Upsert, - Request: req, - }) - } - - err := mys.Multi(&state.TransactionalStateRequest{ - Operations: operations, }) - assert.Nil(t, err) - - for _, delete := range deleteRequests { - assert.False(t, storeItemExists(t, delete.Key)) - } - - for _, set := range setRequests { - assert.True(t, storeItemExists(t, set.Key)) - deleteItem(t, mys, set.Key, nil) - } -} - -func deleteItemThatDoesNotExist(t *testing.T, mys *MySQL) { - // Delete the item with a key not in the store - deleteReq := &state.DeleteRequest{ - Key: randomKey(), - } - - err := mys.Delete(deleteReq) - assert.Nil(t, err) -} - -func deleteWithNoKeyFails(t *testing.T, mys *MySQL) { - deleteReq := &state.DeleteRequest{ - Key: "", - } - - err := mys.Delete(deleteReq) - assert.NotNil(t, err) -} - -func deleteWithInvalidEtagFails(t *testing.T, mys *MySQL) { - // Create new item - key := randomKey() - value := &fakeItem{Color: "mauve"} - setItem(t, mys, key, value, nil) - - eTag := "1234" - - // Delete the item with a fake eTag - deleteReq := &state.DeleteRequest{ - Key: key, - ETag: &eTag, - } - - err := mys.Delete(deleteReq) - assert.NotNil(t, err) -} - -// newItemWithEtagFails creates a new item and also supplies an ETag, which is -// invalid - expect failure. -func newItemWithEtagFails(t *testing.T, mys *MySQL) { - value := &fakeItem{Color: "teal"} - invalidETag := "12345" - - setReq := &state.SetRequest{ - Key: randomKey(), - ETag: &invalidETag, - Value: value, - } - - err := mys.Set(setReq) - assert.NotNil(t, err) -} - -func updateWithOldETagFails(t *testing.T, mys *MySQL) { - // Create and retrieve new item - key := randomKey() - value := &fakeItem{Color: "gray"} - setItem(t, mys, key, value, nil) - - getResponse, _ := getItem(t, mys, key) - assert.NotNil(t, getResponse.ETag) - originalEtag := getResponse.ETag - - // Change the value and get the updated eTag - newValue := &fakeItem{Color: "silver"} - setItem(t, mys, key, newValue, originalEtag) - - _, updatedItem := getItem(t, mys, key) - assert.Equal(t, newValue, updatedItem) - - // Update again with the original eTag - expect update failure - newValue = &fakeItem{Color: "maroon"} - setReq := &state.SetRequest{ - Key: key, - ETag: originalEtag, - Value: newValue, - } - - err := mys.Set(setReq) - assert.NotNil(t, err, "Error was not thrown using old eTag") -} - -func updateAndDeleteWithETagSucceeds(t *testing.T, mys *MySQL) { - // Create and retrieve new item - key := randomKey() - value := &fakeItem{Color: "hazel"} - setItem(t, mys, key, value, nil) - getResponse, _ := getItem(t, mys, key) - assert.NotNil(t, getResponse.ETag) - - // Change the value and compare - value.Color = "purple" - setItem(t, mys, key, value, getResponse.ETag) - updateResponse, updatedItem := getItem(t, mys, key) - assert.Equal(t, value, updatedItem, "Item should have been updated") - assert.NotEqual(t, getResponse.ETag, updateResponse.ETag, - "ETag should change when item is updated") - - // Delete - deleteItem(t, mys, key, updateResponse.ETag) - - assert.False(t, storeItemExists(t, key), "Item is not in the data store") } // Tests valid bulk sets and deletes. @@ -400,159 +581,6 @@ func testBulkSetAndBulkDelete(t *testing.T, mys *MySQL) { assert.False(t, storeItemExists(t, setReq[1].Key)) } -func setItemWithNoKey(t *testing.T, mys *MySQL) { - setReq := &state.SetRequest{ - Key: "", - } - - err := mys.Set(setReq) - assert.NotNil(t, err, "Error was not nil when setting item with no key.") -} - -// setUpdatesTheUpdatedateField proves that the updatedate is set for an -// update, and set upon insert. The updatedate is used as the eTag so must be -// set. It is also auto updated on update by MySQL. -func setUpdatesTheUpdatedateField(t *testing.T, mys *MySQL) { - key := randomKey() - value := &fakeItem{Color: "orange"} - setItem(t, mys, key, value, nil) - - // insertdate and updatedate should have a value - _, insertdate, updatedate, eTag := getRowData(t, key) - assert.NotNil(t, insertdate, "insertdate was not set") - assert.NotNil(t, updatedate, "updatedate was not set") - - // insertdate should not change, updatedate should have a value - value = &fakeItem{Color: "aqua"} - setItem(t, mys, key, value, nil) - _, newinsertdate, _, newETag := getRowData(t, key) - assert.Equal(t, insertdate, newinsertdate, "InsertDate was changed") - assert.NotEqual(t, eTag, newETag, "eTag was not updated") - - deleteItem(t, mys, key, nil) -} - -// getItemWithNoKey validates that attempting a Get operation without providing -// a key will return an error. -func getItemWithNoKey(t *testing.T, mys *MySQL) { - getReq := &state.GetRequest{ - Key: "", - } - - response, getErr := mys.Get(getReq) - assert.NotNil(t, getErr) - assert.Nil(t, response) -} - -// getItemThatDoesNotExist validates the behavior of retrieving an item that -// does not exist. -func getItemThatDoesNotExist(t *testing.T, mys *MySQL) { - key := randomKey() - response, outputObject := getItem(t, mys, key) - assert.Nil(t, response.Data) - assert.Equal(t, "", outputObject.Color) -} - -// setGetUpdateDeleteOneItem validates setting one item, getting it, and -// deleting it. -func setGetUpdateDeleteOneItem(t *testing.T, mys *MySQL) { - key := randomKey() - value := &fakeItem{Color: "yellow"} - - setItem(t, mys, key, value, nil) - - getResponse, outputObject := getItem(t, mys, key) - assert.Equal(t, value, outputObject) - - newValue := &fakeItem{Color: "green"} - setItem(t, mys, key, newValue, getResponse.ETag) - getResponse, outputObject = getItem(t, mys, key) - assert.Equal(t, newValue, outputObject) - - deleteItem(t, mys, key, getResponse.ETag) -} - -// testCreateTable tests the ability to create the state table. -func testCreateTable(t *testing.T, mys *MySQL) { - tableName := "test_state" - - // Drop the table if it already exists - exists, err := tableExists(mys.db, tableName) - assert.Nil(t, err) - if exists { - dropTable(t, mys.db, tableName) - } - - // Create the state table and test for its existence - // There should be no error - err = mys.ensureStateTable(tableName) - assert.Nil(t, err) - - // Now create it and make sure there are no errors - exists, err = tableExists(mys.db, tableName) - assert.Nil(t, err) - assert.True(t, exists) - - // Drop the state table - dropTable(t, mys.db, tableName) -} - -// testInitConfiguration tests valid and invalid config settings. -func testInitConfiguration(t *testing.T) { - logger := logger.NewLogger("test") - - // define a struct the contain the metadata and create - // two instances of it in a tests slice - tests := []struct { - name string - props map[string]string - expectedErr string - }{ - { - name: "Empty", - props: map[string]string{}, - expectedErr: errMissingConnectionString, - }, - { - name: "Valid connection string", - props: map[string]string{ - connectionStringKey: getConnectionString(""), - pemPathKey: getPemPath(), - }, - expectedErr: "", - }, - { - name: "Valid table name", - props: map[string]string{ - connectionStringKey: getConnectionString(""), - pemPathKey: getPemPath(), - tableNameKey: "stateStore", - }, - expectedErr: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - p := NewMySQLStateStore(logger) - defer p.Close() - - metadata := state.Metadata{ - Properties: tt.props, - } - - err := p.Init(metadata) - - if tt.expectedErr == "" { - assert.Nil(t, err) - } else { - assert.NotNil(t, err) - assert.Equal(t, err.Error(), tt.expectedErr) - } - }) - } -} - func dropTable(t *testing.T, db *sql.DB, tableName string) { _, err := db.Exec(fmt.Sprintf( `DROP TABLE %s;`, diff --git a/state/mysql/mysql_test.go b/state/mysql/mysql_test.go index 876b9a284..1da90c13f 100644 --- a/state/mysql/mysql_test.go +++ b/state/mysql/mysql_test.go @@ -61,7 +61,7 @@ func TestFinishInitHandlesSchemaExistsError(t *testing.T) { m.mock1.ExpectQuery("SELECT EXISTS").WillReturnError(expectedErr) // Act - actualErr := m.mySQL.finishInit(m.mySQL.db, nil) + actualErr := m.mySQL.finishInit(m.mySQL.db) // Assert assert.NotNil(t, actualErr, "now error returned") @@ -80,26 +80,13 @@ func TestFinishInitHandlesDatabaseCreateError(t *testing.T) { m.mock1.ExpectExec("CREATE DATABASE").WillReturnError(expectedErr) // Act - actualErr := m.mySQL.finishInit(m.mySQL.db, nil) + actualErr := m.mySQL.finishInit(m.mySQL.db) // Assert assert.NotNil(t, actualErr, "now error returned") assert.Equal(t, "createDatabaseError", actualErr.Error(), "wrong error") } -func TestFinishInitHandlesOpenError(t *testing.T) { - // Arrange - m, _ := mockDatabase(t) - defer m.mySQL.Close() - - // Act - err := m.mySQL.finishInit(m.mySQL.db, fmt.Errorf("failed to open database")) - - // Assert - assert.NotNil(t, err, "now error returned") - assert.Equal(t, "failed to open database", err.Error(), "wrong error") -} - func TestFinishInitHandlesPingError(t *testing.T) { // Arrange m, _ := mockDatabase(t) @@ -117,7 +104,7 @@ func TestFinishInitHandlesPingError(t *testing.T) { m.mock2.ExpectPing().WillReturnError(expectedErr) // Act - actualErr := m.mySQL.finishInit(m.mySQL.db, nil) + actualErr := m.mySQL.finishInit(m.mySQL.db) // Assert assert.NotNil(t, actualErr, "now error returned") @@ -145,7 +132,7 @@ func TestFinishInitHandlesTableExistsError(t *testing.T) { m.mock2.ExpectQuery("SELECT EXISTS").WillReturnError(fmt.Errorf("tableExistsError")) // Act - err := m.mySQL.finishInit(m.mySQL.db, nil) + err := m.mySQL.finishInit(m.mySQL.db) // Assert assert.NotNil(t, err, "no error returned") @@ -667,6 +654,21 @@ func TestInitSetsTableName(t *testing.T) { assert.Equal(t, "stateStore", m.mySQL.tableName, "table name did not default") } +func TestInitInvalidTableName(t *testing.T) { + // Arrange + t.Parallel() + m, _ := mockDatabase(t) + metadata := &state.Metadata{ + Properties: map[string]string{connectionStringKey: "", tableNameKey: "🙃"}, + } + + // Act + err := m.mySQL.Init(*metadata) + + // Assert + assert.ErrorContains(t, err, "table name '🙃' is not valid") +} + func TestInitSetsSchemaName(t *testing.T) { // Arrange t.Parallel() @@ -683,6 +685,21 @@ func TestInitSetsSchemaName(t *testing.T) { assert.Equal(t, "stateStoreSchema", m.mySQL.schemaName, "table name did not default") } +func TestInitInvalidSchemaName(t *testing.T) { + // Arrange + t.Parallel() + m, _ := mockDatabase(t) + metadata := &state.Metadata{ + Properties: map[string]string{connectionStringKey: "", schemaNameKey: "?"}, + } + + // Act + err := m.mySQL.Init(*metadata) + + // Assert + assert.ErrorContains(t, err, "schema name '?' is not valid") +} + // This state store does not support BulkGet so it must return false and // nil nil. func TestBulkGetReturnsNil(t *testing.T) { @@ -1018,3 +1035,25 @@ func (f *fakeMySQLFactory) Open(connectionString string) (*sql.DB, error) { func (f *fakeMySQLFactory) RegisterTLSConfig(pemPath string) error { return f.registerErr } + +func TestValidIdentifier(t *testing.T) { + tests := []struct { + name string + arg string + want bool + }{ + {name: "empty string", arg: "", want: false}, + {name: "valid characters only", arg: "acz_039_AZS", want: true}, + {name: "invalid ASCII characters 1", arg: "$", want: false}, + {name: "invalid ASCII characters 2", arg: "*", want: false}, + {name: "invalid ASCII characters 3", arg: "hello world", want: false}, + {name: "non-ASCII characters", arg: "🙃", want: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := validIdentifier(tt.arg); got != tt.want { + t.Errorf("validIdentifier() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/tests/config/state/tests.yml b/tests/config/state/tests.yml index 78e73b6af..2448c0744 100644 --- a/tests/config/state/tests.yml +++ b/tests/config/state/tests.yml @@ -24,7 +24,7 @@ components: operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "query" ] - component: mysql allOperations: false - operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag" ] + operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "first-write" ] - component: azure.tablestorage.storage allOperations: false operations: ["set", "get", "delete", "etag", "bulkset", "bulkdelete", "first-write"]