server: fix race between GracefulStop and new incoming connections (#1745)

New connections can race with GracefulStop such that the server will accept the connection, but then close it immediately. If a connection is accepted before GracefulStop has a chance to effectively cancel the listeners, the server should handle it to avoid client errors.
This commit is contained in:
dfawley 2017-12-18 15:38:51 -08:00 committed by GitHub
parent 0547980095
commit 2720857d97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 283 additions and 49 deletions

View File

@ -92,11 +92,7 @@ type Server struct {
conns map[io.Closer]bool
serve bool
drain bool
ctx context.Context
cancel context.CancelFunc
// A CondVar to let GracefulStop() blocks until all the pending RPCs are finished
// and all the transport goes away.
cv *sync.Cond
cv *sync.Cond // signaled when connections close for GracefulStop
m map[string]*service // service name -> service info
events trace.EventLog
@ -104,6 +100,7 @@ type Server struct {
done chan struct{}
quitOnce sync.Once
doneOnce sync.Once
serveWG sync.WaitGroup // counts active Serve goroutines for GracefulStop
}
type options struct {
@ -343,7 +340,6 @@ func NewServer(opt ...ServerOption) *Server {
done: make(chan struct{}),
}
s.cv = sync.NewCond(&s.mu)
s.ctx, s.cancel = context.WithCancel(context.Background())
if EnableTracing {
_, file, line, _ := runtime.Caller(1)
s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line))
@ -474,10 +470,19 @@ func (s *Server) Serve(lis net.Listener) error {
s.printf("serving")
s.serve = true
if s.lis == nil {
// Serve called after Stop or GracefulStop.
s.mu.Unlock()
lis.Close()
return ErrServerStopped
}
s.serveWG.Add(1)
defer func() {
s.serveWG.Done()
// Block until Stop or GracefulStop is ready for us to return.
<-s.done
}()
s.lis[lis] = true
s.mu.Unlock()
defer func() {
@ -511,33 +516,40 @@ func (s *Server) Serve(lis net.Listener) error {
timer := time.NewTimer(tempDelay)
select {
case <-timer.C:
case <-s.ctx.Done():
case <-s.quit:
timer.Stop()
return nil
}
timer.Stop()
continue
}
s.mu.Lock()
s.printf("done serving; Accept = %v", err)
s.mu.Unlock()
// If Stop or GracefulStop is called, block until they are done and return nil
// If Stop or GracefulStop is called, return nil.
select {
case <-s.quit:
<-s.done
return nil
default:
}
return err
}
tempDelay = 0
// Start a new goroutine to deal with rawConn
// so we don't stall this Accept loop goroutine.
go s.handleRawConn(rawConn)
// Start a new goroutine to deal with rawConn so we don't stall this Accept
// loop goroutine.
//
// Make sure we account for the goroutine so GracefulStop doesn't nil out
// s.conns before this conn can be added.
s.serveWG.Add(1)
go func() {
s.handleRawConn(rawConn)
s.serveWG.Done()
}()
}
}
// handleRawConn is run in its own goroutine and handles a just-accepted
// connection that has not had any I/O performed on it yet.
// handleRawConn forks a goroutine to handle a just-accepted connection that
// has not had any I/O performed on it yet.
func (s *Server) handleRawConn(rawConn net.Conn) {
rawConn.SetDeadline(time.Now().Add(s.opts.connectionTimeout))
conn, authInfo, err := s.useTransportAuthenticator(rawConn)
@ -562,17 +574,28 @@ func (s *Server) handleRawConn(rawConn net.Conn) {
}
s.mu.Unlock()
var serve func()
c := conn.(io.Closer)
if s.opts.useHandlerImpl {
rawConn.SetDeadline(time.Time{})
s.serveUsingHandler(conn)
serve = func() { s.serveUsingHandler(conn) }
} else {
// Finish handshaking (HTTP2)
st := s.newHTTP2Transport(conn, authInfo)
if st == nil {
return
}
rawConn.SetDeadline(time.Time{})
s.serveStreams(st)
c = st
serve = func() { s.serveStreams(st) }
}
rawConn.SetDeadline(time.Time{})
if !s.addConn(c) {
return
}
go func() {
serve()
s.removeConn(c)
}()
}
// newHTTP2Transport sets up a http/2 transport (using the
@ -599,15 +622,10 @@ func (s *Server) newHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) tr
grpclog.Warningln("grpc: Server.Serve failed to create ServerTransport: ", err)
return nil
}
if !s.addConn(st) {
st.Close()
return nil
}
return st
}
func (s *Server) serveStreams(st transport.ServerTransport) {
defer s.removeConn(st)
defer st.Close()
var wg sync.WaitGroup
st.HandleStreams(func(stream *transport.Stream) {
@ -641,11 +659,6 @@ var _ http.Handler = (*Server)(nil)
//
// conn is the *tls.Conn that's already been authenticated.
func (s *Server) serveUsingHandler(conn net.Conn) {
if !s.addConn(conn) {
conn.Close()
return
}
defer s.removeConn(conn)
h2s := &http2.Server{
MaxConcurrentStreams: s.opts.maxConcurrentStreams,
}
@ -685,7 +698,6 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
if !s.addConn(st) {
st.Close()
return
}
defer s.removeConn(st)
@ -715,9 +727,15 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea
func (s *Server) addConn(c io.Closer) bool {
s.mu.Lock()
defer s.mu.Unlock()
if s.conns == nil || s.drain {
if s.conns == nil {
c.Close()
return false
}
if s.drain {
// Transport added after we drained our existing conns: drain it
// immediately.
c.(transport.ServerTransport).Drain()
}
s.conns[c] = true
return true
}
@ -1158,6 +1176,7 @@ func (s *Server) Stop() {
})
defer func() {
s.serveWG.Wait()
s.doneOnce.Do(func() {
close(s.done)
})
@ -1180,7 +1199,6 @@ func (s *Server) Stop() {
}
s.mu.Lock()
s.cancel()
if s.events != nil {
s.events.Finish()
s.events = nil
@ -1203,21 +1221,27 @@ func (s *Server) GracefulStop() {
}()
s.mu.Lock()
defer s.mu.Unlock()
if s.conns == nil {
s.mu.Unlock()
return
}
for lis := range s.lis {
lis.Close()
}
s.lis = nil
s.cancel()
if !s.drain {
for c := range s.conns {
c.(transport.ServerTransport).Drain()
}
s.drain = true
}
// Wait for serving threads to be ready to exit. Only then can we be sure no
// new conns will be created.
s.mu.Unlock()
s.serveWG.Wait()
s.mu.Lock()
for len(s.conns) != 0 {
s.cv.Wait()
}
@ -1226,6 +1250,7 @@ func (s *Server) GracefulStop() {
s.events.Finish()
s.events = nil
}
s.mu.Unlock()
}
func init() {

207
test/gracefulstop_test.go Normal file
View File

@ -0,0 +1,207 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package test
import (
"fmt"
"net"
"sync"
"testing"
"time"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/test/leakcheck"
testpb "google.golang.org/grpc/test/grpc_testing"
)
type delayListener struct {
net.Listener
closeCalled chan struct{}
acceptCalled chan struct{}
allowCloseCh chan struct{}
cc *delayConn
}
func (d *delayListener) Accept() (net.Conn, error) {
select {
case <-d.acceptCalled:
// On the second call, block until closed, then return an error.
<-d.closeCalled
<-d.allowCloseCh
return nil, fmt.Errorf("listener is closed")
default:
close(d.acceptCalled)
return d.Listener.Accept()
}
}
func (d *delayListener) allowClose() {
close(d.allowCloseCh)
}
func (d *delayListener) Close() error {
close(d.closeCalled)
go func() {
<-d.allowCloseCh
d.Listener.Close()
}()
return nil
}
func (d *delayListener) allowClientRead() {
d.cc.allowRead()
}
func (d *delayListener) Dial(to time.Duration) (net.Conn, error) {
c, err := net.DialTimeout("tcp", d.Listener.Addr().String(), to)
if err != nil {
return nil, err
}
d.cc = &delayConn{Conn: c, blockRead: make(chan struct{})}
return d.cc, nil
}
func (d *delayListener) clientWriteCalledChan() <-chan struct{} {
return d.cc.writeCalledChan()
}
type delayConn struct {
net.Conn
blockRead chan struct{}
mu sync.Mutex
writeCalled chan struct{}
}
func (d *delayConn) writeCalledChan() <-chan struct{} {
d.mu.Lock()
defer d.mu.Unlock()
d.writeCalled = make(chan struct{})
return d.writeCalled
}
func (d *delayConn) allowRead() {
close(d.blockRead)
}
func (d *delayConn) Read(b []byte) (n int, err error) {
<-d.blockRead
return d.Conn.Read(b)
}
func (d *delayConn) Write(b []byte) (n int, err error) {
d.mu.Lock()
if d.writeCalled != nil {
close(d.writeCalled)
d.writeCalled = nil
}
d.mu.Unlock()
return d.Conn.Write(b)
}
func TestGracefulStop(t *testing.T) {
defer leakcheck.Check(t)
// This test ensures GracefulStop cannot race and break RPCs on new
// connections created after GracefulStop was called but before
// listener.Accept() returns a "closing" error.
//
// Steps of this test:
// 1. Start Server
// 2. GracefulStop() Server after listener's Accept is called, but don't
// allow Accept() to exit when Close() is called on it.
// 3. Create a new connection to the server after listener.Close() is called.
// Server will want to send a GoAway on the new conn, but we delay client
// reads until 5.
// 4. Send an RPC on the new connection.
// 5. Allow the client to read the GoAway. The RPC should complete
// successfully.
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error listenening: %v", err)
}
dlis := &delayListener{
Listener: lis,
acceptCalled: make(chan struct{}),
closeCalled: make(chan struct{}),
allowCloseCh: make(chan struct{}),
}
d := func(_ string, to time.Duration) (net.Conn, error) { return dlis.Dial(to) }
ss := &stubServer{
emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
return &testpb.Empty{}, nil
},
}
s := grpc.NewServer()
testpb.RegisterTestServiceServer(s, ss)
// 1. Start Server
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
s.Serve(dlis)
wg.Done()
}()
// 2. GracefulStop() Server after listener's Accept is called, but don't
// allow Accept() to exit when Close() is called on it.
<-dlis.acceptCalled
wg.Add(1)
go func() {
s.GracefulStop()
wg.Done()
}()
// 3. Create a new connection to the server after listener.Close() is called.
// Server will want to send a GoAway on the new conn, but we delay it
// until 5.
<-dlis.closeCalled // Block until GracefulStop calls dlis.Close()
// Now dial. The listener's Accept method will return a valid connection,
// even though GracefulStop has closed the listener.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
cc, err := grpc.DialContext(ctx, "", grpc.WithInsecure(), grpc.WithBlock(), grpc.WithDialer(d))
if err != nil {
t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err)
}
cancel()
client := testpb.NewTestServiceClient(cc)
defer cc.Close()
dlis.allowClose()
wcch := dlis.clientWriteCalledChan()
go func() {
// 5. Allow the client to read the GoAway. The RPC should complete
// successfully.
<-wcch
dlis.allowClientRead()
}()
// 4. Send an RPC on the new connection.
// The server would send a GOAWAY first, but we are delaying the server's
// writes for now until the client writes more than the preface.
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Fatalf("EmptyCall() = %v; want <nil>", err)
}
// 5. happens above, then we finish the call.
cancel()
wg.Wait()
}

View File

@ -20,6 +20,7 @@ package transport
import (
"bytes"
"fmt"
"io"
"math"
"net"
@ -302,7 +303,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
t.framer.writer.Flush()
go func() {
loopyWriter(t.ctx, t.controlBuf, t.itemHandler)
t.Close()
t.conn.Close()
}()
if t.kp.Time != infinity {
go t.keepalive()
@ -1124,7 +1125,6 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
s.mu.Unlock()
return
}
if len(state.mdata) > 0 {
s.trailer = state.mdata
}
@ -1237,8 +1237,7 @@ func (t *http2Client) applySettings(ss []http2.Setting) {
// TODO(mmukhi): A lot of this code(and code in other places in the tranpsort layer)
// is duplicated between the client and the server.
// The transport layer needs to be refactored to take care of this.
func (t *http2Client) itemHandler(i item) error {
var err error
func (t *http2Client) itemHandler(i item) (err error) {
defer func() {
if err != nil {
errorf(" error in itemHandler: %v", err)
@ -1246,10 +1245,11 @@ func (t *http2Client) itemHandler(i item) error {
}()
switch i := i.(type) {
case *dataFrame:
err = t.framer.fr.WriteData(i.streamID, i.endStream, i.d)
if err == nil {
i.f()
if err := t.framer.fr.WriteData(i.streamID, i.endStream, i.d); err != nil {
return err
}
i.f()
return nil
case *headerFrame:
t.hBuf.Reset()
for _, f := range i.hf {
@ -1283,31 +1283,33 @@ func (t *http2Client) itemHandler(i item) error {
return err
}
}
return nil
case *windowUpdate:
err = t.framer.fr.WriteWindowUpdate(i.streamID, i.increment)
return t.framer.fr.WriteWindowUpdate(i.streamID, i.increment)
case *settings:
err = t.framer.fr.WriteSettings(i.ss...)
return t.framer.fr.WriteSettings(i.ss...)
case *settingsAck:
err = t.framer.fr.WriteSettingsAck()
return t.framer.fr.WriteSettingsAck()
case *resetStream:
// If the server needs to be to intimated about stream closing,
// then we need to make sure the RST_STREAM frame is written to
// the wire before the headers of the next stream waiting on
// streamQuota. We ensure this by adding to the streamsQuota pool
// only after having acquired the writableChan to send RST_STREAM.
err = t.framer.fr.WriteRSTStream(i.streamID, i.code)
err := t.framer.fr.WriteRSTStream(i.streamID, i.code)
t.streamsQuota.add(1)
return err
case *flushIO:
err = t.framer.writer.Flush()
return t.framer.writer.Flush()
case *ping:
if !i.ack {
t.bdpEst.timesnap(i.data)
}
err = t.framer.fr.WritePing(i.ack, i.data)
return t.framer.fr.WritePing(i.ack, i.data)
default:
errorf("transport: http2Client.controller got unexpected item type %v", i)
return fmt.Errorf("transport: http2Client.controller got unexpected item type %v", i)
}
return err
}
// keepalive running in a separate goroutune makes sure the connection is alive by sending pings.

View File

@ -259,7 +259,7 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
go func() {
loopyWriter(t.ctx, t.controlBuf, t.itemHandler)
t.Close()
t.conn.Close()
}()
go t.keepalive()
return t, nil