Fixed: MySQL was not actually using transactions
Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
parent
c5847890ee
commit
8c617817a8
|
|
@ -323,12 +323,12 @@ func tableExists(db *sql.DB, tableName string, timeout time.Duration) (bool, err
|
|||
// Delete removes an entity from the store
|
||||
// Store Interface.
|
||||
func (m *MySQL) Delete(req *state.DeleteRequest) error {
|
||||
return state.DeleteWithOptions(m.deleteValue, req)
|
||||
return m.deleteValue(m.db, req)
|
||||
}
|
||||
|
||||
// deleteValue is an internal implementation of delete to enable passing the
|
||||
// logic to state.DeleteWithRetries as a func.
|
||||
func (m *MySQL) deleteValue(req *state.DeleteRequest) error {
|
||||
func (m *MySQL) deleteValue(querier querier, req *state.DeleteRequest) error {
|
||||
m.logger.Debug("Deleting state value from MySql")
|
||||
|
||||
if req.Key == "" {
|
||||
|
|
@ -343,11 +343,11 @@ func (m *MySQL) deleteValue(req *state.DeleteRequest) error {
|
|||
defer cancel()
|
||||
|
||||
if req.ETag == nil || *req.ETag == "" {
|
||||
result, err = m.db.ExecContext(ctx, fmt.Sprintf(
|
||||
result, err = querier.ExecContext(ctx, fmt.Sprintf(
|
||||
`DELETE FROM %s WHERE id = ?`,
|
||||
m.tableName), req.Key)
|
||||
} else {
|
||||
result, err = m.db.ExecContext(ctx, fmt.Sprintf(
|
||||
result, err = querier.ExecContext(ctx, fmt.Sprintf(
|
||||
`DELETE FROM %s WHERE id = ? and eTag = ?`,
|
||||
m.tableName), req.Key, *req.ETag)
|
||||
}
|
||||
|
|
@ -381,10 +381,12 @@ func (m *MySQL) BulkDelete(req []state.DeleteRequest) error {
|
|||
if len(req) > 0 {
|
||||
for _, d := range req {
|
||||
da := d // Fix for goSec G601: Implicit memory aliasing in for loop.
|
||||
err = m.Delete(&da)
|
||||
err = m.deleteValue(tx, &da)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
|
||||
rollbackErr := tx.Rollback()
|
||||
if rollbackErr != nil {
|
||||
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
@ -460,12 +462,12 @@ func (m *MySQL) Get(req *state.GetRequest) (*state.GetResponse, error) {
|
|||
// Set adds/updates an entity on store
|
||||
// Store Interface.
|
||||
func (m *MySQL) Set(req *state.SetRequest) error {
|
||||
return state.SetWithOptions(m.setValue, req)
|
||||
return m.setValue(m.db, req)
|
||||
}
|
||||
|
||||
// setValue is an internal implementation of set to enable passing the logic
|
||||
// to state.SetWithRetries as a func.
|
||||
func (m *MySQL) setValue(req *state.SetRequest) error {
|
||||
func (m *MySQL) setValue(querier querier, req *state.SetRequest) error {
|
||||
m.logger.Debug("Setting state value in MySql")
|
||||
|
||||
err := state.CheckRequestOptions(req.Options)
|
||||
|
|
@ -494,7 +496,12 @@ func (m *MySQL) setValue(req *state.SetRequest) error {
|
|||
|
||||
encB, _ := json.Marshal(v)
|
||||
enc := string(encB)
|
||||
eTag := uuid.New().String()
|
||||
|
||||
eTagObj, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate etag: %w", err)
|
||||
}
|
||||
eTag := eTagObj.String()
|
||||
|
||||
var (
|
||||
result sql.Result
|
||||
|
|
@ -505,29 +512,26 @@ func (m *MySQL) setValue(req *state.SetRequest) error {
|
|||
|
||||
if req.Options.Concurrency == state.FirstWrite && (req.ETag == nil || *req.ETag == "") {
|
||||
// With first-write-wins and no etag, we can insert the row only if it doesn't exist
|
||||
//nolint:gosec
|
||||
query := fmt.Sprintf(
|
||||
`INSERT INTO %s (value, id, eTag, isbinary) VALUES (?, ?, ?, ?);`,
|
||||
m.tableName, // m.tableName is sanitized
|
||||
)
|
||||
result, err = m.db.ExecContext(ctx, query, enc, req.Key, eTag, isBinary)
|
||||
result, err = querier.ExecContext(ctx, query, enc, req.Key, eTag, isBinary)
|
||||
} else if req.ETag != nil && *req.ETag != "" {
|
||||
// When an eTag is provided do an update - not insert
|
||||
//nolint:gosec
|
||||
query := fmt.Sprintf(
|
||||
`UPDATE %s SET value = ?, eTag = ?, isbinary = ? WHERE id = ? AND eTag = ?;`,
|
||||
m.tableName, // m.tableName is sanitized
|
||||
)
|
||||
result, err = m.db.ExecContext(ctx, query, enc, eTag, isBinary, req.Key, *req.ETag)
|
||||
result, err = querier.ExecContext(ctx, query, enc, eTag, isBinary, req.Key, *req.ETag)
|
||||
} else {
|
||||
// If this is a duplicate MySQL returns that two rows affected
|
||||
maxRows = 2
|
||||
//nolint:gosec
|
||||
query := fmt.Sprintf(
|
||||
`INSERT INTO %s (value, id, eTag, isbinary) VALUES (?, ?, ?, ?) on duplicate key update value=?, eTag=?, isbinary=?;`,
|
||||
m.tableName, // m.tableName is sanitized
|
||||
)
|
||||
result, err = m.db.ExecContext(ctx, query, enc, req.Key, eTag, isBinary, enc, eTag, isBinary)
|
||||
result, err = querier.ExecContext(ctx, query, enc, req.Key, eTag, isBinary, enc, eTag, isBinary)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
|
@ -571,9 +575,12 @@ func (m *MySQL) BulkSet(req []state.SetRequest) error {
|
|||
|
||||
if len(req) > 0 {
|
||||
for i := range req {
|
||||
err = m.Set(&req[i])
|
||||
err = m.setValue(tx, &req[i])
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
rollbackErr := tx.Rollback()
|
||||
if rollbackErr != nil {
|
||||
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
@ -597,26 +604,38 @@ func (m *MySQL) Multi(request *state.TransactionalStateRequest) error {
|
|||
case state.Upsert:
|
||||
setReq, err := m.getSets(req)
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
rollbackErr := tx.Rollback()
|
||||
if rollbackErr != nil {
|
||||
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
err = m.Set(&setReq)
|
||||
err = m.setValue(tx, &setReq)
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
rollbackErr := tx.Rollback()
|
||||
if rollbackErr != nil {
|
||||
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
case state.Delete:
|
||||
delReq, err := m.getDeletes(req)
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
rollbackErr := tx.Rollback()
|
||||
if rollbackErr != nil {
|
||||
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
err = m.Delete(&delReq)
|
||||
err = m.deleteValue(tx, &delReq)
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
rollbackErr := tx.Rollback()
|
||||
if rollbackErr != nil {
|
||||
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -694,3 +713,10 @@ func validIdentifier(v string) bool {
|
|||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Interface for both sql.DB and sql.Tx
|
||||
type querier interface {
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
||||
}
|
||||
|
|
|
|||
|
|
@ -249,7 +249,7 @@ func TestSetHandlesOptionsError(t *testing.T) {
|
|||
request.Options.Consistency = "Invalid"
|
||||
|
||||
// Act
|
||||
err := m.mySQL.setValue(&request)
|
||||
err := m.mySQL.setValue(m.db, &request)
|
||||
|
||||
// Assert
|
||||
assert.NotNil(t, err)
|
||||
|
|
@ -284,7 +284,7 @@ func TestSetHandlesUpdate(t *testing.T) {
|
|||
request.ETag = &eTag
|
||||
|
||||
// Act
|
||||
err := m.mySQL.setValue(&request)
|
||||
err := m.mySQL.setValue(m.db, &request)
|
||||
|
||||
// Assert
|
||||
assert.Nil(t, err)
|
||||
|
|
@ -303,7 +303,7 @@ func TestSetHandlesErr(t *testing.T) {
|
|||
request.ETag = &eTag
|
||||
|
||||
// Act
|
||||
err := m.mySQL.setValue(&request)
|
||||
err := m.mySQL.setValue(m.db, &request)
|
||||
|
||||
// Assert
|
||||
assert.NotNil(t, err)
|
||||
|
|
@ -316,7 +316,7 @@ func TestSetHandlesErr(t *testing.T) {
|
|||
request := createSetRequest()
|
||||
|
||||
// Act
|
||||
err := m.mySQL.setValue(&request)
|
||||
err := m.mySQL.setValue(m.db, &request)
|
||||
|
||||
// Assert
|
||||
assert.NotNil(t, err)
|
||||
|
|
@ -328,7 +328,7 @@ func TestSetHandlesErr(t *testing.T) {
|
|||
request := createSetRequest()
|
||||
|
||||
// Act
|
||||
err := m.mySQL.setValue(&request)
|
||||
err := m.mySQL.setValue(m.db, &request)
|
||||
|
||||
// Assert
|
||||
assert.Nil(t, err)
|
||||
|
|
@ -339,7 +339,7 @@ func TestSetHandlesErr(t *testing.T) {
|
|||
request := createSetRequest()
|
||||
|
||||
// Act
|
||||
err := m.mySQL.setValue(&request)
|
||||
err := m.mySQL.setValue(m.db, &request)
|
||||
|
||||
// Assert
|
||||
assert.NotNil(t, err)
|
||||
|
|
@ -353,7 +353,7 @@ func TestSetHandlesErr(t *testing.T) {
|
|||
request.ETag = &eTag
|
||||
|
||||
// Act
|
||||
err := m.mySQL.setValue(&request)
|
||||
err := m.mySQL.setValue(m.db, &request)
|
||||
|
||||
// Assert
|
||||
assert.NotNil(t, err)
|
||||
|
|
@ -389,7 +389,7 @@ func TestDeleteWithETag(t *testing.T) {
|
|||
request.ETag = &eTag
|
||||
|
||||
// Act
|
||||
err := m.mySQL.deleteValue(&request)
|
||||
err := m.mySQL.deleteValue(m.db, &request)
|
||||
|
||||
// Assert
|
||||
assert.Nil(t, err)
|
||||
|
|
@ -406,7 +406,7 @@ func TestDeleteWithErr(t *testing.T) {
|
|||
request := createDeleteRequest()
|
||||
|
||||
// Act
|
||||
err := m.mySQL.deleteValue(&request)
|
||||
err := m.mySQL.deleteValue(m.db, &request)
|
||||
|
||||
// Assert
|
||||
assert.NotNil(t, err)
|
||||
|
|
@ -421,7 +421,7 @@ func TestDeleteWithErr(t *testing.T) {
|
|||
request.ETag = &eTag
|
||||
|
||||
// Act
|
||||
err := m.mySQL.deleteValue(&request)
|
||||
err := m.mySQL.deleteValue(m.db, &request)
|
||||
|
||||
// Assert
|
||||
assert.NotNil(t, err)
|
||||
|
|
|
|||
Loading…
Reference in New Issue