xds: replace random with atomic sequence in WRR (#10458)

This commit is contained in:
Tony An 2023-08-07 16:19:50 -07:00 committed by GitHub
parent 4049f89e13
commit 40bff673c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 24 deletions

View File

@ -65,7 +65,7 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
private final ScheduledExecutorService timeService; private final ScheduledExecutorService timeService;
private ScheduledHandle weightUpdateTimer; private ScheduledHandle weightUpdateTimer;
private final Runnable updateWeightTask; private final Runnable updateWeightTask;
private final Random random; private final AtomicInteger sequence;
private final long infTime; private final long infTime;
private final Ticker ticker; private final Ticker ticker;
@ -81,7 +81,7 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext"); this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext");
this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService"); this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService");
this.updateWeightTask = new UpdateWeightTask(); this.updateWeightTask = new UpdateWeightTask();
this.random = random; this.sequence = new AtomicInteger(random.nextInt());
log.log(Level.FINE, "weighted_round_robin LB created"); log.log(Level.FINE, "weighted_round_robin LB created");
} }
@ -294,9 +294,7 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
double newWeight = subchannel.getWeight(); double newWeight = subchannel.getWeight();
newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f; newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f;
} }
this.scheduler = new StaticStrideScheduler(newWeights, sequence);
StaticStrideScheduler scheduler = new StaticStrideScheduler(newWeights, random);
this.scheduler = scheduler;
} }
@Override @Override
@ -353,7 +351,7 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
private final AtomicInteger sequence; private final AtomicInteger sequence;
private static final int K_MAX_WEIGHT = 0xFFFF; private static final int K_MAX_WEIGHT = 0xFFFF;
StaticStrideScheduler(float[] weights, Random random) { StaticStrideScheduler(float[] weights, AtomicInteger sequence) {
checkArgument(weights.length >= 1, "Couldn't build scheduler: requires at least one weight"); checkArgument(weights.length >= 1, "Couldn't build scheduler: requires at least one weight");
int numChannels = weights.length; int numChannels = weights.length;
int numWeightedChannels = 0; int numWeightedChannels = 0;
@ -386,7 +384,7 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
} }
this.scaledWeights = scaledWeights; this.scaledWeights = scaledWeights;
this.sequence = new AtomicInteger(random.nextInt()); this.sequence = sequence;
} }

View File

