balancer/weightedroundrobin: add load balancing policy (A58) (#6241)

This commit is contained in:
Doug Fawley 2023-05-08 10:01:08 -07:00 committed by GitHub
parent c44f77e12d
commit 5c4bee51c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1545 additions and 21 deletions

View File

@ -0,0 +1,532 @@
/*
*
* Copyright 2023 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package weightedroundrobin
import (
"context"
"encoding/json"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"unsafe"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/balancer/weightedroundrobin/internal"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/orca"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig"
v3orcapb "github.com/cncf/xds/go/xds/data/orca/v3"
)
// Name is the name of the weighted round robin balancer.
const Name = "weighted_round_robin_experimental"
func init() {
balancer.Register(bb{})
}
type bb struct{}
func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer {
b := &wrrBalancer{
cc: cc,
subConns: resolver.NewAddressMap(),
csEvltr: &balancer.ConnectivityStateEvaluator{},
scMap: make(map[balancer.SubConn]*weightedSubConn),
connectivityState: connectivity.Connecting,
}
b.logger = prefixLogger(b)
b.logger.Infof("Created")
return b
}
func (bb) ParseConfig(js json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
lbCfg := &lbConfig{
// Default values as documented in A58.
OOBReportingPeriod: 10 * time.Second,
BlackoutPeriod: 10 * time.Second,
WeightExpirationPeriod: 3 * time.Minute,
WeightUpdatePeriod: time.Second,
ErrorUtilizationPenalty: 1,
}
if err := json.Unmarshal(js, lbCfg); err != nil {
return nil, fmt.Errorf("wrr: unable to unmarshal LB policy config: %s, error: %v", string(js), err)
}
if lbCfg.ErrorUtilizationPenalty < 0 {
return nil, fmt.Errorf("wrr: errorUtilizationPenalty must be non-negative")
}
// For easier comparisons later, ensure the OOB reporting period is unset
// (0s) when OOB reports are disabled.
if !lbCfg.EnableOOBLoadReport {
lbCfg.OOBReportingPeriod = 0
}
// Impose lower bound of 100ms on weightUpdatePeriod.
if !internal.AllowAnyWeightUpdatePeriod && lbCfg.WeightUpdatePeriod < 100*time.Millisecond {
lbCfg.WeightUpdatePeriod = 100 * time.Millisecond
}
return lbCfg, nil
}
func (bb) Name() string {
return Name
}
// wrrBalancer implements the weighted round robin LB policy.
type wrrBalancer struct {
cc balancer.ClientConn
logger *grpclog.PrefixLogger
// The following fields are only accessed on calls into the LB policy, and
// do not need a mutex.
cfg *lbConfig // active config
subConns *resolver.AddressMap // active weightedSubConns mapped by address
scMap map[balancer.SubConn]*weightedSubConn
connectivityState connectivity.State // aggregate state
csEvltr *balancer.ConnectivityStateEvaluator
resolverErr error // the last error reported by the resolver; cleared on successful resolution
connErr error // the last connection error; cleared upon leaving TransientFailure
stopPicker func()
}
func (b *wrrBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error {
b.logger.Infof("UpdateCCS: %v", ccs)
b.resolverErr = nil
cfg, ok := ccs.BalancerConfig.(*lbConfig)
if !ok {
return fmt.Errorf("wrr: received nil or illegal BalancerConfig (type %T): %v", ccs.BalancerConfig, ccs.BalancerConfig)
}
b.cfg = cfg
b.updateAddresses(ccs.ResolverState.Addresses)
if len(ccs.ResolverState.Addresses) == 0 {
b.ResolverError(errors.New("resolver produced zero addresses")) // will call regeneratePicker
return balancer.ErrBadResolverState
}
b.regeneratePicker()
return nil
}
func (b *wrrBalancer) updateAddresses(addrs []resolver.Address) {
addrsSet := resolver.NewAddressMap()
// Loop through new address list and create subconns for any new addresses.
for _, addr := range addrs {
if _, ok := addrsSet.Get(addr); ok {
// Redundant address; skip.
continue
}
addrsSet.Set(addr, nil)
var wsc *weightedSubConn
wsci, ok := b.subConns.Get(addr)
if ok {
wsc = wsci.(*weightedSubConn)
} else {
// addr is a new address (not existing in b.subConns).
sc, err := b.cc.NewSubConn([]resolver.Address{addr}, balancer.NewSubConnOptions{})
if err != nil {
b.logger.Warningf("Failed to create new SubConn for address %v: %v", addr, err)
continue
}
wsc = &weightedSubConn{
SubConn: sc,
logger: b.logger,
connectivityState: connectivity.Idle,
// Initially, we set load reports to off, because they are not
// running upon initial weightedSubConn creation.
cfg: &lbConfig{EnableOOBLoadReport: false},
}
b.subConns.Set(addr, wsc)
b.scMap[sc] = wsc
b.csEvltr.RecordTransition(connectivity.Shutdown, connectivity.Idle)
sc.Connect()
}
// Update config for existing weightedSubConn or send update for first
// time to new one. Ensures an OOB listener is running if needed
// (and stops the existing one if applicable).
wsc.updateConfig(b.cfg)
}
// Loop through existing subconns and remove ones that are not in addrs.
for _, addr := range b.subConns.Keys() {
if _, ok := addrsSet.Get(addr); ok {
// Existing address also in new address list; skip.
continue
}
// addr was removed by resolver. Remove.
wsci, _ := b.subConns.Get(addr)
wsc := wsci.(*weightedSubConn)
b.cc.RemoveSubConn(wsc.SubConn)
b.subConns.Delete(addr)
}
}
func (b *wrrBalancer) ResolverError(err error) {
b.resolverErr = err
if b.subConns.Len() == 0 {
b.connectivityState = connectivity.TransientFailure
}
if b.connectivityState != connectivity.TransientFailure {
// No need to update the picker since no error is being returned.
return
}
b.regeneratePicker()
}
func (b *wrrBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
wsc := b.scMap[sc]
if wsc == nil {
b.logger.Errorf("UpdateSubConnState called with an unknown SubConn: %p, %v", sc, state)
return
}
if b.logger.V(2) {
logger.Infof("UpdateSubConnState(%+v, %+v)", sc, state)
}
cs := state.ConnectivityState
if cs == connectivity.TransientFailure {
// Save error to be reported via picker.
b.connErr = state.ConnectionError
}
if cs == connectivity.Shutdown {
delete(b.scMap, sc)
// The subconn was removed from b.subConns when the address was removed
// in updateAddresses.
}
oldCS := wsc.updateConnectivityState(cs)
b.connectivityState = b.csEvltr.RecordTransition(oldCS, cs)
// Regenerate picker when one of the following happens:
// - this sc entered or left ready
// - the aggregated state of balancer is TransientFailure
// (may need to update error message)
if (cs == connectivity.Ready) != (oldCS == connectivity.Ready) ||
b.connectivityState == connectivity.TransientFailure {
b.regeneratePicker()
}
}
// Close stops the balancer. It cancels any ongoing scheduler updates and
// stops any ORCA listeners.
func (b *wrrBalancer) Close() {
if b.stopPicker != nil {
b.stopPicker()
b.stopPicker = nil
}
for _, wsc := range b.scMap {
// Ensure any lingering OOB watchers are stopped.
wsc.updateConnectivityState(connectivity.Shutdown)
}
}
// ExitIdle is ignored; we always connect to all backends.
func (b *wrrBalancer) ExitIdle() {}
func (b *wrrBalancer) readySubConns() []*weightedSubConn {
var ret []*weightedSubConn
for _, v := range b.subConns.Values() {
wsc := v.(*weightedSubConn)
if wsc.connectivityState == connectivity.Ready {
ret = append(ret, wsc)
}
}
return ret
}
// mergeErrors builds an error from the last connection error and the last
// resolver error. Must only be called if b.connectivityState is
// TransientFailure.
func (b *wrrBalancer) mergeErrors() error {
// connErr must always be non-nil unless there are no SubConns, in which
// case resolverErr must be non-nil.
if b.connErr == nil {
return fmt.Errorf("last resolver error: %v", b.resolverErr)
}
if b.resolverErr == nil {
return fmt.Errorf("last connection error: %v", b.connErr)
}
return fmt.Errorf("last connection error: %v; last resolver error: %v", b.connErr, b.resolverErr)
}
func (b *wrrBalancer) regeneratePicker() {
if b.stopPicker != nil {
b.stopPicker()
b.stopPicker = nil
}
switch b.connectivityState {
case connectivity.TransientFailure:
b.cc.UpdateState(balancer.State{
ConnectivityState: connectivity.TransientFailure,
Picker: base.NewErrPicker(b.mergeErrors()),
})
return
case connectivity.Connecting, connectivity.Idle:
// Idle could happen very briefly if all subconns are Idle and we've
// asked them to connect but they haven't reported Connecting yet.
// Report the same as Connecting since this is temporary.
b.cc.UpdateState(balancer.State{
ConnectivityState: connectivity.Connecting,
Picker: base.NewErrPicker(balancer.ErrNoSubConnAvailable),
})
return
case connectivity.Ready:
b.connErr = nil
}
p := &picker{
v: grpcrand.Uint32(), // start the scheduler at a random point
cfg: b.cfg,
subConns: b.readySubConns(),
}
var ctx context.Context
ctx, b.stopPicker = context.WithCancel(context.Background())
p.start(ctx)
b.cc.UpdateState(balancer.State{
ConnectivityState: b.connectivityState,
Picker: p,
})
}
// picker is the WRR policy's picker. It uses live-updating backend weights to
// update the scheduler periodically and ensure picks are routed proportional
// to those weights.
type picker struct {
scheduler unsafe.Pointer // *scheduler; accessed atomically
v uint32 // incrementing value used by the scheduler; accessed atomically
cfg *lbConfig // active config when picker created
subConns []*weightedSubConn // all READY subconns
}
// scWeights returns a slice containing the weights from p.subConns in the same
// order as p.subConns.
func (p *picker) scWeights() []float64 {
ws := make([]float64, len(p.subConns))
now := internal.TimeNow()
for i, wsc := range p.subConns {
ws[i] = wsc.weight(now, p.cfg.WeightExpirationPeriod, p.cfg.BlackoutPeriod)
}
return ws
}
func (p *picker) inc() uint32 {
return atomic.AddUint32(&p.v, 1)
}
func (p *picker) regenerateScheduler() {
s := newScheduler(p.scWeights(), p.inc)
atomic.StorePointer(&p.scheduler, unsafe.Pointer(&s))
}
func (p *picker) start(ctx context.Context) {
p.regenerateScheduler()
if len(p.subConns) == 1 {
// No need to regenerate weights with only one backend.
return
}
go func() {
ticker := time.NewTicker(p.cfg.WeightUpdatePeriod)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
p.regenerateScheduler()
}
}
}()
}
func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
// Read the scheduler atomically. All scheduler operations are threadsafe,
// and if the scheduler is replaced during this usage, we want to use the
// scheduler that was live when the pick started.
sched := *(*scheduler)(atomic.LoadPointer(&p.scheduler))
pickedSC := p.subConns[sched.nextIndex()]
pr := balancer.PickResult{SubConn: pickedSC.SubConn}
if !p.cfg.EnableOOBLoadReport {
pr.Done = func(info balancer.DoneInfo) {
if load, ok := info.ServerLoad.(*v3orcapb.OrcaLoadReport); ok && load != nil {
pickedSC.OnLoadReport(load)
}
}
}
return pr, nil
}
// weightedSubConn is the wrapper of a subconn that holds the subconn and its
// weight (and other parameters relevant to computing the effective weight).
// When needed, it also tracks connectivity state, listens for metrics updates
// by implementing the orca.OOBListener interface and manages that listener.
type weightedSubConn struct {
balancer.SubConn
logger *grpclog.PrefixLogger
// The following fields are only accessed on calls into the LB policy, and
// do not need a mutex.
connectivityState connectivity.State
stopORCAListener func()
// The following fields are accessed asynchronously and are protected by
// mu. Note that mu may not be held when calling into the stopORCAListener
// or when registering a new listener, as those calls require the ORCA
// producer mu which is held when calling the listener, and the listener
// holds mu.
mu sync.Mutex
weightVal float64
nonEmptySince time.Time
lastUpdated time.Time
cfg *lbConfig
}
func (w *weightedSubConn) OnLoadReport(load *v3orcapb.OrcaLoadReport) {
if w.logger.V(2) {
w.logger.Infof("Received load report for subchannel %v: %v", w.SubConn, load)
}
// Update weights of this subchannel according to the reported load
if load.CpuUtilization == 0 || load.RpsFractional == 0 {
if w.logger.V(2) {
w.logger.Infof("Ignoring empty load report for subchannel %v", w.SubConn)
}
return
}
w.mu.Lock()
defer w.mu.Unlock()
errorRate := load.Eps / load.RpsFractional
w.weightVal = load.RpsFractional / (load.CpuUtilization + errorRate*w.cfg.ErrorUtilizationPenalty)
if w.logger.V(2) {
w.logger.Infof("New weight for subchannel %v: %v", w.SubConn, w.weightVal)
}
w.lastUpdated = internal.TimeNow()
if w.nonEmptySince == (time.Time{}) {
w.nonEmptySince = w.lastUpdated
}
}
// updateConfig updates the parameters of the WRR policy and
// stops/starts/restarts the ORCA OOB listener.
func (w *weightedSubConn) updateConfig(cfg *lbConfig) {
w.mu.Lock()
oldCfg := w.cfg
w.cfg = cfg
w.mu.Unlock()
newPeriod := cfg.OOBReportingPeriod
if cfg.EnableOOBLoadReport == oldCfg.EnableOOBLoadReport &&
newPeriod == oldCfg.OOBReportingPeriod {
// Load reporting wasn't enabled before or after, or load reporting was
// enabled before and after, and had the same period. (Note that with
// load reporting disabled, OOBReportingPeriod is always 0.)
return
}
// (Optionally stop and) start the listener to use the new config's
// settings for OOB reporting.
if w.stopORCAListener != nil {
w.stopORCAListener()
}
if !cfg.EnableOOBLoadReport {
w.stopORCAListener = nil
return
}
if w.logger.V(2) {
w.logger.Infof("Registering ORCA listener for %v with interval %v", w.SubConn, newPeriod)
}
opts := orca.OOBListenerOptions{ReportInterval: newPeriod}
w.stopORCAListener = orca.RegisterOOBListener(w.SubConn, w, opts)
}
func (w *weightedSubConn) updateConnectivityState(cs connectivity.State) connectivity.State {
switch cs {
case connectivity.Idle:
// Always reconnect when idle.
w.SubConn.Connect()
case connectivity.Ready:
// If we transition back to READY state, reset nonEmptySince so that we
// apply the blackout period after we start receiving load data. Note
// that we cannot guarantee that we will never receive lingering
// callbacks for backend metric reports from the previous connection
// after the new connection has been established, but they should be
// masked by new backend metric reports from the new connection by the
// time the blackout period ends.
w.mu.Lock()
w.nonEmptySince = time.Time{}
w.mu.Unlock()
case connectivity.Shutdown:
if w.stopORCAListener != nil {
w.stopORCAListener()
}
}
oldCS := w.connectivityState
if oldCS == connectivity.TransientFailure &&
(cs == connectivity.Connecting || cs == connectivity.Idle) {
// Once a subconn enters TRANSIENT_FAILURE, ignore subsequent IDLE or
// CONNECTING transitions to prevent the aggregated state from being
// always CONNECTING when many backends exist but are all down.
return oldCS
}
w.connectivityState = cs
return oldCS
}
// weight returns the current effective weight of the subconn, taking into
// account the parameters. Returns 0 for blacked out or expired data, which
// will cause the backend weight to be treated as the mean of the weights of
// the other backends.
func (w *weightedSubConn) weight(now time.Time, weightExpirationPeriod, blackoutPeriod time.Duration) float64 {
w.mu.Lock()
defer w.mu.Unlock()
// If the most recent update was longer ago than the expiration period,
// reset nonEmptySince so that we apply the blackout period again if we
// start getting data again in the future, and return 0.
if now.Sub(w.lastUpdated) >= weightExpirationPeriod {
w.nonEmptySince = time.Time{}
return 0
}
// If we don't have at least blackoutPeriod worth of data, return 0.
if blackoutPeriod != 0 && (w.nonEmptySince == (time.Time{}) || now.Sub(w.nonEmptySince) < blackoutPeriod) {
return 0
}
return w.weightVal
}

View File

@ -0,0 +1,713 @@
/*
*
* Copyright 2023 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package weightedroundrobin_test
import (
"context"
"encoding/json"
"fmt"
"sync"
"testing"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/internal/testutils/roundrobin"
"google.golang.org/grpc/orca"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/resolver"
wrr "google.golang.org/grpc/balancer/weightedroundrobin"
iwrr "google.golang.org/grpc/balancer/weightedroundrobin/internal"
testgrpc "google.golang.org/grpc/interop/grpc_testing"
testpb "google.golang.org/grpc/interop/grpc_testing"
)
type s struct {
grpctest.Tester
}
func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}
const defaultTestTimeout = 10 * time.Second
const weightUpdatePeriod = 50 * time.Millisecond
const oobReportingInterval = 10 * time.Millisecond
func init() {
iwrr.AllowAnyWeightUpdatePeriod = true
}
func boolp(b bool) *bool { return &b }
func float64p(f float64) *float64 { return &f }
func durationp(d time.Duration) *time.Duration { return &d }
var (
perCallConfig = iwrr.LBConfig{
EnableOOBLoadReport: boolp(false),
OOBReportingPeriod: durationp(5 * time.Millisecond),
BlackoutPeriod: durationp(0),
WeightExpirationPeriod: durationp(time.Minute),
WeightUpdatePeriod: durationp(weightUpdatePeriod),
ErrorUtilizationPenalty: float64p(0),
}
oobConfig = iwrr.LBConfig{
EnableOOBLoadReport: boolp(true),
OOBReportingPeriod: durationp(5 * time.Millisecond),
BlackoutPeriod: durationp(0),
WeightExpirationPeriod: durationp(time.Minute),
WeightUpdatePeriod: durationp(weightUpdatePeriod),
ErrorUtilizationPenalty: float64p(0),
}
)
type testServer struct {
*stubserver.StubServer
oobMetrics orca.ServerMetricsRecorder // Attached to the OOB stream.
callMetrics orca.CallMetricsRecorder // Attached to per-call metrics.
}
type reportType int
const (
reportNone reportType = iota
reportOOB
reportCall
reportBoth
)
func startServer(t *testing.T, r reportType) *testServer {
t.Helper()
smr := orca.NewServerMetricsRecorder()
cmr := orca.NewServerMetricsRecorder().(orca.CallMetricsRecorder)
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
if r := orca.CallMetricsRecorderFromContext(ctx); r != nil {
// Copy metrics from what the test set in cmr into r.
sm := cmr.(orca.ServerMetricsProvider).ServerMetrics()
r.SetCPUUtilization(sm.CPUUtilization)
r.SetQPS(sm.QPS)
r.SetEPS(sm.EPS)
}
return &testpb.Empty{}, nil
},
}
var sopts []grpc.ServerOption
if r == reportCall || r == reportBoth {
sopts = append(sopts, orca.CallMetricsServerOption(nil))
}
if r == reportOOB || r == reportBoth {
oso := orca.ServiceOptions{
ServerMetricsProvider: smr,
MinReportingInterval: 10 * time.Millisecond,
}
internal.ORCAAllowAnyMinReportingInterval.(func(so *orca.ServiceOptions))(&oso)
sopts = append(sopts, stubserver.RegisterServiceServerOption(func(s *grpc.Server) {
if err := orca.Register(s, oso); err != nil {
t.Fatalf("Failed to register orca service: %v", err)
}
}))
}
if err := ss.StartServer(sopts...); err != nil {
t.Fatalf("Error starting server: %v", err)
}
t.Cleanup(ss.Stop)
return &testServer{
StubServer: ss,
oobMetrics: smr,
callMetrics: cmr,
}
}
func svcConfig(t *testing.T, wrrCfg iwrr.LBConfig) string {
t.Helper()
m, err := json.Marshal(wrrCfg)
if err != nil {
t.Fatalf("Error marshaling JSON %v: %v", wrrCfg, err)
}
sc := fmt.Sprintf(`{"loadBalancingConfig": [ {%q:%v} ] }`, wrr.Name, string(m))
t.Logf("Marshaled service config: %v", sc)
return sc
}
// Tests basic functionality with one address. With only one address, load
// reporting doesn't affect routing at all.
func (s) TestBalancer_OneAddress(t *testing.T) {
testCases := []struct {
rt reportType
cfg iwrr.LBConfig
}{
{rt: reportNone, cfg: perCallConfig},
{rt: reportCall, cfg: perCallConfig},
{rt: reportOOB, cfg: oobConfig},
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("reportType:%v", tc.rt), func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
srv := startServer(t, tc.rt)
sc := svcConfig(t, tc.cfg)
if err := srv.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil {
t.Fatalf("Error starting client: %v", err)
}
// Perform many RPCs to ensure the LB policy works with 1 address.
for i := 0; i < 100; i++ {
srv.callMetrics.SetQPS(float64(i))
srv.oobMetrics.SetQPS(float64(i))
if _, err := srv.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Fatalf("Error from EmptyCall: %v", err)
}
time.Sleep(time.Millisecond) // Delay; test will run 100ms and should perform ~10 weight updates
}
})
}
}
// Tests two addresses with ORCA reporting disabled (should fall back to pure
// RR).
func (s) TestBalancer_TwoAddresses_ReportingDisabled(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
srv1 := startServer(t, reportNone)
srv2 := startServer(t, reportNone)
sc := svcConfig(t, perCallConfig)
if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil {
t.Fatalf("Error starting client: %v", err)
}
addrs := []resolver.Address{{Addr: srv1.Address}, {Addr: srv2.Address}}
srv1.R.UpdateState(resolver.State{Addresses: addrs})
// Perform many RPCs to ensure the LB policy works with 2 addresses.
for i := 0; i < 20; i++ {
roundrobin.CheckRoundRobinRPCs(ctx, srv1.Client, addrs)
}
}
// Tests two addresses with per-call ORCA reporting enabled. Checks the
// backends are called in the appropriate ratios.
func (s) TestBalancer_TwoAddresses_ReportingEnabledPerCall(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
srv1 := startServer(t, reportCall)
srv2 := startServer(t, reportCall)
// srv1 starts loaded and srv2 starts without load; ensure RPCs are routed
// disproportionately to srv2 (10:1).
srv1.callMetrics.SetQPS(10.0)
srv1.callMetrics.SetCPUUtilization(1.0)
srv2.callMetrics.SetQPS(10.0)
srv2.callMetrics.SetCPUUtilization(.1)
sc := svcConfig(t, perCallConfig)
if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil {
t.Fatalf("Error starting client: %v", err)
}
addrs := []resolver.Address{{Addr: srv1.Address}, {Addr: srv2.Address}}
srv1.R.UpdateState(resolver.State{Addresses: addrs})
// Call each backend once to ensure the weights have been received.
ensureReached(ctx, t, srv1.Client, 2)
// Wait for the weight update period to allow the new weights to be processed.
time.Sleep(weightUpdatePeriod)
checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10})
}
// Tests two addresses with OOB ORCA reporting enabled. Checks the backends
// are called in the appropriate ratios.
func (s) TestBalancer_TwoAddresses_ReportingEnabledOOB(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
srv1 := startServer(t, reportOOB)
srv2 := startServer(t, reportOOB)
// srv1 starts loaded and srv2 starts without load; ensure RPCs are routed
// disproportionately to srv2 (10:1).
srv1.oobMetrics.SetQPS(10.0)
srv1.oobMetrics.SetCPUUtilization(1.0)
srv2.oobMetrics.SetQPS(10.0)
srv2.oobMetrics.SetCPUUtilization(.1)
sc := svcConfig(t, oobConfig)
if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil {
t.Fatalf("Error starting client: %v", err)
}
addrs := []resolver.Address{{Addr: srv1.Address}, {Addr: srv2.Address}}
srv1.R.UpdateState(resolver.State{Addresses: addrs})
// Call each backend once to ensure the weights have been received.
ensureReached(ctx, t, srv1.Client, 2)
// Wait for the weight update period to allow the new weights to be processed.
time.Sleep(weightUpdatePeriod)
checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10})
}
// Tests two addresses with OOB ORCA reporting enabled, where the reports
// change over time. Checks the backends are called in the appropriate ratios
// before and after modifying the reports.
func (s) TestBalancer_TwoAddresses_UpdateLoads(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
srv1 := startServer(t, reportOOB)
srv2 := startServer(t, reportOOB)
// srv1 starts loaded and srv2 starts without load; ensure RPCs are routed
// disproportionately to srv2 (10:1).
srv1.oobMetrics.SetQPS(10.0)
srv1.oobMetrics.SetCPUUtilization(1.0)
srv2.oobMetrics.SetQPS(10.0)
srv2.oobMetrics.SetCPUUtilization(.1)
sc := svcConfig(t, oobConfig)
if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil {
t.Fatalf("Error starting client: %v", err)
}
addrs := []resolver.Address{{Addr: srv1.Address}, {Addr: srv2.Address}}
srv1.R.UpdateState(resolver.State{Addresses: addrs})
// Call each backend once to ensure the weights have been received.
ensureReached(ctx, t, srv1.Client, 2)
// Wait for the weight update period to allow the new weights to be processed.
time.Sleep(weightUpdatePeriod)
checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10})
// Update the loads so srv2 is loaded and srv1 is not; ensure RPCs are
// routed disproportionately to srv1.
srv1.oobMetrics.SetQPS(10.0)
srv1.oobMetrics.SetCPUUtilization(.1)
srv2.oobMetrics.SetQPS(10.0)
srv2.oobMetrics.SetCPUUtilization(1.0)
// Wait for the weight update period to allow the new weights to be processed.
time.Sleep(weightUpdatePeriod + oobReportingInterval)
checkWeights(ctx, t, srvWeight{srv1, 10}, srvWeight{srv2, 1})
}
// Tests two addresses with OOB ORCA reporting enabled, then with switching to
// per-call reporting. Checks the backends are called in the appropriate
// ratios before and after the change.
func (s) TestBalancer_TwoAddresses_OOBThenPerCall(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
srv1 := startServer(t, reportBoth)
srv2 := startServer(t, reportBoth)
// srv1 starts loaded and srv2 starts without load; ensure RPCs are routed
// disproportionately to srv2 (10:1).
srv1.oobMetrics.SetQPS(10.0)
srv1.oobMetrics.SetCPUUtilization(1.0)
srv2.oobMetrics.SetQPS(10.0)
srv2.oobMetrics.SetCPUUtilization(.1)
// For per-call metrics (not used initially), srv2 reports that it is
// loaded and srv1 reports low load. After confirming OOB works, switch to
// per-call and confirm the new routing weights are applied.
srv1.callMetrics.SetQPS(10.0)
srv1.callMetrics.SetCPUUtilization(.1)
srv2.callMetrics.SetQPS(10.0)
srv2.callMetrics.SetCPUUtilization(1.0)
sc := svcConfig(t, oobConfig)
if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil {
t.Fatalf("Error starting client: %v", err)
}
addrs := []resolver.Address{{Addr: srv1.Address}, {Addr: srv2.Address}}
srv1.R.UpdateState(resolver.State{Addresses: addrs})
// Call each backend once to ensure the weights have been received.
ensureReached(ctx, t, srv1.Client, 2)
// Wait for the weight update period to allow the new weights to be processed.
time.Sleep(weightUpdatePeriod)
checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10})
// Update to per-call weights.
c := svcConfig(t, perCallConfig)
parsedCfg := srv1.R.CC.ParseServiceConfig(c)
if parsedCfg.Err != nil {
panic(fmt.Sprintf("Error parsing config %q: %v", c, parsedCfg.Err))
}
srv1.R.UpdateState(resolver.State{Addresses: addrs, ServiceConfig: parsedCfg})
// Wait for the weight update period to allow the new weights to be processed.
time.Sleep(weightUpdatePeriod)
checkWeights(ctx, t, srvWeight{srv1, 10}, srvWeight{srv2, 1})
}
// Tests two addresses with OOB ORCA reporting enabled and a non-zero error
// penalty applied.
func (s) TestBalancer_TwoAddresses_ErrorPenalty(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
srv1 := startServer(t, reportOOB)
srv2 := startServer(t, reportOOB)
// srv1 starts loaded and srv2 starts without load; ensure RPCs are routed
// disproportionately to srv2 (10:1). EPS values are set (but ignored
// initially due to ErrorUtilizationPenalty=0). Later EUP will be updated
// to 0.9 which will cause the weights to be equal and RPCs to be routed
// 50/50.
srv1.oobMetrics.SetQPS(10.0)
srv1.oobMetrics.SetCPUUtilization(1.0)
srv1.oobMetrics.SetEPS(0)
// srv1 weight before: 10.0 / 1.0 = 10.0
// srv1 weight after: 10.0 / 1.0 = 10.0
srv2.oobMetrics.SetQPS(10.0)
srv2.oobMetrics.SetCPUUtilization(.1)
srv2.oobMetrics.SetEPS(10.0)
// srv2 weight before: 10.0 / 0.1 = 100.0
// srv2 weight after: 10.0 / 1.0 = 10.0
sc := svcConfig(t, oobConfig)
if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil {
t.Fatalf("Error starting client: %v", err)
}
addrs := []resolver.Address{{Addr: srv1.Address}, {Addr: srv2.Address}}
srv1.R.UpdateState(resolver.State{Addresses: addrs})
// Call each backend once to ensure the weights have been received.
ensureReached(ctx, t, srv1.Client, 2)
// Wait for the weight update period to allow the new weights to be processed.
time.Sleep(weightUpdatePeriod)
checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10})
// Update to include an error penalty in the weights.
newCfg := oobConfig
newCfg.ErrorUtilizationPenalty = float64p(0.9)
c := svcConfig(t, newCfg)
parsedCfg := srv1.R.CC.ParseServiceConfig(c)
if parsedCfg.Err != nil {
panic(fmt.Sprintf("Error parsing config %q: %v", c, parsedCfg.Err))
}
srv1.R.UpdateState(resolver.State{Addresses: addrs, ServiceConfig: parsedCfg})
// Wait for the weight update period to allow the new weights to be processed.
time.Sleep(weightUpdatePeriod + oobReportingInterval)
checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 1})
}
// Tests that the blackout period causes backends to use 0 as their weight
// (meaning to use the average weight) until the blackout period elapses.
func (s) TestBalancer_TwoAddresses_BlackoutPeriod(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
var mu sync.Mutex
start := time.Now()
now := start
setNow := func(t time.Time) {
mu.Lock()
defer mu.Unlock()
now = t
}
iwrr.TimeNow = func() time.Time {
mu.Lock()
defer mu.Unlock()
return now
}
t.Cleanup(func() { iwrr.TimeNow = time.Now })
testCases := []struct {
blackoutPeriodCfg *time.Duration
blackoutPeriod time.Duration
}{{
blackoutPeriodCfg: durationp(time.Second),
blackoutPeriod: time.Second,
}, {
blackoutPeriodCfg: nil,
blackoutPeriod: 10 * time.Second, // the default
}}
for _, tc := range testCases {
setNow(start)
srv1 := startServer(t, reportOOB)
srv2 := startServer(t, reportOOB)
// srv1 starts loaded and srv2 starts without load; ensure RPCs are routed
// disproportionately to srv2 (10:1).
srv1.oobMetrics.SetQPS(10.0)
srv1.oobMetrics.SetCPUUtilization(1.0)
srv2.oobMetrics.SetQPS(10.0)
srv2.oobMetrics.SetCPUUtilization(.1)
cfg := oobConfig
cfg.BlackoutPeriod = tc.blackoutPeriodCfg
sc := svcConfig(t, cfg)
if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil {
t.Fatalf("Error starting client: %v", err)
}
addrs := []resolver.Address{{Addr: srv1.Address}, {Addr: srv2.Address}}
srv1.R.UpdateState(resolver.State{Addresses: addrs})
// Call each backend once to ensure the weights have been received.
ensureReached(ctx, t, srv1.Client, 2)
// Wait for the weight update period to allow the new weights to be processed.
time.Sleep(weightUpdatePeriod)
// During the blackout period (1s) we should route roughly 50/50.
checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 1})
// Advance time to right before the blackout period ends and the weights
// should still be zero.
setNow(start.Add(tc.blackoutPeriod - time.Nanosecond))
// Wait for the weight update period to allow the new weights to be processed.
time.Sleep(weightUpdatePeriod)
checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 1})
// Advance time to right after the blackout period ends and the weights
// should now activate.
setNow(start.Add(tc.blackoutPeriod))
// Wait for the weight update period to allow the new weights to be processed.
time.Sleep(weightUpdatePeriod)
checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10})
}
}
// Tests that the weight expiration period causes backends to use 0 as their
// weight (meaning to use the average weight) once the expiration period
// elapses.
func (s) TestBalancer_TwoAddresses_WeightExpiration(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
var mu sync.Mutex
start := time.Now()
now := start
setNow := func(t time.Time) {
mu.Lock()
defer mu.Unlock()
now = t
}
iwrr.TimeNow = func() time.Time {
mu.Lock()
defer mu.Unlock()
return now
}
t.Cleanup(func() { iwrr.TimeNow = time.Now })
srv1 := startServer(t, reportBoth)
srv2 := startServer(t, reportBoth)
// srv1 starts loaded and srv2 starts without load; ensure RPCs are routed
// disproportionately to srv2 (10:1). Because the OOB reporting interval
// is 1 minute but the weights expire in 1 second, routing will go to 50/50
// after the weights expire.
srv1.oobMetrics.SetQPS(10.0)
srv1.oobMetrics.SetCPUUtilization(1.0)
srv2.oobMetrics.SetQPS(10.0)
srv2.oobMetrics.SetCPUUtilization(.1)
cfg := oobConfig
cfg.OOBReportingPeriod = durationp(time.Minute)
sc := svcConfig(t, cfg)
if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil {
t.Fatalf("Error starting client: %v", err)
}
addrs := []resolver.Address{{Addr: srv1.Address}, {Addr: srv2.Address}}
srv1.R.UpdateState(resolver.State{Addresses: addrs})
// Call each backend once to ensure the weights have been received.
ensureReached(ctx, t, srv1.Client, 2)
// Wait for the weight update period to allow the new weights to be processed.
time.Sleep(weightUpdatePeriod)
checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10})
// Advance what time.Now returns to the weight expiration time minus 1s to
// ensure all weights are still honored.
setNow(start.Add(*cfg.WeightExpirationPeriod - time.Second))
// Wait for the weight update period to allow the new weights to be processed.
time.Sleep(weightUpdatePeriod)
checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10})
// Advance what time.Now returns to the weight expiration time plus 1s to
// ensure all weights expired and addresses are routed evenly.
setNow(start.Add(*cfg.WeightExpirationPeriod + time.Second))
// Wait for the weight expiration period so the weights have expired.
time.Sleep(weightUpdatePeriod)
checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 1})
}
// Tests logic surrounding subchannel management.
func (s) TestBalancer_AddressesChanging(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
srv1 := startServer(t, reportBoth)
srv2 := startServer(t, reportBoth)
srv3 := startServer(t, reportBoth)
srv4 := startServer(t, reportBoth)
// srv1: weight 10
srv1.oobMetrics.SetQPS(10.0)
srv1.oobMetrics.SetCPUUtilization(1.0)
// srv2: weight 100
srv2.oobMetrics.SetQPS(10.0)
srv2.oobMetrics.SetCPUUtilization(.1)
// srv3: weight 20
srv3.oobMetrics.SetQPS(20.0)
srv3.oobMetrics.SetCPUUtilization(1.0)
// srv4: weight 200
srv4.oobMetrics.SetQPS(20.0)
srv4.oobMetrics.SetCPUUtilization(.1)
sc := svcConfig(t, oobConfig)
if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil {
t.Fatalf("Error starting client: %v", err)
}
srv2.Client = srv1.Client
addrs := []resolver.Address{{Addr: srv1.Address}, {Addr: srv2.Address}, {Addr: srv3.Address}}
srv1.R.UpdateState(resolver.State{Addresses: addrs})
// Call each backend once to ensure the weights have been received.
ensureReached(ctx, t, srv1.Client, 3)
time.Sleep(weightUpdatePeriod)
checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10}, srvWeight{srv3, 2})
// Add backend 4
addrs = append(addrs, resolver.Address{Addr: srv4.Address})
srv1.R.UpdateState(resolver.State{Addresses: addrs})
time.Sleep(weightUpdatePeriod)
checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10}, srvWeight{srv3, 2}, srvWeight{srv4, 20})
// Shutdown backend 3. RPCs will no longer be routed to it.
srv3.Stop()
time.Sleep(weightUpdatePeriod)
checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv2, 10}, srvWeight{srv4, 20})
// Remove addresses 2 and 3. RPCs will no longer be routed to 2 either.
addrs = []resolver.Address{{Addr: srv1.Address}, {Addr: srv4.Address}}
srv1.R.UpdateState(resolver.State{Addresses: addrs})
time.Sleep(weightUpdatePeriod)
checkWeights(ctx, t, srvWeight{srv1, 1}, srvWeight{srv4, 20})
// Re-add 2 and remove the rest.
addrs = []resolver.Address{{Addr: srv2.Address}}
srv1.R.UpdateState(resolver.State{Addresses: addrs})
time.Sleep(weightUpdatePeriod)
checkWeights(ctx, t, srvWeight{srv2, 10})
// Re-add 4.
addrs = append(addrs, resolver.Address{Addr: srv4.Address})
srv1.R.UpdateState(resolver.State{Addresses: addrs})
time.Sleep(weightUpdatePeriod)
checkWeights(ctx, t, srvWeight{srv2, 10}, srvWeight{srv4, 20})
}
func ensureReached(ctx context.Context, t *testing.T, c testgrpc.TestServiceClient, n int) {
t.Helper()
reached := make(map[string]struct{})
for len(reached) != n {
var peer peer.Peer
if _, err := c.EmptyCall(ctx, &testpb.Empty{}, grpc.Peer(&peer)); err != nil {
t.Fatalf("Error from EmptyCall: %v", err)
}
reached[peer.Addr.String()] = struct{}{}
}
}
type srvWeight struct {
srv *testServer
w int
}
const rrIterations = 100
// checkWeights does rrIterations RPCs and expects the different backends to be
// routed in a ratio as deterimined by the srvWeights passed in. Allows for
// some variance (+/- 2 RPCs per backend).
func checkWeights(ctx context.Context, t *testing.T, sws ...srvWeight) {
t.Helper()
c := sws[0].srv.Client
// Replace the weights with approximate counts of RPCs wanted given the
// iterations performed.
weightSum := 0
for _, sw := range sws {
weightSum += sw.w
}
for i := range sws {
sws[i].w = rrIterations * sws[i].w / weightSum
}
for attempts := 0; attempts < 10; attempts++ {
serverCounts := make(map[string]int)
for i := 0; i < rrIterations; i++ {
var peer peer.Peer
if _, err := c.EmptyCall(ctx, &testpb.Empty{}, grpc.Peer(&peer)); err != nil {
t.Fatalf("Error from EmptyCall: %v; timed out waiting for weighted RR behavior?", err)
}
serverCounts[peer.Addr.String()]++
}
if len(serverCounts) != len(sws) {
continue
}
success := true
for _, sw := range sws {
c := serverCounts[sw.srv.Address]
if c < sw.w-2 || c > sw.w+2 {
success = false
break
}
}
if success {
t.Logf("Passed iteration %v; counts: %v", attempts, serverCounts)
return
}
t.Logf("Failed iteration %v; counts: %v; want %+v", attempts, serverCounts, sws)
time.Sleep(5 * time.Millisecond)
}
t.Fatalf("Failed to route RPCs with proper ratio")
}

View File

@ -0,0 +1,60 @@
/*
*
* Copyright 2023 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package weightedroundrobin
import (
"time"
"google.golang.org/grpc/serviceconfig"
)
type lbConfig struct {
serviceconfig.LoadBalancingConfig `json:"-"`
// Whether to enable out-of-band utilization reporting collection from the
// endpoints. By default, per-request utilization reporting is used.
EnableOOBLoadReport bool `json:"enableOobLoadReport,omitempty"`
// Load reporting interval to request from the server. Note that the
// server may not provide reports as frequently as the client requests.
// Used only when enable_oob_load_report is true. Default is 10 seconds.
OOBReportingPeriod time.Duration `json:"oobReportingPeriod,omitempty"`
// A given endpoint must report load metrics continuously for at least this
// long before the endpoint weight will be used. This avoids churn when
// the set of endpoint addresses changes. Takes effect both immediately
// after we establish a connection to an endpoint and after
// weight_expiration_period has caused us to stop using the most recent
// load metrics. Default is 10 seconds.
BlackoutPeriod time.Duration `json:"blackoutPeriod,omitempty"`
// If a given endpoint has not reported load metrics in this long,
// then we stop using the reported weight. This ensures that we do
// not continue to use very stale weights. Once we stop using a stale
// value, if we later start seeing fresh reports again, the
// blackout_period applies. Defaults to 3 minutes.
WeightExpirationPeriod time.Duration `json:"weightExpirationPeriod,omitempty"`
// How often endpoint weights are recalculated. Default is 1 second.
WeightUpdatePeriod time.Duration `json:"weightUpdatePeriod,omitempty"`
// The multiplier used to adjust endpoint weights with the error rate
// calculated as eps/qps. Default is 1.0.
ErrorUtilizationPenalty float64 `json:"errorUtilizationPenalty,omitempty"`
}

View File

@ -0,0 +1,44 @@
/*
*
* Copyright 2023 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package internal allows for easier testing of the weightedroundrobin
// package.
package internal
import (
"time"
)
// AllowAnyWeightUpdatePeriod permits any setting of WeightUpdatePeriod for
// testing. Normally a minimum of 100ms is applied.
var AllowAnyWeightUpdatePeriod bool
// LBConfig allows tests to produce a JSON form of the config from the struct
// instead of using a string.
type LBConfig struct {
EnableOOBLoadReport *bool `json:"enableOobLoadReport,omitempty"`
OOBReportingPeriod *time.Duration `json:"oobReportingPeriod,omitempty"`
BlackoutPeriod *time.Duration `json:"blackoutPeriod,omitempty"`
WeightExpirationPeriod *time.Duration `json:"weightExpirationPeriod,omitempty"`
WeightUpdatePeriod *time.Duration `json:"weightUpdatePeriod,omitempty"`
ErrorUtilizationPenalty *float64 `json:"errorUtilizationPenalty,omitempty"`
}
// TimeNow can be overridden by tests to return a different value for the
// current time.
var TimeNow = time.Now

View File

@ -0,0 +1,34 @@
/*
*
* Copyright 2023 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package weightedroundrobin
import (
"fmt"
"google.golang.org/grpc/grpclog"
internalgrpclog "google.golang.org/grpc/internal/grpclog"
)
const prefix = "[%p] "
var logger = grpclog.Component("weighted-round-robin")
func prefixLogger(p *wrrBalancer) *internalgrpclog.PrefixLogger {
return internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(prefix, p))
}

View File

@ -0,0 +1,138 @@
/*
*
* Copyright 2023 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package weightedroundrobin
import (
"math"
)
type scheduler interface {
nextIndex() int
}
// newScheduler uses scWeights to create a new scheduler for selecting subconns
// in a picker. It will return a round robin implementation if at least
// len(scWeights)-1 are zero or there is only a single subconn, otherwise it
// will return an Earliest Deadline First (EDF) scheduler implementation that
// selects the subchannels according to their weights.
func newScheduler(scWeights []float64, inc func() uint32) scheduler {
n := len(scWeights)
if n == 0 {
return nil
}
if n == 1 {
return &rrScheduler{numSCs: 1, inc: inc}
}
sum := float64(0)
numZero := 0
max := float64(0)
for _, w := range scWeights {
sum += w
if w > max {
max = w
}
if w == 0 {
numZero++
}
}
if numZero >= n-1 {
return &rrScheduler{numSCs: uint32(n), inc: inc}
}
unscaledMean := sum / float64(n-numZero)
scalingFactor := maxWeight / max
mean := uint16(math.Round(scalingFactor * unscaledMean))
weights := make([]uint16, n)
allEqual := true
for i, w := range scWeights {
if w == 0 {
// Backends with weight = 0 use the mean.
weights[i] = mean
} else {
scaledWeight := uint16(math.Round(scalingFactor * w))
weights[i] = scaledWeight
if scaledWeight != mean {
allEqual = false
}
}
}
if allEqual {
return &rrScheduler{numSCs: uint32(n), inc: inc}
}
logger.Infof("using edf scheduler with weights: %v", weights)
return &edfScheduler{weights: weights, inc: inc}
}
const maxWeight = math.MaxUint16
// edfScheduler implements EDF using the same algorithm as grpc-c++ here:
//
// https://github.com/grpc/grpc/blob/master/src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.cc
type edfScheduler struct {
inc func() uint32
weights []uint16
}
// Returns the index in s.weights for the picker to choose.
func (s *edfScheduler) nextIndex() int {
const offset = maxWeight / 2
for {
idx := uint64(s.inc())
// The sequence number (idx) is split in two: the lower %n gives the
// index of the backend, and the rest gives the number of times we've
// iterated through all backends. `generation` is used to
// deterministically decide whether we pick or skip the backend on this
// iteration, in proportion to the backend's weight.
backendIndex := idx % uint64(len(s.weights))
generation := idx / uint64(len(s.weights))
weight := uint64(s.weights[backendIndex])
// We pick a backend `weight` times per `maxWeight` generations. The
// multiply and modulus ~evenly spread out the picks for a given
// backend between different generations. The offset by `backendIndex`
// helps to reduce the chance of multiple consecutive non-picks: if we
// have two consecutive backends with an equal, say, 80% weight of the
// max, with no offset we would see 1/5 generations that skipped both.
// TODO(b/190488683): add test for offset efficacy.
mod := uint64(weight*generation+backendIndex*offset) % maxWeight
if mod < maxWeight-weight {
continue
}
return int(backendIndex)
}
}
// A simple RR scheduler to use for fallback when fewer than two backends have
// non-zero weights, or all backends have the the same weight, or when only one
// subconn exists.
type rrScheduler struct {
inc func() uint32
numSCs uint32
}
func (s *rrScheduler) nextIndex() int {
idx := s.inc()
return int(idx % s.numSCs)
}

View File

@ -16,16 +16,21 @@
*
*/
// Package weightedroundrobin defines a weighted roundrobin balancer.
// Package weightedroundrobin provides an implementation of the weighted round
// robin LB policy, as defined in [gRFC A58].
//
// # Experimental
//
// Notice: This package is EXPERIMENTAL and may be changed or removed in a
// later release.
//
// [gRFC A58]: https://github.com/grpc/proposal/blob/master/A58-client-side-weighted-round-robin-lb-policy.md
package weightedroundrobin
import (
"google.golang.org/grpc/resolver"
)
// Name is the name of weighted_round_robin balancer.
const Name = "weighted_round_robin"
// attributeKey is the type used as the key to store AddrInfo in the
// BalancerAttributes field of resolver.Address.
type attributeKey struct{}
@ -44,11 +49,6 @@ func (a AddrInfo) Equal(o interface{}) bool {
// SetAddrInfo returns a copy of addr in which the BalancerAttributes field is
// updated with addrInfo.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func SetAddrInfo(addr resolver.Address, addrInfo AddrInfo) resolver.Address {
addr.BalancerAttributes = addr.BalancerAttributes.WithValue(attributeKey{}, addrInfo)
return addr
@ -56,11 +56,6 @@ func SetAddrInfo(addr resolver.Address, addrInfo AddrInfo) resolver.Address {
// GetAddrInfo returns the AddrInfo stored in the BalancerAttributes field of
// addr.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func GetAddrInfo(addr resolver.Address) AddrInfo {
v := addr.BalancerAttributes.Value(attributeKey{})
ai, _ := v.(AddrInfo)

View File

@ -72,3 +72,10 @@ func Uint64() uint64 {
defer mu.Unlock()
return r.Uint64()
}
// Uint32 implements rand.Uint32 on the grpcrand global source.
func Uint32() uint32 {
mu.Lock()
defer mu.Unlock()
return r.Uint32()
}

View File

@ -199,12 +199,13 @@ func (p *producer) run(ctx context.Context, done chan struct{}, interval time.Du
// Unimplemented; do not retry.
logger.Error("Server doesn't support ORCA OOB load reporting protocol; not listening for load reports.")
return
case status.Code(err) == codes.Unavailable:
// TODO: this code should ideally log an error, too, but for now we
// receive this code when shutting down the ClientConn. Once we
// can determine the state or ensure the producer is stopped before
// the stream ends, we can log an error when it's not a natural
// shutdown.
case status.Code(err) == codes.Unavailable, status.Code(err) == codes.Canceled:
// TODO: these codes should ideally log an error, too, but for now
// we receive them when shutting down the ClientConn (Unavailable
// if the stream hasn't started yet, and Canceled if it happens
// mid-stream). Once we can determine the state or ensure the
// producer is stopped before the stream ends, we can log an error
// when it's not a natural shutdown.
default:
// Log all other errors.
logger.Error("Received unexpected stream error:", err)

View File

@ -160,7 +160,7 @@ func (d *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
d.loadStore.CallFinished(lIDStr, info.Err)
load, ok := info.ServerLoad.(*v3orcapb.OrcaLoadReport)
if !ok {
if !ok || load == nil {
return
}
d.loadStore.CallServerLoad(lIDStr, serverLoadCPUName, load.CpuUtilization)