xds: fix wrr stuck in rr mode (#10061)

This commit is contained in:
yifeizhuang 2023-04-18 16:39:51 -07:00 committed by GitHub
parent 35852130d9
commit 111ff60e1c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 72 additions and 18 deletions

View File

@ -46,6 +46,8 @@ import java.util.PriorityQueue;
import java.util.Random; import java.util.Random;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;
/** /**
* A {@link LoadBalancer} that provides weighted-round-robin load-balancing over * A {@link LoadBalancer} that provides weighted-round-robin load-balancing over
@ -54,6 +56,8 @@ import java.util.concurrent.TimeUnit;
*/ */
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/9885") @ExperimentalApi("https://github.com/grpc/grpc-java/issues/9885")
final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer { final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
private static final Logger log = Logger.getLogger(
WeightedRoundRobinLoadBalancer.class.getName());
private volatile WeightedRoundRobinLoadBalancerConfig config; private volatile WeightedRoundRobinLoadBalancerConfig config;
private final SynchronizationContext syncContext; private final SynchronizationContext syncContext;
private final ScheduledExecutorService timeService; private final ScheduledExecutorService timeService;
@ -76,6 +80,7 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService"); this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService");
this.updateWeightTask = new UpdateWeightTask(); this.updateWeightTask = new UpdateWeightTask();
this.random = random; this.random = random;
log.log(Level.FINE, "weighted_round_robin LB created");
} }
@VisibleForTesting @VisibleForTesting
@ -230,7 +235,6 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
final class WeightedRoundRobinPicker extends ReadyPicker { final class WeightedRoundRobinPicker extends ReadyPicker {
private final List<Subchannel> list; private final List<Subchannel> list;
private volatile EdfScheduler scheduler; private volatile EdfScheduler scheduler;
private volatile boolean rrMode;
WeightedRoundRobinPicker(List<Subchannel> list) { WeightedRoundRobinPicker(List<Subchannel> list) {
super(checkNotNull(list, "list"), random.nextInt(list.size())); super(checkNotNull(list, "list"), random.nextInt(list.size()));
@ -241,16 +245,11 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
@Override @Override
public PickResult pickSubchannel(PickSubchannelArgs args) { public PickResult pickSubchannel(PickSubchannelArgs args) {
if (rrMode) { Subchannel subchannel = list.get(scheduler.pick());
return super.pickSubchannel(args);
}
int pickIndex = scheduler.pick();
WrrSubchannel subchannel = (WrrSubchannel) list.get(pickIndex);
if (!config.enableOobLoadReport) { if (!config.enableOobLoadReport) {
return PickResult.withSubchannel( return PickResult.withSubchannel(subchannel,
subchannel, OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory(
OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory( ((WrrSubchannel)subchannel).perRpcListener));
subchannel.perRpcListener));
} else { } else {
return PickResult.withSubchannel(subchannel); return PickResult.withSubchannel(subchannel);
} }
@ -266,25 +265,24 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
weightedChannelCount++; weightedChannelCount++;
} }
} }
if (weightedChannelCount < 2) {
rrMode = true;
return;
}
EdfScheduler scheduler = new EdfScheduler(list.size(), random); EdfScheduler scheduler = new EdfScheduler(list.size(), random);
avgWeight /= 1.0 * weightedChannelCount; if (weightedChannelCount >= 1) {
avgWeight /= 1.0 * weightedChannelCount;
} else {
avgWeight = 1;
}
for (int i = 0; i < list.size(); i++) { for (int i = 0; i < list.size(); i++) {
WrrSubchannel subchannel = (WrrSubchannel) list.get(i); WrrSubchannel subchannel = (WrrSubchannel) list.get(i);
double newWeight = subchannel.getWeight(); double newWeight = subchannel.getWeight();
scheduler.add(i, newWeight > 0 ? newWeight : avgWeight); scheduler.add(i, newWeight > 0 ? newWeight : avgWeight);
} }
this.scheduler = scheduler; this.scheduler = scheduler;
rrMode = false;
} }
@Override @Override
public String toString() { public String toString() {
return MoreObjects.toStringHelper(WeightedRoundRobinPicker.class) return MoreObjects.toStringHelper(WeightedRoundRobinPicker.class)
.add("list", list).add("rrMode", rrMode).toString(); .add("list", list).toString();
} }
@VisibleForTesting @VisibleForTesting

