diff --git a/state/mysql/mysql.go b/state/mysql/mysql.go index 56c60fc24..408306f91 100644 --- a/state/mysql/mysql.go +++ b/state/mysql/mysql.go @@ -341,7 +341,28 @@ func (m *MySQL) deleteValue(req *state.DeleteRequest) error { // BulkDelete removes multiple entries from the store // Store Interface. func (m *MySQL) BulkDelete(req []state.DeleteRequest) error { - return m.executeMulti(nil, req) + m.logger.Debug("Executing BulkDelete request") + + tx, err := m.db.Begin() + if err != nil { + return err + } + + if len(req) > 0 { + for _, d := range req { + da := d // Fix for goSec G601: Implicit memory aliasing in for loop. + err = m.Delete(&da) + if err != nil { + tx.Rollback() + + return err + } + } + } + + err = tx.Commit() + + return err } // Get returns an entity from store @@ -482,33 +503,66 @@ func (m *MySQL) setValue(req *state.SetRequest) error { // BulkSet adds/updates multiple entities on store // Store Interface. func (m *MySQL) BulkSet(req []state.SetRequest) error { - return m.executeMulti(req, nil) + m.logger.Debug("Executing BulkSet request") + + tx, err := m.db.Begin() + if err != nil { + return err + } + + if len(req) > 0 { + for _, s := range req { + sa := s // Fix for goSec G601: Implicit memory aliasing in for loop. + err = m.Set(&sa) + if err != nil { + tx.Rollback() + + return err + } + } + } + + err = tx.Commit() + + return err } // Multi handles multiple transactions. // TransactionalStore Interface. func (m *MySQL) Multi(request *state.TransactionalStateRequest) error { - var sets []state.SetRequest - var deletes []state.DeleteRequest + m.logger.Debug("Executing Multi request") + + tx, err := m.db.Begin() + if err != nil { + return err + } for _, req := range request.Operations { switch req.Operation { case state.Upsert: - setReq, ok := req.Request.(state.SetRequest) + setReq, err := m.getSets(req) + if err != nil { + tx.Rollback() + return err + } - if ok { - sets = append(sets, setReq) - } else { - return fmt.Errorf("expecting set request") + err = m.Set(&setReq) + if err != nil { + tx.Rollback() + return err } case state.Delete: - delReq, ok := req.Request.(state.DeleteRequest) + delReq, err := m.getDeletes(req) + if err != nil { + tx.Rollback() + return err + } - if ok { - deletes = append(deletes, delReq) - } else { - return fmt.Errorf("expecting delete request") + err = m.Delete(&delReq) + if err != nil { + tx.Rollback() + return err } default: @@ -516,11 +570,35 @@ func (m *MySQL) Multi(request *state.TransactionalStateRequest) error { } } - if len(sets) > 0 || len(deletes) > 0 { - return m.executeMulti(sets, deletes) + return tx.Commit() +} + +// Returns the set requests. +func (m *MySQL) getSets(req state.TransactionalStateOperation) (state.SetRequest, error) { + setReq, ok := req.Request.(state.SetRequest) + if !ok { + return setReq, fmt.Errorf("expecting set request") } - return nil + if setReq.Key == "" { + return setReq, fmt.Errorf("missing key in upsert operation") + } + + return setReq, nil +} + +// Returns the delete requests. +func (m *MySQL) getDeletes(req state.TransactionalStateOperation) (state.DeleteRequest, error) { + delReq, ok := req.Request.(state.DeleteRequest) + if !ok { + return delReq, fmt.Errorf("expecting delete request") + } + + if delReq.Key == "" { + return delReq, fmt.Errorf("missing key in upsert operation") + } + + return delReq, nil } // BulkGet performs a bulks get operations. @@ -538,40 +616,3 @@ func (m *MySQL) Close() error { return nil } - -func (m *MySQL) executeMulti(sets []state.SetRequest, deletes []state.DeleteRequest) error { - m.logger.Debug("Executing multiple MySql operations") - - tx, err := m.db.Begin() - if err != nil { - return err - } - - if len(deletes) > 0 { - for _, d := range deletes { - da := d // Fix for goSec G601: Implicit memory aliasing in for loop. - err = m.Delete(&da) - if err != nil { - tx.Rollback() - - return err - } - } - } - - if len(sets) > 0 { - for _, s := range sets { - sa := s // Fix for goSec G601: Implicit memory aliasing in for loop. - err = m.Set(&sa) - if err != nil { - tx.Rollback() - - return err - } - } - } - - err = tx.Commit() - - return err -} diff --git a/state/mysql/mysql_test.go b/state/mysql/mysql_test.go index a546b40ce..5ba5ca0e1 100644 --- a/state/mysql/mysql_test.go +++ b/state/mysql/mysql_test.go @@ -171,7 +171,7 @@ func TestExecuteMultiCannotBeginTransaction(t *testing.T) { m.mock1.ExpectBegin().WillReturnError(fmt.Errorf("beginError")) // Act - err := m.mySQL.executeMulti(nil, nil) + err := m.mySQL.Multi(nil) // Assert assert.NotNil(t, err, "no error returned") @@ -222,15 +222,27 @@ func TestExecuteMultiCommitSetsAndDeletes(t *testing.T) { defer m.mySQL.Close() m.mock1.ExpectBegin() - m.mock1.ExpectExec("DELETE FROM").WillReturnResult(sqlmock.NewResult(0, 1)) m.mock1.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(0, 1)) + m.mock1.ExpectExec("DELETE FROM").WillReturnResult(sqlmock.NewResult(0, 1)) m.mock1.ExpectCommit() - sets := []state.SetRequest{createSetRequest()} - deletes := []state.DeleteRequest{createDeleteRequest()} + setOperation := state.TransactionalStateOperation{ + Request: createSetRequest(), + Operation: state.Upsert, + } + + deleteOperation := state.TransactionalStateOperation{ + Request: createDeleteRequest(), + Operation: state.Delete, + } + + request := state.TransactionalStateRequest{ + Operations: []state.TransactionalStateOperation{setOperation, deleteOperation}, + Metadata: map[string]string{}, + } // Act - err := m.mySQL.executeMulti(sets, deletes) + err := m.mySQL.Multi(&request) // Assert assert.Nil(t, err, "error returned") @@ -685,12 +697,16 @@ func TestBulkGetReturnsNil(t *testing.T) { assert.False(t, supported, `returned supported`) } -func TestMultiWithNoRequestsReturnsNil(t *testing.T) { +func TestMultiWithNoRequestsDoesNothing(t *testing.T) { // Arrange t.Parallel() m, _ := mockDatabase(t) var ops []state.TransactionalStateOperation + // no operations expected + m.mock1.ExpectBegin() + m.mock1.ExpectCommit() + // Act err := m.mySQL.Multi(&state.TransactionalStateRequest{ Operations: ops, @@ -780,6 +796,30 @@ func TestInvalidMultiSetRequest(t *testing.T) { assert.NotNil(t, err) } +func TestInvalidMultiSetRequestNoKey(t *testing.T) { + // Arrange + t.Parallel() + m, _ := mockDatabase(t) + var ops []state.TransactionalStateOperation + + ops = append(ops, state.TransactionalStateOperation{ + Operation: state.Upsert, + Request: state.SetRequest{ + // empty key is not valid for Upsert operation + Key: "", + Value: "value1", + }, + }) + + // Act + err := m.mySQL.Multi(&state.TransactionalStateRequest{ + Operations: ops, + }) + + // Assert + assert.NotNil(t, err) +} + func TestValidMultiDeleteRequest(t *testing.T) { // Arrange t.Parallel() @@ -825,6 +865,71 @@ func TestInvalidMultiDeleteRequest(t *testing.T) { assert.NotNil(t, err) } +func TestInvalidMultiDeleteRequestNoKey(t *testing.T) { + // Arrange + t.Parallel() + m, _ := mockDatabase(t) + var ops []state.TransactionalStateOperation + + ops = append(ops, state.TransactionalStateOperation{ + Operation: state.Delete, + Request: state.DeleteRequest{ + // empty key is not valid for Delete operation + Key: "", + }, + }) + + // Act + err := m.mySQL.Multi(&state.TransactionalStateRequest{ + Operations: ops, + }) + + // Assert + assert.NotNil(t, err) +} + +func TestMultiOperationOrder(t *testing.T) { + // Arrange + t.Parallel() + m, _ := mockDatabase(t) + var ops []state.TransactionalStateOperation + + // In a transaction with multiple operations, + // the order of operations must be respected. + ops = append(ops, + state.TransactionalStateOperation{ + Operation: state.Upsert, + Request: state.SetRequest{Key: "k1", Value: "v1"}, + }, + state.TransactionalStateOperation{ + Operation: state.Delete, + Request: state.DeleteRequest{Key: "k1"}, + }, + state.TransactionalStateOperation{ + Operation: state.Upsert, + Request: state.SetRequest{Key: "k2", Value: "v2"}, + }, + ) + + // expected to run the operations in sequence + m.mock1.ExpectBegin() + m.mock1.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(0, 1)) + m.mock1.ExpectExec("DELETE FROM").WithArgs("k1").WillReturnResult(sqlmock.NewResult(0, 1)) + m.mock1.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(0, 1)) + m.mock1.ExpectCommit() + + // Act + err := m.mySQL.Multi(&state.TransactionalStateRequest{ + Operations: ops, + }) + + // Assert + assert.Nil(t, err) + + err = m.mock1.ExpectationsWereMet() + assert.Nil(t, err) +} + func createSetRequest() state.SetRequest { return state.SetRequest{ Key: randomKey(),