leastrequest: Delegate subchannel creation to pickfirst (#7969)

This commit is contained in:
Arjan Singh Bal 2025-01-15 12:20:45 +05:30 committed by GitHub
parent 74ac821433
commit 130c1d73d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 329 additions and 73 deletions

View File

@ -117,14 +117,9 @@ func (s) TestParseConfig(t *testing.T) {
}
}
// setupBackends spins up three test backends, each listening on a port on
// localhost. The three backends always reply with an empty response with no
// error, and for streaming receive until hitting an EOF error.
func setupBackends(t *testing.T) []string {
t.Helper()
const numBackends = 3
addresses := make([]string, numBackends)
// Construct and start three working backends.
func startBackends(t *testing.T, numBackends int) []*stubserver.StubServer {
backends := make([]*stubserver.StubServer, 0, numBackends)
// Construct and start working backends.
for i := 0; i < numBackends; i++ {
backend := &stubserver.StubServer{
EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
@ -140,7 +135,21 @@ func setupBackends(t *testing.T) []string {
}
t.Logf("Started good TestService backend at: %q", backend.Address)
t.Cleanup(func() { backend.Stop() })
addresses[i] = backend.Address
backends = append(backends, backend)
}
return backends
}
// setupBackends spins up three test backends, each listening on a port on
// localhost. The three backends always reply with an empty response with no
// error, and for streaming receive until hitting an EOF error.
func setupBackends(t *testing.T, numBackends int) []string {
t.Helper()
addresses := make([]string, numBackends)
backends := startBackends(t, numBackends)
// Construct and start working backends.
for i := 0; i < numBackends; i++ {
addresses[i] = backends[i].Address
}
return addresses
}
@ -205,7 +214,7 @@ func (s) TestLeastRequestE2E(t *testing.T) {
index++
return ret
}
addresses := setupBackends(t)
addresses := setupBackends(t, 3)
mr := manual.NewBuilderWithScheme("lr-e2e")
defer mr.Close()
@ -321,7 +330,7 @@ func (s) TestLeastRequestPersistsCounts(t *testing.T) {
index++
return ret
}
addresses := setupBackends(t)
addresses := setupBackends(t, 3)
mr := manual.NewBuilderWithScheme("lr-e2e")
defer mr.Close()
@ -462,7 +471,7 @@ func (s) TestLeastRequestPersistsCounts(t *testing.T) {
// and makes 100 RPCs asynchronously. This makes sure no race conditions happen
// in this scenario.
func (s) TestConcurrentRPCs(t *testing.T) {
addresses := setupBackends(t)
addresses := setupBackends(t, 3)
mr := manual.NewBuilderWithScheme("lr-e2e")
defer mr.Close()
@ -508,5 +517,192 @@ func (s) TestConcurrentRPCs(t *testing.T) {
}()
}
wg.Wait()
}
// Test tests that the least request balancer persists RPC counts once it gets
// new picker updates and backends within an endpoint go down. It first updates
// the balancer with two endpoints having two addresses each. It verifies the
// requests are round robined across the first address of each endpoint. It then
// stops the active backend in endpoint[0]. It verified that the balancer starts
// using the second address in endpoint[0]. The test then creates a bunch of
// streams on two endpoints. Then, it updates the balancer with three endpoints,
// including the two previous. Any created streams should then be started on the
// new endpoint. The test shuts down the active backed in endpoint[1] and
// endpoint[2]. The test verifies that new RPCs are round robined across the
// active backends in endpoint[1] and endpoint[2].
func (s) TestLeastRequestEndpoints_MultipleAddresses(t *testing.T) {
defer func(u func() uint32) {
randuint32 = u
}(randuint32)
var index int
indexes := []uint32{
0, 0, 1, 1,
}
randuint32 = func() uint32 {
ret := indexes[index%len(indexes)]
index++
return ret
}
backends := startBackends(t, 6)
mr := manual.NewBuilderWithScheme("lr-e2e")
defer mr.Close()
// Configure least request as top level balancer of channel.
lrscJSON := `
{
"loadBalancingConfig": [
{
"least_request_experimental": {
"choiceCount": 2
}
}
]
}`
endpoints := []resolver.Endpoint{
{Addresses: []resolver.Address{{Addr: backends[0].Address}, {Addr: backends[1].Address}}},
{Addresses: []resolver.Address{{Addr: backends[2].Address}, {Addr: backends[3].Address}}},
{Addresses: []resolver.Address{{Addr: backends[4].Address}, {Addr: backends[5].Address}}},
}
sc := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(lrscJSON)
firstTwoEndpoints := []resolver.Endpoint{endpoints[0], endpoints[1]}
mr.InitialState(resolver.State{
Endpoints: firstTwoEndpoints,
ServiceConfig: sc,
})
cc, err := grpc.NewClient(mr.Scheme()+":///", grpc.WithResolvers(mr), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.NewClient() failed: %v", err)
}
defer cc.Close()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
testServiceClient := testgrpc.NewTestServiceClient(cc)
// Wait for the two backends to round robin across. The happens because a
// child pickfirst transitioning into READY causes a new picker update. Once
// the picker update with the two backends is present, this test can start
// to populate those backends with streams.
wantAddrs := []resolver.Address{
endpoints[0].Addresses[0],
endpoints[1].Addresses[0],
}
if err := checkRoundRobinRPCs(ctx, testServiceClient, wantAddrs); err != nil {
t.Fatalf("error in expected round robin: %v", err)
}
// Shut down one of the addresses in endpoints[0], the child pickfirst
// should fallback to the next address in endpoints[0].
backends[0].Stop()
wantAddrs = []resolver.Address{
endpoints[0].Addresses[1],
endpoints[1].Addresses[0],
}
if err := checkRoundRobinRPCs(ctx, testServiceClient, wantAddrs); err != nil {
t.Fatalf("error in expected round robin: %v", err)
}
// Start 50 streaming RPCs, and leave them unfinished for the duration of
// the test. This will populate the first two endpoints with many active
// RPCs.
for i := 0; i < 50; i++ {
_, err := testServiceClient.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("testServiceClient.FullDuplexCall failed: %v", err)
}
}
// Update the least request balancer to choice count 3. Also update the
// address list adding a third endpoint. Alongside the injected randomness,
// this should trigger the least request balancer to search all created
// endpoints. Thus, since endpoint 3 is the new endpoint and the first two
// endpoint are populated with RPCs, once the picker update of all 3 READY
// pickfirsts takes effect, all new streams should be started on endpoint 3.
index = 0
indexes = []uint32{
0, 1, 2, 3, 4, 5,
}
lrscJSON = `
{
"loadBalancingConfig": [
{
"least_request_experimental": {
"choiceCount": 3
}
}
]
}`
sc = internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(lrscJSON)
mr.UpdateState(resolver.State{
Endpoints: endpoints,
ServiceConfig: sc,
})
newAddress := endpoints[2].Addresses[0]
// Poll for only endpoint 3 to show up. This requires a polling loop because
// picker update with all three endpoints doesn't take into effect
// immediately, needs the third pickfirst to become READY.
if err := checkRoundRobinRPCs(ctx, testServiceClient, []resolver.Address{newAddress}); err != nil {
t.Fatalf("error in expected round robin: %v", err)
}
// Start 25 rpcs, but don't finish them. They should all start on endpoint 3,
// since the first two endpoints both have 25 RPCs (and randomness
// injection/choiceCount causes all 3 to be compared every iteration).
for i := 0; i < 25; i++ {
stream, err := testServiceClient.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("testServiceClient.FullDuplexCall failed: %v", err)
}
p, ok := peer.FromContext(stream.Context())
if !ok {
t.Fatalf("testServiceClient.FullDuplexCall has no Peer")
}
if p.Addr.String() != newAddress.Addr {
t.Fatalf("testServiceClient.FullDuplexCall's Peer got: %v, want: %v", p.Addr.String(), newAddress)
}
}
// Now 25 RPC's are active on each endpoint, the next three RPC's should
// round robin, since choiceCount is three and the injected random indexes
// cause it to search all three endpoints for fewest outstanding requests on
// each iteration.
wantAddrCount := map[string]int{
endpoints[0].Addresses[1].Addr: 1,
endpoints[1].Addresses[0].Addr: 1,
endpoints[2].Addresses[0].Addr: 1,
}
gotAddrCount := make(map[string]int)
for i := 0; i < len(endpoints); i++ {
stream, err := testServiceClient.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("testServiceClient.FullDuplexCall failed: %v", err)
}
p, ok := peer.FromContext(stream.Context())
if !ok {
t.Fatalf("testServiceClient.FullDuplexCall has no Peer")
}
if p.Addr != nil {
gotAddrCount[p.Addr.String()]++
}
}
if diff := cmp.Diff(gotAddrCount, wantAddrCount); diff != "" {
t.Fatalf("addr count (-got:, +want): %v", diff)
}
// Shutdown the active address for endpoint[1] and endpoint[2]. This should
// result in their streams failing. Now the requests should roundrobin b/w
// endpoint[1] and endpoint[2].
backends[2].Stop()
backends[4].Stop()
index = 0
indexes = []uint32{
0, 1, 2, 2, 1, 0,
}
wantAddrs = []resolver.Address{
endpoints[1].Addresses[1],
endpoints[2].Addresses[1],
}
if err := checkRoundRobinRPCs(ctx, testServiceClient, wantAddrs); err != nil {
t.Fatalf("error in expected round robin: %v", err)
}
}