View File

@ -29,6 +29,7 @@ import static org.mockito.Mockito.when;
import com.github.xds.data.orca.v3.OrcaLoadReport; import com.github.xds.data.orca.v3.OrcaLoadReport;
import com.github.xds.service.orca.v3.OrcaLoadReportRequest; import com.github.xds.service.orca.v3.OrcaLoadReportRequest;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Maps; import com.google.common.collect.Maps;
import com.google.protobuf.Duration; import com.google.protobuf.Duration;
@ -514,6 +515,62 @@ public class WeightedRoundRobinLoadBalancerTest {
.isAtMost(0.001); .isAtMost(0.001);
} }
@Test
public void rrFallback() {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
verify(helper, times(3)).createSubchannel(
any(CreateSubchannelArgs.class));
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
verify(helper, times(2)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
WeightedRoundRobinPicker weightedPicker = pickerCaptor.getAllValues().get(1);
assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1);
WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0);
WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1);
Map<WrrSubchannel, Integer> qpsByChannel = ImmutableMap.of(weightedSubchannel1, 2,
weightedSubchannel2, 1);
Map<Subchannel, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
PickResult pickResult = weightedPicker.pickSubchannel(mockArgs);
pickCount.put(pickResult.getSubchannel(),
pickCount.getOrDefault(pickResult.getSubchannel(), 0) + 1);
assertThat(pickResult.getStreamTracerFactory()).isNotNull();
WrrSubchannel subchannel = (WrrSubchannel)pickResult.getSubchannel();
subchannel.onLoadReport(InternalCallMetricRecorder.createMetricReport(
0.1, 0.1, qpsByChannel.get(subchannel), new HashMap<>(), new HashMap<>()));
}
assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 1.0 / 2))
.isAtMost(0.1);
assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 2))
.isAtMost(0.1);
pickCount.clear();
for (int i = 0; i < 1000; i++) {
PickResult pickResult = weightedPicker.pickSubchannel(mockArgs);
pickCount.put(pickResult.getSubchannel(),
pickCount.getOrDefault(pickResult.getSubchannel(), 0) + 1);
assertThat(pickResult.getStreamTracerFactory()).isNotNull();
WrrSubchannel subchannel = (WrrSubchannel)pickResult.getSubchannel();
subchannel.onLoadReport(InternalCallMetricRecorder.createMetricReport(
0.1, 0.1, qpsByChannel.get(subchannel), new HashMap<>(), new HashMap<>()));
fakeClock.forwardTime(50, TimeUnit.MILLISECONDS);
}
assertThat(pickCount.size()).isEqualTo(2);
assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 2.0 / 3))
.isAtMost(0.1);
assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 3))
.isAtMost(0.1);
}
@Test @Test
public void unknownWeightIsAvgWeight() { public void unknownWeightIsAvgWeight() {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
@ -584,7 +641,6 @@ public class WeightedRoundRobinLoadBalancerTest {
0.1, 0.1, 1, new HashMap<>(), new HashMap<>())); 0.1, 0.1, 1, new HashMap<>(), new HashMap<>()));
weightedSubchannel2.onLoadReport(InternalCallMetricRecorder.createMetricReport( weightedSubchannel2.onLoadReport(InternalCallMetricRecorder.createMetricReport(
0.2, 0.1, 1, new HashMap<>(), new HashMap<>())); 0.2, 0.1, 1, new HashMap<>(), new HashMap<>()));
assertThat(weightedPicker.toString()).contains("rrMode=true");
CyclicBarrier barrier = new CyclicBarrier(2); CyclicBarrier barrier = new CyclicBarrier(2);
Map<Subchannel, AtomicInteger> pickCount = new ConcurrentHashMap<>(); Map<Subchannel, AtomicInteger> pickCount = new ConcurrentHashMap<>();
pickCount.put(weightedSubchannel1, new AtomicInteger(0)); pickCount.put(weightedSubchannel1, new AtomicInteger(0));