@ -836,7 +836,8 @@ public class WeightedRoundRobinLoadBalancerTest {
public void emptyWeights() { public void emptyWeights() {
float[] weights = {}; float[] weights = {};
Random random = new Random(); Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
sss.pick(); sss.pick();
} }
@ -844,7 +845,8 @@ public class WeightedRoundRobinLoadBalancerTest {
public void testPicksEqualsWeights() { public void testPicksEqualsWeights() {
float[] weights = {1.0f, 2.0f, 3.0f}; float[] weights = {1.0f, 2.0f, 3.0f};
Random random = new Random(); Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
int[] expectedPicks = new int[] {1, 2, 3}; int[] expectedPicks = new int[] {1, 2, 3};
int[] picks = new int[3]; int[] picks = new int[3];
for (int i = 0; i < 6; i++) { for (int i = 0; i < 6; i++) {
@ -857,7 +859,8 @@ public class WeightedRoundRobinLoadBalancerTest {
public void testContainsZeroWeightUseMean() { public void testContainsZeroWeightUseMean() {
float[] weights = {3.0f, 0.0f, 1.0f}; float[] weights = {3.0f, 0.0f, 1.0f};
Random random = new Random(); Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
int[] expectedPicks = new int[] {3, 2, 1}; int[] expectedPicks = new int[] {3, 2, 1};
int[] picks = new int[3]; int[] picks = new int[3];
for (int i = 0; i < 6; i++) { for (int i = 0; i < 6; i++) {
@ -870,7 +873,8 @@ public class WeightedRoundRobinLoadBalancerTest {
public void testContainsNegativeWeightUseMean() { public void testContainsNegativeWeightUseMean() {
float[] weights = {3.0f, -1.0f, 1.0f}; float[] weights = {3.0f, -1.0f, 1.0f};
Random random = new Random(); Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
int[] expectedPicks = new int[] {3, 2, 1}; int[] expectedPicks = new int[] {3, 2, 1};
int[] picks = new int[3]; int[] picks = new int[3];
for (int i = 0; i < 6; i++) { for (int i = 0; i < 6; i++) {
@ -883,7 +887,8 @@ public class WeightedRoundRobinLoadBalancerTest {
public void testAllSameWeights() { public void testAllSameWeights() {
float[] weights = {1.0f, 1.0f, 1.0f}; float[] weights = {1.0f, 1.0f, 1.0f};
Random random = new Random(); Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
int[] expectedPicks = new int[] {2, 2, 2}; int[] expectedPicks = new int[] {2, 2, 2};
int[] picks = new int[3]; int[] picks = new int[3];
for (int i = 0; i < 6; i++) { for (int i = 0; i < 6; i++) {
@ -896,7 +901,8 @@ public class WeightedRoundRobinLoadBalancerTest {
public void testAllZeroWeightsUseOne() { public void testAllZeroWeightsUseOne() {
float[] weights = {0.0f, 0.0f, 0.0f}; float[] weights = {0.0f, 0.0f, 0.0f};
Random random = new Random(); Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
int[] expectedPicks = new int[] {2, 2, 2}; int[] expectedPicks = new int[] {2, 2, 2};
int[] picks = new int[3]; int[] picks = new int[3];
for (int i = 0; i < 6; i++) { for (int i = 0; i < 6; i++) {
@ -909,7 +915,8 @@ public class WeightedRoundRobinLoadBalancerTest {
public void testAllInvalidWeightsUseOne() { public void testAllInvalidWeightsUseOne() {
float[] weights = {-3.1f, -0.0f, 0.0f}; float[] weights = {-3.1f, -0.0f, 0.0f};
Random random = new Random(); Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
int[] expectedPicks = new int[] {2, 2, 2}; int[] expectedPicks = new int[] {2, 2, 2};
int[] picks = new int[3]; int[] picks = new int[3];
for (int i = 0; i < 6; i++) { for (int i = 0; i < 6; i++) {
@ -923,7 +930,8 @@ public class WeightedRoundRobinLoadBalancerTest {
float[] weights = {1.0f, 2.0f, 3.0f}; float[] weights = {1.0f, 2.0f, 3.0f};
int largestWeightIndex = 2; int largestWeightIndex = 2;
Random random = new Random(); Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
int largestWeightPickCount = 0; int largestWeightPickCount = 0;
int kMaxWeight = 65535; int kMaxWeight = 65535;
for (int i = 0; i < largestWeightIndex * kMaxWeight; i++) { for (int i = 0; i < largestWeightIndex * kMaxWeight; i++) {
@ -938,7 +946,8 @@ public class WeightedRoundRobinLoadBalancerTest {
public void testStaticStrideSchedulerNonIntegers1() { public void testStaticStrideSchedulerNonIntegers1() {
float[] weights = {2.0f, (float) (10.0 / 3.0), 1.0f}; float[] weights = {2.0f, (float) (10.0 / 3.0), 1.0f};
Random random = new Random(); Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
double totalWeight = 2 + 10.0 / 3.0 + 1.0; double totalWeight = 2 + 10.0 / 3.0 + 1.0;
Map<Integer, Integer> pickCount = new HashMap<>(); Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) { for (int i = 0; i < 1000; i++) {
@ -955,7 +964,8 @@ public class WeightedRoundRobinLoadBalancerTest {
public void testStaticStrideSchedulerNonIntegers2() { public void testStaticStrideSchedulerNonIntegers2() {
float[] weights = {0.5f, 0.3f, 1.0f}; float[] weights = {0.5f, 0.3f, 1.0f};
Random random = new Random(); Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
double totalWeight = 1.8; double totalWeight = 1.8;
Map<Integer, Integer> pickCount = new HashMap<>(); Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) { for (int i = 0; i < 1000; i++) {
@ -972,7 +982,8 @@ public class WeightedRoundRobinLoadBalancerTest {
public void testTwoWeights() { public void testTwoWeights() {
float[] weights = {1.0f, 2.0f}; float[] weights = {1.0f, 2.0f};
Random random = new Random(); Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
double totalWeight = 3; double totalWeight = 3;
Map<Integer, Integer> pickCount = new HashMap<>(); Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) { for (int i = 0; i < 1000; i++) {
@ -989,7 +1000,8 @@ public class WeightedRoundRobinLoadBalancerTest {
public void testManyWeights() { public void testManyWeights() {
float[] weights = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; float[] weights = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
Random random = new Random(); Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
double totalWeight = 15; double totalWeight = 15;
Map<Integer, Integer> pickCount = new HashMap<>(); Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) { for (int i = 0; i < 1000; i++) {
@ -1006,7 +1018,8 @@ public class WeightedRoundRobinLoadBalancerTest {
public void testManyComplexWeights() { public void testManyComplexWeights() {
float[] weights = {1.2f, 2.4f, 222.56f, 1.1f, 15.0f, 226342.0f, 5123.0f, 532.2f}; float[] weights = {1.2f, 2.4f, 222.56f, 1.1f, 15.0f, 226342.0f, 5123.0f, 532.2f};
Random random = new Random(); Random random = new Random();
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
double totalWeight = 1.2 + 2.4 + 222.56 + 15.0 + 226342.0 + 5123.0 + 0.0001; double totalWeight = 1.2 + 2.4 + 222.56 + 15.0 + 226342.0 + 5123.0 + 0.0001;
Map<Integer, Integer> pickCount = new HashMap<>(); Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) { for (int i = 0; i < 1000; i++) {
@ -1023,7 +1036,8 @@ public class WeightedRoundRobinLoadBalancerTest {
public void testDeterministicPicks() { public void testDeterministicPicks() {
float[] weights = {2.0f, 3.0f, 6.0f}; float[] weights = {2.0f, 3.0f, 6.0f};
Random random = new FakeRandom(0); Random random = new FakeRandom(0);
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
assertThat(sss.getSequence()).isEqualTo(0); assertThat(sss.getSequence()).isEqualTo(0);
assertThat(sss.pick()).isEqualTo(1); assertThat(sss.pick()).isEqualTo(1);
assertThat(sss.getSequence()).isEqualTo(2); assertThat(sss.getSequence()).isEqualTo(2);
@ -1043,7 +1057,8 @@ public class WeightedRoundRobinLoadBalancerTest {
public void testImmediateWraparound() { public void testImmediateWraparound() {
float[] weights = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; float[] weights = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
Random random = new FakeRandom(-1); Random random = new FakeRandom(-1);
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
double totalWeight = 15; double totalWeight = 15;
Map<Integer, Integer> pickCount = new HashMap<>(); Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) { for (int i = 0; i < 1000; i++) {
@ -1060,7 +1075,8 @@ public class WeightedRoundRobinLoadBalancerTest {
public void testWraparound() { public void testWraparound() {
float[] weights = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; float[] weights = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
Random random = new FakeRandom(-500); Random random = new FakeRandom(-500);
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
double totalWeight = 15; double totalWeight = 15;
Map<Integer, Integer> pickCount = new HashMap<>(); Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) { for (int i = 0; i < 1000; i++) {
@ -1077,7 +1093,8 @@ public class WeightedRoundRobinLoadBalancerTest {
public void testDeterministicWraparound() { public void testDeterministicWraparound() {
float[] weights = {2.0f, 3.0f, 6.0f}; float[] weights = {2.0f, 3.0f, 6.0f};
Random random = new FakeRandom(-1); Random random = new FakeRandom(-1);
StaticStrideScheduler sss = new StaticStrideScheduler(weights, random); StaticStrideScheduler sss = new StaticStrideScheduler(weights,
new AtomicInteger(random.nextInt()));
assertThat(sss.getSequence()).isEqualTo(0xFFFF_FFFFL); assertThat(sss.getSequence()).isEqualTo(0xFFFF_FFFFL);
assertThat(sss.pick()).isEqualTo(1); assertThat(sss.pick()).isEqualTo(1);
assertThat(sss.getSequence()).isEqualTo(2); assertThat(sss.getSequence()).isEqualTo(2);