diff --git a/core/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java b/core/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java index 4649302af1..5609708492 100644 --- a/core/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java +++ b/core/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java @@ -284,7 +284,8 @@ public class RoundRobinLoadBalancer extends LoadBalancer { public abstract boolean isEquivalentTo(RoundRobinPicker picker); } - public static class ReadyPicker extends RoundRobinPicker { + @VisibleForTesting + static class ReadyPicker extends RoundRobinPicker { private static final AtomicIntegerFieldUpdater indexUpdater = AtomicIntegerFieldUpdater.newUpdater(ReadyPicker.class, "index"); @@ -336,7 +337,8 @@ public class RoundRobinLoadBalancer extends LoadBalancer { } } - public static final class EmptyPicker extends RoundRobinPicker { + @VisibleForTesting + static final class EmptyPicker extends RoundRobinPicker { private final Status status; diff --git a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java index 5a278bc6e3..87593d5324 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java @@ -232,13 +232,13 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer { } @VisibleForTesting - final class WeightedRoundRobinPicker extends ReadyPicker { + final class WeightedRoundRobinPicker extends RoundRobinPicker { private final List list; private final boolean enableOobLoadReport; private volatile EdfScheduler scheduler; WeightedRoundRobinPicker(List list, boolean enableOobLoadReport) { - super(checkNotNull(list, "list"), random.nextInt(list.size())); + checkNotNull(list, "list"); Preconditions.checkArgument(!list.isEmpty(), "empty list"); this.list = list; this.enableOobLoadReport = enableOobLoadReport; diff --git a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java index 1d80a032c9..da298d3abb 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java @@ -46,12 +46,12 @@ import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.Subchannel; +import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.SynchronizationContext; import io.grpc.internal.FakeClock; import io.grpc.services.InternalCallMetricRecorder; import io.grpc.services.MetricReport; -import io.grpc.util.RoundRobinLoadBalancer.EmptyPicker; import io.grpc.xds.WeightedRoundRobinLoadBalancer.EdfScheduler; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinPicker; @@ -94,9 +94,9 @@ public class WeightedRoundRobinLoadBalancerTest { private LoadBalancer.PickSubchannelArgs mockArgs; @Captor - private ArgumentCaptor pickerCaptor; + private ArgumentCaptor pickerCaptor; @Captor - private ArgumentCaptor pickerCaptor2; + private ArgumentCaptor pickerCaptor2; private final List servers = Lists.newArrayList(); @@ -200,8 +200,10 @@ public class WeightedRoundRobinLoadBalancerTest { verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2); - assertThat(pickerCaptor.getAllValues().get(0).getList().size()).isEqualTo(1); - WeightedRoundRobinPicker weightedPicker = pickerCaptor.getAllValues().get(1); + WeightedRoundRobinPicker weightedPicker = + (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(0); + assertThat(weightedPicker.getList().size()).isEqualTo(1); + weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); assertThat(weightedPicker.getList().size()).isEqualTo(2); WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); @@ -245,7 +247,8 @@ public class WeightedRoundRobinLoadBalancerTest { .forNonError(ConnectivityState.READY)); verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); - WeightedRoundRobinPicker weightedPicker = pickerCaptor.getAllValues().get(1); + WeightedRoundRobinPicker weightedPicker = + (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); weightedSubchannel1.onLoadReport(InternalCallMetricRecorder.createMetricReport( @@ -266,7 +269,7 @@ public class WeightedRoundRobinLoadBalancerTest { .setAttributes(affinity).build())); verify(helper, times(3)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor2.capture()); - weightedPicker = pickerCaptor2.getAllValues().get(2); + weightedPicker = (WeightedRoundRobinPicker) pickerCaptor2.getAllValues().get(2); pickResult = weightedPicker.pickSubchannel(mockArgs); assertThat(pickResult.getSubchannel()).isEqualTo(weightedSubchannel1); assertThat(pickResult.getStreamTracerFactory()).isNull(); @@ -299,7 +302,8 @@ public class WeightedRoundRobinLoadBalancerTest { .forNonError(ConnectivityState.READY)); verify(helper, times(3)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); - WeightedRoundRobinPicker weightedPicker = pickerCaptor.getAllValues().get(2); + WeightedRoundRobinPicker weightedPicker = + (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(2); WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); WrrSubchannel weightedSubchannel3 = (WrrSubchannel) weightedPicker.getList().get(2); @@ -364,7 +368,8 @@ public class WeightedRoundRobinLoadBalancerTest { verify(helper, times(3)).createSubchannel( any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); - assertThat(pickerCaptor.getValue()).isInstanceOf(EmptyPicker.class); + assertThat(pickerCaptor.getValue().getClass().getName()) + .isEqualTo("io.grpc.util.RoundRobinLoadBalancer$EmptyPicker"); assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); } @@ -386,7 +391,8 @@ public class WeightedRoundRobinLoadBalancerTest { .forNonError(ConnectivityState.READY)); verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); - WeightedRoundRobinPicker weightedPicker = pickerCaptor.getAllValues().get(1); + WeightedRoundRobinPicker weightedPicker = + (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); weightedSubchannel1.onLoadReport(InternalCallMetricRecorder.createMetricReport( @@ -440,8 +446,10 @@ public class WeightedRoundRobinLoadBalancerTest { verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2); - assertThat(pickerCaptor.getAllValues().get(0).getList().size()).isEqualTo(1); - WeightedRoundRobinPicker weightedPicker = pickerCaptor.getAllValues().get(1); + WeightedRoundRobinPicker weightedPicker = + (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(0); + assertThat(weightedPicker.getList().size()).isEqualTo(1); + weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); assertThat(weightedPicker.getList().size()).isEqualTo(2); WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); @@ -488,7 +496,8 @@ public class WeightedRoundRobinLoadBalancerTest { .forNonError(ConnectivityState.READY)); verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); - WeightedRoundRobinPicker weightedPicker = pickerCaptor.getAllValues().get(1); + WeightedRoundRobinPicker weightedPicker = + (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); weightedSubchannel1.onLoadReport(InternalCallMetricRecorder.createMetricReport( @@ -539,7 +548,8 @@ public class WeightedRoundRobinLoadBalancerTest { .forNonError(ConnectivityState.READY)); verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); - WeightedRoundRobinPicker weightedPicker = pickerCaptor.getAllValues().get(1); + WeightedRoundRobinPicker weightedPicker = + (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1); WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); @@ -598,7 +608,8 @@ public class WeightedRoundRobinLoadBalancerTest { .forNonError(ConnectivityState.READY)); verify(helper, times(3)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); - WeightedRoundRobinPicker weightedPicker = pickerCaptor.getAllValues().get(2); + WeightedRoundRobinPicker weightedPicker = + (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(2); WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); WrrSubchannel weightedSubchannel3 = (WrrSubchannel) weightedPicker.getList().get(2); @@ -640,7 +651,8 @@ public class WeightedRoundRobinLoadBalancerTest { .forNonError(ConnectivityState.READY)); verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); - WeightedRoundRobinPicker weightedPicker = pickerCaptor.getAllValues().get(1); + WeightedRoundRobinPicker weightedPicker = + (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); weightedSubchannel1.onLoadReport(InternalCallMetricRecorder.createMetricReport(