diff --git a/internal/client/client_async.go b/internal/client/client_async.go index bbd14cc0..3c52f91b 100644 --- a/internal/client/client_async.go +++ b/internal/client/client_async.go @@ -49,12 +49,6 @@ func (c *RPCClient) SendRequestAsync(ctx context.Context, addr string, req *tikv return } - batchReq := req.ToBatchCommandsRequest() - if batchReq == nil { - cb.Invoke(nil, errors.New("unsupported request type: "+req.Type.String())) - return - } - regionRPC := trace.StartRegion(ctx, req.Type.String()) spanRPC := opentracing.SpanFromContext(ctx) if spanRPC != nil && spanRPC.Tracer() != nil { @@ -72,6 +66,13 @@ func (c *RPCClient) SendRequestAsync(ctx context.Context, addr string, req *tikv } tikvrpc.AttachContext(req, req.Context) + // ToBatchCommandsRequest should be called after all modifications to req are done. + batchReq := req.ToBatchCommandsRequest() + if batchReq == nil { + cb.Invoke(nil, errors.New("unsupported request type: "+req.Type.String())) + return + } + // TODO(zyguan): If the client created `WithGRPCDialOptions(grpc.WithBlock())`, `getConnPool` might be blocked for // a while when the corresponding conn array is uninitialized. However, since tidb won't set this option, we just // keep `getConnPool` synchronous for now. diff --git a/internal/client/client_async_test.go b/internal/client/client_async_test.go index 9a7bf5c0..020a6409 100644 --- a/internal/client/client_async_test.go +++ b/internal/client/client_async_test.go @@ -16,12 +16,15 @@ package client import ( "context" + "math" "sync/atomic" "testing" "time" "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/errorpb" "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/tikvpb" "github.com/pkg/errors" "github.com/stretchr/testify/require" @@ -97,6 +100,53 @@ func TestSendRequestAsyncBasic(t *testing.T) { }) } +func TestSendRequestAsyncAttachContext(t *testing.T) { + ctx := context.Background() + srv, port := mockserver.StartMockTikvService() + require.True(t, port > 0) + require.True(t, srv.IsRunning()) + addr := srv.Addr() + + cli := NewRPCClient() + defer func() { + cli.Close() + srv.Stop() + }() + + handle := func(req *tikvpb.BatchCommandsRequest) (*tikvpb.BatchCommandsResponse, error) { + ids := req.GetRequestIds() + require.Len(t, ids, 1) + getReq := req.GetRequests()[0].GetGet() + var getResp *kvrpcpb.GetResponse + if getReq.GetContext().GetRegionId() == 0 { + getResp = &kvrpcpb.GetResponse{RegionError: &errorpb.Error{RegionNotFound: &errorpb.RegionNotFound{}}} + } else { + getResp = &kvrpcpb.GetResponse{Value: getReq.Key} + } + return &tikvpb.BatchCommandsResponse{RequestIds: ids, Responses: []*tikvpb.BatchCommandsResponse_Response{{Cmd: &tikvpb.BatchCommandsResponse_Response_Get{Get: getResp}}}}, nil + } + srv.OnBatchCommandsRequest.Store(&handle) + + called := false + rl := async.NewRunLoop() + cb := async.NewCallback(rl, func(resp *tikvrpc.Response, err error) { + called = true + require.NoError(t, err) + getResp := resp.Resp.(*kvrpcpb.GetResponse) + require.Nil(t, getResp.GetRegionError()) + require.Equal(t, []byte("foo"), getResp.Value) + }) + req := tikvrpc.NewRequest(tikvrpc.CmdGet, &kvrpcpb.GetRequest{Key: []byte("foo"), Version: math.MaxUint64}) + + require.Zero(t, req.Context.RegionId) + tikvrpc.AttachContext(req, req.Context) + tikvrpc.SetContextNoAttach(req, &metapb.Region{Id: 1}, &metapb.Peer{}) + + cli.SendRequestAsync(ctx, addr, req, cb) + rl.Exec(ctx) + require.True(t, called) +} + func TestSendRequestAsyncTimeout(t *testing.T) { ctx := context.Background() srv, port := mockserver.StartMockTikvService()