View File

@ -23,21 +23,28 @@ import (
"encoding/json"
"fmt"
rand "math/rand/v2"
"sync"
"sync/atomic"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/balancer/endpointsharding"
"google.golang.org/grpc/balancer/pickfirst/pickfirstleaf"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/grpclog"
internalgrpclog "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig"
)
// randuint32 is a global to stub out in tests.
var randuint32 = rand.Uint32
// Name is the name of the least request balancer.
const Name = "least_request_experimental"
var logger = grpclog.Component("least-request")
var (
// randuint32 is a global to stub out in tests.
randuint32 = rand.Uint32
endpointShardingLBConfig = endpointsharding.PickFirstConfig
logger = grpclog.Component("least-request")
)
func init() {
balancer.Register(bb{})
@ -80,9 +87,13 @@ func (bb) Name() string {
}
func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer {
b := &leastRequestBalancer{scRPCCounts: make(map[balancer.SubConn]*atomic.Int32)}
baseBuilder := base.NewBalancerBuilder(Name, b, base.Config{HealthCheck: true})
b.Balancer = baseBuilder.Build(cc, bOpts)
b := &leastRequestBalancer{
ClientConn: cc,
endpointRPCCounts: resolver.NewEndpointMap(),
}
b.child = endpointsharding.NewBalancer(b, bOpts)
b.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf("[%p] ", b))
b.logger.Infof("Created")
return b
}
@ -90,94 +101,143 @@ type leastRequestBalancer struct {
// Embeds balancer.Balancer because needs to intercept UpdateClientConnState
// to learn about choiceCount.
balancer.Balancer
// Embeds balancer.ClientConn because needs to intercept UpdateState calls
// from the child balancer.
balancer.ClientConn
child balancer.Balancer
logger *internalgrpclog.PrefixLogger
mu sync.Mutex
choiceCount uint32
scRPCCounts map[balancer.SubConn]*atomic.Int32 // Hold onto RPC counts to keep track for subsequent picker updates.
// endpointRPCCounts holds RPC counts to keep track for subsequent picker
// updates.
endpointRPCCounts *resolver.EndpointMap // endpoint -> *atomic.Int32
}
func (lrb *leastRequestBalancer) UpdateClientConnState(s balancer.ClientConnState) error {
lrCfg, ok := s.BalancerConfig.(*LBConfig)
func (lrb *leastRequestBalancer) Close() {
lrb.child.Close()
lrb.endpointRPCCounts = nil
}
func (lrb *leastRequestBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error {
lrCfg, ok := ccs.BalancerConfig.(*LBConfig)
if !ok {
logger.Errorf("least-request: received config with unexpected type %T: %v", s.BalancerConfig, s.BalancerConfig)
logger.Errorf("least-request: received config with unexpected type %T: %v", ccs.BalancerConfig, ccs.BalancerConfig)
return balancer.ErrBadResolverState
}
lrb.mu.Lock()
lrb.choiceCount = lrCfg.ChoiceCount
return lrb.Balancer.UpdateClientConnState(s)
lrb.mu.Unlock()
// Enable the health listener in pickfirst children for client side health
// checks and outlier detection, if configured.
ccs.ResolverState = pickfirstleaf.EnableHealthListener(ccs.ResolverState)
ccs.BalancerConfig = endpointShardingLBConfig
return lrb.child.UpdateClientConnState(ccs)
}
type scWithRPCCount struct {
sc balancer.SubConn
type endpointState struct {
picker balancer.Picker
numRPCs *atomic.Int32
}
func (lrb *leastRequestBalancer) Build(info base.PickerBuildInfo) balancer.Picker {
if logger.V(2) {
logger.Infof("least-request: Build called with info: %v", info)
}
if len(info.ReadySCs) == 0 {
return base.NewErrPicker(balancer.ErrNoSubConnAvailable)
}
for sc := range lrb.scRPCCounts {
if _, ok := info.ReadySCs[sc]; !ok { // If no longer ready, no more need for the ref to count active RPCs.
delete(lrb.scRPCCounts, sc)
func (lrb *leastRequestBalancer) UpdateState(state balancer.State) {
var readyEndpoints []endpointsharding.ChildState
for _, child := range endpointsharding.ChildStatesFromPicker(state.Picker) {
if child.State.ConnectivityState == connectivity.Ready {
readyEndpoints = append(readyEndpoints, child)
}
}
// Create new refs if needed.
for sc := range info.ReadySCs {
if _, ok := lrb.scRPCCounts[sc]; !ok {
lrb.scRPCCounts[sc] = new(atomic.Int32)
// If no ready pickers are present, simply defer to the round robin picker
// from endpoint sharding, which will round robin across the most relevant
// pick first children in the highest precedence connectivity state.
if len(readyEndpoints) == 0 {
lrb.ClientConn.UpdateState(state)
return
}
lrb.mu.Lock()
defer lrb.mu.Unlock()
if logger.V(2) {
lrb.logger.Infof("UpdateState called with ready endpoints: %v", readyEndpoints)
}
// Reconcile endpoints.
newEndpoints := resolver.NewEndpointMap() // endpoint -> nil
for _, child := range readyEndpoints {
newEndpoints.Set(child.Endpoint, nil)
}
// If endpoints are no longer ready, no need to count their active RPCs.
for _, endpoint := range lrb.endpointRPCCounts.Keys() {
if _, ok := newEndpoints.Get(endpoint); !ok {
lrb.endpointRPCCounts.Delete(endpoint)
}
}
// Copy refs to counters into picker.
scs := make([]scWithRPCCount, 0, len(info.ReadySCs))
for sc := range info.ReadySCs {
scs = append(scs, scWithRPCCount{
sc: sc,
numRPCs: lrb.scRPCCounts[sc], // guaranteed to be present due to algorithm
endpointStates := make([]endpointState, 0, len(readyEndpoints))
for _, child := range readyEndpoints {
var counter *atomic.Int32
if val, ok := lrb.endpointRPCCounts.Get(child.Endpoint); !ok {
// Create new counts if needed.
counter = new(atomic.Int32)
lrb.endpointRPCCounts.Set(child.Endpoint, counter)
} else {
counter = val.(*atomic.Int32)
}
endpointStates = append(endpointStates, endpointState{
picker: child.State.Picker,
numRPCs: counter,
})
}
return &picker{
choiceCount: lrb.choiceCount,
subConns: scs,
}
lrb.ClientConn.UpdateState(balancer.State{
Picker: &picker{
choiceCount: lrb.choiceCount,
endpointStates: endpointStates,
},
ConnectivityState: connectivity.Ready,
})
}
type picker struct {
// choiceCount is the number of random SubConns to find the one with
// the least request.
choiceCount uint32
// Built out when receives list of ready RPCs.
subConns []scWithRPCCount
// choiceCount is the number of random endpoints to sample for choosing the
// one with the least requests.
choiceCount uint32
endpointStates []endpointState
}
func (p *picker) Pick(balancer.PickInfo) (balancer.PickResult, error) {
var pickedSC *scWithRPCCount
var pickedSCNumRPCs int32
func (p *picker) Pick(pInfo balancer.PickInfo) (balancer.PickResult, error) {
var pickedEndpointState *endpointState
var pickedEndpointNumRPCs int32
for i := 0; i < int(p.choiceCount); i++ {
index := randuint32() % uint32(len(p.subConns))
sc := p.subConns[index]
n := sc.numRPCs.Load()
if pickedSC == nil || n < pickedSCNumRPCs {
pickedSC = &sc
pickedSCNumRPCs = n
index := randuint32() % uint32(len(p.endpointStates))
endpointState := p.endpointStates[index]
n := endpointState.numRPCs.Load()
if pickedEndpointState == nil || n < pickedEndpointNumRPCs {
pickedEndpointState = &endpointState
pickedEndpointNumRPCs = n
}
}
result, err := pickedEndpointState.picker.Pick(pInfo)
if err != nil {
return result, err
}
// "The counter for a subchannel should be atomically incremented by one
// after it has been successfully picked by the picker." - A48
pickedSC.numRPCs.Add(1)
pickedEndpointState.numRPCs.Add(1)
// "the picker should add a callback for atomically decrementing the
// subchannel counter once the RPC finishes (regardless of Status code)." -
// A48.
done := func(balancer.DoneInfo) {
pickedSC.numRPCs.Add(-1)
originalDone := result.Done
result.Done = func(info balancer.DoneInfo) {
pickedEndpointState.numRPCs.Add(-1)
if originalDone != nil {
originalDone(info)
}
}
return balancer.PickResult{
SubConn: pickedSC.sc,
Done: done,
}, nil
return result, nil
}