mirror of https://github.com/grpc/grpc-go.git
				
				
				
			support goaway
This commit is contained in:
		
							parent
							
								
									0e86f69ef3
								
							
						
					
					
						commit
						873cc272c2
					
				| 
						 | 
				
			
			@ -635,6 +635,7 @@ func (ac *addrConn) transportMonitor() {
 | 
			
		|||
			if t.Err() == transport.ErrConnDrain {
 | 
			
		||||
				ac.mu.Unlock()
 | 
			
		||||
				ac.tearDown(errConnDrain)
 | 
			
		||||
				ac.cc.newAddrConn(ac.addr, true)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			ac.state = TransientFailure
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -385,6 +385,12 @@ func toRPCErr(err error) error {
 | 
			
		|||
			desc: e.Desc,
 | 
			
		||||
		}
 | 
			
		||||
	case transport.ConnectionError:
 | 
			
		||||
		if err == transport.ErrConnDrain {
 | 
			
		||||
			return &rpcError{
 | 
			
		||||
				code: codes.Unavailable,
 | 
			
		||||
				desc: e.Desc,
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		return &rpcError{
 | 
			
		||||
			code: codes.Internal,
 | 
			
		||||
			desc: e.Desc,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										32
									
								
								server.go
								
								
								
								
							
							
						
						
									
										32
									
								
								server.go
								
								
								
								
							| 
						 | 
				
			
			@ -92,6 +92,8 @@ type Server struct {
 | 
			
		|||
	mu     sync.Mutex // guards following
 | 
			
		||||
	lis    map[net.Listener]bool
 | 
			
		||||
	conns  map[io.Closer]bool
 | 
			
		||||
	drain  bool
 | 
			
		||||
	cv     *sync.Cond
 | 
			
		||||
	m      map[string]*service // service name -> service info
 | 
			
		||||
	events trace.EventLog
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -186,6 +188,7 @@ func NewServer(opt ...ServerOption) *Server {
 | 
			
		|||
		conns: make(map[io.Closer]bool),
 | 
			
		||||
		m:     make(map[string]*service),
 | 
			
		||||
	}
 | 
			
		||||
	s.cv = sync.NewCond(&s.mu)
 | 
			
		||||
	if EnableTracing {
 | 
			
		||||
		_, file, line, _ := runtime.Caller(1)
 | 
			
		||||
		s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line))
 | 
			
		||||
| 
						 | 
				
			
			@ -468,7 +471,7 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea
 | 
			
		|||
func (s *Server) addConn(c io.Closer) bool {
 | 
			
		||||
	s.mu.Lock()
 | 
			
		||||
	defer s.mu.Unlock()
 | 
			
		||||
	if s.conns == nil {
 | 
			
		||||
	if s.conns == nil || s.drain {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	s.conns[c] = true
 | 
			
		||||
| 
						 | 
				
			
			@ -480,6 +483,7 @@ func (s *Server) removeConn(c io.Closer) {
 | 
			
		|||
	defer s.mu.Unlock()
 | 
			
		||||
	if s.conns != nil {
 | 
			
		||||
		delete(s.conns, c)
 | 
			
		||||
		s.cv.Signal()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -766,14 +770,14 @@ func (s *Server) Stop() {
 | 
			
		|||
	s.mu.Lock()
 | 
			
		||||
	listeners := s.lis
 | 
			
		||||
	s.lis = nil
 | 
			
		||||
	cs := s.conns
 | 
			
		||||
	st := s.conns
 | 
			
		||||
	s.conns = nil
 | 
			
		||||
	s.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	for lis := range listeners {
 | 
			
		||||
		lis.Close()
 | 
			
		||||
	}
 | 
			
		||||
	for c := range cs {
 | 
			
		||||
	for c := range st {
 | 
			
		||||
		c.Close()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -785,6 +789,28 @@ func (s *Server) Stop() {
 | 
			
		|||
	s.mu.Unlock()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) GracefulStop() {
 | 
			
		||||
	s.mu.Lock()
 | 
			
		||||
	s.drain = true
 | 
			
		||||
	for lis := range s.lis {
 | 
			
		||||
		lis.Close()
 | 
			
		||||
	}
 | 
			
		||||
	for c := range s.conns {
 | 
			
		||||
		c.(transport.ServerTransport).GoAway()
 | 
			
		||||
	}
 | 
			
		||||
	for len(s.conns) != 0 {
 | 
			
		||||
		s.cv.Wait()
 | 
			
		||||
	}
 | 
			
		||||
	s.lis = nil
 | 
			
		||||
	s.conns = nil
 | 
			
		||||
	if s.events != nil {
 | 
			
		||||
		s.events.Finish()
 | 
			
		||||
		s.events = nil
 | 
			
		||||
	}
 | 
			
		||||
	s.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	internal.TestingCloseConns = func(arg interface{}) {
 | 
			
		||||
		arg.(*Server).testingCloseConns()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -195,6 +195,9 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
 | 
			
		|||
				cs.finish(Errorf(s.StatusCode(), "%s", s.StatusDesc()))
 | 
			
		||||
			}
 | 
			
		||||
			cs.closeTransportStream(nil)
 | 
			
		||||
		case <-s.GoAway():
 | 
			
		||||
			cs.finish(errConnDrain)
 | 
			
		||||
			cs.closeTransportStream(errConnDrain)
 | 
			
		||||
		case <-s.Context().Done():
 | 
			
		||||
			err := s.Context().Err()
 | 
			
		||||
			cs.finish(err)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -572,6 +572,55 @@ func TestFailFast(t *testing.T) {
 | 
			
		|||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestServerGoAway(t *testing.T) {
 | 
			
		||||
	defer leakCheck(t)()
 | 
			
		||||
	for _, e := range listTestEnv() {
 | 
			
		||||
		if e.name == "handler-tls" {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		//if e.name != "tcp-clear" {
 | 
			
		||||
		//	continue
 | 
			
		||||
		//}
 | 
			
		||||
		testServerGoAway(t, e)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func testServerGoAway(t *testing.T, e env) {
 | 
			
		||||
	te := newTest(t, e)
 | 
			
		||||
	te.userAgent = testAppUA
 | 
			
		||||
	te.declareLogNoise(
 | 
			
		||||
		"transport: http2Client.notifyError got notified that the client transport was broken EOF",
 | 
			
		||||
		"grpc: Conn.transportMonitor exits due to: grpc: the client connection is closing",
 | 
			
		||||
		"grpc: Conn.resetTransport failed to create client transport: connection error",
 | 
			
		||||
		"grpc: Conn.resetTransport failed to create client transport: connection error: desc = \"transport: dial unix",
 | 
			
		||||
	)
 | 
			
		||||
	te.startServer(&testServer{security: e.security})
 | 
			
		||||
	defer te.tearDown()
 | 
			
		||||
 | 
			
		||||
	cc := te.clientConn()
 | 
			
		||||
	tc := testpb.NewTestServiceClient(cc)
 | 
			
		||||
	if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
 | 
			
		||||
		t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
 | 
			
		||||
	}
 | 
			
		||||
	ch := make(chan struct{})
 | 
			
		||||
	go func() {
 | 
			
		||||
		te.srv.GracefulStop()
 | 
			
		||||
		close(ch)
 | 
			
		||||
	}()
 | 
			
		||||
	for {
 | 
			
		||||
		ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond)
 | 
			
		||||
		if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		break
 | 
			
		||||
	}
 | 
			
		||||
	if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err == nil || grpc.Code(err) != codes.Unavailable {
 | 
			
		||||
		t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, error code: %d", err, codes.Unavailable)
 | 
			
		||||
	}
 | 
			
		||||
	<-ch
 | 
			
		||||
	awaitNewConnLogOutput()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func testFailFast(t *testing.T, e env) {
 | 
			
		||||
	te := newTest(t, e)
 | 
			
		||||
	te.userAgent = testAppUA
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -72,6 +72,11 @@ type resetStream struct {
 | 
			
		|||
 | 
			
		||||
func (*resetStream) item() {}
 | 
			
		||||
 | 
			
		||||
type goAway struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (*goAway) item() {}
 | 
			
		||||
 | 
			
		||||
type flushIO struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -370,6 +370,9 @@ func (ht *serverHandlerTransport) runStream() {
 | 
			
		|||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ht *serverHandlerTransport) GoAway() {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// mapRecvMsgError returns the non-nil err into the appropriate
 | 
			
		||||
// error value as expected by callers of *grpc.parser.recvMsg.
 | 
			
		||||
// In particular, in can only be:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -205,6 +205,7 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
 | 
			
		|||
	s := &Stream{
 | 
			
		||||
		id:            t.nextID,
 | 
			
		||||
		done:          make(chan struct{}),
 | 
			
		||||
		goAway:        make(chan struct{}),
 | 
			
		||||
		method:        callHdr.Method,
 | 
			
		||||
		sendCompress:  callHdr.SendCompress,
 | 
			
		||||
		buf:           newRecvBuffer(),
 | 
			
		||||
| 
						 | 
				
			
			@ -220,6 +221,7 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
 | 
			
		|||
	s.ctx, s.cancel = context.WithCancel(ctx)
 | 
			
		||||
	s.dec = &recvBufferReader{
 | 
			
		||||
		ctx:    s.ctx,
 | 
			
		||||
		goAway: s.goAway,
 | 
			
		||||
		recv:   s.buf,
 | 
			
		||||
	}
 | 
			
		||||
	return s
 | 
			
		||||
| 
						 | 
				
			
			@ -443,13 +445,13 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
 | 
			
		|||
// accessed any more.
 | 
			
		||||
func (t *http2Client) Close() (err error) {
 | 
			
		||||
	t.mu.Lock()
 | 
			
		||||
	if t.state == reachable {
 | 
			
		||||
		close(t.errorChan)
 | 
			
		||||
	}
 | 
			
		||||
	if t.state == closing {
 | 
			
		||||
		t.mu.Unlock()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if t.state == reachable {
 | 
			
		||||
		close(t.errorChan)
 | 
			
		||||
	}
 | 
			
		||||
	t.state = closing
 | 
			
		||||
	t.mu.Unlock()
 | 
			
		||||
	close(t.shutdownChan)
 | 
			
		||||
| 
						 | 
				
			
			@ -732,16 +734,11 @@ func (t *http2Client) handlePing(f *http2.PingFrame) {
 | 
			
		|||
 | 
			
		||||
func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
 | 
			
		||||
	t.mu.Lock()
 | 
			
		||||
	if t.state == reachable {
 | 
			
		||||
		t.goAwayID = f.LastStreamID
 | 
			
		||||
	t.err = ErrDrain
 | 
			
		||||
		t.err = ErrConnDrain
 | 
			
		||||
		close(t.errorChan)
 | 
			
		||||
 | 
			
		||||
	// Notify the streams which were initiated after the server sent GOAWAY.
 | 
			
		||||
	//for i := f.LastStreamID + 2; i < t.nextID; i += 2 {
 | 
			
		||||
	//	if s, ok := t.activeStreams[i]; ok {
 | 
			
		||||
	//		close(s.goAway)
 | 
			
		||||
	//	}
 | 
			
		||||
	//}
 | 
			
		||||
	}
 | 
			
		||||
	t.mu.Unlock()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -196,15 +196,22 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
 | 
			
		|||
	s.recvCompress = state.encoding
 | 
			
		||||
	s.method = state.method
 | 
			
		||||
	t.mu.Lock()
 | 
			
		||||
	if t.state == draining {
 | 
			
		||||
		t.mu.Unlock()
 | 
			
		||||
		t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if t.state != reachable {
 | 
			
		||||
		t.mu.Unlock()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if uint32(len(t.activeStreams)) >= t.maxStreams {
 | 
			
		||||
		t.mu.Unlock()
 | 
			
		||||
		t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota))
 | 
			
		||||
	t.activeStreams[s.id] = s
 | 
			
		||||
	t.mu.Unlock()
 | 
			
		||||
| 
						 | 
				
			
			@ -263,13 +270,16 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
 | 
			
		|||
		switch frame := frame.(type) {
 | 
			
		||||
		case *http2.MetaHeadersFrame:
 | 
			
		||||
			id := frame.Header().StreamID
 | 
			
		||||
			t.mu.Lock()
 | 
			
		||||
			if id%2 != 1 || id <= t.maxStreamID {
 | 
			
		||||
				t.mu.Unlock()
 | 
			
		||||
				// illegal gRPC stream id.
 | 
			
		||||
				grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", id)
 | 
			
		||||
				t.Close()
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
			t.maxStreamID = id
 | 
			
		||||
			t.mu.Unlock()
 | 
			
		||||
			t.operateHeaders(frame, handle)
 | 
			
		||||
		case *http2.DataFrame:
 | 
			
		||||
			t.handleData(frame)
 | 
			
		||||
| 
						 | 
				
			
			@ -282,6 +292,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
 | 
			
		|||
		case *http2.WindowUpdateFrame:
 | 
			
		||||
			t.handleWindowUpdate(frame)
 | 
			
		||||
		case *http2.GoAwayFrame:
 | 
			
		||||
			t.Close()
 | 
			
		||||
			break
 | 
			
		||||
		default:
 | 
			
		||||
			grpclog.Printf("transport: http2Server.HandleStreams found unhandled frame type %v.", frame)
 | 
			
		||||
| 
						 | 
				
			
			@ -675,6 +686,12 @@ func (t *http2Server) controller() {
 | 
			
		|||
					}
 | 
			
		||||
				case *resetStream:
 | 
			
		||||
					t.framer.writeRSTStream(true, i.streamID, i.code)
 | 
			
		||||
				case *goAway:
 | 
			
		||||
					t.mu.Lock()
 | 
			
		||||
					sid := t.maxStreamID
 | 
			
		||||
					t.state = draining
 | 
			
		||||
					t.mu.Unlock()
 | 
			
		||||
					t.framer.writeGoAway(true, sid, http2.ErrCodeNo, nil)
 | 
			
		||||
				case *flushIO:
 | 
			
		||||
					t.framer.flushWrite()
 | 
			
		||||
				case *ping:
 | 
			
		||||
| 
						 | 
				
			
			@ -742,3 +759,7 @@ func (t *http2Server) closeStream(s *Stream) {
 | 
			
		|||
func (t *http2Server) RemoteAddr() net.Addr {
 | 
			
		||||
	return t.conn.RemoteAddr()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *http2Server) GoAway() {
 | 
			
		||||
	t.controlBuf.put(&goAway{})
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -53,10 +53,6 @@ import (
 | 
			
		|||
	"google.golang.org/grpc/metadata"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	ErrDrain = ConnectionErrorf("transport: Server stopped accepting new RPCs")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// recvMsg represents the received msg from the transport. All transport
 | 
			
		||||
// protocol specific info has been removed.
 | 
			
		||||
type recvMsg struct {
 | 
			
		||||
| 
						 | 
				
			
			@ -147,7 +143,7 @@ func (r *recvBufferReader) Read(p []byte) (n int, err error) {
 | 
			
		|||
	case <-r.ctx.Done():
 | 
			
		||||
		return 0, ContextErr(r.ctx.Err())
 | 
			
		||||
	case <-r.goAway:
 | 
			
		||||
		return 0, ErrConnDrain
 | 
			
		||||
		return 0, ErrStreamDrain
 | 
			
		||||
	case i := <-r.recv.get():
 | 
			
		||||
		r.recv.load()
 | 
			
		||||
		m := i.(*recvMsg)
 | 
			
		||||
| 
						 | 
				
			
			@ -478,6 +474,9 @@ type ServerTransport interface {
 | 
			
		|||
 | 
			
		||||
	// RemoteAddr returns the remote network address.
 | 
			
		||||
	RemoteAddr() net.Addr
 | 
			
		||||
 | 
			
		||||
	// GoAway ...
 | 
			
		||||
	GoAway()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// StreamErrorf creates an StreamError with the specified error code and description.
 | 
			
		||||
| 
						 | 
				
			
			@ -509,6 +508,7 @@ func (e ConnectionError) Error() string {
 | 
			
		|||
var (
 | 
			
		||||
	ErrConnClosing = ConnectionError{Desc: "transport is closing"}
 | 
			
		||||
	ErrConnDrain   = ConnectionError{Desc: "transport is being drained"}
 | 
			
		||||
	ErrStreamDrain = StreamErrorf(codes.Unavailable, "afjlalf")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// StreamError is an error that only affects one stream within a connection.
 | 
			
		||||
| 
						 | 
				
			
			@ -536,7 +536,7 @@ func ContextErr(err error) StreamError {
 | 
			
		|||
// If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err.
 | 
			
		||||
// If it receives from done, it returns 0, io.EOF if ctx is not done; otherwise
 | 
			
		||||
// it return the StreamError for ctx.Err.
 | 
			
		||||
// If it receives from goAway, it returns 0, ErrConnDrain.
 | 
			
		||||
// If it receives from goAway, it returns 0, ErrStreamDrain.
 | 
			
		||||
// If it receives from closing, it returns 0, ErrConnClosing.
 | 
			
		||||
// If it receives from proceed, it returns the received integer, nil.
 | 
			
		||||
func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-chan int) (int, error) {
 | 
			
		||||
| 
						 | 
				
			
			@ -552,7 +552,7 @@ func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-
 | 
			
		|||
		}
 | 
			
		||||
		return 0, io.EOF
 | 
			
		||||
	case <-goAway:
 | 
			
		||||
		return 0, ErrConnDrain
 | 
			
		||||
		return 0, ErrStreamDrain
 | 
			
		||||
	case <-closing:
 | 
			
		||||
		return 0, ErrConnClosing
 | 
			
		||||
	case i := <-proceed:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue