From f0ea91749956d00288d10f534e2f207808b006c1 Mon Sep 17 00:00:00 2001 From: YangKeao Date: Wed, 4 Sep 2024 17:21:32 +0800 Subject: [PATCH] transaction: add hook for async commit to track the life cycle of the async-commit goroutine and secondary lock cleanup goroutine (#1432) Signed-off-by: Yang Keao --- integration_tests/2pc_test.go | 63 ++++++++++++++++++++++++++ integration_tests/async_commit_test.go | 31 +++++++++++++ txnkv/transaction/2pc.go | 21 +++------ txnkv/transaction/test_probe.go | 5 ++ txnkv/transaction/txn.go | 53 ++++++++++++++++++++++ 5 files changed, 159 insertions(+), 14 deletions(-) diff --git a/integration_tests/2pc_test.go b/integration_tests/2pc_test.go index d907baf3..08de42e0 100644 --- a/integration_tests/2pc_test.go +++ b/integration_tests/2pc_test.go @@ -2512,3 +2512,66 @@ func (s *testCommitterSuite) TestKillSignal() { err = txn.Commit(context.Background()) s.ErrorContains(err, "query interrupted") } + +func (s *testCommitterSuite) Test2PCLifecycleHooks() { + reachedPre := atomic.Bool{} + reachedPost := atomic.Bool{} + + var wg sync.WaitGroup + + t1 := s.begin() + t1.SetBackgroundGoroutineLifecycleHooks(transaction.LifecycleHooks{ + Pre: func() { + wg.Add(1) + + reachedPre.Store(true) + }, + Post: func() { + s.Equal(reachedPre.Load(), true) + reachedPost.Store(true) + + wg.Done() + }, + }) + t1.Set([]byte("a"), []byte("a")) + t1.Set([]byte("z"), []byte("z")) + s.Nil(t1.Commit(context.Background())) + + s.Equal(reachedPre.Load(), true) + s.Equal(reachedPost.Load(), false) + wg.Wait() + s.Equal(reachedPost.Load(), true) +} + +func (s *testCommitterSuite) Test2PCCleanupLifecycleHooks() { + reachedPre := atomic.Bool{} + reachedPost := atomic.Bool{} + + var wg sync.WaitGroup + + t1 := s.begin() + t1.SetBackgroundGoroutineLifecycleHooks(transaction.LifecycleHooks{ + Pre: func() { + wg.Add(1) + + reachedPre.Store(true) + }, + Post: func() { + s.Equal(reachedPre.Load(), true) + reachedPost.Store(true) + + wg.Done() + }, + }) + t1.Set([]byte("a"), []byte("a")) + t1.Set([]byte("z"), []byte("z")) + committer, err := t1.NewCommitter(0) + s.Nil(err) + + committer.CleanupWithoutWait(context.Background()) + + s.Equal(reachedPre.Load(), true) + s.Equal(reachedPost.Load(), false) + wg.Wait() + s.Equal(reachedPost.Load(), true) +} diff --git a/integration_tests/async_commit_test.go b/integration_tests/async_commit_test.go index 8aa36fe9..54109843 100644 --- a/integration_tests/async_commit_test.go +++ b/integration_tests/async_commit_test.go @@ -39,6 +39,7 @@ import ( "context" "fmt" "math" + "sync" "sync/atomic" "testing" "time" @@ -632,3 +633,33 @@ func (s *testAsyncCommitSuite) TestRollbackAsyncCommitEnforcesFallback() { committer.PrewriteMutations(context.Background(), committer.GetMutations().Slice(1, 2)) s.False(committer.IsAsyncCommit()) } + +func (s *testAsyncCommitSuite) TestAsyncCommitLifecycleHooks() { + reachedPre := atomic.Bool{} + reachedPost := atomic.Bool{} + + var wg sync.WaitGroup + + t1 := s.beginAsyncCommit() + t1.SetBackgroundGoroutineLifecycleHooks(transaction.LifecycleHooks{ + Pre: func() { + wg.Add(1) + + reachedPre.Store(true) + }, + Post: func() { + s.Equal(reachedPre.Load(), true) + reachedPost.Store(true) + + wg.Done() + }, + }) + t1.Set([]byte("a"), []byte("a")) + t1.Set([]byte("z"), []byte("z")) + s.Nil(t1.Commit(context.Background())) + + s.Equal(reachedPre.Load(), true) + s.Equal(reachedPost.Load(), false) + wg.Wait() + s.Equal(reachedPost.Load(), true) +} diff --git a/txnkv/transaction/2pc.go b/txnkv/transaction/2pc.go index a772a0e8..fbc8a85a 100644 --- a/txnkv/transaction/2pc.go +++ b/txnkv/transaction/2pc.go @@ -1018,9 +1018,7 @@ func (c *twoPhaseCommitter) doActionOnGroupMutations(bo *retry.Backoffer, action zap.Uint64("sessionID", c.sessionID)) return nil } - c.store.WaitGroup().Add(1) - err = c.store.Go(func() { - defer c.store.WaitGroup().Done() + err = c.txn.spawnWithStorePool(func() { if c.sessionID > 0 { if v, err := util.EvalFailpoint("beforeCommitSecondaries"); err == nil { if s, ok := v.(string); !ok { @@ -1045,7 +1043,6 @@ func (c *twoPhaseCommitter) doActionOnGroupMutations(bo *retry.Backoffer, action } }) if err != nil { - c.store.WaitGroup().Done() logutil.BgLogger().Error("fail to create goroutine", zap.Uint64("session", c.sessionID), zap.Stringer("action type", action), @@ -1414,13 +1411,12 @@ func (c *twoPhaseCommitter) cleanup(ctx context.Context) { return } c.cleanWg.Add(1) - c.store.WaitGroup().Add(1) - go func() { - defer c.store.WaitGroup().Done() + c.txn.spawn(func() { + defer c.cleanWg.Done() + if _, err := util.EvalFailpoint("commitFailedSkipCleanup"); err == nil { logutil.Logger(ctx).Info("[failpoint] injected skip cleanup secondaries on failure", zap.Uint64("txnStartTS", c.startTS)) - c.cleanWg.Done() return } @@ -1443,8 +1439,7 @@ func (c *twoPhaseCommitter) cleanup(ctx context.Context) { zap.Uint64("txnStartTS", c.startTS), zap.Bool("isPessimistic", c.isPessimistic), zap.Bool("isOnePC", c.isOnePC())) } - c.cleanWg.Done() - }() + }) } // execute executes the two-phase commit protocol. @@ -1758,9 +1753,7 @@ func (c *twoPhaseCommitter) execute(ctx context.Context) (err error) { zap.Uint64("sessionID", c.sessionID)) return nil } - c.store.WaitGroup().Add(1) - go func() { - defer c.store.WaitGroup().Done() + c.txn.spawn(func() { if _, err := util.EvalFailpoint("asyncCommitDoNothing"); err == nil { return } @@ -1770,7 +1763,7 @@ func (c *twoPhaseCommitter) execute(ctx context.Context) (err error) { logutil.Logger(ctx).Warn("2PC async commit failed", zap.Uint64("sessionID", c.sessionID), zap.Uint64("startTS", c.startTS), zap.Uint64("commitTS", c.commitTS), zap.Error(err)) } - }() + }) return nil } return c.commitTxn(ctx, commitDetail) diff --git a/txnkv/transaction/test_probe.go b/txnkv/transaction/test_probe.go index 713502db..598bd70c 100644 --- a/txnkv/transaction/test_probe.go +++ b/txnkv/transaction/test_probe.go @@ -289,6 +289,11 @@ func (c CommitterProbe) Cleanup(ctx context.Context) { c.cleanWg.Wait() } +// CleanupWithoutWait cleans dirty data of a committer without waiting. +func (c CommitterProbe) CleanupWithoutWait(ctx context.Context) { + c.cleanup(ctx) +} + // WaitCleanup waits for the committer to complete. func (c CommitterProbe) WaitCleanup() { c.cleanWg.Wait() diff --git a/txnkv/transaction/txn.go b/txnkv/transaction/txn.go index 674a242a..dfa1461e 100644 --- a/txnkv/transaction/txn.go +++ b/txnkv/transaction/txn.go @@ -138,6 +138,11 @@ type KVTxn struct { // commitCallback is called after current transaction gets committed commitCallback func(info string, err error) + // backgroundGoroutineLifecycleHooks tracks the lifecycle of background goroutines of a + // transaction. The `.Pre` will be executed before the start of each background goroutine, + // and the `.Post` will be called after the background goroutine exits. + backgroundGoroutineLifecycleHooks LifecycleHooks + binlog BinlogExecutor schemaLeaseChecker SchemaLeaseChecker syncLog bool @@ -350,6 +355,47 @@ func (txn *KVTxn) SetCommitCallback(f func(string, error)) { txn.commitCallback = f } +// SetBackgroundGoroutineLifecycleHooks sets up the hooks to track the lifecycle of the background goroutines of a transaction. +func (txn *KVTxn) SetBackgroundGoroutineLifecycleHooks(hooks LifecycleHooks) { + txn.backgroundGoroutineLifecycleHooks = hooks +} + +// spawn starts a goroutine to run the given function. +func (txn *KVTxn) spawn(f func()) { + if txn.backgroundGoroutineLifecycleHooks.Pre != nil { + txn.backgroundGoroutineLifecycleHooks.Pre() + } + txn.store.WaitGroup().Add(1) + go func() { + if txn.backgroundGoroutineLifecycleHooks.Post != nil { + defer txn.backgroundGoroutineLifecycleHooks.Post() + } + defer txn.store.WaitGroup().Done() + + f() + }() +} + +// spawnWithStorePool starts a goroutine to run the given function with the store's goroutine pool. +func (txn *KVTxn) spawnWithStorePool(f func()) error { + if txn.backgroundGoroutineLifecycleHooks.Pre != nil { + txn.backgroundGoroutineLifecycleHooks.Pre() + } + txn.store.WaitGroup().Add(1) + err := txn.store.Go(func() { + if txn.backgroundGoroutineLifecycleHooks.Post != nil { + defer txn.backgroundGoroutineLifecycleHooks.Post() + } + defer txn.store.WaitGroup().Done() + + f() + }) + if err != nil { + txn.store.WaitGroup().Done() + } + return err +} + // SetEnableAsyncCommit indicates if the transaction will try to use async commit. func (txn *KVTxn) SetEnableAsyncCommit(b bool) { txn.enableAsyncCommit = b @@ -1708,3 +1754,10 @@ func (txn *KVTxn) SetExplicitRequestSourceType(tp string) { func (txn *KVTxn) MemHookSet() bool { return txn.us.GetMemBuffer().MemHookSet() } + +// LifecycleHooks is a struct that contains hooks for a background goroutine. +// The `Pre` is called before the goroutine starts, and the `Post` is called after the goroutine finishes. +type LifecycleHooks struct { + Pre func() + Post func() +}