1pc_tesst, async_commit_test: replace pingcap/check with testify (#162)

Signed-off-by: disksing <i@disksing.com>
This commit is contained in:
disksing 2021-06-24 15:55:19 +08:00 committed by GitHub
parent 691f687223
commit ccb3cdb2f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 328 additions and 318 deletions

View File

@ -34,8 +34,9 @@ package tikv_test
import ( import (
"context" "context"
"testing"
. "github.com/pingcap/check" "github.com/stretchr/testify/suite"
"github.com/tikv/client-go/v2/metrics" "github.com/tikv/client-go/v2/metrics"
"github.com/tikv/client-go/v2/mockstore" "github.com/tikv/client-go/v2/mockstore"
"github.com/tikv/client-go/v2/oracle" "github.com/tikv/client-go/v2/oracle"
@ -43,11 +44,8 @@ import (
"github.com/tikv/client-go/v2/util" "github.com/tikv/client-go/v2/util"
) )
func (s *testAsyncCommitCommon) begin1PC(c *C) tikv.TxnProbe { func TestOnePC(t *testing.T) {
txn, err := s.store.Begin() suite.Run(t, new(testOnePCSuite))
c.Assert(err, IsNil)
txn.SetEnable1PC(true)
return tikv.TxnProbe{KVTxn: txn}
} }
type testOnePCSuite struct { type testOnePCSuite struct {
@ -55,56 +53,54 @@ type testOnePCSuite struct {
bo *tikv.Backoffer bo *tikv.Backoffer
} }
var _ = SerialSuites(&testOnePCSuite{}) func (s *testOnePCSuite) SetupTest() {
s.testAsyncCommitCommon.setUpTest()
func (s *testOnePCSuite) SetUpTest(c *C) {
s.testAsyncCommitCommon.setUpTest(c)
s.bo = tikv.NewBackofferWithVars(context.Background(), 5000, nil) s.bo = tikv.NewBackofferWithVars(context.Background(), 5000, nil)
} }
func (s *testOnePCSuite) Test1PC(c *C) { func (s *testOnePCSuite) Test1PC() {
ctx := context.WithValue(context.Background(), util.SessionID, uint64(1)) ctx := context.WithValue(context.Background(), util.SessionID, uint64(1))
k1 := []byte("k1") k1 := []byte("k1")
v1 := []byte("v1") v1 := []byte("v1")
txn := s.begin1PC(c) txn := s.begin1PC()
err := txn.Set(k1, v1) err := txn.Set(k1, v1)
c.Assert(err, IsNil) s.Nil(err)
err = txn.Commit(ctx) err = txn.Commit(ctx)
c.Assert(err, IsNil) s.Nil(err)
c.Assert(txn.GetCommitter().IsOnePC(), IsTrue) s.True(txn.GetCommitter().IsOnePC())
c.Assert(txn.GetCommitter().GetOnePCCommitTS(), Equals, txn.GetCommitter().GetCommitTS()) s.Equal(txn.GetCommitter().GetOnePCCommitTS(), txn.GetCommitter().GetCommitTS())
c.Assert(txn.GetCommitter().GetOnePCCommitTS(), Greater, txn.StartTS()) s.Greater(txn.GetCommitter().GetOnePCCommitTS(), txn.StartTS())
// ttlManager is not used for 1PC. // ttlManager is not used for 1PC.
c.Assert(txn.GetCommitter().IsTTLUninitialized(), IsTrue) s.True(txn.GetCommitter().IsTTLUninitialized())
// 1PC doesn't work if sessionID == 0 // 1PC doesn't work if sessionID == 0
k2 := []byte("k2") k2 := []byte("k2")
v2 := []byte("v2") v2 := []byte("v2")
txn = s.begin1PC(c) txn = s.begin1PC()
err = txn.Set(k2, v2) err = txn.Set(k2, v2)
c.Assert(err, IsNil) s.Nil(err)
err = txn.Commit(context.Background()) err = txn.Commit(context.Background())
c.Assert(err, IsNil) s.Nil(err)
c.Assert(txn.GetCommitter().IsOnePC(), IsFalse) s.False(txn.GetCommitter().IsOnePC())
c.Assert(txn.GetCommitter().GetOnePCCommitTS(), Equals, uint64(0)) s.Equal(txn.GetCommitter().GetOnePCCommitTS(), uint64(0))
c.Assert(txn.GetCommitter().GetCommitTS(), Greater, txn.StartTS()) s.Greater(txn.GetCommitter().GetCommitTS(), txn.StartTS())
// 1PC doesn't work if system variable not set // 1PC doesn't work if system variable not set
k3 := []byte("k3") k3 := []byte("k3")
v3 := []byte("v3") v3 := []byte("v3")
txn = s.begin(c) txn = s.begin()
err = txn.Set(k3, v3) err = txn.Set(k3, v3)
c.Assert(err, IsNil) s.Nil(err)
err = txn.Commit(ctx) err = txn.Commit(ctx)
c.Assert(err, IsNil) s.Nil(err)
c.Assert(txn.GetCommitter().IsOnePC(), IsFalse) s.False(txn.GetCommitter().IsOnePC())
c.Assert(txn.GetCommitter().GetOnePCCommitTS(), Equals, uint64(0)) s.Equal(txn.GetCommitter().GetOnePCCommitTS(), uint64(0))
c.Assert(txn.GetCommitter().GetCommitTS(), Greater, txn.StartTS()) s.Greater(txn.GetCommitter().GetCommitTS(), txn.StartTS())
// Test multiple keys // Test multiple keys
k4 := []byte("k4") k4 := []byte("k4")
@ -114,89 +110,89 @@ func (s *testOnePCSuite) Test1PC(c *C) {
k6 := []byte("k6") k6 := []byte("k6")
v6 := []byte("v6") v6 := []byte("v6")
txn = s.begin1PC(c) txn = s.begin1PC()
err = txn.Set(k4, v4) err = txn.Set(k4, v4)
c.Assert(err, IsNil) s.Nil(err)
err = txn.Set(k5, v5) err = txn.Set(k5, v5)
c.Assert(err, IsNil) s.Nil(err)
err = txn.Set(k6, v6) err = txn.Set(k6, v6)
c.Assert(err, IsNil) s.Nil(err)
err = txn.Commit(ctx) err = txn.Commit(ctx)
c.Assert(err, IsNil) s.Nil(err)
c.Assert(txn.GetCommitter().IsOnePC(), IsTrue) s.True(txn.GetCommitter().IsOnePC())
c.Assert(txn.GetCommitter().GetOnePCCommitTS(), Equals, txn.GetCommitter().GetCommitTS()) s.Equal(txn.GetCommitter().GetOnePCCommitTS(), txn.GetCommitter().GetCommitTS())
c.Assert(txn.GetCommitter().GetOnePCCommitTS(), Greater, txn.StartTS()) s.Greater(txn.GetCommitter().GetOnePCCommitTS(), txn.StartTS())
// Check keys are committed with the same version // Check keys are committed with the same version
s.mustGetFromSnapshot(c, txn.GetCommitTS(), k4, v4) s.mustGetFromSnapshot(txn.GetCommitTS(), k4, v4)
s.mustGetFromSnapshot(c, txn.GetCommitTS(), k5, v5) s.mustGetFromSnapshot(txn.GetCommitTS(), k5, v5)
s.mustGetFromSnapshot(c, txn.GetCommitTS(), k6, v6) s.mustGetFromSnapshot(txn.GetCommitTS(), k6, v6)
s.mustGetNoneFromSnapshot(c, txn.GetCommitTS()-1, k4) s.mustGetNoneFromSnapshot(txn.GetCommitTS()-1, k4)
s.mustGetNoneFromSnapshot(c, txn.GetCommitTS()-1, k5) s.mustGetNoneFromSnapshot(txn.GetCommitTS()-1, k5)
s.mustGetNoneFromSnapshot(c, txn.GetCommitTS()-1, k6) s.mustGetNoneFromSnapshot(txn.GetCommitTS()-1, k6)
// Overwriting in MVCC // Overwriting in MVCC
v6New := []byte("v6new") v6New := []byte("v6new")
txn = s.begin1PC(c) txn = s.begin1PC()
err = txn.Set(k6, v6New) err = txn.Set(k6, v6New)
c.Assert(err, IsNil) s.Nil(err)
err = txn.Commit(ctx) err = txn.Commit(ctx)
c.Assert(err, IsNil) s.Nil(err)
c.Assert(txn.GetCommitter().IsOnePC(), IsTrue) s.True(txn.GetCommitter().IsOnePC())
c.Assert(txn.GetCommitter().GetOnePCCommitTS(), Equals, txn.GetCommitter().GetCommitTS()) s.Equal(txn.GetCommitter().GetOnePCCommitTS(), txn.GetCommitter().GetCommitTS())
c.Assert(txn.GetCommitter().GetOnePCCommitTS(), Greater, txn.StartTS()) s.Greater(txn.GetCommitter().GetOnePCCommitTS(), txn.StartTS())
s.mustGetFromSnapshot(c, txn.GetCommitTS(), k6, v6New) s.mustGetFromSnapshot(txn.GetCommitTS(), k6, v6New)
s.mustGetFromSnapshot(c, txn.GetCommitTS()-1, k6, v6) s.mustGetFromSnapshot(txn.GetCommitTS()-1, k6, v6)
// Check all keys // Check all keys
keys := [][]byte{k1, k2, k3, k4, k5, k6} keys := [][]byte{k1, k2, k3, k4, k5, k6}
values := [][]byte{v1, v2, v3, v4, v5, v6New} values := [][]byte{v1, v2, v3, v4, v5, v6New}
ver, err := s.store.CurrentTimestamp(oracle.GlobalTxnScope) ver, err := s.store.CurrentTimestamp(oracle.GlobalTxnScope)
c.Assert(err, IsNil) s.Nil(err)
snap := s.store.GetSnapshot(ver) snap := s.store.GetSnapshot(ver)
for i, k := range keys { for i, k := range keys {
v, err := snap.Get(ctx, k) v, err := snap.Get(ctx, k)
c.Assert(err, IsNil) s.Nil(err)
c.Assert(v, BytesEquals, values[i]) s.Equal(v, values[i])
} }
} }
func (s *testOnePCSuite) Test1PCIsolation(c *C) { func (s *testOnePCSuite) Test1PCIsolation() {
ctx := context.WithValue(context.Background(), util.SessionID, uint64(1)) ctx := context.WithValue(context.Background(), util.SessionID, uint64(1))
k := []byte("k") k := []byte("k")
v1 := []byte("v1") v1 := []byte("v1")
txn := s.begin1PC(c) txn := s.begin1PC()
txn.Set(k, v1) txn.Set(k, v1)
err := txn.Commit(ctx) err := txn.Commit(ctx)
c.Assert(err, IsNil) s.Nil(err)
v2 := []byte("v2") v2 := []byte("v2")
txn = s.begin1PC(c) txn = s.begin1PC()
txn.Set(k, v2) txn.Set(k, v2)
// Make `txn`'s commitTs more likely to be less than `txn2`'s startTs if there's bug in commitTs // Make `txn`'s commitTs more likely to be less than `txn2`'s startTs if there's bug in commitTs
// calculation. // calculation.
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
_, err := s.store.GetOracle().GetTimestamp(ctx, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) _, err := s.store.GetOracle().GetTimestamp(ctx, &oracle.Option{TxnScope: oracle.GlobalTxnScope})
c.Assert(err, IsNil) s.Nil(err)
} }
txn2 := s.begin1PC(c) txn2 := s.begin1PC()
s.mustGetFromTxn(c, txn2, k, v1) s.mustGetFromTxn(txn2, k, v1)
err = txn.Commit(ctx) err = txn.Commit(ctx)
c.Assert(txn.GetCommitter().IsOnePC(), IsTrue) s.True(txn.GetCommitter().IsOnePC())
c.Assert(err, IsNil) s.Nil(err)
s.mustGetFromTxn(c, txn2, k, v1) s.mustGetFromTxn(txn2, k, v1)
c.Assert(txn2.Rollback(), IsNil) s.Nil(txn2.Rollback())
s.mustGetFromSnapshot(c, txn.GetCommitTS(), k, v2) s.mustGetFromSnapshot(txn.GetCommitTS(), k, v2)
s.mustGetFromSnapshot(c, txn.GetCommitTS()-1, k, v1) s.mustGetFromSnapshot(txn.GetCommitTS()-1, k, v1)
} }
func (s *testOnePCSuite) Test1PCDisallowMultiRegion(c *C) { func (s *testOnePCSuite) Test1PCDisallowMultiRegion() {
// This test doesn't support tikv mode. // This test doesn't support tikv mode.
if *mockstore.WithTiKV { if *mockstore.WithTiKV {
return return
@ -204,127 +200,127 @@ func (s *testOnePCSuite) Test1PCDisallowMultiRegion(c *C) {
ctx := context.WithValue(context.Background(), util.SessionID, uint64(1)) ctx := context.WithValue(context.Background(), util.SessionID, uint64(1))
txn := s.begin1PC(c) txn := s.begin1PC()
keys := []string{"k0", "k1", "k2", "k3"} keys := []string{"k0", "k1", "k2", "k3"}
values := []string{"v0", "v1", "v2", "v3"} values := []string{"v0", "v1", "v2", "v3"}
err := txn.Set([]byte(keys[0]), []byte(values[0])) err := txn.Set([]byte(keys[0]), []byte(values[0]))
c.Assert(err, IsNil) s.Nil(err)
err = txn.Set([]byte(keys[3]), []byte(values[3])) err = txn.Set([]byte(keys[3]), []byte(values[3]))
c.Assert(err, IsNil) s.Nil(err)
err = txn.Commit(ctx) err = txn.Commit(ctx)
c.Assert(err, IsNil) s.Nil(err)
// 1PC doesn't work if it affects multiple regions. // 1PC doesn't work if it affects multiple regions.
loc, err := s.store.GetRegionCache().LocateKey(s.bo, []byte(keys[2])) loc, err := s.store.GetRegionCache().LocateKey(s.bo, []byte(keys[2]))
c.Assert(err, IsNil) s.Nil(err)
newRegionID := s.cluster.AllocID() newRegionID := s.cluster.AllocID()
newPeerID := s.cluster.AllocID() newPeerID := s.cluster.AllocID()
s.cluster.Split(loc.Region.GetID(), newRegionID, []byte(keys[2]), []uint64{newPeerID}, newPeerID) s.cluster.Split(loc.Region.GetID(), newRegionID, []byte(keys[2]), []uint64{newPeerID}, newPeerID)
txn = s.begin1PC(c) txn = s.begin1PC()
err = txn.Set([]byte(keys[1]), []byte(values[1])) err = txn.Set([]byte(keys[1]), []byte(values[1]))
c.Assert(err, IsNil) s.Nil(err)
err = txn.Set([]byte(keys[2]), []byte(values[2])) err = txn.Set([]byte(keys[2]), []byte(values[2]))
c.Assert(err, IsNil) s.Nil(err)
err = txn.Commit(ctx) err = txn.Commit(ctx)
c.Assert(err, IsNil) s.Nil(err)
c.Assert(txn.GetCommitter().IsOnePC(), IsFalse) s.False(txn.GetCommitter().IsOnePC())
c.Assert(txn.GetCommitter().GetOnePCCommitTS(), Equals, uint64(0)) s.Equal(txn.GetCommitter().GetOnePCCommitTS(), uint64(0))
c.Assert(txn.GetCommitter().GetCommitTS(), Greater, txn.StartTS()) s.Greater(txn.GetCommitter().GetCommitTS(), txn.StartTS())
ver, err := s.store.CurrentTimestamp(oracle.GlobalTxnScope) ver, err := s.store.CurrentTimestamp(oracle.GlobalTxnScope)
c.Assert(err, IsNil) s.Nil(err)
snap := s.store.GetSnapshot(ver) snap := s.store.GetSnapshot(ver)
for i, k := range keys { for i, k := range keys {
v, err := snap.Get(ctx, []byte(k)) v, err := snap.Get(ctx, []byte(k))
c.Assert(err, IsNil) s.Nil(err)
c.Assert(v, BytesEquals, []byte(values[i])) s.Equal(v, []byte(values[i]))
} }
} }
// It's just a simple validation of linearizability. // It's just a simple validation of linearizability.
// Extra tests are needed to test this feature with the control of the TiKV cluster. // Extra tests are needed to test this feature with the control of the TiKV cluster.
func (s *testOnePCSuite) Test1PCLinearizability(c *C) { func (s *testOnePCSuite) Test1PCLinearizability() {
t1 := s.begin(c) t1 := s.begin()
t2 := s.begin(c) t2 := s.begin()
err := t1.Set([]byte("a"), []byte("a1")) err := t1.Set([]byte("a"), []byte("a1"))
c.Assert(err, IsNil) s.Nil(err)
err = t2.Set([]byte("b"), []byte("b1")) err = t2.Set([]byte("b"), []byte("b1"))
c.Assert(err, IsNil) s.Nil(err)
ctx := context.WithValue(context.Background(), util.SessionID, uint64(1)) ctx := context.WithValue(context.Background(), util.SessionID, uint64(1))
// t2 commits earlier than t1 // t2 commits earlier than t1
err = t2.Commit(ctx) err = t2.Commit(ctx)
c.Assert(err, IsNil) s.Nil(err)
err = t1.Commit(ctx) err = t1.Commit(ctx)
c.Assert(err, IsNil) s.Nil(err)
commitTS1 := t1.GetCommitter().GetCommitTS() commitTS1 := t1.GetCommitter().GetCommitTS()
commitTS2 := t2.GetCommitter().GetCommitTS() commitTS2 := t2.GetCommitter().GetCommitTS()
c.Assert(commitTS2, Less, commitTS1) s.Less(commitTS2, commitTS1)
} }
func (s *testOnePCSuite) Test1PCWithMultiDC(c *C) { func (s *testOnePCSuite) Test1PCWithMultiDC() {
// It requires setting placement rules to run with TiKV // It requires setting placement rules to run with TiKV
if *mockstore.WithTiKV { if *mockstore.WithTiKV {
return return
} }
localTxn := s.begin1PC(c) localTxn := s.begin1PC()
err := localTxn.Set([]byte("a"), []byte("a1")) err := localTxn.Set([]byte("a"), []byte("a1"))
localTxn.SetScope("bj") localTxn.SetScope("bj")
c.Assert(err, IsNil) s.Nil(err)
ctx := context.WithValue(context.Background(), util.SessionID, uint64(1)) ctx := context.WithValue(context.Background(), util.SessionID, uint64(1))
err = localTxn.Commit(ctx) err = localTxn.Commit(ctx)
c.Assert(err, IsNil) s.Nil(err)
c.Assert(localTxn.GetCommitter().IsOnePC(), IsFalse) s.False(localTxn.GetCommitter().IsOnePC())
globalTxn := s.begin1PC(c) globalTxn := s.begin1PC()
err = globalTxn.Set([]byte("b"), []byte("b1")) err = globalTxn.Set([]byte("b"), []byte("b1"))
globalTxn.SetScope(oracle.GlobalTxnScope) globalTxn.SetScope(oracle.GlobalTxnScope)
c.Assert(err, IsNil) s.Nil(err)
err = globalTxn.Commit(ctx) err = globalTxn.Commit(ctx)
c.Assert(err, IsNil) s.Nil(err)
c.Assert(globalTxn.GetCommitter().IsOnePC(), IsTrue) s.True(globalTxn.GetCommitter().IsOnePC())
} }
func (s *testOnePCSuite) TestTxnCommitCounter(c *C) { func (s *testOnePCSuite) TestTxnCommitCounter() {
initial := metrics.GetTxnCommitCounter() initial := metrics.GetTxnCommitCounter()
// 2PC // 2PC
txn := s.begin(c) txn := s.begin()
err := txn.Set([]byte("k"), []byte("v")) err := txn.Set([]byte("k"), []byte("v"))
c.Assert(err, IsNil) s.Nil(err)
ctx := context.WithValue(context.Background(), util.SessionID, uint64(1)) ctx := context.WithValue(context.Background(), util.SessionID, uint64(1))
err = txn.Commit(ctx) err = txn.Commit(ctx)
c.Assert(err, IsNil) s.Nil(err)
curr := metrics.GetTxnCommitCounter() curr := metrics.GetTxnCommitCounter()
diff := curr.Sub(initial) diff := curr.Sub(initial)
c.Assert(diff.TwoPC, Equals, int64(1)) s.Equal(diff.TwoPC, int64(1))
c.Assert(diff.AsyncCommit, Equals, int64(0)) s.Equal(diff.AsyncCommit, int64(0))
c.Assert(diff.OnePC, Equals, int64(0)) s.Equal(diff.OnePC, int64(0))
// AsyncCommit // AsyncCommit
txn = s.beginAsyncCommit(c) txn = s.beginAsyncCommit()
err = txn.Set([]byte("k1"), []byte("v1")) err = txn.Set([]byte("k1"), []byte("v1"))
c.Assert(err, IsNil) s.Nil(err)
err = txn.Commit(ctx) err = txn.Commit(ctx)
c.Assert(err, IsNil) s.Nil(err)
curr = metrics.GetTxnCommitCounter() curr = metrics.GetTxnCommitCounter()
diff = curr.Sub(initial) diff = curr.Sub(initial)
c.Assert(diff.TwoPC, Equals, int64(1)) s.Equal(diff.TwoPC, int64(1))
c.Assert(diff.AsyncCommit, Equals, int64(1)) s.Equal(diff.AsyncCommit, int64(1))
c.Assert(diff.OnePC, Equals, int64(0)) s.Equal(diff.OnePC, int64(0))
// 1PC // 1PC
txn = s.begin1PC(c) txn = s.begin1PC()
err = txn.Set([]byte("k2"), []byte("v2")) err = txn.Set([]byte("k2"), []byte("v2"))
c.Assert(err, IsNil) s.Nil(err)
err = txn.Commit(ctx) err = txn.Commit(ctx)
c.Assert(err, IsNil) s.Nil(err)
curr = metrics.GetTxnCommitCounter() curr = metrics.GetTxnCommitCounter()
diff = curr.Sub(initial) diff = curr.Sub(initial)
c.Assert(diff.TwoPC, Equals, int64(1)) s.Equal(diff.TwoPC, int64(1))
c.Assert(diff.AsyncCommit, Equals, int64(1)) s.Equal(diff.AsyncCommit, int64(1))
c.Assert(diff.OnePC, Equals, int64(1)) s.Equal(diff.OnePC, int64(1))
} }

View File

@ -36,63 +36,66 @@ import (
"bytes" "bytes"
"context" "context"
"sort" "sort"
"testing"
. "github.com/pingcap/check"
"github.com/pingcap/errors" "github.com/pingcap/errors"
"github.com/pingcap/failpoint" "github.com/pingcap/failpoint"
"github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/kvproto/pkg/kvrpcpb"
"github.com/pingcap/parser/terror" "github.com/pingcap/parser/terror"
"github.com/stretchr/testify/suite"
tikverr "github.com/tikv/client-go/v2/error" tikverr "github.com/tikv/client-go/v2/error"
"github.com/tikv/client-go/v2/mockstore" "github.com/tikv/client-go/v2/mockstore"
"github.com/tikv/client-go/v2/tikv" "github.com/tikv/client-go/v2/tikv"
"github.com/tikv/client-go/v2/util" "github.com/tikv/client-go/v2/util"
) )
func TestAsyncCommitFail(t *testing.T) {
suite.Run(t, new(testAsyncCommitFailSuite))
}
type testAsyncCommitFailSuite struct { type testAsyncCommitFailSuite struct {
testAsyncCommitCommon testAsyncCommitCommon
} }
var _ = SerialSuites(&testAsyncCommitFailSuite{}) func (s *testAsyncCommitFailSuite) SetupTest() {
s.testAsyncCommitCommon.setUpTest()
func (s *testAsyncCommitFailSuite) SetUpTest(c *C) {
s.testAsyncCommitCommon.setUpTest(c)
} }
// TestFailCommitPrimaryRpcErrors tests rpc errors are handled properly when // TestFailCommitPrimaryRpcErrors tests rpc errors are handled properly when
// committing primary region task. // committing primary region task.
func (s *testAsyncCommitFailSuite) TestFailAsyncCommitPrewriteRpcErrors(c *C) { func (s *testAsyncCommitFailSuite) TestFailAsyncCommitPrewriteRpcErrors() {
// This test doesn't support tikv mode because it needs setting failpoint in unistore. // This test doesn't support tikv mode because it needs setting failpoint in unistore.
if *mockstore.WithTiKV { if *mockstore.WithTiKV {
return return
} }
c.Assert(failpoint.Enable("tikvclient/noRetryOnRpcError", "return(true)"), IsNil) s.Nil(failpoint.Enable("tikvclient/noRetryOnRpcError", "return(true)"))
c.Assert(failpoint.Enable("tikvclient/rpcPrewriteTimeout", `return(true)`), IsNil) s.Nil(failpoint.Enable("tikvclient/rpcPrewriteTimeout", `return(true)`))
defer func() { defer func() {
c.Assert(failpoint.Disable("tikvclient/rpcPrewriteTimeout"), IsNil) s.Nil(failpoint.Disable("tikvclient/rpcPrewriteTimeout"))
c.Assert(failpoint.Disable("tikvclient/noRetryOnRpcError"), IsNil) s.Nil(failpoint.Disable("tikvclient/noRetryOnRpcError"))
}() }()
// The rpc error will be wrapped to ErrResultUndetermined. // The rpc error will be wrapped to ErrResultUndetermined.
t1 := s.beginAsyncCommit(c) t1 := s.beginAsyncCommit()
err := t1.Set([]byte("a"), []byte("a1")) err := t1.Set([]byte("a"), []byte("a1"))
c.Assert(err, IsNil) s.Nil(err)
ctx := context.WithValue(context.Background(), util.SessionID, uint64(1)) ctx := context.WithValue(context.Background(), util.SessionID, uint64(1))
err = t1.Commit(ctx) err = t1.Commit(ctx)
c.Assert(err, NotNil) s.NotNil(err)
c.Assert(terror.ErrorEqual(err, terror.ErrResultUndetermined), IsTrue, Commentf("%s", errors.ErrorStack(err))) s.True(terror.ErrorEqual(err, terror.ErrResultUndetermined), errors.ErrorStack(err))
// We don't need to call "Rollback" after "Commit" fails. // We don't need to call "Rollback" after "Commit" fails.
err = t1.Rollback() err = t1.Rollback()
c.Assert(err, Equals, tikverr.ErrInvalidTxn) s.Equal(err, tikverr.ErrInvalidTxn)
// Create a new transaction to check. The previous transaction should actually commit. // Create a new transaction to check. The previous transaction should actually commit.
t2 := s.beginAsyncCommit(c) t2 := s.beginAsyncCommit()
res, err := t2.Get(context.Background(), []byte("a")) res, err := t2.Get(context.Background(), []byte("a"))
c.Assert(err, IsNil) s.Nil(err)
c.Assert(bytes.Equal(res, []byte("a1")), IsTrue) s.True(bytes.Equal(res, []byte("a1")))
} }
func (s *testAsyncCommitFailSuite) TestAsyncCommitPrewriteCancelled(c *C) { func (s *testAsyncCommitFailSuite) TestAsyncCommitPrewriteCancelled() {
// This test doesn't support tikv mode because it needs setting failpoint in unistore. // This test doesn't support tikv mode because it needs setting failpoint in unistore.
if *mockstore.WithTiKV { if *mockstore.WithTiKV {
return return
@ -102,69 +105,69 @@ func (s *testAsyncCommitFailSuite) TestAsyncCommitPrewriteCancelled(c *C) {
splitKey := "s" splitKey := "s"
bo := tikv.NewBackofferWithVars(context.Background(), 5000, nil) bo := tikv.NewBackofferWithVars(context.Background(), 5000, nil)
loc, err := s.store.GetRegionCache().LocateKey(bo, []byte(splitKey)) loc, err := s.store.GetRegionCache().LocateKey(bo, []byte(splitKey))
c.Assert(err, IsNil) s.Nil(err)
newRegionID := s.cluster.AllocID() newRegionID := s.cluster.AllocID()
newPeerID := s.cluster.AllocID() newPeerID := s.cluster.AllocID()
s.cluster.Split(loc.Region.GetID(), newRegionID, []byte(splitKey), []uint64{newPeerID}, newPeerID) s.cluster.Split(loc.Region.GetID(), newRegionID, []byte(splitKey), []uint64{newPeerID}, newPeerID)
s.store.GetRegionCache().InvalidateCachedRegion(loc.Region) s.store.GetRegionCache().InvalidateCachedRegion(loc.Region)
c.Assert(failpoint.Enable("tikvclient/rpcPrewriteResult", `1*return("writeConflict")->sleep(50)`), IsNil) s.Nil(failpoint.Enable("tikvclient/rpcPrewriteResult", `1*return("writeConflict")->sleep(50)`))
defer func() { defer func() {
c.Assert(failpoint.Disable("tikvclient/rpcPrewriteResult"), IsNil) s.Nil(failpoint.Disable("tikvclient/rpcPrewriteResult"))
}() }()
t1 := s.beginAsyncCommit(c) t1 := s.beginAsyncCommit()
err = t1.Set([]byte("a"), []byte("a")) err = t1.Set([]byte("a"), []byte("a"))
c.Assert(err, IsNil) s.Nil(err)
err = t1.Set([]byte("z"), []byte("z")) err = t1.Set([]byte("z"), []byte("z"))
c.Assert(err, IsNil) s.Nil(err)
ctx := context.WithValue(context.Background(), util.SessionID, uint64(1)) ctx := context.WithValue(context.Background(), util.SessionID, uint64(1))
err = t1.Commit(ctx) err = t1.Commit(ctx)
c.Assert(err, NotNil) s.NotNil(err)
_, ok := errors.Cause(err).(*tikverr.ErrWriteConflict) _, ok := errors.Cause(err).(*tikverr.ErrWriteConflict)
c.Assert(ok, IsTrue, Commentf("%s", errors.ErrorStack(err))) s.True(ok, errors.ErrorStack(err))
} }
func (s *testAsyncCommitFailSuite) TestPointGetWithAsyncCommit(c *C) { func (s *testAsyncCommitFailSuite) TestPointGetWithAsyncCommit() {
s.putAlphabets(c, true) s.putAlphabets(true)
txn := s.beginAsyncCommit(c) txn := s.beginAsyncCommit()
txn.Set([]byte("a"), []byte("v1")) txn.Set([]byte("a"), []byte("v1"))
txn.Set([]byte("b"), []byte("v2")) txn.Set([]byte("b"), []byte("v2"))
s.mustPointGet(c, []byte("a"), []byte("a")) s.mustPointGet([]byte("a"), []byte("a"))
s.mustPointGet(c, []byte("b"), []byte("b")) s.mustPointGet([]byte("b"), []byte("b"))
// PointGet cannot ignore async commit transactions' locks. // PointGet cannot ignore async commit transactions' locks.
c.Assert(failpoint.Enable("tikvclient/asyncCommitDoNothing", "return"), IsNil) s.Nil(failpoint.Enable("tikvclient/asyncCommitDoNothing", "return"))
ctx := context.WithValue(context.Background(), util.SessionID, uint64(1)) ctx := context.WithValue(context.Background(), util.SessionID, uint64(1))
err := txn.Commit(ctx) err := txn.Commit(ctx)
c.Assert(err, IsNil) s.Nil(err)
c.Assert(txn.GetCommitter().IsAsyncCommit(), IsTrue) s.True(txn.GetCommitter().IsAsyncCommit())
s.mustPointGet(c, []byte("a"), []byte("v1")) s.mustPointGet([]byte("a"), []byte("v1"))
s.mustPointGet(c, []byte("b"), []byte("v2")) s.mustPointGet([]byte("b"), []byte("v2"))
c.Assert(failpoint.Disable("tikvclient/asyncCommitDoNothing"), IsNil) s.Nil(failpoint.Disable("tikvclient/asyncCommitDoNothing"))
// PointGet will not push the `max_ts` to its ts which is MaxUint64. // PointGet will not push the `max_ts` to its ts which is MaxUint64.
txn2 := s.beginAsyncCommit(c) txn2 := s.beginAsyncCommit()
s.mustGetFromTxn(c, txn2, []byte("a"), []byte("v1")) s.mustGetFromTxn(txn2, []byte("a"), []byte("v1"))
s.mustGetFromTxn(c, txn2, []byte("b"), []byte("v2")) s.mustGetFromTxn(txn2, []byte("b"), []byte("v2"))
err = txn2.Rollback() err = txn2.Rollback()
c.Assert(err, IsNil) s.Nil(err)
} }
func (s *testAsyncCommitFailSuite) TestSecondaryListInPrimaryLock(c *C) { func (s *testAsyncCommitFailSuite) TestSecondaryListInPrimaryLock() {
// This test doesn't support tikv mode. // This test doesn't support tikv mode.
if *mockstore.WithTiKV { if *mockstore.WithTiKV {
return return
} }
s.putAlphabets(c, true) s.putAlphabets(true)
// Split into several regions. // Split into several regions.
for _, splitKey := range []string{"h", "o", "u"} { for _, splitKey := range []string{"h", "o", "u"} {
bo := tikv.NewBackofferWithVars(context.Background(), 5000, nil) bo := tikv.NewBackofferWithVars(context.Background(), 5000, nil)
loc, err := s.store.GetRegionCache().LocateKey(bo, []byte(splitKey)) loc, err := s.store.GetRegionCache().LocateKey(bo, []byte(splitKey))
c.Assert(err, IsNil) s.Nil(err)
newRegionID := s.cluster.AllocID() newRegionID := s.cluster.AllocID()
newPeerID := s.cluster.AllocID() newPeerID := s.cluster.AllocID()
s.cluster.Split(loc.Region.GetID(), newRegionID, []byte(splitKey), []uint64{newPeerID}, newPeerID) s.cluster.Split(loc.Region.GetID(), newRegionID, []byte(splitKey), []uint64{newPeerID}, newPeerID)
@ -174,37 +177,37 @@ func (s *testAsyncCommitFailSuite) TestSecondaryListInPrimaryLock(c *C) {
// Ensure the region has been split // Ensure the region has been split
bo := tikv.NewBackofferWithVars(context.Background(), 5000, nil) bo := tikv.NewBackofferWithVars(context.Background(), 5000, nil)
loc, err := s.store.GetRegionCache().LocateKey(bo, []byte("i")) loc, err := s.store.GetRegionCache().LocateKey(bo, []byte("i"))
c.Assert(err, IsNil) s.Nil(err)
c.Assert(loc.StartKey, BytesEquals, []byte("h")) s.Equal(loc.StartKey, []byte("h"))
c.Assert(loc.EndKey, BytesEquals, []byte("o")) s.Equal(loc.EndKey, []byte("o"))
loc, err = s.store.GetRegionCache().LocateKey(bo, []byte("p")) loc, err = s.store.GetRegionCache().LocateKey(bo, []byte("p"))
c.Assert(err, IsNil) s.Nil(err)
c.Assert(loc.StartKey, BytesEquals, []byte("o")) s.Equal(loc.StartKey, []byte("o"))
c.Assert(loc.EndKey, BytesEquals, []byte("u")) s.Equal(loc.EndKey, []byte("u"))
var sessionID uint64 = 0 var sessionID uint64 = 0
test := func(keys []string, values []string) { test := func(keys []string, values []string) {
sessionID++ sessionID++
ctx := context.WithValue(context.Background(), util.SessionID, sessionID) ctx := context.WithValue(context.Background(), util.SessionID, sessionID)
txn := s.beginAsyncCommit(c) txn := s.beginAsyncCommit()
for i := range keys { for i := range keys {
txn.Set([]byte(keys[i]), []byte(values[i])) txn.Set([]byte(keys[i]), []byte(values[i]))
} }
c.Assert(failpoint.Enable("tikvclient/asyncCommitDoNothing", "return"), IsNil) s.Nil(failpoint.Enable("tikvclient/asyncCommitDoNothing", "return"))
err = txn.Commit(ctx) err = txn.Commit(ctx)
c.Assert(err, IsNil) s.Nil(err)
primary := txn.GetCommitter().GetPrimaryKey() primary := txn.GetCommitter().GetPrimaryKey()
bo := tikv.NewBackofferWithVars(context.Background(), 5000, nil) bo := tikv.NewBackofferWithVars(context.Background(), 5000, nil)
lockResolver := tikv.LockResolverProbe{LockResolver: s.store.GetLockResolver()} lockResolver := tikv.LockResolverProbe{LockResolver: s.store.GetLockResolver()}
txnStatus, err := lockResolver.GetTxnStatus(bo, txn.StartTS(), primary, 0, 0, false, false, nil) txnStatus, err := lockResolver.GetTxnStatus(bo, txn.StartTS(), primary, 0, 0, false, false, nil)
c.Assert(err, IsNil) s.Nil(err)
c.Assert(txnStatus.IsCommitted(), IsFalse) s.False(txnStatus.IsCommitted())
c.Assert(txnStatus.Action(), Equals, kvrpcpb.Action_NoAction) s.Equal(txnStatus.Action(), kvrpcpb.Action_NoAction)
// Currently when the transaction has no secondary, the `secondaries` field of the txnStatus // Currently when the transaction has no secondary, the `secondaries` field of the txnStatus
// will be set nil. So here initialize the `expectedSecondaries` to nil too. // will be set nil. So here initialize the `expectedSecondaries` to nil too.
var expectedSecondaries [][]byte var expectedSecondaries [][]byte
@ -222,9 +225,9 @@ func (s *testAsyncCommitFailSuite) TestSecondaryListInPrimaryLock(c *C) {
return bytes.Compare(gotSecondaries[i], gotSecondaries[j]) < 0 return bytes.Compare(gotSecondaries[i], gotSecondaries[j]) < 0
}) })
c.Assert(gotSecondaries, DeepEquals, expectedSecondaries) s.Equal(gotSecondaries, expectedSecondaries)
c.Assert(failpoint.Disable("tikvclient/asyncCommitDoNothing"), IsNil) s.Nil(failpoint.Disable("tikvclient/asyncCommitDoNothing"))
txn.GetCommitter().Cleanup(context.Background()) txn.GetCommitter().Cleanup(context.Background())
} }
@ -235,68 +238,68 @@ func (s *testAsyncCommitFailSuite) TestSecondaryListInPrimaryLock(c *C) {
test([]string{"i", "a", "z", "u", "b"}, []string{"i5", "a5", "z5", "u5", "b5"}) test([]string{"i", "a", "z", "u", "b"}, []string{"i5", "a5", "z5", "u5", "b5"})
} }
func (s *testAsyncCommitFailSuite) TestAsyncCommitContextCancelCausingUndetermined(c *C) { func (s *testAsyncCommitFailSuite) TestAsyncCommitContextCancelCausingUndetermined() {
// For an async commit transaction, if RPC returns context.Canceled error when prewriting, the // For an async commit transaction, if RPC returns context.Canceled error when prewriting, the
// transaction should go to undetermined state. // transaction should go to undetermined state.
txn := s.beginAsyncCommit(c) txn := s.beginAsyncCommit()
err := txn.Set([]byte("a"), []byte("va")) err := txn.Set([]byte("a"), []byte("va"))
c.Assert(err, IsNil) s.Nil(err)
c.Assert(failpoint.Enable("tikvclient/rpcContextCancelErr", `return(true)`), IsNil) s.Nil(failpoint.Enable("tikvclient/rpcContextCancelErr", `return(true)`))
defer func() { defer func() {
c.Assert(failpoint.Disable("tikvclient/rpcContextCancelErr"), IsNil) s.Nil(failpoint.Disable("tikvclient/rpcContextCancelErr"))
}() }()
ctx := context.WithValue(context.Background(), util.SessionID, uint64(1)) ctx := context.WithValue(context.Background(), util.SessionID, uint64(1))
err = txn.Commit(ctx) err = txn.Commit(ctx)
c.Assert(err, NotNil) s.NotNil(err)
c.Assert(txn.GetCommitter().GetUndeterminedErr(), NotNil) s.NotNil(txn.GetCommitter().GetUndeterminedErr())
} }
// TestAsyncCommitRPCErrorThenWriteConflict verifies that the determined failure error overwrites undetermined error. // TestAsyncCommitRPCErrorThenWriteConflict verifies that the determined failure error overwrites undetermined error.
func (s *testAsyncCommitFailSuite) TestAsyncCommitRPCErrorThenWriteConflict(c *C) { func (s *testAsyncCommitFailSuite) TestAsyncCommitRPCErrorThenWriteConflict() {
// This test doesn't support tikv mode because it needs setting failpoint in unistore. // This test doesn't support tikv mode because it needs setting failpoint in unistore.
if *mockstore.WithTiKV { if *mockstore.WithTiKV {
return return
} }
txn := s.beginAsyncCommit(c) txn := s.beginAsyncCommit()
err := txn.Set([]byte("a"), []byte("va")) err := txn.Set([]byte("a"), []byte("va"))
c.Assert(err, IsNil) s.Nil(err)
c.Assert(failpoint.Enable("tikvclient/rpcPrewriteResult", `1*return("timeout")->return("writeConflict")`), IsNil) s.Nil(failpoint.Enable("tikvclient/rpcPrewriteResult", `1*return("timeout")->return("writeConflict")`))
defer func() { defer func() {
c.Assert(failpoint.Disable("tikvclient/rpcPrewriteResult"), IsNil) s.Nil(failpoint.Disable("tikvclient/rpcPrewriteResult"))
}() }()
ctx := context.WithValue(context.Background(), util.SessionID, uint64(1)) ctx := context.WithValue(context.Background(), util.SessionID, uint64(1))
err = txn.Commit(ctx) err = txn.Commit(ctx)
c.Assert(err, NotNil) s.NotNil(err)
c.Assert(txn.GetCommitter().GetUndeterminedErr(), IsNil) s.Nil(txn.GetCommitter().GetUndeterminedErr())
} }
// TestAsyncCommitRPCErrorThenWriteConflictInChild verifies that the determined failure error in a child recursion // TestAsyncCommitRPCErrorThenWriteConflictInChild verifies that the determined failure error in a child recursion
// overwrites the undetermined error in the parent. // overwrites the undetermined error in the parent.
func (s *testAsyncCommitFailSuite) TestAsyncCommitRPCErrorThenWriteConflictInChild(c *C) { func (s *testAsyncCommitFailSuite) TestAsyncCommitRPCErrorThenWriteConflictInChild() {
// This test doesn't support tikv mode because it needs setting failpoint in unistore. // This test doesn't support tikv mode because it needs setting failpoint in unistore.
if *mockstore.WithTiKV { if *mockstore.WithTiKV {
return return
} }
txn := s.beginAsyncCommit(c) txn := s.beginAsyncCommit()
err := txn.Set([]byte("a"), []byte("va")) err := txn.Set([]byte("a"), []byte("va"))
c.Assert(err, IsNil) s.Nil(err)
c.Assert(failpoint.Enable("tikvclient/rpcPrewriteResult", `1*return("timeout")->return("writeConflict")`), IsNil) s.Nil(failpoint.Enable("tikvclient/rpcPrewriteResult", `1*return("timeout")->return("writeConflict")`))
c.Assert(failpoint.Enable("tikvclient/forceRecursion", `return`), IsNil) s.Nil(failpoint.Enable("tikvclient/forceRecursion", `return`))
defer func() { defer func() {
c.Assert(failpoint.Disable("tikvclient/rpcPrewriteResult"), IsNil) s.Nil(failpoint.Disable("tikvclient/rpcPrewriteResult"))
c.Assert(failpoint.Disable("tikvclient/forceRecursion"), IsNil) s.Nil(failpoint.Disable("tikvclient/forceRecursion"))
}() }()
ctx := context.WithValue(context.Background(), util.SessionID, uint64(1)) ctx := context.WithValue(context.Background(), util.SessionID, uint64(1))
err = txn.Commit(ctx) err = txn.Commit(ctx)
c.Assert(err, NotNil) s.NotNil(err)
c.Assert(txn.GetCommitter().GetUndeterminedErr(), IsNil) s.Nil(txn.GetCommitter().GetUndeterminedErr())
} }

View File

@ -45,6 +45,7 @@ import (
"github.com/pingcap/errors" "github.com/pingcap/errors"
"github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/kvproto/pkg/kvrpcpb"
"github.com/pingcap/tidb/store/mockstore/unistore" "github.com/pingcap/tidb/store/mockstore/unistore"
"github.com/stretchr/testify/suite"
tikverr "github.com/tikv/client-go/v2/error" tikverr "github.com/tikv/client-go/v2/error"
"github.com/tikv/client-go/v2/mockstore" "github.com/tikv/client-go/v2/mockstore"
"github.com/tikv/client-go/v2/mockstore/cluster" "github.com/tikv/client-go/v2/mockstore/cluster"
@ -59,107 +60,119 @@ func TestT(t *testing.T) {
TestingT(t) TestingT(t)
} }
func TestAsyncCommit(t *testing.T) {
suite.Run(t, new(testAsyncCommitSuite))
}
// testAsyncCommitCommon is used to put common parts that will be both used by // testAsyncCommitCommon is used to put common parts that will be both used by
// testAsyncCommitSuite and testAsyncCommitFailSuite. // testAsyncCommitSuite and testAsyncCommitFailSuite.
type testAsyncCommitCommon struct { type testAsyncCommitCommon struct {
suite.Suite
cluster cluster.Cluster cluster cluster.Cluster
store *tikv.KVStore store *tikv.KVStore
} }
func (s *testAsyncCommitCommon) setUpTest(c *C) { func (s *testAsyncCommitCommon) setUpTest() {
if *mockstore.WithTiKV { if *mockstore.WithTiKV {
s.store = NewTestStore(c) s.store = NewTestStoreT(s.T())
return return
} }
client, pdClient, cluster, err := unistore.New("") client, pdClient, cluster, err := unistore.New("")
c.Assert(err, IsNil) s.Require().Nil(err)
unistore.BootstrapWithSingleStore(cluster) unistore.BootstrapWithSingleStore(cluster)
s.cluster = cluster s.cluster = cluster
store, err := tikv.NewTestTiKVStore(fpClient{Client: client}, pdClient, nil, nil, 0) store, err := tikv.NewTestTiKVStore(fpClient{Client: client}, pdClient, nil, nil, 0)
c.Assert(err, IsNil) s.Require().Nil(err)
s.store = store s.store = store
} }
func (s *testAsyncCommitCommon) putAlphabets(c *C, enableAsyncCommit bool) { func (s *testAsyncCommitCommon) putAlphabets(enableAsyncCommit bool) {
for ch := byte('a'); ch <= byte('z'); ch++ { for ch := byte('a'); ch <= byte('z'); ch++ {
s.putKV(c, []byte{ch}, []byte{ch}, enableAsyncCommit) s.putKV([]byte{ch}, []byte{ch}, enableAsyncCommit)
} }
} }
func (s *testAsyncCommitCommon) putKV(c *C, key, value []byte, enableAsyncCommit bool) (uint64, uint64) { func (s *testAsyncCommitCommon) putKV(key, value []byte, enableAsyncCommit bool) (uint64, uint64) {
txn := s.beginAsyncCommit(c) txn := s.beginAsyncCommit()
err := txn.Set(key, value) err := txn.Set(key, value)
c.Assert(err, IsNil) s.Nil(err)
err = txn.Commit(context.Background()) err = txn.Commit(context.Background())
c.Assert(err, IsNil) s.Nil(err)
return txn.StartTS(), txn.GetCommitTS() return txn.StartTS(), txn.GetCommitTS()
} }
func (s *testAsyncCommitCommon) mustGetFromTxn(c *C, txn tikv.TxnProbe, key, expectedValue []byte) { func (s *testAsyncCommitCommon) mustGetFromTxn(txn tikv.TxnProbe, key, expectedValue []byte) {
v, err := txn.Get(context.Background(), key) v, err := txn.Get(context.Background(), key)
c.Assert(err, IsNil) s.Nil(err)
c.Assert(v, BytesEquals, expectedValue) s.Equal(v, expectedValue)
} }
func (s *testAsyncCommitCommon) mustGetLock(c *C, key []byte) *tikv.Lock { func (s *testAsyncCommitCommon) mustGetLock(key []byte) *tikv.Lock {
ver, err := s.store.CurrentTimestamp(oracle.GlobalTxnScope) ver, err := s.store.CurrentTimestamp(oracle.GlobalTxnScope)
c.Assert(err, IsNil) s.Nil(err)
req := tikvrpc.NewRequest(tikvrpc.CmdGet, &kvrpcpb.GetRequest{ req := tikvrpc.NewRequest(tikvrpc.CmdGet, &kvrpcpb.GetRequest{
Key: key, Key: key,
Version: ver, Version: ver,
}) })
bo := tikv.NewBackofferWithVars(context.Background(), 5000, nil) bo := tikv.NewBackofferWithVars(context.Background(), 5000, nil)
loc, err := s.store.GetRegionCache().LocateKey(bo, key) loc, err := s.store.GetRegionCache().LocateKey(bo, key)
c.Assert(err, IsNil) s.Nil(err)
resp, err := s.store.SendReq(bo, req, loc.Region, time.Second*10) resp, err := s.store.SendReq(bo, req, loc.Region, time.Second*10)
c.Assert(err, IsNil) s.Nil(err)
c.Assert(resp.Resp, NotNil) s.NotNil(resp.Resp)
keyErr := resp.Resp.(*kvrpcpb.GetResponse).GetError() keyErr := resp.Resp.(*kvrpcpb.GetResponse).GetError()
c.Assert(keyErr, NotNil) s.NotNil(keyErr)
var lockutil tikv.LockProbe var lockutil tikv.LockProbe
lock, err := lockutil.ExtractLockFromKeyErr(keyErr) lock, err := lockutil.ExtractLockFromKeyErr(keyErr)
c.Assert(err, IsNil) s.Nil(err)
return lock return lock
} }
func (s *testAsyncCommitCommon) mustPointGet(c *C, key, expectedValue []byte) { func (s *testAsyncCommitCommon) mustPointGet(key, expectedValue []byte) {
snap := s.store.GetSnapshot(math.MaxUint64) snap := s.store.GetSnapshot(math.MaxUint64)
value, err := snap.Get(context.Background(), key) value, err := snap.Get(context.Background(), key)
c.Assert(err, IsNil) s.Nil(err)
c.Assert(value, BytesEquals, expectedValue) s.Equal(value, expectedValue)
} }
func (s *testAsyncCommitCommon) mustGetFromSnapshot(c *C, version uint64, key, expectedValue []byte) { func (s *testAsyncCommitCommon) mustGetFromSnapshot(version uint64, key, expectedValue []byte) {
snap := s.store.GetSnapshot(version) snap := s.store.GetSnapshot(version)
value, err := snap.Get(context.Background(), key) value, err := snap.Get(context.Background(), key)
c.Assert(err, IsNil) s.Nil(err)
c.Assert(value, BytesEquals, expectedValue) s.Equal(value, expectedValue)
} }
func (s *testAsyncCommitCommon) mustGetNoneFromSnapshot(c *C, version uint64, key []byte) { func (s *testAsyncCommitCommon) mustGetNoneFromSnapshot(version uint64, key []byte) {
snap := s.store.GetSnapshot(version) snap := s.store.GetSnapshot(version)
_, err := snap.Get(context.Background(), key) _, err := snap.Get(context.Background(), key)
c.Assert(errors.Cause(err), Equals, tikverr.ErrNotExist) s.Equal(errors.Cause(err), tikverr.ErrNotExist)
} }
func (s *testAsyncCommitCommon) beginAsyncCommitWithLinearizability(c *C) tikv.TxnProbe { func (s *testAsyncCommitCommon) beginAsyncCommitWithLinearizability() tikv.TxnProbe {
txn := s.beginAsyncCommit(c) txn := s.beginAsyncCommit()
txn.SetCausalConsistency(false) txn.SetCausalConsistency(false)
return txn return txn
} }
func (s *testAsyncCommitCommon) beginAsyncCommit(c *C) tikv.TxnProbe { func (s *testAsyncCommitCommon) beginAsyncCommit() tikv.TxnProbe {
txn, err := s.store.Begin() txn, err := s.store.Begin()
c.Assert(err, IsNil) s.Nil(err)
txn.SetEnableAsyncCommit(true) txn.SetEnableAsyncCommit(true)
return tikv.TxnProbe{KVTxn: txn} return tikv.TxnProbe{KVTxn: txn}
} }
func (s *testAsyncCommitCommon) begin(c *C) tikv.TxnProbe { func (s *testAsyncCommitCommon) begin() tikv.TxnProbe {
txn, err := s.store.Begin() txn, err := s.store.Begin()
c.Assert(err, IsNil) s.Nil(err)
return tikv.TxnProbe{KVTxn: txn}
}
func (s *testAsyncCommitCommon) begin1PC() tikv.TxnProbe {
txn, err := s.store.Begin()
s.Nil(err)
txn.SetEnable1PC(true)
return tikv.TxnProbe{KVTxn: txn} return tikv.TxnProbe{KVTxn: txn}
} }
@ -168,16 +181,14 @@ type testAsyncCommitSuite struct {
bo *tikv.Backoffer bo *tikv.Backoffer
} }
var _ = SerialSuites(&testAsyncCommitSuite{}) func (s *testAsyncCommitSuite) SetupTest() {
s.testAsyncCommitCommon.setUpTest()
func (s *testAsyncCommitSuite) SetUpTest(c *C) {
s.testAsyncCommitCommon.setUpTest(c)
s.bo = tikv.NewBackofferWithVars(context.Background(), 5000, nil) s.bo = tikv.NewBackofferWithVars(context.Background(), 5000, nil)
} }
func (s *testAsyncCommitSuite) lockKeysWithAsyncCommit(c *C, keys, values [][]byte, primaryKey, primaryValue []byte, commitPrimary bool) (uint64, uint64) { func (s *testAsyncCommitSuite) lockKeysWithAsyncCommit(keys, values [][]byte, primaryKey, primaryValue []byte, commitPrimary bool) (uint64, uint64) {
txn, err := s.store.Begin() txn, err := s.store.Begin()
c.Assert(err, IsNil) s.Nil(err)
txn.SetEnableAsyncCommit(true) txn.SetEnableAsyncCommit(true)
for i, k := range keys { for i, k := range keys {
if len(values[i]) > 0 { if len(values[i]) > 0 {
@ -185,69 +196,69 @@ func (s *testAsyncCommitSuite) lockKeysWithAsyncCommit(c *C, keys, values [][]by
} else { } else {
err = txn.Delete(k) err = txn.Delete(k)
} }
c.Assert(err, IsNil) s.Nil(err)
} }
if len(primaryValue) > 0 { if len(primaryValue) > 0 {
err = txn.Set(primaryKey, primaryValue) err = txn.Set(primaryKey, primaryValue)
} else { } else {
err = txn.Delete(primaryKey) err = txn.Delete(primaryKey)
} }
c.Assert(err, IsNil) s.Nil(err)
txnProbe := tikv.TxnProbe{KVTxn: txn} txnProbe := tikv.TxnProbe{KVTxn: txn}
tpc, err := txnProbe.NewCommitter(0) tpc, err := txnProbe.NewCommitter(0)
c.Assert(err, IsNil) s.Nil(err)
tpc.SetPrimaryKey(primaryKey) tpc.SetPrimaryKey(primaryKey)
ctx := context.Background() ctx := context.Background()
err = tpc.PrewriteAllMutations(ctx) err = tpc.PrewriteAllMutations(ctx)
c.Assert(err, IsNil) s.Nil(err)
if commitPrimary { if commitPrimary {
commitTS, err := s.store.GetOracle().GetTimestamp(ctx, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) commitTS, err := s.store.GetOracle().GetTimestamp(ctx, &oracle.Option{TxnScope: oracle.GlobalTxnScope})
c.Assert(err, IsNil) s.Nil(err)
tpc.SetCommitTS(commitTS) tpc.SetCommitTS(commitTS)
err = tpc.CommitMutations(ctx) err = tpc.CommitMutations(ctx)
c.Assert(err, IsNil) s.Nil(err)
} }
return txn.StartTS(), tpc.GetCommitTS() return txn.StartTS(), tpc.GetCommitTS()
} }
func (s *testAsyncCommitSuite) TestCheckSecondaries(c *C) { func (s *testAsyncCommitSuite) TestCheckSecondaries() {
// This test doesn't support tikv mode. // This test doesn't support tikv mode.
if *mockstore.WithTiKV { if *mockstore.WithTiKV {
return return
} }
s.putAlphabets(c, true) s.putAlphabets(true)
loc, err := s.store.GetRegionCache().LocateKey(s.bo, []byte("a")) loc, err := s.store.GetRegionCache().LocateKey(s.bo, []byte("a"))
c.Assert(err, IsNil) s.Nil(err)
newRegionID, peerID := s.cluster.AllocID(), s.cluster.AllocID() newRegionID, peerID := s.cluster.AllocID(), s.cluster.AllocID()
s.cluster.Split(loc.Region.GetID(), newRegionID, []byte("e"), []uint64{peerID}, peerID) s.cluster.Split(loc.Region.GetID(), newRegionID, []byte("e"), []uint64{peerID}, peerID)
s.store.GetRegionCache().InvalidateCachedRegion(loc.Region) s.store.GetRegionCache().InvalidateCachedRegion(loc.Region)
// No locks to check, only primary key is locked, should be successful. // No locks to check, only primary key is locked, should be successful.
s.lockKeysWithAsyncCommit(c, [][]byte{}, [][]byte{}, []byte("z"), []byte("z"), false) s.lockKeysWithAsyncCommit([][]byte{}, [][]byte{}, []byte("z"), []byte("z"), false)
lock := s.mustGetLock(c, []byte("z")) lock := s.mustGetLock([]byte("z"))
lock.UseAsyncCommit = true lock.UseAsyncCommit = true
ts, err := s.store.GetOracle().GetTimestamp(context.Background(), &oracle.Option{TxnScope: oracle.GlobalTxnScope}) ts, err := s.store.GetOracle().GetTimestamp(context.Background(), &oracle.Option{TxnScope: oracle.GlobalTxnScope})
c.Assert(err, IsNil) s.Nil(err)
var lockutil tikv.LockProbe var lockutil tikv.LockProbe
status := lockutil.NewLockStatus(nil, true, ts) status := lockutil.NewLockStatus(nil, true, ts)
resolver := tikv.LockResolverProbe{LockResolver: s.store.GetLockResolver()} resolver := tikv.LockResolverProbe{LockResolver: s.store.GetLockResolver()}
err = resolver.ResolveLockAsync(s.bo, lock, status) err = resolver.ResolveLockAsync(s.bo, lock, status)
c.Assert(err, IsNil) s.Nil(err)
currentTS, err := s.store.GetOracle().GetTimestamp(context.Background(), &oracle.Option{TxnScope: oracle.GlobalTxnScope}) currentTS, err := s.store.GetOracle().GetTimestamp(context.Background(), &oracle.Option{TxnScope: oracle.GlobalTxnScope})
c.Assert(err, IsNil) s.Nil(err)
status, err = resolver.GetTxnStatus(s.bo, lock.TxnID, []byte("z"), currentTS, currentTS, true, false, nil) status, err = resolver.GetTxnStatus(s.bo, lock.TxnID, []byte("z"), currentTS, currentTS, true, false, nil)
c.Assert(err, IsNil) s.Nil(err)
c.Assert(status.IsCommitted(), IsTrue) s.True(status.IsCommitted())
c.Assert(status.CommitTS(), Equals, ts) s.Equal(status.CommitTS(), ts)
// One key is committed (i), one key is locked (a). Should get committed. // One key is committed (i), one key is locked (a). Should get committed.
ts, err = s.store.GetOracle().GetTimestamp(context.Background(), &oracle.Option{TxnScope: oracle.GlobalTxnScope}) ts, err = s.store.GetOracle().GetTimestamp(context.Background(), &oracle.Option{TxnScope: oracle.GlobalTxnScope})
c.Assert(err, IsNil) s.Nil(err)
commitTs := ts + 10 commitTs := ts + 10
gotCheckA := int64(0) gotCheckA := int64(0)
@ -313,18 +324,18 @@ func (s *testAsyncCommitSuite) TestCheckSecondaries(c *C) {
MinCommitTS: ts + 5, MinCommitTS: ts + 5,
} }
_ = s.beginAsyncCommit(c) _ = s.beginAsyncCommit()
err = resolver.ResolveLockAsync(s.bo, lock, status) err = resolver.ResolveLockAsync(s.bo, lock, status)
c.Assert(err, IsNil) s.Nil(err)
c.Assert(gotCheckA, Equals, int64(1)) s.Equal(gotCheckA, int64(1))
c.Assert(gotCheckB, Equals, int64(1)) s.Equal(gotCheckB, int64(1))
c.Assert(gotOther, Equals, int64(0)) s.Equal(gotOther, int64(0))
c.Assert(gotResolve, Equals, int64(1)) s.Equal(gotResolve, int64(1))
// One key has been rolled back (b), one is locked (a). Should be rolled back. // One key has been rolled back (b), one is locked (a). Should be rolled back.
ts, err = s.store.GetOracle().GetTimestamp(context.Background(), &oracle.Option{TxnScope: oracle.GlobalTxnScope}) ts, err = s.store.GetOracle().GetTimestamp(context.Background(), &oracle.Option{TxnScope: oracle.GlobalTxnScope})
c.Assert(err, IsNil) s.Nil(err)
commitTs = ts + 10 commitTs = ts + 10
gotCheckA = int64(0) gotCheckA = int64(0)
@ -353,45 +364,45 @@ func (s *testAsyncCommitSuite) TestCheckSecondaries(c *C) {
lock.MinCommitTS = ts + 5 lock.MinCommitTS = ts + 5
err = resolver.ResolveLockAsync(s.bo, lock, status) err = resolver.ResolveLockAsync(s.bo, lock, status)
c.Assert(err, IsNil) s.Nil(err)
c.Assert(gotCheckA, Equals, int64(1)) s.Equal(gotCheckA, int64(1))
c.Assert(gotCheckB, Equals, int64(1)) s.Equal(gotCheckB, int64(1))
c.Assert(gotResolve, Equals, int64(1)) s.Equal(gotResolve, int64(1))
c.Assert(gotOther, Equals, int64(0)) s.Equal(gotOther, int64(0))
} }
func (s *testAsyncCommitSuite) TestRepeatableRead(c *C) { func (s *testAsyncCommitSuite) TestRepeatableRead() {
var sessionID uint64 = 0 var sessionID uint64 = 0
test := func(isPessimistic bool) { test := func(isPessimistic bool) {
s.putKV(c, []byte("k1"), []byte("v1"), true) s.putKV([]byte("k1"), []byte("v1"), true)
sessionID++ sessionID++
ctx := context.WithValue(context.Background(), util.SessionID, sessionID) ctx := context.WithValue(context.Background(), util.SessionID, sessionID)
txn1 := s.beginAsyncCommit(c) txn1 := s.beginAsyncCommit()
txn1.SetPessimistic(isPessimistic) txn1.SetPessimistic(isPessimistic)
s.mustGetFromTxn(c, txn1, []byte("k1"), []byte("v1")) s.mustGetFromTxn(txn1, []byte("k1"), []byte("v1"))
txn1.Set([]byte("k1"), []byte("v2")) txn1.Set([]byte("k1"), []byte("v2"))
for i := 0; i < 20; i++ { for i := 0; i < 20; i++ {
_, err := s.store.GetOracle().GetTimestamp(ctx, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) _, err := s.store.GetOracle().GetTimestamp(ctx, &oracle.Option{TxnScope: oracle.GlobalTxnScope})
c.Assert(err, IsNil) s.Nil(err)
} }
txn2 := s.beginAsyncCommit(c) txn2 := s.beginAsyncCommit()
s.mustGetFromTxn(c, txn2, []byte("k1"), []byte("v1")) s.mustGetFromTxn(txn2, []byte("k1"), []byte("v1"))
err := txn1.Commit(ctx) err := txn1.Commit(ctx)
c.Assert(err, IsNil) s.Nil(err)
// Check txn1 is committed in async commit. // Check txn1 is committed in async commit.
c.Assert(txn1.IsAsyncCommit(), IsTrue) s.True(txn1.IsAsyncCommit())
s.mustGetFromTxn(c, txn2, []byte("k1"), []byte("v1")) s.mustGetFromTxn(txn2, []byte("k1"), []byte("v1"))
err = txn2.Rollback() err = txn2.Rollback()
c.Assert(err, IsNil) s.Nil(err)
txn3 := s.beginAsyncCommit(c) txn3 := s.beginAsyncCommit()
s.mustGetFromTxn(c, txn3, []byte("k1"), []byte("v2")) s.mustGetFromTxn(txn3, []byte("k1"), []byte("v2"))
err = txn3.Rollback() err = txn3.Rollback()
c.Assert(err, IsNil) s.Nil(err)
} }
test(false) test(false)
@ -400,69 +411,69 @@ func (s *testAsyncCommitSuite) TestRepeatableRead(c *C) {
// It's just a simple validation of linearizability. // It's just a simple validation of linearizability.
// Extra tests are needed to test this feature with the control of the TiKV cluster. // Extra tests are needed to test this feature with the control of the TiKV cluster.
func (s *testAsyncCommitSuite) TestAsyncCommitLinearizability(c *C) { func (s *testAsyncCommitSuite) TestAsyncCommitLinearizability() {
t1 := s.beginAsyncCommitWithLinearizability(c) t1 := s.beginAsyncCommitWithLinearizability()
t2 := s.beginAsyncCommitWithLinearizability(c) t2 := s.beginAsyncCommitWithLinearizability()
err := t1.Set([]byte("a"), []byte("a1")) err := t1.Set([]byte("a"), []byte("a1"))
c.Assert(err, IsNil) s.Nil(err)
err = t2.Set([]byte("b"), []byte("b1")) err = t2.Set([]byte("b"), []byte("b1"))
c.Assert(err, IsNil) s.Nil(err)
ctx := context.WithValue(context.Background(), util.SessionID, uint64(1)) ctx := context.WithValue(context.Background(), util.SessionID, uint64(1))
// t2 commits earlier than t1 // t2 commits earlier than t1
err = t2.Commit(ctx) err = t2.Commit(ctx)
c.Assert(err, IsNil) s.Nil(err)
err = t1.Commit(ctx) err = t1.Commit(ctx)
c.Assert(err, IsNil) s.Nil(err)
commitTS1 := t1.GetCommitTS() commitTS1 := t1.GetCommitTS()
commitTS2 := t2.GetCommitTS() commitTS2 := t2.GetCommitTS()
c.Assert(commitTS2, Less, commitTS1) s.Less(commitTS2, commitTS1)
} }
// TestAsyncCommitWithMultiDC tests that async commit can only be enabled in global transactions // TestAsyncCommitWithMultiDC tests that async commit can only be enabled in global transactions
func (s *testAsyncCommitSuite) TestAsyncCommitWithMultiDC(c *C) { func (s *testAsyncCommitSuite) TestAsyncCommitWithMultiDC() {
// It requires setting placement rules to run with TiKV // It requires setting placement rules to run with TiKV
if *mockstore.WithTiKV { if *mockstore.WithTiKV {
return return
} }
localTxn := s.beginAsyncCommit(c) localTxn := s.beginAsyncCommit()
err := localTxn.Set([]byte("a"), []byte("a1")) err := localTxn.Set([]byte("a"), []byte("a1"))
localTxn.SetScope("bj") localTxn.SetScope("bj")
c.Assert(err, IsNil) s.Nil(err)
ctx := context.WithValue(context.Background(), util.SessionID, uint64(1)) ctx := context.WithValue(context.Background(), util.SessionID, uint64(1))
err = localTxn.Commit(ctx) err = localTxn.Commit(ctx)
c.Assert(err, IsNil) s.Nil(err)
c.Assert(localTxn.IsAsyncCommit(), IsFalse) s.False(localTxn.IsAsyncCommit())
globalTxn := s.beginAsyncCommit(c) globalTxn := s.beginAsyncCommit()
err = globalTxn.Set([]byte("b"), []byte("b1")) err = globalTxn.Set([]byte("b"), []byte("b1"))
globalTxn.SetScope(oracle.GlobalTxnScope) globalTxn.SetScope(oracle.GlobalTxnScope)
c.Assert(err, IsNil) s.Nil(err)
err = globalTxn.Commit(ctx) err = globalTxn.Commit(ctx)
c.Assert(err, IsNil) s.Nil(err)
c.Assert(globalTxn.IsAsyncCommit(), IsTrue) s.True(globalTxn.IsAsyncCommit())
} }
func (s *testAsyncCommitSuite) TestResolveTxnFallbackFromAsyncCommit(c *C) { func (s *testAsyncCommitSuite) TestResolveTxnFallbackFromAsyncCommit() {
keys := [][]byte{[]byte("k0"), []byte("k1")} keys := [][]byte{[]byte("k0"), []byte("k1")}
values := [][]byte{[]byte("v00"), []byte("v10")} values := [][]byte{[]byte("v00"), []byte("v10")}
initTest := func() tikv.CommitterProbe { initTest := func() tikv.CommitterProbe {
t0 := s.begin(c) t0 := s.begin()
err := t0.Set(keys[0], values[0]) err := t0.Set(keys[0], values[0])
c.Assert(err, IsNil) s.Nil(err)
err = t0.Set(keys[1], values[1]) err = t0.Set(keys[1], values[1])
c.Assert(err, IsNil) s.Nil(err)
err = t0.Commit(context.Background()) err = t0.Commit(context.Background())
c.Assert(err, IsNil) s.Nil(err)
t1 := s.beginAsyncCommit(c) t1 := s.beginAsyncCommit()
err = t1.Set(keys[0], []byte("v01")) err = t1.Set(keys[0], []byte("v01"))
c.Assert(err, IsNil) s.Nil(err)
err = t1.Set(keys[1], []byte("v11")) err = t1.Set(keys[1], []byte("v11"))
c.Assert(err, IsNil) s.Nil(err)
committer, err := t1.NewCommitter(1) committer, err := t1.NewCommitter(1)
c.Assert(err, IsNil) s.Nil(err)
committer.SetLockTTL(1) committer.SetLockTTL(1)
committer.SetUseAsyncCommit() committer.SetUseAsyncCommit()
return committer return committer
@ -470,23 +481,23 @@ func (s *testAsyncCommitSuite) TestResolveTxnFallbackFromAsyncCommit(c *C) {
prewriteKey := func(committer tikv.CommitterProbe, idx int, fallback bool) { prewriteKey := func(committer tikv.CommitterProbe, idx int, fallback bool) {
bo := tikv.NewBackofferWithVars(context.Background(), 5000, nil) bo := tikv.NewBackofferWithVars(context.Background(), 5000, nil)
loc, err := s.store.GetRegionCache().LocateKey(bo, keys[idx]) loc, err := s.store.GetRegionCache().LocateKey(bo, keys[idx])
c.Assert(err, IsNil) s.Nil(err)
req := committer.BuildPrewriteRequest(loc.Region.GetID(), loc.Region.GetConfVer(), loc.Region.GetVer(), req := committer.BuildPrewriteRequest(loc.Region.GetID(), loc.Region.GetConfVer(), loc.Region.GetVer(),
committer.GetMutations().Slice(idx, idx+1), 1) committer.GetMutations().Slice(idx, idx+1), 1)
if fallback { if fallback {
req.Req.(*kvrpcpb.PrewriteRequest).MaxCommitTs = 1 req.Req.(*kvrpcpb.PrewriteRequest).MaxCommitTs = 1
} }
resp, err := s.store.SendReq(bo, req, loc.Region, 5000) resp, err := s.store.SendReq(bo, req, loc.Region, 5000)
c.Assert(err, IsNil) s.Nil(err)
c.Assert(resp.Resp, NotNil) s.NotNil(resp.Resp)
} }
readKey := func(idx int) { readKey := func(idx int) {
t2 := s.begin(c) t2 := s.begin()
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel() defer cancel()
val, err := t2.Get(ctx, keys[idx]) val, err := t2.Get(ctx, keys[idx])
c.Assert(err, IsNil) s.Nil(err)
c.Assert(val, DeepEquals, values[idx]) s.Equal(val, values[idx])
} }
// Case 1: Fallback primary, read primary // Case 1: Fallback primary, read primary