Fixed: MySQL was not actually using transactions

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
ItalyPaleAle 2022-11-01 16:53:34 +00:00
parent c5847890ee
commit 8c617817a8
2 changed files with 60 additions and 34 deletions

View File

@ -323,12 +323,12 @@ func tableExists(db *sql.DB, tableName string, timeout time.Duration) (bool, err
// Delete removes an entity from the store // Delete removes an entity from the store
// Store Interface. // Store Interface.
func (m *MySQL) Delete(req *state.DeleteRequest) error { func (m *MySQL) Delete(req *state.DeleteRequest) error {
return state.DeleteWithOptions(m.deleteValue, req) return m.deleteValue(m.db, req)
} }
// deleteValue is an internal implementation of delete to enable passing the // deleteValue is an internal implementation of delete to enable passing the
// logic to state.DeleteWithRetries as a func. // logic to state.DeleteWithRetries as a func.
func (m *MySQL) deleteValue(req *state.DeleteRequest) error { func (m *MySQL) deleteValue(querier querier, req *state.DeleteRequest) error {
m.logger.Debug("Deleting state value from MySql") m.logger.Debug("Deleting state value from MySql")
if req.Key == "" { if req.Key == "" {
@ -343,11 +343,11 @@ func (m *MySQL) deleteValue(req *state.DeleteRequest) error {
defer cancel() defer cancel()
if req.ETag == nil || *req.ETag == "" { if req.ETag == nil || *req.ETag == "" {
result, err = m.db.ExecContext(ctx, fmt.Sprintf( result, err = querier.ExecContext(ctx, fmt.Sprintf(
`DELETE FROM %s WHERE id = ?`, `DELETE FROM %s WHERE id = ?`,
m.tableName), req.Key) m.tableName), req.Key)
} else { } else {
result, err = m.db.ExecContext(ctx, fmt.Sprintf( result, err = querier.ExecContext(ctx, fmt.Sprintf(
`DELETE FROM %s WHERE id = ? and eTag = ?`, `DELETE FROM %s WHERE id = ? and eTag = ?`,
m.tableName), req.Key, *req.ETag) m.tableName), req.Key, *req.ETag)
} }
@ -381,10 +381,12 @@ func (m *MySQL) BulkDelete(req []state.DeleteRequest) error {
if len(req) > 0 { if len(req) > 0 {
for _, d := range req { for _, d := range req {
da := d // Fix for goSec G601: Implicit memory aliasing in for loop. da := d // Fix for goSec G601: Implicit memory aliasing in for loop.
err = m.Delete(&da) err = m.deleteValue(tx, &da)
if err != nil { if err != nil {
tx.Rollback() rollbackErr := tx.Rollback()
if rollbackErr != nil {
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
}
return err return err
} }
} }
@ -460,12 +462,12 @@ func (m *MySQL) Get(req *state.GetRequest) (*state.GetResponse, error) {
// Set adds/updates an entity on store // Set adds/updates an entity on store
// Store Interface. // Store Interface.
func (m *MySQL) Set(req *state.SetRequest) error { func (m *MySQL) Set(req *state.SetRequest) error {
return state.SetWithOptions(m.setValue, req) return m.setValue(m.db, req)
} }
// setValue is an internal implementation of set to enable passing the logic // setValue is an internal implementation of set to enable passing the logic
// to state.SetWithRetries as a func. // to state.SetWithRetries as a func.
func (m *MySQL) setValue(req *state.SetRequest) error { func (m *MySQL) setValue(querier querier, req *state.SetRequest) error {
m.logger.Debug("Setting state value in MySql") m.logger.Debug("Setting state value in MySql")
err := state.CheckRequestOptions(req.Options) err := state.CheckRequestOptions(req.Options)
@ -494,7 +496,12 @@ func (m *MySQL) setValue(req *state.SetRequest) error {
encB, _ := json.Marshal(v) encB, _ := json.Marshal(v)
enc := string(encB) enc := string(encB)
eTag := uuid.New().String()
eTagObj, err := uuid.NewRandom()
if err != nil {
return fmt.Errorf("failed to generate etag: %w", err)
}
eTag := eTagObj.String()
var ( var (
result sql.Result result sql.Result
@ -505,29 +512,26 @@ func (m *MySQL) setValue(req *state.SetRequest) error {
if req.Options.Concurrency == state.FirstWrite && (req.ETag == nil || *req.ETag == "") { if req.Options.Concurrency == state.FirstWrite && (req.ETag == nil || *req.ETag == "") {
// With first-write-wins and no etag, we can insert the row only if it doesn't exist // With first-write-wins and no etag, we can insert the row only if it doesn't exist
//nolint:gosec
query := fmt.Sprintf( query := fmt.Sprintf(
`INSERT INTO %s (value, id, eTag, isbinary) VALUES (?, ?, ?, ?);`, `INSERT INTO %s (value, id, eTag, isbinary) VALUES (?, ?, ?, ?);`,
m.tableName, // m.tableName is sanitized m.tableName, // m.tableName is sanitized
) )
result, err = m.db.ExecContext(ctx, query, enc, req.Key, eTag, isBinary) result, err = querier.ExecContext(ctx, query, enc, req.Key, eTag, isBinary)
} else if req.ETag != nil && *req.ETag != "" { } else if req.ETag != nil && *req.ETag != "" {
// When an eTag is provided do an update - not insert // When an eTag is provided do an update - not insert
//nolint:gosec
query := fmt.Sprintf( query := fmt.Sprintf(
`UPDATE %s SET value = ?, eTag = ?, isbinary = ? WHERE id = ? AND eTag = ?;`, `UPDATE %s SET value = ?, eTag = ?, isbinary = ? WHERE id = ? AND eTag = ?;`,
m.tableName, // m.tableName is sanitized m.tableName, // m.tableName is sanitized
) )
result, err = m.db.ExecContext(ctx, query, enc, eTag, isBinary, req.Key, *req.ETag) result, err = querier.ExecContext(ctx, query, enc, eTag, isBinary, req.Key, *req.ETag)
} else { } else {
// If this is a duplicate MySQL returns that two rows affected // If this is a duplicate MySQL returns that two rows affected
maxRows = 2 maxRows = 2
//nolint:gosec
query := fmt.Sprintf( query := fmt.Sprintf(
`INSERT INTO %s (value, id, eTag, isbinary) VALUES (?, ?, ?, ?) on duplicate key update value=?, eTag=?, isbinary=?;`, `INSERT INTO %s (value, id, eTag, isbinary) VALUES (?, ?, ?, ?) on duplicate key update value=?, eTag=?, isbinary=?;`,
m.tableName, // m.tableName is sanitized m.tableName, // m.tableName is sanitized
) )
result, err = m.db.ExecContext(ctx, query, enc, req.Key, eTag, isBinary, enc, eTag, isBinary) result, err = querier.ExecContext(ctx, query, enc, req.Key, eTag, isBinary, enc, eTag, isBinary)
} }
if err != nil { if err != nil {
@ -571,9 +575,12 @@ func (m *MySQL) BulkSet(req []state.SetRequest) error {
if len(req) > 0 { if len(req) > 0 {
for i := range req { for i := range req {
err = m.Set(&req[i]) err = m.setValue(tx, &req[i])
if err != nil { if err != nil {
tx.Rollback() rollbackErr := tx.Rollback()
if rollbackErr != nil {
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
}
return err return err
} }
} }
@ -597,26 +604,38 @@ func (m *MySQL) Multi(request *state.TransactionalStateRequest) error {
case state.Upsert: case state.Upsert:
setReq, err := m.getSets(req) setReq, err := m.getSets(req)
if err != nil { if err != nil {
_ = tx.Rollback() rollbackErr := tx.Rollback()
if rollbackErr != nil {
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
}
return err return err
} }
err = m.Set(&setReq) err = m.setValue(tx, &setReq)
if err != nil { if err != nil {
_ = tx.Rollback() rollbackErr := tx.Rollback()
if rollbackErr != nil {
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
}
return err return err
} }
case state.Delete: case state.Delete:
delReq, err := m.getDeletes(req) delReq, err := m.getDeletes(req)
if err != nil { if err != nil {
_ = tx.Rollback() rollbackErr := tx.Rollback()
if rollbackErr != nil {
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
}
return err return err
} }
err = m.Delete(&delReq) err = m.deleteValue(tx, &delReq)
if err != nil { if err != nil {
_ = tx.Rollback() rollbackErr := tx.Rollback()
if rollbackErr != nil {
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
}
return err return err
} }
@ -694,3 +713,10 @@ func validIdentifier(v string) bool {
} }
return true return true
} }
// Interface for both sql.DB and sql.Tx
type querier interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}

View File

@ -249,7 +249,7 @@ func TestSetHandlesOptionsError(t *testing.T) {
request.Options.Consistency = "Invalid" request.Options.Consistency = "Invalid"
// Act // Act
err := m.mySQL.setValue(&request) err := m.mySQL.setValue(m.db, &request)
// Assert // Assert
assert.NotNil(t, err) assert.NotNil(t, err)
@ -284,7 +284,7 @@ func TestSetHandlesUpdate(t *testing.T) {
request.ETag = &eTag request.ETag = &eTag
// Act // Act
err := m.mySQL.setValue(&request) err := m.mySQL.setValue(m.db, &request)
// Assert // Assert
assert.Nil(t, err) assert.Nil(t, err)
@ -303,7 +303,7 @@ func TestSetHandlesErr(t *testing.T) {
request.ETag = &eTag request.ETag = &eTag
// Act // Act
err := m.mySQL.setValue(&request) err := m.mySQL.setValue(m.db, &request)
// Assert // Assert
assert.NotNil(t, err) assert.NotNil(t, err)
@ -316,7 +316,7 @@ func TestSetHandlesErr(t *testing.T) {
request := createSetRequest() request := createSetRequest()
// Act // Act
err := m.mySQL.setValue(&request) err := m.mySQL.setValue(m.db, &request)
// Assert // Assert
assert.NotNil(t, err) assert.NotNil(t, err)
@ -328,7 +328,7 @@ func TestSetHandlesErr(t *testing.T) {
request := createSetRequest() request := createSetRequest()
// Act // Act
err := m.mySQL.setValue(&request) err := m.mySQL.setValue(m.db, &request)
// Assert // Assert
assert.Nil(t, err) assert.Nil(t, err)
@ -339,7 +339,7 @@ func TestSetHandlesErr(t *testing.T) {
request := createSetRequest() request := createSetRequest()
// Act // Act
err := m.mySQL.setValue(&request) err := m.mySQL.setValue(m.db, &request)
// Assert // Assert
assert.NotNil(t, err) assert.NotNil(t, err)
@ -353,7 +353,7 @@ func TestSetHandlesErr(t *testing.T) {
request.ETag = &eTag request.ETag = &eTag
// Act // Act
err := m.mySQL.setValue(&request) err := m.mySQL.setValue(m.db, &request)
// Assert // Assert
assert.NotNil(t, err) assert.NotNil(t, err)
@ -389,7 +389,7 @@ func TestDeleteWithETag(t *testing.T) {
request.ETag = &eTag request.ETag = &eTag
// Act // Act
err := m.mySQL.deleteValue(&request) err := m.mySQL.deleteValue(m.db, &request)
// Assert // Assert
assert.Nil(t, err) assert.Nil(t, err)
@ -406,7 +406,7 @@ func TestDeleteWithErr(t *testing.T) {
request := createDeleteRequest() request := createDeleteRequest()
// Act // Act
err := m.mySQL.deleteValue(&request) err := m.mySQL.deleteValue(m.db, &request)
// Assert // Assert
assert.NotNil(t, err) assert.NotNil(t, err)
@ -421,7 +421,7 @@ func TestDeleteWithErr(t *testing.T) {
request.ETag = &eTag request.ETag = &eTag
// Act // Act
err := m.mySQL.deleteValue(&request) err := m.mySQL.deleteValue(m.db, &request)
// Assert // Assert
assert.NotNil(t, err) assert.NotNil(t, err)