mirror of https://github.com/grpc/grpc-go.git
server: fix race between GracefulStop and new incoming connections (#1745)
New connections can race with GracefulStop such that the server will accept the connection, but then close it immediately. If a connection is accepted before GracefulStop has a chance to effectively cancel the listeners, the server should handle it to avoid client errors.
This commit is contained in:
parent
0547980095
commit
2720857d97
93
server.go
93
server.go
|
@ -92,11 +92,7 @@ type Server struct {
|
|||
conns map[io.Closer]bool
|
||||
serve bool
|
||||
drain bool
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
// A CondVar to let GracefulStop() blocks until all the pending RPCs are finished
|
||||
// and all the transport goes away.
|
||||
cv *sync.Cond
|
||||
cv *sync.Cond // signaled when connections close for GracefulStop
|
||||
m map[string]*service // service name -> service info
|
||||
events trace.EventLog
|
||||
|
||||
|
@ -104,6 +100,7 @@ type Server struct {
|
|||
done chan struct{}
|
||||
quitOnce sync.Once
|
||||
doneOnce sync.Once
|
||||
serveWG sync.WaitGroup // counts active Serve goroutines for GracefulStop
|
||||
}
|
||||
|
||||
type options struct {
|
||||
|
@ -343,7 +340,6 @@ func NewServer(opt ...ServerOption) *Server {
|
|||
done: make(chan struct{}),
|
||||
}
|
||||
s.cv = sync.NewCond(&s.mu)
|
||||
s.ctx, s.cancel = context.WithCancel(context.Background())
|
||||
if EnableTracing {
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line))
|
||||
|
@ -474,10 +470,19 @@ func (s *Server) Serve(lis net.Listener) error {
|
|||
s.printf("serving")
|
||||
s.serve = true
|
||||
if s.lis == nil {
|
||||
// Serve called after Stop or GracefulStop.
|
||||
s.mu.Unlock()
|
||||
lis.Close()
|
||||
return ErrServerStopped
|
||||
}
|
||||
|
||||
s.serveWG.Add(1)
|
||||
defer func() {
|
||||
s.serveWG.Done()
|
||||
// Block until Stop or GracefulStop is ready for us to return.
|
||||
<-s.done
|
||||
}()
|
||||
|
||||
s.lis[lis] = true
|
||||
s.mu.Unlock()
|
||||
defer func() {
|
||||
|
@ -511,33 +516,40 @@ func (s *Server) Serve(lis net.Listener) error {
|
|||
timer := time.NewTimer(tempDelay)
|
||||
select {
|
||||
case <-timer.C:
|
||||
case <-s.ctx.Done():
|
||||
case <-s.quit:
|
||||
timer.Stop()
|
||||
return nil
|
||||
}
|
||||
timer.Stop()
|
||||
continue
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.printf("done serving; Accept = %v", err)
|
||||
s.mu.Unlock()
|
||||
|
||||
// If Stop or GracefulStop is called, block until they are done and return nil
|
||||
// If Stop or GracefulStop is called, return nil.
|
||||
select {
|
||||
case <-s.quit:
|
||||
<-s.done
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
return err
|
||||
}
|
||||
tempDelay = 0
|
||||
// Start a new goroutine to deal with rawConn
|
||||
// so we don't stall this Accept loop goroutine.
|
||||
go s.handleRawConn(rawConn)
|
||||
// Start a new goroutine to deal with rawConn so we don't stall this Accept
|
||||
// loop goroutine.
|
||||
//
|
||||
// Make sure we account for the goroutine so GracefulStop doesn't nil out
|
||||
// s.conns before this conn can be added.
|
||||
s.serveWG.Add(1)
|
||||
go func() {
|
||||
s.handleRawConn(rawConn)
|
||||
s.serveWG.Done()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// handleRawConn is run in its own goroutine and handles a just-accepted
|
||||
// connection that has not had any I/O performed on it yet.
|
||||
// handleRawConn forks a goroutine to handle a just-accepted connection that
|
||||
// has not had any I/O performed on it yet.
|
||||
func (s *Server) handleRawConn(rawConn net.Conn) {
|
||||
rawConn.SetDeadline(time.Now().Add(s.opts.connectionTimeout))
|
||||
conn, authInfo, err := s.useTransportAuthenticator(rawConn)
|
||||
|
@ -562,17 +574,28 @@ func (s *Server) handleRawConn(rawConn net.Conn) {
|
|||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
var serve func()
|
||||
c := conn.(io.Closer)
|
||||
if s.opts.useHandlerImpl {
|
||||
rawConn.SetDeadline(time.Time{})
|
||||
s.serveUsingHandler(conn)
|
||||
serve = func() { s.serveUsingHandler(conn) }
|
||||
} else {
|
||||
// Finish handshaking (HTTP2)
|
||||
st := s.newHTTP2Transport(conn, authInfo)
|
||||
if st == nil {
|
||||
return
|
||||
}
|
||||
rawConn.SetDeadline(time.Time{})
|
||||
s.serveStreams(st)
|
||||
c = st
|
||||
serve = func() { s.serveStreams(st) }
|
||||
}
|
||||
|
||||
rawConn.SetDeadline(time.Time{})
|
||||
if !s.addConn(c) {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
serve()
|
||||
s.removeConn(c)
|
||||
}()
|
||||
}
|
||||
|
||||
// newHTTP2Transport sets up a http/2 transport (using the
|
||||
|
@ -599,15 +622,10 @@ func (s *Server) newHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) tr
|
|||
grpclog.Warningln("grpc: Server.Serve failed to create ServerTransport: ", err)
|
||||
return nil
|
||||
}
|
||||
if !s.addConn(st) {
|
||||
st.Close()
|
||||
return nil
|
||||
}
|
||||
return st
|
||||
}
|
||||
|
||||
func (s *Server) serveStreams(st transport.ServerTransport) {
|
||||
defer s.removeConn(st)
|
||||
defer st.Close()
|
||||
var wg sync.WaitGroup
|
||||
st.HandleStreams(func(stream *transport.Stream) {
|
||||
|
@ -641,11 +659,6 @@ var _ http.Handler = (*Server)(nil)
|
|||
//
|
||||
// conn is the *tls.Conn that's already been authenticated.
|
||||
func (s *Server) serveUsingHandler(conn net.Conn) {
|
||||
if !s.addConn(conn) {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
defer s.removeConn(conn)
|
||||
h2s := &http2.Server{
|
||||
MaxConcurrentStreams: s.opts.maxConcurrentStreams,
|
||||
}
|
||||
|
@ -685,7 +698,6 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
if !s.addConn(st) {
|
||||
st.Close()
|
||||
return
|
||||
}
|
||||
defer s.removeConn(st)
|
||||
|
@ -715,9 +727,15 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea
|
|||
func (s *Server) addConn(c io.Closer) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.conns == nil || s.drain {
|
||||
if s.conns == nil {
|
||||
c.Close()
|
||||
return false
|
||||
}
|
||||
if s.drain {
|
||||
// Transport added after we drained our existing conns: drain it
|
||||
// immediately.
|
||||
c.(transport.ServerTransport).Drain()
|
||||
}
|
||||
s.conns[c] = true
|
||||
return true
|
||||
}
|
||||
|
@ -1158,6 +1176,7 @@ func (s *Server) Stop() {
|
|||
})
|
||||
|
||||
defer func() {
|
||||
s.serveWG.Wait()
|
||||
s.doneOnce.Do(func() {
|
||||
close(s.done)
|
||||
})
|
||||
|
@ -1180,7 +1199,6 @@ func (s *Server) Stop() {
|
|||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.cancel()
|
||||
if s.events != nil {
|
||||
s.events.Finish()
|
||||
s.events = nil
|
||||
|
@ -1203,21 +1221,27 @@ func (s *Server) GracefulStop() {
|
|||
}()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.conns == nil {
|
||||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
for lis := range s.lis {
|
||||
lis.Close()
|
||||
}
|
||||
s.lis = nil
|
||||
s.cancel()
|
||||
if !s.drain {
|
||||
for c := range s.conns {
|
||||
c.(transport.ServerTransport).Drain()
|
||||
}
|
||||
s.drain = true
|
||||
}
|
||||
|
||||
// Wait for serving threads to be ready to exit. Only then can we be sure no
|
||||
// new conns will be created.
|
||||
s.mu.Unlock()
|
||||
s.serveWG.Wait()
|
||||
s.mu.Lock()
|
||||
|
||||
for len(s.conns) != 0 {
|
||||
s.cv.Wait()
|
||||
}
|
||||
|
@ -1226,6 +1250,7 @@ func (s *Server) GracefulStop() {
|
|||
s.events.Finish()
|
||||
s.events = nil
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
|
|
@ -0,0 +1,207 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2017 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/test/leakcheck"
|
||||
|
||||
testpb "google.golang.org/grpc/test/grpc_testing"
|
||||
)
|
||||
|
||||
type delayListener struct {
|
||||
net.Listener
|
||||
closeCalled chan struct{}
|
||||
acceptCalled chan struct{}
|
||||
allowCloseCh chan struct{}
|
||||
cc *delayConn
|
||||
}
|
||||
|
||||
func (d *delayListener) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case <-d.acceptCalled:
|
||||
// On the second call, block until closed, then return an error.
|
||||
<-d.closeCalled
|
||||
<-d.allowCloseCh
|
||||
return nil, fmt.Errorf("listener is closed")
|
||||
default:
|
||||
close(d.acceptCalled)
|
||||
return d.Listener.Accept()
|
||||
}
|
||||
}
|
||||
|
||||
func (d *delayListener) allowClose() {
|
||||
close(d.allowCloseCh)
|
||||
}
|
||||
func (d *delayListener) Close() error {
|
||||
close(d.closeCalled)
|
||||
go func() {
|
||||
<-d.allowCloseCh
|
||||
d.Listener.Close()
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *delayListener) allowClientRead() {
|
||||
d.cc.allowRead()
|
||||
}
|
||||
|
||||
func (d *delayListener) Dial(to time.Duration) (net.Conn, error) {
|
||||
c, err := net.DialTimeout("tcp", d.Listener.Addr().String(), to)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.cc = &delayConn{Conn: c, blockRead: make(chan struct{})}
|
||||
return d.cc, nil
|
||||
}
|
||||
|
||||
func (d *delayListener) clientWriteCalledChan() <-chan struct{} {
|
||||
return d.cc.writeCalledChan()
|
||||
}
|
||||
|
||||
type delayConn struct {
|
||||
net.Conn
|
||||
blockRead chan struct{}
|
||||
mu sync.Mutex
|
||||
writeCalled chan struct{}
|
||||
}
|
||||
|
||||
func (d *delayConn) writeCalledChan() <-chan struct{} {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
d.writeCalled = make(chan struct{})
|
||||
return d.writeCalled
|
||||
}
|
||||
func (d *delayConn) allowRead() {
|
||||
close(d.blockRead)
|
||||
}
|
||||
func (d *delayConn) Read(b []byte) (n int, err error) {
|
||||
<-d.blockRead
|
||||
return d.Conn.Read(b)
|
||||
}
|
||||
func (d *delayConn) Write(b []byte) (n int, err error) {
|
||||
d.mu.Lock()
|
||||
if d.writeCalled != nil {
|
||||
close(d.writeCalled)
|
||||
d.writeCalled = nil
|
||||
}
|
||||
d.mu.Unlock()
|
||||
return d.Conn.Write(b)
|
||||
}
|
||||
|
||||
func TestGracefulStop(t *testing.T) {
|
||||
defer leakcheck.Check(t)
|
||||
// This test ensures GracefulStop cannot race and break RPCs on new
|
||||
// connections created after GracefulStop was called but before
|
||||
// listener.Accept() returns a "closing" error.
|
||||
//
|
||||
// Steps of this test:
|
||||
// 1. Start Server
|
||||
// 2. GracefulStop() Server after listener's Accept is called, but don't
|
||||
// allow Accept() to exit when Close() is called on it.
|
||||
// 3. Create a new connection to the server after listener.Close() is called.
|
||||
// Server will want to send a GoAway on the new conn, but we delay client
|
||||
// reads until 5.
|
||||
// 4. Send an RPC on the new connection.
|
||||
// 5. Allow the client to read the GoAway. The RPC should complete
|
||||
// successfully.
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Error listenening: %v", err)
|
||||
}
|
||||
dlis := &delayListener{
|
||||
Listener: lis,
|
||||
acceptCalled: make(chan struct{}),
|
||||
closeCalled: make(chan struct{}),
|
||||
allowCloseCh: make(chan struct{}),
|
||||
}
|
||||
d := func(_ string, to time.Duration) (net.Conn, error) { return dlis.Dial(to) }
|
||||
|
||||
ss := &stubServer{
|
||||
emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
|
||||
return &testpb.Empty{}, nil
|
||||
},
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
testpb.RegisterTestServiceServer(s, ss)
|
||||
|
||||
// 1. Start Server
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
s.Serve(dlis)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
// 2. GracefulStop() Server after listener's Accept is called, but don't
|
||||
// allow Accept() to exit when Close() is called on it.
|
||||
<-dlis.acceptCalled
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
s.GracefulStop()
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
// 3. Create a new connection to the server after listener.Close() is called.
|
||||
// Server will want to send a GoAway on the new conn, but we delay it
|
||||
// until 5.
|
||||
|
||||
<-dlis.closeCalled // Block until GracefulStop calls dlis.Close()
|
||||
|
||||
// Now dial. The listener's Accept method will return a valid connection,
|
||||
// even though GracefulStop has closed the listener.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
cc, err := grpc.DialContext(ctx, "", grpc.WithInsecure(), grpc.WithBlock(), grpc.WithDialer(d))
|
||||
if err != nil {
|
||||
t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err)
|
||||
}
|
||||
cancel()
|
||||
client := testpb.NewTestServiceClient(cc)
|
||||
defer cc.Close()
|
||||
|
||||
dlis.allowClose()
|
||||
|
||||
wcch := dlis.clientWriteCalledChan()
|
||||
go func() {
|
||||
// 5. Allow the client to read the GoAway. The RPC should complete
|
||||
// successfully.
|
||||
<-wcch
|
||||
dlis.allowClientRead()
|
||||
}()
|
||||
|
||||
// 4. Send an RPC on the new connection.
|
||||
// The server would send a GOAWAY first, but we are delaying the server's
|
||||
// writes for now until the client writes more than the preface.
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
|
||||
if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
|
||||
t.Fatalf("EmptyCall() = %v; want <nil>", err)
|
||||
}
|
||||
|
||||
// 5. happens above, then we finish the call.
|
||||
cancel()
|
||||
wg.Wait()
|
||||
}
|
|
@ -20,6 +20,7 @@ package transport
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
|
@ -302,7 +303,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
|
|||
t.framer.writer.Flush()
|
||||
go func() {
|
||||
loopyWriter(t.ctx, t.controlBuf, t.itemHandler)
|
||||
t.Close()
|
||||
t.conn.Close()
|
||||
}()
|
||||
if t.kp.Time != infinity {
|
||||
go t.keepalive()
|
||||
|
@ -1124,7 +1125,6 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
|
|||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
if len(state.mdata) > 0 {
|
||||
s.trailer = state.mdata
|
||||
}
|
||||
|
@ -1237,8 +1237,7 @@ func (t *http2Client) applySettings(ss []http2.Setting) {
|
|||
// TODO(mmukhi): A lot of this code(and code in other places in the tranpsort layer)
|
||||
// is duplicated between the client and the server.
|
||||
// The transport layer needs to be refactored to take care of this.
|
||||
func (t *http2Client) itemHandler(i item) error {
|
||||
var err error
|
||||
func (t *http2Client) itemHandler(i item) (err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
errorf(" error in itemHandler: %v", err)
|
||||
|
@ -1246,10 +1245,11 @@ func (t *http2Client) itemHandler(i item) error {
|
|||
}()
|
||||
switch i := i.(type) {
|
||||
case *dataFrame:
|
||||
err = t.framer.fr.WriteData(i.streamID, i.endStream, i.d)
|
||||
if err == nil {
|
||||
i.f()
|
||||
if err := t.framer.fr.WriteData(i.streamID, i.endStream, i.d); err != nil {
|
||||
return err
|
||||
}
|
||||
i.f()
|
||||
return nil
|
||||
case *headerFrame:
|
||||
t.hBuf.Reset()
|
||||
for _, f := range i.hf {
|
||||
|
@ -1283,31 +1283,33 @@ func (t *http2Client) itemHandler(i item) error {
|
|||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case *windowUpdate:
|
||||
err = t.framer.fr.WriteWindowUpdate(i.streamID, i.increment)
|
||||
return t.framer.fr.WriteWindowUpdate(i.streamID, i.increment)
|
||||
case *settings:
|
||||
err = t.framer.fr.WriteSettings(i.ss...)
|
||||
return t.framer.fr.WriteSettings(i.ss...)
|
||||
case *settingsAck:
|
||||
err = t.framer.fr.WriteSettingsAck()
|
||||
return t.framer.fr.WriteSettingsAck()
|
||||
case *resetStream:
|
||||
// If the server needs to be to intimated about stream closing,
|
||||
// then we need to make sure the RST_STREAM frame is written to
|
||||
// the wire before the headers of the next stream waiting on
|
||||
// streamQuota. We ensure this by adding to the streamsQuota pool
|
||||
// only after having acquired the writableChan to send RST_STREAM.
|
||||
err = t.framer.fr.WriteRSTStream(i.streamID, i.code)
|
||||
err := t.framer.fr.WriteRSTStream(i.streamID, i.code)
|
||||
t.streamsQuota.add(1)
|
||||
return err
|
||||
case *flushIO:
|
||||
err = t.framer.writer.Flush()
|
||||
return t.framer.writer.Flush()
|
||||
case *ping:
|
||||
if !i.ack {
|
||||
t.bdpEst.timesnap(i.data)
|
||||
}
|
||||
err = t.framer.fr.WritePing(i.ack, i.data)
|
||||
return t.framer.fr.WritePing(i.ack, i.data)
|
||||
default:
|
||||
errorf("transport: http2Client.controller got unexpected item type %v", i)
|
||||
return fmt.Errorf("transport: http2Client.controller got unexpected item type %v", i)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// keepalive running in a separate goroutune makes sure the connection is alive by sending pings.
|
||||
|
|
|
@ -259,7 +259,7 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
|
|||
|
||||
go func() {
|
||||
loopyWriter(t.ctx, t.controlBuf, t.itemHandler)
|
||||
t.Close()
|
||||
t.conn.Close()
|
||||
}()
|
||||
go t.keepalive()
|
||||
return t, nil
|
||||
|
|
Loading…
Reference in New Issue