diff --git a/clientconn.go b/clientconn.go index 8402c19e3..26166b89b 100644 --- a/clientconn.go +++ b/clientconn.go @@ -1237,9 +1237,11 @@ func (ac *addrConn) createTransport(addr resolver.Address, copts transport.Conne addr.ServerName = ac.cc.getServerName(addr) hctx, hcancel := context.WithCancel(ac.ctx) - onClose := grpcsync.OnceFunc(func() { + onClose := func(r transport.GoAwayReason) { ac.mu.Lock() defer ac.mu.Unlock() + // adjust params based on GoAwayReason + ac.adjustParams(r) if ac.state == connectivity.Shutdown { // Already shut down. tearDown() already cleared the transport and // canceled hctx via ac.ctx, and we expected this connection to be @@ -1260,19 +1262,13 @@ func (ac *addrConn) createTransport(addr resolver.Address, copts transport.Conne // Always go idle and wait for the LB policy to initiate a new // connection attempt. ac.updateConnectivityState(connectivity.Idle, nil) - }) - onGoAway := func(r transport.GoAwayReason) { - ac.mu.Lock() - ac.adjustParams(r) - ac.mu.Unlock() - onClose() } connectCtx, cancel := context.WithDeadline(ac.ctx, connectDeadline) defer cancel() copts.ChannelzParentID = ac.channelzID - newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, addr, copts, onGoAway, onClose) + newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, addr, copts, onClose) if err != nil { if logger.V(2) { logger.Infof("Creating new client transport to %q: %v", addr, err) diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 3e582a285..8cf2a7a11 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -140,8 +140,7 @@ type http2Client struct { channelzID *channelz.Identifier czData *channelzData - onGoAway func(GoAwayReason) - onClose func() + onClose func(GoAwayReason) bufferPool *bufferPool @@ -197,7 +196,7 @@ func isTemporary(err error) bool { // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 // and starts to receive messages on it. Non-nil error returns if construction // fails. -func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts ConnectOptions, onGoAway func(GoAwayReason), onClose func()) (_ *http2Client, err error) { +func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts ConnectOptions, onClose func(GoAwayReason)) (_ *http2Client, err error) { scheme := "http" ctx, cancel := context.WithCancel(ctx) defer func() { @@ -343,7 +342,6 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts streamQuota: defaultMaxStreamsClient, streamsQuotaAvailable: make(chan struct{}, 1), czData: new(channelzData), - onGoAway: onGoAway, keepaliveEnabled: keepaliveEnabled, bufferPool: newBufferPool(), onClose: onClose, @@ -957,7 +955,9 @@ func (t *http2Client) Close(err error) { } // Call t.onClose ASAP to prevent the client from attempting to create new // streams. - t.onClose() + if t.state != draining { + t.onClose(GoAwayInvalid) + } t.state = closing streams := t.activeStreams t.activeStreams = nil @@ -1010,6 +1010,7 @@ func (t *http2Client) GracefulClose() { if logger.V(logLevel) { logger.Infof("transport: GracefulClose called") } + t.onClose(GoAwayInvalid) t.state = draining active := len(t.activeStreams) t.mu.Unlock() @@ -1290,8 +1291,10 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { // Notify the clientconn about the GOAWAY before we set the state to // draining, to allow the client to stop attempting to create streams // before disallowing new streams on this connection. - t.onGoAway(t.goAwayReason) - t.state = draining + if t.state != draining { + t.onClose(t.goAwayReason) + t.state = draining + } } // All streams with IDs greater than the GoAwayId // and smaller than the previous GoAway ID should be killed. diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 6cff20c8e..0ac77ea4f 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -583,8 +583,8 @@ type ConnectOptions struct { // NewClientTransport establishes the transport with the required ConnectOptions // and returns it to the caller. -func NewClientTransport(connectCtx, ctx context.Context, addr resolver.Address, opts ConnectOptions, onGoAway func(GoAwayReason), onClose func()) (ClientTransport, error) { - return newHTTP2Client(connectCtx, ctx, addr, opts, onGoAway, onClose) +func NewClientTransport(connectCtx, ctx context.Context, addr resolver.Address, opts ConnectOptions, onClose func(GoAwayReason)) (ClientTransport, error) { + return newHTTP2Client(connectCtx, ctx, addr, opts, onClose) } // Options provides additional hints and information for message diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 517cd4015..b41378b00 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -452,7 +452,7 @@ func setUpWithOptions(t *testing.T, port int, sc *ServerConfig, ht hType, copts copts.ChannelzParentID = channelz.NewIdentifierForTesting(channelz.RefSubChannel, time.Now().Unix(), nil) connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) - ct, connErr := NewClientTransport(connectCtx, context.Background(), addr, copts, func(GoAwayReason) {}, func() {}) + ct, connErr := NewClientTransport(connectCtx, context.Background(), addr, copts, func(GoAwayReason) {}) if connErr != nil { cancel() // Do not cancel in success path. t.Fatalf("failed to create transport: %v", connErr) @@ -483,7 +483,7 @@ func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, connCh chan net.C connCh <- conn }() connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) - tr, err := NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}, func() {}) + tr, err := NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}) if err != nil { cancel() // Do not cancel in success path. // Server clean-up. @@ -1287,7 +1287,7 @@ func (s) TestClientHonorsConnectContext(t *testing.T) { time.AfterFunc(100*time.Millisecond, cancel) copts := ConnectOptions{ChannelzParentID: channelz.NewIdentifierForTesting(channelz.RefSubChannel, time.Now().Unix(), nil)} - _, err = NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}, func() {}) + _, err = NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}) if err == nil { t.Fatalf("NewClientTransport() returned successfully; wanted error") } @@ -1299,7 +1299,7 @@ func (s) TestClientHonorsConnectContext(t *testing.T) { // Test context deadline. connectCtx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() - _, err = NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}, func() {}) + _, err = NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}) if err == nil { t.Fatalf("NewClientTransport() returned successfully; wanted error") } @@ -1378,7 +1378,7 @@ func (s) TestClientWithMisbehavedServer(t *testing.T) { defer cancel() copts := ConnectOptions{ChannelzParentID: channelz.NewIdentifierForTesting(channelz.RefSubChannel, time.Now().Unix(), nil)} - ct, err := NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}, func() {}) + ct, err := NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}) if err != nil { t.Fatalf("Error while creating client transport: %v", err) } @@ -2282,7 +2282,7 @@ func (s) TestClientHandshakeInfo(t *testing.T) { TransportCredentials: creds, ChannelzParentID: channelz.NewIdentifierForTesting(channelz.RefSubChannel, time.Now().Unix(), nil), } - tr, err := NewClientTransport(ctx, context.Background(), addr, copts, func(GoAwayReason) {}, func() {}) + tr, err := NewClientTransport(ctx, context.Background(), addr, copts, func(GoAwayReason) {}) if err != nil { t.Fatalf("NewClientTransport(): %v", err) } @@ -2323,7 +2323,7 @@ func (s) TestClientHandshakeInfoDialer(t *testing.T) { Dialer: dialer, ChannelzParentID: channelz.NewIdentifierForTesting(channelz.RefSubChannel, time.Now().Unix(), nil), } - tr, err := NewClientTransport(ctx, context.Background(), addr, copts, func(GoAwayReason) {}, func() {}) + tr, err := NewClientTransport(ctx, context.Background(), addr, copts, func(GoAwayReason) {}) if err != nil { t.Fatalf("NewClientTransport(): %v", err) }