Add ExecuteInTransaction method for db.SQL (#3309)

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
Alessandro (Ale) Segala 2024-01-16 08:27:46 -08:00 committed by GitHub
parent 3da3d783d6
commit 3b0f320025
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 94 additions and 60 deletions

View File

@ -129,7 +129,7 @@ func (g *gc) CleanupExpired() error {
// Check if the last iteration was too recent
// This performs an atomic operation, so allows coordination with other daprd processes too
// We do this before beginning the transaction
// We do this outside of a the transaction since it's atomic
canContinue, err := g.updateLastCleanup(ctx)
if err != nil {
return fmt.Errorf("failed to read last cleanup time from database: %w", err)
@ -139,23 +139,12 @@ func (g *gc) CleanupExpired() error {
return nil
}
tx, err := g.db.Begin(ctx)
if err != nil {
return fmt.Errorf("failed to start transaction: %w", err)
}
defer tx.Rollback(ctx)
rowsAffected, err := tx.Exec(ctx, g.deleteExpiredValuesQuery)
// Delete the expired values
rowsAffected, err := g.db.Exec(ctx, g.deleteExpiredValuesQuery)
if err != nil {
return fmt.Errorf("failed to execute query: %w", err)
}
// Commit
err = tx.Commit(ctx)
if err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
g.log.Infof("Removed %d expired rows", rowsAffected)
return nil
}

View File

@ -0,0 +1,61 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package transactions
import (
"context"
"database/sql"
"fmt"
"github.com/dapr/kit/logger"
)
// ExecuteInTransaction executes a function in a transaction.
// If the handler returns an error, the transaction is rolled back automatically.
func ExecuteInTransaction[T any](ctx context.Context, log logger.Logger, db *sql.DB, fn func(ctx context.Context, tx *sql.Tx) (T, error)) (res T, err error) {
// Start the transaction
// Note that the context here is tied to the entire transaction
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return res, fmt.Errorf("failed to begin transaction: %w", err)
}
// Rollback in case of failure
var success bool
defer func() {
if success {
return
}
rollbackErr := tx.Rollback()
if rollbackErr != nil {
// Log errors only
log.Errorf("Error while attempting to roll back transaction: %v", rollbackErr)
}
}()
// Execute the callback
res, err = fn(ctx, tx)
if err != nil {
return res, err
}
// Commit the transaction
err = tx.Commit()
if err != nil {
return res, fmt.Errorf("failed to commit transaction: %w", err)
}
success = true
return res, nil
}

View File

@ -28,6 +28,7 @@ import (
"github.com/google/uuid"
commonsql "github.com/dapr/components-contrib/common/component/sql"
sqltransactions "github.com/dapr/components-contrib/common/component/sql/transactions"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/state"
"github.com/dapr/components-contrib/state/utils"
@ -782,24 +783,16 @@ func (m *MySQL) Multi(ctx context.Context, request *state.TransactionalStateRequ
case 1:
return m.execMultiOperation(ctx, request.Operations[0], m.db)
default:
tx, err := m.db.Begin()
if err != nil {
return err
}
defer func() {
rollbackErr := tx.Rollback()
if rollbackErr != nil && !errors.Is(rollbackErr, sql.ErrTxDone) {
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
_, err := sqltransactions.ExecuteInTransaction(ctx, m.logger, m.db, func(ctx context.Context, tx *sql.Tx) (r struct{}, err error) {
for _, op := range request.Operations {
err = m.execMultiOperation(ctx, op, tx)
if err != nil {
return r, err
}
}
}()
for _, op := range request.Operations {
err = m.execMultiOperation(ctx, op, tx)
if err != nil {
return err
}
}
return tx.Commit()
return r, nil
})
return err
}
}

View File

@ -31,6 +31,7 @@ import (
"github.com/dapr/components-contrib/common/authentication/sqlite"
commonsql "github.com/dapr/components-contrib/common/component/sql"
sqltransactions "github.com/dapr/components-contrib/common/component/sql/transactions"
"github.com/dapr/components-contrib/state"
stateutils "github.com/dapr/components-contrib/state/utils"
"github.com/dapr/kit/logger"
@ -452,19 +453,16 @@ func (a *sqliteDBAccess) ExecuteMulti(parentCtx context.Context, reqs []state.Tr
case 1:
return a.execMultiOperation(parentCtx, reqs[0], a.db)
default:
tx, err := a.db.BeginTx(parentCtx, nil)
if err != nil {
return err
}
defer tx.Rollback()
for _, op := range reqs {
err = a.execMultiOperation(parentCtx, op, tx)
if err != nil {
return err
_, err := sqltransactions.ExecuteInTransaction(parentCtx, a.logger, a.db, func(ctx context.Context, tx *sql.Tx) (r struct{}, err error) {
for _, op := range reqs {
err = a.execMultiOperation(parentCtx, op, tx)
if err != nil {
return r, err
}
}
}
return tx.Commit()
return r, nil
})
return err
}
}

View File

@ -24,6 +24,7 @@ import (
"time"
commonsql "github.com/dapr/components-contrib/common/component/sql"
sqltransactions "github.com/dapr/components-contrib/common/component/sql/transactions"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/state"
"github.com/dapr/components-contrib/state/utils"
@ -186,24 +187,16 @@ func (s *SQLServer) Multi(ctx context.Context, request *state.TransactionalState
case 1:
return s.execMultiOperation(ctx, request.Operations[0], s.db)
default:
tx, err := s.db.Begin()
if err != nil {
return err
}
defer func() {
rollbackErr := tx.Rollback()
if rollbackErr != nil && !errors.Is(rollbackErr, sql.ErrTxDone) {
s.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
_, err := sqltransactions.ExecuteInTransaction(ctx, s.logger, s.db, func(ctx context.Context, tx *sql.Tx) (r struct{}, err error) {
for _, op := range request.Operations {
err = s.execMultiOperation(ctx, op, tx)
if err != nil {
return r, err
}
}
}()
for _, op := range request.Operations {
err = s.execMultiOperation(ctx, op, tx)
if err != nil {
return err
}
}
return tx.Commit()
return r, nil
})
return err
}
}