diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 0756a6b52..ff344a0a6 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -1024,13 +1024,7 @@ func (t *http2Server) Close() error { } // deleteStream deletes the stream s from transport's active streams. -func (t *http2Server) deleteStream(s *Stream, eosReceived bool) (oldState streamState) { - oldState = s.swapState(streamDone) - if oldState == streamDone { - // If the stream was already done, return. - return oldState - } - +func (t *http2Server) deleteStream(s *Stream, eosReceived bool) { // In case stream sending and receiving are invoked in separate // goroutines (e.g., bi-directional streaming), cancel needs to be // called to interrupt the potential blocking on other goroutines. @@ -1052,15 +1046,13 @@ func (t *http2Server) deleteStream(s *Stream, eosReceived bool) (oldState stream atomic.AddInt64(&t.czData.streamsFailed, 1) } } - - return oldState } // finishStream closes the stream and puts the trailing headerFrame into controlbuf. func (t *http2Server) finishStream(s *Stream, rst bool, rstCode http2.ErrCode, hdr *headerFrame, eosReceived bool) { - oldState := t.deleteStream(s, eosReceived) - // If the stream is already closed, then don't put trailing header to controlbuf. + oldState := s.swapState(streamDone) if oldState == streamDone { + // If the stream was already done, return. return } @@ -1068,14 +1060,18 @@ func (t *http2Server) finishStream(s *Stream, rst bool, rstCode http2.ErrCode, h streamID: s.id, rst: rst, rstCode: rstCode, - onWrite: func() {}, + onWrite: func() { + t.deleteStream(s, eosReceived) + }, } t.controlBuf.put(hdr) } // closeStream clears the footprint of a stream when the stream is not needed any more. func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, eosReceived bool) { + s.swapState(streamDone) t.deleteStream(s, eosReceived) + t.controlBuf.put(&cleanupStream{ streamID: s.id, rst: rst, diff --git a/test/end2end_test.go b/test/end2end_test.go index a7f0fd470..b78ff2733 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -5237,6 +5237,7 @@ type stubServer struct { // A client connected to this service the test may use. Created in Start(). client testpb.TestServiceClient cc *grpc.ClientConn + s *grpc.Server addr string // address of listener @@ -5274,6 +5275,7 @@ func (ss *stubServer) Start(sopts []grpc.ServerOption, dopts ...grpc.DialOption) testpb.RegisterTestServiceServer(s, ss) go s.Serve(lis) ss.cleanups = append(ss.cleanups, s.Stop) + ss.s = s target := ss.r.Scheme() + ":///" + ss.addr diff --git a/test/stream_cleanup_test.go b/test/stream_cleanup_test.go index 728c51771..cb31b4eb2 100644 --- a/test/stream_cleanup_test.go +++ b/test/stream_cleanup_test.go @@ -20,7 +20,9 @@ package test import ( "context" + "io" "testing" + "time" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -55,3 +57,78 @@ func (s) TestStreamCleanup(t *testing.T) { t.Fatalf("should succeed, err: %v", err) } } + +func (s) TestStreamCleanupAfterSendStatus(t *testing.T) { + const initialWindowSize uint = 70 * 1024 // Must be higher than default 64K, ignored otherwise + const bodySize = 2 * initialWindowSize // Something that is not going to fit in a single window + + serverReturnedStatus := make(chan struct{}) + + ss := &stubServer{ + fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error { + defer func() { + close(serverReturnedStatus) + }() + return stream.Send(&testpb.StreamingOutputCallResponse{ + Payload: &testpb.Payload{ + Body: make([]byte, bodySize), + }, + }) + }, + } + if err := ss.Start([]grpc.ServerOption{grpc.MaxConcurrentStreams(1)}, grpc.WithInitialWindowSize(int32(initialWindowSize))); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + // This test makes sure we don't delete stream from server transport's + // activeStreams list too aggressively. + + // 1. Make a long living stream RPC. So server's activeStream list is not + // empty. + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + stream, err := ss.client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("FullDuplexCall= _, %v; want _, ", err) + } + + // 2. Wait for service handler to return status. + // + // This will trigger a stream cleanup code, which will eventually remove + // this stream from activeStream. + // + // But the stream removal won't happen because it's supposed to be done + // after the status is sent by loopyWriter, and the status send is blocked + // by flow control. + <-serverReturnedStatus + + // 3. GracefulStop (besides sending goaway) checks the number of + // activeStreams. + // + // It will close the connection if there's no active streams. This won't + // happen because of the pending stream. But if there's a bug in stream + // cleanup that causes stream to be removed too aggressively, the connection + // will be closd and the stream will be broken. + gracefulStopDone := make(chan struct{}) + go func() { + defer close(gracefulStopDone) + ss.s.GracefulStop() + }() + + // 4. Make sure the stream is not broken. + if _, err := stream.Recv(); err != nil { + t.Fatalf("stream.Recv() = _, %v, want _, ", err) + } + if _, err := stream.Recv(); err != io.EOF { + t.Fatalf("stream.Recv() = _, %v, want _, io.EOF", err) + } + + timer := time.NewTimer(time.Second) + select { + case <-gracefulStopDone: + timer.Stop() + case <-timer.C: + t.Fatalf("s.GracefulStop() didn't finish without 1 second after the last RPC") + } +}