client: fix the bug of attach-context in async api (#1747)

Signed-off-by: zyguan <zhongyangguan@gmail.com>
This commit is contained in:
zyguan 2025-08-28 14:52:01 +08:00 committed by GitHub
parent 82ff387182
commit 2d0237392a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 57 additions and 6 deletions

View File

@ -49,12 +49,6 @@ func (c *RPCClient) SendRequestAsync(ctx context.Context, addr string, req *tikv
return return
} }
batchReq := req.ToBatchCommandsRequest()
if batchReq == nil {
cb.Invoke(nil, errors.New("unsupported request type: "+req.Type.String()))
return
}
regionRPC := trace.StartRegion(ctx, req.Type.String()) regionRPC := trace.StartRegion(ctx, req.Type.String())
spanRPC := opentracing.SpanFromContext(ctx) spanRPC := opentracing.SpanFromContext(ctx)
if spanRPC != nil && spanRPC.Tracer() != nil { if spanRPC != nil && spanRPC.Tracer() != nil {
@ -72,6 +66,13 @@ func (c *RPCClient) SendRequestAsync(ctx context.Context, addr string, req *tikv
} }
tikvrpc.AttachContext(req, req.Context) tikvrpc.AttachContext(req, req.Context)
// ToBatchCommandsRequest should be called after all modifications to req are done.
batchReq := req.ToBatchCommandsRequest()
if batchReq == nil {
cb.Invoke(nil, errors.New("unsupported request type: "+req.Type.String()))
return
}
// TODO(zyguan): If the client created `WithGRPCDialOptions(grpc.WithBlock())`, `getConnPool` might be blocked for // TODO(zyguan): If the client created `WithGRPCDialOptions(grpc.WithBlock())`, `getConnPool` might be blocked for
// a while when the corresponding conn array is uninitialized. However, since tidb won't set this option, we just // a while when the corresponding conn array is uninitialized. However, since tidb won't set this option, we just
// keep `getConnPool` synchronous for now. // keep `getConnPool` synchronous for now.

View File

@ -16,12 +16,15 @@ package client
import ( import (
"context" "context"
"math"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
"github.com/pingcap/failpoint" "github.com/pingcap/failpoint"
"github.com/pingcap/kvproto/pkg/errorpb"
"github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/kvproto/pkg/kvrpcpb"
"github.com/pingcap/kvproto/pkg/metapb"
"github.com/pingcap/kvproto/pkg/tikvpb" "github.com/pingcap/kvproto/pkg/tikvpb"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -97,6 +100,53 @@ func TestSendRequestAsyncBasic(t *testing.T) {
}) })
} }
func TestSendRequestAsyncAttachContext(t *testing.T) {
ctx := context.Background()
srv, port := mockserver.StartMockTikvService()
require.True(t, port > 0)
require.True(t, srv.IsRunning())
addr := srv.Addr()
cli := NewRPCClient()
defer func() {
cli.Close()
srv.Stop()
}()
handle := func(req *tikvpb.BatchCommandsRequest) (*tikvpb.BatchCommandsResponse, error) {
ids := req.GetRequestIds()
require.Len(t, ids, 1)
getReq := req.GetRequests()[0].GetGet()
var getResp *kvrpcpb.GetResponse
if getReq.GetContext().GetRegionId() == 0 {
getResp = &kvrpcpb.GetResponse{RegionError: &errorpb.Error{RegionNotFound: &errorpb.RegionNotFound{}}}
} else {
getResp = &kvrpcpb.GetResponse{Value: getReq.Key}
}
return &tikvpb.BatchCommandsResponse{RequestIds: ids, Responses: []*tikvpb.BatchCommandsResponse_Response{{Cmd: &tikvpb.BatchCommandsResponse_Response_Get{Get: getResp}}}}, nil
}
srv.OnBatchCommandsRequest.Store(&handle)
called := false
rl := async.NewRunLoop()
cb := async.NewCallback(rl, func(resp *tikvrpc.Response, err error) {
called = true
require.NoError(t, err)
getResp := resp.Resp.(*kvrpcpb.GetResponse)
require.Nil(t, getResp.GetRegionError())
require.Equal(t, []byte("foo"), getResp.Value)
})
req := tikvrpc.NewRequest(tikvrpc.CmdGet, &kvrpcpb.GetRequest{Key: []byte("foo"), Version: math.MaxUint64})
require.Zero(t, req.Context.RegionId)
tikvrpc.AttachContext(req, req.Context)
tikvrpc.SetContextNoAttach(req, &metapb.Region{Id: 1}, &metapb.Peer{})
cli.SendRequestAsync(ctx, addr, req, cb)
rl.Exec(ctx)
require.True(t, called)
}
func TestSendRequestAsyncTimeout(t *testing.T) { func TestSendRequestAsyncTimeout(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srv, port := mockserver.StartMockTikvService() srv, port := mockserver.StartMockTikvService()