diff --git a/clientconn.go b/clientconn.go index c3c7691dc..f6b39e4a6 100644 --- a/clientconn.go +++ b/clientconn.go @@ -625,13 +625,18 @@ func (ac *addrConn) transportMonitor() { // the addrConn is idle (i.e., no RPC in flight). case <-ac.shutdownChan: return - case <-t.Error(): + case <-t.Done(): ac.mu.Lock() if ac.state == Shutdown { // ac.tearDown(...) has been invoked. ac.mu.Unlock() return } + if t.Err() == transport.ErrConnDrain { + ac.mu.Unlock() + ac.tearDown(errConnDrain) + return + } ac.state = TransientFailure ac.stateCV.Broadcast() ac.mu.Unlock() diff --git a/stream.go b/stream.go index a182e077d..2940cbbd3 100644 --- a/stream.go +++ b/stream.go @@ -184,7 +184,7 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth // when there is no pending I/O operations on this stream. go func() { select { - case <-t.Error(): + case <-t.Done(): // Incur transport error, simply exit. case <-s.Done(): // TODO: The trace of the RPC is terminated here when there is no pending diff --git a/transport/http2_client.go b/transport/http2_client.go index 4f22be09e..bde77f112 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -71,6 +71,7 @@ type http2Client struct { shutdownChan chan struct{} // errorChan is closed to notify the I/O error to the caller. errorChan chan struct{} + err error framer *framer hBuf *bytes.Buffer // the buffer for HPACK encoding @@ -97,6 +98,7 @@ type http2Client struct { maxStreams int // the per-stream outbound flow control window size set by the peer. streamSendQuota uint32 + goAwayID uint32 } // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 @@ -279,7 +281,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea checkStreamsQuota := t.streamsQuota != nil t.mu.Unlock() if checkStreamsQuota { - sq, err := wait(ctx, nil, t.shutdownChan, t.streamsQuota.acquire()) + sq, err := wait(ctx, nil, nil, t.shutdownChan, t.streamsQuota.acquire()) if err != nil { return nil, err } @@ -288,7 +290,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea t.streamsQuota.add(sq - 1) } } - if _, err := wait(ctx, nil, t.shutdownChan, t.writableChan); err != nil { + if _, err := wait(ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { // Return the quota back now because there is no stream returned to the caller. if _, ok := err.(StreamError); ok && checkStreamsQuota { t.streamsQuota.add(1) @@ -480,6 +482,12 @@ func (t *http2Client) GracefulClose() error { return nil } t.state = draining + // Notify the streams which were initiated after the server sent GOAWAY. + for i := t.goAwayID + 2; i < t.nextID; i += 2 { + if s, ok := t.activeStreams[i]; ok { + close(s.goAway) + } + } active := len(t.activeStreams) t.mu.Unlock() if active == 0 { @@ -500,13 +508,13 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { size := http2MaxFrameLen s.sendQuotaPool.add(0) // Wait until the stream has some quota to send the data. - sq, err := wait(s.ctx, s.done, t.shutdownChan, s.sendQuotaPool.acquire()) + sq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, s.sendQuotaPool.acquire()) if err != nil { return err } t.sendQuotaPool.add(0) // Wait until the transport has some quota to send the data. - tq, err := wait(s.ctx, s.done, t.shutdownChan, t.sendQuotaPool.acquire()) + tq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.sendQuotaPool.acquire()) if err != nil { if _, ok := err.(StreamError); ok || err == io.EOF { t.sendQuotaPool.cancel() @@ -540,7 +548,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { // Indicate there is a writer who is about to write a data frame. t.framer.adjustNumWriters(1) // Got some quota. Try to acquire writing privilege on the transport. - if _, err := wait(s.ctx, s.done, t.shutdownChan, t.writableChan); err != nil { + if _, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.writableChan); err != nil { if _, ok := err.(StreamError); ok || err == io.EOF { // Return the connection quota back. t.sendQuotaPool.add(len(p)) @@ -723,7 +731,18 @@ func (t *http2Client) handlePing(f *http2.PingFrame) { } func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { - // TODO(zhaoq): GoAwayFrame handler to be implemented + t.mu.Lock() + t.goAwayID = f.LastStreamID + t.err = ErrDrain + close(t.errorChan) + + // Notify the streams which were initiated after the server sent GOAWAY. + //for i := f.LastStreamID + 2; i < t.nextID; i += 2 { + // if s, ok := t.activeStreams[i]; ok { + // close(s.goAway) + // } + //} + t.mu.Unlock() } func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) { @@ -928,10 +947,14 @@ func (t *http2Client) controller() { } } -func (t *http2Client) Error() <-chan struct{} { +func (t *http2Client) Done() <-chan struct{} { return t.errorChan } +func (t *http2Client) Err() error { + return t.err +} + func (t *http2Client) notifyError(err error) { t.mu.Lock() defer t.mu.Unlock() diff --git a/transport/http2_server.go b/transport/http2_server.go index 9e35fdd85..2467630a1 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -451,7 +451,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { } s.headerOk = true s.mu.Unlock() - if _, err := wait(s.ctx, nil, t.shutdownChan, t.writableChan); err != nil { + if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { return err } t.hBuf.Reset() @@ -491,7 +491,7 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s headersSent = true } s.mu.Unlock() - if _, err := wait(s.ctx, nil, t.shutdownChan, t.writableChan); err != nil { + if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { return err } t.hBuf.Reset() @@ -540,7 +540,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { } s.mu.Unlock() if writeHeaderFrame { - if _, err := wait(s.ctx, nil, t.shutdownChan, t.writableChan); err != nil { + if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { return err } t.hBuf.Reset() @@ -568,13 +568,13 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { size := http2MaxFrameLen s.sendQuotaPool.add(0) // Wait until the stream has some quota to send the data. - sq, err := wait(s.ctx, nil, t.shutdownChan, s.sendQuotaPool.acquire()) + sq, err := wait(s.ctx, nil, nil, t.shutdownChan, s.sendQuotaPool.acquire()) if err != nil { return err } t.sendQuotaPool.add(0) // Wait until the transport has some quota to send the data. - tq, err := wait(s.ctx, nil, t.shutdownChan, t.sendQuotaPool.acquire()) + tq, err := wait(s.ctx, nil, nil, t.shutdownChan, t.sendQuotaPool.acquire()) if err != nil { if _, ok := err.(StreamError); ok { t.sendQuotaPool.cancel() @@ -600,7 +600,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { t.framer.adjustNumWriters(1) // Got some quota. Try to acquire writing privilege on the // transport. - if _, err := wait(s.ctx, nil, t.shutdownChan, t.writableChan); err != nil { + if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { if _, ok := err.(StreamError); ok { // Return the connection quota back. t.sendQuotaPool.add(ps) diff --git a/transport/transport.go b/transport/transport.go index 4dab57459..4a7b83c57 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -53,6 +53,10 @@ import ( "google.golang.org/grpc/metadata" ) +var ( + ErrDrain = ConnectionErrorf("transport: Server stopped accepting new RPCs") +) + // recvMsg represents the received msg from the transport. All transport // protocol specific info has been removed. type recvMsg struct { @@ -120,10 +124,11 @@ func (b *recvBuffer) get() <-chan item { // recvBufferReader implements io.Reader interface to read the data from // recvBuffer. type recvBufferReader struct { - ctx context.Context - recv *recvBuffer - last *bytes.Reader // Stores the remaining data in the previous calls. - err error + ctx context.Context + goAway chan struct{} + recv *recvBuffer + last *bytes.Reader // Stores the remaining data in the previous calls. + err error } // Read reads the next len(p) bytes from last. If last is drained, it tries to @@ -141,6 +146,8 @@ func (r *recvBufferReader) Read(p []byte) (n int, err error) { select { case <-r.ctx.Done(): return 0, ContextErr(r.ctx.Err()) + case <-r.goAway: + return 0, ErrConnDrain case i := <-r.recv.get(): r.recv.load() m := i.(*recvMsg) @@ -171,6 +178,8 @@ type Stream struct { cancel context.CancelFunc // done is closed when the final status arrives. done chan struct{} + // goAway + goAway chan struct{} // method records the associated RPC method of the stream. method string recvCompress string @@ -220,6 +229,10 @@ func (s *Stream) Done() <-chan struct{} { return s.done } +func (s *Stream) GoAway() <-chan struct{} { + return s.goAway +} + // Header acquires the key-value pairs of header metadata once it // is available. It blocks until i) the metadata is ready or ii) there is no // header metadata or iii) the stream is cancelled/expired. @@ -422,7 +435,18 @@ type ClientTransport interface { // this in order to take action (e.g., close the current transport // and create a new one) in error case. It should not return nil // once the transport is initiated. - Error() <-chan struct{} + //Error() <-chan struct{} + + // Done returns a channel that is closed when some I/O error + // happens or ClientTranspor receives the draining signal from the server + // (e.g., GOAWAY frame in HTTP/2). Typically the caller should have + // a goroutine to monitor this in order to take action (e.g., close + // the current transport and create a new one) in error case. It should + // not return nil once the transport is initiated. + Done() <-chan struct{} + + // Err returns ... + Err() error } // ServerTransport is the common interface for all gRPC server-side transport @@ -482,7 +506,10 @@ func (e ConnectionError) Error() string { } // ErrConnClosing indicates that the transport is closing. -var ErrConnClosing = ConnectionError{Desc: "transport is closing"} +var ( + ErrConnClosing = ConnectionError{Desc: "transport is closing"} + ErrConnDrain = ConnectionError{Desc: "transport is being drained"} +) // StreamError is an error that only affects one stream within a connection. type StreamError struct { @@ -509,9 +536,10 @@ func ContextErr(err error) StreamError { // If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err. // If it receives from done, it returns 0, io.EOF if ctx is not done; otherwise // it return the StreamError for ctx.Err. +// If it receives from goAway, it returns 0, ErrConnDrain. // If it receives from closing, it returns 0, ErrConnClosing. // If it receives from proceed, it returns the received integer, nil. -func wait(ctx context.Context, done, closing <-chan struct{}, proceed <-chan int) (int, error) { +func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-chan int) (int, error) { select { case <-ctx.Done(): return 0, ContextErr(ctx.Err()) @@ -523,6 +551,8 @@ func wait(ctx context.Context, done, closing <-chan struct{}, proceed <-chan int default: } return 0, io.EOF + case <-goAway: + return 0, ErrConnDrain case <-closing: return 0, ErrConnClosing case i := <-proceed: diff --git a/transport/transport_test.go b/transport/transport_test.go index ce015da29..a98f27e5b 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -271,8 +271,8 @@ func TestClientSendAndReceive(t *testing.T) { func TestClientErrorNotify(t *testing.T) { server, ct := setUp(t, 0, math.MaxUint32, normal) go server.stop() - // ct.reader should detect the error and activate ct.Error(). - <-ct.Error() + // ct.reader should detect the error and activate ct.Done(). + <-ct.Done() ct.Close() } @@ -309,7 +309,7 @@ func TestClientMix(t *testing.T) { s.stop() }(s) go func(ct ClientTransport) { - <-ct.Error() + <-ct.Done() ct.Close() }(ct) for i := 0; i < 1000; i++ { @@ -700,7 +700,7 @@ func TestClientWithMisbehavedServer(t *testing.T) { } } // http2Client.errChan is closed due to connection flow control window size violation. - <-conn.Error() + <-conn.Done() ct.Close() server.stop() }