rls: fix wrong synchronization for pickSubchannel()

`RlsPicker.pickSubchannel()` does not run in SynchronizationContext, but it calls `CachingRlsLbClient.get()` which assumed running in SynchronizationContext. Fixed by removing `synchronizationContext.throwIfNotInThisSynchronizationContext()`. `CachingRlsLbClient.get()` is actually thread-safe in the sense it's guarded by lock, and `DataCacheEntry`'s fields are final.

`ChildPolicyWrapper.picker` was not thread-safe. Fixed by making it volatile.

Changed the test a bit since the old test doesn't really test things well.
This commit is contained in:
ZHANG Dapeng 2020-09-30 15:31:09 -07:00 committed by GitHub
parent 00e2d717a2
commit f6c2d221e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 96 additions and 334 deletions

View File

@ -453,7 +453,7 @@ final class CachingRlsLbClient {
private final RouteLookupResponse response; private final RouteLookupResponse response;
private final long expireTime; private final long expireTime;
private final long staleTime; private final long staleTime;
private ChildPolicyWrapper childPolicyWrapper; private final ChildPolicyWrapper childPolicyWrapper;
DataCacheEntry(RouteLookupRequest request, final RouteLookupResponse response) { DataCacheEntry(RouteLookupRequest request, final RouteLookupResponse response) {
super(request); super(request);
@ -467,21 +467,12 @@ final class CachingRlsLbClient {
staleTime = now + staleAgeNanos; staleTime = now + staleAgeNanos;
if (childPolicyWrapper.getPicker() != null) { if (childPolicyWrapper.getPicker() != null) {
// using cached childPolicyWrapper childPolicyWrapper.refreshState();
updateLbState();
} else { } else {
createChildLbPolicy(); createChildLbPolicy();
} }
} }
private void updateLbState() {
childPolicyWrapper
.getHelper()
.updateBalancingState(
childPolicyWrapper.getConnectivityStateInfo().getState(),
childPolicyWrapper.getPicker());
}
private void createChildLbPolicy() { private void createChildLbPolicy() {
ChildLoadBalancingPolicy childPolicy = lbPolicyConfig.getLoadBalancingPolicy(); ChildLoadBalancingPolicy childPolicy = lbPolicyConfig.getLoadBalancingPolicy();
LoadBalancerProvider lbProvider = childPolicy.getEffectiveLbProvider(); LoadBalancerProvider lbProvider = childPolicy.getEffectiveLbProvider();
@ -868,19 +859,15 @@ final class CachingRlsLbClient {
} }
if (response.hasData()) { if (response.hasData()) {
ChildPolicyWrapper childPolicyWrapper = response.getChildPolicyWrapper(); ChildPolicyWrapper childPolicyWrapper = response.getChildPolicyWrapper();
ConnectivityState connectivityState = SubchannelPicker picker = childPolicyWrapper.getPicker();
childPolicyWrapper.getConnectivityStateInfo().getState(); if (picker == null) {
switch (connectivityState) { return PickResult.withNoResult();
case IDLE:
case CONNECTING:
return PickResult.withNoResult();
case READY:
return childPolicyWrapper.getPicker().pickSubchannel(args);
case TRANSIENT_FAILURE:
case SHUTDOWN:
default:
return useFallback(args);
} }
PickResult result = picker.pickSubchannel(args);
if (result.getStatus().isOk()) {
return result;
}
return useFallback(args);
} else if (response.hasError()) { } else if (response.hasError()) {
return useFallback(args); return useFallback(args);
} else { } else {
@ -898,26 +885,11 @@ final class CachingRlsLbClient {
// TODO(creamsoup) wait until lb is ready // TODO(creamsoup) wait until lb is ready
startFallbackChildPolicy(); startFallbackChildPolicy();
} }
switch (fallbackChildPolicyWrapper.getConnectivityStateInfo().getState()) { SubchannelPicker picker = fallbackChildPolicyWrapper.getPicker();
case IDLE: if (picker == null) {
// fall through return PickResult.withNoResult();
case CONNECTING:
return PickResult.withNoResult();
case TRANSIENT_FAILURE:
// fall through
case SHUTDOWN:
return
PickResult
.withError(fallbackChildPolicyWrapper.getConnectivityStateInfo().getStatus());
case READY:
SubchannelPicker picker = fallbackChildPolicyWrapper.getPicker();
if (picker == null) {
return PickResult.withNoResult();
}
return picker.pickSubchannel(args);
default:
throw new AssertionError();
} }
return picker.pickSubchannel(args);
} }
private void startFallbackChildPolicy() { private void startFallbackChildPolicy() {

View File

@ -23,19 +23,15 @@ import static com.google.common.base.Preconditions.checkState;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects; import com.google.common.base.MoreObjects;
import io.grpc.ConnectivityState; import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
import io.grpc.LoadBalancer.CreateSubchannelArgs;
import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.Helper;
import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.Subchannel;
import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancer.SubchannelPicker;
import io.grpc.LoadBalancer.SubchannelStateListener;
import io.grpc.LoadBalancerProvider; import io.grpc.LoadBalancerProvider;
import io.grpc.LoadBalancerRegistry; import io.grpc.LoadBalancerRegistry;
import io.grpc.internal.ObjectPool; import io.grpc.internal.ObjectPool;
import io.grpc.rls.ChildLoadBalancerHelper.ChildLoadBalancerHelperProvider; import io.grpc.rls.ChildLoadBalancerHelper.ChildLoadBalancerHelperProvider;
import io.grpc.rls.RlsProtoData.RouteLookupConfig; import io.grpc.rls.RlsProtoData.RouteLookupConfig;
import io.grpc.util.ForwardingLoadBalancerHelper; import io.grpc.util.ForwardingLoadBalancerHelper;
import io.grpc.util.ForwardingSubchannel;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
@ -242,9 +238,8 @@ final class LbPolicyConfiguration {
private final String target; private final String target;
private final ChildPolicyReportingHelper helper; private final ChildPolicyReportingHelper helper;
private ConnectivityStateInfo connectivityStateInfo = private volatile SubchannelPicker picker;
ConnectivityStateInfo.forNonError(ConnectivityState.IDLE); private ConnectivityState state;
private SubchannelPicker picker;
public ChildPolicyWrapper( public ChildPolicyWrapper(
String target, String target,
@ -259,10 +254,6 @@ final class LbPolicyConfiguration {
return target; return target;
} }
void setPicker(SubchannelPicker picker) {
this.picker = checkNotNull(picker, "picker");
}
SubchannelPicker getPicker() { SubchannelPicker getPicker() {
return picker; return picker;
} }
@ -271,32 +262,8 @@ final class LbPolicyConfiguration {
return helper; return helper;
} }
void setConnectivityStateInfo(ConnectivityStateInfo connectivityStateInfo) { void refreshState() {
this.connectivityStateInfo = connectivityStateInfo; helper.updateBalancingState(state, picker);
}
ConnectivityStateInfo getConnectivityStateInfo() {
return connectivityStateInfo;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
ChildPolicyWrapper that = (ChildPolicyWrapper) o;
return Objects.equals(target, that.target)
&& Objects.equals(helper, that.helper)
&& Objects.equals(connectivityStateInfo, that.connectivityStateInfo)
&& Objects.equals(picker, that.picker);
}
@Override
public int hashCode() {
return Objects.hash(target, helper, connectivityStateInfo, picker);
} }
@Override @Override
@ -304,8 +271,8 @@ final class LbPolicyConfiguration {
return MoreObjects.toStringHelper(this) return MoreObjects.toStringHelper(this)
.add("target", target) .add("target", target)
.add("helper", helper) .add("helper", helper)
.add("connectivityStateInfo", connectivityStateInfo)
.add("picker", picker) .add("picker", picker)
.add("state", state)
.toString(); .toString();
} }
@ -335,32 +302,11 @@ final class LbPolicyConfiguration {
@Override @Override
public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) {
setPicker(newPicker); picker = newPicker;
state = newState;
super.updateBalancingState(newState, newPicker); super.updateBalancingState(newState, newPicker);
listener.onStatusChanged(newState); listener.onStatusChanged(newState);
} }
@Override
public Subchannel createSubchannel(CreateSubchannelArgs args) {
final Subchannel subchannel = super.createSubchannel(args);
return new ForwardingSubchannel() {
@Override
protected Subchannel delegate() {
return subchannel;
}
@Override
public void start(final SubchannelStateListener listener) {
super.start(new SubchannelStateListener() {
@Override
public void onSubchannelState(ConnectivityStateInfo newState) {
setConnectivityStateInfo(newState);
listener.onSubchannelState(newState);
}
});
}
};
}
} }
} }

