xds: use locality from the connected address for load reporting (#7378)

This commit is contained in:
Brad Town 2024-07-10 12:51:11 -07:00 committed by GitHub
parent 45d44a736e
commit 9c5b31d74b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 91 additions and 44 deletions

View File

@ -72,8 +72,21 @@ func unregisterForTesting(name string) {
delete(m, name) delete(m, name)
} }
// connectedAddress returns the connected address for a SubConnState. The
// address is only valid if the state is READY.
func connectedAddress(scs SubConnState) resolver.Address {
return scs.connectedAddress
}
// setConnectedAddress sets the connected address for a SubConnState.
func setConnectedAddress(scs *SubConnState, addr resolver.Address) {
scs.connectedAddress = addr
}
func init() { func init() {
internal.BalancerUnregister = unregisterForTesting internal.BalancerUnregister = unregisterForTesting
internal.ConnectedAddress = connectedAddress
internal.SetConnectedAddress = setConnectedAddress
} }
// Get returns the resolver builder registered with the given name. // Get returns the resolver builder registered with the given name.
@ -410,6 +423,9 @@ type SubConnState struct {
// ConnectionError is set if the ConnectivityState is TransientFailure, // ConnectionError is set if the ConnectivityState is TransientFailure,
// describing the reason the SubConn failed. Otherwise, it is nil. // describing the reason the SubConn failed. Otherwise, it is nil.
ConnectionError error ConnectionError error
// connectedAddr contains the connected address when ConnectivityState is
// Ready. Otherwise, it is indeterminate.
connectedAddress resolver.Address
} }
// ClientConnState describes the state of a ClientConn relevant to the // ClientConnState describes the state of a ClientConn relevant to the

View File

