client: split large files and rename unclear fields (#1700)

Signed-off-by: Lynn <zimu_xia@126.com>
This commit is contained in:
Lynn 2025-07-10 12:19:15 +08:00 committed by GitHub
parent 4fd3c42d69
commit e60fec1b25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 767 additions and 687 deletions

View File

@ -47,7 +47,6 @@ import (
"sync/atomic"
"time"
grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
"github.com/opentracing/opentracing-go"
"github.com/pingcap/kvproto/pkg/coprocessor"
"github.com/pingcap/kvproto/pkg/debugpb"
@ -57,7 +56,6 @@ import (
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
"github.com/tikv/client-go/v2/config"
tikverr "github.com/tikv/client-go/v2/error"
"github.com/tikv/client-go/v2/internal/apicodec"
"github.com/tikv/client-go/v2/internal/logutil"
"github.com/tikv/client-go/v2/metrics"
@ -66,13 +64,7 @@ import (
"github.com/tikv/client-go/v2/util/async"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/backoff"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/encoding/gzip"
"google.golang.org/grpc/experimental"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
)
@ -144,306 +136,14 @@ func (e *ErrConn) Unwrap() error {
return e.Err
}
func WrapErrConn(err error, conn *connArray) error {
func WrapErrConn(err error, pool *connPool) error {
if err == nil {
return nil
}
return &ErrConn{
Err: err,
Addr: conn.target,
Ver: conn.ver,
}
}
type connArray struct {
// The target host.
target string
// version of the connection array, increase by 1 when reconnect.
ver uint64
index uint32
v []*monitoredConn
// streamTimeout binds with a background goroutine to process coprocessor streaming timeout.
streamTimeout chan *tikvrpc.Lease
dialTimeout time.Duration
// batchConn is not null when batch is enabled.
*batchConn
done chan struct{}
monitor *connMonitor
metrics struct {
rpcLatHist *rpcMetrics
rpcSrcLatSum sync.Map
rpcNetLatExternal prometheus.Observer
rpcNetLatInternal prometheus.Observer
}
}
func newConnArray(maxSize uint, addr string, ver uint64, security config.Security,
idleNotify *uint32, enableBatch bool, dialTimeout time.Duration, m *connMonitor, eventListener *atomic.Pointer[ClientEventListener], opts []grpc.DialOption) (*connArray, error) {
a := &connArray{
ver: ver,
index: 0,
v: make([]*monitoredConn, maxSize),
streamTimeout: make(chan *tikvrpc.Lease, 1024),
done: make(chan struct{}),
dialTimeout: dialTimeout,
monitor: m,
}
a.metrics.rpcLatHist = deriveRPCMetrics(metrics.TiKVSendReqHistogram.MustCurryWith(prometheus.Labels{metrics.LblStore: addr}))
a.metrics.rpcNetLatExternal = metrics.TiKVRPCNetLatencyHistogram.WithLabelValues(addr, "false")
a.metrics.rpcNetLatInternal = metrics.TiKVRPCNetLatencyHistogram.WithLabelValues(addr, "true")
if err := a.Init(addr, security, idleNotify, enableBatch, eventListener, opts...); err != nil {
return nil, err
}
return a, nil
}
type connMonitor struct {
m sync.Map
loopOnce sync.Once
stopOnce sync.Once
stop chan struct{}
}
func (c *connMonitor) AddConn(conn *monitoredConn) {
c.m.Store(conn.Name, conn)
}
func (c *connMonitor) RemoveConn(conn *monitoredConn) {
c.m.Delete(conn.Name)
for state := connectivity.Idle; state <= connectivity.Shutdown; state++ {
metrics.TiKVGrpcConnectionState.WithLabelValues(conn.Name, conn.Target(), state.String()).Set(0)
}
}
func (c *connMonitor) Start() {
c.loopOnce.Do(
func() {
c.stop = make(chan struct{})
go c.start()
},
)
}
func (c *connMonitor) Stop() {
c.stopOnce.Do(
func() {
if c.stop != nil {
close(c.stop)
}
},
)
}
func (c *connMonitor) start() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
c.m.Range(func(_, value interface{}) bool {
conn := value.(*monitoredConn)
nowState := conn.GetState()
for state := connectivity.Idle; state <= connectivity.Shutdown; state++ {
if state == nowState {
metrics.TiKVGrpcConnectionState.WithLabelValues(conn.Name, conn.Target(), nowState.String()).Set(1)
} else {
metrics.TiKVGrpcConnectionState.WithLabelValues(conn.Name, conn.Target(), state.String()).Set(0)
}
}
return true
})
case <-c.stop:
return
}
}
}
type monitoredConn struct {
*grpc.ClientConn
Name string
}
func (a *connArray) monitoredDial(ctx context.Context, connName, target string, opts ...grpc.DialOption) (conn *monitoredConn, err error) {
conn = &monitoredConn{
Name: connName,
}
conn.ClientConn, err = grpc.DialContext(ctx, target, opts...)
if err != nil {
return nil, err
}
a.monitor.AddConn(conn)
return conn, nil
}
func (c *monitoredConn) Close() error {
if c.ClientConn != nil {
err := c.ClientConn.Close()
logutil.BgLogger().Debug("close gRPC connection", zap.String("target", c.Name), zap.Error(err))
return err
}
return nil
}
func (a *connArray) Init(addr string, security config.Security, idleNotify *uint32, enableBatch bool, eventListener *atomic.Pointer[ClientEventListener], opts ...grpc.DialOption) error {
a.target = addr
opt := grpc.WithTransportCredentials(insecure.NewCredentials())
if len(security.ClusterSSLCA) != 0 {
tlsConfig, err := security.ToTLSConfig()
if err != nil {
return errors.WithStack(err)
}
opt = grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))
}
cfg := config.GetGlobalConfig()
var (
unaryInterceptor grpc.UnaryClientInterceptor
streamInterceptor grpc.StreamClientInterceptor
)
if cfg.OpenTracingEnable {
unaryInterceptor = grpc_opentracing.UnaryClientInterceptor()
streamInterceptor = grpc_opentracing.StreamClientInterceptor()
}
allowBatch := (cfg.TiKVClient.MaxBatchSize > 0) && enableBatch
if allowBatch {
a.batchConn = newBatchConn(uint(len(a.v)), cfg.TiKVClient.MaxBatchSize, idleNotify)
a.batchConn.initMetrics(a.target)
}
keepAlive := cfg.TiKVClient.GrpcKeepAliveTime
for i := range a.v {
ctx, cancel := context.WithTimeout(context.Background(), a.dialTimeout)
var callOptions []grpc.CallOption
callOptions = append(callOptions, grpc.MaxCallRecvMsgSize(MaxRecvMsgSize))
if cfg.TiKVClient.GrpcCompressionType == gzip.Name {
callOptions = append(callOptions, grpc.UseCompressor(gzip.Name))
}
opts = append([]grpc.DialOption{
opt,
grpc.WithInitialWindowSize(cfg.TiKVClient.GrpcInitialWindowSize),
grpc.WithInitialConnWindowSize(cfg.TiKVClient.GrpcInitialConnWindowSize),
grpc.WithUnaryInterceptor(unaryInterceptor),
grpc.WithStreamInterceptor(streamInterceptor),
grpc.WithDefaultCallOptions(callOptions...),
grpc.WithConnectParams(grpc.ConnectParams{
Backoff: backoff.Config{
BaseDelay: 100 * time.Millisecond, // Default was 1s.
Multiplier: 1.6, // Default
Jitter: 0.2, // Default
MaxDelay: 3 * time.Second, // Default was 120s.
},
MinConnectTimeout: a.dialTimeout,
}),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: time.Duration(keepAlive) * time.Second,
Timeout: cfg.TiKVClient.GetGrpcKeepAliveTimeout(),
}),
}, opts...)
if cfg.TiKVClient.GrpcSharedBufferPool {
opts = append(opts, experimental.WithRecvBufferPool(grpc.NewSharedBufferPool()))
}
conn, err := a.monitoredDial(
ctx,
fmt.Sprintf("%s-%d", a.target, i),
addr,
opts...,
)
cancel()
if err != nil {
// Cleanup if the initialization fails.
a.Close()
return errors.WithStack(err)
}
a.v[i] = conn
if allowBatch {
batchClient := &batchCommandsClient{
target: a.target,
conn: conn.ClientConn,
forwardedClients: make(map[string]*batchCommandsStream),
batched: sync.Map{},
epoch: 0,
closed: 0,
tikvClientCfg: cfg.TiKVClient,
tikvLoad: &a.tikvTransportLayerLoad,
dialTimeout: a.dialTimeout,
tryLock: tryLock{sync.NewCond(new(sync.Mutex)), false},
eventListener: eventListener,
metrics: &a.batchConn.metrics,
}
batchClient.maxConcurrencyRequestLimit.Store(cfg.TiKVClient.MaxConcurrencyRequestLimit)
a.batchCommandsClients = append(a.batchCommandsClients, batchClient)
}
}
go tikvrpc.CheckStreamTimeoutLoop(a.streamTimeout, a.done)
if allowBatch {
go a.batchSendLoop(cfg.TiKVClient)
}
return nil
}
func (a *connArray) Get() *grpc.ClientConn {
next := atomic.AddUint32(&a.index, 1) % uint32(len(a.v))
return a.v[next].ClientConn
}
func (a *connArray) Close() {
if a.batchConn != nil {
a.batchConn.Close()
}
for _, c := range a.v {
if c != nil {
err := c.Close()
tikverr.Log(err)
if err == nil {
a.monitor.RemoveConn(c)
}
}
}
close(a.done)
}
func (a *connArray) updateRPCMetrics(req *tikvrpc.Request, resp *tikvrpc.Response, latency time.Duration) {
seconds := latency.Seconds()
stale := req.GetStaleRead()
source := req.GetRequestSource()
internal := util.IsInternalRequest(req.GetRequestSource())
a.metrics.rpcLatHist.get(req.Type, stale, internal).Observe(seconds)
srcLatSum, ok := a.metrics.rpcSrcLatSum.Load(source)
if !ok {
srcLatSum = deriveRPCMetrics(metrics.TiKVSendReqSummary.MustCurryWith(
prometheus.Labels{metrics.LblStore: a.target, metrics.LblSource: source}))
a.metrics.rpcSrcLatSum.Store(source, srcLatSum)
}
srcLatSum.(*rpcMetrics).get(req.Type, stale, internal).Observe(seconds)
if execDetail := resp.GetExecDetailsV2(); execDetail != nil {
var totalRpcWallTimeNs uint64
if execDetail.TimeDetailV2 != nil {
totalRpcWallTimeNs = execDetail.TimeDetailV2.TotalRpcWallTimeNs
} else if execDetail.TimeDetail != nil {
totalRpcWallTimeNs = execDetail.TimeDetail.TotalRpcWallTimeNs
}
if totalRpcWallTimeNs > 0 {
lat := latency - time.Duration(totalRpcWallTimeNs)
if internal {
a.metrics.rpcNetLatInternal.Observe(lat.Seconds())
} else {
a.metrics.rpcNetLatExternal.Observe(lat.Seconds())
}
}
Addr: pool.target,
Ver: pool.ver,
}
}
@ -485,7 +185,7 @@ func WithCodec(codec apicodec.Codec) Opt {
type RPCClient struct {
sync.RWMutex
conns map[string]*connArray
connPools map[string]*connPool
vers map[string]uint64
option *option
@ -505,7 +205,7 @@ var _ Client = &RPCClient{}
// NewRPCClient creates a client that manages connections and rpc calls with tikv-servers.
func NewRPCClient(opts ...Opt) *RPCClient {
cli := &RPCClient{
conns: make(map[string]*connArray),
connPools: make(map[string]*connPool),
vers: make(map[string]uint64),
option: &option{
dialTimeout: dialTimeout,
@ -520,35 +220,35 @@ func NewRPCClient(opts ...Opt) *RPCClient {
return cli
}
func (c *RPCClient) getConnArray(addr string, enableBatch bool, opt ...func(cfg *config.TiKVClient)) (*connArray, error) {
func (c *RPCClient) getConnPool(addr string, enableBatch bool, opt ...func(cfg *config.TiKVClient)) (*connPool, error) {
c.RLock()
if c.isClosed {
c.RUnlock()
return nil, errors.Errorf("rpcClient is closed")
}
array, ok := c.conns[addr]
pool, ok := c.connPools[addr]
c.RUnlock()
if !ok {
var err error
array, err = c.createConnArray(addr, enableBatch, opt...)
pool, err = c.createConnPool(addr, enableBatch, opt...)
if err != nil {
return nil, err
}
}
// An idle connArray will not change to active again, this avoid the race condition
// An idle connPool will not change to active again, this avoid the race condition
// that recycling idle connection close an active connection unexpectedly (idle -> active).
if array.batchConn != nil && array.isIdle() {
if pool.batchConn != nil && pool.isIdle() {
return nil, errors.Errorf("rpcClient is idle")
}
return array, nil
return pool, nil
}
func (c *RPCClient) createConnArray(addr string, enableBatch bool, opts ...func(cfg *config.TiKVClient)) (*connArray, error) {
func (c *RPCClient) createConnPool(addr string, enableBatch bool, opts ...func(cfg *config.TiKVClient)) (*connPool, error) {
c.Lock()
defer c.Unlock()
array, ok := c.conns[addr]
pool, ok := c.connPools[addr]
if !ok {
var err error
client := config.GetGlobalConfig().TiKVClient
@ -556,7 +256,7 @@ func (c *RPCClient) createConnArray(addr string, enableBatch bool, opts ...func(
opt(&client)
}
ver := c.vers[addr] + 1
array, err = newConnArray(
pool, err = newConnPool(
client.GrpcConnectionCount,
addr,
ver,
@ -571,10 +271,10 @@ func (c *RPCClient) createConnArray(addr string, enableBatch bool, opts ...func(
if err != nil {
return nil, err
}
c.conns[addr] = array
c.connPools[addr] = pool
c.vers[addr] = ver
}
return array, nil
return pool, nil
}
func (c *RPCClient) closeConns() {
@ -582,8 +282,8 @@ func (c *RPCClient) closeConns() {
if !c.isClosed {
c.isClosed = true
// close all connections
for _, array := range c.conns {
array.Close()
for _, pool := range c.connPools {
pool.Close()
}
}
c.Unlock()
@ -595,10 +295,10 @@ func (c *RPCClient) recycleIdleConnArray() {
var addrs []string
var vers []uint64
c.RLock()
for _, conn := range c.conns {
if conn.batchConn != nil && conn.isIdle() {
addrs = append(addrs, conn.target)
vers = append(vers, conn.ver)
for _, pool := range c.connPools {
if pool.batchConn != nil && pool.isIdle() {
addrs = append(addrs, pool.target)
vers = append(vers, pool.ver)
}
}
c.RUnlock()
@ -627,19 +327,19 @@ func (c *RPCClient) sendRequest(ctx context.Context, addr string, req *tikvrpc.R
// TiDB will not send batch commands to TiFlash, to resolve the conflict with Batch Cop Request.
// tiflash/tiflash_mpp/tidb don't use BatchCommand.
enableBatch := req.StoreTp == tikvrpc.TiKV
connArray, err := c.getConnArray(addr, enableBatch)
connPool, err := c.getConnPool(addr, enableBatch)
if err != nil {
return nil, err
}
wrapErrConn := func(resp *tikvrpc.Response, err error) (*tikvrpc.Response, error) {
return resp, WrapErrConn(err, connArray)
return resp, WrapErrConn(err, connPool)
}
start := time.Now()
defer func() {
elapsed := time.Since(start)
connArray.updateRPCMetrics(req, resp, elapsed)
connPool.updateRPCMetrics(req, resp, elapsed)
if spanRPC != nil && util.TraceExecDetailsEnabled(ctx) {
if si := buildSpanInfoFromResp(resp); si != nil {
@ -654,11 +354,11 @@ func (c *RPCClient) sendRequest(ctx context.Context, addr string, req *tikvrpc.R
if config.GetGlobalConfig().TiKVClient.MaxBatchSize > 0 && enableBatch {
if batchReq := req.ToBatchCommandsRequest(); batchReq != nil {
defer trace.StartRegion(ctx, req.Type.String()).End()
return wrapErrConn(sendBatchRequest(ctx, addr, req.ForwardedHost, connArray.batchConn, batchReq, timeout, pri))
return wrapErrConn(sendBatchRequest(ctx, addr, req.ForwardedHost, connPool.batchConn, batchReq, timeout, pri))
}
}
clientConn := connArray.Get()
clientConn := connPool.Get()
if state := clientConn.GetState(); state == connectivity.TransientFailure {
storeID := strconv.FormatUint(req.Context.GetPeer().GetStoreId(), 10)
metrics.TiKVGRPCConnTransientFailureCounter.WithLabelValues(addr, storeID).Inc()
@ -679,11 +379,11 @@ func (c *RPCClient) sendRequest(ctx context.Context, addr string, req *tikvrpc.R
}
switch req.Type {
case tikvrpc.CmdBatchCop:
return wrapErrConn(c.getBatchCopStreamResponse(ctx, client, req, timeout, connArray))
return wrapErrConn(c.getBatchCopStreamResponse(ctx, client, req, timeout, connPool))
case tikvrpc.CmdCopStream:
return wrapErrConn(c.getCopStreamResponse(ctx, client, req, timeout, connArray))
return wrapErrConn(c.getCopStreamResponse(ctx, client, req, timeout, connPool))
case tikvrpc.CmdMPPConn:
return wrapErrConn(c.getMPPStreamResponse(ctx, client, req, timeout, connArray))
return wrapErrConn(c.getMPPStreamResponse(ctx, client, req, timeout, connPool))
}
// Or else it's a unary call.
ctx1, cancel := context.WithTimeout(ctx, timeout)
@ -710,7 +410,7 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R
return codec.DecodeResponse(req, resp)
}
func (c *RPCClient) getCopStreamResponse(ctx context.Context, client tikvpb.TikvClient, req *tikvrpc.Request, timeout time.Duration, connArray *connArray) (*tikvrpc.Response, error) {
func (c *RPCClient) getCopStreamResponse(ctx context.Context, client tikvpb.TikvClient, req *tikvrpc.Request, timeout time.Duration, connPool *connPool) (*tikvrpc.Response, error) {
// Coprocessor streaming request.
// Use context to support timeout for grpc streaming client.
ctx1, cancel := context.WithCancel(ctx)
@ -727,7 +427,7 @@ func (c *RPCClient) getCopStreamResponse(ctx context.Context, client tikvpb.Tikv
copStream := resp.Resp.(*tikvrpc.CopStreamResponse)
copStream.Timeout = timeout
copStream.Lease.Cancel = cancel
connArray.streamTimeout <- &copStream.Lease
connPool.streamTimeout <- &copStream.Lease
// Read the first streaming response to get CopStreamResponse.
// This can make error handling much easier, because SendReq() retry on
@ -745,7 +445,7 @@ func (c *RPCClient) getCopStreamResponse(ctx context.Context, client tikvpb.Tikv
}
func (c *RPCClient) getBatchCopStreamResponse(ctx context.Context, client tikvpb.TikvClient, req *tikvrpc.Request, timeout time.Duration, connArray *connArray) (*tikvrpc.Response, error) {
func (c *RPCClient) getBatchCopStreamResponse(ctx context.Context, client tikvpb.TikvClient, req *tikvrpc.Request, timeout time.Duration, connPool *connPool) (*tikvrpc.Response, error) {
// Coprocessor streaming request.
// Use context to support timeout for grpc streaming client.
ctx1, cancel := context.WithCancel(ctx)
@ -762,7 +462,7 @@ func (c *RPCClient) getBatchCopStreamResponse(ctx context.Context, client tikvpb
copStream := resp.Resp.(*tikvrpc.BatchCopStreamResponse)
copStream.Timeout = timeout
copStream.Lease.Cancel = cancel
connArray.streamTimeout <- &copStream.Lease
connPool.streamTimeout <- &copStream.Lease
// Read the first streaming response to get CopStreamResponse.
// This can make error handling much easier, because SendReq() retry on
@ -779,7 +479,7 @@ func (c *RPCClient) getBatchCopStreamResponse(ctx context.Context, client tikvpb
return resp, nil
}
func (c *RPCClient) getMPPStreamResponse(ctx context.Context, client tikvpb.TikvClient, req *tikvrpc.Request, timeout time.Duration, connArray *connArray) (*tikvrpc.Response, error) {
func (c *RPCClient) getMPPStreamResponse(ctx context.Context, client tikvpb.TikvClient, req *tikvrpc.Request, timeout time.Duration, connPool *connPool) (*tikvrpc.Response, error) {
// MPP streaming request.
// Use context to support timeout for grpc streaming client.
ctx1, cancel := context.WithCancel(ctx)
@ -796,7 +496,7 @@ func (c *RPCClient) getMPPStreamResponse(ctx context.Context, client tikvpb.Tikv
copStream := resp.Resp.(*tikvrpc.MPPStreamResponse)
copStream.Timeout = timeout
copStream.Lease.Cancel = cancel
connArray.streamTimeout <- &copStream.Lease
connPool.streamTimeout <- &copStream.Lease
// Read the first streaming response to get CopStreamResponse.
// This can make error handling much easier, because SendReq() retry on
@ -831,20 +531,20 @@ func (c *RPCClient) CloseAddrVer(addr string, ver uint64) error {
c.Unlock()
return nil
}
conn, ok := c.conns[addr]
pool, ok := c.connPools[addr]
if ok {
if conn.ver <= ver {
delete(c.conns, addr)
logutil.BgLogger().Debug("close connection", zap.String("target", addr), zap.Uint64("ver", ver), zap.Uint64("conn.ver", conn.ver))
if pool.ver <= ver {
delete(c.connPools, addr)
logutil.BgLogger().Debug("close connection", zap.String("target", addr), zap.Uint64("ver", ver), zap.Uint64("pool.ver", pool.ver))
} else {
logutil.BgLogger().Debug("ignore close connection", zap.String("target", addr), zap.Uint64("ver", ver), zap.Uint64("conn.ver", conn.ver))
conn = nil
logutil.BgLogger().Debug("ignore close connection", zap.String("target", addr), zap.Uint64("ver", ver), zap.Uint64("pool.ver", pool.ver))
pool = nil
}
}
c.Unlock()
if conn != nil {
conn.Close()
if pool != nil {
pool.Close()
}
return nil
}

View File

@ -72,10 +72,10 @@ func (c *RPCClient) SendRequestAsync(ctx context.Context, addr string, req *tikv
}
tikvrpc.AttachContext(req, req.Context)
// TODO(zyguan): If the client created `WithGRPCDialOptions(grpc.WithBlock())`, `getConnArray` might be blocked for
// TODO(zyguan): If the client created `WithGRPCDialOptions(grpc.WithBlock())`, `getConnPool` might be blocked for
// a while when the corresponding conn array is uninitialized. However, since tidb won't set this option, we just
// keep `getConnArray` synchronous for now.
connArray, err := c.getConnArray(addr, true)
// keep `getConnPool` synchronous for now.
connPool, err := c.getConnPool(addr, true)
if err != nil {
cb.Invoke(nil, err)
return
@ -113,7 +113,7 @@ func (c *RPCClient) SendRequestAsync(ctx context.Context, addr string, req *tikv
metrics.BatchRequestDurationDone.Observe(elapsed.Seconds())
// rpc metrics
connArray.updateRPCMetrics(req, resp, elapsed)
connPool.updateRPCMetrics(req, resp, elapsed)
// tracing
if spanRPC != nil {
@ -131,7 +131,7 @@ func (c *RPCClient) SendRequestAsync(ctx context.Context, addr string, req *tikv
resp, err = c.option.codec.DecodeResponse(req, resp)
}
return resp, WrapErrConn(err, connArray)
return resp, WrapErrConn(err, connPool)
})
stop = context.AfterFunc(ctx, func() {
@ -139,7 +139,7 @@ func (c *RPCClient) SendRequestAsync(ctx context.Context, addr string, req *tikv
entry.error(ctx.Err())
})
batchConn := connArray.batchConn
batchConn := connPool.batchConn
if val, err := util.EvalFailpoint("mockBatchCommandsChannelFullOnAsyncSend"); err == nil {
mockBatchCommandsChannelFullOnAsyncSend(ctx, batchConn, cb, val)
}

View File

@ -40,7 +40,6 @@ import (
"encoding/json"
"fmt"
"math"
"runtime"
"runtime/trace"
"strings"
"sync"
@ -264,159 +263,6 @@ type batchConnMetrics struct {
bestBatchSize prometheus.Observer
}
type batchConn struct {
// An atomic flag indicates whether the batch is idle or not.
// 0 for busy, others for idle.
idle uint32
// batchCommandsCh used for batch commands.
batchCommandsCh chan *batchCommandsEntry
batchCommandsClients []*batchCommandsClient
tikvTransportLayerLoad uint64
closed chan struct{}
reqBuilder *batchCommandsBuilder
// Notify rpcClient to check the idle flag
idleNotify *uint32
idleDetect *time.Timer
fetchMoreTimer *time.Timer
index uint32
metrics batchConnMetrics
}
func newBatchConn(connCount, maxBatchSize uint, idleNotify *uint32) *batchConn {
return &batchConn{
batchCommandsCh: make(chan *batchCommandsEntry, maxBatchSize),
batchCommandsClients: make([]*batchCommandsClient, 0, connCount),
tikvTransportLayerLoad: 0,
closed: make(chan struct{}),
reqBuilder: newBatchCommandsBuilder(maxBatchSize),
idleNotify: idleNotify,
idleDetect: time.NewTimer(idleTimeout),
}
}
func (a *batchConn) initMetrics(target string) {
a.metrics.pendingRequests = metrics.TiKVBatchPendingRequests.WithLabelValues(target)
a.metrics.batchSize = metrics.TiKVBatchRequests.WithLabelValues(target)
a.metrics.sendLoopWaitHeadDur = metrics.TiKVBatchSendLoopDuration.WithLabelValues(target, "wait-head")
a.metrics.sendLoopWaitMoreDur = metrics.TiKVBatchSendLoopDuration.WithLabelValues(target, "wait-more")
a.metrics.sendLoopSendDur = metrics.TiKVBatchSendLoopDuration.WithLabelValues(target, "send")
a.metrics.recvLoopRecvDur = metrics.TiKVBatchRecvLoopDuration.WithLabelValues(target, "recv")
a.metrics.recvLoopProcessDur = metrics.TiKVBatchRecvLoopDuration.WithLabelValues(target, "process")
a.metrics.batchSendTailLat = metrics.TiKVBatchSendTailLatency.WithLabelValues(target)
a.metrics.batchRecvTailLat = metrics.TiKVBatchRecvTailLatency.WithLabelValues(target)
a.metrics.headArrivalInterval = metrics.TiKVBatchHeadArrivalInterval.WithLabelValues(target)
a.metrics.batchMoreRequests = metrics.TiKVBatchMoreRequests.WithLabelValues(target)
a.metrics.bestBatchSize = metrics.TiKVBatchBestSize.WithLabelValues(target)
}
func (a *batchConn) isIdle() bool {
return atomic.LoadUint32(&a.idle) != 0
}
// fetchAllPendingRequests fetches all pending requests from the channel.
func (a *batchConn) fetchAllPendingRequests(maxBatchSize int) (headRecvTime time.Time, headArrivalInterval time.Duration) {
// Block on the first element.
latestReqStartTime := a.reqBuilder.latestReqStartTime
var headEntry *batchCommandsEntry
select {
case headEntry = <-a.batchCommandsCh:
if !a.idleDetect.Stop() {
<-a.idleDetect.C
}
a.idleDetect.Reset(idleTimeout)
case <-a.idleDetect.C:
a.idleDetect.Reset(idleTimeout)
atomic.AddUint32(&a.idle, 1)
atomic.CompareAndSwapUint32(a.idleNotify, 0, 1)
// This batchConn to be recycled
return time.Now(), 0
case <-a.closed:
return time.Now(), 0
}
if headEntry == nil {
return time.Now(), 0
}
headRecvTime = time.Now()
if headEntry.start.After(latestReqStartTime) && !latestReqStartTime.IsZero() {
headArrivalInterval = headEntry.start.Sub(latestReqStartTime)
}
a.reqBuilder.push(headEntry)
// This loop is for trying best to collect more requests.
for a.reqBuilder.len() < maxBatchSize {
select {
case entry := <-a.batchCommandsCh:
if entry == nil {
return
}
a.reqBuilder.push(entry)
default:
return
}
}
return
}
// fetchMorePendingRequests fetches more pending requests from the channel.
func (a *batchConn) fetchMorePendingRequests(
maxBatchSize int,
batchWaitSize int,
maxWaitTime time.Duration,
) {
// Try to collect `batchWaitSize` requests, or wait `maxWaitTime`.
if a.fetchMoreTimer == nil {
a.fetchMoreTimer = time.NewTimer(maxWaitTime)
} else {
a.fetchMoreTimer.Reset(maxWaitTime)
}
for a.reqBuilder.len() < batchWaitSize {
select {
case entry := <-a.batchCommandsCh:
if entry == nil {
if !a.fetchMoreTimer.Stop() {
<-a.fetchMoreTimer.C
}
return
}
a.reqBuilder.push(entry)
case <-a.fetchMoreTimer.C:
return
}
}
if !a.fetchMoreTimer.Stop() {
<-a.fetchMoreTimer.C
}
// Do an additional non-block try. Here we test the length with `maxBatchSize` instead
// of `batchWaitSize` because trying best to fetch more requests is necessary so that
// we can adjust the `batchWaitSize` dynamically.
yielded := false
for a.reqBuilder.len() < maxBatchSize {
select {
case entry := <-a.batchCommandsCh:
if entry == nil {
return
}
a.reqBuilder.push(entry)
default:
if yielded {
return
}
// yield once to batch more requests.
runtime.Gosched()
yielded = true
}
}
}
const idleTimeout = 3 * time.Minute
var (
// presetBatchPolicies defines a set of [turboBatchOptions] as batch policies.
presetBatchPolicies = map[string]turboBatchOptions{
@ -534,150 +380,6 @@ func (t *turboBatchTrigger) preferredBatchWaitSize(avgBatchWaitSize float64, def
return batchWaitSize
}
// BatchSendLoopPanicCounter is only used for testing.
var BatchSendLoopPanicCounter int64 = 0
var initBatchPolicyWarn sync.Once
func (a *batchConn) batchSendLoop(cfg config.TiKVClient) {
defer func() {
if r := recover(); r != nil {
metrics.TiKVPanicCounter.WithLabelValues(metrics.LabelBatchSendLoop).Inc()
logutil.BgLogger().Error("batchSendLoop",
zap.Any("r", r),
zap.Stack("stack"))
atomic.AddInt64(&BatchSendLoopPanicCounter, 1)
logutil.BgLogger().Info("restart batchSendLoop", zap.Int64("count", atomic.LoadInt64(&BatchSendLoopPanicCounter)))
go a.batchSendLoop(cfg)
}
}()
trigger, ok := newTurboBatchTriggerFromPolicy(cfg.BatchPolicy)
if !ok {
initBatchPolicyWarn.Do(func() {
logutil.BgLogger().Warn("fallback to default batch policy due to invalid value", zap.String("value", cfg.BatchPolicy))
})
}
turboBatchWaitTime := trigger.turboWaitTime()
avgBatchWaitSize := float64(cfg.BatchWaitSize)
for {
sendLoopStartTime := time.Now()
a.reqBuilder.reset()
headRecvTime, headArrivalInterval := a.fetchAllPendingRequests(int(cfg.MaxBatchSize))
if a.reqBuilder.len() == 0 {
// the conn is closed or recycled.
return
}
// curl -X PUT -d 'return(true)' http://0.0.0.0:10080/fail/tikvclient/mockBlockOnBatchClient
if val, err := util.EvalFailpoint("mockBlockOnBatchClient"); err == nil {
if val.(bool) {
time.Sleep(1 * time.Hour)
}
}
if batchSize := a.reqBuilder.len(); batchSize < int(cfg.MaxBatchSize) {
if cfg.MaxBatchWaitTime > 0 && atomic.LoadUint64(&a.tikvTransportLayerLoad) > uint64(cfg.OverloadThreshold) {
// If the target TiKV is overload, wait a while to collect more requests.
metrics.TiKVBatchWaitOverLoad.Inc()
a.fetchMorePendingRequests(int(cfg.MaxBatchSize), int(cfg.BatchWaitSize), cfg.MaxBatchWaitTime)
} else if turboBatchWaitTime > 0 && headArrivalInterval > 0 && trigger.needFetchMore(headArrivalInterval) {
batchWaitSize := trigger.preferredBatchWaitSize(avgBatchWaitSize, int(cfg.BatchWaitSize))
a.fetchMorePendingRequests(int(cfg.MaxBatchSize), batchWaitSize, turboBatchWaitTime)
a.metrics.batchMoreRequests.Observe(float64(a.reqBuilder.len() - batchSize))
}
}
length := a.reqBuilder.len()
avgBatchWaitSize = 0.2*float64(length) + 0.8*avgBatchWaitSize
a.metrics.pendingRequests.Observe(float64(len(a.batchCommandsCh) + length))
a.metrics.bestBatchSize.Observe(avgBatchWaitSize)
a.metrics.headArrivalInterval.Observe(headArrivalInterval.Seconds())
a.metrics.sendLoopWaitHeadDur.Observe(headRecvTime.Sub(sendLoopStartTime).Seconds())
a.metrics.sendLoopWaitMoreDur.Observe(time.Since(sendLoopStartTime).Seconds())
a.getClientAndSend()
sendLoopEndTime := time.Now()
a.metrics.sendLoopSendDur.Observe(sendLoopEndTime.Sub(sendLoopStartTime).Seconds())
if dur := sendLoopEndTime.Sub(headRecvTime); dur > batchSendTailLatThreshold {
a.metrics.batchSendTailLat.Observe(dur.Seconds())
}
}
}
const (
SendFailedReasonNoAvailableLimit = "concurrency limit exceeded"
SendFailedReasonTryLockForSendFail = "tryLockForSend fail"
)
func (a *batchConn) getClientAndSend() {
if val, err := util.EvalFailpoint("mockBatchClientSendDelay"); err == nil {
if timeout, ok := val.(int); ok && timeout > 0 {
time.Sleep(time.Duration(timeout * int(time.Millisecond)))
}
}
// Choose a connection by round-robbin.
var (
cli *batchCommandsClient
target string
)
reasons := make([]string, 0)
hasHighPriorityTask := a.reqBuilder.hasHighPriorityTask()
for i := 0; i < len(a.batchCommandsClients); i++ {
a.index = (a.index + 1) % uint32(len(a.batchCommandsClients))
target = a.batchCommandsClients[a.index].target
// The lock protects the batchCommandsClient from been closed while it's in use.
c := a.batchCommandsClients[a.index]
if hasHighPriorityTask || c.available() > 0 {
if c.tryLockForSend() {
cli = c
break
} else {
reasons = append(reasons, SendFailedReasonTryLockForSendFail)
}
} else {
reasons = append(reasons, SendFailedReasonNoAvailableLimit)
}
}
if cli == nil {
logutil.BgLogger().Info("no available connections", zap.String("target", target), zap.Any("reasons", reasons))
metrics.TiKVNoAvailableConnectionCounter.Inc()
if config.GetGlobalConfig().TiKVClient.MaxConcurrencyRequestLimit == config.DefMaxConcurrencyRequestLimit {
// Only cancel requests when MaxConcurrencyRequestLimit feature is not enabled, to be compatible with the behavior of older versions.
// TODO: But when MaxConcurrencyRequestLimit feature is enabled, the requests won't be canceled and will wait until timeout.
// This behavior may not be reasonable, as the timeout is usually 40s or 60s, which is too long to retry in time.
a.reqBuilder.cancel(errors.New("no available connections"))
}
return
}
defer cli.unlockForSend()
available := cli.available()
reqSendTime := time.Now()
batch := 0
req, forwardingReqs := a.reqBuilder.buildWithLimit(available, func(id uint64, e *batchCommandsEntry) {
cli.batched.Store(id, e)
cli.sent.Add(1)
atomic.StoreInt64(&e.sendLat, int64(reqSendTime.Sub(e.start)))
if trace.IsEnabled() {
trace.Log(e.ctx, "rpc", "send")
}
})
if req != nil {
batch += len(req.RequestIds)
cli.send("", req)
}
for forwardedHost, req := range forwardingReqs {
batch += len(req.RequestIds)
cli.send(forwardedHost, req)
}
if batch > 0 {
a.metrics.batchSize.Observe(float64(batch))
}
}
type tryLock struct {
*sync.Cond
reCreating bool
@ -1127,18 +829,6 @@ func (c *batchCommandsClient) initBatchClient(forwardedHost string) error {
return nil
}
func (a *batchConn) Close() {
// Close all batchRecvLoop.
for _, c := range a.batchCommandsClients {
// After connections are closed, `batchRecvLoop`s will check the flag.
atomic.StoreInt32(&c.closed, 1)
}
// Don't close(batchCommandsCh) because when Close() is called, someone maybe
// calling SendRequest and writing batchCommandsCh, if we close it here the
// writing goroutine will panic.
close(a.closed)
}
func sendBatchRequest(
ctx context.Context,
addr string,

View File

@ -64,7 +64,7 @@ func TestPanicInRecvLoop(t *testing.T) {
rpcClient.option.dialTimeout = time.Second / 3
// Start batchRecvLoop, and it should panic in `failPendingRequests`.
_, err := rpcClient.getConnArray(addr, true, func(cfg *config.TiKVClient) { cfg.GrpcConnectionCount = 1 })
_, err := rpcClient.getConnPool(addr, true, func(cfg *config.TiKVClient) { cfg.GrpcConnectionCount = 1 })
assert.Nil(t, err, "cannot establish local connection due to env problems(e.g. heavy load in test machine), please retry again")
req := tikvrpc.NewRequest(tikvrpc.CmdEmpty, &tikvpb.BatchCommandsEmptyRequest{})
@ -103,10 +103,10 @@ func TestRecvErrorInMultipleRecvLoops(t *testing.T) {
_, err := rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second)
assert.Nil(t, err)
}
connArray, err := rpcClient.getConnArray(addr, true)
assert.NotNil(t, connArray)
connPool, err := rpcClient.getConnPool(addr, true)
assert.NotNil(t, connPool)
assert.Nil(t, err)
batchConn := connArray.batchConn
batchConn := connPool.batchConn
assert.NotNil(t, batchConn)
assert.Equal(t, len(batchConn.batchCommandsClients), 1)
batchClient := batchConn.batchCommandsClients[0]

View File

@ -74,28 +74,28 @@ func TestConn(t *testing.T) {
defer client.Close()
addr := "127.0.0.1:6379"
conn1, err := client.getConnArray(addr, true)
conn1, err := client.getConnPool(addr, true)
assert.Nil(t, err)
conn2, err := client.getConnArray(addr, true)
conn2, err := client.getConnPool(addr, true)
assert.Nil(t, err)
assert.False(t, conn2.Get() == conn1.Get())
ver := conn2.ver
assert.Nil(t, client.CloseAddrVer(addr, ver-1))
_, ok := client.conns[addr]
_, ok := client.connPools[addr]
assert.True(t, ok)
assert.Nil(t, client.CloseAddrVer(addr, ver))
_, ok = client.conns[addr]
_, ok = client.connPools[addr]
assert.False(t, ok)
conn3, err := client.getConnArray(addr, true)
conn3, err := client.getConnPool(addr, true)
assert.Nil(t, err)
assert.NotNil(t, conn3)
assert.Equal(t, ver+1, conn3.ver)
client.Close()
conn4, err := client.getConnArray(addr, true)
conn4, err := client.getConnPool(addr, true)
assert.NotNil(t, err)
assert.Nil(t, conn4)
}
@ -105,10 +105,10 @@ func TestGetConnAfterClose(t *testing.T) {
defer client.Close()
addr := "127.0.0.1:6379"
connArray, err := client.getConnArray(addr, true)
connPool, err := client.getConnPool(addr, true)
assert.Nil(t, err)
assert.Nil(t, client.CloseAddr(addr))
conn := connArray.Get()
conn := connPool.Get()
state := conn.GetState()
assert.True(t, state == connectivity.Shutdown)
}
@ -139,7 +139,7 @@ func TestSendWhenReconnect(t *testing.T) {
restoreFn()
}()
addr := server.Addr()
conn, err := rpcClient.getConnArray(addr, true)
conn, err := rpcClient.getConnPool(addr, true)
assert.Nil(t, err)
// Suppose all connections are re-establishing.
@ -680,7 +680,7 @@ func TestBatchClientRecoverAfterServerRestart(t *testing.T) {
}()
req := &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_Coprocessor{Coprocessor: &coprocessor.Request{}}}
conn, err := client.getConnArray(addr, true)
conn, err := client.getConnPool(addr, true)
assert.Nil(t, err)
// send some request, it should be success.
for i := 0; i < 100; i++ {
@ -855,7 +855,7 @@ func TestBatchClientReceiveHealthFeedback(t *testing.T) {
client := NewRPCClient()
defer client.Close()
conn, err := client.getConnArray(addr, true)
conn, err := client.getConnPool(addr, true)
assert.NoError(t, err)
tikvClient := tikvpb.NewTikvClient(conn.Get())
@ -944,7 +944,7 @@ func TestRandomRestartStoreAndForwarding(t *testing.T) {
}
}()
conn, err := client1.getConnArray(addr1, true)
conn, err := client1.getConnPool(addr1, true)
assert.Nil(t, err)
for j := 0; j < concurrency; j++ {
wg.Add(1)
@ -1040,7 +1040,7 @@ func TestFastFailWhenNoAvailableConn(t *testing.T) {
}()
req := &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_Coprocessor{Coprocessor: &coprocessor.Request{}}}
conn, err := client.getConnArray(addr, true)
conn, err := client.getConnPool(addr, true)
assert.Nil(t, err)
_, err = sendBatchRequest(context.Background(), addr, "", conn.batchConn, req, time.Second, 0)
require.NoError(t, err)
@ -1060,7 +1060,7 @@ func TestFastFailWhenNoAvailableConn(t *testing.T) {
func TestConcurrentCloseConnPanic(t *testing.T) {
client := NewRPCClient()
addr := "127.0.0.1:6379"
_, err := client.getConnArray(addr, true)
_, err := client.getConnPool(addr, true)
assert.Nil(t, err)
var wg sync.WaitGroup
wg.Add(2)

View File

@ -0,0 +1,339 @@
// Copyright 2025 TiKV Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"runtime"
"runtime/trace"
"sync"
"sync/atomic"
"time"
"github.com/pkg/errors"
"github.com/tikv/client-go/v2/config"
"github.com/tikv/client-go/v2/internal/logutil"
"github.com/tikv/client-go/v2/metrics"
"github.com/tikv/client-go/v2/util"
"go.uber.org/zap"
)
type batchConn struct {
// An atomic flag indicates whether the batch is idle or not.
// 0 for busy, others for idle.
idle uint32
// batchCommandsCh used for batch commands.
batchCommandsCh chan *batchCommandsEntry
batchCommandsClients []*batchCommandsClient
tikvTransportLayerLoad uint64
closed chan struct{}
reqBuilder *batchCommandsBuilder
// Notify rpcClient to check the idle flag
idleNotify *uint32
idleDetect *time.Timer
fetchMoreTimer *time.Timer
index uint32
metrics batchConnMetrics
}
func newBatchConn(connCount, maxBatchSize uint, idleNotify *uint32) *batchConn {
return &batchConn{
batchCommandsCh: make(chan *batchCommandsEntry, maxBatchSize),
batchCommandsClients: make([]*batchCommandsClient, 0, connCount),
tikvTransportLayerLoad: 0,
closed: make(chan struct{}),
reqBuilder: newBatchCommandsBuilder(maxBatchSize),
idleNotify: idleNotify,
idleDetect: time.NewTimer(idleTimeout),
}
}
func (a *batchConn) initMetrics(target string) {
a.metrics.pendingRequests = metrics.TiKVBatchPendingRequests.WithLabelValues(target)
a.metrics.batchSize = metrics.TiKVBatchRequests.WithLabelValues(target)
a.metrics.sendLoopWaitHeadDur = metrics.TiKVBatchSendLoopDuration.WithLabelValues(target, "wait-head")
a.metrics.sendLoopWaitMoreDur = metrics.TiKVBatchSendLoopDuration.WithLabelValues(target, "wait-more")
a.metrics.sendLoopSendDur = metrics.TiKVBatchSendLoopDuration.WithLabelValues(target, "send")
a.metrics.recvLoopRecvDur = metrics.TiKVBatchRecvLoopDuration.WithLabelValues(target, "recv")
a.metrics.recvLoopProcessDur = metrics.TiKVBatchRecvLoopDuration.WithLabelValues(target, "process")
a.metrics.batchSendTailLat = metrics.TiKVBatchSendTailLatency.WithLabelValues(target)
a.metrics.batchRecvTailLat = metrics.TiKVBatchRecvTailLatency.WithLabelValues(target)
a.metrics.headArrivalInterval = metrics.TiKVBatchHeadArrivalInterval.WithLabelValues(target)
a.metrics.batchMoreRequests = metrics.TiKVBatchMoreRequests.WithLabelValues(target)
a.metrics.bestBatchSize = metrics.TiKVBatchBestSize.WithLabelValues(target)
}
func (a *batchConn) isIdle() bool {
return atomic.LoadUint32(&a.idle) != 0
}
// fetchAllPendingRequests fetches all pending requests from the channel.
func (a *batchConn) fetchAllPendingRequests(maxBatchSize int) (headRecvTime time.Time, headArrivalInterval time.Duration) {
// Block on the first element.
latestReqStartTime := a.reqBuilder.latestReqStartTime
var headEntry *batchCommandsEntry
select {
case headEntry = <-a.batchCommandsCh:
if !a.idleDetect.Stop() {
<-a.idleDetect.C
}
a.idleDetect.Reset(idleTimeout)
case <-a.idleDetect.C:
a.idleDetect.Reset(idleTimeout)
atomic.AddUint32(&a.idle, 1)
atomic.CompareAndSwapUint32(a.idleNotify, 0, 1)
// This batchConn to be recycled
return time.Now(), 0
case <-a.closed:
return time.Now(), 0
}
if headEntry == nil {
return time.Now(), 0
}
headRecvTime = time.Now()
if headEntry.start.After(latestReqStartTime) && !latestReqStartTime.IsZero() {
headArrivalInterval = headEntry.start.Sub(latestReqStartTime)
}
a.reqBuilder.push(headEntry)
// This loop is for trying best to collect more requests.
for a.reqBuilder.len() < maxBatchSize {
select {
case entry := <-a.batchCommandsCh:
if entry == nil {
return
}
a.reqBuilder.push(entry)
default:
return
}
}
return
}
// fetchMorePendingRequests fetches more pending requests from the channel.
func (a *batchConn) fetchMorePendingRequests(
maxBatchSize int,
batchWaitSize int,
maxWaitTime time.Duration,
) {
// Try to collect `batchWaitSize` requests, or wait `maxWaitTime`.
if a.fetchMoreTimer == nil {
a.fetchMoreTimer = time.NewTimer(maxWaitTime)
} else {
a.fetchMoreTimer.Reset(maxWaitTime)
}
for a.reqBuilder.len() < batchWaitSize {
select {
case entry := <-a.batchCommandsCh:
if entry == nil {
if !a.fetchMoreTimer.Stop() {
<-a.fetchMoreTimer.C
}
return
}
a.reqBuilder.push(entry)
case <-a.fetchMoreTimer.C:
return
}
}
if !a.fetchMoreTimer.Stop() {
<-a.fetchMoreTimer.C
}
// Do an additional non-block try. Here we test the length with `maxBatchSize` instead
// of `batchWaitSize` because trying best to fetch more requests is necessary so that
// we can adjust the `batchWaitSize` dynamically.
yielded := false
for a.reqBuilder.len() < maxBatchSize {
select {
case entry := <-a.batchCommandsCh:
if entry == nil {
return
}
a.reqBuilder.push(entry)
default:
if yielded {
return
}
// yield once to batch more requests.
runtime.Gosched()
yielded = true
}
}
}
const idleTimeout = 3 * time.Minute
// BatchSendLoopPanicCounter is only used for testing.
var BatchSendLoopPanicCounter int64 = 0
var initBatchPolicyWarn sync.Once
func (a *batchConn) batchSendLoop(cfg config.TiKVClient) {
defer func() {
if r := recover(); r != nil {
metrics.TiKVPanicCounter.WithLabelValues(metrics.LabelBatchSendLoop).Inc()
logutil.BgLogger().Error("batchSendLoop",
zap.Any("r", r),
zap.Stack("stack"))
atomic.AddInt64(&BatchSendLoopPanicCounter, 1)
logutil.BgLogger().Info("restart batchSendLoop", zap.Int64("count", atomic.LoadInt64(&BatchSendLoopPanicCounter)))
go a.batchSendLoop(cfg)
}
}()
trigger, ok := newTurboBatchTriggerFromPolicy(cfg.BatchPolicy)
if !ok {
initBatchPolicyWarn.Do(func() {
logutil.BgLogger().Warn("fallback to default batch policy due to invalid value", zap.String("value", cfg.BatchPolicy))
})
}
turboBatchWaitTime := trigger.turboWaitTime()
avgBatchWaitSize := float64(cfg.BatchWaitSize)
for {
sendLoopStartTime := time.Now()
a.reqBuilder.reset()
headRecvTime, headArrivalInterval := a.fetchAllPendingRequests(int(cfg.MaxBatchSize))
if a.reqBuilder.len() == 0 {
// the conn is closed or recycled.
return
}
// curl -X PUT -d 'return(true)' http://0.0.0.0:10080/fail/tikvclient/mockBlockOnBatchClient
if val, err := util.EvalFailpoint("mockBlockOnBatchClient"); err == nil {
if val.(bool) {
time.Sleep(1 * time.Hour)
}
}
if batchSize := a.reqBuilder.len(); batchSize < int(cfg.MaxBatchSize) {
if cfg.MaxBatchWaitTime > 0 && atomic.LoadUint64(&a.tikvTransportLayerLoad) > uint64(cfg.OverloadThreshold) {
// If the target TiKV is overload, wait a while to collect more requests.
metrics.TiKVBatchWaitOverLoad.Inc()
a.fetchMorePendingRequests(int(cfg.MaxBatchSize), int(cfg.BatchWaitSize), cfg.MaxBatchWaitTime)
} else if turboBatchWaitTime > 0 && headArrivalInterval > 0 && trigger.needFetchMore(headArrivalInterval) {
batchWaitSize := trigger.preferredBatchWaitSize(avgBatchWaitSize, int(cfg.BatchWaitSize))
a.fetchMorePendingRequests(int(cfg.MaxBatchSize), batchWaitSize, turboBatchWaitTime)
a.metrics.batchMoreRequests.Observe(float64(a.reqBuilder.len() - batchSize))
}
}
length := a.reqBuilder.len()
avgBatchWaitSize = 0.2*float64(length) + 0.8*avgBatchWaitSize
a.metrics.pendingRequests.Observe(float64(len(a.batchCommandsCh) + length))
a.metrics.bestBatchSize.Observe(avgBatchWaitSize)
a.metrics.headArrivalInterval.Observe(headArrivalInterval.Seconds())
a.metrics.sendLoopWaitHeadDur.Observe(headRecvTime.Sub(sendLoopStartTime).Seconds())
a.metrics.sendLoopWaitMoreDur.Observe(time.Since(sendLoopStartTime).Seconds())
a.getClientAndSend()
sendLoopEndTime := time.Now()
a.metrics.sendLoopSendDur.Observe(sendLoopEndTime.Sub(sendLoopStartTime).Seconds())
if dur := sendLoopEndTime.Sub(headRecvTime); dur > batchSendTailLatThreshold {
a.metrics.batchSendTailLat.Observe(dur.Seconds())
}
}
}
const (
SendFailedReasonNoAvailableLimit = "concurrency limit exceeded"
SendFailedReasonTryLockForSendFail = "tryLockForSend fail"
)
func (a *batchConn) getClientAndSend() {
if val, err := util.EvalFailpoint("mockBatchClientSendDelay"); err == nil {
if timeout, ok := val.(int); ok && timeout > 0 {
time.Sleep(time.Duration(timeout * int(time.Millisecond)))
}
}
// Choose a connection by round-robbin.
var (
cli *batchCommandsClient
target string
)
reasons := make([]string, 0)
hasHighPriorityTask := a.reqBuilder.hasHighPriorityTask()
for i := 0; i < len(a.batchCommandsClients); i++ {
a.index = (a.index + 1) % uint32(len(a.batchCommandsClients))
target = a.batchCommandsClients[a.index].target
// The lock protects the batchCommandsClient from been closed while it's in use.
c := a.batchCommandsClients[a.index]
if hasHighPriorityTask || c.available() > 0 {
if c.tryLockForSend() {
cli = c
break
} else {
reasons = append(reasons, SendFailedReasonTryLockForSendFail)
}
} else {
reasons = append(reasons, SendFailedReasonNoAvailableLimit)
}
}
if cli == nil {
logutil.BgLogger().Info("no available connections", zap.String("target", target), zap.Any("reasons", reasons))
metrics.TiKVNoAvailableConnectionCounter.Inc()
if config.GetGlobalConfig().TiKVClient.MaxConcurrencyRequestLimit == config.DefMaxConcurrencyRequestLimit {
// Only cancel requests when MaxConcurrencyRequestLimit feature is not enabled, to be compatible with the behavior of older versions.
// TODO: But when MaxConcurrencyRequestLimit feature is enabled, the requests won't be canceled and will wait until timeout.
// This behavior may not be reasonable, as the timeout is usually 40s or 60s, which is too long to retry in time.
a.reqBuilder.cancel(errors.New("no available connections"))
}
return
}
defer cli.unlockForSend()
available := cli.available()
reqSendTime := time.Now()
batch := 0
req, forwardingReqs := a.reqBuilder.buildWithLimit(available, func(id uint64, e *batchCommandsEntry) {
cli.batched.Store(id, e)
cli.sent.Add(1)
atomic.StoreInt64(&e.sendLat, int64(reqSendTime.Sub(e.start)))
if trace.IsEnabled() {
trace.Log(e.ctx, "rpc", "send")
}
})
if req != nil {
batch += len(req.RequestIds)
cli.send("", req)
}
for forwardedHost, req := range forwardingReqs {
batch += len(req.RequestIds)
cli.send(forwardedHost, req)
}
if batch > 0 {
a.metrics.batchSize.Observe(float64(batch))
}
}
func (a *batchConn) Close() {
// Close all batchRecvLoop.
for _, c := range a.batchCommandsClients {
// After connections are closed, `batchRecvLoop`s will check the flag.
atomic.StoreInt32(&c.closed, 1)
}
// Don't close(batchCommandsCh) because when Close() is called, someone maybe
// calling SendRequest and writing batchCommandsCh, if we close it here the
// writing goroutine will panic.
close(a.closed)
}

View File

@ -0,0 +1,101 @@
// Copyright 2025 TiKV Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"sync"
"time"
"github.com/tikv/client-go/v2/internal/logutil"
"github.com/tikv/client-go/v2/metrics"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
)
type monitoredConn struct {
*grpc.ClientConn
Name string
}
func (c *monitoredConn) Close() error {
if c.ClientConn != nil {
err := c.ClientConn.Close()
logutil.BgLogger().Debug("close gRPC connection", zap.String("target", c.Name), zap.Error(err))
return err
}
return nil
}
type connMonitor struct {
m sync.Map
loopOnce sync.Once
stopOnce sync.Once
stop chan struct{}
}
func (c *connMonitor) AddConn(conn *monitoredConn) {
c.m.Store(conn.Name, conn)
}
func (c *connMonitor) RemoveConn(conn *monitoredConn) {
c.m.Delete(conn.Name)
for state := connectivity.Idle; state <= connectivity.Shutdown; state++ {
metrics.TiKVGrpcConnectionState.WithLabelValues(conn.Name, conn.Target(), state.String()).Set(0)
}
}
func (c *connMonitor) Start() {
c.loopOnce.Do(
func() {
c.stop = make(chan struct{})
go c.start()
},
)
}
func (c *connMonitor) Stop() {
c.stopOnce.Do(
func() {
if c.stop != nil {
close(c.stop)
}
},
)
}
func (c *connMonitor) start() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
c.m.Range(func(_, value interface{}) bool {
conn := value.(*monitoredConn)
nowState := conn.GetState()
for state := connectivity.Idle; state <= connectivity.Shutdown; state++ {
if state == nowState {
metrics.TiKVGrpcConnectionState.WithLabelValues(conn.Name, conn.Target(), nowState.String()).Set(1)
} else {
metrics.TiKVGrpcConnectionState.WithLabelValues(conn.Name, conn.Target(), state.String()).Set(0)
}
}
return true
})
case <-c.stop:
return
}
}
}

View File

@ -0,0 +1,255 @@
// Copyright 2025 TiKV Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
"github.com/tikv/client-go/v2/config"
tikverr "github.com/tikv/client-go/v2/error"
"github.com/tikv/client-go/v2/metrics"
"github.com/tikv/client-go/v2/tikvrpc"
"github.com/tikv/client-go/v2/util"
"google.golang.org/grpc"
"google.golang.org/grpc/backoff"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/encoding/gzip"
"google.golang.org/grpc/experimental"
"google.golang.org/grpc/keepalive"
)
type connPool struct {
// The target host.
target string
// version of the connection pool, increase by 1 when reconnect.
ver uint64
index uint32
// streamTimeout binds with a background goroutine to process coprocessor streaming timeout.
streamTimeout chan *tikvrpc.Lease
dialTimeout time.Duration
conns []*monitoredConn
// batchConn is not null when batch is enabled.
*batchConn
done chan struct{}
monitor *connMonitor
metrics struct {
rpcLatHist *rpcMetrics
rpcSrcLatSum sync.Map
rpcNetLatExternal prometheus.Observer
rpcNetLatInternal prometheus.Observer
}
}
func newConnPool(maxSize uint, addr string, ver uint64, security config.Security,
idleNotify *uint32, enableBatch bool, dialTimeout time.Duration, m *connMonitor, eventListener *atomic.Pointer[ClientEventListener], opts []grpc.DialOption) (*connPool, error) {
a := &connPool{
ver: ver,
index: 0,
conns: make([]*monitoredConn, maxSize),
streamTimeout: make(chan *tikvrpc.Lease, 1024),
done: make(chan struct{}),
dialTimeout: dialTimeout,
monitor: m,
}
a.metrics.rpcLatHist = deriveRPCMetrics(metrics.TiKVSendReqHistogram.MustCurryWith(prometheus.Labels{metrics.LblStore: addr}))
a.metrics.rpcNetLatExternal = metrics.TiKVRPCNetLatencyHistogram.WithLabelValues(addr, "false")
a.metrics.rpcNetLatInternal = metrics.TiKVRPCNetLatencyHistogram.WithLabelValues(addr, "true")
if err := a.Init(addr, security, idleNotify, enableBatch, eventListener, opts...); err != nil {
return nil, err
}
return a, nil
}
func (a *connPool) monitoredDial(ctx context.Context, connName, target string, opts ...grpc.DialOption) (conn *monitoredConn, err error) {
conn = &monitoredConn{
Name: connName,
}
conn.ClientConn, err = grpc.DialContext(ctx, target, opts...)
if err != nil {
return nil, err
}
a.monitor.AddConn(conn)
return conn, nil
}
func (a *connPool) Init(addr string, security config.Security, idleNotify *uint32, enableBatch bool, eventListener *atomic.Pointer[ClientEventListener], opts ...grpc.DialOption) error {
a.target = addr
opt := grpc.WithTransportCredentials(insecure.NewCredentials())
if len(security.ClusterSSLCA) != 0 {
tlsConfig, err := security.ToTLSConfig()
if err != nil {
return errors.WithStack(err)
}
opt = grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))
}
cfg := config.GetGlobalConfig()
var (
unaryInterceptor grpc.UnaryClientInterceptor
streamInterceptor grpc.StreamClientInterceptor
)
if cfg.OpenTracingEnable {
unaryInterceptor = grpc_opentracing.UnaryClientInterceptor()
streamInterceptor = grpc_opentracing.StreamClientInterceptor()
}
allowBatch := (cfg.TiKVClient.MaxBatchSize > 0) && enableBatch
if allowBatch {
a.batchConn = newBatchConn(uint(len(a.conns)), cfg.TiKVClient.MaxBatchSize, idleNotify)
a.batchConn.initMetrics(a.target)
}
keepAlive := cfg.TiKVClient.GrpcKeepAliveTime
for i := range a.conns {
ctx, cancel := context.WithTimeout(context.Background(), a.dialTimeout)
var callOptions []grpc.CallOption
callOptions = append(callOptions, grpc.MaxCallRecvMsgSize(MaxRecvMsgSize))
if cfg.TiKVClient.GrpcCompressionType == gzip.Name {
callOptions = append(callOptions, grpc.UseCompressor(gzip.Name))
}
opts = append([]grpc.DialOption{
opt,
grpc.WithInitialWindowSize(cfg.TiKVClient.GrpcInitialWindowSize),
grpc.WithInitialConnWindowSize(cfg.TiKVClient.GrpcInitialConnWindowSize),
grpc.WithUnaryInterceptor(unaryInterceptor),
grpc.WithStreamInterceptor(streamInterceptor),
grpc.WithDefaultCallOptions(callOptions...),
grpc.WithConnectParams(grpc.ConnectParams{
Backoff: backoff.Config{
BaseDelay: 100 * time.Millisecond, // Default was 1s.
Multiplier: 1.6, // Default
Jitter: 0.2, // Default
MaxDelay: 3 * time.Second, // Default was 120s.
},
MinConnectTimeout: a.dialTimeout,
}),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: time.Duration(keepAlive) * time.Second,
Timeout: cfg.TiKVClient.GetGrpcKeepAliveTimeout(),
}),
}, opts...)
if cfg.TiKVClient.GrpcSharedBufferPool {
opts = append(opts, experimental.WithRecvBufferPool(grpc.NewSharedBufferPool()))
}
conn, err := a.monitoredDial(
ctx,
fmt.Sprintf("%s-%d", a.target, i),
addr,
opts...,
)
cancel()
if err != nil {
// Cleanup if the initialization fails.
a.Close()
return errors.WithStack(err)
}
a.conns[i] = conn
if allowBatch {
batchClient := &batchCommandsClient{
target: a.target,
conn: conn.ClientConn,
forwardedClients: make(map[string]*batchCommandsStream),
batched: sync.Map{},
epoch: 0,
closed: 0,
tikvClientCfg: cfg.TiKVClient,
tikvLoad: &a.tikvTransportLayerLoad,
dialTimeout: a.dialTimeout,
tryLock: tryLock{sync.NewCond(new(sync.Mutex)), false},
eventListener: eventListener,
metrics: &a.batchConn.metrics,
}
batchClient.maxConcurrencyRequestLimit.Store(cfg.TiKVClient.MaxConcurrencyRequestLimit)
a.batchCommandsClients = append(a.batchCommandsClients, batchClient)
}
}
go tikvrpc.CheckStreamTimeoutLoop(a.streamTimeout, a.done)
if allowBatch {
go a.batchSendLoop(cfg.TiKVClient)
}
return nil
}
func (a *connPool) Get() *grpc.ClientConn {
next := atomic.AddUint32(&a.index, 1) % uint32(len(a.conns))
return a.conns[next].ClientConn
}
func (a *connPool) Close() {
if a.batchConn != nil {
a.batchConn.Close()
}
for _, c := range a.conns {
if c != nil {
err := c.Close()
tikverr.Log(err)
if err == nil {
a.monitor.RemoveConn(c)
}
}
}
close(a.done)
}
func (a *connPool) updateRPCMetrics(req *tikvrpc.Request, resp *tikvrpc.Response, latency time.Duration) {
seconds := latency.Seconds()
stale := req.GetStaleRead()
source := req.GetRequestSource()
internal := util.IsInternalRequest(req.GetRequestSource())
a.metrics.rpcLatHist.get(req.Type, stale, internal).Observe(seconds)
srcLatSum, ok := a.metrics.rpcSrcLatSum.Load(source)
if !ok {
srcLatSum = deriveRPCMetrics(metrics.TiKVSendReqSummary.MustCurryWith(
prometheus.Labels{metrics.LblStore: a.target, metrics.LblSource: source}))
a.metrics.rpcSrcLatSum.Store(source, srcLatSum)
}
srcLatSum.(*rpcMetrics).get(req.Type, stale, internal).Observe(seconds)
if execDetail := resp.GetExecDetailsV2(); execDetail != nil {
var totalRpcWallTimeNs uint64
if execDetail.TimeDetailV2 != nil {
totalRpcWallTimeNs = execDetail.TimeDetailV2.TotalRpcWallTimeNs
} else if execDetail.TimeDetail != nil {
totalRpcWallTimeNs = execDetail.TimeDetail.TotalRpcWallTimeNs
}
if totalRpcWallTimeNs > 0 {
lat := latency - time.Duration(totalRpcWallTimeNs)
if internal {
a.metrics.rpcNetLatInternal.Observe(lat.Seconds())
} else {
a.metrics.rpcNetLatExternal.Observe(lat.Seconds())
}
}
}
}

View File

@ -756,11 +756,6 @@ func NewRegionCache(pdClient pd.Client, opt ...RegionCacheOpt) *RegionCache {
return c
}
// ForceRefreshAllStores get all stores from PD and refresh store cache.
func (c *RegionCache) ForceRefreshAllStores(ctx context.Context) {
refreshFullStoreList(ctx, c.stores)
}
// Try to refresh full store list. Errors are ignored.
func refreshFullStoreList(ctx context.Context, stores storeCache) {
storeList, err := stores.fetchAllStores(ctx)

View File

@ -761,7 +761,7 @@ func invokeKVStatusAPI(addr string, timeout time.Duration) (l livenessState) {
func createKVHealthClient(ctx context.Context, addr string) (*grpc.ClientConn, healthpb.HealthClient, error) {
// Temporarily directly load the config from the global config, however it's not a good idea to let RegionCache to
// access it.
// TODO: Pass the config in a better way, or use the connArray inner the client directly rather than creating new
// TODO: Pass the config in a better way, or use the connPool inner the client directly rather than creating new
// connection.
cfg := config.GetGlobalConfig()