test: move clientconn state transition test to test/ directory (#5551)

This commit is contained in:
Easwar Swaminathan 2022-08-02 12:31:30 -07:00 committed by GitHub
parent 23f015c36d
commit 57aaa10b8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 135 additions and 38 deletions

View File

@ -25,12 +25,14 @@ import (
"math" "math"
"net" "net"
"strings" "strings"
"sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"google.golang.org/grpc/backoff" "google.golang.org/grpc/backoff"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity" "google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
@ -44,6 +46,17 @@ import (
"google.golang.org/grpc/testdata" "google.golang.org/grpc/testdata"
) )
const (
defaultTestTimeout = 10 * time.Second
stateRecordingBalancerName = "state_recording_balancer"
)
var testBalancerBuilder = newStateRecordingBalancerBuilder()
func init() {
balancer.Register(testBalancerBuilder)
}
func parseCfg(r *manual.Resolver, s string) *serviceconfig.ParseResult { func parseCfg(r *manual.Resolver, s string) *serviceconfig.ParseResult {
scpr := r.CC.ParseServiceConfig(s) scpr := r.CC.ParseServiceConfig(s)
if scpr.Err != nil { if scpr.Err != nil {
@ -221,8 +234,10 @@ func (s) TestDialWaitsForServerSettingsAndFails(t *testing.T) {
lis.Addr().String(), lis.Addr().String(),
WithTransportCredentials(insecure.NewCredentials()), WithTransportCredentials(insecure.NewCredentials()),
WithReturnConnectionError(), WithReturnConnectionError(),
withBackoff(noBackoff{}), WithConnectParams(ConnectParams{
withMinConnectDeadline(func() time.Duration { return time.Second / 4 })) Backoff: backoff.Config{},
MinConnectTimeout: 250 * time.Millisecond,
}))
lis.Close() lis.Close()
if err == nil { if err == nil {
client.Close() client.Close()
@ -453,7 +468,6 @@ func (s) TestDial_OneBackoffPerRetryGroup(t *testing.T) {
}}) }})
client, err := DialContext(ctx, "whatever:///this-gets-overwritten", client, err := DialContext(ctx, "whatever:///this-gets-overwritten",
WithTransportCredentials(insecure.NewCredentials()), WithTransportCredentials(insecure.NewCredentials()),
WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateRecordingBalancerName)),
WithResolvers(rb), WithResolvers(rb),
withMinConnectDeadline(getMinConnectTimeout)) withMinConnectDeadline(getMinConnectTimeout))
if err != nil { if err != nil {
@ -976,9 +990,11 @@ func (s) TestUpdateAddresses_NoopIfCalledWithSameAddresses(t *testing.T) {
client, err := Dial("whatever:///this-gets-overwritten", client, err := Dial("whatever:///this-gets-overwritten",
WithTransportCredentials(insecure.NewCredentials()), WithTransportCredentials(insecure.NewCredentials()),
WithResolvers(rb), WithResolvers(rb),
withBackoff(noBackoff{}), WithConnectParams(ConnectParams{
WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateRecordingBalancerName)), Backoff: backoff.Config{},
withMinConnectDeadline(func() time.Duration { return time.Hour })) MinConnectTimeout: time.Hour,
}),
WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateRecordingBalancerName)))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1113,6 +1129,66 @@ func testDefaultServiceConfigWhenResolverReturnInvalidServiceConfig(t *testing.T
} }
} }
type stateRecordingBalancer struct {
notifier chan<- connectivity.State
balancer.Balancer
}
func (b *stateRecordingBalancer) UpdateSubConnState(sc balancer.SubConn, s balancer.SubConnState) {
b.notifier <- s.ConnectivityState
b.Balancer.UpdateSubConnState(sc, s)
}
func (b *stateRecordingBalancer) ResetNotifier(r chan<- connectivity.State) {
b.notifier = r
}
func (b *stateRecordingBalancer) Close() {
b.Balancer.Close()
}
type stateRecordingBalancerBuilder struct {
mu sync.Mutex
notifier chan connectivity.State // The notifier used in the last Balancer.
}
func newStateRecordingBalancerBuilder() *stateRecordingBalancerBuilder {
return &stateRecordingBalancerBuilder{}
}
func (b *stateRecordingBalancerBuilder) Name() string {
return stateRecordingBalancerName
}
func (b *stateRecordingBalancerBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
stateNotifications := make(chan connectivity.State, 10)
b.mu.Lock()
b.notifier = stateNotifications
b.mu.Unlock()
return &stateRecordingBalancer{
notifier: stateNotifications,
Balancer: balancer.Get("pick_first").Build(cc, opts),
}
}
func (b *stateRecordingBalancerBuilder) nextStateNotifier() <-chan connectivity.State {
b.mu.Lock()
defer b.mu.Unlock()
ret := b.notifier
b.notifier = nil
return ret
}
// Keep reading until something causes the connection to die (EOF, server
// closed, etc). Useful as a tool for mindlessly keeping the connection
// healthy, since the client will error if things like client prefaces are not
// accepted in a timely fashion.
func keepReading(conn net.Conn) {
buf := make([]byte, 1024)
for _, err := conn.Read(buf); err == nil; _, err = conn.Read(buf) {
}
}
// stayConnected makes cc stay connected by repeatedly calling cc.Connect() // stayConnected makes cc stay connected by repeatedly calling cc.Connect()
// until the state becomes Shutdown or until 10 seconds elapses. // until the state becomes Shutdown or until 10 seconds elapses.
func stayConnected(cc *ClientConn) { func stayConnected(cc *ClientConn) {

View File

@ -16,7 +16,7 @@
* *
*/ */
package grpc package test
import ( import (
"context" "context"
@ -27,6 +27,8 @@ import (
"time" "time"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"google.golang.org/grpc"
"google.golang.org/grpc/backoff"
"google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity" "google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
@ -35,10 +37,7 @@ import (
"google.golang.org/grpc/resolver/manual" "google.golang.org/grpc/resolver/manual"
) )
const ( const stateRecordingBalancerName = "state_recording_balancer"
stateRecordingBalancerName = "state_recoding_balancer"
defaultTestTimeout = 10 * time.Second
)
var testBalancerBuilder = newStateRecordingBalancerBuilder() var testBalancerBuilder = newStateRecordingBalancerBuilder()
@ -158,17 +157,22 @@ func testStateTransitionSingleAddress(t *testing.T, want []connectivity.State, s
connMu.Unlock() connMu.Unlock()
}() }()
client, err := Dial("", client, err := grpc.Dial("",
WithTransportCredentials(insecure.NewCredentials()), grpc.WithTransportCredentials(insecure.NewCredentials()),
WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateRecordingBalancerName)), grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateRecordingBalancerName)),
WithDialer(pl.Dialer()), grpc.WithDialer(pl.Dialer()),
withBackoff(noBackoff{}), grpc.WithConnectParams(grpc.ConnectParams{
withMinConnectDeadline(func() time.Duration { return time.Millisecond * 100 })) Backoff: backoff.Config{},
MinConnectTimeout: 100 * time.Millisecond,
}))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer client.Close() defer client.Close()
go stayConnected(client)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
go stayConnected(ctx, client)
stateNotifications := testBalancerBuilder.nextStateNotifier() stateNotifications := testBalancerBuilder.nextStateNotifier()
for i := 0; i < len(want); i++ { for i := 0; i < len(want); i++ {
@ -225,14 +229,17 @@ func (s) TestStateTransitions_ReadyToConnecting(t *testing.T) {
conn.Close() conn.Close()
}() }()
client, err := Dial(lis.Addr().String(), client, err := grpc.Dial(lis.Addr().String(),
WithTransportCredentials(insecure.NewCredentials()), grpc.WithTransportCredentials(insecure.NewCredentials()),
WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateRecordingBalancerName))) grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateRecordingBalancerName)))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer client.Close() defer client.Close()
go stayConnected(client)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
go stayConnected(ctx, client)
stateNotifications := testBalancerBuilder.nextStateNotifier() stateNotifications := testBalancerBuilder.nextStateNotifier()
@ -310,10 +317,10 @@ func (s) TestStateTransitions_TriesAllAddrsBeforeTransientFailure(t *testing.T)
{Addr: lis1.Addr().String()}, {Addr: lis1.Addr().String()},
{Addr: lis2.Addr().String()}, {Addr: lis2.Addr().String()},
}}) }})
client, err := Dial("whatever:///this-gets-overwritten", client, err := grpc.Dial("whatever:///this-gets-overwritten",
WithTransportCredentials(insecure.NewCredentials()), grpc.WithTransportCredentials(insecure.NewCredentials()),
WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateRecordingBalancerName)), grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateRecordingBalancerName)),
WithResolvers(rb)) grpc.WithResolvers(rb))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -396,15 +403,18 @@ func (s) TestStateTransitions_MultipleAddrsEntersReady(t *testing.T) {
{Addr: lis1.Addr().String()}, {Addr: lis1.Addr().String()},
{Addr: lis2.Addr().String()}, {Addr: lis2.Addr().String()},
}}) }})
client, err := Dial("whatever:///this-gets-overwritten", client, err := grpc.Dial("whatever:///this-gets-overwritten",
WithTransportCredentials(insecure.NewCredentials()), grpc.WithTransportCredentials(insecure.NewCredentials()),
WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateRecordingBalancerName)), grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateRecordingBalancerName)),
WithResolvers(rb)) grpc.WithResolvers(rb))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer client.Close() defer client.Close()
go stayConnected(client)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
go stayConnected(ctx, client)
stateNotifications := testBalancerBuilder.nextStateNotifier() stateNotifications := testBalancerBuilder.nextStateNotifier()
want := []connectivity.State{ want := []connectivity.State{
@ -413,8 +423,6 @@ func (s) TestStateTransitions_MultipleAddrsEntersReady(t *testing.T) {
connectivity.Idle, connectivity.Idle,
connectivity.Connecting, connectivity.Connecting,
} }
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for i := 0; i < len(want); i++ { for i := 0; i < len(want); i++ {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -473,7 +481,7 @@ func (b *stateRecordingBalancerBuilder) Build(cc balancer.ClientConn, opts balan
b.mu.Unlock() b.mu.Unlock()
return &stateRecordingBalancer{ return &stateRecordingBalancer{
notifier: stateNotifications, notifier: stateNotifications,
Balancer: balancer.Get(PickFirstBalancerName).Build(cc, opts), Balancer: balancer.Get("pick_first").Build(cc, opts),
} }
} }
@ -485,10 +493,6 @@ func (b *stateRecordingBalancerBuilder) nextStateNotifier() <-chan connectivity.
return ret return ret
} }
type noBackoff struct{}
func (b noBackoff) Backoff(int) time.Duration { return time.Duration(0) }
// Keep reading until something causes the connection to die (EOF, server // Keep reading until something causes the connection to die (EOF, server
// closed, etc). Useful as a tool for mindlessly keeping the connection // closed, etc). Useful as a tool for mindlessly keeping the connection
// healthy, since the client will error if things like client prefaces are not // healthy, since the client will error if things like client prefaces are not
@ -498,3 +502,20 @@ func keepReading(conn net.Conn) {
for _, err := conn.Read(buf); err == nil; _, err = conn.Read(buf) { for _, err := conn.Read(buf); err == nil; _, err = conn.Read(buf) {
} }
} }
// stayConnected makes cc stay connected by repeatedly calling cc.Connect()
// until the state becomes Shutdown or until ithe context expires.
func stayConnected(ctx context.Context, cc *grpc.ClientConn) {
for {
state := cc.GetState()
switch state {
case connectivity.Idle:
cc.Connect()
case connectivity.Shutdown:
return
}
if !cc.WaitForStateChange(ctx, state) {
return
}
}
}