View File

@ -418,7 +418,14 @@ public class CachingRlsLbClientTest {
@Override @Override
public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) {
// TODO: make the picker accessible // TODO: make the picker accessible
helper.updateBalancingState(ConnectivityState.READY, mock(SubchannelPicker.class)); helper.updateBalancingState(
ConnectivityState.READY,
new SubchannelPicker() {
@Override
public PickResult pickSubchannel(PickSubchannelArgs args) {
return PickResult.withSubchannel(mock(Subchannel.class));
}
});
} }
@Override @Override

View File

@ -16,25 +16,17 @@
package io.grpc.rls; package io.grpc.rls;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import io.grpc.Attributes;
import io.grpc.ConnectivityState; import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
import io.grpc.EquivalentAddressGroup;
import io.grpc.LoadBalancer.CreateSubchannelArgs;
import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.Helper;
import io.grpc.LoadBalancer.Subchannel;
import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancer.SubchannelPicker;
import io.grpc.LoadBalancer.SubchannelStateListener;
import io.grpc.LoadBalancerProvider; import io.grpc.LoadBalancerProvider;
import io.grpc.LoadBalancerRegistry; import io.grpc.LoadBalancerRegistry;
import io.grpc.rls.ChildLoadBalancerHelper.ChildLoadBalancerHelperProvider; import io.grpc.rls.ChildLoadBalancerHelper.ChildLoadBalancerHelperProvider;
@ -44,7 +36,6 @@ import io.grpc.rls.LbPolicyConfiguration.ChildPolicyWrapper;
import io.grpc.rls.LbPolicyConfiguration.ChildPolicyWrapper.ChildPolicyReportingHelper; import io.grpc.rls.LbPolicyConfiguration.ChildPolicyWrapper.ChildPolicyReportingHelper;
import io.grpc.rls.LbPolicyConfiguration.InvalidChildPolicyConfigException; import io.grpc.rls.LbPolicyConfiguration.InvalidChildPolicyConfigException;
import io.grpc.rls.LbPolicyConfiguration.RefCountedChildPolicyWrapperFactory; import io.grpc.rls.LbPolicyConfiguration.RefCountedChildPolicyWrapperFactory;
import java.net.SocketAddress;
import java.util.Map; import java.util.Map;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
@ -133,31 +124,6 @@ public class LbPolicyConfigurationTest {
} }
} }
@Test
public void subchannelStateChange_updateChildPolicyWrapper() {
ChildPolicyWrapper childPolicyWrapper = factory.createOrGet("foo.google.com");
ChildPolicyReportingHelper childPolicyReportingHelper = childPolicyWrapper.getHelper();
FakeSubchannel fakeSubchannel = new FakeSubchannel();
when(helper.createSubchannel(any(CreateSubchannelArgs.class))).thenReturn(fakeSubchannel);
Subchannel subchannel =
childPolicyReportingHelper
.createSubchannel(
CreateSubchannelArgs.newBuilder()
.setAddresses(new EquivalentAddressGroup(mock(SocketAddress.class)))
.build());
subchannel.start(new SubchannelStateListener() {
@Override
public void onSubchannelState(ConnectivityStateInfo newState) {
// no-op
}
});
fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING));
assertThat(childPolicyWrapper.getConnectivityStateInfo())
.isEqualTo(ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING));
}
@Test @Test
public void updateBalancingState_triggersListener() { public void updateBalancingState_triggersListener() {
ChildPolicyWrapper childPolicyWrapper = factory.createOrGet("foo.google.com"); ChildPolicyWrapper childPolicyWrapper = factory.createOrGet("foo.google.com");
@ -171,30 +137,4 @@ public class LbPolicyConfigurationTest {
// picker governs childPickers will be reported to parent LB // picker governs childPickers will be reported to parent LB
verify(helper).updateBalancingState(ConnectivityState.READY, picker); verify(helper).updateBalancingState(ConnectivityState.READY, picker);
} }
private static class FakeSubchannel extends Subchannel {
private SubchannelStateListener listener;
@Override
public void start(SubchannelStateListener listener) {
this.listener = listener;
}
void updateState(ConnectivityStateInfo newState) {
checkState(listener != null, "channel is not started yet");
listener.onSubchannelState(newState);
}
@Override
public void shutdown() {}
@Override
public void requestConnection() {}
@Override
public Attributes getAttributes() {
return null;
}
}
} }

