diff --git a/stream.go b/stream.go index a0373600f..70c6447a7 100644 --- a/stream.go +++ b/stream.go @@ -189,6 +189,10 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth select { case <-t.Error(): // Incur transport error, simply exit. + case <-s.Done(): + err := Errorf(s.StatusCode(), s.StatusDesc()) + cs.finish(err) + cs.closeTransportStream(err) case <-s.Context().Done(): err := s.Context().Err() cs.finish(err) @@ -251,7 +255,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { if err != nil { cs.finish(err) } - if err == nil || err == io.EOF || err == transport.ErrEarlyDone { + if err == nil || err == io.EOF { return } if _, ok := err.(transport.ConnectionError); !ok { @@ -326,11 +330,6 @@ func (cs *clientStream) CloseSend() (err error) { } }() if err == nil || err == io.EOF { - return - } - if err == transport.ErrEarlyDone { - // If the RPC is done prematurely, Stream.RecvMsg(...) needs to be - // called to get the final status and clear the footprint. return nil } if _, ok := err.(transport.ConnectionError); !ok { diff --git a/transport/http2_client.go b/transport/http2_client.go index 2715e2d0b..373a5f130 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -202,7 +202,7 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { // TODO(zhaoq): Handle uint32 overflow of Stream.id. s := &Stream{ id: t.nextID, - earlyDone: make(chan struct{}), + done: make(chan struct{}), method: callHdr.Method, sendCompress: callHdr.SendCompress, buf: newRecvBuffer(), @@ -419,7 +419,7 @@ func (t *http2Client) CloseStream(s *Stream, err error) { // goroutines (e.g., bi-directional streaming), the caller needs // to call cancel on the stream to interrupt the blocking on // other goroutines. - s.cancel() + //s.cancel() s.mu.Lock() if q := s.fc.resetPendingData(); q > 0 { if n := t.fc.onRead(q); n > 0 { @@ -505,15 +505,15 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { size := http2MaxFrameLen s.sendQuotaPool.add(0) // Wait until the stream has some quota to send the data. - sq, err := wait(s.ctx, s.earlyDone, t.shutdownChan, s.sendQuotaPool.acquire()) + sq, err := wait(s.ctx, s.done, t.shutdownChan, s.sendQuotaPool.acquire()) if err != nil { return err } t.sendQuotaPool.add(0) // Wait until the transport has some quota to send the data. - tq, err := wait(s.ctx, s.earlyDone, t.shutdownChan, t.sendQuotaPool.acquire()) + tq, err := wait(s.ctx, s.done, t.shutdownChan, t.sendQuotaPool.acquire()) if err != nil { - if _, ok := err.(StreamError); ok { + if _, ok := err.(StreamError); ok || err == io.EOF { t.sendQuotaPool.cancel() } return err @@ -545,8 +545,8 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { // Indicate there is a writer who is about to write a data frame. t.framer.adjustNumWriters(1) // Got some quota. Try to acquire writing privilege on the transport. - if _, err := wait(s.ctx, s.earlyDone, t.shutdownChan, t.writableChan); err != nil { - if _, ok := err.(StreamError); ok { + if _, err := wait(s.ctx, s.done, t.shutdownChan, t.writableChan); err != nil { + if _, ok := err.(StreamError); ok || err == io.EOF { // Return the connection quota back. t.sendQuotaPool.add(len(p)) } @@ -775,11 +775,7 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { } s.statusCode = state.statusCode s.statusDesc = state.statusDesc - if s.state != streamWriteDone { - // This is required to interrupt any pending blocking Write calls - // when the final RPC status has been arrived. - close(s.earlyDone) - } + close(s.done) s.state = streamDone s.mu.Unlock() s.write(recvMsg{err: io.EOF}) diff --git a/transport/transport.go b/transport/transport.go index 502c14813..b1f4c077a 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -169,8 +169,8 @@ type Stream struct { // ctx is the associated context of the stream. ctx context.Context cancel context.CancelFunc - // earlyDone is closed when the final status arrives prematurely. - earlyDone chan struct{} + // done is closed when the final status arrives prematurely. + done chan struct{} // method records the associated RPC method of the stream. method string recvCompress string @@ -216,6 +216,10 @@ func (s *Stream) SetSendCompress(str string) { s.sendCompress = str } +func (s *Stream) Done() <-chan struct{} { + return s.done +} + // Header acquires the key-value pairs of header metadata once it // is available. It blocks until i) the metadata is ready or ii) there is no // header metadata or iii) the stream is cancelled/expired. @@ -460,7 +464,8 @@ func StreamErrorf(c codes.Code, format string, a ...interface{}) StreamError { } } -var ErrEarlyDone = StreamErrorf(codes.Internal, "rpc is done prematurely") +// ErrDone indicates +//var ErrDone = StreamErrorf(codes.Internal, "rpc is done") // ConnectionErrorf creates an ConnectionError with the specified error description. func ConnectionErrorf(format string, a ...interface{}) ConnectionError { @@ -505,15 +510,22 @@ func ContextErr(err error) StreamError { // wait blocks until it can receive from ctx.Done, closing, or proceed. // If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err. -// If it receives from earlyDone, it returns 0, ErrEarlyDone. +// If it receives from done, it returns 0, io.EOF if ctx is not done; otherwise +// it return the StreamError for ctx.Err. // If it receives from closing, it returns 0, ErrConnClosing. // If it receives from proceed, it returns the received integer, nil. -func wait(ctx context.Context, earlyDone, closing <-chan struct{}, proceed <-chan int) (int, error) { +func wait(ctx context.Context, done, closing <-chan struct{}, proceed <-chan int) (int, error) { select { case <-ctx.Done(): return 0, ContextErr(ctx.Err()) - case <-earlyDone: - return 0, ErrEarlyDone + case <-done: + // User cancellation has precedence. + select { + case <-ctx.Done(): + return 0, ContextErr(ctx.Err()) + default: + } + return 0, io.EOF case <-closing: return 0, ErrConnClosing case i := <-proceed: