rls: fix wrong synchronization for pickSubchannel()

`RlsPicker.pickSubchannel()` does not run in SynchronizationContext, but it calls `CachingRlsLbClient.get()` which assumed running in SynchronizationContext. Fixed by removing `synchronizationContext.throwIfNotInThisSynchronizationContext()`. `CachingRlsLbClient.get()` is actually thread-safe in the sense it's guarded by lock, and `DataCacheEntry`'s fields are final.

`ChildPolicyWrapper.picker` was not thread-safe. Fixed by making it volatile.

Changed the test a bit since the old test doesn't really test things well.
This commit is contained in:
ZHANG Dapeng 2020-09-30 15:31:09 -07:00 committed by GitHub
parent 00e2d717a2
commit f6c2d221e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 96 additions and 334 deletions

View File

@ -453,7 +453,7 @@ final class CachingRlsLbClient {
private final RouteLookupResponse response;
private final long expireTime;
private final long staleTime;
private ChildPolicyWrapper childPolicyWrapper;
private final ChildPolicyWrapper childPolicyWrapper;
DataCacheEntry(RouteLookupRequest request, final RouteLookupResponse response) {
super(request);
@ -467,21 +467,12 @@ final class CachingRlsLbClient {
staleTime = now + staleAgeNanos;
if (childPolicyWrapper.getPicker() != null) {
// using cached childPolicyWrapper
updateLbState();
childPolicyWrapper.refreshState();
} else {
createChildLbPolicy();
}
}
private void updateLbState() {
childPolicyWrapper
.getHelper()
.updateBalancingState(
childPolicyWrapper.getConnectivityStateInfo().getState(),
childPolicyWrapper.getPicker());
}
private void createChildLbPolicy() {
ChildLoadBalancingPolicy childPolicy = lbPolicyConfig.getLoadBalancingPolicy();
LoadBalancerProvider lbProvider = childPolicy.getEffectiveLbProvider();
@ -868,19 +859,15 @@ final class CachingRlsLbClient {
}
if (response.hasData()) {
ChildPolicyWrapper childPolicyWrapper = response.getChildPolicyWrapper();
ConnectivityState connectivityState =
childPolicyWrapper.getConnectivityStateInfo().getState();
switch (connectivityState) {
case IDLE:
case CONNECTING:
return PickResult.withNoResult();
case READY:
return childPolicyWrapper.getPicker().pickSubchannel(args);
case TRANSIENT_FAILURE:
case SHUTDOWN:
default:
return useFallback(args);
SubchannelPicker picker = childPolicyWrapper.getPicker();
if (picker == null) {
return PickResult.withNoResult();
}
PickResult result = picker.pickSubchannel(args);
if (result.getStatus().isOk()) {
return result;
}
return useFallback(args);
} else if (response.hasError()) {
return useFallback(args);
} else {
@ -898,26 +885,11 @@ final class CachingRlsLbClient {
// TODO(creamsoup) wait until lb is ready
startFallbackChildPolicy();
}
switch (fallbackChildPolicyWrapper.getConnectivityStateInfo().getState()) {
case IDLE:
// fall through
case CONNECTING:
return PickResult.withNoResult();
case TRANSIENT_FAILURE:
// fall through
case SHUTDOWN:
return
PickResult
.withError(fallbackChildPolicyWrapper.getConnectivityStateInfo().getStatus());
case READY:
SubchannelPicker picker = fallbackChildPolicyWrapper.getPicker();
if (picker == null) {
return PickResult.withNoResult();
}
return picker.pickSubchannel(args);
default:
throw new AssertionError();
SubchannelPicker picker = fallbackChildPolicyWrapper.getPicker();
if (picker == null) {
return PickResult.withNoResult();
}
return picker.pickSubchannel(args);
}
private void startFallbackChildPolicy() {

View File

@ -23,19 +23,15 @@ import static com.google.common.base.Preconditions.checkState;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
import io.grpc.LoadBalancer.CreateSubchannelArgs;
import io.grpc.LoadBalancer.Helper;
import io.grpc.LoadBalancer.Subchannel;
import io.grpc.LoadBalancer.SubchannelPicker;
import io.grpc.LoadBalancer.SubchannelStateListener;
import io.grpc.LoadBalancerProvider;
import io.grpc.LoadBalancerRegistry;
import io.grpc.internal.ObjectPool;
import io.grpc.rls.ChildLoadBalancerHelper.ChildLoadBalancerHelperProvider;
import io.grpc.rls.RlsProtoData.RouteLookupConfig;
import io.grpc.util.ForwardingLoadBalancerHelper;
import io.grpc.util.ForwardingSubchannel;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
@ -242,9 +238,8 @@ final class LbPolicyConfiguration {
private final String target;
private final ChildPolicyReportingHelper helper;
private ConnectivityStateInfo connectivityStateInfo =
ConnectivityStateInfo.forNonError(ConnectivityState.IDLE);
private SubchannelPicker picker;
private volatile SubchannelPicker picker;
private ConnectivityState state;
public ChildPolicyWrapper(
String target,
@ -259,10 +254,6 @@ final class LbPolicyConfiguration {
return target;
}
void setPicker(SubchannelPicker picker) {
this.picker = checkNotNull(picker, "picker");
}
SubchannelPicker getPicker() {
return picker;
}
@ -271,32 +262,8 @@ final class LbPolicyConfiguration {
return helper;
}
void setConnectivityStateInfo(ConnectivityStateInfo connectivityStateInfo) {
this.connectivityStateInfo = connectivityStateInfo;
}
ConnectivityStateInfo getConnectivityStateInfo() {
return connectivityStateInfo;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
ChildPolicyWrapper that = (ChildPolicyWrapper) o;
return Objects.equals(target, that.target)
&& Objects.equals(helper, that.helper)
&& Objects.equals(connectivityStateInfo, that.connectivityStateInfo)
&& Objects.equals(picker, that.picker);
}
@Override
public int hashCode() {
return Objects.hash(target, helper, connectivityStateInfo, picker);
void refreshState() {
helper.updateBalancingState(state, picker);
}
@Override
@ -304,8 +271,8 @@ final class LbPolicyConfiguration {
return MoreObjects.toStringHelper(this)
.add("target", target)
.add("helper", helper)
.add("connectivityStateInfo", connectivityStateInfo)
.add("picker", picker)
.add("state", state)
.toString();
}
@ -335,32 +302,11 @@ final class LbPolicyConfiguration {
@Override
public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) {
setPicker(newPicker);
picker = newPicker;
state = newState;
super.updateBalancingState(newState, newPicker);
listener.onStatusChanged(newState);
}
@Override
public Subchannel createSubchannel(CreateSubchannelArgs args) {
final Subchannel subchannel = super.createSubchannel(args);
return new ForwardingSubchannel() {
@Override
protected Subchannel delegate() {
return subchannel;
}
@Override
public void start(final SubchannelStateListener listener) {
super.start(new SubchannelStateListener() {
@Override
public void onSubchannelState(ConnectivityStateInfo newState) {
setConnectivityStateInfo(newState);
listener.onSubchannelState(newState);
}
});
}
};
}
}
}

View File

@ -418,7 +418,14 @@ public class CachingRlsLbClientTest {
@Override
public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) {
// TODO: make the picker accessible
helper.updateBalancingState(ConnectivityState.READY, mock(SubchannelPicker.class));
helper.updateBalancingState(
ConnectivityState.READY,
new SubchannelPicker() {
@Override
public PickResult pickSubchannel(PickSubchannelArgs args) {
return PickResult.withSubchannel(mock(Subchannel.class));
}
});
}
@Override

View File

@ -16,25 +16,17 @@
package io.grpc.rls;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.grpc.Attributes;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
import io.grpc.EquivalentAddressGroup;
import io.grpc.LoadBalancer.CreateSubchannelArgs;
import io.grpc.LoadBalancer.Helper;
import io.grpc.LoadBalancer.Subchannel;
import io.grpc.LoadBalancer.SubchannelPicker;
import io.grpc.LoadBalancer.SubchannelStateListener;
import io.grpc.LoadBalancerProvider;
import io.grpc.LoadBalancerRegistry;
import io.grpc.rls.ChildLoadBalancerHelper.ChildLoadBalancerHelperProvider;
@ -44,7 +36,6 @@ import io.grpc.rls.LbPolicyConfiguration.ChildPolicyWrapper;
import io.grpc.rls.LbPolicyConfiguration.ChildPolicyWrapper.ChildPolicyReportingHelper;
import io.grpc.rls.LbPolicyConfiguration.InvalidChildPolicyConfigException;
import io.grpc.rls.LbPolicyConfiguration.RefCountedChildPolicyWrapperFactory;
import java.net.SocketAddress;
import java.util.Map;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -133,31 +124,6 @@ public class LbPolicyConfigurationTest {
}
}
@Test
public void subchannelStateChange_updateChildPolicyWrapper() {
ChildPolicyWrapper childPolicyWrapper = factory.createOrGet("foo.google.com");
ChildPolicyReportingHelper childPolicyReportingHelper = childPolicyWrapper.getHelper();
FakeSubchannel fakeSubchannel = new FakeSubchannel();
when(helper.createSubchannel(any(CreateSubchannelArgs.class))).thenReturn(fakeSubchannel);
Subchannel subchannel =
childPolicyReportingHelper
.createSubchannel(
CreateSubchannelArgs.newBuilder()
.setAddresses(new EquivalentAddressGroup(mock(SocketAddress.class)))
.build());
subchannel.start(new SubchannelStateListener() {
@Override
public void onSubchannelState(ConnectivityStateInfo newState) {
// no-op
}
});
fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING));
assertThat(childPolicyWrapper.getConnectivityStateInfo())
.isEqualTo(ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING));
}
@Test
public void updateBalancingState_triggersListener() {
ChildPolicyWrapper childPolicyWrapper = factory.createOrGet("foo.google.com");
@ -171,30 +137,4 @@ public class LbPolicyConfigurationTest {
// picker governs childPickers will be reported to parent LB
verify(helper).updateBalancingState(ConnectivityState.READY, picker);
}
private static class FakeSubchannel extends Subchannel {
private SubchannelStateListener listener;
@Override
public void start(SubchannelStateListener listener) {
this.listener = listener;
}
void updateState(ConnectivityStateInfo newState) {
checkState(listener != null, "channel is not started yet");
listener.onSubchannelState(newState);
}
@Override
public void shutdown() {}
@Override
public void requestConnection() {}
@Override
public Attributes getAttributes() {
return null;
}
}
}

View File

@ -30,7 +30,6 @@ import static org.mockito.Mockito.verify;
import com.google.common.base.Converter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.SettableFuture;
import io.grpc.Attributes;
import io.grpc.CallOptions;
import io.grpc.ChannelLogger;
@ -51,8 +50,8 @@ import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.MethodDescriptor.Marshaller;
import io.grpc.MethodDescriptor.MethodType;
import io.grpc.NameResolver.ConfigOrError;
import io.grpc.NameResolver;
import io.grpc.NameResolver.ConfigOrError;
import io.grpc.Status;
import io.grpc.SynchronizationContext;
import io.grpc.inprocess.InProcessChannelBuilder;
@ -60,7 +59,6 @@ import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.internal.JsonParser;
import io.grpc.internal.PickSubchannelArgsImpl;
import io.grpc.lookup.v1.RouteLookupServiceGrpc;
import io.grpc.rls.CachingRlsLbClient.RlsPicker;
import io.grpc.rls.RlsLoadBalancer.CachingRlsLbClientBuilderProvider;
import io.grpc.rls.RlsProtoConverters.RouteLookupResponseConverter;
import io.grpc.rls.RlsProtoData.RouteLookupRequest;
@ -76,7 +74,6 @@ import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nonnull;
import org.junit.After;
import org.junit.Before;
@ -180,140 +177,78 @@ public class RlsLoadBalancerTest {
}
@Test
public void lb_working() throws Exception {
final InOrder inOrder = inOrder(helper);
public void lb_working() {
InOrder inOrder = inOrder(helper);
inOrder.verify(helper)
.updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture());
assertThat(pickerCaptor.getValue()).isInstanceOf(RlsPicker.class);
final RlsPicker picker = (RlsPicker) pickerCaptor.getValue();
final Metadata headers = new Metadata();
blockingRunInSyncContext(
new Runnable() {
@Override
public void run() {
PickResult res =
picker.pickSubchannel(
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT));
// verify pending
assertThat(res.getSubchannel()).isNull();
assertThat(res.getStatus().isOk()).isTrue();
}
});
SubchannelPicker picker = pickerCaptor.getValue();
Metadata headers = new Metadata();
PickResult res = picker.pickSubchannel(
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT));
inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class));
inOrder.verify(helper)
.updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture());
assertThat(subchannels).hasSize(1);
.updateBalancingState(eq(ConnectivityState.CONNECTING), any(SubchannelPicker.class));
inOrder.verifyNoMoreInteractions();
assertThat(res.getStatus().isOk()).isTrue();
assertThat(subchannelIsReady(res.getSubchannel())).isFalse();
final FakeSubchannel searchSubchannel = subchannels.getLast();
assertThat(subchannels).hasSize(1);
FakeSubchannel searchSubchannel = subchannels.getLast();
searchSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY));
inOrder.verify(helper)
.updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture());
assertThat(pickerCaptor.getValue()).isInstanceOf(RlsPicker.class);
final RlsPicker picker2 = (RlsPicker) pickerCaptor.getValue();
assertThat(picker2).isEqualTo(picker);
blockingRunInSyncContext(
new Runnable() {
@Override
public void run() {
PickResult res = picker2.pickSubchannel(
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT));
// verify success. Subchannel is wrapped, so checking attributes.
assertThat(res.getSubchannel()).isNotNull();
assertThat(res.getSubchannel().getAddresses())
.isEqualTo(searchSubchannel.getAddresses());
assertThat(res.getSubchannel().getAttributes())
.isEqualTo(searchSubchannel.getAttributes());
assertThat(res.getStatus().isOk()).isTrue();
}
});
inOrder.verifyNoMoreInteractions();
assertThat(subchannelIsReady(res.getSubchannel())).isTrue();
assertThat(res.getSubchannel().getAddresses()).isEqualTo(searchSubchannel.getAddresses());
assertThat(res.getSubchannel().getAttributes()).isEqualTo(searchSubchannel.getAttributes());
// rescue should be pending status
blockingRunInSyncContext(
new Runnable() {
@Override
public void run() {
PickResult res =
picker.pickSubchannel(
new PickSubchannelArgsImpl(fakeRescueMethod, headers, CallOptions.DEFAULT));
assertThat(res.getSubchannel()).isNull();
assertThat(res.getStatus().isOk()).isTrue();
}
});
// rescue should be pending status although the overall channel state is READY
res = picker.pickSubchannel(
new PickSubchannelArgsImpl(fakeRescueMethod, headers, CallOptions.DEFAULT));
inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class));
// other rls picker itself is ready due to first channel.
inOrder.verify(helper)
.updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture());
assertThat(subchannels).hasSize(2);
inOrder.verifyNoMoreInteractions();
assertThat(res.getStatus().isOk()).isTrue();
assertThat(subchannelIsReady(res.getSubchannel())).isFalse();
assertThat(subchannels).hasSize(2);
FakeSubchannel rescueSubchannel = subchannels.getLast();
// rescue subchannel is connecting
// search subchannel is down, rescue subchannel is connecting
searchSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.NOT_FOUND));
inOrder.verify(helper)
.updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture());
final FakeSubchannel rescueSubchannel = subchannels.getLast();
rescueSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY));
inOrder.verify(helper)
.updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture());
// search again, use pending fallback because searchSubchannel is in failure mode
blockingRunInSyncContext(
new Runnable() {
@Override
public void run() {
PickResult res =
picker.pickSubchannel(
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT));
assertThat(res.getSubchannel()).isNull();
assertThat(res.getStatus().isOk()).isTrue();
}
});
res = picker.pickSubchannel(
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT));
assertThat(res.getStatus().isOk()).isTrue();
assertThat(subchannelIsReady(res.getSubchannel())).isFalse();
inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class));
assertThat(subchannels).hasSize(3);
final FakeSubchannel fallbackSubchannel = subchannels.getLast();
FakeSubchannel fallbackSubchannel = subchannels.getLast();
fallbackSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY));
inOrder.verify(helper, times(2))
.updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture());
inOrder.verifyNoMoreInteractions();
blockingRunInSyncContext(
new Runnable() {
@Override
public void run() {
PickResult res =
picker.pickSubchannel(
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT));
assertThat(res.getSubchannel().getAddresses())
.isEqualTo(fallbackSubchannel.getAddresses());
assertThat(res.getSubchannel().getAttributes())
.isEqualTo(fallbackSubchannel.getAttributes());
assertThat(res.getStatus().isOk()).isTrue();
}
});
blockingRunInSyncContext(
new Runnable() {
@Override
public void run() {
PickResult res =
picker.pickSubchannel(
new PickSubchannelArgsImpl(fakeRescueMethod, headers, CallOptions.DEFAULT));
assertThat(res.getSubchannel().getAddresses())
.isEqualTo(rescueSubchannel.getAddresses());
assertThat(res.getSubchannel().getAttributes())
.isEqualTo(rescueSubchannel.getAttributes());
assertThat(res.getStatus().isOk()).isTrue();
}
});
res = picker.pickSubchannel(
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT));
assertThat(subchannelIsReady(res.getSubchannel())).isTrue();
assertThat(res.getSubchannel().getAddresses()).isEqualTo(searchSubchannel.getAddresses());
assertThat(res.getSubchannel().getAttributes()).isEqualTo(searchSubchannel.getAttributes());
res = picker.pickSubchannel(
new PickSubchannelArgsImpl(fakeRescueMethod, headers, CallOptions.DEFAULT));
assertThat(subchannelIsReady(res.getSubchannel())).isTrue();
assertThat(res.getSubchannel().getAddresses()).isEqualTo(rescueSubchannel.getAddresses());
assertThat(res.getSubchannel().getAttributes()).isEqualTo(rescueSubchannel.getAttributes());
// all channels are failed
rescueSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.NOT_FOUND));
@ -326,27 +261,18 @@ public class RlsLoadBalancerTest {
}
@Test
public void lb_nameResolutionFailed() throws Exception {
final InOrder inOrder = inOrder(helper);
public void lb_nameResolutionFailed() {
InOrder inOrder = inOrder(helper);
inOrder.verify(helper)
.updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture());
assertThat(pickerCaptor.getValue()).isInstanceOf(RlsPicker.class);
final RlsPicker picker = (RlsPicker) pickerCaptor.getValue();
final Metadata headers = new Metadata();
blockingRunInSyncContext(
new Runnable() {
@Override
public void run() {
PickResult res =
picker.pickSubchannel(
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT));
// verify pending
assertThat(res.getSubchannel()).isNull();
assertThat(res.getStatus().isOk()).isTrue();
}
});
SubchannelPicker picker = pickerCaptor.getValue();
Metadata headers = new Metadata();
PickResult res =
picker.pickSubchannel(
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT));
assertThat(res.getStatus().isOk()).isTrue();
assertThat(subchannelIsReady(res.getSubchannel())).isFalse();
inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class));
inOrder.verify(helper)
@ -354,29 +280,19 @@ public class RlsLoadBalancerTest {
assertThat(subchannels).hasSize(1);
inOrder.verifyNoMoreInteractions();
final FakeSubchannel searchSubchannel = subchannels.getLast();
FakeSubchannel searchSubchannel = subchannels.getLast();
searchSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY));
inOrder.verify(helper)
.updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture());
assertThat(pickerCaptor.getValue()).isInstanceOf(RlsPicker.class);
final RlsPicker picker2 = (RlsPicker) pickerCaptor.getValue();
SubchannelPicker picker2 = pickerCaptor.getValue();
assertThat(picker2).isEqualTo(picker);
blockingRunInSyncContext(
new Runnable() {
@Override
public void run() {
PickResult res = picker2.pickSubchannel(
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT));
// verify success. Subchannel is wrapped, so checking attributes.
assertThat(res.getSubchannel()).isNotNull();
assertThat(res.getSubchannel().getAddresses())
.isEqualTo(searchSubchannel.getAddresses());
assertThat(res.getSubchannel().getAttributes())
.isEqualTo(searchSubchannel.getAttributes());
assertThat(res.getStatus().isOk()).isTrue();
}
});
res = picker2.pickSubchannel(
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT));
// verify success. Subchannel is wrapped, so checking attributes.
assertThat(subchannelIsReady(res.getSubchannel())).isTrue();
assertThat(res.getSubchannel().getAddresses()).isEqualTo(searchSubchannel.getAddresses());
assertThat(res.getSubchannel().getAttributes()).isEqualTo(searchSubchannel.getAttributes());
inOrder.verifyNoMoreInteractions();
@ -384,36 +300,11 @@ public class RlsLoadBalancerTest {
verify(helper)
.updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture());
final SubchannelPicker failedPicker = pickerCaptor.getValue();
blockingRunInSyncContext(
new Runnable() {
@Override
public void run() {
PickResult res = failedPicker.pickSubchannel(
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT));
assertThat(res.getSubchannel()).isNull();
assertThat(res.getStatus().isOk()).isFalse();
}
});
}
private void blockingRunInSyncContext(final Runnable command) throws Exception {
final SettableFuture<Exception> exceptionFuture = SettableFuture.create();
syncContext.execute(new Runnable() {
@Override
public void run() {
try {
command.run();
exceptionFuture.set(null);
} catch (Exception e) {
exceptionFuture.set(e);
}
}
});
Exception exception = exceptionFuture.get(5, TimeUnit.SECONDS);
if (exception != null) {
throw exception;
}
SubchannelPicker failedPicker = pickerCaptor.getValue();
res = failedPicker.pickSubchannel(
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT));
assertThat(res.getStatus().isOk()).isFalse();
assertThat(subchannelIsReady(res.getSubchannel())).isFalse();
}
@SuppressWarnings("unchecked")
@ -566,6 +457,7 @@ public class RlsLoadBalancerTest {
private final Attributes attributes;
private List<EquivalentAddressGroup> eags;
private SubchannelStateListener listener;
private boolean isReady;
public FakeSubchannel(List<EquivalentAddressGroup> eags, Attributes attributes) {
this.eags = Collections.unmodifiableList(eags);
@ -602,6 +494,11 @@ public class RlsLoadBalancerTest {
public void updateState(ConnectivityStateInfo newState) {
listener.onSubchannelState(newState);
isReady = newState.getState().equals(ConnectivityState.READY);
}
}
private static boolean subchannelIsReady(Subchannel subchannel) {
return subchannel instanceof FakeSubchannel && ((FakeSubchannel) subchannel).isReady;
}
}