diff --git a/state/mysql/mysql.go b/state/mysql/mysql.go index 8115bc44f..a80bfc414 100644 --- a/state/mysql/mysql.go +++ b/state/mysql/mysql.go @@ -323,12 +323,12 @@ func tableExists(db *sql.DB, tableName string, timeout time.Duration) (bool, err // Delete removes an entity from the store // Store Interface. func (m *MySQL) Delete(req *state.DeleteRequest) error { - return state.DeleteWithOptions(m.deleteValue, req) + return m.deleteValue(m.db, req) } // deleteValue is an internal implementation of delete to enable passing the // logic to state.DeleteWithRetries as a func. -func (m *MySQL) deleteValue(req *state.DeleteRequest) error { +func (m *MySQL) deleteValue(querier querier, req *state.DeleteRequest) error { m.logger.Debug("Deleting state value from MySql") if req.Key == "" { @@ -343,11 +343,11 @@ func (m *MySQL) deleteValue(req *state.DeleteRequest) error { defer cancel() if req.ETag == nil || *req.ETag == "" { - result, err = m.db.ExecContext(ctx, fmt.Sprintf( + result, err = querier.ExecContext(ctx, fmt.Sprintf( `DELETE FROM %s WHERE id = ?`, m.tableName), req.Key) } else { - result, err = m.db.ExecContext(ctx, fmt.Sprintf( + result, err = querier.ExecContext(ctx, fmt.Sprintf( `DELETE FROM %s WHERE id = ? and eTag = ?`, m.tableName), req.Key, *req.ETag) } @@ -381,10 +381,12 @@ func (m *MySQL) BulkDelete(req []state.DeleteRequest) error { if len(req) > 0 { for _, d := range req { da := d // Fix for goSec G601: Implicit memory aliasing in for loop. - err = m.Delete(&da) + err = m.deleteValue(tx, &da) if err != nil { - tx.Rollback() - + rollbackErr := tx.Rollback() + if rollbackErr != nil { + m.logger.Errorf("Error rolling back transaction: %v", rollbackErr) + } return err } } @@ -460,12 +462,12 @@ func (m *MySQL) Get(req *state.GetRequest) (*state.GetResponse, error) { // Set adds/updates an entity on store // Store Interface. func (m *MySQL) Set(req *state.SetRequest) error { - return state.SetWithOptions(m.setValue, req) + return m.setValue(m.db, req) } // setValue is an internal implementation of set to enable passing the logic // to state.SetWithRetries as a func. -func (m *MySQL) setValue(req *state.SetRequest) error { +func (m *MySQL) setValue(querier querier, req *state.SetRequest) error { m.logger.Debug("Setting state value in MySql") err := state.CheckRequestOptions(req.Options) @@ -494,7 +496,12 @@ func (m *MySQL) setValue(req *state.SetRequest) error { encB, _ := json.Marshal(v) enc := string(encB) - eTag := uuid.New().String() + + eTagObj, err := uuid.NewRandom() + if err != nil { + return fmt.Errorf("failed to generate etag: %w", err) + } + eTag := eTagObj.String() var ( result sql.Result @@ -505,29 +512,26 @@ func (m *MySQL) setValue(req *state.SetRequest) error { if req.Options.Concurrency == state.FirstWrite && (req.ETag == nil || *req.ETag == "") { // With first-write-wins and no etag, we can insert the row only if it doesn't exist - //nolint:gosec query := fmt.Sprintf( `INSERT INTO %s (value, id, eTag, isbinary) VALUES (?, ?, ?, ?);`, m.tableName, // m.tableName is sanitized ) - result, err = m.db.ExecContext(ctx, query, enc, req.Key, eTag, isBinary) + result, err = querier.ExecContext(ctx, query, enc, req.Key, eTag, isBinary) } else if req.ETag != nil && *req.ETag != "" { // When an eTag is provided do an update - not insert - //nolint:gosec query := fmt.Sprintf( `UPDATE %s SET value = ?, eTag = ?, isbinary = ? WHERE id = ? AND eTag = ?;`, m.tableName, // m.tableName is sanitized ) - result, err = m.db.ExecContext(ctx, query, enc, eTag, isBinary, req.Key, *req.ETag) + result, err = querier.ExecContext(ctx, query, enc, eTag, isBinary, req.Key, *req.ETag) } else { // If this is a duplicate MySQL returns that two rows affected maxRows = 2 - //nolint:gosec query := fmt.Sprintf( `INSERT INTO %s (value, id, eTag, isbinary) VALUES (?, ?, ?, ?) on duplicate key update value=?, eTag=?, isbinary=?;`, m.tableName, // m.tableName is sanitized ) - result, err = m.db.ExecContext(ctx, query, enc, req.Key, eTag, isBinary, enc, eTag, isBinary) + result, err = querier.ExecContext(ctx, query, enc, req.Key, eTag, isBinary, enc, eTag, isBinary) } if err != nil { @@ -571,9 +575,12 @@ func (m *MySQL) BulkSet(req []state.SetRequest) error { if len(req) > 0 { for i := range req { - err = m.Set(&req[i]) + err = m.setValue(tx, &req[i]) if err != nil { - tx.Rollback() + rollbackErr := tx.Rollback() + if rollbackErr != nil { + m.logger.Errorf("Error rolling back transaction: %v", rollbackErr) + } return err } } @@ -597,26 +604,38 @@ func (m *MySQL) Multi(request *state.TransactionalStateRequest) error { case state.Upsert: setReq, err := m.getSets(req) if err != nil { - _ = tx.Rollback() + rollbackErr := tx.Rollback() + if rollbackErr != nil { + m.logger.Errorf("Error rolling back transaction: %v", rollbackErr) + } return err } - err = m.Set(&setReq) + err = m.setValue(tx, &setReq) if err != nil { - _ = tx.Rollback() + rollbackErr := tx.Rollback() + if rollbackErr != nil { + m.logger.Errorf("Error rolling back transaction: %v", rollbackErr) + } return err } case state.Delete: delReq, err := m.getDeletes(req) if err != nil { - _ = tx.Rollback() + rollbackErr := tx.Rollback() + if rollbackErr != nil { + m.logger.Errorf("Error rolling back transaction: %v", rollbackErr) + } return err } - err = m.Delete(&delReq) + err = m.deleteValue(tx, &delReq) if err != nil { - _ = tx.Rollback() + rollbackErr := tx.Rollback() + if rollbackErr != nil { + m.logger.Errorf("Error rolling back transaction: %v", rollbackErr) + } return err } @@ -694,3 +713,10 @@ func validIdentifier(v string) bool { } return true } + +// Interface for both sql.DB and sql.Tx +type querier interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row +} diff --git a/state/mysql/mysql_test.go b/state/mysql/mysql_test.go index d6f7b4f10..cdbab501e 100644 --- a/state/mysql/mysql_test.go +++ b/state/mysql/mysql_test.go @@ -249,7 +249,7 @@ func TestSetHandlesOptionsError(t *testing.T) { request.Options.Consistency = "Invalid" // Act - err := m.mySQL.setValue(&request) + err := m.mySQL.setValue(m.db, &request) // Assert assert.NotNil(t, err) @@ -284,7 +284,7 @@ func TestSetHandlesUpdate(t *testing.T) { request.ETag = &eTag // Act - err := m.mySQL.setValue(&request) + err := m.mySQL.setValue(m.db, &request) // Assert assert.Nil(t, err) @@ -303,7 +303,7 @@ func TestSetHandlesErr(t *testing.T) { request.ETag = &eTag // Act - err := m.mySQL.setValue(&request) + err := m.mySQL.setValue(m.db, &request) // Assert assert.NotNil(t, err) @@ -316,7 +316,7 @@ func TestSetHandlesErr(t *testing.T) { request := createSetRequest() // Act - err := m.mySQL.setValue(&request) + err := m.mySQL.setValue(m.db, &request) // Assert assert.NotNil(t, err) @@ -328,7 +328,7 @@ func TestSetHandlesErr(t *testing.T) { request := createSetRequest() // Act - err := m.mySQL.setValue(&request) + err := m.mySQL.setValue(m.db, &request) // Assert assert.Nil(t, err) @@ -339,7 +339,7 @@ func TestSetHandlesErr(t *testing.T) { request := createSetRequest() // Act - err := m.mySQL.setValue(&request) + err := m.mySQL.setValue(m.db, &request) // Assert assert.NotNil(t, err) @@ -353,7 +353,7 @@ func TestSetHandlesErr(t *testing.T) { request.ETag = &eTag // Act - err := m.mySQL.setValue(&request) + err := m.mySQL.setValue(m.db, &request) // Assert assert.NotNil(t, err) @@ -389,7 +389,7 @@ func TestDeleteWithETag(t *testing.T) { request.ETag = &eTag // Act - err := m.mySQL.deleteValue(&request) + err := m.mySQL.deleteValue(m.db, &request) // Assert assert.Nil(t, err) @@ -406,7 +406,7 @@ func TestDeleteWithErr(t *testing.T) { request := createDeleteRequest() // Act - err := m.mySQL.deleteValue(&request) + err := m.mySQL.deleteValue(m.db, &request) // Assert assert.NotNil(t, err) @@ -421,7 +421,7 @@ func TestDeleteWithErr(t *testing.T) { request.ETag = &eTag // Act - err := m.mySQL.deleteValue(&request) + err := m.mySQL.deleteValue(m.db, &request) // Assert assert.NotNil(t, err)