diff --git a/internal/apicodec/codec.go b/internal/apicodec/codec.go index 7945a5a9..54b5840c 100644 --- a/internal/apicodec/codec.go +++ b/internal/apicodec/codec.go @@ -92,28 +92,27 @@ func attachAPICtx(c Codec, req *tikvrpc.Request) *tikvrpc.Request { // Shallow copy the request to avoid concurrent modification. r := *req - ctx := &r.Context - ctx.ApiVersion = c.GetAPIVersion() - ctx.KeyspaceId = uint32(c.GetKeyspaceID()) + r.Context.ApiVersion = c.GetAPIVersion() + r.Context.KeyspaceId = uint32(c.GetKeyspaceID()) switch r.Type { case tikvrpc.CmdMPPTask: mpp := *r.DispatchMPPTask() // Shallow copy the meta to avoid concurrent modification. meta := *mpp.Meta - meta.KeyspaceId = ctx.KeyspaceId - meta.ApiVersion = ctx.ApiVersion + meta.KeyspaceId = r.Context.KeyspaceId + meta.ApiVersion = r.Context.ApiVersion mpp.Meta = &meta r.Req = &mpp case tikvrpc.CmdCompact: compact := *r.Compact() - compact.KeyspaceId = ctx.KeyspaceId - compact.ApiVersion = ctx.ApiVersion + compact.KeyspaceId = r.Context.KeyspaceId + compact.ApiVersion = r.Context.ApiVersion r.Req = &compact } - tikvrpc.AttachContext(&r, ctx) + tikvrpc.AttachContext(&r, r.Context) return &r } diff --git a/internal/client/client_batch.go b/internal/client/client_batch.go index a8091094..95aed901 100644 --- a/internal/client/client_batch.go +++ b/internal/client/client_batch.go @@ -298,6 +298,9 @@ func (a *batchConn) fetchMorePendingRequests( const idleTimeout = 3 * time.Minute +// BatchSendLoopPanicCounter is only used for testing. +var BatchSendLoopPanicCounter int64 = 0 + func (a *batchConn) batchSendLoop(cfg config.TiKVClient) { defer func() { if r := recover(); r != nil { @@ -305,7 +308,8 @@ func (a *batchConn) batchSendLoop(cfg config.TiKVClient) { logutil.BgLogger().Error("batchSendLoop", zap.Any("r", r), zap.Stack("stack")) - logutil.BgLogger().Info("restart batchSendLoop") + atomic.AddInt64(&BatchSendLoopPanicCounter, 1) + logutil.BgLogger().Info("restart batchSendLoop", zap.Int64("count", atomic.LoadInt64(&BatchSendLoopPanicCounter))) go a.batchSendLoop(cfg) } }() diff --git a/internal/locate/region_request_test.go b/internal/locate/region_request_test.go index 6c97be87..be7ad625 100644 --- a/internal/locate/region_request_test.go +++ b/internal/locate/region_request_test.go @@ -37,8 +37,10 @@ package locate import ( "context" "fmt" + "math/rand" "net" "sync" + "sync/atomic" "testing" "time" "unsafe" @@ -733,3 +735,52 @@ func (s *testRegionRequestToSingleStoreSuite) TestKVReadTimeoutWithDisableBatchC s.True(IsFakeRegionError(regionErr)) s.Equal(0, bo.GetTotalBackoffTimes()) // use kv read timeout will do fast retry, so backoff times should be 0. } + +func (s *testRegionRequestToSingleStoreSuite) TestBatchClientSendLoopPanic() { + // This test should use `go test -race` to run. + config.UpdateGlobal(func(conf *config.Config) { + conf.TiKVClient.MaxBatchSize = 128 + })() + + server, port := mock_server.StartMockTikvService() + s.True(port > 0) + rpcClient := client.NewRPCClient() + fnClient := &fnClient{fn: func(ctx context.Context, addr string, req *tikvrpc.Request, timeout time.Duration) (response *tikvrpc.Response, err error) { + return rpcClient.SendRequest(ctx, server.Addr(), req, timeout) + }} + tf := func(s *Store, bo *retry.Backoffer) livenessState { + return reachable + } + + defer func() { + rpcClient.Close() + server.Stop() + }() + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 100; j++ { + ctx, cancel := context.WithCancel(context.Background()) + bo := retry.NewBackofferWithVars(ctx, int(client.ReadTimeoutShort.Milliseconds()), nil) + region, err := s.cache.LocateRegionByID(bo, s.region) + s.Nil(err) + s.NotNil(region) + go func() { + // mock for kill query execution or timeout. + time.Sleep(time.Millisecond * time.Duration(rand.Intn(5)+1)) + cancel() + }() + req := tikvrpc.NewRequest(tikvrpc.CmdCop, &coprocessor.Request{Data: []byte("a"), StartTs: 1}) + regionRequestSender := NewRegionRequestSender(s.cache, fnClient) + regionRequestSender.regionCache.testingKnobs.mockRequestLiveness.Store((*livenessFunc)(&tf)) + regionRequestSender.SendReq(bo, req, region.Region, client.ReadTimeoutShort) + } + }() + } + wg.Wait() + // batchSendLoop should not panic. + s.Equal(atomic.LoadInt64(&client.BatchSendLoopPanicCounter), int64(0)) +} diff --git a/tikvrpc/tikvrpc.go b/tikvrpc/tikvrpc.go index 83119c3d..13e85706 100644 --- a/tikvrpc/tikvrpc.go +++ b/tikvrpc/tikvrpc.go @@ -709,7 +709,9 @@ type MPPStreamResponse struct { // AttachContext sets the request context to the request, // return false if encounter unknown request type. -func AttachContext(req *Request, ctx *kvrpcpb.Context) bool { +// Parameter `rpcCtx` use `kvrpcpb.Context` instead of `*kvrpcpb.Context` to avoid concurrent modification by shallow copy. +func AttachContext(req *Request, rpcCtx kvrpcpb.Context) bool { + ctx := &rpcCtx switch req.Type { case CmdGet: req.Get().Context = ctx @@ -807,13 +809,14 @@ func AttachContext(req *Request, ctx *kvrpcpb.Context) bool { // SetContext set the Context field for the given req to the specified ctx. func SetContext(req *Request, region *metapb.Region, peer *metapb.Peer) error { - ctx := &req.Context if region != nil { - ctx.RegionId = region.Id - ctx.RegionEpoch = region.RegionEpoch + req.Context.RegionId = region.Id + req.Context.RegionEpoch = region.RegionEpoch } - ctx.Peer = peer - if !AttachContext(req, ctx) { + req.Context.Peer = peer + + // Shallow copy the context to avoid concurrent modification. + if !AttachContext(req, req.Context) { return errors.Errorf("invalid request type %v", req.Type) } return nil