diff --git a/internal/client/client.go b/internal/client/client.go index 19d6660d..818ff6bd 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -47,7 +47,6 @@ import ( "sync/atomic" "time" - grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" "github.com/opentracing/opentracing-go" "github.com/pingcap/kvproto/pkg/coprocessor" "github.com/pingcap/kvproto/pkg/debugpb" @@ -57,7 +56,6 @@ import ( "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" "github.com/tikv/client-go/v2/config" - tikverr "github.com/tikv/client-go/v2/error" "github.com/tikv/client-go/v2/internal/apicodec" "github.com/tikv/client-go/v2/internal/logutil" "github.com/tikv/client-go/v2/metrics" @@ -66,13 +64,7 @@ import ( "github.com/tikv/client-go/v2/util/async" "go.uber.org/zap" "google.golang.org/grpc" - "google.golang.org/grpc/backoff" "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/encoding/gzip" - "google.golang.org/grpc/experimental" - "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" ) @@ -144,306 +136,14 @@ func (e *ErrConn) Unwrap() error { return e.Err } -func WrapErrConn(err error, conn *connArray) error { +func WrapErrConn(err error, pool *connPool) error { if err == nil { return nil } return &ErrConn{ Err: err, - Addr: conn.target, - Ver: conn.ver, - } -} - -type connArray struct { - // The target host. - target string - // version of the connection array, increase by 1 when reconnect. - ver uint64 - - index uint32 - v []*monitoredConn - // streamTimeout binds with a background goroutine to process coprocessor streaming timeout. - streamTimeout chan *tikvrpc.Lease - dialTimeout time.Duration - // batchConn is not null when batch is enabled. - *batchConn - done chan struct{} - - monitor *connMonitor - - metrics struct { - rpcLatHist *rpcMetrics - rpcSrcLatSum sync.Map - rpcNetLatExternal prometheus.Observer - rpcNetLatInternal prometheus.Observer - } -} - -func newConnArray(maxSize uint, addr string, ver uint64, security config.Security, - idleNotify *uint32, enableBatch bool, dialTimeout time.Duration, m *connMonitor, eventListener *atomic.Pointer[ClientEventListener], opts []grpc.DialOption) (*connArray, error) { - a := &connArray{ - ver: ver, - index: 0, - v: make([]*monitoredConn, maxSize), - streamTimeout: make(chan *tikvrpc.Lease, 1024), - done: make(chan struct{}), - dialTimeout: dialTimeout, - monitor: m, - } - a.metrics.rpcLatHist = deriveRPCMetrics(metrics.TiKVSendReqHistogram.MustCurryWith(prometheus.Labels{metrics.LblStore: addr})) - a.metrics.rpcNetLatExternal = metrics.TiKVRPCNetLatencyHistogram.WithLabelValues(addr, "false") - a.metrics.rpcNetLatInternal = metrics.TiKVRPCNetLatencyHistogram.WithLabelValues(addr, "true") - if err := a.Init(addr, security, idleNotify, enableBatch, eventListener, opts...); err != nil { - return nil, err - } - return a, nil -} - -type connMonitor struct { - m sync.Map - loopOnce sync.Once - stopOnce sync.Once - stop chan struct{} -} - -func (c *connMonitor) AddConn(conn *monitoredConn) { - c.m.Store(conn.Name, conn) -} - -func (c *connMonitor) RemoveConn(conn *monitoredConn) { - c.m.Delete(conn.Name) - for state := connectivity.Idle; state <= connectivity.Shutdown; state++ { - metrics.TiKVGrpcConnectionState.WithLabelValues(conn.Name, conn.Target(), state.String()).Set(0) - } -} - -func (c *connMonitor) Start() { - c.loopOnce.Do( - func() { - c.stop = make(chan struct{}) - go c.start() - }, - ) -} - -func (c *connMonitor) Stop() { - c.stopOnce.Do( - func() { - if c.stop != nil { - close(c.stop) - } - }, - ) -} - -func (c *connMonitor) start() { - - ticker := time.NewTicker(time.Second) - defer ticker.Stop() - for { - select { - case <-ticker.C: - c.m.Range(func(_, value interface{}) bool { - conn := value.(*monitoredConn) - nowState := conn.GetState() - for state := connectivity.Idle; state <= connectivity.Shutdown; state++ { - if state == nowState { - metrics.TiKVGrpcConnectionState.WithLabelValues(conn.Name, conn.Target(), nowState.String()).Set(1) - } else { - metrics.TiKVGrpcConnectionState.WithLabelValues(conn.Name, conn.Target(), state.String()).Set(0) - } - } - return true - }) - case <-c.stop: - return - } - } -} - -type monitoredConn struct { - *grpc.ClientConn - Name string -} - -func (a *connArray) monitoredDial(ctx context.Context, connName, target string, opts ...grpc.DialOption) (conn *monitoredConn, err error) { - conn = &monitoredConn{ - Name: connName, - } - conn.ClientConn, err = grpc.DialContext(ctx, target, opts...) - if err != nil { - return nil, err - } - a.monitor.AddConn(conn) - return conn, nil -} - -func (c *monitoredConn) Close() error { - if c.ClientConn != nil { - err := c.ClientConn.Close() - logutil.BgLogger().Debug("close gRPC connection", zap.String("target", c.Name), zap.Error(err)) - return err - } - return nil -} - -func (a *connArray) Init(addr string, security config.Security, idleNotify *uint32, enableBatch bool, eventListener *atomic.Pointer[ClientEventListener], opts ...grpc.DialOption) error { - a.target = addr - - opt := grpc.WithTransportCredentials(insecure.NewCredentials()) - if len(security.ClusterSSLCA) != 0 { - tlsConfig, err := security.ToTLSConfig() - if err != nil { - return errors.WithStack(err) - } - opt = grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)) - } - - cfg := config.GetGlobalConfig() - var ( - unaryInterceptor grpc.UnaryClientInterceptor - streamInterceptor grpc.StreamClientInterceptor - ) - if cfg.OpenTracingEnable { - unaryInterceptor = grpc_opentracing.UnaryClientInterceptor() - streamInterceptor = grpc_opentracing.StreamClientInterceptor() - } - - allowBatch := (cfg.TiKVClient.MaxBatchSize > 0) && enableBatch - if allowBatch { - a.batchConn = newBatchConn(uint(len(a.v)), cfg.TiKVClient.MaxBatchSize, idleNotify) - a.batchConn.initMetrics(a.target) - } - keepAlive := cfg.TiKVClient.GrpcKeepAliveTime - for i := range a.v { - ctx, cancel := context.WithTimeout(context.Background(), a.dialTimeout) - var callOptions []grpc.CallOption - callOptions = append(callOptions, grpc.MaxCallRecvMsgSize(MaxRecvMsgSize)) - if cfg.TiKVClient.GrpcCompressionType == gzip.Name { - callOptions = append(callOptions, grpc.UseCompressor(gzip.Name)) - } - - opts = append([]grpc.DialOption{ - opt, - grpc.WithInitialWindowSize(cfg.TiKVClient.GrpcInitialWindowSize), - grpc.WithInitialConnWindowSize(cfg.TiKVClient.GrpcInitialConnWindowSize), - grpc.WithUnaryInterceptor(unaryInterceptor), - grpc.WithStreamInterceptor(streamInterceptor), - grpc.WithDefaultCallOptions(callOptions...), - grpc.WithConnectParams(grpc.ConnectParams{ - Backoff: backoff.Config{ - BaseDelay: 100 * time.Millisecond, // Default was 1s. - Multiplier: 1.6, // Default - Jitter: 0.2, // Default - MaxDelay: 3 * time.Second, // Default was 120s. - }, - MinConnectTimeout: a.dialTimeout, - }), - grpc.WithKeepaliveParams(keepalive.ClientParameters{ - Time: time.Duration(keepAlive) * time.Second, - Timeout: cfg.TiKVClient.GetGrpcKeepAliveTimeout(), - }), - }, opts...) - if cfg.TiKVClient.GrpcSharedBufferPool { - opts = append(opts, experimental.WithRecvBufferPool(grpc.NewSharedBufferPool())) - } - conn, err := a.monitoredDial( - ctx, - fmt.Sprintf("%s-%d", a.target, i), - addr, - opts..., - ) - - cancel() - if err != nil { - // Cleanup if the initialization fails. - a.Close() - return errors.WithStack(err) - } - a.v[i] = conn - - if allowBatch { - batchClient := &batchCommandsClient{ - target: a.target, - conn: conn.ClientConn, - forwardedClients: make(map[string]*batchCommandsStream), - batched: sync.Map{}, - epoch: 0, - closed: 0, - tikvClientCfg: cfg.TiKVClient, - tikvLoad: &a.tikvTransportLayerLoad, - dialTimeout: a.dialTimeout, - tryLock: tryLock{sync.NewCond(new(sync.Mutex)), false}, - eventListener: eventListener, - metrics: &a.batchConn.metrics, - } - batchClient.maxConcurrencyRequestLimit.Store(cfg.TiKVClient.MaxConcurrencyRequestLimit) - a.batchCommandsClients = append(a.batchCommandsClients, batchClient) - } - } - go tikvrpc.CheckStreamTimeoutLoop(a.streamTimeout, a.done) - if allowBatch { - go a.batchSendLoop(cfg.TiKVClient) - } - - return nil -} - -func (a *connArray) Get() *grpc.ClientConn { - next := atomic.AddUint32(&a.index, 1) % uint32(len(a.v)) - return a.v[next].ClientConn -} - -func (a *connArray) Close() { - if a.batchConn != nil { - a.batchConn.Close() - } - - for _, c := range a.v { - if c != nil { - err := c.Close() - tikverr.Log(err) - if err == nil { - a.monitor.RemoveConn(c) - } - } - } - - close(a.done) -} - -func (a *connArray) updateRPCMetrics(req *tikvrpc.Request, resp *tikvrpc.Response, latency time.Duration) { - seconds := latency.Seconds() - stale := req.GetStaleRead() - source := req.GetRequestSource() - internal := util.IsInternalRequest(req.GetRequestSource()) - - a.metrics.rpcLatHist.get(req.Type, stale, internal).Observe(seconds) - - srcLatSum, ok := a.metrics.rpcSrcLatSum.Load(source) - if !ok { - srcLatSum = deriveRPCMetrics(metrics.TiKVSendReqSummary.MustCurryWith( - prometheus.Labels{metrics.LblStore: a.target, metrics.LblSource: source})) - a.metrics.rpcSrcLatSum.Store(source, srcLatSum) - } - srcLatSum.(*rpcMetrics).get(req.Type, stale, internal).Observe(seconds) - - if execDetail := resp.GetExecDetailsV2(); execDetail != nil { - var totalRpcWallTimeNs uint64 - if execDetail.TimeDetailV2 != nil { - totalRpcWallTimeNs = execDetail.TimeDetailV2.TotalRpcWallTimeNs - } else if execDetail.TimeDetail != nil { - totalRpcWallTimeNs = execDetail.TimeDetail.TotalRpcWallTimeNs - } - if totalRpcWallTimeNs > 0 { - lat := latency - time.Duration(totalRpcWallTimeNs) - if internal { - a.metrics.rpcNetLatInternal.Observe(lat.Seconds()) - } else { - a.metrics.rpcNetLatExternal.Observe(lat.Seconds()) - } - } + Addr: pool.target, + Ver: pool.ver, } } @@ -485,9 +185,9 @@ func WithCodec(codec apicodec.Codec) Opt { type RPCClient struct { sync.RWMutex - conns map[string]*connArray - vers map[string]uint64 - option *option + connPools map[string]*connPool + vers map[string]uint64 + option *option idleNotify uint32 @@ -505,8 +205,8 @@ var _ Client = &RPCClient{} // NewRPCClient creates a client that manages connections and rpc calls with tikv-servers. func NewRPCClient(opts ...Opt) *RPCClient { cli := &RPCClient{ - conns: make(map[string]*connArray), - vers: make(map[string]uint64), + connPools: make(map[string]*connPool), + vers: make(map[string]uint64), option: &option{ dialTimeout: dialTimeout, }, @@ -520,35 +220,35 @@ func NewRPCClient(opts ...Opt) *RPCClient { return cli } -func (c *RPCClient) getConnArray(addr string, enableBatch bool, opt ...func(cfg *config.TiKVClient)) (*connArray, error) { +func (c *RPCClient) getConnPool(addr string, enableBatch bool, opt ...func(cfg *config.TiKVClient)) (*connPool, error) { c.RLock() if c.isClosed { c.RUnlock() return nil, errors.Errorf("rpcClient is closed") } - array, ok := c.conns[addr] + pool, ok := c.connPools[addr] c.RUnlock() if !ok { var err error - array, err = c.createConnArray(addr, enableBatch, opt...) + pool, err = c.createConnPool(addr, enableBatch, opt...) if err != nil { return nil, err } } - // An idle connArray will not change to active again, this avoid the race condition + // An idle connPool will not change to active again, this avoid the race condition // that recycling idle connection close an active connection unexpectedly (idle -> active). - if array.batchConn != nil && array.isIdle() { + if pool.batchConn != nil && pool.isIdle() { return nil, errors.Errorf("rpcClient is idle") } - return array, nil + return pool, nil } -func (c *RPCClient) createConnArray(addr string, enableBatch bool, opts ...func(cfg *config.TiKVClient)) (*connArray, error) { +func (c *RPCClient) createConnPool(addr string, enableBatch bool, opts ...func(cfg *config.TiKVClient)) (*connPool, error) { c.Lock() defer c.Unlock() - array, ok := c.conns[addr] + pool, ok := c.connPools[addr] if !ok { var err error client := config.GetGlobalConfig().TiKVClient @@ -556,7 +256,7 @@ func (c *RPCClient) createConnArray(addr string, enableBatch bool, opts ...func( opt(&client) } ver := c.vers[addr] + 1 - array, err = newConnArray( + pool, err = newConnPool( client.GrpcConnectionCount, addr, ver, @@ -571,10 +271,10 @@ func (c *RPCClient) createConnArray(addr string, enableBatch bool, opts ...func( if err != nil { return nil, err } - c.conns[addr] = array + c.connPools[addr] = pool c.vers[addr] = ver } - return array, nil + return pool, nil } func (c *RPCClient) closeConns() { @@ -582,8 +282,8 @@ func (c *RPCClient) closeConns() { if !c.isClosed { c.isClosed = true // close all connections - for _, array := range c.conns { - array.Close() + for _, pool := range c.connPools { + pool.Close() } } c.Unlock() @@ -595,10 +295,10 @@ func (c *RPCClient) recycleIdleConnArray() { var addrs []string var vers []uint64 c.RLock() - for _, conn := range c.conns { - if conn.batchConn != nil && conn.isIdle() { - addrs = append(addrs, conn.target) - vers = append(vers, conn.ver) + for _, pool := range c.connPools { + if pool.batchConn != nil && pool.isIdle() { + addrs = append(addrs, pool.target) + vers = append(vers, pool.ver) } } c.RUnlock() @@ -627,19 +327,19 @@ func (c *RPCClient) sendRequest(ctx context.Context, addr string, req *tikvrpc.R // TiDB will not send batch commands to TiFlash, to resolve the conflict with Batch Cop Request. // tiflash/tiflash_mpp/tidb don't use BatchCommand. enableBatch := req.StoreTp == tikvrpc.TiKV - connArray, err := c.getConnArray(addr, enableBatch) + connPool, err := c.getConnPool(addr, enableBatch) if err != nil { return nil, err } wrapErrConn := func(resp *tikvrpc.Response, err error) (*tikvrpc.Response, error) { - return resp, WrapErrConn(err, connArray) + return resp, WrapErrConn(err, connPool) } start := time.Now() defer func() { elapsed := time.Since(start) - connArray.updateRPCMetrics(req, resp, elapsed) + connPool.updateRPCMetrics(req, resp, elapsed) if spanRPC != nil && util.TraceExecDetailsEnabled(ctx) { if si := buildSpanInfoFromResp(resp); si != nil { @@ -654,11 +354,11 @@ func (c *RPCClient) sendRequest(ctx context.Context, addr string, req *tikvrpc.R if config.GetGlobalConfig().TiKVClient.MaxBatchSize > 0 && enableBatch { if batchReq := req.ToBatchCommandsRequest(); batchReq != nil { defer trace.StartRegion(ctx, req.Type.String()).End() - return wrapErrConn(sendBatchRequest(ctx, addr, req.ForwardedHost, connArray.batchConn, batchReq, timeout, pri)) + return wrapErrConn(sendBatchRequest(ctx, addr, req.ForwardedHost, connPool.batchConn, batchReq, timeout, pri)) } } - clientConn := connArray.Get() + clientConn := connPool.Get() if state := clientConn.GetState(); state == connectivity.TransientFailure { storeID := strconv.FormatUint(req.Context.GetPeer().GetStoreId(), 10) metrics.TiKVGRPCConnTransientFailureCounter.WithLabelValues(addr, storeID).Inc() @@ -679,11 +379,11 @@ func (c *RPCClient) sendRequest(ctx context.Context, addr string, req *tikvrpc.R } switch req.Type { case tikvrpc.CmdBatchCop: - return wrapErrConn(c.getBatchCopStreamResponse(ctx, client, req, timeout, connArray)) + return wrapErrConn(c.getBatchCopStreamResponse(ctx, client, req, timeout, connPool)) case tikvrpc.CmdCopStream: - return wrapErrConn(c.getCopStreamResponse(ctx, client, req, timeout, connArray)) + return wrapErrConn(c.getCopStreamResponse(ctx, client, req, timeout, connPool)) case tikvrpc.CmdMPPConn: - return wrapErrConn(c.getMPPStreamResponse(ctx, client, req, timeout, connArray)) + return wrapErrConn(c.getMPPStreamResponse(ctx, client, req, timeout, connPool)) } // Or else it's a unary call. ctx1, cancel := context.WithTimeout(ctx, timeout) @@ -710,7 +410,7 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R return codec.DecodeResponse(req, resp) } -func (c *RPCClient) getCopStreamResponse(ctx context.Context, client tikvpb.TikvClient, req *tikvrpc.Request, timeout time.Duration, connArray *connArray) (*tikvrpc.Response, error) { +func (c *RPCClient) getCopStreamResponse(ctx context.Context, client tikvpb.TikvClient, req *tikvrpc.Request, timeout time.Duration, connPool *connPool) (*tikvrpc.Response, error) { // Coprocessor streaming request. // Use context to support timeout for grpc streaming client. ctx1, cancel := context.WithCancel(ctx) @@ -727,7 +427,7 @@ func (c *RPCClient) getCopStreamResponse(ctx context.Context, client tikvpb.Tikv copStream := resp.Resp.(*tikvrpc.CopStreamResponse) copStream.Timeout = timeout copStream.Lease.Cancel = cancel - connArray.streamTimeout <- &copStream.Lease + connPool.streamTimeout <- &copStream.Lease // Read the first streaming response to get CopStreamResponse. // This can make error handling much easier, because SendReq() retry on @@ -745,7 +445,7 @@ func (c *RPCClient) getCopStreamResponse(ctx context.Context, client tikvpb.Tikv } -func (c *RPCClient) getBatchCopStreamResponse(ctx context.Context, client tikvpb.TikvClient, req *tikvrpc.Request, timeout time.Duration, connArray *connArray) (*tikvrpc.Response, error) { +func (c *RPCClient) getBatchCopStreamResponse(ctx context.Context, client tikvpb.TikvClient, req *tikvrpc.Request, timeout time.Duration, connPool *connPool) (*tikvrpc.Response, error) { // Coprocessor streaming request. // Use context to support timeout for grpc streaming client. ctx1, cancel := context.WithCancel(ctx) @@ -762,7 +462,7 @@ func (c *RPCClient) getBatchCopStreamResponse(ctx context.Context, client tikvpb copStream := resp.Resp.(*tikvrpc.BatchCopStreamResponse) copStream.Timeout = timeout copStream.Lease.Cancel = cancel - connArray.streamTimeout <- &copStream.Lease + connPool.streamTimeout <- &copStream.Lease // Read the first streaming response to get CopStreamResponse. // This can make error handling much easier, because SendReq() retry on @@ -779,7 +479,7 @@ func (c *RPCClient) getBatchCopStreamResponse(ctx context.Context, client tikvpb return resp, nil } -func (c *RPCClient) getMPPStreamResponse(ctx context.Context, client tikvpb.TikvClient, req *tikvrpc.Request, timeout time.Duration, connArray *connArray) (*tikvrpc.Response, error) { +func (c *RPCClient) getMPPStreamResponse(ctx context.Context, client tikvpb.TikvClient, req *tikvrpc.Request, timeout time.Duration, connPool *connPool) (*tikvrpc.Response, error) { // MPP streaming request. // Use context to support timeout for grpc streaming client. ctx1, cancel := context.WithCancel(ctx) @@ -796,7 +496,7 @@ func (c *RPCClient) getMPPStreamResponse(ctx context.Context, client tikvpb.Tikv copStream := resp.Resp.(*tikvrpc.MPPStreamResponse) copStream.Timeout = timeout copStream.Lease.Cancel = cancel - connArray.streamTimeout <- &copStream.Lease + connPool.streamTimeout <- &copStream.Lease // Read the first streaming response to get CopStreamResponse. // This can make error handling much easier, because SendReq() retry on @@ -831,20 +531,20 @@ func (c *RPCClient) CloseAddrVer(addr string, ver uint64) error { c.Unlock() return nil } - conn, ok := c.conns[addr] + pool, ok := c.connPools[addr] if ok { - if conn.ver <= ver { - delete(c.conns, addr) - logutil.BgLogger().Debug("close connection", zap.String("target", addr), zap.Uint64("ver", ver), zap.Uint64("conn.ver", conn.ver)) + if pool.ver <= ver { + delete(c.connPools, addr) + logutil.BgLogger().Debug("close connection", zap.String("target", addr), zap.Uint64("ver", ver), zap.Uint64("pool.ver", pool.ver)) } else { - logutil.BgLogger().Debug("ignore close connection", zap.String("target", addr), zap.Uint64("ver", ver), zap.Uint64("conn.ver", conn.ver)) - conn = nil + logutil.BgLogger().Debug("ignore close connection", zap.String("target", addr), zap.Uint64("ver", ver), zap.Uint64("pool.ver", pool.ver)) + pool = nil } } c.Unlock() - if conn != nil { - conn.Close() + if pool != nil { + pool.Close() } return nil } diff --git a/internal/client/client_async.go b/internal/client/client_async.go index 04c00f3c..bbd14cc0 100644 --- a/internal/client/client_async.go +++ b/internal/client/client_async.go @@ -72,10 +72,10 @@ func (c *RPCClient) SendRequestAsync(ctx context.Context, addr string, req *tikv } tikvrpc.AttachContext(req, req.Context) - // TODO(zyguan): If the client created `WithGRPCDialOptions(grpc.WithBlock())`, `getConnArray` might be blocked for + // TODO(zyguan): If the client created `WithGRPCDialOptions(grpc.WithBlock())`, `getConnPool` might be blocked for // a while when the corresponding conn array is uninitialized. However, since tidb won't set this option, we just - // keep `getConnArray` synchronous for now. - connArray, err := c.getConnArray(addr, true) + // keep `getConnPool` synchronous for now. + connPool, err := c.getConnPool(addr, true) if err != nil { cb.Invoke(nil, err) return @@ -113,7 +113,7 @@ func (c *RPCClient) SendRequestAsync(ctx context.Context, addr string, req *tikv metrics.BatchRequestDurationDone.Observe(elapsed.Seconds()) // rpc metrics - connArray.updateRPCMetrics(req, resp, elapsed) + connPool.updateRPCMetrics(req, resp, elapsed) // tracing if spanRPC != nil { @@ -131,7 +131,7 @@ func (c *RPCClient) SendRequestAsync(ctx context.Context, addr string, req *tikv resp, err = c.option.codec.DecodeResponse(req, resp) } - return resp, WrapErrConn(err, connArray) + return resp, WrapErrConn(err, connPool) }) stop = context.AfterFunc(ctx, func() { @@ -139,7 +139,7 @@ func (c *RPCClient) SendRequestAsync(ctx context.Context, addr string, req *tikv entry.error(ctx.Err()) }) - batchConn := connArray.batchConn + batchConn := connPool.batchConn if val, err := util.EvalFailpoint("mockBatchCommandsChannelFullOnAsyncSend"); err == nil { mockBatchCommandsChannelFullOnAsyncSend(ctx, batchConn, cb, val) } diff --git a/internal/client/client_batch.go b/internal/client/client_batch.go index 7aa12964..d97e3ae2 100644 --- a/internal/client/client_batch.go +++ b/internal/client/client_batch.go @@ -40,7 +40,6 @@ import ( "encoding/json" "fmt" "math" - "runtime" "runtime/trace" "strings" "sync" @@ -264,159 +263,6 @@ type batchConnMetrics struct { bestBatchSize prometheus.Observer } -type batchConn struct { - // An atomic flag indicates whether the batch is idle or not. - // 0 for busy, others for idle. - idle uint32 - - // batchCommandsCh used for batch commands. - batchCommandsCh chan *batchCommandsEntry - batchCommandsClients []*batchCommandsClient - tikvTransportLayerLoad uint64 - closed chan struct{} - - reqBuilder *batchCommandsBuilder - - // Notify rpcClient to check the idle flag - idleNotify *uint32 - idleDetect *time.Timer - - fetchMoreTimer *time.Timer - - index uint32 - - metrics batchConnMetrics -} - -func newBatchConn(connCount, maxBatchSize uint, idleNotify *uint32) *batchConn { - return &batchConn{ - batchCommandsCh: make(chan *batchCommandsEntry, maxBatchSize), - batchCommandsClients: make([]*batchCommandsClient, 0, connCount), - tikvTransportLayerLoad: 0, - closed: make(chan struct{}), - reqBuilder: newBatchCommandsBuilder(maxBatchSize), - idleNotify: idleNotify, - idleDetect: time.NewTimer(idleTimeout), - } -} - -func (a *batchConn) initMetrics(target string) { - a.metrics.pendingRequests = metrics.TiKVBatchPendingRequests.WithLabelValues(target) - a.metrics.batchSize = metrics.TiKVBatchRequests.WithLabelValues(target) - a.metrics.sendLoopWaitHeadDur = metrics.TiKVBatchSendLoopDuration.WithLabelValues(target, "wait-head") - a.metrics.sendLoopWaitMoreDur = metrics.TiKVBatchSendLoopDuration.WithLabelValues(target, "wait-more") - a.metrics.sendLoopSendDur = metrics.TiKVBatchSendLoopDuration.WithLabelValues(target, "send") - a.metrics.recvLoopRecvDur = metrics.TiKVBatchRecvLoopDuration.WithLabelValues(target, "recv") - a.metrics.recvLoopProcessDur = metrics.TiKVBatchRecvLoopDuration.WithLabelValues(target, "process") - a.metrics.batchSendTailLat = metrics.TiKVBatchSendTailLatency.WithLabelValues(target) - a.metrics.batchRecvTailLat = metrics.TiKVBatchRecvTailLatency.WithLabelValues(target) - a.metrics.headArrivalInterval = metrics.TiKVBatchHeadArrivalInterval.WithLabelValues(target) - a.metrics.batchMoreRequests = metrics.TiKVBatchMoreRequests.WithLabelValues(target) - a.metrics.bestBatchSize = metrics.TiKVBatchBestSize.WithLabelValues(target) -} - -func (a *batchConn) isIdle() bool { - return atomic.LoadUint32(&a.idle) != 0 -} - -// fetchAllPendingRequests fetches all pending requests from the channel. -func (a *batchConn) fetchAllPendingRequests(maxBatchSize int) (headRecvTime time.Time, headArrivalInterval time.Duration) { - // Block on the first element. - latestReqStartTime := a.reqBuilder.latestReqStartTime - var headEntry *batchCommandsEntry - select { - case headEntry = <-a.batchCommandsCh: - if !a.idleDetect.Stop() { - <-a.idleDetect.C - } - a.idleDetect.Reset(idleTimeout) - case <-a.idleDetect.C: - a.idleDetect.Reset(idleTimeout) - atomic.AddUint32(&a.idle, 1) - atomic.CompareAndSwapUint32(a.idleNotify, 0, 1) - // This batchConn to be recycled - return time.Now(), 0 - case <-a.closed: - return time.Now(), 0 - } - if headEntry == nil { - return time.Now(), 0 - } - headRecvTime = time.Now() - if headEntry.start.After(latestReqStartTime) && !latestReqStartTime.IsZero() { - headArrivalInterval = headEntry.start.Sub(latestReqStartTime) - } - a.reqBuilder.push(headEntry) - - // This loop is for trying best to collect more requests. - for a.reqBuilder.len() < maxBatchSize { - select { - case entry := <-a.batchCommandsCh: - if entry == nil { - return - } - a.reqBuilder.push(entry) - default: - return - } - } - return -} - -// fetchMorePendingRequests fetches more pending requests from the channel. -func (a *batchConn) fetchMorePendingRequests( - maxBatchSize int, - batchWaitSize int, - maxWaitTime time.Duration, -) { - // Try to collect `batchWaitSize` requests, or wait `maxWaitTime`. - if a.fetchMoreTimer == nil { - a.fetchMoreTimer = time.NewTimer(maxWaitTime) - } else { - a.fetchMoreTimer.Reset(maxWaitTime) - } - for a.reqBuilder.len() < batchWaitSize { - select { - case entry := <-a.batchCommandsCh: - if entry == nil { - if !a.fetchMoreTimer.Stop() { - <-a.fetchMoreTimer.C - } - return - } - a.reqBuilder.push(entry) - case <-a.fetchMoreTimer.C: - return - } - } - if !a.fetchMoreTimer.Stop() { - <-a.fetchMoreTimer.C - } - - // Do an additional non-block try. Here we test the length with `maxBatchSize` instead - // of `batchWaitSize` because trying best to fetch more requests is necessary so that - // we can adjust the `batchWaitSize` dynamically. - yielded := false - for a.reqBuilder.len() < maxBatchSize { - select { - case entry := <-a.batchCommandsCh: - if entry == nil { - return - } - a.reqBuilder.push(entry) - default: - if yielded { - return - } - // yield once to batch more requests. - runtime.Gosched() - yielded = true - } - } -} - -const idleTimeout = 3 * time.Minute - var ( // presetBatchPolicies defines a set of [turboBatchOptions] as batch policies. presetBatchPolicies = map[string]turboBatchOptions{ @@ -534,150 +380,6 @@ func (t *turboBatchTrigger) preferredBatchWaitSize(avgBatchWaitSize float64, def return batchWaitSize } -// BatchSendLoopPanicCounter is only used for testing. -var BatchSendLoopPanicCounter int64 = 0 - -var initBatchPolicyWarn sync.Once - -func (a *batchConn) batchSendLoop(cfg config.TiKVClient) { - defer func() { - if r := recover(); r != nil { - metrics.TiKVPanicCounter.WithLabelValues(metrics.LabelBatchSendLoop).Inc() - logutil.BgLogger().Error("batchSendLoop", - zap.Any("r", r), - zap.Stack("stack")) - atomic.AddInt64(&BatchSendLoopPanicCounter, 1) - logutil.BgLogger().Info("restart batchSendLoop", zap.Int64("count", atomic.LoadInt64(&BatchSendLoopPanicCounter))) - go a.batchSendLoop(cfg) - } - }() - - trigger, ok := newTurboBatchTriggerFromPolicy(cfg.BatchPolicy) - if !ok { - initBatchPolicyWarn.Do(func() { - logutil.BgLogger().Warn("fallback to default batch policy due to invalid value", zap.String("value", cfg.BatchPolicy)) - }) - } - turboBatchWaitTime := trigger.turboWaitTime() - - avgBatchWaitSize := float64(cfg.BatchWaitSize) - for { - sendLoopStartTime := time.Now() - a.reqBuilder.reset() - - headRecvTime, headArrivalInterval := a.fetchAllPendingRequests(int(cfg.MaxBatchSize)) - if a.reqBuilder.len() == 0 { - // the conn is closed or recycled. - return - } - - // curl -X PUT -d 'return(true)' http://0.0.0.0:10080/fail/tikvclient/mockBlockOnBatchClient - if val, err := util.EvalFailpoint("mockBlockOnBatchClient"); err == nil { - if val.(bool) { - time.Sleep(1 * time.Hour) - } - } - - if batchSize := a.reqBuilder.len(); batchSize < int(cfg.MaxBatchSize) { - if cfg.MaxBatchWaitTime > 0 && atomic.LoadUint64(&a.tikvTransportLayerLoad) > uint64(cfg.OverloadThreshold) { - // If the target TiKV is overload, wait a while to collect more requests. - metrics.TiKVBatchWaitOverLoad.Inc() - a.fetchMorePendingRequests(int(cfg.MaxBatchSize), int(cfg.BatchWaitSize), cfg.MaxBatchWaitTime) - } else if turboBatchWaitTime > 0 && headArrivalInterval > 0 && trigger.needFetchMore(headArrivalInterval) { - batchWaitSize := trigger.preferredBatchWaitSize(avgBatchWaitSize, int(cfg.BatchWaitSize)) - a.fetchMorePendingRequests(int(cfg.MaxBatchSize), batchWaitSize, turboBatchWaitTime) - a.metrics.batchMoreRequests.Observe(float64(a.reqBuilder.len() - batchSize)) - } - } - length := a.reqBuilder.len() - avgBatchWaitSize = 0.2*float64(length) + 0.8*avgBatchWaitSize - a.metrics.pendingRequests.Observe(float64(len(a.batchCommandsCh) + length)) - a.metrics.bestBatchSize.Observe(avgBatchWaitSize) - a.metrics.headArrivalInterval.Observe(headArrivalInterval.Seconds()) - a.metrics.sendLoopWaitHeadDur.Observe(headRecvTime.Sub(sendLoopStartTime).Seconds()) - a.metrics.sendLoopWaitMoreDur.Observe(time.Since(sendLoopStartTime).Seconds()) - - a.getClientAndSend() - - sendLoopEndTime := time.Now() - a.metrics.sendLoopSendDur.Observe(sendLoopEndTime.Sub(sendLoopStartTime).Seconds()) - if dur := sendLoopEndTime.Sub(headRecvTime); dur > batchSendTailLatThreshold { - a.metrics.batchSendTailLat.Observe(dur.Seconds()) - } - } -} - -const ( - SendFailedReasonNoAvailableLimit = "concurrency limit exceeded" - SendFailedReasonTryLockForSendFail = "tryLockForSend fail" -) - -func (a *batchConn) getClientAndSend() { - if val, err := util.EvalFailpoint("mockBatchClientSendDelay"); err == nil { - if timeout, ok := val.(int); ok && timeout > 0 { - time.Sleep(time.Duration(timeout * int(time.Millisecond))) - } - } - - // Choose a connection by round-robbin. - var ( - cli *batchCommandsClient - target string - ) - reasons := make([]string, 0) - hasHighPriorityTask := a.reqBuilder.hasHighPriorityTask() - for i := 0; i < len(a.batchCommandsClients); i++ { - a.index = (a.index + 1) % uint32(len(a.batchCommandsClients)) - target = a.batchCommandsClients[a.index].target - // The lock protects the batchCommandsClient from been closed while it's in use. - c := a.batchCommandsClients[a.index] - if hasHighPriorityTask || c.available() > 0 { - if c.tryLockForSend() { - cli = c - break - } else { - reasons = append(reasons, SendFailedReasonTryLockForSendFail) - } - } else { - reasons = append(reasons, SendFailedReasonNoAvailableLimit) - } - } - if cli == nil { - logutil.BgLogger().Info("no available connections", zap.String("target", target), zap.Any("reasons", reasons)) - metrics.TiKVNoAvailableConnectionCounter.Inc() - if config.GetGlobalConfig().TiKVClient.MaxConcurrencyRequestLimit == config.DefMaxConcurrencyRequestLimit { - // Only cancel requests when MaxConcurrencyRequestLimit feature is not enabled, to be compatible with the behavior of older versions. - // TODO: But when MaxConcurrencyRequestLimit feature is enabled, the requests won't be canceled and will wait until timeout. - // This behavior may not be reasonable, as the timeout is usually 40s or 60s, which is too long to retry in time. - a.reqBuilder.cancel(errors.New("no available connections")) - } - return - } - defer cli.unlockForSend() - available := cli.available() - reqSendTime := time.Now() - batch := 0 - req, forwardingReqs := a.reqBuilder.buildWithLimit(available, func(id uint64, e *batchCommandsEntry) { - cli.batched.Store(id, e) - cli.sent.Add(1) - atomic.StoreInt64(&e.sendLat, int64(reqSendTime.Sub(e.start))) - if trace.IsEnabled() { - trace.Log(e.ctx, "rpc", "send") - } - }) - if req != nil { - batch += len(req.RequestIds) - cli.send("", req) - } - for forwardedHost, req := range forwardingReqs { - batch += len(req.RequestIds) - cli.send(forwardedHost, req) - } - if batch > 0 { - a.metrics.batchSize.Observe(float64(batch)) - } -} - type tryLock struct { *sync.Cond reCreating bool @@ -1127,18 +829,6 @@ func (c *batchCommandsClient) initBatchClient(forwardedHost string) error { return nil } -func (a *batchConn) Close() { - // Close all batchRecvLoop. - for _, c := range a.batchCommandsClients { - // After connections are closed, `batchRecvLoop`s will check the flag. - atomic.StoreInt32(&c.closed, 1) - } - // Don't close(batchCommandsCh) because when Close() is called, someone maybe - // calling SendRequest and writing batchCommandsCh, if we close it here the - // writing goroutine will panic. - close(a.closed) -} - func sendBatchRequest( ctx context.Context, addr string, diff --git a/internal/client/client_fail_test.go b/internal/client/client_fail_test.go index 6da27594..bc530e80 100644 --- a/internal/client/client_fail_test.go +++ b/internal/client/client_fail_test.go @@ -64,7 +64,7 @@ func TestPanicInRecvLoop(t *testing.T) { rpcClient.option.dialTimeout = time.Second / 3 // Start batchRecvLoop, and it should panic in `failPendingRequests`. - _, err := rpcClient.getConnArray(addr, true, func(cfg *config.TiKVClient) { cfg.GrpcConnectionCount = 1 }) + _, err := rpcClient.getConnPool(addr, true, func(cfg *config.TiKVClient) { cfg.GrpcConnectionCount = 1 }) assert.Nil(t, err, "cannot establish local connection due to env problems(e.g. heavy load in test machine), please retry again") req := tikvrpc.NewRequest(tikvrpc.CmdEmpty, &tikvpb.BatchCommandsEmptyRequest{}) @@ -103,10 +103,10 @@ func TestRecvErrorInMultipleRecvLoops(t *testing.T) { _, err := rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second) assert.Nil(t, err) } - connArray, err := rpcClient.getConnArray(addr, true) - assert.NotNil(t, connArray) + connPool, err := rpcClient.getConnPool(addr, true) + assert.NotNil(t, connPool) assert.Nil(t, err) - batchConn := connArray.batchConn + batchConn := connPool.batchConn assert.NotNil(t, batchConn) assert.Equal(t, len(batchConn.batchCommandsClients), 1) batchClient := batchConn.batchCommandsClients[0] diff --git a/internal/client/client_test.go b/internal/client/client_test.go index 62629342..5608eec1 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -74,28 +74,28 @@ func TestConn(t *testing.T) { defer client.Close() addr := "127.0.0.1:6379" - conn1, err := client.getConnArray(addr, true) + conn1, err := client.getConnPool(addr, true) assert.Nil(t, err) - conn2, err := client.getConnArray(addr, true) + conn2, err := client.getConnPool(addr, true) assert.Nil(t, err) assert.False(t, conn2.Get() == conn1.Get()) ver := conn2.ver assert.Nil(t, client.CloseAddrVer(addr, ver-1)) - _, ok := client.conns[addr] + _, ok := client.connPools[addr] assert.True(t, ok) assert.Nil(t, client.CloseAddrVer(addr, ver)) - _, ok = client.conns[addr] + _, ok = client.connPools[addr] assert.False(t, ok) - conn3, err := client.getConnArray(addr, true) + conn3, err := client.getConnPool(addr, true) assert.Nil(t, err) assert.NotNil(t, conn3) assert.Equal(t, ver+1, conn3.ver) client.Close() - conn4, err := client.getConnArray(addr, true) + conn4, err := client.getConnPool(addr, true) assert.NotNil(t, err) assert.Nil(t, conn4) } @@ -105,10 +105,10 @@ func TestGetConnAfterClose(t *testing.T) { defer client.Close() addr := "127.0.0.1:6379" - connArray, err := client.getConnArray(addr, true) + connPool, err := client.getConnPool(addr, true) assert.Nil(t, err) assert.Nil(t, client.CloseAddr(addr)) - conn := connArray.Get() + conn := connPool.Get() state := conn.GetState() assert.True(t, state == connectivity.Shutdown) } @@ -139,7 +139,7 @@ func TestSendWhenReconnect(t *testing.T) { restoreFn() }() addr := server.Addr() - conn, err := rpcClient.getConnArray(addr, true) + conn, err := rpcClient.getConnPool(addr, true) assert.Nil(t, err) // Suppose all connections are re-establishing. @@ -680,7 +680,7 @@ func TestBatchClientRecoverAfterServerRestart(t *testing.T) { }() req := &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_Coprocessor{Coprocessor: &coprocessor.Request{}}} - conn, err := client.getConnArray(addr, true) + conn, err := client.getConnPool(addr, true) assert.Nil(t, err) // send some request, it should be success. for i := 0; i < 100; i++ { @@ -855,7 +855,7 @@ func TestBatchClientReceiveHealthFeedback(t *testing.T) { client := NewRPCClient() defer client.Close() - conn, err := client.getConnArray(addr, true) + conn, err := client.getConnPool(addr, true) assert.NoError(t, err) tikvClient := tikvpb.NewTikvClient(conn.Get()) @@ -944,7 +944,7 @@ func TestRandomRestartStoreAndForwarding(t *testing.T) { } }() - conn, err := client1.getConnArray(addr1, true) + conn, err := client1.getConnPool(addr1, true) assert.Nil(t, err) for j := 0; j < concurrency; j++ { wg.Add(1) @@ -1040,7 +1040,7 @@ func TestFastFailWhenNoAvailableConn(t *testing.T) { }() req := &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_Coprocessor{Coprocessor: &coprocessor.Request{}}} - conn, err := client.getConnArray(addr, true) + conn, err := client.getConnPool(addr, true) assert.Nil(t, err) _, err = sendBatchRequest(context.Background(), addr, "", conn.batchConn, req, time.Second, 0) require.NoError(t, err) @@ -1060,7 +1060,7 @@ func TestFastFailWhenNoAvailableConn(t *testing.T) { func TestConcurrentCloseConnPanic(t *testing.T) { client := NewRPCClient() addr := "127.0.0.1:6379" - _, err := client.getConnArray(addr, true) + _, err := client.getConnPool(addr, true) assert.Nil(t, err) var wg sync.WaitGroup wg.Add(2) diff --git a/internal/client/conn_batch.go b/internal/client/conn_batch.go new file mode 100644 index 00000000..bf5e926a --- /dev/null +++ b/internal/client/conn_batch.go @@ -0,0 +1,339 @@ +// Copyright 2025 TiKV Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "runtime" + "runtime/trace" + "sync" + "sync/atomic" + "time" + + "github.com/pkg/errors" + "github.com/tikv/client-go/v2/config" + "github.com/tikv/client-go/v2/internal/logutil" + "github.com/tikv/client-go/v2/metrics" + "github.com/tikv/client-go/v2/util" + "go.uber.org/zap" +) + +type batchConn struct { + // An atomic flag indicates whether the batch is idle or not. + // 0 for busy, others for idle. + idle uint32 + + // batchCommandsCh used for batch commands. + batchCommandsCh chan *batchCommandsEntry + batchCommandsClients []*batchCommandsClient + tikvTransportLayerLoad uint64 + closed chan struct{} + + reqBuilder *batchCommandsBuilder + + // Notify rpcClient to check the idle flag + idleNotify *uint32 + idleDetect *time.Timer + + fetchMoreTimer *time.Timer + + index uint32 + + metrics batchConnMetrics +} + +func newBatchConn(connCount, maxBatchSize uint, idleNotify *uint32) *batchConn { + return &batchConn{ + batchCommandsCh: make(chan *batchCommandsEntry, maxBatchSize), + batchCommandsClients: make([]*batchCommandsClient, 0, connCount), + tikvTransportLayerLoad: 0, + closed: make(chan struct{}), + reqBuilder: newBatchCommandsBuilder(maxBatchSize), + idleNotify: idleNotify, + idleDetect: time.NewTimer(idleTimeout), + } +} + +func (a *batchConn) initMetrics(target string) { + a.metrics.pendingRequests = metrics.TiKVBatchPendingRequests.WithLabelValues(target) + a.metrics.batchSize = metrics.TiKVBatchRequests.WithLabelValues(target) + a.metrics.sendLoopWaitHeadDur = metrics.TiKVBatchSendLoopDuration.WithLabelValues(target, "wait-head") + a.metrics.sendLoopWaitMoreDur = metrics.TiKVBatchSendLoopDuration.WithLabelValues(target, "wait-more") + a.metrics.sendLoopSendDur = metrics.TiKVBatchSendLoopDuration.WithLabelValues(target, "send") + a.metrics.recvLoopRecvDur = metrics.TiKVBatchRecvLoopDuration.WithLabelValues(target, "recv") + a.metrics.recvLoopProcessDur = metrics.TiKVBatchRecvLoopDuration.WithLabelValues(target, "process") + a.metrics.batchSendTailLat = metrics.TiKVBatchSendTailLatency.WithLabelValues(target) + a.metrics.batchRecvTailLat = metrics.TiKVBatchRecvTailLatency.WithLabelValues(target) + a.metrics.headArrivalInterval = metrics.TiKVBatchHeadArrivalInterval.WithLabelValues(target) + a.metrics.batchMoreRequests = metrics.TiKVBatchMoreRequests.WithLabelValues(target) + a.metrics.bestBatchSize = metrics.TiKVBatchBestSize.WithLabelValues(target) +} + +func (a *batchConn) isIdle() bool { + return atomic.LoadUint32(&a.idle) != 0 +} + +// fetchAllPendingRequests fetches all pending requests from the channel. +func (a *batchConn) fetchAllPendingRequests(maxBatchSize int) (headRecvTime time.Time, headArrivalInterval time.Duration) { + // Block on the first element. + latestReqStartTime := a.reqBuilder.latestReqStartTime + var headEntry *batchCommandsEntry + select { + case headEntry = <-a.batchCommandsCh: + if !a.idleDetect.Stop() { + <-a.idleDetect.C + } + a.idleDetect.Reset(idleTimeout) + case <-a.idleDetect.C: + a.idleDetect.Reset(idleTimeout) + atomic.AddUint32(&a.idle, 1) + atomic.CompareAndSwapUint32(a.idleNotify, 0, 1) + // This batchConn to be recycled + return time.Now(), 0 + case <-a.closed: + return time.Now(), 0 + } + if headEntry == nil { + return time.Now(), 0 + } + headRecvTime = time.Now() + if headEntry.start.After(latestReqStartTime) && !latestReqStartTime.IsZero() { + headArrivalInterval = headEntry.start.Sub(latestReqStartTime) + } + a.reqBuilder.push(headEntry) + + // This loop is for trying best to collect more requests. + for a.reqBuilder.len() < maxBatchSize { + select { + case entry := <-a.batchCommandsCh: + if entry == nil { + return + } + a.reqBuilder.push(entry) + default: + return + } + } + return +} + +// fetchMorePendingRequests fetches more pending requests from the channel. +func (a *batchConn) fetchMorePendingRequests( + maxBatchSize int, + batchWaitSize int, + maxWaitTime time.Duration, +) { + // Try to collect `batchWaitSize` requests, or wait `maxWaitTime`. + if a.fetchMoreTimer == nil { + a.fetchMoreTimer = time.NewTimer(maxWaitTime) + } else { + a.fetchMoreTimer.Reset(maxWaitTime) + } + for a.reqBuilder.len() < batchWaitSize { + select { + case entry := <-a.batchCommandsCh: + if entry == nil { + if !a.fetchMoreTimer.Stop() { + <-a.fetchMoreTimer.C + } + return + } + a.reqBuilder.push(entry) + case <-a.fetchMoreTimer.C: + return + } + } + if !a.fetchMoreTimer.Stop() { + <-a.fetchMoreTimer.C + } + + // Do an additional non-block try. Here we test the length with `maxBatchSize` instead + // of `batchWaitSize` because trying best to fetch more requests is necessary so that + // we can adjust the `batchWaitSize` dynamically. + yielded := false + for a.reqBuilder.len() < maxBatchSize { + select { + case entry := <-a.batchCommandsCh: + if entry == nil { + return + } + a.reqBuilder.push(entry) + default: + if yielded { + return + } + // yield once to batch more requests. + runtime.Gosched() + yielded = true + } + } +} + +const idleTimeout = 3 * time.Minute + +// BatchSendLoopPanicCounter is only used for testing. +var BatchSendLoopPanicCounter int64 = 0 + +var initBatchPolicyWarn sync.Once + +func (a *batchConn) batchSendLoop(cfg config.TiKVClient) { + defer func() { + if r := recover(); r != nil { + metrics.TiKVPanicCounter.WithLabelValues(metrics.LabelBatchSendLoop).Inc() + logutil.BgLogger().Error("batchSendLoop", + zap.Any("r", r), + zap.Stack("stack")) + atomic.AddInt64(&BatchSendLoopPanicCounter, 1) + logutil.BgLogger().Info("restart batchSendLoop", zap.Int64("count", atomic.LoadInt64(&BatchSendLoopPanicCounter))) + go a.batchSendLoop(cfg) + } + }() + + trigger, ok := newTurboBatchTriggerFromPolicy(cfg.BatchPolicy) + if !ok { + initBatchPolicyWarn.Do(func() { + logutil.BgLogger().Warn("fallback to default batch policy due to invalid value", zap.String("value", cfg.BatchPolicy)) + }) + } + turboBatchWaitTime := trigger.turboWaitTime() + + avgBatchWaitSize := float64(cfg.BatchWaitSize) + for { + sendLoopStartTime := time.Now() + a.reqBuilder.reset() + + headRecvTime, headArrivalInterval := a.fetchAllPendingRequests(int(cfg.MaxBatchSize)) + if a.reqBuilder.len() == 0 { + // the conn is closed or recycled. + return + } + + // curl -X PUT -d 'return(true)' http://0.0.0.0:10080/fail/tikvclient/mockBlockOnBatchClient + if val, err := util.EvalFailpoint("mockBlockOnBatchClient"); err == nil { + if val.(bool) { + time.Sleep(1 * time.Hour) + } + } + + if batchSize := a.reqBuilder.len(); batchSize < int(cfg.MaxBatchSize) { + if cfg.MaxBatchWaitTime > 0 && atomic.LoadUint64(&a.tikvTransportLayerLoad) > uint64(cfg.OverloadThreshold) { + // If the target TiKV is overload, wait a while to collect more requests. + metrics.TiKVBatchWaitOverLoad.Inc() + a.fetchMorePendingRequests(int(cfg.MaxBatchSize), int(cfg.BatchWaitSize), cfg.MaxBatchWaitTime) + } else if turboBatchWaitTime > 0 && headArrivalInterval > 0 && trigger.needFetchMore(headArrivalInterval) { + batchWaitSize := trigger.preferredBatchWaitSize(avgBatchWaitSize, int(cfg.BatchWaitSize)) + a.fetchMorePendingRequests(int(cfg.MaxBatchSize), batchWaitSize, turboBatchWaitTime) + a.metrics.batchMoreRequests.Observe(float64(a.reqBuilder.len() - batchSize)) + } + } + length := a.reqBuilder.len() + avgBatchWaitSize = 0.2*float64(length) + 0.8*avgBatchWaitSize + a.metrics.pendingRequests.Observe(float64(len(a.batchCommandsCh) + length)) + a.metrics.bestBatchSize.Observe(avgBatchWaitSize) + a.metrics.headArrivalInterval.Observe(headArrivalInterval.Seconds()) + a.metrics.sendLoopWaitHeadDur.Observe(headRecvTime.Sub(sendLoopStartTime).Seconds()) + a.metrics.sendLoopWaitMoreDur.Observe(time.Since(sendLoopStartTime).Seconds()) + + a.getClientAndSend() + + sendLoopEndTime := time.Now() + a.metrics.sendLoopSendDur.Observe(sendLoopEndTime.Sub(sendLoopStartTime).Seconds()) + if dur := sendLoopEndTime.Sub(headRecvTime); dur > batchSendTailLatThreshold { + a.metrics.batchSendTailLat.Observe(dur.Seconds()) + } + } +} + +const ( + SendFailedReasonNoAvailableLimit = "concurrency limit exceeded" + SendFailedReasonTryLockForSendFail = "tryLockForSend fail" +) + +func (a *batchConn) getClientAndSend() { + if val, err := util.EvalFailpoint("mockBatchClientSendDelay"); err == nil { + if timeout, ok := val.(int); ok && timeout > 0 { + time.Sleep(time.Duration(timeout * int(time.Millisecond))) + } + } + + // Choose a connection by round-robbin. + var ( + cli *batchCommandsClient + target string + ) + reasons := make([]string, 0) + hasHighPriorityTask := a.reqBuilder.hasHighPriorityTask() + for i := 0; i < len(a.batchCommandsClients); i++ { + a.index = (a.index + 1) % uint32(len(a.batchCommandsClients)) + target = a.batchCommandsClients[a.index].target + // The lock protects the batchCommandsClient from been closed while it's in use. + c := a.batchCommandsClients[a.index] + if hasHighPriorityTask || c.available() > 0 { + if c.tryLockForSend() { + cli = c + break + } else { + reasons = append(reasons, SendFailedReasonTryLockForSendFail) + } + } else { + reasons = append(reasons, SendFailedReasonNoAvailableLimit) + } + } + if cli == nil { + logutil.BgLogger().Info("no available connections", zap.String("target", target), zap.Any("reasons", reasons)) + metrics.TiKVNoAvailableConnectionCounter.Inc() + if config.GetGlobalConfig().TiKVClient.MaxConcurrencyRequestLimit == config.DefMaxConcurrencyRequestLimit { + // Only cancel requests when MaxConcurrencyRequestLimit feature is not enabled, to be compatible with the behavior of older versions. + // TODO: But when MaxConcurrencyRequestLimit feature is enabled, the requests won't be canceled and will wait until timeout. + // This behavior may not be reasonable, as the timeout is usually 40s or 60s, which is too long to retry in time. + a.reqBuilder.cancel(errors.New("no available connections")) + } + return + } + defer cli.unlockForSend() + available := cli.available() + reqSendTime := time.Now() + batch := 0 + req, forwardingReqs := a.reqBuilder.buildWithLimit(available, func(id uint64, e *batchCommandsEntry) { + cli.batched.Store(id, e) + cli.sent.Add(1) + atomic.StoreInt64(&e.sendLat, int64(reqSendTime.Sub(e.start))) + if trace.IsEnabled() { + trace.Log(e.ctx, "rpc", "send") + } + }) + if req != nil { + batch += len(req.RequestIds) + cli.send("", req) + } + for forwardedHost, req := range forwardingReqs { + batch += len(req.RequestIds) + cli.send(forwardedHost, req) + } + if batch > 0 { + a.metrics.batchSize.Observe(float64(batch)) + } +} + +func (a *batchConn) Close() { + // Close all batchRecvLoop. + for _, c := range a.batchCommandsClients { + // After connections are closed, `batchRecvLoop`s will check the flag. + atomic.StoreInt32(&c.closed, 1) + } + // Don't close(batchCommandsCh) because when Close() is called, someone maybe + // calling SendRequest and writing batchCommandsCh, if we close it here the + // writing goroutine will panic. + close(a.closed) +} diff --git a/internal/client/conn_monitor.go b/internal/client/conn_monitor.go new file mode 100644 index 00000000..07e51afc --- /dev/null +++ b/internal/client/conn_monitor.go @@ -0,0 +1,101 @@ +// Copyright 2025 TiKV Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "sync" + "time" + + "github.com/tikv/client-go/v2/internal/logutil" + "github.com/tikv/client-go/v2/metrics" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" +) + +type monitoredConn struct { + *grpc.ClientConn + Name string +} + +func (c *monitoredConn) Close() error { + if c.ClientConn != nil { + err := c.ClientConn.Close() + logutil.BgLogger().Debug("close gRPC connection", zap.String("target", c.Name), zap.Error(err)) + return err + } + return nil +} + +type connMonitor struct { + m sync.Map + loopOnce sync.Once + stopOnce sync.Once + stop chan struct{} +} + +func (c *connMonitor) AddConn(conn *monitoredConn) { + c.m.Store(conn.Name, conn) +} + +func (c *connMonitor) RemoveConn(conn *monitoredConn) { + c.m.Delete(conn.Name) + for state := connectivity.Idle; state <= connectivity.Shutdown; state++ { + metrics.TiKVGrpcConnectionState.WithLabelValues(conn.Name, conn.Target(), state.String()).Set(0) + } +} + +func (c *connMonitor) Start() { + c.loopOnce.Do( + func() { + c.stop = make(chan struct{}) + go c.start() + }, + ) +} + +func (c *connMonitor) Stop() { + c.stopOnce.Do( + func() { + if c.stop != nil { + close(c.stop) + } + }, + ) +} + +func (c *connMonitor) start() { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + c.m.Range(func(_, value interface{}) bool { + conn := value.(*monitoredConn) + nowState := conn.GetState() + for state := connectivity.Idle; state <= connectivity.Shutdown; state++ { + if state == nowState { + metrics.TiKVGrpcConnectionState.WithLabelValues(conn.Name, conn.Target(), nowState.String()).Set(1) + } else { + metrics.TiKVGrpcConnectionState.WithLabelValues(conn.Name, conn.Target(), state.String()).Set(0) + } + } + return true + }) + case <-c.stop: + return + } + } +} diff --git a/internal/client/conn_pool.go b/internal/client/conn_pool.go new file mode 100644 index 00000000..2ac8f032 --- /dev/null +++ b/internal/client/conn_pool.go @@ -0,0 +1,255 @@ +// Copyright 2025 TiKV Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" + "github.com/pkg/errors" + "github.com/prometheus/client_golang/prometheus" + "github.com/tikv/client-go/v2/config" + tikverr "github.com/tikv/client-go/v2/error" + "github.com/tikv/client-go/v2/metrics" + "github.com/tikv/client-go/v2/tikvrpc" + "github.com/tikv/client-go/v2/util" + "google.golang.org/grpc" + "google.golang.org/grpc/backoff" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/encoding/gzip" + "google.golang.org/grpc/experimental" + "google.golang.org/grpc/keepalive" +) + +type connPool struct { + // The target host. + target string + // version of the connection pool, increase by 1 when reconnect. + ver uint64 + + index uint32 + // streamTimeout binds with a background goroutine to process coprocessor streaming timeout. + streamTimeout chan *tikvrpc.Lease + dialTimeout time.Duration + conns []*monitoredConn + // batchConn is not null when batch is enabled. + *batchConn + done chan struct{} + + monitor *connMonitor + + metrics struct { + rpcLatHist *rpcMetrics + rpcSrcLatSum sync.Map + rpcNetLatExternal prometheus.Observer + rpcNetLatInternal prometheus.Observer + } +} + +func newConnPool(maxSize uint, addr string, ver uint64, security config.Security, + idleNotify *uint32, enableBatch bool, dialTimeout time.Duration, m *connMonitor, eventListener *atomic.Pointer[ClientEventListener], opts []grpc.DialOption) (*connPool, error) { + a := &connPool{ + ver: ver, + index: 0, + conns: make([]*monitoredConn, maxSize), + streamTimeout: make(chan *tikvrpc.Lease, 1024), + done: make(chan struct{}), + dialTimeout: dialTimeout, + monitor: m, + } + a.metrics.rpcLatHist = deriveRPCMetrics(metrics.TiKVSendReqHistogram.MustCurryWith(prometheus.Labels{metrics.LblStore: addr})) + a.metrics.rpcNetLatExternal = metrics.TiKVRPCNetLatencyHistogram.WithLabelValues(addr, "false") + a.metrics.rpcNetLatInternal = metrics.TiKVRPCNetLatencyHistogram.WithLabelValues(addr, "true") + if err := a.Init(addr, security, idleNotify, enableBatch, eventListener, opts...); err != nil { + return nil, err + } + return a, nil +} + +func (a *connPool) monitoredDial(ctx context.Context, connName, target string, opts ...grpc.DialOption) (conn *monitoredConn, err error) { + conn = &monitoredConn{ + Name: connName, + } + conn.ClientConn, err = grpc.DialContext(ctx, target, opts...) + if err != nil { + return nil, err + } + a.monitor.AddConn(conn) + return conn, nil +} + +func (a *connPool) Init(addr string, security config.Security, idleNotify *uint32, enableBatch bool, eventListener *atomic.Pointer[ClientEventListener], opts ...grpc.DialOption) error { + a.target = addr + + opt := grpc.WithTransportCredentials(insecure.NewCredentials()) + if len(security.ClusterSSLCA) != 0 { + tlsConfig, err := security.ToTLSConfig() + if err != nil { + return errors.WithStack(err) + } + opt = grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)) + } + + cfg := config.GetGlobalConfig() + var ( + unaryInterceptor grpc.UnaryClientInterceptor + streamInterceptor grpc.StreamClientInterceptor + ) + if cfg.OpenTracingEnable { + unaryInterceptor = grpc_opentracing.UnaryClientInterceptor() + streamInterceptor = grpc_opentracing.StreamClientInterceptor() + } + + allowBatch := (cfg.TiKVClient.MaxBatchSize > 0) && enableBatch + if allowBatch { + a.batchConn = newBatchConn(uint(len(a.conns)), cfg.TiKVClient.MaxBatchSize, idleNotify) + a.batchConn.initMetrics(a.target) + } + keepAlive := cfg.TiKVClient.GrpcKeepAliveTime + for i := range a.conns { + ctx, cancel := context.WithTimeout(context.Background(), a.dialTimeout) + var callOptions []grpc.CallOption + callOptions = append(callOptions, grpc.MaxCallRecvMsgSize(MaxRecvMsgSize)) + if cfg.TiKVClient.GrpcCompressionType == gzip.Name { + callOptions = append(callOptions, grpc.UseCompressor(gzip.Name)) + } + + opts = append([]grpc.DialOption{ + opt, + grpc.WithInitialWindowSize(cfg.TiKVClient.GrpcInitialWindowSize), + grpc.WithInitialConnWindowSize(cfg.TiKVClient.GrpcInitialConnWindowSize), + grpc.WithUnaryInterceptor(unaryInterceptor), + grpc.WithStreamInterceptor(streamInterceptor), + grpc.WithDefaultCallOptions(callOptions...), + grpc.WithConnectParams(grpc.ConnectParams{ + Backoff: backoff.Config{ + BaseDelay: 100 * time.Millisecond, // Default was 1s. + Multiplier: 1.6, // Default + Jitter: 0.2, // Default + MaxDelay: 3 * time.Second, // Default was 120s. + }, + MinConnectTimeout: a.dialTimeout, + }), + grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: time.Duration(keepAlive) * time.Second, + Timeout: cfg.TiKVClient.GetGrpcKeepAliveTimeout(), + }), + }, opts...) + if cfg.TiKVClient.GrpcSharedBufferPool { + opts = append(opts, experimental.WithRecvBufferPool(grpc.NewSharedBufferPool())) + } + conn, err := a.monitoredDial( + ctx, + fmt.Sprintf("%s-%d", a.target, i), + addr, + opts..., + ) + + cancel() + if err != nil { + // Cleanup if the initialization fails. + a.Close() + return errors.WithStack(err) + } + a.conns[i] = conn + + if allowBatch { + batchClient := &batchCommandsClient{ + target: a.target, + conn: conn.ClientConn, + forwardedClients: make(map[string]*batchCommandsStream), + batched: sync.Map{}, + epoch: 0, + closed: 0, + tikvClientCfg: cfg.TiKVClient, + tikvLoad: &a.tikvTransportLayerLoad, + dialTimeout: a.dialTimeout, + tryLock: tryLock{sync.NewCond(new(sync.Mutex)), false}, + eventListener: eventListener, + metrics: &a.batchConn.metrics, + } + batchClient.maxConcurrencyRequestLimit.Store(cfg.TiKVClient.MaxConcurrencyRequestLimit) + a.batchCommandsClients = append(a.batchCommandsClients, batchClient) + } + } + go tikvrpc.CheckStreamTimeoutLoop(a.streamTimeout, a.done) + if allowBatch { + go a.batchSendLoop(cfg.TiKVClient) + } + + return nil +} + +func (a *connPool) Get() *grpc.ClientConn { + next := atomic.AddUint32(&a.index, 1) % uint32(len(a.conns)) + return a.conns[next].ClientConn +} + +func (a *connPool) Close() { + if a.batchConn != nil { + a.batchConn.Close() + } + + for _, c := range a.conns { + if c != nil { + err := c.Close() + tikverr.Log(err) + if err == nil { + a.monitor.RemoveConn(c) + } + } + } + + close(a.done) +} + +func (a *connPool) updateRPCMetrics(req *tikvrpc.Request, resp *tikvrpc.Response, latency time.Duration) { + seconds := latency.Seconds() + stale := req.GetStaleRead() + source := req.GetRequestSource() + internal := util.IsInternalRequest(req.GetRequestSource()) + + a.metrics.rpcLatHist.get(req.Type, stale, internal).Observe(seconds) + + srcLatSum, ok := a.metrics.rpcSrcLatSum.Load(source) + if !ok { + srcLatSum = deriveRPCMetrics(metrics.TiKVSendReqSummary.MustCurryWith( + prometheus.Labels{metrics.LblStore: a.target, metrics.LblSource: source})) + a.metrics.rpcSrcLatSum.Store(source, srcLatSum) + } + srcLatSum.(*rpcMetrics).get(req.Type, stale, internal).Observe(seconds) + + if execDetail := resp.GetExecDetailsV2(); execDetail != nil { + var totalRpcWallTimeNs uint64 + if execDetail.TimeDetailV2 != nil { + totalRpcWallTimeNs = execDetail.TimeDetailV2.TotalRpcWallTimeNs + } else if execDetail.TimeDetail != nil { + totalRpcWallTimeNs = execDetail.TimeDetail.TotalRpcWallTimeNs + } + if totalRpcWallTimeNs > 0 { + lat := latency - time.Duration(totalRpcWallTimeNs) + if internal { + a.metrics.rpcNetLatInternal.Observe(lat.Seconds()) + } else { + a.metrics.rpcNetLatExternal.Observe(lat.Seconds()) + } + } + } +} diff --git a/internal/locate/region_cache.go b/internal/locate/region_cache.go index 4fa8d3fe..93c2a098 100644 --- a/internal/locate/region_cache.go +++ b/internal/locate/region_cache.go @@ -756,11 +756,6 @@ func NewRegionCache(pdClient pd.Client, opt ...RegionCacheOpt) *RegionCache { return c } -// ForceRefreshAllStores get all stores from PD and refresh store cache. -func (c *RegionCache) ForceRefreshAllStores(ctx context.Context) { - refreshFullStoreList(ctx, c.stores) -} - // Try to refresh full store list. Errors are ignored. func refreshFullStoreList(ctx context.Context, stores storeCache) { storeList, err := stores.fetchAllStores(ctx) diff --git a/internal/locate/store_cache.go b/internal/locate/store_cache.go index b6048142..9ac20ec9 100644 --- a/internal/locate/store_cache.go +++ b/internal/locate/store_cache.go @@ -761,7 +761,7 @@ func invokeKVStatusAPI(addr string, timeout time.Duration) (l livenessState) { func createKVHealthClient(ctx context.Context, addr string) (*grpc.ClientConn, healthpb.HealthClient, error) { // Temporarily directly load the config from the global config, however it's not a good idea to let RegionCache to // access it. - // TODO: Pass the config in a better way, or use the connArray inner the client directly rather than creating new + // TODO: Pass the config in a better way, or use the connPool inner the client directly rather than creating new // connection. cfg := config.GetGlobalConfig()