From 0d47f5bd1baff87d412223c4ac22ea061eafb506 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Mon, 12 Aug 2024 11:23:37 -0700 Subject: [PATCH] xds: WRRPicker must not access unsynchronized data in ChildLbState There was no point to using subchannels as keys to subchannelToReportListenerMap, as the listener is per-child. That meant the keys would be guaranteed to be known ahead-of-time and the unsynchronized getOrCreateOrcaListener() during picking was unnecessary. The picker still stores ChildLbStates to make sure that updating weights uses the correct children, but the picker itself no longer references ChildLbStates except in the constructor. That means weight calculation is moved into the LB policy, as child.getWeight() is unsynchronized, and the picker no longer needs a reference to helper. --- .../xds/WeightedRoundRobinLoadBalancer.java | 132 +++++++++--------- .../WeightedRoundRobinLoadBalancerTest.java | 2 +- 2 files changed, 67 insertions(+), 67 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java index f45bb571a3..e4502da874 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java @@ -44,11 +44,10 @@ import io.grpc.xds.orca.OrcaOobUtil; import io.grpc.xds.orca.OrcaOobUtil.OrcaOobReportListener; import io.grpc.xds.orca.OrcaPerRequestUtil; import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaPerRequestReportListener; +import java.util.ArrayList; import java.util.Collection; -import java.util.HashMap; import java.util.HashSet; import java.util.List; -import java.util.Map; import java.util.Random; import java.util.Set; import java.util.concurrent.ScheduledExecutorService; @@ -233,9 +232,44 @@ final class WeightedRoundRobinLoadBalancer extends MultiChildLoadBalancer { } private SubchannelPicker createReadyPicker(Collection activeList) { - return new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList), - config.enableOobLoadReport, config.errorUtilizationPenalty, sequence, getHelper(), - locality); + WeightedRoundRobinPicker picker = new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList), + config.enableOobLoadReport, config.errorUtilizationPenalty, sequence); + updateWeight(picker); + return picker; + } + + private void updateWeight(WeightedRoundRobinPicker picker) { + Helper helper = getHelper(); + float[] newWeights = new float[picker.children.size()]; + AtomicInteger staleEndpoints = new AtomicInteger(); + AtomicInteger notYetUsableEndpoints = new AtomicInteger(); + for (int i = 0; i < picker.children.size(); i++) { + double newWeight = ((WeightedChildLbState) picker.children.get(i)).getWeight(staleEndpoints, + notYetUsableEndpoints); + helper.getMetricRecorder() + .recordDoubleHistogram(ENDPOINT_WEIGHTS_HISTOGRAM, newWeight, + ImmutableList.of(helper.getChannelTarget()), + ImmutableList.of(locality)); + newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f; + } + + if (staleEndpoints.get() > 0) { + helper.getMetricRecorder() + .addLongCounter(ENDPOINT_WEIGHT_STALE_COUNTER, staleEndpoints.get(), + ImmutableList.of(helper.getChannelTarget()), + ImmutableList.of(locality)); + } + if (notYetUsableEndpoints.get() > 0) { + helper.getMetricRecorder() + .addLongCounter(ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER, notYetUsableEndpoints.get(), + ImmutableList.of(helper.getChannelTarget()), ImmutableList.of(locality)); + } + boolean weightsEffective = picker.updateWeight(newWeights); + if (!weightsEffective) { + helper.getMetricRecorder() + .addLongCounter(RR_FALLBACK_COUNTER, 1, ImmutableList.of(helper.getChannelTarget()), + ImmutableList.of(locality)); + } } private void updateBalancingState(ConnectivityState state, SubchannelPicker picker) { @@ -345,7 +379,7 @@ final class WeightedRoundRobinLoadBalancer extends MultiChildLoadBalancer { @Override public void run() { if (currentPicker != null && currentPicker instanceof WeightedRoundRobinPicker) { - ((WeightedRoundRobinPicker) currentPicker).updateWeight(); + updateWeight((WeightedRoundRobinPicker) currentPicker); } weightUpdateTimer = syncContext.schedule(this, config.weightUpdatePeriodNanos, TimeUnit.NANOSECONDS, timeService); @@ -415,53 +449,50 @@ final class WeightedRoundRobinLoadBalancer extends MultiChildLoadBalancer { @VisibleForTesting static final class WeightedRoundRobinPicker extends SubchannelPicker { - private final List children; - private final Map subchannelToReportListenerMap = - new HashMap<>(); + // Parallel lists (column-based storage instead of normal row-based storage of List). + // The ith element of children corresponds to the ith element of pickers, listeners, and even + // updateWeight(float[]). + private final List children; // May only be accessed from sync context + private final List pickers; + private final List reportListeners; private final boolean enableOobLoadReport; private final float errorUtilizationPenalty; private final AtomicInteger sequence; private final int hashCode; - private final LoadBalancer.Helper helper; - private final String locality; private volatile StaticStrideScheduler scheduler; WeightedRoundRobinPicker(List children, boolean enableOobLoadReport, - float errorUtilizationPenalty, AtomicInteger sequence, LoadBalancer.Helper helper, - String locality) { + float errorUtilizationPenalty, AtomicInteger sequence) { checkNotNull(children, "children"); Preconditions.checkArgument(!children.isEmpty(), "empty child list"); this.children = children; + List pickers = new ArrayList<>(children.size()); + List reportListeners = new ArrayList<>(children.size()); for (ChildLbState child : children) { WeightedChildLbState wChild = (WeightedChildLbState) child; - for (WrrSubchannel subchannel : wChild.subchannels) { - this.subchannelToReportListenerMap - .put(subchannel, wChild.getOrCreateOrcaListener(errorUtilizationPenalty)); - } + pickers.add(wChild.getCurrentPicker()); + reportListeners.add(wChild.getOrCreateOrcaListener(errorUtilizationPenalty)); } + this.pickers = pickers; + this.reportListeners = reportListeners; this.enableOobLoadReport = enableOobLoadReport; this.errorUtilizationPenalty = errorUtilizationPenalty; this.sequence = checkNotNull(sequence, "sequence"); - this.helper = helper; - this.locality = checkNotNull(locality, "locality"); - // For equality we treat children as a set; use hash code as defined by Set + // For equality we treat pickers as a set; use hash code as defined by Set int sum = 0; - for (ChildLbState child : children) { - sum += child.hashCode(); + for (SubchannelPicker picker : pickers) { + sum += picker.hashCode(); } this.hashCode = sum ^ Boolean.hashCode(enableOobLoadReport) ^ Float.hashCode(errorUtilizationPenalty); - - updateWeight(); } @Override public PickResult pickSubchannel(PickSubchannelArgs args) { - ChildLbState childLbState = children.get(scheduler.pick()); - WeightedChildLbState wChild = (WeightedChildLbState) childLbState; - PickResult pickResult = childLbState.getCurrentPicker().pickSubchannel(args); + int pick = scheduler.pick(); + PickResult pickResult = pickers.get(pick).pickSubchannel(args); Subchannel subchannel = pickResult.getSubchannel(); if (subchannel == null) { return pickResult; @@ -469,48 +500,16 @@ final class WeightedRoundRobinLoadBalancer extends MultiChildLoadBalancer { if (!enableOobLoadReport) { return PickResult.withSubchannel(subchannel, OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory( - subchannelToReportListenerMap.getOrDefault(subchannel, - wChild.getOrCreateOrcaListener(errorUtilizationPenalty)))); + reportListeners.get(pick))); } else { return PickResult.withSubchannel(subchannel); } } - private void updateWeight() { - float[] newWeights = new float[children.size()]; - AtomicInteger staleEndpoints = new AtomicInteger(); - AtomicInteger notYetUsableEndpoints = new AtomicInteger(); - for (int i = 0; i < children.size(); i++) { - double newWeight = ((WeightedChildLbState) children.get(i)).getWeight(staleEndpoints, - notYetUsableEndpoints); - // TODO: add locality label once available - helper.getMetricRecorder() - .recordDoubleHistogram(ENDPOINT_WEIGHTS_HISTOGRAM, newWeight, - ImmutableList.of(helper.getChannelTarget()), - ImmutableList.of(locality)); - newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f; - } - if (staleEndpoints.get() > 0) { - // TODO: add locality label once available - helper.getMetricRecorder() - .addLongCounter(ENDPOINT_WEIGHT_STALE_COUNTER, staleEndpoints.get(), - ImmutableList.of(helper.getChannelTarget()), - ImmutableList.of(locality)); - } - if (notYetUsableEndpoints.get() > 0) { - // TODO: add locality label once available - helper.getMetricRecorder() - .addLongCounter(ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER, notYetUsableEndpoints.get(), - ImmutableList.of(helper.getChannelTarget()), ImmutableList.of(locality)); - } - + /** Returns {@code true} if weights are different than round_robin. */ + private boolean updateWeight(float[] newWeights) { this.scheduler = new StaticStrideScheduler(newWeights, sequence); - if (this.scheduler.usesRoundRobin()) { - // TODO: locality label once available - helper.getMetricRecorder() - .addLongCounter(RR_FALLBACK_COUNTER, 1, ImmutableList.of(helper.getChannelTarget()), - ImmutableList.of(locality)); - } + return !this.scheduler.usesRoundRobin(); } @Override @@ -518,7 +517,8 @@ final class WeightedRoundRobinLoadBalancer extends MultiChildLoadBalancer { return MoreObjects.toStringHelper(WeightedRoundRobinPicker.class) .add("enableOobLoadReport", enableOobLoadReport) .add("errorUtilizationPenalty", errorUtilizationPenalty) - .add("list", children).toString(); + .add("pickers", pickers) + .toString(); } @VisibleForTesting @@ -545,8 +545,8 @@ final class WeightedRoundRobinLoadBalancer extends MultiChildLoadBalancer { && sequence == other.sequence && enableOobLoadReport == other.enableOobLoadReport && Float.compare(errorUtilizationPenalty, other.errorUtilizationPenalty) == 0 - && children.size() == other.children.size() - && new HashSet<>(children).containsAll(other.children); + && pickers.size() == other.pickers.size() + && new HashSet<>(pickers).containsAll(other.pickers); } } diff --git a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java index dd98f1e1ae..05ad1f56ec 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java @@ -244,7 +244,7 @@ public class WeightedRoundRobinLoadBalancerTest { String weightedPickerStr = weightedPicker.toString(); assertThat(weightedPickerStr).contains("enableOobLoadReport=false"); assertThat(weightedPickerStr).contains("errorUtilizationPenalty=1.0"); - assertThat(weightedPickerStr).contains("list="); + assertThat(weightedPickerStr).contains("pickers="); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1);