From d021e89b3f47668fd1a1e97b14c391cb9a1eaf8d Mon Sep 17 00:00:00 2001 From: Jean de Klerk Date: Wed, 20 Mar 2019 13:58:29 -0600 Subject: [PATCH] internal: fix Dial_OneBackoffPerRetryGroup (#2689) * internal: fix Dial_OneBackoffPerRetryGroup Instead of mutating global variables, switches getMinConnectDeadline to a dial option. Fixes #2687. * rename getMinConnectTimeoutFunc to minConnectTimeout, ditto dial opt --- clientconn.go | 11 ++--- clientconn_state_transition_test.go | 17 ++++--- clientconn_test.go | 74 ++++++++++++----------------- dialoptions.go | 19 ++++++-- 4 files changed, 60 insertions(+), 61 deletions(-) diff --git a/clientconn.go b/clientconn.go index 4ff588424..959582456 100644 --- a/clientconn.go +++ b/clientconn.go @@ -68,11 +68,6 @@ var ( errConnClosing = errors.New("grpc: the connection is closing") // errBalancerClosed indicates that the balancer is closed. errBalancerClosed = errors.New("grpc: balancer is closed") - // We use an accessor so that minConnectTimeout can be - // atomically read and updated while testing. - getMinConnectTimeout = func() time.Duration { - return minConnectTimeout - } ) // The following errors are returned from Dial and DialContext @@ -971,7 +966,11 @@ func (ac *addrConn) resetTransport() { addrs := ac.addrs backoffFor := ac.dopts.bs.Backoff(ac.backoffIdx) // This will be the duration that dial gets to finish. - dialDuration := getMinConnectTimeout() + dialDuration := minConnectTimeout + if ac.dopts.minConnectTimeout != nil { + dialDuration = ac.dopts.minConnectTimeout() + } + if dialDuration < backoffFor { // Give dial more time as we keep failing to connect. dialDuration = backoffFor diff --git a/clientconn_state_transition_test.go b/clientconn_state_transition_test.go index 37fd3ad97..a924d9d27 100644 --- a/clientconn_state_transition_test.go +++ b/clientconn_state_transition_test.go @@ -22,7 +22,6 @@ import ( "context" "net" "sync" - "sync/atomic" "testing" "time" @@ -46,12 +45,6 @@ func init() { // except that it is unbuffered, so each read and write will wait for the other // side's corresponding write or read. func (s) TestStateTransitions_SingleAddress(t *testing.T) { - mctBkp := getMinConnectTimeout() - defer func() { - atomic.StoreInt64((*int64)(&mutableMinConnectTimeout), int64(mctBkp)) - }() - atomic.StoreInt64((*int64)(&mutableMinConnectTimeout), int64(time.Millisecond)*100) - for _, test := range []struct { desc string want []connectivity.State @@ -163,8 +156,14 @@ func testStateTransitionSingleAddress(t *testing.T, want []connectivity.State, s connMu.Unlock() }() - client, err := DialContext(ctx, "", WithWaitForHandshake(), WithInsecure(), - WithBalancerName(stateRecordingBalancerName), WithDialer(pl.Dialer()), withBackoff(noBackoff{})) + client, err := DialContext(ctx, + "", + WithWaitForHandshake(), + WithInsecure(), + WithBalancerName(stateRecordingBalancerName), + WithDialer(pl.Dialer()), + withBackoff(noBackoff{}), + withMinConnectDeadline(func() time.Duration { return time.Millisecond * 100 })) if err != nil { t.Fatal(err) } diff --git a/clientconn_test.go b/clientconn_test.go index b619344eb..e8e464253 100644 --- a/clientconn_test.go +++ b/clientconn_test.go @@ -42,16 +42,6 @@ import ( "google.golang.org/grpc/testdata" ) -var ( - mutableMinConnectTimeout = time.Second * 20 -) - -func init() { - getMinConnectTimeout = func() time.Duration { - return time.Duration(atomic.LoadInt64((*int64)(&mutableMinConnectTimeout))) - } -} - func assertState(wantState connectivity.State, cc *ClientConn) (connectivity.State, bool) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -255,11 +245,15 @@ func (s) TestDialWaitsForServerSettingsAndFails(t *testing.T) { defer conn.Close() } }() - cleanup := setMinConnectTimeout(time.Second / 4) - defer cleanup() ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() - client, err := DialContext(ctx, lis.Addr().String(), WithInsecure(), WithWaitForHandshake(), WithBlock(), withBackoff(noBackoff{})) + client, err := DialContext(ctx, + lis.Addr().String(), + WithInsecure(), + WithWaitForHandshake(), + WithBlock(), + withBackoff(noBackoff{}), + withMinConnectDeadline(func() time.Duration { return time.Second / 4 })) lis.Close() if err == nil { client.Close() @@ -300,11 +294,14 @@ func (s) TestDialWaitsForServerSettingsViaEnvAndFails(t *testing.T) { defer conn.Close() } }() - cleanup := setMinConnectTimeout(time.Second / 4) - defer cleanup() ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() - client, err := DialContext(ctx, lis.Addr().String(), WithInsecure(), WithBlock(), withBackoff(noBackoff{})) + client, err := DialContext(ctx, + lis.Addr().String(), + WithInsecure(), + WithBlock(), + withBackoff(noBackoff{}), + withMinConnectDeadline(func() time.Duration { return time.Second / 4 })) lis.Close() if err == nil { client.Close() @@ -358,19 +355,16 @@ func (s) TestDialDoesNotWaitForServerSettings(t *testing.T) { close(dialDone) } +// 1. Client connects to a server that doesn't send preface. +// 2. After minConnectTimeout(500 ms here), client disconnects and retries. +// 3. The new server sends its preface. +// 4. Client doesn't kill the connection this time. func (s) TestCloseConnectionWhenServerPrefaceNotReceived(t *testing.T) { // Restore current setting after test. old := envconfig.RequireHandshake defer func() { envconfig.RequireHandshake = old }() envconfig.RequireHandshake = envconfig.RequireHandshakeOn - // 1. Client connects to a server that doesn't send preface. - // 2. After minConnectTimeout(500 ms here), client disconnects and retries. - // 3. The new server sends its preface. - // 4. Client doesn't kill the connection this time. - cleanup := setMinConnectTimeout(time.Millisecond * 500) - defer cleanup() - lis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("Error while listening. Err: %v", err) @@ -424,7 +418,7 @@ func (s) TestCloseConnectionWhenServerPrefaceNotReceived(t *testing.T) { break } }() - client, err := Dial(lis.Addr().String(), WithInsecure()) + client, err := Dial(lis.Addr().String(), WithInsecure(), withMinConnectDeadline(func() time.Duration { return time.Millisecond * 500 })) if err != nil { t.Fatalf("Error while dialing. Err: %v", err) } @@ -610,12 +604,8 @@ func (s) TestWithAuthorityAndTLS(t *testing.T) { // backoff once per "round" of attempts instead of once per address (n times // per "round" of attempts). func (s) TestDial_OneBackoffPerRetryGroup(t *testing.T) { - getMinConnectTimeoutBackup := getMinConnectTimeout - defer func() { - getMinConnectTimeout = getMinConnectTimeoutBackup - }() var attempts uint32 - getMinConnectTimeout = func() time.Duration { + getMinConnectTimeout := func() time.Duration { if atomic.AddUint32(&attempts, 1) == 1 { // Once all addresses are exhausted, hang around and wait for the // client.Close to happen rather than re-starting a new round of @@ -671,7 +661,11 @@ func (s) TestDial_OneBackoffPerRetryGroup(t *testing.T) { {Addr: lis1.Addr().String()}, {Addr: lis2.Addr().String()}, }) - client, err := DialContext(ctx, "this-gets-overwritten", WithInsecure(), WithBalancerName(stateRecordingBalancerName), withResolverBuilder(rb)) + client, err := DialContext(ctx, "this-gets-overwritten", + WithInsecure(), + WithBalancerName(stateRecordingBalancerName), + withResolverBuilder(rb), + withMinConnectDeadline(getMinConnectTimeout)) if err != nil { t.Fatal(err) } @@ -1079,9 +1073,6 @@ func (s) TestBackoffCancel(t *testing.T) { // UpdateAddresses should cause the next reconnect to begin from the top of the // list if the connection is not READY. func (s) TestUpdateAddresses_RetryFromFirstAddr(t *testing.T) { - cleanup := setMinConnectTimeout(time.Hour) - defer cleanup() - lis1, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("Error while listening. Err: %v", err) @@ -1188,7 +1179,13 @@ func (s) TestUpdateAddresses_RetryFromFirstAddr(t *testing.T) { rb := manual.NewBuilderWithScheme("whatever") rb.InitialAddrs(addrsList) - client, err := Dial("this-gets-overwritten", WithInsecure(), WithWaitForHandshake(), withResolverBuilder(rb), withBackoff(noBackoff{}), WithBalancerName(stateRecordingBalancerName)) + client, err := Dial("this-gets-overwritten", + WithInsecure(), + WithWaitForHandshake(), + withResolverBuilder(rb), + withBackoff(noBackoff{}), + WithBalancerName(stateRecordingBalancerName), + withMinConnectDeadline(func() time.Duration { return time.Hour })) if err != nil { t.Fatal(err) } @@ -1235,12 +1232,3 @@ func (s) TestUpdateAddresses_RetryFromFirstAddr(t *testing.T) { t.Fatal("timed out waiting for any server to be contacted after tryUpdateAddrs") } } - -// Set the minConnectTimeout. Be sure to defer cleanup! -func setMinConnectTimeout(newMin time.Duration) (cleanup func()) { - mctBkp := getMinConnectTimeout() - atomic.StoreInt64((*int64)(&mutableMinConnectTimeout), int64(newMin)) - return func() { - atomic.StoreInt64((*int64)(&mutableMinConnectTimeout), int64(mctBkp)) - } -} diff --git a/dialoptions.go b/dialoptions.go index 537b25860..a0743a9e7 100644 --- a/dialoptions.go +++ b/dialoptions.go @@ -62,6 +62,7 @@ type dialOptions struct { disableRetry bool disableHealthCheck bool healthCheckFunc internal.HealthChecker + minConnectTimeout func() time.Duration } // DialOption configures how we set up the connection. @@ -470,7 +471,8 @@ func WithMaxHeaderListSize(s uint32) DialOption { }) } -// WithDisableHealthCheck disables the LB channel health checking for all SubConns of this ClientConn. +// WithDisableHealthCheck disables the LB channel health checking for all +// SubConns of this ClientConn. // // This API is EXPERIMENTAL. func WithDisableHealthCheck() DialOption { @@ -479,8 +481,8 @@ func WithDisableHealthCheck() DialOption { }) } -// withHealthCheckFunc replaces the default health check function with the provided one. It makes -// tests easier to change the health check function. +// withHealthCheckFunc replaces the default health check function with the +// provided one. It makes tests easier to change the health check function. // // For testing purpose only. func withHealthCheckFunc(f internal.HealthChecker) DialOption { @@ -500,3 +502,14 @@ func defaultDialOptions() dialOptions { }, } } + +// withGetMinConnectDeadline specifies the function that clientconn uses to +// get minConnectDeadline. This can be used to make connection attempts happen +// faster/slower. +// +// For testing purpose only. +func withMinConnectDeadline(f func() time.Duration) DialOption { + return newFuncDialOption(func(o *dialOptions) { + o.minConnectTimeout = f + }) +}