@ -25,12 +25,15 @@ import (
"google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity" "google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/balancer/gracefulswitch" "google.golang.org/grpc/internal/balancer/gracefulswitch"
"google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
) )
var setConnectedAddress = internal.SetConnectedAddress.(func(*balancer.SubConnState, resolver.Address))
// ccBalancerWrapper sits between the ClientConn and the Balancer. // ccBalancerWrapper sits between the ClientConn and the Balancer.
// //
// ccBalancerWrapper implements methods corresponding to the ones on the // ccBalancerWrapper implements methods corresponding to the ones on the
@ -252,7 +255,7 @@ type acBalancerWrapper struct {
// updateState is invoked by grpc to push a subConn state update to the // updateState is invoked by grpc to push a subConn state update to the
// underlying balancer. // underlying balancer.
func (acbw *acBalancerWrapper) updateState(s connectivity.State, err error) { func (acbw *acBalancerWrapper) updateState(s connectivity.State, curAddr resolver.Address, err error) {
acbw.ccb.serializer.Schedule(func(ctx context.Context) { acbw.ccb.serializer.Schedule(func(ctx context.Context) {
if ctx.Err() != nil || acbw.ccb.balancer == nil { if ctx.Err() != nil || acbw.ccb.balancer == nil {
return return
@ -260,7 +263,11 @@ func (acbw *acBalancerWrapper) updateState(s connectivity.State, err error) {
// Even though it is optional for balancers, gracefulswitch ensures // Even though it is optional for balancers, gracefulswitch ensures
// opts.StateListener is set, so this cannot ever be nil. // opts.StateListener is set, so this cannot ever be nil.
// TODO: delete this comment when UpdateSubConnState is removed. // TODO: delete this comment when UpdateSubConnState is removed.
acbw.stateListener(balancer.SubConnState{ConnectivityState: s, ConnectionError: err}) scs := balancer.SubConnState{ConnectivityState: s, ConnectionError: err}
if s == connectivity.Ready {
setConnectedAddress(&scs, curAddr)
}
acbw.stateListener(scs)
}) })
} }

View File

@ -24,6 +24,7 @@ import (
"fmt" "fmt"
"math" "math"
"net/url" "net/url"
"slices"
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -812,17 +813,11 @@ func (cc *ClientConn) applyFailingLBLocked(sc *serviceconfig.ParseResult) {
cc.csMgr.updateState(connectivity.TransientFailure) cc.csMgr.updateState(connectivity.TransientFailure)
} }
// Makes a copy of the input addresses slice and clears out the balancer // Makes a copy of the input addresses slice. Addresses are passed during
// attributes field. Addresses are passed during subconn creation and address // subconn creation and address update operations.
// update operations. In both cases, we will clear the balancer attributes by func copyAddresses(in []resolver.Address) []resolver.Address {
// calling this function, and therefore we will be able to use the Equal method
// provided by the resolver.Address type for comparison.
func copyAddressesWithoutBalancerAttributes(in []resolver.Address) []resolver.Address {
out := make([]resolver.Address, len(in)) out := make([]resolver.Address, len(in))
for i := range in { copy(out, in)
out[i] = in[i]
out[i].BalancerAttributes = nil
}
return out return out
} }
@ -837,7 +832,7 @@ func (cc *ClientConn) newAddrConnLocked(addrs []resolver.Address, opts balancer.
ac := &addrConn{ ac := &addrConn{
state: connectivity.Idle, state: connectivity.Idle,
cc: cc, cc: cc,
addrs: copyAddressesWithoutBalancerAttributes(addrs), addrs: copyAddresses(addrs),
scopts: opts, scopts: opts,
dopts: cc.dopts, dopts: cc.dopts,
channelz: channelz.RegisterSubChannel(cc.channelz, ""), channelz: channelz.RegisterSubChannel(cc.channelz, ""),
@ -923,22 +918,24 @@ func (ac *addrConn) connect() error {
return nil return nil
} }
func equalAddresses(a, b []resolver.Address) bool { // equalAddressIgnoringBalAttributes returns true is a and b are considered equal.
if len(a) != len(b) { // This is different from the Equal method on the resolver.Address type which
return false // considers all fields to determine equality. Here, we only consider fields
} // that are meaningful to the subConn.
for i, v := range a { func equalAddressIgnoringBalAttributes(a, b *resolver.Address) bool {
if !v.Equal(b[i]) { return a.Addr == b.Addr && a.ServerName == b.ServerName &&
return false a.Attributes.Equal(b.Attributes) &&
} a.Metadata == b.Metadata
} }
return true
func equalAddressesIgnoringBalAttributes(a, b []resolver.Address) bool {
return slices.EqualFunc(a, b, func(a, b resolver.Address) bool { return equalAddressIgnoringBalAttributes(&a, &b) })
} }
// updateAddrs updates ac.addrs with the new addresses list and handles active // updateAddrs updates ac.addrs with the new addresses list and handles active
// connections or connection attempts. // connections or connection attempts.
func (ac *addrConn) updateAddrs(addrs []resolver.Address) { func (ac *addrConn) updateAddrs(addrs []resolver.Address) {
addrs = copyAddressesWithoutBalancerAttributes(addrs) addrs = copyAddresses(addrs)
limit := len(addrs) limit := len(addrs)
if limit > 5 { if limit > 5 {
limit = 5 limit = 5
@ -946,7 +943,7 @@ func (ac *addrConn) updateAddrs(addrs []resolver.Address) {
channelz.Infof(logger, ac.channelz, "addrConn: updateAddrs addrs (%d of %d): %v", limit, len(addrs), addrs[:limit]) channelz.Infof(logger, ac.channelz, "addrConn: updateAddrs addrs (%d of %d): %v", limit, len(addrs), addrs[:limit])
ac.mu.Lock() ac.mu.Lock()
if equalAddresses(ac.addrs, addrs) { if equalAddressesIgnoringBalAttributes(ac.addrs, addrs) {
ac.mu.Unlock() ac.mu.Unlock()
return return
} }
@ -965,7 +962,7 @@ func (ac *addrConn) updateAddrs(addrs []resolver.Address) {
// Try to find the connected address. // Try to find the connected address.
for _, a := range addrs { for _, a := range addrs {
a.ServerName = ac.cc.getServerName(a) a.ServerName = ac.cc.getServerName(a)
if a.Equal(ac.curAddr) { if equalAddressIgnoringBalAttributes(&a, &ac.curAddr) {
// We are connected to a valid address, so do nothing but // We are connected to a valid address, so do nothing but
// update the addresses. // update the addresses.
ac.mu.Unlock() ac.mu.Unlock()
@ -1211,7 +1208,7 @@ func (ac *addrConn) updateConnectivityState(s connectivity.State, lastErr error)
} else { } else {
channelz.Infof(logger, ac.channelz, "Subchannel Connectivity change to %v, last error: %s", s, lastErr) channelz.Infof(logger, ac.channelz, "Subchannel Connectivity change to %v, last error: %s", s, lastErr)
} }
ac.acbw.updateState(s, lastErr) ac.acbw.updateState(s, ac.curAddr, lastErr)
} }
// adjustParams updates parameters used to create transports upon // adjustParams updates parameters used to create transports upon

View File

@ -208,6 +208,13 @@ var (
// ShuffleAddressListForTesting pseudo-randomizes the order of addresses. n // ShuffleAddressListForTesting pseudo-randomizes the order of addresses. n
// is the number of elements. swap swaps the elements with indexes i and j. // is the number of elements. swap swaps the elements with indexes i and j.
ShuffleAddressListForTesting any // func(n int, swap func(i, j int)) ShuffleAddressListForTesting any // func(n int, swap func(i, j int))
// ConnectedAddress returns the connected address for a SubConnState. The
// address is only valid if the state is READY.
ConnectedAddress any // func (scs SubConnState) resolver.Address
// SetConnectedAddress sets the connected address for a SubConnState.
SetConnectedAddress any // func(scs *SubConnState, addr resolver.Address)
) )
// HealthChecker defines the signature of the client-side LB channel health // HealthChecker defines the signature of the client-side LB channel health

View File

@ -32,6 +32,7 @@ import (
"google.golang.org/grpc/balancer/base" "google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/balancer/roundrobin" "google.golang.org/grpc/balancer/roundrobin"
"google.golang.org/grpc/connectivity" "google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/balancer/stub" "google.golang.org/grpc/internal/balancer/stub"
"google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/grpctest"
internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" internalserviceconfig "google.golang.org/grpc/internal/serviceconfig"
@ -637,7 +638,10 @@ func (s) TestLoadReporting(t *testing.T) {
t.Fatal(err.Error()) t.Fatal(err.Error())
} }
sc1.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready}) scs := balancer.SubConnState{ConnectivityState: connectivity.Ready}
sca := internal.SetConnectedAddress.(func(*balancer.SubConnState, resolver.Address))
sca(&scs, addrs[0])
sc1.UpdateState(scs)
// Test pick with one backend. // Test pick with one backend.
const successCount = 5 const successCount = 5
const errorCount = 5 const errorCount = 5

View File

@ -31,6 +31,7 @@ import (
"google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity" "google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/balancer/gracefulswitch" "google.golang.org/grpc/internal/balancer/gracefulswitch"
"google.golang.org/grpc/internal/buffer" "google.golang.org/grpc/internal/buffer"
"google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpclog"
@ -52,6 +53,8 @@ const (
defaultRequestCountMax = 1024 defaultRequestCountMax = 1024
) )
var connectedAddress = internal.ConnectedAddress.(func(balancer.SubConnState) resolver.Address)
func init() { func init() {
balancer.Register(bb{}) balancer.Register(bb{})
} }
@ -360,22 +363,35 @@ func (scw *scWrapper) localityID() xdsinternal.LocalityID {
func (b *clusterImplBalancer) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { func (b *clusterImplBalancer) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) {
clusterName := b.getClusterName() clusterName := b.getClusterName()
newAddrs := make([]resolver.Address, len(addrs)) newAddrs := make([]resolver.Address, len(addrs))
var lID xdsinternal.LocalityID
for i, addr := range addrs { for i, addr := range addrs {
newAddrs[i] = xds.SetXDSHandshakeClusterName(addr, clusterName) newAddrs[i] = xds.SetXDSHandshakeClusterName(addr, clusterName)
lID = xdsinternal.GetLocalityID(newAddrs[i])
} }
var sc balancer.SubConn var sc balancer.SubConn
scw := &scWrapper{}
oldListener := opts.StateListener oldListener := opts.StateListener
opts.StateListener = func(state balancer.SubConnState) { b.updateSubConnState(sc, state, oldListener) } opts.StateListener = func(state balancer.SubConnState) {
b.updateSubConnState(sc, state, oldListener)
if state.ConnectivityState != connectivity.Ready {
return
}
// Read connected address and call updateLocalityID() based on the connected
// address's locality. https://github.com/grpc/grpc-go/issues/7339
addr := connectedAddress(state)
lID := xdsinternal.GetLocalityID(addr)
if lID.Empty() {
if b.logger.V(2) {
b.logger.Infof("Locality ID for %s unexpectedly empty", addr)
}
return
}
scw.updateLocalityID(lID)
}
sc, err := b.ClientConn.NewSubConn(newAddrs, opts) sc, err := b.ClientConn.NewSubConn(newAddrs, opts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Wrap this SubConn in a wrapper, and add it to the map. scw.SubConn = sc
ret := &scWrapper{SubConn: sc} return scw, nil
ret.updateLocalityID(lID)
return ret, nil
} }
func (b *clusterImplBalancer) RemoveSubConn(sc balancer.SubConn) { func (b *clusterImplBalancer) RemoveSubConn(sc balancer.SubConn) {

View File

@ -310,14 +310,9 @@ func (s) TestLoadReportingPickFirstMultiLocality(t *testing.T) {
} }
mgmtServer.LRSServer.LRSResponseChan <- &resp mgmtServer.LRSServer.LRSResponseChan <- &resp
// Wait for load to be reported for locality of server 2. // Wait for load to be reported for locality of server 1.
// We (incorrectly) wait for load report for region-2 because presently if err := waitForSuccessfulLoadReport(ctx, mgmtServer.LRSServer, "region-1"); err != nil {
// pickfirst always reports load for the locality of the last address in the t.Fatalf("Server 1 did not receive load due to error: %v", err)
// subconn. This will be fixed by ensuring there is only one address per
// subconn.
// TODO(#7339): Change region to region-1 once fixed.
if err := waitForSuccessfulLoadReport(ctx, mgmtServer.LRSServer, "region-2"); err != nil {
t.Fatalf("region-2 did not receive load due to error: %v", err)
} }
// Stop server 1 and send one more rpc. Now the request should go to server 2. // Stop server 1 and send one more rpc. Now the request should go to server 2.

View File

@ -852,7 +852,7 @@ func (s) TestUpdateAddresses(t *testing.T) {
} }
func scwsEqual(gotSCWS subConnWithState, wantSCWS subConnWithState) error { func scwsEqual(gotSCWS subConnWithState, wantSCWS subConnWithState) error {
if gotSCWS.sc != wantSCWS.sc || !cmp.Equal(gotSCWS.state, wantSCWS.state, cmp.AllowUnexported(subConnWrapper{}, addressInfo{}), cmpopts.IgnoreFields(subConnWrapper{}, "scUpdateCh")) { if gotSCWS.sc != wantSCWS.sc || !cmp.Equal(gotSCWS.state, wantSCWS.state, cmp.AllowUnexported(subConnWrapper{}, addressInfo{}, balancer.SubConnState{}), cmpopts.IgnoreFields(subConnWrapper{}, "scUpdateCh")) {
return fmt.Errorf("received SubConnState: %+v, want %+v", gotSCWS, wantSCWS) return fmt.Errorf("received SubConnState: %+v, want %+v", gotSCWS, wantSCWS)
} }
return nil return nil

View File

@ -55,6 +55,11 @@ func (l LocalityID) Equal(o any) bool {
return l.Region == ol.Region && l.Zone == ol.Zone && l.SubZone == ol.SubZone return l.Region == ol.Region && l.Zone == ol.Zone && l.SubZone == ol.SubZone
} }
// Empty returns whether or not the locality ID is empty.
func (l LocalityID) Empty() bool {
return l.Region == "" && l.Zone == "" && l.SubZone == ""
}
// LocalityIDFromString converts a json representation of locality, into a // LocalityIDFromString converts a json representation of locality, into a
// LocalityID struct. // LocalityID struct.
func LocalityIDFromString(s string) (ret LocalityID, _ error) { func LocalityIDFromString(s string) (ret LocalityID, _ error) {