mirror of https://github.com/grpc/grpc-go.git
Fix test race: Atomically access minConnecTimout in testing environment. (#1897)
This commit is contained in:
parent
3ae2a613bc
commit
207e2760fd
|
|
@ -45,6 +45,11 @@ import (
|
||||||
"google.golang.org/grpc/transport"
|
"google.golang.org/grpc/transport"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// minimum time to give a connection to complete
|
||||||
|
minConnectTimeout = 20 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// ErrClientConnClosing indicates that the operation is illegal because
|
// ErrClientConnClosing indicates that the operation is illegal because
|
||||||
// the ClientConn is closing.
|
// the ClientConn is closing.
|
||||||
|
|
@ -60,8 +65,11 @@ var (
|
||||||
errConnUnavailable = errors.New("grpc: the connection is unavailable")
|
errConnUnavailable = errors.New("grpc: the connection is unavailable")
|
||||||
// errBalancerClosed indicates that the balancer is closed.
|
// errBalancerClosed indicates that the balancer is closed.
|
||||||
errBalancerClosed = errors.New("grpc: balancer is closed")
|
errBalancerClosed = errors.New("grpc: balancer is closed")
|
||||||
// minimum time to give a connection to complete
|
// We use an accessor so that minConnectTimeout can be
|
||||||
minConnectTimeout = 20 * time.Second
|
// atomically read and updated while testing.
|
||||||
|
getMinConnectTimeout = func() time.Duration {
|
||||||
|
return minConnectTimeout
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
// The following errors are returned from Dial and DialContext
|
// The following errors are returned from Dial and DialContext
|
||||||
|
|
@ -1055,7 +1063,7 @@ func (ac *addrConn) resetTransport() error {
|
||||||
// connection.
|
// connection.
|
||||||
backoffFor := ac.dopts.bs.backoff(connectRetryNum) // time.Duration.
|
backoffFor := ac.dopts.bs.backoff(connectRetryNum) // time.Duration.
|
||||||
// This will be the duration that dial gets to finish.
|
// This will be the duration that dial gets to finish.
|
||||||
dialDuration := minConnectTimeout
|
dialDuration := getMinConnectTimeout()
|
||||||
if backoffFor > dialDuration {
|
if backoffFor > dialDuration {
|
||||||
// Give dial more time as we keep failing to connect.
|
// Give dial more time as we keep failing to connect.
|
||||||
dialDuration = backoffFor
|
dialDuration = backoffFor
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,16 @@ import (
|
||||||
"google.golang.org/grpc/testdata"
|
"google.golang.org/grpc/testdata"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
mutableMinConnectTimeout = time.Second * 20
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
getMinConnectTimeout = func() time.Duration {
|
||||||
|
return time.Duration(atomic.LoadInt64((*int64)(&mutableMinConnectTimeout)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func assertState(wantState connectivity.State, cc *ClientConn) (connectivity.State, bool) {
|
func assertState(wantState connectivity.State, cc *ClientConn) (connectivity.State, bool) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
@ -169,13 +179,14 @@ func TestDialWaitsForServerSettings(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCloseConnectionWhenServerPrefaceNotReceived(t *testing.T) {
|
func TestCloseConnectionWhenServerPrefaceNotReceived(t *testing.T) {
|
||||||
mctBkp := minConnectTimeout
|
mctBkp := getMinConnectTimeout()
|
||||||
// Call this only after transportMonitor goroutine has ended.
|
// Call this only after transportMonitor goroutine has ended.
|
||||||
defer func() {
|
defer func() {
|
||||||
minConnectTimeout = mctBkp
|
atomic.StoreInt64((*int64)(&mutableMinConnectTimeout), int64(mctBkp))
|
||||||
|
|
||||||
}()
|
}()
|
||||||
defer leakcheck.Check(t)
|
defer leakcheck.Check(t)
|
||||||
minConnectTimeout = time.Millisecond * 500
|
atomic.StoreInt64((*int64)(&mutableMinConnectTimeout), int64(time.Millisecond)*500)
|
||||||
lis, err := net.Listen("tcp", "localhost:0")
|
lis, err := net.Listen("tcp", "localhost:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error while listening. Err: %v", err)
|
t.Fatalf("Error while listening. Err: %v", err)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue