mirror of https://github.com/grpc/grpc-go.git
transport: simplify httpClient by moving onGoAway func to onClose (#5885)
This commit is contained in:
parent
5ff7dfcd79
commit
07ac97c355
|
@ -1237,9 +1237,11 @@ func (ac *addrConn) createTransport(addr resolver.Address, copts transport.Conne
|
||||||
addr.ServerName = ac.cc.getServerName(addr)
|
addr.ServerName = ac.cc.getServerName(addr)
|
||||||
hctx, hcancel := context.WithCancel(ac.ctx)
|
hctx, hcancel := context.WithCancel(ac.ctx)
|
||||||
|
|
||||||
onClose := grpcsync.OnceFunc(func() {
|
onClose := func(r transport.GoAwayReason) {
|
||||||
ac.mu.Lock()
|
ac.mu.Lock()
|
||||||
defer ac.mu.Unlock()
|
defer ac.mu.Unlock()
|
||||||
|
// adjust params based on GoAwayReason
|
||||||
|
ac.adjustParams(r)
|
||||||
if ac.state == connectivity.Shutdown {
|
if ac.state == connectivity.Shutdown {
|
||||||
// Already shut down. tearDown() already cleared the transport and
|
// Already shut down. tearDown() already cleared the transport and
|
||||||
// canceled hctx via ac.ctx, and we expected this connection to be
|
// canceled hctx via ac.ctx, and we expected this connection to be
|
||||||
|
@ -1260,19 +1262,13 @@ func (ac *addrConn) createTransport(addr resolver.Address, copts transport.Conne
|
||||||
// Always go idle and wait for the LB policy to initiate a new
|
// Always go idle and wait for the LB policy to initiate a new
|
||||||
// connection attempt.
|
// connection attempt.
|
||||||
ac.updateConnectivityState(connectivity.Idle, nil)
|
ac.updateConnectivityState(connectivity.Idle, nil)
|
||||||
})
|
|
||||||
onGoAway := func(r transport.GoAwayReason) {
|
|
||||||
ac.mu.Lock()
|
|
||||||
ac.adjustParams(r)
|
|
||||||
ac.mu.Unlock()
|
|
||||||
onClose()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
connectCtx, cancel := context.WithDeadline(ac.ctx, connectDeadline)
|
connectCtx, cancel := context.WithDeadline(ac.ctx, connectDeadline)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
copts.ChannelzParentID = ac.channelzID
|
copts.ChannelzParentID = ac.channelzID
|
||||||
|
|
||||||
newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, addr, copts, onGoAway, onClose)
|
newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, addr, copts, onClose)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if logger.V(2) {
|
if logger.V(2) {
|
||||||
logger.Infof("Creating new client transport to %q: %v", addr, err)
|
logger.Infof("Creating new client transport to %q: %v", addr, err)
|
||||||
|
|
|
@ -140,8 +140,7 @@ type http2Client struct {
|
||||||
channelzID *channelz.Identifier
|
channelzID *channelz.Identifier
|
||||||
czData *channelzData
|
czData *channelzData
|
||||||
|
|
||||||
onGoAway func(GoAwayReason)
|
onClose func(GoAwayReason)
|
||||||
onClose func()
|
|
||||||
|
|
||||||
bufferPool *bufferPool
|
bufferPool *bufferPool
|
||||||
|
|
||||||
|
@ -197,7 +196,7 @@ func isTemporary(err error) bool {
|
||||||
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
|
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
|
||||||
// and starts to receive messages on it. Non-nil error returns if construction
|
// and starts to receive messages on it. Non-nil error returns if construction
|
||||||
// fails.
|
// fails.
|
||||||
func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts ConnectOptions, onGoAway func(GoAwayReason), onClose func()) (_ *http2Client, err error) {
|
func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts ConnectOptions, onClose func(GoAwayReason)) (_ *http2Client, err error) {
|
||||||
scheme := "http"
|
scheme := "http"
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
defer func() {
|
defer func() {
|
||||||
|
@ -343,7 +342,6 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
|
||||||
streamQuota: defaultMaxStreamsClient,
|
streamQuota: defaultMaxStreamsClient,
|
||||||
streamsQuotaAvailable: make(chan struct{}, 1),
|
streamsQuotaAvailable: make(chan struct{}, 1),
|
||||||
czData: new(channelzData),
|
czData: new(channelzData),
|
||||||
onGoAway: onGoAway,
|
|
||||||
keepaliveEnabled: keepaliveEnabled,
|
keepaliveEnabled: keepaliveEnabled,
|
||||||
bufferPool: newBufferPool(),
|
bufferPool: newBufferPool(),
|
||||||
onClose: onClose,
|
onClose: onClose,
|
||||||
|
@ -957,7 +955,9 @@ func (t *http2Client) Close(err error) {
|
||||||
}
|
}
|
||||||
// Call t.onClose ASAP to prevent the client from attempting to create new
|
// Call t.onClose ASAP to prevent the client from attempting to create new
|
||||||
// streams.
|
// streams.
|
||||||
t.onClose()
|
if t.state != draining {
|
||||||
|
t.onClose(GoAwayInvalid)
|
||||||
|
}
|
||||||
t.state = closing
|
t.state = closing
|
||||||
streams := t.activeStreams
|
streams := t.activeStreams
|
||||||
t.activeStreams = nil
|
t.activeStreams = nil
|
||||||
|
@ -1010,6 +1010,7 @@ func (t *http2Client) GracefulClose() {
|
||||||
if logger.V(logLevel) {
|
if logger.V(logLevel) {
|
||||||
logger.Infof("transport: GracefulClose called")
|
logger.Infof("transport: GracefulClose called")
|
||||||
}
|
}
|
||||||
|
t.onClose(GoAwayInvalid)
|
||||||
t.state = draining
|
t.state = draining
|
||||||
active := len(t.activeStreams)
|
active := len(t.activeStreams)
|
||||||
t.mu.Unlock()
|
t.mu.Unlock()
|
||||||
|
@ -1290,8 +1291,10 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
|
||||||
// Notify the clientconn about the GOAWAY before we set the state to
|
// Notify the clientconn about the GOAWAY before we set the state to
|
||||||
// draining, to allow the client to stop attempting to create streams
|
// draining, to allow the client to stop attempting to create streams
|
||||||
// before disallowing new streams on this connection.
|
// before disallowing new streams on this connection.
|
||||||
t.onGoAway(t.goAwayReason)
|
if t.state != draining {
|
||||||
t.state = draining
|
t.onClose(t.goAwayReason)
|
||||||
|
t.state = draining
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// All streams with IDs greater than the GoAwayId
|
// All streams with IDs greater than the GoAwayId
|
||||||
// and smaller than the previous GoAway ID should be killed.
|
// and smaller than the previous GoAway ID should be killed.
|
||||||
|
|
|
@ -583,8 +583,8 @@ type ConnectOptions struct {
|
||||||
|
|
||||||
// NewClientTransport establishes the transport with the required ConnectOptions
|
// NewClientTransport establishes the transport with the required ConnectOptions
|
||||||
// and returns it to the caller.
|
// and returns it to the caller.
|
||||||
func NewClientTransport(connectCtx, ctx context.Context, addr resolver.Address, opts ConnectOptions, onGoAway func(GoAwayReason), onClose func()) (ClientTransport, error) {
|
func NewClientTransport(connectCtx, ctx context.Context, addr resolver.Address, opts ConnectOptions, onClose func(GoAwayReason)) (ClientTransport, error) {
|
||||||
return newHTTP2Client(connectCtx, ctx, addr, opts, onGoAway, onClose)
|
return newHTTP2Client(connectCtx, ctx, addr, opts, onClose)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Options provides additional hints and information for message
|
// Options provides additional hints and information for message
|
||||||
|
|
|
@ -452,7 +452,7 @@ func setUpWithOptions(t *testing.T, port int, sc *ServerConfig, ht hType, copts
|
||||||
copts.ChannelzParentID = channelz.NewIdentifierForTesting(channelz.RefSubChannel, time.Now().Unix(), nil)
|
copts.ChannelzParentID = channelz.NewIdentifierForTesting(channelz.RefSubChannel, time.Now().Unix(), nil)
|
||||||
|
|
||||||
connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
|
connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
|
||||||
ct, connErr := NewClientTransport(connectCtx, context.Background(), addr, copts, func(GoAwayReason) {}, func() {})
|
ct, connErr := NewClientTransport(connectCtx, context.Background(), addr, copts, func(GoAwayReason) {})
|
||||||
if connErr != nil {
|
if connErr != nil {
|
||||||
cancel() // Do not cancel in success path.
|
cancel() // Do not cancel in success path.
|
||||||
t.Fatalf("failed to create transport: %v", connErr)
|
t.Fatalf("failed to create transport: %v", connErr)
|
||||||
|
@ -483,7 +483,7 @@ func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, connCh chan net.C
|
||||||
connCh <- conn
|
connCh <- conn
|
||||||
}()
|
}()
|
||||||
connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
|
connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
|
||||||
tr, err := NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}, func() {})
|
tr, err := NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cancel() // Do not cancel in success path.
|
cancel() // Do not cancel in success path.
|
||||||
// Server clean-up.
|
// Server clean-up.
|
||||||
|
@ -1287,7 +1287,7 @@ func (s) TestClientHonorsConnectContext(t *testing.T) {
|
||||||
time.AfterFunc(100*time.Millisecond, cancel)
|
time.AfterFunc(100*time.Millisecond, cancel)
|
||||||
|
|
||||||
copts := ConnectOptions{ChannelzParentID: channelz.NewIdentifierForTesting(channelz.RefSubChannel, time.Now().Unix(), nil)}
|
copts := ConnectOptions{ChannelzParentID: channelz.NewIdentifierForTesting(channelz.RefSubChannel, time.Now().Unix(), nil)}
|
||||||
_, err = NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}, func() {})
|
_, err = NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("NewClientTransport() returned successfully; wanted error")
|
t.Fatalf("NewClientTransport() returned successfully; wanted error")
|
||||||
}
|
}
|
||||||
|
@ -1299,7 +1299,7 @@ func (s) TestClientHonorsConnectContext(t *testing.T) {
|
||||||
// Test context deadline.
|
// Test context deadline.
|
||||||
connectCtx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond)
|
connectCtx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
_, err = NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}, func() {})
|
_, err = NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("NewClientTransport() returned successfully; wanted error")
|
t.Fatalf("NewClientTransport() returned successfully; wanted error")
|
||||||
}
|
}
|
||||||
|
@ -1378,7 +1378,7 @@ func (s) TestClientWithMisbehavedServer(t *testing.T) {
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
copts := ConnectOptions{ChannelzParentID: channelz.NewIdentifierForTesting(channelz.RefSubChannel, time.Now().Unix(), nil)}
|
copts := ConnectOptions{ChannelzParentID: channelz.NewIdentifierForTesting(channelz.RefSubChannel, time.Now().Unix(), nil)}
|
||||||
ct, err := NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}, func() {})
|
ct, err := NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error while creating client transport: %v", err)
|
t.Fatalf("Error while creating client transport: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -2282,7 +2282,7 @@ func (s) TestClientHandshakeInfo(t *testing.T) {
|
||||||
TransportCredentials: creds,
|
TransportCredentials: creds,
|
||||||
ChannelzParentID: channelz.NewIdentifierForTesting(channelz.RefSubChannel, time.Now().Unix(), nil),
|
ChannelzParentID: channelz.NewIdentifierForTesting(channelz.RefSubChannel, time.Now().Unix(), nil),
|
||||||
}
|
}
|
||||||
tr, err := NewClientTransport(ctx, context.Background(), addr, copts, func(GoAwayReason) {}, func() {})
|
tr, err := NewClientTransport(ctx, context.Background(), addr, copts, func(GoAwayReason) {})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewClientTransport(): %v", err)
|
t.Fatalf("NewClientTransport(): %v", err)
|
||||||
}
|
}
|
||||||
|
@ -2323,7 +2323,7 @@ func (s) TestClientHandshakeInfoDialer(t *testing.T) {
|
||||||
Dialer: dialer,
|
Dialer: dialer,
|
||||||
ChannelzParentID: channelz.NewIdentifierForTesting(channelz.RefSubChannel, time.Now().Unix(), nil),
|
ChannelzParentID: channelz.NewIdentifierForTesting(channelz.RefSubChannel, time.Now().Unix(), nil),
|
||||||
}
|
}
|
||||||
tr, err := NewClientTransport(ctx, context.Background(), addr, copts, func(GoAwayReason) {}, func() {})
|
tr, err := NewClientTransport(ctx, context.Background(), addr, copts, func(GoAwayReason) {})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewClientTransport(): %v", err)
|
t.Fatalf("NewClientTransport(): %v", err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue