diff --git a/integration_tests/split_test.go b/integration_tests/split_test.go index 56957bf0..b3841db5 100644 --- a/integration_tests/split_test.go +++ b/integration_tests/split_test.go @@ -39,6 +39,7 @@ import ( "sync" "testing" + "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/keyspacepb" "github.com/pingcap/kvproto/pkg/meta_storagepb" "github.com/pingcap/kvproto/pkg/metapb" @@ -172,6 +173,13 @@ func (s *testSplitSuite) TestBatchGetUsingAsyncAPI() { s.Equal([]byte("a"), m[string([]byte{'a'})]) s.Equal([]byte("c"), m[string([]byte{'c'})]) s.NotContains(m, string([]byte{'b'})) + + // inject an error on sending request. + failpoint.Enable("tikvclient/tikvStoreSendReqResult", `1*return("timeout")`) + defer failpoint.Disable("tikvclient/tikvStoreSendReqResult") + txn = s.begin() + _, err = txn.GetSnapshot().BatchGet(context.TODO(), [][]byte{{'a'}, {'b'}, {'c'}}) + s.Error(err) } func (s *testSplitSuite) TestStaleEpoch() { diff --git a/internal/locate/region_request.go b/internal/locate/region_request.go index 313fcb54..345970ca 100644 --- a/internal/locate/region_request.go +++ b/internal/locate/region_request.go @@ -470,6 +470,14 @@ func (s *RegionRequestSender) SendReqAsync( cb async.Callback[*tikvrpc.ResponseExt], opts ...StoreSelectorOption, ) { + if resp, err := failpointSendReqResult(req, tikvrpc.TiKV); err != nil || resp != nil { + var re *tikvrpc.ResponseExt + if resp != nil { + re = &tikvrpc.ResponseExt{Response: *resp} + } + cb.Invoke(re, err) + return + } if err := s.validateReadTS(bo.GetCtx(), req); err != nil { logutil.Logger(bo.GetCtx()).Error("validate read ts failed for request", zap.Stringer("reqType", req.Type), zap.Stringer("req", req.Req.(fmt.Stringer)), zap.Stringer("context", &req.Context), zap.Stack("stack"), zap.Error(err)) cb.Invoke(nil, err) diff --git a/txnkv/txnsnapshot/snapshot_async.go b/txnkv/txnsnapshot/snapshot_async.go index b8934b57..c2813594 100644 --- a/txnkv/txnsnapshot/snapshot_async.go +++ b/txnkv/txnsnapshot/snapshot_async.go @@ -75,8 +75,8 @@ func (s *KVSnapshot) asyncBatchGetByRegions( })) } for completed < len(batches) { - _, err = runloop.Exec(bo.GetCtx()) - if err != nil { + if _, e := runloop.Exec(bo.GetCtx()); e != nil { + err = errors.WithStack(e) break } }