diff --git a/internal/client/client.go b/internal/client/client.go index 0212081c..e6a65ce6 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -119,9 +119,45 @@ type ClientEventListener interface { OnHealthFeedback(feedback *tikvpb.HealthFeedback) } +// ClientExt is a client has extended interfaces. +type ClientExt interface { + // CloseAddrVer closes gRPC connections to the address with additional `ver` parameter. + // Each new connection will have an incremented `ver` value, and attempts to close a previous `ver` will be ignored. + // Passing `math.MaxUint64` as the `ver` parameter will forcefully close all connections to the address. + CloseAddrVer(addr string, ver uint64) error +} + +// ErrConn wraps error with target address and version of the connection. +type ErrConn struct { + Err error + Addr string + Ver uint64 +} + +func (e *ErrConn) Error() string { + return fmt.Sprintf("[%s](%d) %s", e.Addr, e.Ver, e.Err.Error()) +} + +func (e *ErrConn) Unwrap() error { + return e.Err +} + +func WrapErrConn(err error, conn *connArray) error { + if err == nil { + return nil + } + return &ErrConn{ + Err: err, + Addr: conn.target, + Ver: conn.ver, + } +} + type connArray struct { // The target host. target string + // version of the connection array, increase by 1 when reconnect. + ver uint64 index uint32 v []*monitoredConn @@ -135,9 +171,10 @@ type connArray struct { monitor *connMonitor } -func newConnArray(maxSize uint, addr string, security config.Security, +func newConnArray(maxSize uint, addr string, ver uint64, security config.Security, idleNotify *uint32, enableBatch bool, dialTimeout time.Duration, m *connMonitor, eventListener *atomic.Pointer[ClientEventListener], opts []grpc.DialOption) (*connArray, error) { a := &connArray{ + ver: ver, index: 0, v: make([]*monitoredConn, maxSize), streamTimeout: make(chan *tikvrpc.Lease, 1024), @@ -232,7 +269,9 @@ func (a *connArray) monitoredDial(ctx context.Context, connName, target string, func (c *monitoredConn) Close() error { if c.ClientConn != nil { - return c.ClientConn.Close() + err := c.ClientConn.Close() + logutil.BgLogger().Debug("close gRPC connection", zap.String("target", c.Name), zap.Error(err)) + return err } return nil } @@ -402,6 +441,7 @@ type RPCClient struct { sync.RWMutex conns map[string]*connArray + vers map[string]uint64 option *option idleNotify uint32 @@ -421,6 +461,7 @@ var _ Client = &RPCClient{} func NewRPCClient(opts ...Opt) *RPCClient { cli := &RPCClient{ conns: make(map[string]*connArray), + vers: make(map[string]uint64), option: &option{ dialTimeout: dialTimeout, }, @@ -469,9 +510,11 @@ func (c *RPCClient) createConnArray(addr string, enableBatch bool, opts ...func( for _, opt := range opts { opt(&client) } + ver := c.vers[addr] + 1 array, err = newConnArray( client.GrpcConnectionCount, addr, + ver, c.option.security, &c.idleNotify, enableBatch, @@ -484,6 +527,7 @@ func (c *RPCClient) createConnArray(addr string, enableBatch bool, opts ...func( return nil, err } c.conns[addr] = array + c.vers[addr] = ver } return array, nil } @@ -621,6 +665,10 @@ func (c *RPCClient) sendRequest(ctx context.Context, addr string, req *tikvrpc.R return nil, err } + wrapErrConn := func(resp *tikvrpc.Response, err error) (*tikvrpc.Response, error) { + return resp, WrapErrConn(err, connArray) + } + start := time.Now() staleRead := req.GetStaleRead() defer func() { @@ -644,7 +692,7 @@ func (c *RPCClient) sendRequest(ctx context.Context, addr string, req *tikvrpc.R if config.GetGlobalConfig().TiKVClient.MaxBatchSize > 0 && enableBatch { if batchReq := req.ToBatchCommandsRequest(); batchReq != nil { defer trace.StartRegion(ctx, req.Type.String()).End() - return sendBatchRequest(ctx, addr, req.ForwardedHost, connArray.batchConn, batchReq, timeout, pri) + return wrapErrConn(sendBatchRequest(ctx, addr, req.ForwardedHost, connArray.batchConn, batchReq, timeout, pri)) } } @@ -658,7 +706,7 @@ func (c *RPCClient) sendRequest(ctx context.Context, addr string, req *tikvrpc.R client := debugpb.NewDebugClient(clientConn) ctx1, cancel := context.WithTimeout(ctx, timeout) defer cancel() - return tikvrpc.CallDebugRPC(ctx1, client, req) + return wrapErrConn(tikvrpc.CallDebugRPC(ctx1, client, req)) } client := tikvpb.NewTikvClient(clientConn) @@ -669,16 +717,16 @@ func (c *RPCClient) sendRequest(ctx context.Context, addr string, req *tikvrpc.R } switch req.Type { case tikvrpc.CmdBatchCop: - return c.getBatchCopStreamResponse(ctx, client, req, timeout, connArray) + return wrapErrConn(c.getBatchCopStreamResponse(ctx, client, req, timeout, connArray)) case tikvrpc.CmdCopStream: - return c.getCopStreamResponse(ctx, client, req, timeout, connArray) + return wrapErrConn(c.getCopStreamResponse(ctx, client, req, timeout, connArray)) case tikvrpc.CmdMPPConn: - return c.getMPPStreamResponse(ctx, client, req, timeout, connArray) + return wrapErrConn(c.getMPPStreamResponse(ctx, client, req, timeout, connArray)) } // Or else it's a unary call. ctx1, cancel := context.WithTimeout(ctx, timeout) defer cancel() - return tikvrpc.CallRPC(ctx1, client, req) + return wrapErrConn(tikvrpc.CallRPC(ctx1, client, req)) } // SendRequest sends a Request to server and receives Response. @@ -812,11 +860,20 @@ func (c *RPCClient) Close() error { // CloseAddr closes gRPC connections to the address. func (c *RPCClient) CloseAddr(addr string) error { + return c.CloseAddrVer(addr, math.MaxUint64) +} + +func (c *RPCClient) CloseAddrVer(addr string, ver uint64) error { c.Lock() conn, ok := c.conns[addr] if ok { - delete(c.conns, addr) - logutil.BgLogger().Debug("close connection", zap.String("target", addr)) + if conn.ver <= ver { + delete(c.conns, addr) + logutil.BgLogger().Debug("close connection", zap.String("target", addr), zap.Uint64("ver", ver), zap.Uint64("conn.ver", conn.ver)) + } else { + logutil.BgLogger().Debug("ignore close connection", zap.String("target", addr), zap.Uint64("ver", ver), zap.Uint64("conn.ver", conn.ver)) + conn = nil + } } c.Unlock() diff --git a/internal/client/client_batch.go b/internal/client/client_batch.go index 3964efc1..ece58c3f 100644 --- a/internal/client/client_batch.go +++ b/internal/client/client_batch.go @@ -876,6 +876,9 @@ func sendBatchRequest( logutil.Logger(ctx).Debug("send request is cancelled", zap.String("to", addr), zap.String("cause", ctx.Err().Error())) return nil, errors.WithStack(ctx.Err()) + case <-batchConn.closed: + logutil.Logger(ctx).Debug("send request is cancelled (batchConn closed)", zap.String("to", addr)) + return nil, errors.New("batchConn closed") case <-timer.C: return nil, errors.WithMessage(context.DeadlineExceeded, "wait sendLoop") } @@ -893,6 +896,10 @@ func sendBatchRequest( logutil.Logger(ctx).Debug("wait response is cancelled", zap.String("to", addr), zap.String("cause", ctx.Err().Error())) return nil, errors.WithStack(ctx.Err()) + case <-batchConn.closed: + atomic.StoreInt32(&entry.canceled, 1) + logutil.Logger(ctx).Debug("wait response is cancelled (batchConn closed)", zap.String("to", addr)) + return nil, errors.New("batchConn closed") case <-timer.C: atomic.StoreInt32(&entry.canceled, 1) reason := fmt.Sprintf("wait recvLoop timeout,timeout:%s, wait_duration:%s:", timeout, waitDuration) @@ -904,16 +911,18 @@ func (c *RPCClient) recycleIdleConnArray() { start := time.Now() var addrs []string + var vers []uint64 c.RLock() for _, conn := range c.conns { if conn.batchConn != nil && conn.isIdle() { addrs = append(addrs, conn.target) + vers = append(vers, conn.ver) } } c.RUnlock() - for _, addr := range addrs { - c.CloseAddr(addr) + for i, addr := range addrs { + c.CloseAddrVer(addr, vers[i]) } metrics.TiKVBatchClientRecycle.Observe(time.Since(start).Seconds()) diff --git a/internal/client/client_test.go b/internal/client/client_test.go index 8794a593..e70a4585 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -80,12 +80,18 @@ func TestConn(t *testing.T) { assert.Nil(t, err) assert.False(t, conn2.Get() == conn1.Get()) - assert.Nil(t, client.CloseAddr(addr)) + ver := conn2.ver + assert.Nil(t, client.CloseAddrVer(addr, ver-1)) _, ok := client.conns[addr] + assert.True(t, ok) + assert.Nil(t, client.CloseAddrVer(addr, ver)) + _, ok = client.conns[addr] assert.False(t, ok) + conn3, err := client.getConnArray(addr, true) assert.Nil(t, err) assert.NotNil(t, conn3) + assert.Equal(t, ver+1, conn3.ver) client.Close() conn4, err := client.getConnArray(addr, true) @@ -879,3 +885,34 @@ func TestBatchClientReceiveHealthFeedback(t *testing.T) { assert.Fail(t, "health feedback not received") } } + +func TestErrConn(t *testing.T) { + e := errors.New("conn error") + err1 := &ErrConn{Err: e, Addr: "127.0.0.1", Ver: 10} + err2 := &ErrConn{Err: e, Addr: "127.0.0.1", Ver: 10} + + e3 := errors.New("conn error 3") + err3 := &ErrConn{Err: e3} + + err4 := errors.New("not ErrConn") + + assert.True(t, errors.Is(err1, err1)) + assert.True(t, errors.Is(fmt.Errorf("%w", err1), err1)) + assert.False(t, errors.Is(fmt.Errorf("%w", err2), err1)) // err2 != err1 + assert.False(t, errors.Is(fmt.Errorf("%w", err4), err1)) + + var errConn *ErrConn + assert.True(t, errors.As(err1, &errConn)) + assert.Equal(t, "127.0.0.1", errConn.Addr) + assert.EqualValues(t, 10, errConn.Ver) + assert.EqualError(t, errConn.Err, "conn error") + + assert.True(t, errors.As(err3, &errConn)) + assert.EqualError(t, e3, "conn error 3") + + assert.False(t, errors.As(err4, &errConn)) + + errMsg := errors.New("unknown") + assert.True(t, errors.As(err1, &errMsg)) + assert.EqualError(t, err1, errMsg.Error()) +} diff --git a/internal/locate/region_request.go b/internal/locate/region_request.go index d8159ada..221dbee7 100644 --- a/internal/locate/region_request.go +++ b/internal/locate/region_request.go @@ -218,6 +218,14 @@ func (s *RegionRequestSender) GetClient() client.Client { return s.client } +// getClientExt returns the client with ClientExt interface. +// Return nil if the client does not implement ClientExt. +// Don't use in critical path. +func (s *RegionRequestSender) getClientExt() client.ClientExt { + ext, _ := s.client.(client.ClientExt) + return ext +} + // SetStoreAddr specifies the dest store address. func (s *RegionRequestSender) SetStoreAddr(addr string) { s.storeAddr = addr @@ -1964,10 +1972,19 @@ func (s *RegionRequestSender) onSendFail(bo *retry.Backoffer, ctx *RPCContext, r case <-bo.GetCtx().Done(): return errors.WithStack(err) default: - // If we don't cancel, but the error code is Canceled, it must be from grpc remote. - // This may happen when tikv is killed and exiting. - // Backoff and retry in this case. - logutil.Logger(bo.GetCtx()).Warn("receive a grpc cancel signal from remote", zap.Error(err)) + // If we don't cancel, but the error code is Canceled, it may be canceled by keepalive or gRPC remote. + // For the case of canceled by keepalive, we need to re-establish the connection, otherwise following requests will always fail. + // Canceled by gRPC remote may happen when tikv is killed and exiting. + // Close the connection, backoff, and retry. + logutil.Logger(bo.GetCtx()).Warn("receive a grpc cancel signal", zap.Error(err)) + var errConn *client.ErrConn + if errors.As(err, &errConn) { + if ext := s.getClientExt(); ext != nil { + ext.CloseAddrVer(errConn.Addr, errConn.Ver) + } else { + s.client.CloseAddr(errConn.Addr) + } + } } } diff --git a/internal/locate/region_request_test.go b/internal/locate/region_request_test.go index 6150f0e4..b1204407 100644 --- a/internal/locate/region_request_test.go +++ b/internal/locate/region_request_test.go @@ -37,6 +37,7 @@ package locate import ( "context" "fmt" + "math" "math/rand" "net" "sync" @@ -104,6 +105,7 @@ func (s *testRegionRequestToSingleStoreSuite) TearDownTest() { type fnClient struct { fn func(ctx context.Context, addr string, req *tikvrpc.Request, timeout time.Duration) (*tikvrpc.Response, error) closedAddr string + closedVer uint64 } func (f *fnClient) Close() error { @@ -111,7 +113,12 @@ func (f *fnClient) Close() error { } func (f *fnClient) CloseAddr(addr string) error { + return f.CloseAddrVer(addr, math.MaxUint64) +} + +func (f *fnClient) CloseAddrVer(addr string, ver uint64) error { f.closedAddr = addr + f.closedVer = ver return nil } @@ -684,6 +691,8 @@ func (s *testRegionRequestToSingleStoreSuite) TestCloseConnectionOnStoreNotMatch regionErr, _ := resp.GetRegionError() s.NotNil(regionErr) s.Equal(target, client.closedAddr) + var expected uint64 = math.MaxUint64 + s.Equal(expected, client.closedVer) } func (s *testRegionRequestToSingleStoreSuite) TestKVReadTimeoutWithDisableBatchClient() { @@ -792,3 +801,20 @@ func (s *testRegionRequestToSingleStoreSuite) TestClusterIDInReq() { regionErr, _ := resp.GetRegionError() s.Nil(regionErr) } + +type emptyClient struct { + client.Client +} + +func (s *testRegionRequestToSingleStoreSuite) TestClientExt() { + var cli client.Client = client.NewRPCClient() + sender := NewRegionRequestSender(s.cache, cli) + s.NotNil(sender.client) + s.NotNil(sender.getClientExt()) + cli.Close() + + cli = &emptyClient{} + sender = NewRegionRequestSender(s.cache, cli) + s.NotNil(sender.client) + s.Nil(sender.getClientExt()) +}