client: fix the bug of attach-context in async api (#1747)

Signed-off-by: zyguan <zhongyangguan@gmail.com>
This commit is contained in:
zyguan 2025-08-28 14:52:01 +08:00 committed by GitHub
parent 82ff387182
commit 2d0237392a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 57 additions and 6 deletions

View File

@ -49,12 +49,6 @@ func (c *RPCClient) SendRequestAsync(ctx context.Context, addr string, req *tikv
return
}
batchReq := req.ToBatchCommandsRequest()
if batchReq == nil {
cb.Invoke(nil, errors.New("unsupported request type: "+req.Type.String()))
return
}
regionRPC := trace.StartRegion(ctx, req.Type.String())
spanRPC := opentracing.SpanFromContext(ctx)
if spanRPC != nil && spanRPC.Tracer() != nil {
@ -72,6 +66,13 @@ func (c *RPCClient) SendRequestAsync(ctx context.Context, addr string, req *tikv
}
tikvrpc.AttachContext(req, req.Context)
// ToBatchCommandsRequest should be called after all modifications to req are done.
batchReq := req.ToBatchCommandsRequest()
if batchReq == nil {
cb.Invoke(nil, errors.New("unsupported request type: "+req.Type.String()))
return
}
// TODO(zyguan): If the client created `WithGRPCDialOptions(grpc.WithBlock())`, `getConnPool` might be blocked for
// a while when the corresponding conn array is uninitialized. However, since tidb won't set this option, we just
// keep `getConnPool` synchronous for now.

View File

@ -16,12 +16,15 @@ package client
import (
"context"
"math"
"sync/atomic"
"testing"
"time"
"github.com/pingcap/failpoint"
"github.com/pingcap/kvproto/pkg/errorpb"
"github.com/pingcap/kvproto/pkg/kvrpcpb"
"github.com/pingcap/kvproto/pkg/metapb"
"github.com/pingcap/kvproto/pkg/tikvpb"
"github.com/pkg/errors"
"github.com/stretchr/testify/require"
@ -97,6 +100,53 @@ func TestSendRequestAsyncBasic(t *testing.T) {
})
}
func TestSendRequestAsyncAttachContext(t *testing.T) {
ctx := context.Background()
srv, port := mockserver.StartMockTikvService()
require.True(t, port > 0)
require.True(t, srv.IsRunning())
addr := srv.Addr()
cli := NewRPCClient()
defer func() {
cli.Close()
srv.Stop()
}()
handle := func(req *tikvpb.BatchCommandsRequest) (*tikvpb.BatchCommandsResponse, error) {
ids := req.GetRequestIds()
require.Len(t, ids, 1)
getReq := req.GetRequests()[0].GetGet()
var getResp *kvrpcpb.GetResponse
if getReq.GetContext().GetRegionId() == 0 {
getResp = &kvrpcpb.GetResponse{RegionError: &errorpb.Error{RegionNotFound: &errorpb.RegionNotFound{}}}
} else {
getResp = &kvrpcpb.GetResponse{Value: getReq.Key}
}
return &tikvpb.BatchCommandsResponse{RequestIds: ids, Responses: []*tikvpb.BatchCommandsResponse_Response{{Cmd: &tikvpb.BatchCommandsResponse_Response_Get{Get: getResp}}}}, nil
}
srv.OnBatchCommandsRequest.Store(&handle)
called := false
rl := async.NewRunLoop()
cb := async.NewCallback(rl, func(resp *tikvrpc.Response, err error) {
called = true
require.NoError(t, err)
getResp := resp.Resp.(*kvrpcpb.GetResponse)
require.Nil(t, getResp.GetRegionError())
require.Equal(t, []byte("foo"), getResp.Value)
})
req := tikvrpc.NewRequest(tikvrpc.CmdGet, &kvrpcpb.GetRequest{Key: []byte("foo"), Version: math.MaxUint64})
require.Zero(t, req.Context.RegionId)
tikvrpc.AttachContext(req, req.Context)
tikvrpc.SetContextNoAttach(req, &metapb.Region{Id: 1}, &metapb.Peer{})
cli.SendRequestAsync(ctx, addr, req, cb)
rl.Exec(ctx)
require.True(t, called)
}
func TestSendRequestAsyncTimeout(t *testing.T) {
ctx := context.Background()
srv, port := mockserver.StartMockTikvService()