diff --git a/clientconn_test.go b/clientconn_test.go index 9f3299970..ee8372ad8 100644 --- a/clientconn_test.go +++ b/clientconn_test.go @@ -25,12 +25,14 @@ import ( "math" "net" "strings" + "sync" "sync/atomic" "testing" "time" "golang.org/x/net/http2" "google.golang.org/grpc/backoff" + "google.golang.org/grpc/balancer" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" @@ -44,6 +46,17 @@ import ( "google.golang.org/grpc/testdata" ) +const ( + defaultTestTimeout = 10 * time.Second + stateRecordingBalancerName = "state_recording_balancer" +) + +var testBalancerBuilder = newStateRecordingBalancerBuilder() + +func init() { + balancer.Register(testBalancerBuilder) +} + func parseCfg(r *manual.Resolver, s string) *serviceconfig.ParseResult { scpr := r.CC.ParseServiceConfig(s) if scpr.Err != nil { @@ -221,8 +234,10 @@ func (s) TestDialWaitsForServerSettingsAndFails(t *testing.T) { lis.Addr().String(), WithTransportCredentials(insecure.NewCredentials()), WithReturnConnectionError(), - withBackoff(noBackoff{}), - withMinConnectDeadline(func() time.Duration { return time.Second / 4 })) + WithConnectParams(ConnectParams{ + Backoff: backoff.Config{}, + MinConnectTimeout: 250 * time.Millisecond, + })) lis.Close() if err == nil { client.Close() @@ -453,7 +468,6 @@ func (s) TestDial_OneBackoffPerRetryGroup(t *testing.T) { }}) client, err := DialContext(ctx, "whatever:///this-gets-overwritten", WithTransportCredentials(insecure.NewCredentials()), - WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateRecordingBalancerName)), WithResolvers(rb), withMinConnectDeadline(getMinConnectTimeout)) if err != nil { @@ -976,9 +990,11 @@ func (s) TestUpdateAddresses_NoopIfCalledWithSameAddresses(t *testing.T) { client, err := Dial("whatever:///this-gets-overwritten", WithTransportCredentials(insecure.NewCredentials()), WithResolvers(rb), - withBackoff(noBackoff{}), - WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateRecordingBalancerName)), - withMinConnectDeadline(func() time.Duration { return time.Hour })) + WithConnectParams(ConnectParams{ + Backoff: backoff.Config{}, + MinConnectTimeout: time.Hour, + }), + WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateRecordingBalancerName))) if err != nil { t.Fatal(err) } @@ -1113,6 +1129,66 @@ func testDefaultServiceConfigWhenResolverReturnInvalidServiceConfig(t *testing.T } } +type stateRecordingBalancer struct { + notifier chan<- connectivity.State + balancer.Balancer +} + +func (b *stateRecordingBalancer) UpdateSubConnState(sc balancer.SubConn, s balancer.SubConnState) { + b.notifier <- s.ConnectivityState + b.Balancer.UpdateSubConnState(sc, s) +} + +func (b *stateRecordingBalancer) ResetNotifier(r chan<- connectivity.State) { + b.notifier = r +} + +func (b *stateRecordingBalancer) Close() { + b.Balancer.Close() +} + +type stateRecordingBalancerBuilder struct { + mu sync.Mutex + notifier chan connectivity.State // The notifier used in the last Balancer. +} + +func newStateRecordingBalancerBuilder() *stateRecordingBalancerBuilder { + return &stateRecordingBalancerBuilder{} +} + +func (b *stateRecordingBalancerBuilder) Name() string { + return stateRecordingBalancerName +} + +func (b *stateRecordingBalancerBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { + stateNotifications := make(chan connectivity.State, 10) + b.mu.Lock() + b.notifier = stateNotifications + b.mu.Unlock() + return &stateRecordingBalancer{ + notifier: stateNotifications, + Balancer: balancer.Get("pick_first").Build(cc, opts), + } +} + +func (b *stateRecordingBalancerBuilder) nextStateNotifier() <-chan connectivity.State { + b.mu.Lock() + defer b.mu.Unlock() + ret := b.notifier + b.notifier = nil + return ret +} + +// Keep reading until something causes the connection to die (EOF, server +// closed, etc). Useful as a tool for mindlessly keeping the connection +// healthy, since the client will error if things like client prefaces are not +// accepted in a timely fashion. +func keepReading(conn net.Conn) { + buf := make([]byte, 1024) + for _, err := conn.Read(buf); err == nil; _, err = conn.Read(buf) { + } +} + // stayConnected makes cc stay connected by repeatedly calling cc.Connect() // until the state becomes Shutdown or until 10 seconds elapses. func stayConnected(cc *ClientConn) { diff --git a/clientconn_state_transition_test.go b/test/clientconn_state_transition_test.go similarity index 85% rename from clientconn_state_transition_test.go rename to test/clientconn_state_transition_test.go index d1c1321b3..1f15c6905 100644 --- a/clientconn_state_transition_test.go +++ b/test/clientconn_state_transition_test.go @@ -16,7 +16,7 @@ * */ -package grpc +package test import ( "context" @@ -27,6 +27,8 @@ import ( "time" "golang.org/x/net/http2" + "google.golang.org/grpc" + "google.golang.org/grpc/backoff" "google.golang.org/grpc/balancer" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials/insecure" @@ -35,10 +37,7 @@ import ( "google.golang.org/grpc/resolver/manual" ) -const ( - stateRecordingBalancerName = "state_recoding_balancer" - defaultTestTimeout = 10 * time.Second -) +const stateRecordingBalancerName = "state_recording_balancer" var testBalancerBuilder = newStateRecordingBalancerBuilder() @@ -158,17 +157,22 @@ func testStateTransitionSingleAddress(t *testing.T, want []connectivity.State, s connMu.Unlock() }() - client, err := Dial("", - WithTransportCredentials(insecure.NewCredentials()), - WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateRecordingBalancerName)), - WithDialer(pl.Dialer()), - withBackoff(noBackoff{}), - withMinConnectDeadline(func() time.Duration { return time.Millisecond * 100 })) + client, err := grpc.Dial("", + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateRecordingBalancerName)), + grpc.WithDialer(pl.Dialer()), + grpc.WithConnectParams(grpc.ConnectParams{ + Backoff: backoff.Config{}, + MinConnectTimeout: 100 * time.Millisecond, + })) if err != nil { t.Fatal(err) } defer client.Close() - go stayConnected(client) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + go stayConnected(ctx, client) stateNotifications := testBalancerBuilder.nextStateNotifier() for i := 0; i < len(want); i++ { @@ -225,14 +229,17 @@ func (s) TestStateTransitions_ReadyToConnecting(t *testing.T) { conn.Close() }() - client, err := Dial(lis.Addr().String(), - WithTransportCredentials(insecure.NewCredentials()), - WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateRecordingBalancerName))) + client, err := grpc.Dial(lis.Addr().String(), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateRecordingBalancerName))) if err != nil { t.Fatal(err) } defer client.Close() - go stayConnected(client) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + go stayConnected(ctx, client) stateNotifications := testBalancerBuilder.nextStateNotifier() @@ -310,10 +317,10 @@ func (s) TestStateTransitions_TriesAllAddrsBeforeTransientFailure(t *testing.T) {Addr: lis1.Addr().String()}, {Addr: lis2.Addr().String()}, }}) - client, err := Dial("whatever:///this-gets-overwritten", - WithTransportCredentials(insecure.NewCredentials()), - WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateRecordingBalancerName)), - WithResolvers(rb)) + client, err := grpc.Dial("whatever:///this-gets-overwritten", + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateRecordingBalancerName)), + grpc.WithResolvers(rb)) if err != nil { t.Fatal(err) } @@ -396,15 +403,18 @@ func (s) TestStateTransitions_MultipleAddrsEntersReady(t *testing.T) { {Addr: lis1.Addr().String()}, {Addr: lis2.Addr().String()}, }}) - client, err := Dial("whatever:///this-gets-overwritten", - WithTransportCredentials(insecure.NewCredentials()), - WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateRecordingBalancerName)), - WithResolvers(rb)) + client, err := grpc.Dial("whatever:///this-gets-overwritten", + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateRecordingBalancerName)), + grpc.WithResolvers(rb)) if err != nil { t.Fatal(err) } defer client.Close() - go stayConnected(client) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + go stayConnected(ctx, client) stateNotifications := testBalancerBuilder.nextStateNotifier() want := []connectivity.State{ @@ -413,8 +423,6 @@ func (s) TestStateTransitions_MultipleAddrsEntersReady(t *testing.T) { connectivity.Idle, connectivity.Connecting, } - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() for i := 0; i < len(want); i++ { select { case <-ctx.Done(): @@ -473,7 +481,7 @@ func (b *stateRecordingBalancerBuilder) Build(cc balancer.ClientConn, opts balan b.mu.Unlock() return &stateRecordingBalancer{ notifier: stateNotifications, - Balancer: balancer.Get(PickFirstBalancerName).Build(cc, opts), + Balancer: balancer.Get("pick_first").Build(cc, opts), } } @@ -485,10 +493,6 @@ func (b *stateRecordingBalancerBuilder) nextStateNotifier() <-chan connectivity. return ret } -type noBackoff struct{} - -func (b noBackoff) Backoff(int) time.Duration { return time.Duration(0) } - // Keep reading until something causes the connection to die (EOF, server // closed, etc). Useful as a tool for mindlessly keeping the connection // healthy, since the client will error if things like client prefaces are not @@ -498,3 +502,20 @@ func keepReading(conn net.Conn) { for _, err := conn.Read(buf); err == nil; _, err = conn.Read(buf) { } } + +// stayConnected makes cc stay connected by repeatedly calling cc.Connect() +// until the state becomes Shutdown or until ithe context expires. +func stayConnected(ctx context.Context, cc *grpc.ClientConn) { + for { + state := cc.GetState() + switch state { + case connectivity.Idle: + cc.Connect() + case connectivity.Shutdown: + return + } + if !cc.WaitForStateChange(ctx, state) { + return + } + } +}