From f6c2d221e2b6c975c6cf465d68fe11ab12dabe55 Mon Sep 17 00:00:00 2001 From: ZHANG Dapeng Date: Wed, 30 Sep 2020 15:31:09 -0700 Subject: [PATCH] rls: fix wrong synchronization for pickSubchannel() `RlsPicker.pickSubchannel()` does not run in SynchronizationContext, but it calls `CachingRlsLbClient.get()` which assumed running in SynchronizationContext. Fixed by removing `synchronizationContext.throwIfNotInThisSynchronizationContext()`. `CachingRlsLbClient.get()` is actually thread-safe in the sense it's guarded by lock, and `DataCacheEntry`'s fields are final. `ChildPolicyWrapper.picker` was not thread-safe. Fixed by making it volatile. Changed the test a bit since the old test doesn't really test things well. --- .../java/io/grpc/rls/CachingRlsLbClient.java | 56 ++--- .../io/grpc/rls/LbPolicyConfiguration.java | 68 +---- .../io/grpc/rls/CachingRlsLbClientTest.java | 9 +- .../grpc/rls/LbPolicyConfigurationTest.java | 60 ----- .../java/io/grpc/rls/RlsLoadBalancerTest.java | 237 +++++------------- 5 files changed, 96 insertions(+), 334 deletions(-) diff --git a/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java b/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java index e1a4152d89..aefbf7d9d1 100644 --- a/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java +++ b/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java @@ -453,7 +453,7 @@ final class CachingRlsLbClient { private final RouteLookupResponse response; private final long expireTime; private final long staleTime; - private ChildPolicyWrapper childPolicyWrapper; + private final ChildPolicyWrapper childPolicyWrapper; DataCacheEntry(RouteLookupRequest request, final RouteLookupResponse response) { super(request); @@ -467,21 +467,12 @@ final class CachingRlsLbClient { staleTime = now + staleAgeNanos; if (childPolicyWrapper.getPicker() != null) { - // using cached childPolicyWrapper - updateLbState(); + childPolicyWrapper.refreshState(); } else { createChildLbPolicy(); } } - private void updateLbState() { - childPolicyWrapper - .getHelper() - .updateBalancingState( - childPolicyWrapper.getConnectivityStateInfo().getState(), - childPolicyWrapper.getPicker()); - } - private void createChildLbPolicy() { ChildLoadBalancingPolicy childPolicy = lbPolicyConfig.getLoadBalancingPolicy(); LoadBalancerProvider lbProvider = childPolicy.getEffectiveLbProvider(); @@ -868,19 +859,15 @@ final class CachingRlsLbClient { } if (response.hasData()) { ChildPolicyWrapper childPolicyWrapper = response.getChildPolicyWrapper(); - ConnectivityState connectivityState = - childPolicyWrapper.getConnectivityStateInfo().getState(); - switch (connectivityState) { - case IDLE: - case CONNECTING: - return PickResult.withNoResult(); - case READY: - return childPolicyWrapper.getPicker().pickSubchannel(args); - case TRANSIENT_FAILURE: - case SHUTDOWN: - default: - return useFallback(args); + SubchannelPicker picker = childPolicyWrapper.getPicker(); + if (picker == null) { + return PickResult.withNoResult(); } + PickResult result = picker.pickSubchannel(args); + if (result.getStatus().isOk()) { + return result; + } + return useFallback(args); } else if (response.hasError()) { return useFallback(args); } else { @@ -898,26 +885,11 @@ final class CachingRlsLbClient { // TODO(creamsoup) wait until lb is ready startFallbackChildPolicy(); } - switch (fallbackChildPolicyWrapper.getConnectivityStateInfo().getState()) { - case IDLE: - // fall through - case CONNECTING: - return PickResult.withNoResult(); - case TRANSIENT_FAILURE: - // fall through - case SHUTDOWN: - return - PickResult - .withError(fallbackChildPolicyWrapper.getConnectivityStateInfo().getStatus()); - case READY: - SubchannelPicker picker = fallbackChildPolicyWrapper.getPicker(); - if (picker == null) { - return PickResult.withNoResult(); - } - return picker.pickSubchannel(args); - default: - throw new AssertionError(); + SubchannelPicker picker = fallbackChildPolicyWrapper.getPicker(); + if (picker == null) { + return PickResult.withNoResult(); } + return picker.pickSubchannel(args); } private void startFallbackChildPolicy() { diff --git a/rls/src/main/java/io/grpc/rls/LbPolicyConfiguration.java b/rls/src/main/java/io/grpc/rls/LbPolicyConfiguration.java index 20515181fd..d4cc7672b4 100644 --- a/rls/src/main/java/io/grpc/rls/LbPolicyConfiguration.java +++ b/rls/src/main/java/io/grpc/rls/LbPolicyConfiguration.java @@ -23,19 +23,15 @@ import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import io.grpc.ConnectivityState; -import io.grpc.ConnectivityStateInfo; -import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; -import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.LoadBalancerProvider; import io.grpc.LoadBalancerRegistry; import io.grpc.internal.ObjectPool; import io.grpc.rls.ChildLoadBalancerHelper.ChildLoadBalancerHelperProvider; import io.grpc.rls.RlsProtoData.RouteLookupConfig; import io.grpc.util.ForwardingLoadBalancerHelper; -import io.grpc.util.ForwardingSubchannel; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -242,9 +238,8 @@ final class LbPolicyConfiguration { private final String target; private final ChildPolicyReportingHelper helper; - private ConnectivityStateInfo connectivityStateInfo = - ConnectivityStateInfo.forNonError(ConnectivityState.IDLE); - private SubchannelPicker picker; + private volatile SubchannelPicker picker; + private ConnectivityState state; public ChildPolicyWrapper( String target, @@ -259,10 +254,6 @@ final class LbPolicyConfiguration { return target; } - void setPicker(SubchannelPicker picker) { - this.picker = checkNotNull(picker, "picker"); - } - SubchannelPicker getPicker() { return picker; } @@ -271,32 +262,8 @@ final class LbPolicyConfiguration { return helper; } - void setConnectivityStateInfo(ConnectivityStateInfo connectivityStateInfo) { - this.connectivityStateInfo = connectivityStateInfo; - } - - ConnectivityStateInfo getConnectivityStateInfo() { - return connectivityStateInfo; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - ChildPolicyWrapper that = (ChildPolicyWrapper) o; - return Objects.equals(target, that.target) - && Objects.equals(helper, that.helper) - && Objects.equals(connectivityStateInfo, that.connectivityStateInfo) - && Objects.equals(picker, that.picker); - } - - @Override - public int hashCode() { - return Objects.hash(target, helper, connectivityStateInfo, picker); + void refreshState() { + helper.updateBalancingState(state, picker); } @Override @@ -304,8 +271,8 @@ final class LbPolicyConfiguration { return MoreObjects.toStringHelper(this) .add("target", target) .add("helper", helper) - .add("connectivityStateInfo", connectivityStateInfo) .add("picker", picker) + .add("state", state) .toString(); } @@ -335,32 +302,11 @@ final class LbPolicyConfiguration { @Override public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { - setPicker(newPicker); + picker = newPicker; + state = newState; super.updateBalancingState(newState, newPicker); listener.onStatusChanged(newState); } - - @Override - public Subchannel createSubchannel(CreateSubchannelArgs args) { - final Subchannel subchannel = super.createSubchannel(args); - return new ForwardingSubchannel() { - @Override - protected Subchannel delegate() { - return subchannel; - } - - @Override - public void start(final SubchannelStateListener listener) { - super.start(new SubchannelStateListener() { - @Override - public void onSubchannelState(ConnectivityStateInfo newState) { - setConnectivityStateInfo(newState); - listener.onSubchannelState(newState); - } - }); - } - }; - } } } diff --git a/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java b/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java index 8d7a4febba..2a12e5c2c1 100644 --- a/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java +++ b/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java @@ -418,7 +418,14 @@ public class CachingRlsLbClientTest { @Override public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { // TODO: make the picker accessible - helper.updateBalancingState(ConnectivityState.READY, mock(SubchannelPicker.class)); + helper.updateBalancingState( + ConnectivityState.READY, + new SubchannelPicker() { + @Override + public PickResult pickSubchannel(PickSubchannelArgs args) { + return PickResult.withSubchannel(mock(Subchannel.class)); + } + }); } @Override diff --git a/rls/src/test/java/io/grpc/rls/LbPolicyConfigurationTest.java b/rls/src/test/java/io/grpc/rls/LbPolicyConfigurationTest.java index 25fc9de29f..7c6be039c2 100644 --- a/rls/src/test/java/io/grpc/rls/LbPolicyConfigurationTest.java +++ b/rls/src/test/java/io/grpc/rls/LbPolicyConfigurationTest.java @@ -16,25 +16,17 @@ package io.grpc.rls; -import static com.google.common.base.Preconditions.checkState; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; -import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.grpc.Attributes; import io.grpc.ConnectivityState; -import io.grpc.ConnectivityStateInfo; -import io.grpc.EquivalentAddressGroup; -import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; -import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; -import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.LoadBalancerProvider; import io.grpc.LoadBalancerRegistry; import io.grpc.rls.ChildLoadBalancerHelper.ChildLoadBalancerHelperProvider; @@ -44,7 +36,6 @@ import io.grpc.rls.LbPolicyConfiguration.ChildPolicyWrapper; import io.grpc.rls.LbPolicyConfiguration.ChildPolicyWrapper.ChildPolicyReportingHelper; import io.grpc.rls.LbPolicyConfiguration.InvalidChildPolicyConfigException; import io.grpc.rls.LbPolicyConfiguration.RefCountedChildPolicyWrapperFactory; -import java.net.SocketAddress; import java.util.Map; import org.junit.Test; import org.junit.runner.RunWith; @@ -133,31 +124,6 @@ public class LbPolicyConfigurationTest { } } - @Test - public void subchannelStateChange_updateChildPolicyWrapper() { - ChildPolicyWrapper childPolicyWrapper = factory.createOrGet("foo.google.com"); - ChildPolicyReportingHelper childPolicyReportingHelper = childPolicyWrapper.getHelper(); - FakeSubchannel fakeSubchannel = new FakeSubchannel(); - when(helper.createSubchannel(any(CreateSubchannelArgs.class))).thenReturn(fakeSubchannel); - Subchannel subchannel = - childPolicyReportingHelper - .createSubchannel( - CreateSubchannelArgs.newBuilder() - .setAddresses(new EquivalentAddressGroup(mock(SocketAddress.class))) - .build()); - subchannel.start(new SubchannelStateListener() { - @Override - public void onSubchannelState(ConnectivityStateInfo newState) { - // no-op - } - }); - - fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING)); - - assertThat(childPolicyWrapper.getConnectivityStateInfo()) - .isEqualTo(ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING)); - } - @Test public void updateBalancingState_triggersListener() { ChildPolicyWrapper childPolicyWrapper = factory.createOrGet("foo.google.com"); @@ -171,30 +137,4 @@ public class LbPolicyConfigurationTest { // picker governs childPickers will be reported to parent LB verify(helper).updateBalancingState(ConnectivityState.READY, picker); } - - private static class FakeSubchannel extends Subchannel { - - private SubchannelStateListener listener; - - @Override - public void start(SubchannelStateListener listener) { - this.listener = listener; - } - - void updateState(ConnectivityStateInfo newState) { - checkState(listener != null, "channel is not started yet"); - listener.onSubchannelState(newState); - } - - @Override - public void shutdown() {} - - @Override - public void requestConnection() {} - - @Override - public Attributes getAttributes() { - return null; - } - } } diff --git a/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java b/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java index 045c00c303..b92963063b 100644 --- a/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java +++ b/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java @@ -30,7 +30,6 @@ import static org.mockito.Mockito.verify; import com.google.common.base.Converter; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.util.concurrent.SettableFuture; import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ChannelLogger; @@ -51,8 +50,8 @@ import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor.Marshaller; import io.grpc.MethodDescriptor.MethodType; -import io.grpc.NameResolver.ConfigOrError; import io.grpc.NameResolver; +import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; import io.grpc.SynchronizationContext; import io.grpc.inprocess.InProcessChannelBuilder; @@ -60,7 +59,6 @@ import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.internal.JsonParser; import io.grpc.internal.PickSubchannelArgsImpl; import io.grpc.lookup.v1.RouteLookupServiceGrpc; -import io.grpc.rls.CachingRlsLbClient.RlsPicker; import io.grpc.rls.RlsLoadBalancer.CachingRlsLbClientBuilderProvider; import io.grpc.rls.RlsProtoConverters.RouteLookupResponseConverter; import io.grpc.rls.RlsProtoData.RouteLookupRequest; @@ -76,7 +74,6 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; import javax.annotation.Nonnull; import org.junit.After; import org.junit.Before; @@ -180,140 +177,78 @@ public class RlsLoadBalancerTest { } @Test - public void lb_working() throws Exception { - final InOrder inOrder = inOrder(helper); - + public void lb_working() { + InOrder inOrder = inOrder(helper); inOrder.verify(helper) .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); - assertThat(pickerCaptor.getValue()).isInstanceOf(RlsPicker.class); - final RlsPicker picker = (RlsPicker) pickerCaptor.getValue(); - final Metadata headers = new Metadata(); - - blockingRunInSyncContext( - new Runnable() { - @Override - public void run() { - PickResult res = - picker.pickSubchannel( - new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); - // verify pending - assertThat(res.getSubchannel()).isNull(); - assertThat(res.getStatus().isOk()).isTrue(); - } - }); - + SubchannelPicker picker = pickerCaptor.getValue(); + Metadata headers = new Metadata(); + PickResult res = picker.pickSubchannel( + new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(helper) - .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); - assertThat(subchannels).hasSize(1); + .updateBalancingState(eq(ConnectivityState.CONNECTING), any(SubchannelPicker.class)); inOrder.verifyNoMoreInteractions(); + assertThat(res.getStatus().isOk()).isTrue(); + assertThat(subchannelIsReady(res.getSubchannel())).isFalse(); - final FakeSubchannel searchSubchannel = subchannels.getLast(); + assertThat(subchannels).hasSize(1); + FakeSubchannel searchSubchannel = subchannels.getLast(); searchSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); inOrder.verify(helper) .updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture()); - - assertThat(pickerCaptor.getValue()).isInstanceOf(RlsPicker.class); - final RlsPicker picker2 = (RlsPicker) pickerCaptor.getValue(); - assertThat(picker2).isEqualTo(picker); - blockingRunInSyncContext( - new Runnable() { - @Override - public void run() { - PickResult res = picker2.pickSubchannel( - new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); - // verify success. Subchannel is wrapped, so checking attributes. - assertThat(res.getSubchannel()).isNotNull(); - assertThat(res.getSubchannel().getAddresses()) - .isEqualTo(searchSubchannel.getAddresses()); - assertThat(res.getSubchannel().getAttributes()) - .isEqualTo(searchSubchannel.getAttributes()); - assertThat(res.getStatus().isOk()).isTrue(); - } - }); - inOrder.verifyNoMoreInteractions(); + assertThat(subchannelIsReady(res.getSubchannel())).isTrue(); + assertThat(res.getSubchannel().getAddresses()).isEqualTo(searchSubchannel.getAddresses()); + assertThat(res.getSubchannel().getAttributes()).isEqualTo(searchSubchannel.getAttributes()); - // rescue should be pending status - blockingRunInSyncContext( - new Runnable() { - @Override - public void run() { - PickResult res = - picker.pickSubchannel( - new PickSubchannelArgsImpl(fakeRescueMethod, headers, CallOptions.DEFAULT)); - assertThat(res.getSubchannel()).isNull(); - assertThat(res.getStatus().isOk()).isTrue(); - } - }); - + // rescue should be pending status although the overall channel state is READY + res = picker.pickSubchannel( + new PickSubchannelArgsImpl(fakeRescueMethod, headers, CallOptions.DEFAULT)); inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); // other rls picker itself is ready due to first channel. inOrder.verify(helper) .updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture()); - assertThat(subchannels).hasSize(2); inOrder.verifyNoMoreInteractions(); + assertThat(res.getStatus().isOk()).isTrue(); + assertThat(subchannelIsReady(res.getSubchannel())).isFalse(); + assertThat(subchannels).hasSize(2); + FakeSubchannel rescueSubchannel = subchannels.getLast(); - // rescue subchannel is connecting + // search subchannel is down, rescue subchannel is connecting searchSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.NOT_FOUND)); - inOrder.verify(helper) .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); - final FakeSubchannel rescueSubchannel = subchannels.getLast(); rescueSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); inOrder.verify(helper) .updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture()); // search again, use pending fallback because searchSubchannel is in failure mode - blockingRunInSyncContext( - new Runnable() { - @Override - public void run() { - PickResult res = - picker.pickSubchannel( - new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); - assertThat(res.getSubchannel()).isNull(); - assertThat(res.getStatus().isOk()).isTrue(); - } - }); + res = picker.pickSubchannel( + new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); + assertThat(res.getStatus().isOk()).isTrue(); + assertThat(subchannelIsReady(res.getSubchannel())).isFalse(); inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); assertThat(subchannels).hasSize(3); - final FakeSubchannel fallbackSubchannel = subchannels.getLast(); + FakeSubchannel fallbackSubchannel = subchannels.getLast(); fallbackSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); inOrder.verify(helper, times(2)) .updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture()); inOrder.verifyNoMoreInteractions(); - blockingRunInSyncContext( - new Runnable() { - @Override - public void run() { - PickResult res = - picker.pickSubchannel( - new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); - assertThat(res.getSubchannel().getAddresses()) - .isEqualTo(fallbackSubchannel.getAddresses()); - assertThat(res.getSubchannel().getAttributes()) - .isEqualTo(fallbackSubchannel.getAttributes()); - assertThat(res.getStatus().isOk()).isTrue(); - } - }); - blockingRunInSyncContext( - new Runnable() { - @Override - public void run() { - PickResult res = - picker.pickSubchannel( - new PickSubchannelArgsImpl(fakeRescueMethod, headers, CallOptions.DEFAULT)); - assertThat(res.getSubchannel().getAddresses()) - .isEqualTo(rescueSubchannel.getAddresses()); - assertThat(res.getSubchannel().getAttributes()) - .isEqualTo(rescueSubchannel.getAttributes()); - assertThat(res.getStatus().isOk()).isTrue(); - } - }); + res = picker.pickSubchannel( + new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); + assertThat(subchannelIsReady(res.getSubchannel())).isTrue(); + assertThat(res.getSubchannel().getAddresses()).isEqualTo(searchSubchannel.getAddresses()); + assertThat(res.getSubchannel().getAttributes()).isEqualTo(searchSubchannel.getAttributes()); + + res = picker.pickSubchannel( + new PickSubchannelArgsImpl(fakeRescueMethod, headers, CallOptions.DEFAULT)); + assertThat(subchannelIsReady(res.getSubchannel())).isTrue(); + assertThat(res.getSubchannel().getAddresses()).isEqualTo(rescueSubchannel.getAddresses()); + assertThat(res.getSubchannel().getAttributes()).isEqualTo(rescueSubchannel.getAttributes()); // all channels are failed rescueSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.NOT_FOUND)); @@ -326,27 +261,18 @@ public class RlsLoadBalancerTest { } @Test - public void lb_nameResolutionFailed() throws Exception { - final InOrder inOrder = inOrder(helper); + public void lb_nameResolutionFailed() { + InOrder inOrder = inOrder(helper); inOrder.verify(helper) .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); - assertThat(pickerCaptor.getValue()).isInstanceOf(RlsPicker.class); - final RlsPicker picker = (RlsPicker) pickerCaptor.getValue(); - final Metadata headers = new Metadata(); - - blockingRunInSyncContext( - new Runnable() { - @Override - public void run() { - PickResult res = - picker.pickSubchannel( - new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); - // verify pending - assertThat(res.getSubchannel()).isNull(); - assertThat(res.getStatus().isOk()).isTrue(); - } - }); + SubchannelPicker picker = pickerCaptor.getValue(); + Metadata headers = new Metadata(); + PickResult res = + picker.pickSubchannel( + new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); + assertThat(res.getStatus().isOk()).isTrue(); + assertThat(subchannelIsReady(res.getSubchannel())).isFalse(); inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(helper) @@ -354,29 +280,19 @@ public class RlsLoadBalancerTest { assertThat(subchannels).hasSize(1); inOrder.verifyNoMoreInteractions(); - final FakeSubchannel searchSubchannel = subchannels.getLast(); + FakeSubchannel searchSubchannel = subchannels.getLast(); searchSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); inOrder.verify(helper) .updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture()); - assertThat(pickerCaptor.getValue()).isInstanceOf(RlsPicker.class); - final RlsPicker picker2 = (RlsPicker) pickerCaptor.getValue(); + SubchannelPicker picker2 = pickerCaptor.getValue(); assertThat(picker2).isEqualTo(picker); - blockingRunInSyncContext( - new Runnable() { - @Override - public void run() { - PickResult res = picker2.pickSubchannel( - new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); - // verify success. Subchannel is wrapped, so checking attributes. - assertThat(res.getSubchannel()).isNotNull(); - assertThat(res.getSubchannel().getAddresses()) - .isEqualTo(searchSubchannel.getAddresses()); - assertThat(res.getSubchannel().getAttributes()) - .isEqualTo(searchSubchannel.getAttributes()); - assertThat(res.getStatus().isOk()).isTrue(); - } - }); + res = picker2.pickSubchannel( + new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); + // verify success. Subchannel is wrapped, so checking attributes. + assertThat(subchannelIsReady(res.getSubchannel())).isTrue(); + assertThat(res.getSubchannel().getAddresses()).isEqualTo(searchSubchannel.getAddresses()); + assertThat(res.getSubchannel().getAttributes()).isEqualTo(searchSubchannel.getAttributes()); inOrder.verifyNoMoreInteractions(); @@ -384,36 +300,11 @@ public class RlsLoadBalancerTest { verify(helper) .updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - final SubchannelPicker failedPicker = pickerCaptor.getValue(); - blockingRunInSyncContext( - new Runnable() { - @Override - public void run() { - PickResult res = failedPicker.pickSubchannel( - new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); - assertThat(res.getSubchannel()).isNull(); - assertThat(res.getStatus().isOk()).isFalse(); - } - }); - } - - private void blockingRunInSyncContext(final Runnable command) throws Exception { - final SettableFuture exceptionFuture = SettableFuture.create(); - syncContext.execute(new Runnable() { - @Override - public void run() { - try { - command.run(); - exceptionFuture.set(null); - } catch (Exception e) { - exceptionFuture.set(e); - } - } - }); - Exception exception = exceptionFuture.get(5, TimeUnit.SECONDS); - if (exception != null) { - throw exception; - } + SubchannelPicker failedPicker = pickerCaptor.getValue(); + res = failedPicker.pickSubchannel( + new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); + assertThat(res.getStatus().isOk()).isFalse(); + assertThat(subchannelIsReady(res.getSubchannel())).isFalse(); } @SuppressWarnings("unchecked") @@ -566,6 +457,7 @@ public class RlsLoadBalancerTest { private final Attributes attributes; private List eags; private SubchannelStateListener listener; + private boolean isReady; public FakeSubchannel(List eags, Attributes attributes) { this.eags = Collections.unmodifiableList(eags); @@ -602,6 +494,11 @@ public class RlsLoadBalancerTest { public void updateState(ConnectivityStateInfo newState) { listener.onSubchannelState(newState); + isReady = newState.getState().equals(ConnectivityState.READY); } } + + private static boolean subchannelIsReady(Subchannel subchannel) { + return subchannel instanceof FakeSubchannel && ((FakeSubchannel) subchannel).isReady; + } }