allow multiple GoAways and retrying on illegal streams

This commit is contained in:
iamqizhao 2016-07-25 16:35:32 -07:00
parent e40dc9bff9
commit f1e4d3b180
8 changed files with 54 additions and 25 deletions

View File

@ -179,7 +179,10 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
put() put()
put = nil put = nil
} }
if _, ok := err.(transport.ConnectionError); ok { // Retry a non-failfast RPC when
// i) there is a connection error; or
// ii) the server started to drain before this RPC was initiated.
if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
if c.failFast { if c.failFast {
return toRPCErr(err) return toRPCErr(err)
} }

View File

@ -697,14 +697,21 @@ func (ac *addrConn) tearDown(err error) {
} }
ac.cc.mu.Unlock() ac.cc.mu.Unlock()
}() }()
if ac.state == Shutdown {
return
}
ac.state = Shutdown
if ac.down != nil { if ac.down != nil {
ac.down(downErrorf(false, false, "%v", err)) ac.down(downErrorf(false, false, "%v", err))
ac.down = nil ac.down = nil
} }
if err == errConnDrain && ac.transport != nil {
// GracefulClose(...) may be executed multiple times when
// i) receiving multiple GoAway frames from the server; or
// ii) there are concurrent name resolver/Balancer triggered
// address removal and GoAway.
ac.transport.GracefulClose()
}
if ac.state == Shutdown {
return
}
ac.state = Shutdown
ac.stateCV.Broadcast() ac.stateCV.Broadcast()
if ac.events != nil { if ac.events != nil {
ac.events.Finish() ac.events.Finish()
@ -714,13 +721,9 @@ func (ac *addrConn) tearDown(err error) {
close(ac.ready) close(ac.ready)
ac.ready = nil ac.ready = nil
} }
if ac.transport != nil { if ac.transport != nil && err != errConnDrain {
if err == errConnDrain {
ac.transport.GracefulClose()
} else {
ac.transport.Close() ac.transport.Close()
} }
}
if ac.shutdownChan != nil { if ac.shutdownChan != nil {
close(ac.shutdownChan) close(ac.shutdownChan)
} }

View File

@ -795,12 +795,16 @@ func (s *Server) Stop() {
// connections and RPCs and blocks until all the pending RPCs are finished. // connections and RPCs and blocks until all the pending RPCs are finished.
func (s *Server) GracefulStop() { func (s *Server) GracefulStop() {
s.mu.Lock() s.mu.Lock()
if s.drain == true || s.conns == nil {
s.mu.Lock()
return
}
s.drain = true s.drain = true
for lis := range s.lis { for lis := range s.lis {
lis.Close() lis.Close()
} }
for c := range s.conns { for c := range s.conns {
c.(transport.ServerTransport).GoAway() c.(transport.ServerTransport).Drain()
} }
for len(s.conns) != 0 { for len(s.conns) != 0 {
s.cv.Wait() s.cv.Wait()

View File

@ -2114,6 +2114,7 @@ func interestingGoroutines() (gs []string) {
if stack == "" || if stack == "" ||
strings.Contains(stack, "testing.Main(") || strings.Contains(stack, "testing.Main(") ||
strings.Contains(stack, "testing.tRunner(") ||
strings.Contains(stack, "runtime.goexit") || strings.Contains(stack, "runtime.goexit") ||
strings.Contains(stack, "created by runtime.gc") || strings.Contains(stack, "created by runtime.gc") ||
strings.Contains(stack, "created by google3/base/go/log.init") || strings.Contains(stack, "created by google3/base/go/log.init") ||

View File

@ -370,8 +370,8 @@ func (ht *serverHandlerTransport) runStream() {
} }
} }
func (ht *serverHandlerTransport) GoAway() { func (ht *serverHandlerTransport) Drain() {
panic("not implemented") panic("Drain() is not implemented")
} }
// mapRecvMsgError returns the non-nil err into the appropriate // mapRecvMsgError returns the non-nil err into the appropriate

View File

@ -102,6 +102,8 @@ type http2Client struct {
streamSendQuota uint32 streamSendQuota uint32
// goAwayID records the Last-Stream-ID in the GoAway frame from the server. // goAwayID records the Last-Stream-ID in the GoAway frame from the server.
goAwayID uint32 goAwayID uint32
// prevGoAway ID records the Last-Stream-ID in the previous GOAway frame.
prevGoAwayID uint32
} }
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
@ -483,21 +485,25 @@ func (t *http2Client) GracefulClose() error {
t.mu.Unlock() t.mu.Unlock()
return nil return nil
} }
if t.state == draining {
t.mu.Unlock()
return nil
}
t.state = draining
// Notify the streams which were initiated after the server sent GOAWAY. // Notify the streams which were initiated after the server sent GOAWAY.
select { select {
case <-t.goAway: case <-t.goAway:
for i := t.goAwayID + 2; i < t.nextID; i += 2 { n := t.prevGoAwayID
if n == 0 && t.nextID > 1 {
n = t.nextID - 2
}
for i := t.goAwayID + 2; i <= n; i += 2 {
if s, ok := t.activeStreams[i]; ok { if s, ok := t.activeStreams[i]; ok {
close(s.goAway) close(s.goAway)
} }
} }
default: default:
} }
if t.state == draining {
t.mu.Unlock()
return nil
}
t.state = draining
active := len(t.activeStreams) active := len(t.activeStreams)
t.mu.Unlock() t.mu.Unlock()
if active == 0 { if active == 0 {
@ -742,10 +748,22 @@ func (t *http2Client) handlePing(f *http2.PingFrame) {
func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
t.mu.Lock() t.mu.Lock()
if t.state == reachable { if t.state == reachable || t.state == draining {
if t.goAwayID > 0 && t.goAwayID < f.LastStreamID {
id := t.goAwayID
t.mu.Unlock()
t.notifyError(ConnectionErrorf("received illegal http2 GOAWAY frame: previously recv GOAWAY frame with LastStramID %d, currently recv %d", id, f.LastStreamID))
return
}
t.prevGoAwayID = t.goAwayID
t.goAwayID = f.LastStreamID t.goAwayID = f.LastStreamID
select {
case <-t.goAway:
// t.goAway has been closed (i.e.,multiple GoAways).
default:
close(t.goAway) close(t.goAway)
} }
}
t.mu.Unlock() t.mu.Unlock()
} }

View File

@ -755,6 +755,6 @@ func (t *http2Server) RemoteAddr() net.Addr {
return t.conn.RemoteAddr() return t.conn.RemoteAddr()
} }
func (t *http2Server) GoAway() { func (t *http2Server) Drain() {
t.controlBuf.put(&goAway{}) t.controlBuf.put(&goAway{})
} }

View File

@ -473,8 +473,8 @@ type ServerTransport interface {
// RemoteAddr returns the remote network address. // RemoteAddr returns the remote network address.
RemoteAddr() net.Addr RemoteAddr() net.Addr
// GoAway notifies the client this ServerTransport stops accepting new RPCs. // Drain notifies the client this ServerTransport stops accepting new RPCs.
GoAway() Drain()
} }
// StreamErrorf creates an StreamError with the specified error code and description. // StreamErrorf creates an StreamError with the specified error code and description.