Add ExecuteInTransaction method for db.SQL (#3309)
Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
parent
3da3d783d6
commit
3b0f320025
|
@ -129,7 +129,7 @@ func (g *gc) CleanupExpired() error {
|
|||
|
||||
// Check if the last iteration was too recent
|
||||
// This performs an atomic operation, so allows coordination with other daprd processes too
|
||||
// We do this before beginning the transaction
|
||||
// We do this outside of a the transaction since it's atomic
|
||||
canContinue, err := g.updateLastCleanup(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read last cleanup time from database: %w", err)
|
||||
|
@ -139,23 +139,12 @@ func (g *gc) CleanupExpired() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
tx, err := g.db.Begin(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
rowsAffected, err := tx.Exec(ctx, g.deleteExpiredValuesQuery)
|
||||
// Delete the expired values
|
||||
rowsAffected, err := g.db.Exec(ctx, g.deleteExpiredValuesQuery)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute query: %w", err)
|
||||
}
|
||||
|
||||
// Commit
|
||||
err = tx.Commit(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
|
||||
g.log.Infof("Removed %d expired rows", rowsAffected)
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
/*
|
||||
Copyright 2023 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package transactions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
||||
// ExecuteInTransaction executes a function in a transaction.
|
||||
// If the handler returns an error, the transaction is rolled back automatically.
|
||||
func ExecuteInTransaction[T any](ctx context.Context, log logger.Logger, db *sql.DB, fn func(ctx context.Context, tx *sql.Tx) (T, error)) (res T, err error) {
|
||||
// Start the transaction
|
||||
// Note that the context here is tied to the entire transaction
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return res, fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
|
||||
// Rollback in case of failure
|
||||
var success bool
|
||||
defer func() {
|
||||
if success {
|
||||
return
|
||||
}
|
||||
rollbackErr := tx.Rollback()
|
||||
if rollbackErr != nil {
|
||||
// Log errors only
|
||||
log.Errorf("Error while attempting to roll back transaction: %v", rollbackErr)
|
||||
}
|
||||
}()
|
||||
|
||||
// Execute the callback
|
||||
res, err = fn(ctx, tx)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return res, fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
success = true
|
||||
|
||||
return res, nil
|
||||
}
|
|
@ -28,6 +28,7 @@ import (
|
|||
"github.com/google/uuid"
|
||||
|
||||
commonsql "github.com/dapr/components-contrib/common/component/sql"
|
||||
sqltransactions "github.com/dapr/components-contrib/common/component/sql/transactions"
|
||||
"github.com/dapr/components-contrib/metadata"
|
||||
"github.com/dapr/components-contrib/state"
|
||||
"github.com/dapr/components-contrib/state/utils"
|
||||
|
@ -782,24 +783,16 @@ func (m *MySQL) Multi(ctx context.Context, request *state.TransactionalStateRequ
|
|||
case 1:
|
||||
return m.execMultiOperation(ctx, request.Operations[0], m.db)
|
||||
default:
|
||||
tx, err := m.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
rollbackErr := tx.Rollback()
|
||||
if rollbackErr != nil && !errors.Is(rollbackErr, sql.ErrTxDone) {
|
||||
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
|
||||
_, err := sqltransactions.ExecuteInTransaction(ctx, m.logger, m.db, func(ctx context.Context, tx *sql.Tx) (r struct{}, err error) {
|
||||
for _, op := range request.Operations {
|
||||
err = m.execMultiOperation(ctx, op, tx)
|
||||
if err != nil {
|
||||
return r, err
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for _, op := range request.Operations {
|
||||
err = m.execMultiOperation(ctx, op, tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
return r, nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -31,6 +31,7 @@ import (
|
|||
|
||||
"github.com/dapr/components-contrib/common/authentication/sqlite"
|
||||
commonsql "github.com/dapr/components-contrib/common/component/sql"
|
||||
sqltransactions "github.com/dapr/components-contrib/common/component/sql/transactions"
|
||||
"github.com/dapr/components-contrib/state"
|
||||
stateutils "github.com/dapr/components-contrib/state/utils"
|
||||
"github.com/dapr/kit/logger"
|
||||
|
@ -452,19 +453,16 @@ func (a *sqliteDBAccess) ExecuteMulti(parentCtx context.Context, reqs []state.Tr
|
|||
case 1:
|
||||
return a.execMultiOperation(parentCtx, reqs[0], a.db)
|
||||
default:
|
||||
tx, err := a.db.BeginTx(parentCtx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
for _, op := range reqs {
|
||||
err = a.execMultiOperation(parentCtx, op, tx)
|
||||
if err != nil {
|
||||
return err
|
||||
_, err := sqltransactions.ExecuteInTransaction(parentCtx, a.logger, a.db, func(ctx context.Context, tx *sql.Tx) (r struct{}, err error) {
|
||||
for _, op := range reqs {
|
||||
err = a.execMultiOperation(parentCtx, op, tx)
|
||||
if err != nil {
|
||||
return r, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
return r, nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
"time"
|
||||
|
||||
commonsql "github.com/dapr/components-contrib/common/component/sql"
|
||||
sqltransactions "github.com/dapr/components-contrib/common/component/sql/transactions"
|
||||
"github.com/dapr/components-contrib/metadata"
|
||||
"github.com/dapr/components-contrib/state"
|
||||
"github.com/dapr/components-contrib/state/utils"
|
||||
|
@ -186,24 +187,16 @@ func (s *SQLServer) Multi(ctx context.Context, request *state.TransactionalState
|
|||
case 1:
|
||||
return s.execMultiOperation(ctx, request.Operations[0], s.db)
|
||||
default:
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
rollbackErr := tx.Rollback()
|
||||
if rollbackErr != nil && !errors.Is(rollbackErr, sql.ErrTxDone) {
|
||||
s.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
|
||||
_, err := sqltransactions.ExecuteInTransaction(ctx, s.logger, s.db, func(ctx context.Context, tx *sql.Tx) (r struct{}, err error) {
|
||||
for _, op := range request.Operations {
|
||||
err = s.execMultiOperation(ctx, op, tx)
|
||||
if err != nil {
|
||||
return r, err
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for _, op := range request.Operations {
|
||||
err = s.execMultiOperation(ctx, op, tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
return r, nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue