diff --git a/balancer/leastrequest/leastrequest.go b/balancer/leastrequest/leastrequest.go index dd46dfa8f..f758f954e 100644 --- a/balancer/leastrequest/leastrequest.go +++ b/balancer/leastrequest/leastrequest.go @@ -97,11 +97,8 @@ func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Ba } type leastRequestBalancer struct { - // Embeds balancer.Balancer because needs to intercept UpdateClientConnState - // to learn about choiceCount. - balancer.Balancer - // Embeds balancer.ClientConn because needs to intercept UpdateState calls - // from the child balancer. + // Embeds balancer.ClientConn because we need to intercept UpdateState + // calls from the child balancer. balancer.ClientConn child balancer.Balancer logger *internalgrpclog.PrefixLogger @@ -118,6 +115,21 @@ func (lrb *leastRequestBalancer) Close() { lrb.endpointRPCCounts = nil } +func (lrb *leastRequestBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { + lrb.logger.Errorf("UpdateSubConnState(%v, %+v) called unexpectedly", sc, state) +} + +func (lrb *leastRequestBalancer) ResolverError(err error) { + // Will cause inline picker update from endpoint sharding. + lrb.child.ResolverError(err) +} + +func (lrb *leastRequestBalancer) ExitIdle() { + if ei, ok := lrb.child.(balancer.ExitIdler); ok { // Should always be ok, as child is endpoint sharding. + ei.ExitIdle() + } +} + func (lrb *leastRequestBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error { lrCfg, ok := ccs.BalancerConfig.(*LBConfig) if !ok { diff --git a/balancer/leastrequest/balancer_test.go b/balancer/leastrequest/leastrequest_test.go similarity index 91% rename from balancer/leastrequest/balancer_test.go rename to balancer/leastrequest/leastrequest_test.go index e0043db39..d8073470f 100644 --- a/balancer/leastrequest/balancer_test.go +++ b/balancer/leastrequest/leastrequest_test.go @@ -27,12 +27,13 @@ import ( "time" "github.com/google/go-cmp/cmp" - "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/internal/testutils" testgrpc "google.golang.org/grpc/interop/grpc_testing" testpb "google.golang.org/grpc/interop/grpc_testing" "google.golang.org/grpc/peer" @@ -42,7 +43,8 @@ import ( ) const ( - defaultTestTimeout = 5 * time.Second + defaultTestTimeout = 5 * time.Second + defaultTestShortTimeout = 10 * time.Millisecond ) type s struct { @@ -706,3 +708,63 @@ func (s) TestLeastRequestEndpoints_MultipleAddresses(t *testing.T) { t.Fatalf("error in expected round robin: %v", err) } } + +// Test tests that the least request balancer properly surfaces resolver +// errors. +func (s) TestLeastRequestEndpoints_ResolverError(t *testing.T) { + const sc = `{"loadBalancingConfig": [{"least_request_experimental": {}}]}` + mr := manual.NewBuilderWithScheme("lr-e2e") + defer mr.Close() + + cc, err := grpc.NewClient( + mr.Scheme()+":///", + grpc.WithResolvers(mr), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultServiceConfig(sc), + ) + if err != nil { + t.Fatalf("grpc.NewClient() failed: %v", err) + } + defer cc.Close() + + // We need to pass an endpoint with a valid address to the resolver before + // reporting an error - otherwise endpointsharding does not report the + // error through. + lis, err := testutils.LocalTCPListener() + if err != nil { + t.Fatalf("net.Listen() failed: %v", err) + } + // Act like a server that closes the connection without sending a server + // preface. + go func() { + conn, err := lis.Accept() + if err != nil { + t.Errorf("Unexpected error when accepting a connection: %v", err) + } + conn.Close() + }() + mr.UpdateState(resolver.State{ + Endpoints: []resolver.Endpoint{{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}}}, + }) + cc.Connect() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure) + + // Report an error through the resolver + resolverErr := fmt.Errorf("simulated resolver error") + mr.CC().ReportError(resolverErr) + + // Ensure the client returns the expected resolver error. + testServiceClient := testgrpc.NewTestServiceClient(cc) + for ; ctx.Err() == nil; <-time.After(defaultTestShortTimeout) { + _, err = testServiceClient.EmptyCall(ctx, &testpb.Empty{}) + if strings.Contains(err.Error(), resolverErr.Error()) { + break + } + } + if ctx.Err() != nil { + t.Fatalf("Timeout when waiting for RPCs to fail with error containing %s. Last error: %v", resolverErr, err) + } +}