diff --git a/clientconn.go b/clientconn.go index 05a1abc91..8e3b1bb8a 100644 --- a/clientconn.go +++ b/clientconn.go @@ -251,7 +251,7 @@ func WithUserAgent(s string) DialOption { } // WithKeepaliveParams returns a DialOption that specifies keepalive paramaters for the client transport. -func WithKeepaliveParams(k keepalive.Params) DialOption { +func WithKeepaliveParams(k *keepalive.Params) DialOption { return func(o *dialOptions) { o.copts.KeepaliveParams = k } diff --git a/keepalive/keepalive.go b/keepalive/keepalive.go index 1546c6b18..4a18d52ca 100644 --- a/keepalive/keepalive.go +++ b/keepalive/keepalive.go @@ -2,7 +2,6 @@ package keepalive import ( "math" - "sync" "time" ) @@ -16,35 +15,20 @@ type Params struct { PermitWithoutStream bool } -// DefaultParams contains default values for keepalive parameters. -var DefaultParams = Params{ - Time: time.Duration(math.MaxInt64), // default to infinite. - Timeout: time.Duration(20 * time.Second), +// Validate is used to validate the keepalive parameters. +// Time durations initialized to 0 will be replaced with default Values. +func (p *Params) Validate() { + if p.Time == 0 { + p.Time = Infinity + } + if p.Timeout == 0 { + p.Time = TwentySec + } } -// mu is a mutex to protect Enabled variable. -var mu = sync.Mutex{} - -// enable is a knob used to turn keepalive on or off. -var enable = false - -// Enabled exposes the value of enable variable. -func Enabled() bool { - mu.Lock() - defer mu.Unlock() - return enable -} - -// Enable can be called to enable keepalives. -func Enable() { - mu.Lock() - defer mu.Unlock() - enable = true -} - -// Disable can be called to disable keepalive. -func Disable() { - mu.Lock() - defer mu.Unlock() - enable = false -} +const ( + // Infinity is the default value of keepalive time. + Infinity = time.Duration(math.MaxInt64) + // TwentySec is the default value of timeout. + TwentySec = time.Duration(20 * time.Second) +) diff --git a/transport/control.go b/transport/control.go index 2586cba46..ca93de29c 100644 --- a/transport/control.go +++ b/transport/control.go @@ -51,6 +51,11 @@ const ( // The following defines various control items which could flow through // the control buffer of transport. They represent different aspects of // control tasks, e.g., flow control, settings, streaming resetting, etc. + +type fireKeepaliveTimer struct{} + +func (fireKeepaliveTimer) item() {} + type windowUpdate struct { streamID uint32 increment uint32 diff --git a/transport/http2_client.go b/transport/http2_client.go index 8655e7401..ee4fda8a1 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -101,10 +101,14 @@ type http2Client struct { creds []credentials.PerRPCCredentials - // Counter to keep track of activity(reading and writing on transport). + // Counter to keep track of reading activity on transport. activity uint64 // accessed atomically. + // Flag to keep track if the keepalive check was skipped because there + // were no active streams and keepalive.PermitWithoutStream was false + // keepaliveSkipped = 1 means skipped + keepaliveSkipped uint32 // accessed atomically // keepalive parameters. - kp keepalive.Params + kp *keepalive.Params statsHandler stats.Handler mu sync.Mutex // guard the following variables @@ -188,9 +192,10 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( if opts.UserAgent != "" { ua = opts.UserAgent + " " + ua } - kp := keepalive.DefaultParams - if opts.KeepaliveParams != (keepalive.Params{}) { + kp := defaultKeepaliveParams + if opts.KeepaliveParams != nil { kp = opts.KeepaliveParams + kp.Validate() } var buf bytes.Buffer t := &http2Client{ @@ -384,6 +389,11 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea s := t.newStream(ctx, callHdr) s.clientStatsCtx = userCtx t.activeStreams[s.id] = s + // if the number of active streams are now equal to 1, then check if keepalive + // was being skipped. If so, fire the keepalive timer + if len(t.activeStreams) == 1 && atomic.LoadUint32(&t.keepaliveSkipped) == 1 { + t.controlBuf.put(fireKeepaliveTimer{}) + } // This stream is not counted when applySetings(...) initialize t.streamsQuota. // Reset t.streamsQuota to the right value. @@ -717,7 +727,6 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { break } } - atomic.AddUint64(&t.activity, 1) if !opts.Last { return nil } @@ -1082,7 +1091,7 @@ func (t *http2Client) applySettings(ss []http2.Setting) { // frames (e.g., window update, reset stream, setting, etc.) to the server. func (t *http2Client) controller() { timer := time.NewTimer(t.kp.Time) - if !keepalive.Enabled() { + if t.kp.Timeout == keepalive.Infinity { // Prevent the timer from firing, ever. if !timer.Stop() { <-timer.C @@ -1090,17 +1099,36 @@ func (t *http2Client) controller() { } isPingSent := false keepalivePing := &ping{data: [8]byte{}} + // select toggles between control channel and writable chanel. + // We need to wait on writable channel only after having recieved + // a control message that requires controller to take an action. + // However, while waiting on either of these channels, the keepalive + // timer channel or shutdown channel might trigger. Such toggling + // take care of this case. cchan := t.controlBuf.get() - wchan := nil + var wchan chan int + var controlMsg item for { select { - case i := <-cchan: + case controlMsg = <-cchan: t.controlBuf.load() + // If controlMsg is of type fireKeepaliveTimer, + // then check if the keepaliveSkipped flag is still set. + if _, ok := controlMsg.(fireKeepaliveTimer); ok { + if atomic.LoadUint32(&t.keepaliveSkipped) == 1 { + // Reset the timer to 0 so that it fires. + if !timer.Stop() { + <-timer.C + } + timer.Reset(0) + } + continue + } wchan = t.writableChan cchan = nil continue case <-wchan: - switch i := i.(type) { + switch i := controlMsg.(type) { case *windowUpdate: t.framer.writeWindowUpdate(true, i.streamID, i.increment) case *settings: @@ -1127,22 +1155,30 @@ func (t *http2Client) controller() { t.mu.Lock() ns := len(t.activeStreams) t.mu.Unlock() - // Get the activity counter value and reset it. - a := atomic.SwapUint64(&t.activity, 0) - if a > 0 || (!t.kp.PermitWithoutStream && ns < 1) { + if !t.kp.PermitWithoutStream && ns < 1 { timer.Reset(t.kp.Time) isPingSent = false - } else { - if !isPingSent { - // Send ping. - t.controlBuf.put(keepalivePing) - isPingSent = true - timer.Reset(t.kp.Timeout) - } else { - t.Close() - continue - } + // set flag that signifyies keepalive was skipped + atomic.StoreUint32(&t.keepaliveSkipped, 1) + continue } + // reset the keepaliveSkipped flag + atomic.StoreUint32(&t.keepaliveSkipped, 0) + // Get the activity counter value and reset it. + a := atomic.SwapUint64(&t.activity, 0) + if a > 0 { + timer.Reset(t.kp.Time) + isPingSent = false + continue + } + if !isPingSent { + // Send ping. + t.controlBuf.put(keepalivePing) + isPingSent = true + timer.Reset(t.kp.Timeout) + continue + } + t.Close() case <-t.shutdownChan: return } diff --git a/transport/transport.go b/transport/transport.go index e1331bbfb..f6b1754b6 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -384,11 +384,17 @@ type ConnectOptions struct { // TransportCredentials stores the Authenticator required to setup a client connection. TransportCredentials credentials.TransportCredentials // KeepaliveParams stores the keepalive parameters. - KeepaliveParams keepalive.Params + KeepaliveParams *keepalive.Params // StatsHandler stores the handler for stats. StatsHandler stats.Handler } +// default values for keepalive parameters. +var defaultKeepaliveParams = &keepalive.Params{ + Time: keepalive.Infinity, // default to infinite. + Timeout: keepalive.TwentySec, +} + // TargetInfo contains the information of the target such as network address and metadata. type TargetInfo struct { Addr string diff --git a/transport/transport_test.go b/transport/transport_test.go index fe3907ed4..3ab5a9f7e 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -298,10 +298,8 @@ func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, done chan net.Con } func TestKeepaliveClientClosesIdleTransport(t *testing.T) { - keepalive.Enable() - defer keepalive.Disable() done := make(chan net.Conn, 1) - tr := setUpWithNoPingServer(t, ConnectOptions{KeepaliveParams: keepalive.Params{ + tr := setUpWithNoPingServer(t, ConnectOptions{KeepaliveParams: &keepalive.Params{ Time: 2 * time.Second, // Keepalive time = 2 sec. Timeout: 1 * time.Second, // Keepalive timeout = 1 sec. PermitWithoutStream: true, // Run keepalive even with no RPCs. @@ -324,10 +322,8 @@ func TestKeepaliveClientClosesIdleTransport(t *testing.T) { } func TestKeepaliveClientStaysHealthyOnIdleTransport(t *testing.T) { - keepalive.Enable() - defer keepalive.Disable() done := make(chan net.Conn, 1) - tr := setUpWithNoPingServer(t, ConnectOptions{KeepaliveParams: keepalive.Params{ + tr := setUpWithNoPingServer(t, ConnectOptions{KeepaliveParams: &keepalive.Params{ Time: 2 * time.Second, // Keepalive time = 2 sec. Timeout: 1 * time.Second, // Keepalive timeout = 1 sec. PermitWithoutStream: false, // Don't run keepalive even with no RPCs. @@ -350,10 +346,8 @@ func TestKeepaliveClientStaysHealthyOnIdleTransport(t *testing.T) { } func TestKeepaliveClientClosesWithActiveStreams(t *testing.T) { - keepalive.Enable() - defer keepalive.Disable() done := make(chan net.Conn, 1) - tr := setUpWithNoPingServer(t, ConnectOptions{KeepaliveParams: keepalive.Params{ + tr := setUpWithNoPingServer(t, ConnectOptions{KeepaliveParams: &keepalive.Params{ Time: 2 * time.Second, // Keepalive time = 2 sec. Timeout: 1 * time.Second, // Keepalive timeout = 1 sec. PermitWithoutStream: false, // Don't run keepalive even with no RPCs. @@ -381,12 +375,10 @@ func TestKeepaliveClientClosesWithActiveStreams(t *testing.T) { } func TestKeepaliveClientStaysHealthyWithResponsiveServer(t *testing.T) { - keepalive.Enable() - defer keepalive.Disable() - s, tr := setUpWithOptions(t, 0, math.MaxUint32, normal, ConnectOptions{KeepaliveParams: keepalive.Params{ + s, tr := setUpWithOptions(t, 0, math.MaxUint32, normal, ConnectOptions{KeepaliveParams: &keepalive.Params{ Time: 2 * time.Second, // Keepalive time = 2 sec. Timeout: 1 * time.Second, // Keepalive timeout = 1 sec. - PermitWithoutStream: true, // Don't run keepalive even with no RPCs. + PermitWithoutStream: true, // Run keepalive even with no RPCs. }}) defer s.stop() defer tr.Close()