From 110450d45ec9b606e8af18809e57c42770a3d2df Mon Sep 17 00:00:00 2001 From: iamqizhao Date: Wed, 27 Jul 2016 17:27:10 -0700 Subject: [PATCH] fix races introduce by goaway --- clientconn.go | 6 +- server.go | 4 +- test/end2end_test.go | 140 ++++++++++++++++++++++++++++++++++++++ transport/http2_client.go | 17 +++-- transport/http2_server.go | 5 ++ 5 files changed, 166 insertions(+), 6 deletions(-) diff --git a/clientconn.go b/clientconn.go index 3206d6747..01e3ef5f3 100644 --- a/clientconn.go +++ b/clientconn.go @@ -296,6 +296,8 @@ const ( TransientFailure // Shutdown indicates the ClientConn has started shutting down. Shutdown + // Drain + Drain ) func (s ConnectivityState) String() string { @@ -310,6 +312,8 @@ func (s ConnectivityState) String() string { return "TRANSIENT_FAILURE" case Shutdown: return "SHUTDOWN" + case Drain: + return "DRAIN" default: panic(fmt.Sprintf("unknown connectivity state: %d", s)) } @@ -632,7 +636,7 @@ func (ac *addrConn) transportMonitor() { case <-t.Error(): ac.mu.Lock() if ac.state == Shutdown { - // ac.tearDown(...) has been invoked. + // ac has been shutdown. ac.mu.Unlock() return } diff --git a/server.go b/server.go index 1a250c796..fbf96bf73 100644 --- a/server.go +++ b/server.go @@ -774,6 +774,8 @@ func (s *Server) Stop() { s.lis = nil st := s.conns s.conns = nil + // interrupt GracefulStop if Stop and GracefulStop are called concurrently. + s.cv.Signal() s.mu.Unlock() for lis := range listeners { @@ -803,13 +805,13 @@ func (s *Server) GracefulStop() { for lis := range s.lis { lis.Close() } + s.lis = nil for c := range s.conns { c.(transport.ServerTransport).Drain() } for len(s.conns) != 0 { s.cv.Wait() } - s.lis = nil s.conns = nil if s.events != nil { s.events.Finish() diff --git a/test/end2end_test.go b/test/end2end_test.go index cdbc4c555..5fd61d5cf 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -686,6 +686,146 @@ func testServerGoAwayPendingRPC(t *testing.T, e env) { awaitNewConnLogOutput() } +func TestConcurrentClientConnCloseAndServerGoAway(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + if e.name == "handler-tls" { + continue + } + testConcurrentClientConnCloseAndServerGoAway(t, e) + } +} + +func testConcurrentClientConnCloseAndServerGoAway(t *testing.T, e env) { + te := newTest(t, e) + te.userAgent = testAppUA + te.declareLogNoise( + "transport: http2Client.notifyError got notified that the client transport was broken EOF", + "grpc: Conn.transportMonitor exits due to: grpc: the client connection is closing", + "grpc: Conn.resetTransport failed to create client transport: connection error", + "grpc: Conn.resetTransport failed to create client transport: connection error: desc = \"transport: dial unix", + ) + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + stream, err := tc.FullDuplexCall(context.Background(), grpc.FailFast(false)) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + // Finish an RPC to make sure the connection is good. + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { + t.Fatalf("%v.EmptyCall(_, _, _) = _, %v, want _, ", tc, err) + } + ch := make(chan struct{}) + go func() { + te.srv.GracefulStop() + close(ch) + }() + // Loop until the server side GoAway signal is propagated to the client. + for { + ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil { + continue + } + break + } + // Stop the server and close all the connections. + te.srv.Stop() + respParam := []*testpb.ResponseParameters{ + { + Size: proto.Int32(1), + }, + } + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(100)) + if err != nil { + t.Fatal(err) + } + req := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + Payload: payload, + } + if err := stream.Send(req); err == nil { + if _, err := stream.Recv(); err == nil { + t.Fatalf("%v.Recv() = _, %v, want _, ", stream, err) + } + } + <-ch + awaitNewConnLogOutput() +} + +func TestConcurrentServerStopAndGoAway(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + if e.name == "handler-tls" { + continue + } + testConcurrentServerStopAndGoAway(t, e) + } +} + +func testConcurrentServerStopAndGoAway(t *testing.T, e env) { + te := newTest(t, e) + te.userAgent = testAppUA + te.declareLogNoise( + "transport: http2Client.notifyError got notified that the client transport was broken EOF", + "grpc: Conn.transportMonitor exits due to: grpc: the client connection is closing", + "grpc: Conn.resetTransport failed to create client transport: connection error", + "grpc: Conn.resetTransport failed to create client transport: connection error: desc = \"transport: dial unix", + ) + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + stream, err := tc.FullDuplexCall(context.Background(), grpc.FailFast(false)) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + // Finish an RPC to make sure the connection is good. + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { + t.Fatalf("%v.EmptyCall(_, _, _) = _, %v, want _, ", tc, err) + } + ch := make(chan struct{}) + go func() { + te.srv.GracefulStop() + close(ch) + }() + // Loop until the server side GoAway signal is propagated to the client. + for { + ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil { + continue + } + break + } + // Stop the server and close all the connections. + te.srv.Stop() + respParam := []*testpb.ResponseParameters{ + { + Size: proto.Int32(1), + }, + } + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(100)) + if err != nil { + t.Fatal(err) + } + req := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + Payload: payload, + } + if err := stream.Send(req); err == nil { + if _, err := stream.Recv(); err == nil { + t.Fatalf("%v.Recv() = _, %v, want _, ", stream, err) + } + } + <-ch + awaitNewConnLogOutput() +} + func TestFailFast(t *testing.T) { defer leakCheck(t)() for _, e := range listTestEnv() { diff --git a/transport/http2_client.go b/transport/http2_client.go index 036da819f..0426cc939 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -454,7 +454,7 @@ func (t *http2Client) Close() (err error) { t.mu.Unlock() return } - if t.state == reachable { + if t.state == reachable || t.state == draining { close(t.errorChan) } t.state = closing @@ -856,7 +856,11 @@ func (t *http2Client) reader() { // Check the validity of server preface. frame, err := t.framer.readFrame() if err != nil { - t.notifyError(err) + if t.state == draining { + t.Close() + } else { + t.notifyError(err) + } return } sf, ok := frame.(*http2.SettingsFrame) @@ -884,7 +888,12 @@ func (t *http2Client) reader() { continue } else { // Transport error. - t.notifyError(err) + if t.state == draining { + // A network error happened after the connection is drained. Fail the connection immediately. + t.Close() + } else { + t.notifyError(err) + } return } } @@ -993,7 +1002,7 @@ func (t *http2Client) notifyError(err error) { t.mu.Lock() defer t.mu.Unlock() // make sure t.errorChan is closed only once. - if t.state == reachable { + if t.state == reachable || t.state == draining { t.state = unreachable close(t.errorChan) grpclog.Printf("transport: http2Client.notifyError got notified that the client transport was broken %v.", err) diff --git a/transport/http2_server.go b/transport/http2_server.go index 2322c9387..357f01eeb 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -680,6 +680,11 @@ func (t *http2Server) controller() { t.framer.writeRSTStream(true, i.streamID, i.code) case *goAway: t.mu.Lock() + if t.state == closing { + t.mu.Unlock() + // The transport is closing. + return + } sid := t.maxStreamID t.state = draining t.mu.Unlock()