View File

@ -30,7 +30,6 @@ import static org.mockito.Mockito.verify;
import com.google.common.base.Converter; import com.google.common.base.Converter;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.SettableFuture;
import io.grpc.Attributes; import io.grpc.Attributes;
import io.grpc.CallOptions; import io.grpc.CallOptions;
import io.grpc.ChannelLogger; import io.grpc.ChannelLogger;
@ -51,8 +50,8 @@ import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
import io.grpc.MethodDescriptor.Marshaller; import io.grpc.MethodDescriptor.Marshaller;
import io.grpc.MethodDescriptor.MethodType; import io.grpc.MethodDescriptor.MethodType;
import io.grpc.NameResolver.ConfigOrError;
import io.grpc.NameResolver; import io.grpc.NameResolver;
import io.grpc.NameResolver.ConfigOrError;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.SynchronizationContext; import io.grpc.SynchronizationContext;
import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.inprocess.InProcessChannelBuilder;
@ -60,7 +59,6 @@ import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.internal.JsonParser; import io.grpc.internal.JsonParser;
import io.grpc.internal.PickSubchannelArgsImpl; import io.grpc.internal.PickSubchannelArgsImpl;
import io.grpc.lookup.v1.RouteLookupServiceGrpc; import io.grpc.lookup.v1.RouteLookupServiceGrpc;
import io.grpc.rls.CachingRlsLbClient.RlsPicker;
import io.grpc.rls.RlsLoadBalancer.CachingRlsLbClientBuilderProvider; import io.grpc.rls.RlsLoadBalancer.CachingRlsLbClientBuilderProvider;
import io.grpc.rls.RlsProtoConverters.RouteLookupResponseConverter; import io.grpc.rls.RlsProtoConverters.RouteLookupResponseConverter;
import io.grpc.rls.RlsProtoData.RouteLookupRequest; import io.grpc.rls.RlsProtoData.RouteLookupRequest;
@ -76,7 +74,6 @@ import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
@ -180,140 +177,78 @@ public class RlsLoadBalancerTest {
} }
@Test @Test
public void lb_working() throws Exception { public void lb_working() {
final InOrder inOrder = inOrder(helper); InOrder inOrder = inOrder(helper);
inOrder.verify(helper) inOrder.verify(helper)
.updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture());
assertThat(pickerCaptor.getValue()).isInstanceOf(RlsPicker.class); SubchannelPicker picker = pickerCaptor.getValue();
final RlsPicker picker = (RlsPicker) pickerCaptor.getValue(); Metadata headers = new Metadata();
final Metadata headers = new Metadata(); PickResult res = picker.pickSubchannel(
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT));
blockingRunInSyncContext(
new Runnable() {
@Override
public void run() {
PickResult res =
picker.pickSubchannel(
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT));
// verify pending
assertThat(res.getSubchannel()).isNull();
assertThat(res.getStatus().isOk()).isTrue();
}
});
inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class));
inOrder.verify(helper) inOrder.verify(helper)
.updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); .updateBalancingState(eq(ConnectivityState.CONNECTING), any(SubchannelPicker.class));
assertThat(subchannels).hasSize(1);
inOrder.verifyNoMoreInteractions(); inOrder.verifyNoMoreInteractions();
assertThat(res.getStatus().isOk()).isTrue();
assertThat(subchannelIsReady(res.getSubchannel())).isFalse();
final FakeSubchannel searchSubchannel = subchannels.getLast(); assertThat(subchannels).hasSize(1);
FakeSubchannel searchSubchannel = subchannels.getLast();
searchSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); searchSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY));
inOrder.verify(helper) inOrder.verify(helper)
.updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture()); .updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture());
assertThat(pickerCaptor.getValue()).isInstanceOf(RlsPicker.class);
final RlsPicker picker2 = (RlsPicker) pickerCaptor.getValue();
assertThat(picker2).isEqualTo(picker);
blockingRunInSyncContext(
new Runnable() {
@Override
public void run() {
PickResult res = picker2.pickSubchannel(
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT));
// verify success. Subchannel is wrapped, so checking attributes.
assertThat(res.getSubchannel()).isNotNull();
assertThat(res.getSubchannel().getAddresses())
.isEqualTo(searchSubchannel.getAddresses());
assertThat(res.getSubchannel().getAttributes())
.isEqualTo(searchSubchannel.getAttributes());
assertThat(res.getStatus().isOk()).isTrue();
}
});
inOrder.verifyNoMoreInteractions(); inOrder.verifyNoMoreInteractions();
assertThat(subchannelIsReady(res.getSubchannel())).isTrue();
assertThat(res.getSubchannel().getAddresses()).isEqualTo(searchSubchannel.getAddresses());
assertThat(res.getSubchannel().getAttributes()).isEqualTo(searchSubchannel.getAttributes());
// rescue should be pending status // rescue should be pending status although the overall channel state is READY
blockingRunInSyncContext( res = picker.pickSubchannel(
new Runnable() { new PickSubchannelArgsImpl(fakeRescueMethod, headers, CallOptions.DEFAULT));
@Override
public void run() {
PickResult res =
picker.pickSubchannel(
new PickSubchannelArgsImpl(fakeRescueMethod, headers, CallOptions.DEFAULT));
assertThat(res.getSubchannel()).isNull();
assertThat(res.getStatus().isOk()).isTrue();
}
});
inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class));
// other rls picker itself is ready due to first channel. // other rls picker itself is ready due to first channel.
inOrder.verify(helper) inOrder.verify(helper)
.updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture()); .updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture());
assertThat(subchannels).hasSize(2);
inOrder.verifyNoMoreInteractions(); inOrder.verifyNoMoreInteractions();
assertThat(res.getStatus().isOk()).isTrue();
assertThat(subchannelIsReady(res.getSubchannel())).isFalse();
assertThat(subchannels).hasSize(2);
FakeSubchannel rescueSubchannel = subchannels.getLast();
// rescue subchannel is connecting // search subchannel is down, rescue subchannel is connecting
searchSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.NOT_FOUND)); searchSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.NOT_FOUND));
inOrder.verify(helper) inOrder.verify(helper)
.updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture());
final FakeSubchannel rescueSubchannel = subchannels.getLast();
rescueSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); rescueSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY));
inOrder.verify(helper) inOrder.verify(helper)
.updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture()); .updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture());
// search again, use pending fallback because searchSubchannel is in failure mode // search again, use pending fallback because searchSubchannel is in failure mode
blockingRunInSyncContext( res = picker.pickSubchannel(
new Runnable() { new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT));
@Override assertThat(res.getStatus().isOk()).isTrue();
public void run() { assertThat(subchannelIsReady(res.getSubchannel())).isFalse();
PickResult res =
picker.pickSubchannel(
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT));
assertThat(res.getSubchannel()).isNull();
assertThat(res.getStatus().isOk()).isTrue();
}
});
inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class));
assertThat(subchannels).hasSize(3); assertThat(subchannels).hasSize(3);
final FakeSubchannel fallbackSubchannel = subchannels.getLast(); FakeSubchannel fallbackSubchannel = subchannels.getLast();
fallbackSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); fallbackSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY));
inOrder.verify(helper, times(2)) inOrder.verify(helper, times(2))
.updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture()); .updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture());
inOrder.verifyNoMoreInteractions(); inOrder.verifyNoMoreInteractions();
blockingRunInSyncContext( res = picker.pickSubchannel(
new Runnable() { new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT));
@Override assertThat(subchannelIsReady(res.getSubchannel())).isTrue();
public void run() { assertThat(res.getSubchannel().getAddresses()).isEqualTo(searchSubchannel.getAddresses());
PickResult res = assertThat(res.getSubchannel().getAttributes()).isEqualTo(searchSubchannel.getAttributes());
picker.pickSubchannel(
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); res = picker.pickSubchannel(
assertThat(res.getSubchannel().getAddresses()) new PickSubchannelArgsImpl(fakeRescueMethod, headers, CallOptions.DEFAULT));
.isEqualTo(fallbackSubchannel.getAddresses()); assertThat(subchannelIsReady(res.getSubchannel())).isTrue();
assertThat(res.getSubchannel().getAttributes()) assertThat(res.getSubchannel().getAddresses()).isEqualTo(rescueSubchannel.getAddresses());
.isEqualTo(fallbackSubchannel.getAttributes()); assertThat(res.getSubchannel().getAttributes()).isEqualTo(rescueSubchannel.getAttributes());
assertThat(res.getStatus().isOk()).isTrue();
}
});
blockingRunInSyncContext(
new Runnable() {
@Override
public void run() {
PickResult res =
picker.pickSubchannel(
new PickSubchannelArgsImpl(fakeRescueMethod, headers, CallOptions.DEFAULT));
assertThat(res.getSubchannel().getAddresses())
.isEqualTo(rescueSubchannel.getAddresses());
assertThat(res.getSubchannel().getAttributes())
.isEqualTo(rescueSubchannel.getAttributes());
assertThat(res.getStatus().isOk()).isTrue();
}
});
// all channels are failed // all channels are failed
rescueSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.NOT_FOUND)); rescueSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.NOT_FOUND));
@ -326,27 +261,18 @@ public class RlsLoadBalancerTest {
} }
@Test @Test
public void lb_nameResolutionFailed() throws Exception { public void lb_nameResolutionFailed() {
final InOrder inOrder = inOrder(helper); InOrder inOrder = inOrder(helper);
inOrder.verify(helper) inOrder.verify(helper)
.updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture());
assertThat(pickerCaptor.getValue()).isInstanceOf(RlsPicker.class); SubchannelPicker picker = pickerCaptor.getValue();
final RlsPicker picker = (RlsPicker) pickerCaptor.getValue(); Metadata headers = new Metadata();
final Metadata headers = new Metadata(); PickResult res =
picker.pickSubchannel(
blockingRunInSyncContext( new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT));
new Runnable() { assertThat(res.getStatus().isOk()).isTrue();
@Override assertThat(subchannelIsReady(res.getSubchannel())).isFalse();
public void run() {
PickResult res =
picker.pickSubchannel(
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT));
// verify pending
assertThat(res.getSubchannel()).isNull();
assertThat(res.getStatus().isOk()).isTrue();
}
});
inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class));
inOrder.verify(helper) inOrder.verify(helper)
@ -354,29 +280,19 @@ public class RlsLoadBalancerTest {
assertThat(subchannels).hasSize(1); assertThat(subchannels).hasSize(1);
inOrder.verifyNoMoreInteractions(); inOrder.verifyNoMoreInteractions();
final FakeSubchannel searchSubchannel = subchannels.getLast(); FakeSubchannel searchSubchannel = subchannels.getLast();
searchSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); searchSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY));
inOrder.verify(helper) inOrder.verify(helper)
.updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture()); .updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture());
assertThat(pickerCaptor.getValue()).isInstanceOf(RlsPicker.class); SubchannelPicker picker2 = pickerCaptor.getValue();
final RlsPicker picker2 = (RlsPicker) pickerCaptor.getValue();
assertThat(picker2).isEqualTo(picker); assertThat(picker2).isEqualTo(picker);
blockingRunInSyncContext( res = picker2.pickSubchannel(
new Runnable() { new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT));
@Override // verify success. Subchannel is wrapped, so checking attributes.
public void run() { assertThat(subchannelIsReady(res.getSubchannel())).isTrue();
PickResult res = picker2.pickSubchannel( assertThat(res.getSubchannel().getAddresses()).isEqualTo(searchSubchannel.getAddresses());
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); assertThat(res.getSubchannel().getAttributes()).isEqualTo(searchSubchannel.getAttributes());
// verify success. Subchannel is wrapped, so checking attributes.
assertThat(res.getSubchannel()).isNotNull();
assertThat(res.getSubchannel().getAddresses())
.isEqualTo(searchSubchannel.getAddresses());
assertThat(res.getSubchannel().getAttributes())
.isEqualTo(searchSubchannel.getAttributes());
assertThat(res.getStatus().isOk()).isTrue();
}
});
inOrder.verifyNoMoreInteractions(); inOrder.verifyNoMoreInteractions();
@ -384,36 +300,11 @@ public class RlsLoadBalancerTest {
verify(helper) verify(helper)
.updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); .updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture());
final SubchannelPicker failedPicker = pickerCaptor.getValue(); SubchannelPicker failedPicker = pickerCaptor.getValue();
blockingRunInSyncContext( res = failedPicker.pickSubchannel(
new Runnable() { new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT));
@Override assertThat(res.getStatus().isOk()).isFalse();
public void run() { assertThat(subchannelIsReady(res.getSubchannel())).isFalse();
PickResult res = failedPicker.pickSubchannel(
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT));
assertThat(res.getSubchannel()).isNull();
assertThat(res.getStatus().isOk()).isFalse();
}
});
}
private void blockingRunInSyncContext(final Runnable command) throws Exception {
final SettableFuture<Exception> exceptionFuture = SettableFuture.create();
syncContext.execute(new Runnable() {
@Override
public void run() {
try {
command.run();
exceptionFuture.set(null);
} catch (Exception e) {
exceptionFuture.set(e);
}
}
});
Exception exception = exceptionFuture.get(5, TimeUnit.SECONDS);
if (exception != null) {
throw exception;
}
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@ -566,6 +457,7 @@ public class RlsLoadBalancerTest {
private final Attributes attributes; private final Attributes attributes;
private List<EquivalentAddressGroup> eags; private List<EquivalentAddressGroup> eags;
private SubchannelStateListener listener; private SubchannelStateListener listener;
private boolean isReady;
public FakeSubchannel(List<EquivalentAddressGroup> eags, Attributes attributes) { public FakeSubchannel(List<EquivalentAddressGroup> eags, Attributes attributes) {
this.eags = Collections.unmodifiableList(eags); this.eags = Collections.unmodifiableList(eags);
@ -602,6 +494,11 @@ public class RlsLoadBalancerTest {
public void updateState(ConnectivityStateInfo newState) { public void updateState(ConnectivityStateInfo newState) {
listener.onSubchannelState(newState); listener.onSubchannelState(newState);
isReady = newState.getState().equals(ConnectivityState.READY);
} }
} }
private static boolean subchannelIsReady(Subchannel subchannel) {
return subchannel instanceof FakeSubchannel && ((FakeSubchannel) subchannel).isReady;
}
} }