transport : wait for goroutines to exit before transport closes (#7666)

This commit is contained in:
eshitachandwani 2024-10-10 15:34:25 +05:30 committed by GitHub
parent 00b9e140ce
commit b850ea533f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 124 additions and 17 deletions

View File

@ -1140,10 +1140,15 @@ func (cc *ClientConn) Close() error {
<-cc.resolverWrapper.serializer.Done() <-cc.resolverWrapper.serializer.Done()
<-cc.balancerWrapper.serializer.Done() <-cc.balancerWrapper.serializer.Done()
var wg sync.WaitGroup
for ac := range conns { for ac := range conns {
ac.tearDown(ErrClientConnClosing) wg.Add(1)
go func(ac *addrConn) {
defer wg.Done()
ac.tearDown(ErrClientConnClosing)
}(ac)
} }
wg.Wait()
cc.addTraceEvent("deleted") cc.addTraceEvent("deleted")
// TraceEvent needs to be called before RemoveEntry, as TraceEvent may add // TraceEvent needs to be called before RemoveEntry, as TraceEvent may add
// trace reference to the entity being deleted, and thus prevent it from being // trace reference to the entity being deleted, and thus prevent it from being

View File

@ -86,9 +86,9 @@ type http2Client struct {
writerDone chan struct{} // sync point to enable testing. writerDone chan struct{} // sync point to enable testing.
// goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor) // goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor)
// that the server sent GoAway on this transport. // that the server sent GoAway on this transport.
goAway chan struct{} goAway chan struct{}
keepaliveDone chan struct{} // Closed when the keepalive goroutine exits.
framer *framer framer *framer
// controlBuf delivers all the control related tasks (e.g., window // controlBuf delivers all the control related tasks (e.g., window
// updates, reset streams, and various settings) to the controller. // updates, reset streams, and various settings) to the controller.
// Do not access controlBuf with mu held. // Do not access controlBuf with mu held.
@ -335,6 +335,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
readerDone: make(chan struct{}), readerDone: make(chan struct{}),
writerDone: make(chan struct{}), writerDone: make(chan struct{}),
goAway: make(chan struct{}), goAway: make(chan struct{}),
keepaliveDone: make(chan struct{}),
framer: newFramer(conn, writeBufSize, readBufSize, opts.SharedWriteBuffer, maxHeaderListSize), framer: newFramer(conn, writeBufSize, readBufSize, opts.SharedWriteBuffer, maxHeaderListSize),
fc: &trInFlow{limit: uint32(icwz)}, fc: &trInFlow{limit: uint32(icwz)},
scheme: scheme, scheme: scheme,
@ -1029,6 +1030,12 @@ func (t *http2Client) Close(err error) {
} }
t.cancel() t.cancel()
t.conn.Close() t.conn.Close()
// Waits for the reader and keepalive goroutines to exit before returning to
// ensure all resources are cleaned up before Close can return.
<-t.readerDone
if t.keepaliveEnabled {
<-t.keepaliveDone
}
channelz.RemoveEntry(t.channelz.ID) channelz.RemoveEntry(t.channelz.ID)
var st *status.Status var st *status.Status
if len(goAwayDebugMessage) > 0 { if len(goAwayDebugMessage) > 0 {
@ -1316,11 +1323,11 @@ func (t *http2Client) handlePing(f *http2.PingFrame) {
t.controlBuf.put(pingAck) t.controlBuf.put(pingAck)
} }
func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) error {
t.mu.Lock() t.mu.Lock()
if t.state == closing { if t.state == closing {
t.mu.Unlock() t.mu.Unlock()
return return nil
} }
if f.ErrCode == http2.ErrCodeEnhanceYourCalm && string(f.DebugData()) == "too_many_pings" { if f.ErrCode == http2.ErrCodeEnhanceYourCalm && string(f.DebugData()) == "too_many_pings" {
// When a client receives a GOAWAY with error code ENHANCE_YOUR_CALM and debug // When a client receives a GOAWAY with error code ENHANCE_YOUR_CALM and debug
@ -1332,8 +1339,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
id := f.LastStreamID id := f.LastStreamID
if id > 0 && id%2 == 0 { if id > 0 && id%2 == 0 {
t.mu.Unlock() t.mu.Unlock()
t.Close(connectionErrorf(true, nil, "received goaway with non-zero even-numbered stream id: %v", id)) return connectionErrorf(true, nil, "received goaway with non-zero even-numbered stream id: %v", id)
return
} }
// A client can receive multiple GoAways from the server (see // A client can receive multiple GoAways from the server (see
// https://github.com/grpc/grpc-go/issues/1387). The idea is that the first // https://github.com/grpc/grpc-go/issues/1387). The idea is that the first
@ -1350,8 +1356,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
// If there are multiple GoAways the first one should always have an ID greater than the following ones. // If there are multiple GoAways the first one should always have an ID greater than the following ones.
if id > t.prevGoAwayID { if id > t.prevGoAwayID {
t.mu.Unlock() t.mu.Unlock()
t.Close(connectionErrorf(true, nil, "received goaway with stream id: %v, which exceeds stream id of previous goaway: %v", id, t.prevGoAwayID)) return connectionErrorf(true, nil, "received goaway with stream id: %v, which exceeds stream id of previous goaway: %v", id, t.prevGoAwayID)
return
} }
default: default:
t.setGoAwayReason(f) t.setGoAwayReason(f)
@ -1375,8 +1380,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
t.prevGoAwayID = id t.prevGoAwayID = id
if len(t.activeStreams) == 0 { if len(t.activeStreams) == 0 {
t.mu.Unlock() t.mu.Unlock()
t.Close(connectionErrorf(true, nil, "received goaway and there are no active streams")) return connectionErrorf(true, nil, "received goaway and there are no active streams")
return
} }
streamsToClose := make([]*Stream, 0) streamsToClose := make([]*Stream, 0)
@ -1393,6 +1397,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
for _, stream := range streamsToClose { for _, stream := range streamsToClose {
t.closeStream(stream, errStreamDrain, false, http2.ErrCodeNo, statusGoAway, nil, false) t.closeStream(stream, errStreamDrain, false, http2.ErrCodeNo, statusGoAway, nil, false)
} }
return nil
} }
// setGoAwayReason sets the value of t.goAwayReason based // setGoAwayReason sets the value of t.goAwayReason based
@ -1628,7 +1633,13 @@ func (t *http2Client) readServerPreface() error {
// network connection. If the server preface is not read successfully, an // network connection. If the server preface is not read successfully, an
// error is pushed to errCh; otherwise errCh is closed with no error. // error is pushed to errCh; otherwise errCh is closed with no error.
func (t *http2Client) reader(errCh chan<- error) { func (t *http2Client) reader(errCh chan<- error) {
defer close(t.readerDone) var errClose error
defer func() {
close(t.readerDone)
if errClose != nil {
t.Close(errClose)
}
}()
if err := t.readServerPreface(); err != nil { if err := t.readServerPreface(); err != nil {
errCh <- err errCh <- err
@ -1669,7 +1680,7 @@ func (t *http2Client) reader(errCh chan<- error) {
continue continue
} }
// Transport error. // Transport error.
t.Close(connectionErrorf(true, err, "error reading from server: %v", err)) errClose = connectionErrorf(true, err, "error reading from server: %v", err)
return return
} }
switch frame := frame.(type) { switch frame := frame.(type) {
@ -1684,7 +1695,7 @@ func (t *http2Client) reader(errCh chan<- error) {
case *http2.PingFrame: case *http2.PingFrame:
t.handlePing(frame) t.handlePing(frame)
case *http2.GoAwayFrame: case *http2.GoAwayFrame:
t.handleGoAway(frame) errClose = t.handleGoAway(frame)
case *http2.WindowUpdateFrame: case *http2.WindowUpdateFrame:
t.handleWindowUpdate(frame) t.handleWindowUpdate(frame)
default: default:
@ -1697,6 +1708,13 @@ func (t *http2Client) reader(errCh chan<- error) {
// keepalive running in a separate goroutine makes sure the connection is alive by sending pings. // keepalive running in a separate goroutine makes sure the connection is alive by sending pings.
func (t *http2Client) keepalive() { func (t *http2Client) keepalive() {
var err error
defer func() {
close(t.keepaliveDone)
if err != nil {
t.Close(err)
}
}()
p := &ping{data: [8]byte{}} p := &ping{data: [8]byte{}}
// True iff a ping has been sent, and no data has been received since then. // True iff a ping has been sent, and no data has been received since then.
outstandingPing := false outstandingPing := false
@ -1720,7 +1738,7 @@ func (t *http2Client) keepalive() {
continue continue
} }
if outstandingPing && timeoutLeft <= 0 { if outstandingPing && timeoutLeft <= 0 {
t.Close(connectionErrorf(true, nil, "keepalive ping failed to receive ACK within timeout")) err = connectionErrorf(true, nil, "keepalive ping failed to receive ACK within timeout")
return return
} }
t.mu.Lock() t.mu.Lock()

View File

@ -44,6 +44,7 @@ import (
) )
const defaultTestTimeout = 10 * time.Second const defaultTestTimeout = 10 * time.Second
const defaultTestShortTimeout = 10 * time.Millisecond
// TestMaxConnectionIdle tests that a server will send GoAway to an idle // TestMaxConnectionIdle tests that a server will send GoAway to an idle
// client. An idle client is one who doesn't make any RPC calls for a duration // client. An idle client is one who doesn't make any RPC calls for a duration

View File

@ -2781,6 +2781,89 @@ func (s) TestClientSendsAGoAwayFrame(t *testing.T) {
} }
} }
// readHangingConn is a wrapper around net.Conn that makes the Read() hang when
// Close() is called.
type readHangingConn struct {
net.Conn
readHangConn chan struct{} // Read() hangs until this channel is closed by Close().
closed *atomic.Bool // Set to true when Close() is called.
}
func (hc *readHangingConn) Read(b []byte) (n int, err error) {
n, err = hc.Conn.Read(b)
if hc.closed.Load() {
<-hc.readHangConn // hang the read till we want
}
return n, err
}
func (hc *readHangingConn) Close() error {
hc.closed.Store(true)
return hc.Conn.Close()
}
// Tests that closing a client transport does not return until the reader
// goroutine exits.
func (s) TestClientCloseReturnsAfterReaderCompletes(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
server := setUpServerOnly(t, 0, &ServerConfig{}, normal)
defer server.stop()
addr := resolver.Address{Addr: "localhost:" + server.port}
isReaderHanging := &atomic.Bool{}
readHangConn := make(chan struct{})
copts := ConnectOptions{
Dialer: func(_ context.Context, addr string) (net.Conn, error) {
conn, err := net.Dial("tcp", addr)
if err != nil {
return nil, err
}
return &readHangingConn{Conn: conn, readHangConn: readHangConn, closed: isReaderHanging}, nil
},
ChannelzParent: channelzSubChannel(t),
}
// Create a client transport with a custom dialer that hangs the Read()
// after Close().
ct, err := NewClientTransport(ctx, context.Background(), addr, copts, func(GoAwayReason) {})
if err != nil {
t.Fatalf("Failed to create transport: %v", err)
}
if _, err := ct.NewStream(ctx, &CallHdr{}); err != nil {
t.Fatalf("Failed to open stream: %v", err)
}
// Closing the client transport will result in the underlying net.Conn being
// closed, which will result in readHangingConn.Read() to hang. This will
// stall the exit of the reader goroutine, and will stall client
// transport's Close from returning.
transportClosed := make(chan struct{})
go func() {
ct.Close(errors.New("manually closed by client"))
close(transportClosed)
}()
// Wait for a short duration and ensure that the client transport's Close()
// does not return.
select {
case <-transportClosed:
t.Fatal("Transport closed before reader completed")
case <-time.After(defaultTestShortTimeout):
}
// Closing the channel will unblock the reader goroutine and will ensure
// that the client transport's Close() returns.
close(readHangConn)
select {
case <-transportClosed:
case <-time.After(defaultTestTimeout):
t.Fatal("Timeout when waiting for transport to close")
}
}
// hangingConn is a net.Conn wrapper for testing, simulating hanging connections // hangingConn is a net.Conn wrapper for testing, simulating hanging connections
// after a GOAWAY frame is sent, of which Write operations pause until explicitly // after a GOAWAY frame is sent, of which Write operations pause until explicitly
// signaled or a timeout occurs. // signaled or a timeout occurs.