mirror of https://github.com/grpc/grpc-go.git
fix races introduce by goaway
This commit is contained in:
parent
f921887ab5
commit
110450d45e
|
@ -296,6 +296,8 @@ const (
|
||||||
TransientFailure
|
TransientFailure
|
||||||
// Shutdown indicates the ClientConn has started shutting down.
|
// Shutdown indicates the ClientConn has started shutting down.
|
||||||
Shutdown
|
Shutdown
|
||||||
|
// Drain
|
||||||
|
Drain
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s ConnectivityState) String() string {
|
func (s ConnectivityState) String() string {
|
||||||
|
@ -310,6 +312,8 @@ func (s ConnectivityState) String() string {
|
||||||
return "TRANSIENT_FAILURE"
|
return "TRANSIENT_FAILURE"
|
||||||
case Shutdown:
|
case Shutdown:
|
||||||
return "SHUTDOWN"
|
return "SHUTDOWN"
|
||||||
|
case Drain:
|
||||||
|
return "DRAIN"
|
||||||
default:
|
default:
|
||||||
panic(fmt.Sprintf("unknown connectivity state: %d", s))
|
panic(fmt.Sprintf("unknown connectivity state: %d", s))
|
||||||
}
|
}
|
||||||
|
@ -632,7 +636,7 @@ func (ac *addrConn) transportMonitor() {
|
||||||
case <-t.Error():
|
case <-t.Error():
|
||||||
ac.mu.Lock()
|
ac.mu.Lock()
|
||||||
if ac.state == Shutdown {
|
if ac.state == Shutdown {
|
||||||
// ac.tearDown(...) has been invoked.
|
// ac has been shutdown.
|
||||||
ac.mu.Unlock()
|
ac.mu.Unlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -774,6 +774,8 @@ func (s *Server) Stop() {
|
||||||
s.lis = nil
|
s.lis = nil
|
||||||
st := s.conns
|
st := s.conns
|
||||||
s.conns = nil
|
s.conns = nil
|
||||||
|
// interrupt GracefulStop if Stop and GracefulStop are called concurrently.
|
||||||
|
s.cv.Signal()
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
for lis := range listeners {
|
for lis := range listeners {
|
||||||
|
@ -803,13 +805,13 @@ func (s *Server) GracefulStop() {
|
||||||
for lis := range s.lis {
|
for lis := range s.lis {
|
||||||
lis.Close()
|
lis.Close()
|
||||||
}
|
}
|
||||||
|
s.lis = nil
|
||||||
for c := range s.conns {
|
for c := range s.conns {
|
||||||
c.(transport.ServerTransport).Drain()
|
c.(transport.ServerTransport).Drain()
|
||||||
}
|
}
|
||||||
for len(s.conns) != 0 {
|
for len(s.conns) != 0 {
|
||||||
s.cv.Wait()
|
s.cv.Wait()
|
||||||
}
|
}
|
||||||
s.lis = nil
|
|
||||||
s.conns = nil
|
s.conns = nil
|
||||||
if s.events != nil {
|
if s.events != nil {
|
||||||
s.events.Finish()
|
s.events.Finish()
|
||||||
|
|
|
@ -686,6 +686,146 @@ func testServerGoAwayPendingRPC(t *testing.T, e env) {
|
||||||
awaitNewConnLogOutput()
|
awaitNewConnLogOutput()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConcurrentClientConnCloseAndServerGoAway(t *testing.T) {
|
||||||
|
defer leakCheck(t)()
|
||||||
|
for _, e := range listTestEnv() {
|
||||||
|
if e.name == "handler-tls" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
testConcurrentClientConnCloseAndServerGoAway(t, e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testConcurrentClientConnCloseAndServerGoAway(t *testing.T, e env) {
|
||||||
|
te := newTest(t, e)
|
||||||
|
te.userAgent = testAppUA
|
||||||
|
te.declareLogNoise(
|
||||||
|
"transport: http2Client.notifyError got notified that the client transport was broken EOF",
|
||||||
|
"grpc: Conn.transportMonitor exits due to: grpc: the client connection is closing",
|
||||||
|
"grpc: Conn.resetTransport failed to create client transport: connection error",
|
||||||
|
"grpc: Conn.resetTransport failed to create client transport: connection error: desc = \"transport: dial unix",
|
||||||
|
)
|
||||||
|
te.startServer(&testServer{security: e.security})
|
||||||
|
defer te.tearDown()
|
||||||
|
|
||||||
|
cc := te.clientConn()
|
||||||
|
tc := testpb.NewTestServiceClient(cc)
|
||||||
|
stream, err := tc.FullDuplexCall(context.Background(), grpc.FailFast(false))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
|
||||||
|
}
|
||||||
|
// Finish an RPC to make sure the connection is good.
|
||||||
|
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
|
||||||
|
t.Fatalf("%v.EmptyCall(_, _, _) = _, %v, want _, <nil>", tc, err)
|
||||||
|
}
|
||||||
|
ch := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
te.srv.GracefulStop()
|
||||||
|
close(ch)
|
||||||
|
}()
|
||||||
|
// Loop until the server side GoAway signal is propagated to the client.
|
||||||
|
for {
|
||||||
|
ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||||
|
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// Stop the server and close all the connections.
|
||||||
|
te.srv.Stop()
|
||||||
|
respParam := []*testpb.ResponseParameters{
|
||||||
|
{
|
||||||
|
Size: proto.Int32(1),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(100))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
req := &testpb.StreamingOutputCallRequest{
|
||||||
|
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
|
||||||
|
ResponseParameters: respParam,
|
||||||
|
Payload: payload,
|
||||||
|
}
|
||||||
|
if err := stream.Send(req); err == nil {
|
||||||
|
if _, err := stream.Recv(); err == nil {
|
||||||
|
t.Fatalf("%v.Recv() = _, %v, want _, <nil>", stream, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
<-ch
|
||||||
|
awaitNewConnLogOutput()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentServerStopAndGoAway(t *testing.T) {
|
||||||
|
defer leakCheck(t)()
|
||||||
|
for _, e := range listTestEnv() {
|
||||||
|
if e.name == "handler-tls" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
testConcurrentServerStopAndGoAway(t, e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testConcurrentServerStopAndGoAway(t *testing.T, e env) {
|
||||||
|
te := newTest(t, e)
|
||||||
|
te.userAgent = testAppUA
|
||||||
|
te.declareLogNoise(
|
||||||
|
"transport: http2Client.notifyError got notified that the client transport was broken EOF",
|
||||||
|
"grpc: Conn.transportMonitor exits due to: grpc: the client connection is closing",
|
||||||
|
"grpc: Conn.resetTransport failed to create client transport: connection error",
|
||||||
|
"grpc: Conn.resetTransport failed to create client transport: connection error: desc = \"transport: dial unix",
|
||||||
|
)
|
||||||
|
te.startServer(&testServer{security: e.security})
|
||||||
|
defer te.tearDown()
|
||||||
|
|
||||||
|
cc := te.clientConn()
|
||||||
|
tc := testpb.NewTestServiceClient(cc)
|
||||||
|
stream, err := tc.FullDuplexCall(context.Background(), grpc.FailFast(false))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
|
||||||
|
}
|
||||||
|
// Finish an RPC to make sure the connection is good.
|
||||||
|
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
|
||||||
|
t.Fatalf("%v.EmptyCall(_, _, _) = _, %v, want _, <nil>", tc, err)
|
||||||
|
}
|
||||||
|
ch := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
te.srv.GracefulStop()
|
||||||
|
close(ch)
|
||||||
|
}()
|
||||||
|
// Loop until the server side GoAway signal is propagated to the client.
|
||||||
|
for {
|
||||||
|
ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||||
|
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// Stop the server and close all the connections.
|
||||||
|
te.srv.Stop()
|
||||||
|
respParam := []*testpb.ResponseParameters{
|
||||||
|
{
|
||||||
|
Size: proto.Int32(1),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(100))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
req := &testpb.StreamingOutputCallRequest{
|
||||||
|
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
|
||||||
|
ResponseParameters: respParam,
|
||||||
|
Payload: payload,
|
||||||
|
}
|
||||||
|
if err := stream.Send(req); err == nil {
|
||||||
|
if _, err := stream.Recv(); err == nil {
|
||||||
|
t.Fatalf("%v.Recv() = _, %v, want _, <nil>", stream, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
<-ch
|
||||||
|
awaitNewConnLogOutput()
|
||||||
|
}
|
||||||
|
|
||||||
func TestFailFast(t *testing.T) {
|
func TestFailFast(t *testing.T) {
|
||||||
defer leakCheck(t)()
|
defer leakCheck(t)()
|
||||||
for _, e := range listTestEnv() {
|
for _, e := range listTestEnv() {
|
||||||
|
|
|
@ -454,7 +454,7 @@ func (t *http2Client) Close() (err error) {
|
||||||
t.mu.Unlock()
|
t.mu.Unlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if t.state == reachable {
|
if t.state == reachable || t.state == draining {
|
||||||
close(t.errorChan)
|
close(t.errorChan)
|
||||||
}
|
}
|
||||||
t.state = closing
|
t.state = closing
|
||||||
|
@ -856,7 +856,11 @@ func (t *http2Client) reader() {
|
||||||
// Check the validity of server preface.
|
// Check the validity of server preface.
|
||||||
frame, err := t.framer.readFrame()
|
frame, err := t.framer.readFrame()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.notifyError(err)
|
if t.state == draining {
|
||||||
|
t.Close()
|
||||||
|
} else {
|
||||||
|
t.notifyError(err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
sf, ok := frame.(*http2.SettingsFrame)
|
sf, ok := frame.(*http2.SettingsFrame)
|
||||||
|
@ -884,7 +888,12 @@ func (t *http2Client) reader() {
|
||||||
continue
|
continue
|
||||||
} else {
|
} else {
|
||||||
// Transport error.
|
// Transport error.
|
||||||
t.notifyError(err)
|
if t.state == draining {
|
||||||
|
// A network error happened after the connection is drained. Fail the connection immediately.
|
||||||
|
t.Close()
|
||||||
|
} else {
|
||||||
|
t.notifyError(err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -993,7 +1002,7 @@ func (t *http2Client) notifyError(err error) {
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
defer t.mu.Unlock()
|
defer t.mu.Unlock()
|
||||||
// make sure t.errorChan is closed only once.
|
// make sure t.errorChan is closed only once.
|
||||||
if t.state == reachable {
|
if t.state == reachable || t.state == draining {
|
||||||
t.state = unreachable
|
t.state = unreachable
|
||||||
close(t.errorChan)
|
close(t.errorChan)
|
||||||
grpclog.Printf("transport: http2Client.notifyError got notified that the client transport was broken %v.", err)
|
grpclog.Printf("transport: http2Client.notifyError got notified that the client transport was broken %v.", err)
|
||||||
|
|
|
@ -680,6 +680,11 @@ func (t *http2Server) controller() {
|
||||||
t.framer.writeRSTStream(true, i.streamID, i.code)
|
t.framer.writeRSTStream(true, i.streamID, i.code)
|
||||||
case *goAway:
|
case *goAway:
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
|
if t.state == closing {
|
||||||
|
t.mu.Unlock()
|
||||||
|
// The transport is closing.
|
||||||
|
return
|
||||||
|
}
|
||||||
sid := t.maxStreamID
|
sid := t.maxStreamID
|
||||||
t.state = draining
|
t.state = draining
|
||||||
t.mu.Unlock()
|
t.mu.Unlock()
|
||||||
|
|
Loading…
Reference in New Issue