Prepare to switch flag to use new PickFirstLeafLoadBalancer by default (#10998)

* Fix PickFirstLeafLoadBalancer and tests to work when it is used.
* Actually use EAG attributes for subchannels.
This commit is contained in:
Larry Safran 2024-03-11 21:12:56 +00:00 committed by GitHub
parent ebbe0673f3
commit d1c406bd23
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 344 additions and 218 deletions

View File

@ -954,6 +954,9 @@ public final class GrpcUtil {
if (envVar == null) { if (envVar == null) {
envVar = System.getProperty(envVarName); envVar = System.getProperty(envVarName);
} }
if (envVar != null) {
envVar = envVar.trim();
}
if (enableByDefault) { if (enableByDefault) {
return Strings.isNullOrEmpty(envVar) || Boolean.parseBoolean(envVar); return Strings.isNullOrEmpty(envVar) || Boolean.parseBoolean(envVar);
} else { } else {

View File

@ -131,8 +131,14 @@ final class PickFirstLeafLoadBalancer extends LoadBalancer {
// If the previous ready subchannel exists in new address list, // If the previous ready subchannel exists in new address list,
// keep this connection and don't create new subchannels // keep this connection and don't create new subchannels
SocketAddress previousAddress = addressIndex.getCurrentAddress(); SocketAddress previousAddress = addressIndex.getCurrentAddress();
Attributes prevEagAttrs = addressIndex.getCurrentEagAttributes();
addressIndex.updateGroups(newImmutableAddressGroups); addressIndex.updateGroups(newImmutableAddressGroups);
if (addressIndex.seekTo(previousAddress)) { if (addressIndex.seekTo(previousAddress)) {
if (!addressIndex.getCurrentEagAttributes().equals(prevEagAttrs)) {
log.log(Level.FINE, "EAG attributes changed, need to update subchannel");
SubchannelData subchannelData = subchannels.get(previousAddress);
subchannelData.getSubchannel().updateAddresses(addressIndex.getCurrentEagAsList());
}
return Status.OK; return Status.OK;
} else { } else {
addressIndex.reset(); // Previous ready subchannel not in the new list of addresses addressIndex.reset(); // Previous ready subchannel not in the new list of addresses
@ -354,7 +360,7 @@ final class PickFirstLeafLoadBalancer extends LoadBalancer {
currentAddress = addressIndex.getCurrentAddress(); currentAddress = addressIndex.getCurrentAddress();
subchannel = subchannels.containsKey(currentAddress) subchannel = subchannels.containsKey(currentAddress)
? subchannels.get(currentAddress).getSubchannel() ? subchannels.get(currentAddress).getSubchannel()
: createNewSubchannel(currentAddress); : createNewSubchannel(currentAddress, addressIndex.getCurrentEagAttributes());
ConnectivityState subchannelState = subchannels.get(currentAddress).getState(); ConnectivityState subchannelState = subchannels.get(currentAddress).getState();
switch (subchannelState) { switch (subchannelState) {
@ -418,12 +424,12 @@ final class PickFirstLeafLoadBalancer extends LoadBalancer {
} }
} }
private Subchannel createNewSubchannel(SocketAddress addr) { private Subchannel createNewSubchannel(SocketAddress addr, Attributes attrs) {
HealthListener hcListener = new HealthListener(); HealthListener hcListener = new HealthListener();
final Subchannel subchannel = helper.createSubchannel( final Subchannel subchannel = helper.createSubchannel(
CreateSubchannelArgs.newBuilder() CreateSubchannelArgs.newBuilder()
.setAddresses(Lists.newArrayList( .setAddresses(Lists.newArrayList(
new EquivalentAddressGroup(addr))) new EquivalentAddressGroup(addr, attrs)))
.addOption(HEALTH_CONSUMER_LISTENER_ARG_KEY, hcListener) .addOption(HEALTH_CONSUMER_LISTENER_ARG_KEY, hcListener)
.build()); .build());
if (subchannel == null) { if (subchannel == null) {
@ -433,8 +439,8 @@ final class PickFirstLeafLoadBalancer extends LoadBalancer {
SubchannelData subchannelData = new SubchannelData(subchannel, IDLE, hcListener); SubchannelData subchannelData = new SubchannelData(subchannel, IDLE, hcListener);
hcListener.subchannelData = subchannelData; hcListener.subchannelData = subchannelData;
subchannels.put(addr, subchannelData); subchannels.put(addr, subchannelData);
Attributes attrs = subchannel.getAttributes(); Attributes scAttrs = subchannel.getAttributes();
if (attrs.get(LoadBalancer.HAS_HEALTH_PRODUCER_LISTENER_KEY) == null) { if (scAttrs.get(LoadBalancer.HAS_HEALTH_PRODUCER_LISTENER_KEY) == null) {
hcListener.healthStateInfo = ConnectivityStateInfo.forNonError(READY); hcListener.healthStateInfo = ConnectivityStateInfo.forNonError(READY);
} }
subchannel.start(stateInfo -> processSubchannelState(subchannel, stateInfo)); subchannel.start(stateInfo -> processSubchannelState(subchannel, stateInfo));
@ -584,6 +590,11 @@ final class PickFirstLeafLoadBalancer extends LoadBalancer {
return addressGroups.get(groupIndex).getAttributes(); return addressGroups.get(groupIndex).getAttributes();
} }
public List<EquivalentAddressGroup> getCurrentEagAsList() {
return Collections.singletonList(
new EquivalentAddressGroup(getCurrentAddress(), getCurrentEagAttributes()));
}
/** /**
* Update to new groups, resetting the current index. * Update to new groups, resetting the current index.
*/ */

View File

@ -16,12 +16,13 @@
package io.grpc.internal; package io.grpc.internal;
import com.google.common.base.Strings; import com.google.common.annotations.VisibleForTesting;
import io.grpc.LoadBalancer; import io.grpc.LoadBalancer;
import io.grpc.LoadBalancerProvider; import io.grpc.LoadBalancerProvider;
import io.grpc.NameResolver; import io.grpc.NameResolver;
import io.grpc.NameResolver.ConfigOrError; import io.grpc.NameResolver.ConfigOrError;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.internal.PickFirstLeafLoadBalancer.PickFirstLeafLoadBalancerConfig;
import io.grpc.internal.PickFirstLoadBalancer.PickFirstLoadBalancerConfig; import io.grpc.internal.PickFirstLoadBalancer.PickFirstLoadBalancerConfig;
import java.util.Map; import java.util.Map;
@ -35,8 +36,7 @@ public final class PickFirstLoadBalancerProvider extends LoadBalancerProvider {
private static final String SHUFFLE_ADDRESS_LIST_KEY = "shuffleAddressList"; private static final String SHUFFLE_ADDRESS_LIST_KEY = "shuffleAddressList";
static boolean enableNewPickFirst = static boolean enableNewPickFirst =
!Strings.isNullOrEmpty(System.getenv("GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST")) GrpcUtil.getFlag("GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST", false);
&& Boolean.parseBoolean(System.getenv("GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST"));
@Override @Override
public boolean isAvailable() { public boolean isAvailable() {
@ -63,16 +63,28 @@ public final class PickFirstLoadBalancerProvider extends LoadBalancerProvider {
} }
@Override @Override
public ConfigOrError parseLoadBalancingPolicyConfig( public ConfigOrError parseLoadBalancingPolicyConfig(Map<String, ?> rawLbPolicyConfig) {
Map<String, ?> rawLoadBalancingPolicyConfig) {
try { try {
return ConfigOrError.fromConfig( Object config = getLbPolicyConfig(rawLbPolicyConfig);
new PickFirstLoadBalancerConfig(JsonUtil.getBoolean(rawLoadBalancingPolicyConfig, return ConfigOrError.fromConfig(config);
SHUFFLE_ADDRESS_LIST_KEY)));
} catch (RuntimeException e) { } catch (RuntimeException e) {
return ConfigOrError.fromError( return ConfigOrError.fromError(
Status.UNAVAILABLE.withCause(e).withDescription( Status.UNAVAILABLE.withCause(e).withDescription(
"Failed parsing configuration for " + getPolicyName())); "Failed parsing configuration for " + getPolicyName()));
} }
} }
private static Object getLbPolicyConfig(Map<String, ?> rawLbPolicyConfig) {
Boolean shuffleAddressList = JsonUtil.getBoolean(rawLbPolicyConfig, SHUFFLE_ADDRESS_LIST_KEY);
if (enableNewPickFirst) {
return new PickFirstLeafLoadBalancerConfig(shuffleAddressList);
} else {
return new PickFirstLoadBalancerConfig(shuffleAddressList);
}
}
@VisibleForTesting
public static boolean isEnabledNewPickFirst() {
return enableNewPickFirst;
}
} }

View File

@ -49,6 +49,7 @@ import io.grpc.LoadBalancerRegistry;
import io.grpc.NameResolver.ConfigOrError; import io.grpc.NameResolver.ConfigOrError;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.internal.AutoConfiguredLoadBalancerFactory.AutoConfiguredLoadBalancer; import io.grpc.internal.AutoConfiguredLoadBalancerFactory.AutoConfiguredLoadBalancer;
import io.grpc.internal.PickFirstLeafLoadBalancer.PickFirstLeafLoadBalancerConfig;
import io.grpc.internal.PickFirstLoadBalancer.PickFirstLoadBalancerConfig; import io.grpc.internal.PickFirstLoadBalancer.PickFirstLoadBalancerConfig;
import io.grpc.internal.ServiceConfigUtil.PolicySelection; import io.grpc.internal.ServiceConfigUtil.PolicySelection;
import io.grpc.util.ForwardingLoadBalancerHelper; import io.grpc.util.ForwardingLoadBalancerHelper;
@ -95,6 +96,11 @@ public class AutoConfiguredLoadBalancerFactoryTest {
delegatesTo( delegatesTo(
new FakeLoadBalancerProvider("test_lb2", testLbBalancer2, nextParsedConfigOrError2))); new FakeLoadBalancerProvider("test_lb2", testLbBalancer2, nextParsedConfigOrError2)));
private final Class<? extends LoadBalancer> pfLbClass =
PickFirstLoadBalancerProvider.enableNewPickFirst
? PickFirstLeafLoadBalancer.class
: PickFirstLoadBalancer.class;
@Before @Before
public void setUp() { public void setUp() {
when(testLbBalancer.acceptResolvedAddresses(isA(ResolvedAddresses.class))).thenReturn( when(testLbBalancer.acceptResolvedAddresses(isA(ResolvedAddresses.class))).thenReturn(
@ -429,7 +435,7 @@ public class AutoConfiguredLoadBalancerFactoryTest {
.setLoadBalancingPolicyConfig(null) .setLoadBalancingPolicyConfig(null)
.build()); .build());
assertThat(addressesAcceptanceStatus.isOk()).isTrue(); assertThat(addressesAcceptanceStatus.isOk()).isTrue();
assertThat(lb.getDelegate()).isInstanceOf(PickFirstLoadBalancer.class); assertThat(lb.getDelegate()).isInstanceOf(pfLbClass);
} }
@Test @Test
@ -484,7 +490,7 @@ public class AutoConfiguredLoadBalancerFactoryTest {
verify(channelLogger).log( verify(channelLogger).log(
eq(ChannelLogLevel.INFO), eq(ChannelLogLevel.INFO),
eq("Load balancer changed from {0} to {1}"), eq("Load balancer changed from {0} to {1}"),
eq("PickFirstLoadBalancer"), eq(pfLbClass.getSimpleName()),
eq(testLbBalancer.getClass().getSimpleName())); eq(testLbBalancer.getClass().getSimpleName()));
verify(channelLogger).log( verify(channelLogger).log(
@ -628,8 +634,15 @@ public class AutoConfiguredLoadBalancerFactoryTest {
assertThat(parsed.getConfig()).isNotNull(); assertThat(parsed.getConfig()).isNotNull();
PolicySelection policySelection = (PolicySelection) parsed.getConfig(); PolicySelection policySelection = (PolicySelection) parsed.getConfig();
assertThat(policySelection.provider).isInstanceOf(PickFirstLoadBalancerProvider.class); assertThat(policySelection.provider).isInstanceOf(PickFirstLoadBalancerProvider.class);
assertThat(policySelection.config).isInstanceOf(PickFirstLoadBalancerConfig.class); if (PickFirstLoadBalancerProvider.isEnabledNewPickFirst()) {
assertThat(((PickFirstLoadBalancerConfig) policySelection.config).shuffleAddressList).isTrue(); assertThat(policySelection.config).isInstanceOf(PickFirstLeafLoadBalancerConfig.class);
assertThat(((PickFirstLeafLoadBalancerConfig) policySelection.config).shuffleAddressList)
.isTrue();
} else {
assertThat(policySelection.config).isInstanceOf(PickFirstLoadBalancerConfig.class);
assertThat(((PickFirstLoadBalancerConfig) policySelection.config).shuffleAddressList)
.isTrue();
}
verifyNoInteractions(channelLogger); verifyNoInteractions(channelLogger);
} }

View File

@ -19,6 +19,7 @@ package io.grpc.internal;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import io.grpc.NameResolver.ConfigOrError; import io.grpc.NameResolver.ConfigOrError;
import io.grpc.internal.PickFirstLeafLoadBalancer.PickFirstLeafLoadBalancerConfig;
import io.grpc.internal.PickFirstLoadBalancer.PickFirstLoadBalancerConfig; import io.grpc.internal.PickFirstLoadBalancer.PickFirstLoadBalancerConfig;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@ -35,10 +36,23 @@ public class PickFirstLoadBalancerProviderTest {
rawConfig.put("shuffleAddressList", true); rawConfig.put("shuffleAddressList", true);
ConfigOrError parsedConfig = new PickFirstLoadBalancerProvider().parseLoadBalancingPolicyConfig( ConfigOrError parsedConfig = new PickFirstLoadBalancerProvider().parseLoadBalancingPolicyConfig(
rawConfig); rawConfig);
PickFirstLoadBalancerConfig config = (PickFirstLoadBalancerConfig) parsedConfig.getConfig();
assertThat(config.shuffleAddressList).isTrue(); Boolean shuffleAddressList;
assertThat(config.randomSeed).isNull(); Long randomSeed;
if (PickFirstLoadBalancerProvider.isEnabledNewPickFirst()) {
PickFirstLeafLoadBalancerConfig config =
(PickFirstLeafLoadBalancerConfig) parsedConfig.getConfig();
shuffleAddressList = config.shuffleAddressList;
randomSeed = config.randomSeed;
} else {
PickFirstLoadBalancerConfig config = (PickFirstLoadBalancerConfig) parsedConfig.getConfig();
shuffleAddressList = config.shuffleAddressList;
randomSeed = config.randomSeed;
}
assertThat(shuffleAddressList).isTrue();
assertThat(randomSeed).isNull();
} }
@Test @Test
@ -46,9 +60,22 @@ public class PickFirstLoadBalancerProviderTest {
Map<String, Object> rawConfig = new HashMap<>(); Map<String, Object> rawConfig = new HashMap<>();
ConfigOrError parsedConfig = new PickFirstLoadBalancerProvider().parseLoadBalancingPolicyConfig( ConfigOrError parsedConfig = new PickFirstLoadBalancerProvider().parseLoadBalancingPolicyConfig(
rawConfig); rawConfig);
PickFirstLoadBalancerConfig config = (PickFirstLoadBalancerConfig) parsedConfig.getConfig();
assertThat(config.shuffleAddressList).isNull(); Boolean shuffleAddressList;
assertThat(config.randomSeed).isNull(); Long randomSeed;
if (PickFirstLoadBalancerProvider.isEnabledNewPickFirst()) {
PickFirstLeafLoadBalancerConfig config =
(PickFirstLeafLoadBalancerConfig) parsedConfig.getConfig();
shuffleAddressList = config.shuffleAddressList;
randomSeed = config.randomSeed;
} else {
PickFirstLoadBalancerConfig config = (PickFirstLoadBalancerConfig) parsedConfig.getConfig();
shuffleAddressList = config.shuffleAddressList;
randomSeed = config.randomSeed;
}
assertThat(shuffleAddressList).isNull();
assertThat(randomSeed).isNull();
} }
} }

View File

@ -21,9 +21,9 @@ import static com.google.common.truth.Truth.assertThat;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import com.google.common.base.Converter; import com.google.common.base.Converter;
@ -106,8 +106,9 @@ public class RlsLoadBalancerTest {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
}); });
private final FakeHelper helperDelegate = new FakeHelper();
private final Helper helper = private final Helper helper =
mock(Helper.class, AdditionalAnswers.delegatesTo(new FakeHelper())); mock(Helper.class, AdditionalAnswers.delegatesTo(helperDelegate));
private final FakeRlsServerImpl fakeRlsServerImpl = new FakeRlsServerImpl(); private final FakeRlsServerImpl fakeRlsServerImpl = new FakeRlsServerImpl();
private final Deque<FakeSubchannel> subchannels = new LinkedList<>(); private final Deque<FakeSubchannel> subchannels = new LinkedList<>();
private final FakeThrottler fakeThrottler = new FakeThrottler(); private final FakeThrottler fakeThrottler = new FakeThrottler();
@ -119,6 +120,8 @@ public class RlsLoadBalancerTest {
private MethodDescriptor<Object, Object> fakeRescueMethod; private MethodDescriptor<Object, Object> fakeRescueMethod;
private RlsLoadBalancer rlsLb; private RlsLoadBalancer rlsLb;
private String defaultTarget = "defaultTarget"; private String defaultTarget = "defaultTarget";
private PickSubchannelArgsImpl searchSubchannelArgs;
private PickSubchannelArgsImpl rescueSubchannelArgs;
@Before @Before
public void setUp() { public void setUp() {
@ -159,6 +162,13 @@ public class RlsLoadBalancerTest {
.setTicker(fakeClock.getTicker()); .setTicker(fakeClock.getTicker());
} }
}; };
Metadata headers = new Metadata();
searchSubchannelArgs =
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT);
rescueSubchannelArgs =
new PickSubchannelArgsImpl(fakeRescueMethod, headers, CallOptions.DEFAULT);
} }
@After @After
@ -176,13 +186,13 @@ public class RlsLoadBalancerTest {
Metadata headers = new Metadata(); Metadata headers = new Metadata();
PickSubchannelArgsImpl fakeSearchMethodArgs = PickSubchannelArgsImpl fakeSearchMethodArgs =
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT); new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT);
PickResult res = picker.pickSubchannel(fakeSearchMethodArgs); picker.pickSubchannel(fakeSearchMethodArgs); // Will create the subchannel
FakeSubchannel subchannel = (FakeSubchannel) res.getSubchannel(); FakeSubchannel subchannel = subchannels.peek();
assertThat(subchannel).isNotNull(); assertThat(subchannel).isNotNull();
// Ensure happy path is unaffected // Ensure happy path is unaffected
subchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); subchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY));
res = picker.pickSubchannel(fakeSearchMethodArgs); PickResult res = picker.pickSubchannel(fakeSearchMethodArgs);
assertThat(res.getStatus().getCode()).isEqualTo(Status.Code.OK); assertThat(res.getStatus().getCode()).isEqualTo(Status.Code.OK);
// Check on conversion // Check on conversion
@ -203,34 +213,28 @@ public class RlsLoadBalancerTest {
inOrder.verify(helper) inOrder.verify(helper)
.updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture());
SubchannelPicker picker = pickerCaptor.getValue(); SubchannelPicker picker = pickerCaptor.getValue();
Metadata headers = new Metadata(); PickResult res = picker.pickSubchannel(searchSubchannelArgs);
PickResult res = picker.pickSubchannel(
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT));
inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class));
inOrder.verify(helper) inOrder.verify(helper, atLeast(0))
.updateBalancingState(eq(ConnectivityState.CONNECTING), any(SubchannelPicker.class)); .updateBalancingState(eq(ConnectivityState.CONNECTING), any(SubchannelPicker.class));
inOrder.verifyNoMoreInteractions(); inOrder.verifyNoMoreInteractions();
assertThat(res.getStatus().isOk()).isTrue(); assertThat(res.getStatus().isOk()).isTrue();
assertThat(subchannelIsReady(res.getSubchannel())).isFalse();
assertThat(subchannels).hasSize(1); assertThat(subchannels).hasSize(1);
FakeSubchannel searchSubchannel = subchannels.getLast(); FakeSubchannel searchSubchannel = subchannels.getLast();
assertThat(subchannelIsReady(searchSubchannel)).isFalse();
searchSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); searchSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY));
inOrder.verify(helper) inOrder.verify(helper)
.updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture()); .updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture());
inOrder.verifyNoMoreInteractions(); inOrder.verifyNoMoreInteractions();
res = picker.pickSubchannel(searchSubchannelArgs);
assertThat(subchannelIsReady(res.getSubchannel())).isTrue(); assertThat(subchannelIsReady(res.getSubchannel())).isTrue();
assertThat(res.getSubchannel().getAddresses()).isEqualTo(searchSubchannel.getAddresses()); assertThat(res.getSubchannel()).isSameInstanceAs(searchSubchannel);
assertThat(res.getSubchannel().getAttributes()).isEqualTo(searchSubchannel.getAttributes());
// rescue should be pending status although the overall channel state is READY // rescue should be pending status although the overall channel state is READY
res = picker.pickSubchannel( res = picker.pickSubchannel(rescueSubchannelArgs);
new PickSubchannelArgsImpl(fakeRescueMethod, headers, CallOptions.DEFAULT));
inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class));
// other rls picker itself is ready due to first channel. // other rls picker itself is ready due to first channel.
inOrder.verify(helper)
.updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture());
inOrder.verifyNoMoreInteractions();
assertThat(res.getStatus().isOk()).isTrue(); assertThat(res.getStatus().isOk()).isTrue();
assertThat(subchannelIsReady(res.getSubchannel())).isFalse(); assertThat(subchannelIsReady(res.getSubchannel())).isFalse();
assertThat(subchannels).hasSize(2); assertThat(subchannels).hasSize(2);
@ -238,7 +242,6 @@ public class RlsLoadBalancerTest {
// search subchannel is down, rescue subchannel is connecting // search subchannel is down, rescue subchannel is connecting
searchSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); searchSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE));
inOrder.verify(helper) inOrder.verify(helper)
.updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture());
@ -248,8 +251,7 @@ public class RlsLoadBalancerTest {
// search again, verify that it doesn't use fallback, since RLS server responded, even though // search again, verify that it doesn't use fallback, since RLS server responded, even though
// subchannel is in failure mode // subchannel is in failure mode
res = picker.pickSubchannel( res = picker.pickSubchannel(searchSubchannelArgs);
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT));
assertThat(res.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); assertThat(res.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE);
assertThat(subchannelIsReady(res.getSubchannel())).isFalse(); assertThat(subchannelIsReady(res.getSubchannel())).isFalse();
} }
@ -263,52 +265,41 @@ public class RlsLoadBalancerTest {
inOrder.verify(helper) inOrder.verify(helper)
.updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture());
SubchannelPicker picker = pickerCaptor.getValue(); SubchannelPicker picker = pickerCaptor.getValue();
Metadata headers = new Metadata();
PickResult res;
// Search that when the RLS server doesn't respond, that fallback is used // Search that when the RLS server doesn't respond, that fallback is used
res = picker.pickSubchannel( PickResult res = picker.pickSubchannel(searchSubchannelArgs); // create subchannel
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT));
FakeSubchannel fallbackSubchannel = (FakeSubchannel) res.getSubchannel();
assertThat(fallbackSubchannel).isNotNull();
assertThat(res.getStatus().getCode()).isEqualTo(Status.Code.OK); assertThat(res.getStatus().getCode()).isEqualTo(Status.Code.OK);
assertThat(subchannelIsReady(res.getSubchannel())).isFalse(); FakeSubchannel fallbackSubchannel =
inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); (FakeSubchannel) markReadyAndGetPickResult(inOrder, searchSubchannelArgs).getSubchannel();
fallbackSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); assertThat(fallbackSubchannel).isNotNull();
inOrder.verify(helper, times(1)) assertThat(subchannelIsReady(fallbackSubchannel)).isTrue();
.updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture());
inOrder.verifyNoMoreInteractions(); inOrder.verifyNoMoreInteractions();
res = picker.pickSubchannel( Subchannel subchannel = picker.pickSubchannel(searchSubchannelArgs).getSubchannel();
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); assertThat(subchannelIsReady(subchannel)).isTrue();
assertThat(subchannelIsReady(res.getSubchannel())).isTrue(); assertThat(subchannel).isSameInstanceAs(fallbackSubchannel);
assertThat(res.getSubchannel()).isSameInstanceAs(fallbackSubchannel);
res = picker.pickSubchannel( subchannel = picker.pickSubchannel(searchSubchannelArgs).getSubchannel();
new PickSubchannelArgsImpl(fakeRescueMethod, headers, CallOptions.DEFAULT)); assertThat(subchannelIsReady(subchannel)).isTrue();
assertThat(subchannelIsReady(res.getSubchannel())).isTrue(); assertThat(subchannel).isSameInstanceAs(fallbackSubchannel);
assertThat(res.getSubchannel()).isSameInstanceAs(fallbackSubchannel);
// Make sure that when RLS starts communicating that default stops being used // Make sure that when RLS starts communicating that default stops being used
fakeThrottler.nextResult = false; fakeThrottler.nextResult = false;
fakeClock.forwardTime(2, TimeUnit.SECONDS); // Expires backoff cache entries fakeClock.forwardTime(2, TimeUnit.SECONDS); // Expires backoff cache entries
// Create search subchannel
res = picker.pickSubchannel( picker.pickSubchannel(searchSubchannelArgs);// Create search subchannel
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); FakeSubchannel searchSubchannel =
assertThat(res.getSubchannel()).isNotSameInstanceAs(fallbackSubchannel); (FakeSubchannel) markReadyAndGetPickResult(inOrder, searchSubchannelArgs).getSubchannel();
FakeSubchannel searchSubchannel = (FakeSubchannel) res.getSubchannel();
assertThat(searchSubchannel).isNotNull(); assertThat(searchSubchannel).isNotNull();
searchSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); assertThat(searchSubchannel).isNotSameInstanceAs(fallbackSubchannel);
// create rescue subchannel // create rescue subchannel
res = picker.pickSubchannel( picker.pickSubchannel(rescueSubchannelArgs);
new PickSubchannelArgsImpl(fakeRescueMethod, headers, CallOptions.DEFAULT)); FakeSubchannel rescueSubchannel =
assertThat(res.getSubchannel()).isNotSameInstanceAs(fallbackSubchannel); (FakeSubchannel) markReadyAndGetPickResult(inOrder, rescueSubchannelArgs).getSubchannel();
assertThat(res.getSubchannel()).isNotSameInstanceAs(searchSubchannel);
FakeSubchannel rescueSubchannel = (FakeSubchannel) res.getSubchannel();
assertThat(rescueSubchannel).isNotNull(); assertThat(rescueSubchannel).isNotNull();
rescueSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); assertThat(rescueSubchannel).isNotSameInstanceAs(fallbackSubchannel);
assertThat(rescueSubchannel).isNotSameInstanceAs(searchSubchannel);
// all channels are failed // all channels are failed
rescueSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); rescueSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE));
@ -316,7 +307,7 @@ public class RlsLoadBalancerTest {
fallbackSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); fallbackSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE));
res = picker.pickSubchannel( res = picker.pickSubchannel(
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); searchSubchannelArgs);
assertThat(res.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); assertThat(res.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE);
assertThat(res.getSubchannel()).isNull(); assertThat(res.getSubchannel()).isNull();
} }
@ -330,37 +321,29 @@ public class RlsLoadBalancerTest {
.updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture());
SubchannelPicker picker = pickerCaptor.getValue(); SubchannelPicker picker = pickerCaptor.getValue();
Metadata headers = new Metadata(); Metadata headers = new Metadata();
PickResult res = picker.pickSubchannel( PickResult res = picker.pickSubchannel(searchSubchannelArgs);
new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT));
inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class));
inOrder.verify(helper) inOrder.verify(helper, atLeast(0))
.updateBalancingState(eq(ConnectivityState.CONNECTING), any(SubchannelPicker.class)); .updateBalancingState(eq(ConnectivityState.CONNECTING), any(SubchannelPicker.class));
inOrder.verifyNoMoreInteractions(); inOrder.verifyNoMoreInteractions();
assertThat(res.getStatus().isOk()).isTrue(); assertThat(res.getStatus().isOk()).isTrue();
assertThat(subchannels).hasSize(1); assertThat(subchannels).hasSize(1);
FakeSubchannel searchSubchannel = subchannels.getLast(); FakeSubchannel searchSubchannel =
searchSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); (FakeSubchannel) markReadyAndGetPickResult(inOrder, searchSubchannelArgs).getSubchannel();
inOrder.verify(helper)
.updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture());
inOrder.verifyNoMoreInteractions(); inOrder.verifyNoMoreInteractions();
assertThat(subchannelIsReady(res.getSubchannel())).isTrue(); assertThat(subchannelIsReady(searchSubchannel)).isTrue();
assertThat(res.getSubchannel().getAddresses()).isEqualTo(searchSubchannel.getAddresses()); assertThat(subchannels.getLast()).isSameInstanceAs(searchSubchannel);
assertThat(res.getSubchannel().getAttributes()).isEqualTo(searchSubchannel.getAttributes());
// rescue should be pending status although the overall channel state is READY // rescue should be pending status although the overall channel state is READY
picker = pickerCaptor.getValue(); picker = pickerCaptor.getValue();
res = picker.pickSubchannel( res = picker.pickSubchannel(rescueSubchannelArgs);
new PickSubchannelArgsImpl(fakeRescueMethod, headers, CallOptions.DEFAULT));
inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class));
// other rls picker itself is ready due to first channel. // other rls picker itself is ready due to first channel.
inOrder.verify(helper)
.updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture());
inOrder.verifyNoMoreInteractions();
assertThat(res.getStatus().isOk()).isTrue(); assertThat(res.getStatus().isOk()).isTrue();
assertThat(subchannelIsReady(res.getSubchannel())).isFalse();
assertThat(subchannels).hasSize(2); assertThat(subchannels).hasSize(2);
FakeSubchannel rescueSubchannel = subchannels.getLast(); FakeSubchannel rescueSubchannel = subchannels.getLast();
assertThat(subchannelIsReady(rescueSubchannel)).isFalse();
// search subchannel is down, rescue subchannel is still connecting // search subchannel is down, rescue subchannel is still connecting
searchSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.NOT_FOUND)); searchSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.NOT_FOUND));
@ -388,6 +371,7 @@ public class RlsLoadBalancerTest {
rescueSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.NOT_FOUND)); rescueSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.NOT_FOUND));
inOrder.verify(helper) inOrder.verify(helper)
.updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); .updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture());
inOrder.verify(helper, atLeast(0)).refreshNameResolution();
inOrder.verifyNoMoreInteractions(); inOrder.verifyNoMoreInteractions();
} }
@ -406,10 +390,7 @@ public class RlsLoadBalancerTest {
assertThat(subchannelIsReady(res.getSubchannel())).isFalse(); assertThat(subchannelIsReady(res.getSubchannel())).isFalse();
inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class));
inOrder.verify(helper)
.updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture());
assertThat(subchannels).hasSize(1); assertThat(subchannels).hasSize(1);
inOrder.verifyNoMoreInteractions();
FakeSubchannel searchSubchannel = subchannels.getLast(); FakeSubchannel searchSubchannel = subchannels.getLast();
searchSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); searchSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY));
@ -438,6 +419,16 @@ public class RlsLoadBalancerTest {
assertThat(subchannelIsReady(res.getSubchannel())).isFalse(); assertThat(subchannelIsReady(res.getSubchannel())).isFalse();
} }
private PickResult markReadyAndGetPickResult(InOrder inOrder,
PickSubchannelArgsImpl pickSubchannelArgs) {
subchannels.getLast().updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY));
inOrder.verify(helper, atLeast(1))
.updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture());
PickResult pickResult = pickerCaptor.getValue().pickSubchannel(pickSubchannelArgs);
inOrder.verify(helper, atLeast(0)).getChannelLogger();
return pickResult;
}
private void deliverResolvedAddresses() throws Exception { private void deliverResolvedAddresses() throws Exception {
ConfigOrError parsedConfigOrError = ConfigOrError parsedConfigOrError =
provider.parseLoadBalancingPolicyConfig(getServiceConfig()); provider.parseLoadBalancingPolicyConfig(getServiceConfig());

View File

@ -41,6 +41,7 @@ import io.grpc.ConnectivityStateInfo;
import io.grpc.EquivalentAddressGroup; import io.grpc.EquivalentAddressGroup;
import io.grpc.LoadBalancer; import io.grpc.LoadBalancer;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.internal.PickFirstLoadBalancerProvider;
import io.grpc.util.AbstractTestHelper.FakeSocketAddress; import io.grpc.util.AbstractTestHelper.FakeSocketAddress;
import io.grpc.util.MultiChildLoadBalancer.ChildLbState; import io.grpc.util.MultiChildLoadBalancer.ChildLbState;
import io.grpc.util.MultiChildLoadBalancer.Endpoint; import io.grpc.util.MultiChildLoadBalancer.Endpoint;
@ -81,8 +82,8 @@ public class MultiChildLoadBalancerTest {
private ArgumentCaptor<ConnectivityState> stateCaptor; private ArgumentCaptor<ConnectivityState> stateCaptor;
@Captor @Captor
private ArgumentCaptor<LoadBalancer.CreateSubchannelArgs> createArgsCaptor; private ArgumentCaptor<LoadBalancer.CreateSubchannelArgs> createArgsCaptor;
private TestHelper testHelperInst = new TestHelper(); private final TestHelper testHelperInst = new TestHelper();
private LoadBalancer.Helper mockHelper = private final LoadBalancer.Helper mockHelper =
mock(LoadBalancer.Helper.class, delegatesTo(testHelperInst)); mock(LoadBalancer.Helper.class, delegatesTo(testHelperInst));
private TestLb loadBalancer; private TestLb loadBalancer;
@ -99,7 +100,7 @@ public class MultiChildLoadBalancerTest {
} }
@Test @Test
public void pickAfterResolved() throws Exception { public void pickAfterResolved() {
Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
LoadBalancer.ResolvedAddresses.newBuilder().setAddresses(servers).build()); LoadBalancer.ResolvedAddresses.newBuilder().setAddresses(servers).build());
assertThat(addressesAcceptanceStatus.isOk()).isTrue(); assertThat(addressesAcceptanceStatus.isOk()).isTrue();
@ -131,7 +132,7 @@ public class MultiChildLoadBalancerTest {
} }
@Test @Test
public void pickAfterResolvedUpdatedHosts() throws Exception { public void pickAfterResolvedUpdatedHosts() {
Attributes.Key<String> key = Attributes.Key.create("check-that-it-is-propagated"); Attributes.Key<String> key = Attributes.Key.create("check-that-it-is-propagated");
FakeSocketAddress removedAddr = new FakeSocketAddress("removed"); FakeSocketAddress removedAddr = new FakeSocketAddress("removed");
EquivalentAddressGroup removedEag = new EquivalentAddressGroup(removedAddr); EquivalentAddressGroup removedEag = new EquivalentAddressGroup(removedAddr);
@ -195,7 +196,7 @@ public class MultiChildLoadBalancerTest {
} }
@Test @Test
public void pickFromMultiAddressEags() throws Exception { public void pickFromMultiAddressEags() {
List<SocketAddress> addressList1 = new ArrayList<>(); List<SocketAddress> addressList1 = new ArrayList<>();
List<SocketAddress> addressList2 = new ArrayList<>(); List<SocketAddress> addressList2 = new ArrayList<>();
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
@ -215,7 +216,7 @@ public class MultiChildLoadBalancerTest {
LoadBalancer.ResolvedAddresses.newBuilder().setAddresses(multiGroups).build()); LoadBalancer.ResolvedAddresses.newBuilder().setAddresses(multiGroups).build());
assertTrue(addressesAcceptanceStatus.isOk()); assertTrue(addressesAcceptanceStatus.isOk());
LoadBalancer.Subchannel evens = subchannels.get(Collections.singletonList(eag1)); LoadBalancer.Subchannel evens = getSubchannel(eag1);
deliverSubchannelState(evens, ConnectivityStateInfo.forNonError(READY)); deliverSubchannelState(evens, ConnectivityStateInfo.forNonError(READY));
verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture());
assertThat(pickerCaptor.getValue()).isInstanceOf(TestLb.TestSubchannelPicker.class); assertThat(pickerCaptor.getValue()).isInstanceOf(TestLb.TestSubchannelPicker.class);
@ -321,8 +322,20 @@ public class MultiChildLoadBalancerTest {
return new Endpoint(eag); return new Endpoint(eag);
} }
private LoadBalancer.Subchannel getSubchannel(EquivalentAddressGroup removedEag) { private LoadBalancer.Subchannel getSubchannel(EquivalentAddressGroup eag) {
return subchannels.get(Collections.singletonList(removedEag)); if (PickFirstLoadBalancerProvider.isEnabledNewPickFirst()) {
for (SocketAddress addr : eag.getAddresses()) {
LoadBalancer.Subchannel subchannel = subchannels.get(
Arrays.asList(new EquivalentAddressGroup(addr, eag.getAttributes())));
if (subchannel != null) {
return subchannel;
}
}
} else {
return subchannels.get(Collections.singletonList(eag));
}
return null;
} }
private static List<Object> getChildEags(MultiChildLoadBalancer loadBalancer) { private static List<Object> getChildEags(MultiChildLoadBalancer loadBalancer) {

View File

@ -53,6 +53,7 @@ import io.grpc.Status;
import io.grpc.SynchronizationContext; import io.grpc.SynchronizationContext;
import io.grpc.internal.FakeClock; import io.grpc.internal.FakeClock;
import io.grpc.internal.FakeClock.ScheduledTask; import io.grpc.internal.FakeClock.ScheduledTask;
import io.grpc.internal.PickFirstLoadBalancerProvider;
import io.grpc.internal.ServiceConfigUtil.PolicySelection; import io.grpc.internal.ServiceConfigUtil.PolicySelection;
import io.grpc.internal.TestUtils.StandardLoadBalancerProvider; import io.grpc.internal.TestUtils.StandardLoadBalancerProvider;
import io.grpc.util.OutlierDetectionLoadBalancer.EndpointTracker; import io.grpc.util.OutlierDetectionLoadBalancer.EndpointTracker;
@ -409,6 +410,9 @@ public class OutlierDetectionLoadBalancerTest {
SubchannelPicker picker = pickerCaptor.getAllValues().get(2); SubchannelPicker picker = pickerCaptor.getAllValues().get(2);
PickResult pickResult = picker.pickSubchannel(mock(PickSubchannelArgs.class)); PickResult pickResult = picker.pickSubchannel(mock(PickSubchannelArgs.class));
Subchannel s = ((OutlierDetectionSubchannel) pickResult.getSubchannel()).delegate(); Subchannel s = ((OutlierDetectionSubchannel) pickResult.getSubchannel()).delegate();
if (s instanceof HealthProducerHelper.HealthProducerSubchannel) {
s = ((HealthProducerHelper.HealthProducerSubchannel) s).delegate();
}
assertThat(s).isEqualTo(readySubchannel); assertThat(s).isEqualTo(readySubchannel);
} }
@ -564,7 +568,9 @@ public class OutlierDetectionLoadBalancerTest {
loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers));
generateLoad(ImmutableMap.of(subchannel2, Status.DEADLINE_EXCEEDED), 12); // The PickFirstLeafLB has an extra level of indirection because of health
int expectedStateChanges = PickFirstLoadBalancerProvider.isEnabledNewPickFirst() ? 16 : 12;
generateLoad(ImmutableMap.of(subchannel2, Status.DEADLINE_EXCEEDED), expectedStateChanges);
// Move forward in time to a point where the detection timer has fired. // Move forward in time to a point where the detection timer has fired.
forwardTime(config); forwardTime(config);
@ -597,8 +603,9 @@ public class OutlierDetectionLoadBalancerTest {
// The one subchannel that was returning errors should be ejected. // The one subchannel that was returning errors should be ejected.
assertEjectedSubchannels(ImmutableSet.of(ImmutableSet.copyOf(servers.get(0).getAddresses()))); assertEjectedSubchannels(ImmutableSet.of(ImmutableSet.copyOf(servers.get(0).getAddresses())));
// Now we produce more load, but the subchannel start working and is no longer an outlier. // Now we produce more load, but the subchannel has started working and is no longer an outlier.
generateLoad(ImmutableMap.of(), 12); int expectedStateChanges = PickFirstLoadBalancerProvider.isEnabledNewPickFirst() ? 16 : 12;
generateLoad(ImmutableMap.of(), expectedStateChanges);
// Move forward in time to a point where the detection timer has fired. // Move forward in time to a point where the detection timer has fired.
fakeClock.forwardTime(config.maxEjectionTimeNanos + 1, TimeUnit.NANOSECONDS); fakeClock.forwardTime(config.maxEjectionTimeNanos + 1, TimeUnit.NANOSECONDS);

View File

@ -30,6 +30,8 @@ import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA; import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
@ -244,7 +246,7 @@ public class RoundRobinLoadBalancerTest {
// TODO figure out if this method testing the right things // TODO figure out if this method testing the right things
ChildLbState childLbState = loadBalancer.getChildLbStates().iterator().next(); ChildLbState childLbState = loadBalancer.getChildLbStates().iterator().next();
Subchannel subchannel = childLbState.getCurrentPicker().pickSubchannel(null).getSubchannel(); Subchannel subchannel = subchannels.get(Arrays.asList(childLbState.getEag()));
inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class));
assertThat(childLbState.getCurrentState()).isEqualTo(CONNECTING); assertThat(childLbState.getCurrentState()).isEqualTo(CONNECTING);
@ -258,16 +260,14 @@ public class RoundRobinLoadBalancerTest {
deliverSubchannelState(subchannel, deliverSubchannelState(subchannel,
ConnectivityStateInfo.forTransientFailure(error)); ConnectivityStateInfo.forTransientFailure(error));
assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE); assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE);
inOrder.verify(mockHelper).refreshNameResolution(); AbstractTestHelper.refreshInvokedAndUpdateBS(inOrder, CONNECTING, mockHelper, pickerCaptor);
inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
assertThat(pickerCaptor.getValue()).isInstanceOf(EmptyPicker.class); assertThat(pickerCaptor.getValue()).isInstanceOf(EmptyPicker.class);
deliverSubchannelState(subchannel, deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(IDLE));
ConnectivityStateInfo.forNonError(IDLE));
inOrder.verify(mockHelper).refreshNameResolution(); inOrder.verify(mockHelper).refreshNameResolution();
assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE); assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE);
verify(subchannel, times(2)).requestConnection(); verify(subchannel, atLeastOnce()).requestConnection();
verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
verifyNoMoreInteractions(mockHelper); verifyNoMoreInteractions(mockHelper);
} }
@ -302,20 +302,20 @@ public class RoundRobinLoadBalancerTest {
Map<ChildLbState, Subchannel> childToSubChannelMap = new HashMap<>(); Map<ChildLbState, Subchannel> childToSubChannelMap = new HashMap<>();
// Simulate state transitions for each subchannel individually. // Simulate state transitions for each subchannel individually.
for ( ChildLbState child : loadBalancer.getChildLbStates()) { for ( ChildLbState child : loadBalancer.getChildLbStates()) {
Subchannel sc = child.getSubchannels(mockArgs); Subchannel sc = subchannels.get(Arrays.asList(child.getEag()));
childToSubChannelMap.put(child, sc); childToSubChannelMap.put(child, sc);
Status error = Status.UNKNOWN.withDescription("connection broken"); Status error = Status.UNKNOWN.withDescription("connection broken");
deliverSubchannelState( deliverSubchannelState(
sc, sc,
ConnectivityStateInfo.forTransientFailure(error)); ConnectivityStateInfo.forTransientFailure(error));
assertEquals(TRANSIENT_FAILURE, child.getCurrentState()); assertEquals(TRANSIENT_FAILURE, child.getCurrentState());
inOrder.verify(mockHelper).refreshNameResolution();
deliverSubchannelState( deliverSubchannelState(
sc, sc,
ConnectivityStateInfo.forNonError(CONNECTING)); ConnectivityStateInfo.forNonError(CONNECTING));
assertEquals(TRANSIENT_FAILURE, child.getCurrentState()); assertEquals(TRANSIENT_FAILURE, child.getCurrentState());
} }
inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), isA(ReadyPicker.class)); inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), isA(ReadyPicker.class));
inOrder.verify(mockHelper, atLeast(0)).refreshNameResolution();
inOrder.verifyNoMoreInteractions(); inOrder.verifyNoMoreInteractions();
ChildLbState child = loadBalancer.getChildLbStates().iterator().next(); ChildLbState child = loadBalancer.getChildLbStates().iterator().next();
@ -325,7 +325,8 @@ public class RoundRobinLoadBalancerTest {
inOrder.verify(mockHelper).updateBalancingState(eq(READY), isA(ReadyPicker.class)); inOrder.verify(mockHelper).updateBalancingState(eq(READY), isA(ReadyPicker.class));
verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
verifyNoMoreInteractions(mockHelper); inOrder.verify(mockHelper, atLeast(0)).refreshNameResolution();
inOrder.verifyNoMoreInteractions();
} }
@Test @Test
@ -339,7 +340,7 @@ public class RoundRobinLoadBalancerTest {
// Simulate state transitions for each subchannel individually. // Simulate state transitions for each subchannel individually.
for (ChildLbState child : loadBalancer.getChildLbStates()) { for (ChildLbState child : loadBalancer.getChildLbStates()) {
Subchannel sc = child.getSubchannels(mockArgs); Subchannel sc = subchannels.get(Arrays.asList(child.getEag()));
verify(sc).requestConnection(); verify(sc).requestConnection();
deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(CONNECTING)); deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(CONNECTING));
Status error = Status.UNKNOWN.withDescription("connection broken"); Status error = Status.UNKNOWN.withDescription("connection broken");

View File

@ -17,6 +17,7 @@
package io.grpc.util; package io.grpc.util;
import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import com.google.common.collect.Maps; import com.google.common.collect.Maps;
@ -31,11 +32,14 @@ import io.grpc.LoadBalancer.Helper;
import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.Subchannel;
import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancer.SubchannelPicker;
import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.LoadBalancer.SubchannelStateListener;
import io.grpc.internal.PickFirstLoadBalancerProvider;
import java.net.SocketAddress; import java.net.SocketAddress;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import org.mockito.ArgumentCaptor;
import org.mockito.InOrder;
/** /**
* A real class that can be used as a delegate of a mock Helper to provide more real representation * A real class that can be used as a delegate of a mock Helper to provide more real representation
@ -129,6 +133,22 @@ public abstract class AbstractTestHelper extends ForwardingLoadBalancerHelper {
return "Test Helper"; return "Test Helper";
} }
public static void refreshInvokedAndUpdateBS(InOrder inOrder, ConnectivityState state,
Helper helper,
ArgumentCaptor<SubchannelPicker> pickerCaptor) {
// Old PF and new PF reverse calling order of updateBlaancingState and refreshNameResolution
if (PickFirstLoadBalancerProvider.isEnabledNewPickFirst()) {
inOrder.verify(helper).updateBalancingState(eq(state), pickerCaptor.capture());
}
inOrder.verify(helper).refreshNameResolution();
if (!PickFirstLoadBalancerProvider.isEnabledNewPickFirst()) {
inOrder.verify(helper).updateBalancingState(eq(state), pickerCaptor.capture());
}
}
protected class TestSubchannel extends ForwardingSubchannel { protected class TestSubchannel extends ForwardingSubchannel {
CreateSubchannelArgs args; CreateSubchannelArgs args;
Channel channel; Channel channel;

View File

@ -530,16 +530,17 @@ final class RingHashLoadBalancer extends MultiChildLoadBalancer {
@Override @Override
public void updateBalancingState(final ConnectivityState newState, public void updateBalancingState(final ConnectivityState newState,
final SubchannelPicker newPicker) { final SubchannelPicker newPicker) {
// Subchannel picker and state are saved, but will only be propagated to the channel
// when the child instance exits deactivated state.
setCurrentState(newState);
setCurrentPicker(newPicker);
// If we are already in the process of resolving addresses, the overall balancing state // If we are already in the process of resolving addresses, the overall balancing state
// will be updated at the end of it, and we don't need to trigger that update here. // will be updated at the end of it, and we don't need to trigger that update here.
if (getChildLbState(getKey()) == null) { if (getChildLbState(getKey()) == null) {
return; return;
} }
// Subchannel picker and state are saved, but will only be propagated to the channel
// when the child instance exits deactivated state.
setCurrentState(newState);
setCurrentPicker(newPicker);
if (!isDeactivated() && !resolvingAddresses) { if (!isDeactivated() && !resolvingAddresses) {
updateOverallBalancingState(); updateOverallBalancingState();
} }

View File

@ -31,6 +31,7 @@ import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA; import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
@ -58,6 +59,7 @@ import io.grpc.LoadBalancer.SubchannelPicker;
import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.LoadBalancer.SubchannelStateListener;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.internal.PickFirstLoadBalancerProvider;
import io.grpc.util.AbstractTestHelper; import io.grpc.util.AbstractTestHelper;
import io.grpc.util.MultiChildLoadBalancer.ChildLbState; import io.grpc.util.MultiChildLoadBalancer.ChildLbState;
import io.grpc.xds.LeastRequestLoadBalancer.EmptyPicker; import io.grpc.xds.LeastRequestLoadBalancer.EmptyPicker;
@ -266,28 +268,25 @@ public class LeastRequestLoadBalancerTest {
inOrder.verify(helper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); inOrder.verify(helper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class));
assertThat(childLbState.getCurrentState()).isEqualTo(CONNECTING); assertThat(childLbState.getCurrentState()).isEqualTo(CONNECTING);
deliverSubchannelState(subchannel, deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
ConnectivityStateInfo.forNonError(READY));
inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
assertThat(pickerCaptor.getValue()).isInstanceOf(ReadyPicker.class); assertThat(pickerCaptor.getValue()).isInstanceOf(ReadyPicker.class);
assertThat(childLbState.getCurrentState()).isEqualTo(READY); assertThat(childLbState.getCurrentState()).isEqualTo(READY);
Status error = Status.UNKNOWN.withDescription("¯\\_(ツ)_//¯"); Status error = Status.UNKNOWN.withDescription("¯\\_(ツ)_//¯");
deliverSubchannelState(subchannel, deliverSubchannelState(subchannel, ConnectivityStateInfo.forTransientFailure(error));
ConnectivityStateInfo.forTransientFailure(error));
assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE); assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE);
assertThat(childLbState.getCurrentPicker().toString()).contains(error.toString()); assertThat(childLbState.getCurrentPicker().toString()).contains(error.toString());
inOrder.verify(helper).refreshNameResolution(); refreshInvokedAndUpdateBS(inOrder, CONNECTING);
inOrder.verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
assertThat(pickerCaptor.getValue()).isInstanceOf(EmptyPicker.class); assertThat(pickerCaptor.getValue()).isInstanceOf(EmptyPicker.class);
deliverSubchannelState(subchannel, deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(IDLE));
ConnectivityStateInfo.forNonError(IDLE));
inOrder.verify(helper).refreshNameResolution(); inOrder.verify(helper).refreshNameResolution();
assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE); assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE);
assertThat(childLbState.getCurrentPicker().toString()).contains(error.toString()); assertThat(childLbState.getCurrentPicker().toString()).contains(error.toString());
verify(subchannel, times(2)).requestConnection(); int expectedCount = PickFirstLoadBalancerProvider.isEnabledNewPickFirst() ? 1 : 2;
verify(subchannel, times(expectedCount)).requestConnection();
verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
verifyNoMoreInteractions(helper); verifyNoMoreInteractions(helper);
} }
@ -358,14 +357,15 @@ public class LeastRequestLoadBalancerTest {
Subchannel sc = getSubchannel(childLbState); Subchannel sc = getSubchannel(childLbState);
Status error = Status.UNKNOWN.withDescription("connection broken"); Status error = Status.UNKNOWN.withDescription("connection broken");
deliverSubchannelState(sc, ConnectivityStateInfo.forTransientFailure(error)); deliverSubchannelState(sc, ConnectivityStateInfo.forTransientFailure(error));
inOrder.verify(helper).refreshNameResolution();
deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(CONNECTING)); deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(CONNECTING));
assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE); assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE);
} }
inOrder.verify(helper)
.updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); verify(helper, atLeast(loadBalancer.getChildLbStates().size())).refreshNameResolution();
inOrder.verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
assertThat(getStatusString(pickerCaptor.getValue())) assertThat(getStatusString(pickerCaptor.getValue()))
.contains("Status{code=UNKNOWN, description=connection broken"); .contains("Status{code=UNKNOWN, description=connection broken");
inOrder.verify(helper, atLeast(0)).refreshNameResolution();
inOrder.verifyNoMoreInteractions(); inOrder.verifyNoMoreInteractions();
ChildLbState childLbState = loadBalancer.getChildLbStates().iterator().next(); ChildLbState childLbState = loadBalancer.getChildLbStates().iterator().next();
@ -660,6 +660,19 @@ public class LeastRequestLoadBalancerTest {
testHelperInstance.deliverSubchannelState(subchannel, newState); testHelperInstance.deliverSubchannelState(subchannel, newState);
} }
// Old PF and new PF reverse calling order of updateBlaancingState and refreshNameResolution
private void refreshInvokedAndUpdateBS(InOrder inOrder, ConnectivityState state) {
if (PickFirstLoadBalancerProvider.isEnabledNewPickFirst()) {
inOrder.verify(helper).updateBalancingState(eq(state), pickerCaptor.capture());
}
inOrder.verify(helper).refreshNameResolution();
if (!PickFirstLoadBalancerProvider.isEnabledNewPickFirst()) {
inOrder.verify(helper).updateBalancingState(eq(state), pickerCaptor.capture());
}
}
private static class FakeSocketAddress extends SocketAddress { private static class FakeSocketAddress extends SocketAddress {
final String name; final String name;

View File

@ -17,6 +17,7 @@
package io.grpc.xds; package io.grpc.xds;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth.assertWithMessage;
import static io.grpc.ConnectivityState.CONNECTING; import static io.grpc.ConnectivityState.CONNECTING;
import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.IDLE;
import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.READY;
@ -30,6 +31,7 @@ import static io.grpc.xds.RingHashLoadBalancerTest.InitializationFlags.STAY_IN_C
import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
@ -43,6 +45,7 @@ import com.google.common.collect.Iterables;
import com.google.common.primitives.UnsignedInteger; import com.google.common.primitives.UnsignedInteger;
import io.grpc.Attributes; import io.grpc.Attributes;
import io.grpc.CallOptions; import io.grpc.CallOptions;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo; import io.grpc.ConnectivityStateInfo;
import io.grpc.EquivalentAddressGroup; import io.grpc.EquivalentAddressGroup;
import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.CreateSubchannelArgs;
@ -56,6 +59,7 @@ import io.grpc.Metadata;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.Status.Code; import io.grpc.Status.Code;
import io.grpc.SynchronizationContext; import io.grpc.SynchronizationContext;
import io.grpc.internal.PickFirstLoadBalancerProvider;
import io.grpc.internal.PickSubchannelArgsImpl; import io.grpc.internal.PickSubchannelArgsImpl;
import io.grpc.testing.TestMethodDescriptors; import io.grpc.testing.TestMethodDescriptors;
import io.grpc.util.AbstractTestHelper; import io.grpc.util.AbstractTestHelper;
@ -89,6 +93,9 @@ import org.mockito.junit.MockitoRule;
public class RingHashLoadBalancerTest { public class RingHashLoadBalancerTest {
private static final String AUTHORITY = "foo.googleapis.com"; private static final String AUTHORITY = "foo.googleapis.com";
private static final Attributes.Key<String> CUSTOM_KEY = Attributes.Key.create("custom-key"); private static final Attributes.Key<String> CUSTOM_KEY = Attributes.Key.create("custom-key");
private static final ConnectivityStateInfo CSI_CONNECTING =
ConnectivityStateInfo.forNonError(CONNECTING);
public static final ConnectivityStateInfo CSI_READY = ConnectivityStateInfo.forNonError(READY);
@Rule @Rule
public final MockitoRule mocks = MockitoJUnit.rule(); public final MockitoRule mocks = MockitoJUnit.rule();
@ -145,11 +152,12 @@ public class RingHashLoadBalancerTest {
verify(subchannel).requestConnection(); verify(subchannel).requestConnection();
verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class));
verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).createSubchannel(any(CreateSubchannelArgs.class));
deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(CONNECTING)); deliverSubchannelState(subchannel, CSI_CONNECTING);
verify(helper, times(2)).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); int expectedCount = PickFirstLoadBalancerProvider.isEnabledNewPickFirst() ? 1 : 2;
verify(helper, times(expectedCount)).updateBalancingState(eq(CONNECTING), any());
// Subchannel becomes ready, triggers pick again. // Subchannel becomes ready, triggers pick again.
deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); deliverSubchannelState(subchannel, CSI_READY);
verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
result = pickerCaptor.getValue().pickSubchannel(args); result = pickerCaptor.getValue().pickSubchannel(args);
assertThat(result.getSubchannel()).isSameInstanceAs(subchannel); assertThat(result.getSubchannel()).isSameInstanceAs(subchannel);
@ -174,11 +182,13 @@ public class RingHashLoadBalancerTest {
PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid());
pickerCaptor.getValue().pickSubchannel(args); pickerCaptor.getValue().pickSubchannel(args);
assertThat(childLbState.isDeactivated()).isFalse(); assertThat(childLbState.isDeactivated()).isFalse();
assertThat(childLbState.getLb().delegateType()).isEqualTo("PickFirstLoadBalancer"); String expectedLbType = PickFirstLoadBalancerProvider.isEnabledNewPickFirst()
? "PickFirstLeafLoadBalancer" : "PickFirstLoadBalancer";
assertThat(childLbState.getLb().delegateType()).isEqualTo(expectedLbType);
Subchannel subchannel = subchannels.get(Collections.singletonList(childLbState.getEag())); Subchannel subchannel = subchannels.get(Collections.singletonList(childLbState.getEag()));
InOrder inOrder = Mockito.inOrder(helper, subchannel); InOrder inOrder = Mockito.inOrder(helper, subchannel);
inOrder.verify(subchannel).requestConnection(); inOrder.verify(subchannel).requestConnection();
deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); deliverSubchannelState(subchannel, CSI_READY);
inOrder.verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class)); inOrder.verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class));
deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(IDLE)); deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(IDLE));
inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture());
@ -198,50 +208,51 @@ public class RingHashLoadBalancerTest {
initializeLbSubchannels(config, servers); initializeLbSubchannels(config, servers);
// one in CONNECTING, one in IDLE // one in CONNECTING, one in IDLE
deliverSubchannelState( deliverSubchannelState(getSubchannel(servers, 0), CSI_CONNECTING);
subchannels.get(Collections.singletonList(servers.get(0))),
ConnectivityStateInfo.forNonError(CONNECTING));
inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class));
verifyConnection(0); verifyConnection(0);
// two in CONNECTING // two in CONNECTING
deliverSubchannelState( deliverSubchannelState(getSubchannel(servers, 1), CSI_CONNECTING);
subchannels.get(Collections.singletonList(servers.get(1))),
ConnectivityStateInfo.forNonError(CONNECTING));
inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class));
verifyConnection(0); verifyConnection(0);
// one in CONNECTING, one in READY // one in CONNECTING, one in READY
deliverSubchannelState( deliverSubchannelState(getSubchannel(servers, 1), CSI_READY);
subchannels.get(Collections.singletonList(servers.get(1))),
ConnectivityStateInfo.forNonError(READY));
inOrder.verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class)); inOrder.verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class));
verifyConnection(0); verifyConnection(0);
// one in TRANSIENT_FAILURE, one in READY // one in TRANSIENT_FAILURE, one in READY
deliverSubchannelState( deliverSubchannelState(
subchannels.get(Collections.singletonList(servers.get(0))), getSubchannel(servers, 0),
ConnectivityStateInfo.forTransientFailure( ConnectivityStateInfo.forTransientFailure(
Status.UNKNOWN.withDescription("unknown failure"))); Status.UNKNOWN.withDescription("unknown failure")));
inOrder.verify(helper).refreshNameResolution(); if (PickFirstLoadBalancerProvider.isEnabledNewPickFirst()) {
inOrder.verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class)); inOrder.verify(helper).updateBalancingState(eq(READY), any());
} else {
inOrder.verify(helper).refreshNameResolution();
inOrder.verify(helper).updateBalancingState(eq(READY), any());
}
verifyConnection(0); verifyConnection(0);
// one in TRANSIENT_FAILURE, one in IDLE // one in TRANSIENT_FAILURE, one in IDLE
deliverSubchannelState( deliverSubchannelState(
subchannels.get(Collections.singletonList(servers.get(1))), getSubchannel(servers, 1),
ConnectivityStateInfo.forNonError(IDLE)); ConnectivityStateInfo.forNonError(IDLE));
inOrder.verify(helper).refreshNameResolution(); if (PickFirstLoadBalancerProvider.isEnabledNewPickFirst()) {
inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any());
} else {
inOrder.verify(helper).refreshNameResolution();
inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any());
}
verifyConnection(0); verifyConnection(0);
verifyNoMoreInteractions(helper);
} }
private void verifyConnection(int times) { private void verifyConnection(int times) {
for (int i = 0; i < times; i++) { for (int i = 0; i < times; i++) {
Subchannel connectOnce = connectionRequestedQueue.poll(); Subchannel connectOnce = connectionRequestedQueue.poll();
assertThat(connectOnce).isNotNull(); assertWithMessage("Null connection is at (%s) of (%s)", i, times)
.that(connectOnce).isNotNull();
clearInvocations(connectOnce); clearInvocations(connectOnce);
} }
assertThat(connectionRequestedQueue.poll()).isNull(); assertThat(connectionRequestedQueue.poll()).isNull();
@ -261,37 +272,48 @@ public class RingHashLoadBalancerTest {
// one in TRANSIENT_FAILURE, three in CONNECTING // one in TRANSIENT_FAILURE, three in CONNECTING
deliverNotFound(subChannelList, 0); deliverNotFound(subChannelList, 0);
inOrder.verify(helper).refreshNameResolution(); refreshInvokedButNotUpdateBS(inOrder, TRANSIENT_FAILURE);
inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class));
// two in TRANSIENT_FAILURE, two in CONNECTING // two in TRANSIENT_FAILURE, two in CONNECTING
deliverNotFound(subChannelList, 1); deliverNotFound(subChannelList, 1);
inOrder.verify(helper).refreshNameResolution(); refreshInvokedAndUpdateBS(inOrder, TRANSIENT_FAILURE);
inOrder.verify(helper)
.updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class));
// All 4 in TF switch to TF // All 4 in TF switch to TF
deliverNotFound(subChannelList, 2); deliverNotFound(subChannelList, 2);
inOrder.verify(helper).refreshNameResolution(); refreshInvokedAndUpdateBS(inOrder, TRANSIENT_FAILURE);
inOrder.verify(helper)
.updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class));
deliverNotFound(subChannelList, 3); deliverNotFound(subChannelList, 3);
inOrder.verify(helper).refreshNameResolution(); refreshInvokedAndUpdateBS(inOrder, TRANSIENT_FAILURE);
inOrder.verify(helper)
.updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class));
// reset subchannel to CONNECTING - shouldn't change anything since PF hides the state change // reset subchannel to CONNECTING - shouldn't change anything since PF hides the state change
deliverSubchannelState(subChannelList.get(2), ConnectivityStateInfo.forNonError(CONNECTING)); deliverSubchannelState(subChannelList.get(2), CSI_CONNECTING);
inOrder.verify(helper, never()) inOrder.verify(helper, never())
.updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class));
inOrder.verify(subChannelList.get(2), never()).requestConnection(); inOrder.verify(subChannelList.get(2), never()).requestConnection();
// three in TRANSIENT_FAILURE, one in READY // three in TRANSIENT_FAILURE, one in READY
deliverSubchannelState(subChannelList.get(2), ConnectivityStateInfo.forNonError(READY)); deliverSubchannelState(subChannelList.get(2), CSI_READY);
inOrder.verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class)); inOrder.verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class));
inOrder.verify(subChannelList.get(2), never()).requestConnection(); inOrder.verify(subChannelList.get(2), never()).requestConnection();
}
verifyNoMoreInteractions(helper); // Old PF and new PF reverse calling order of updateBlaancingState and refreshNameResolution
private void refreshInvokedButNotUpdateBS(InOrder inOrder, ConnectivityState state) {
inOrder.verify(helper, never()).updateBalancingState(eq(state), any(SubchannelPicker.class));
inOrder.verify(helper).refreshNameResolution();
inOrder.verify(helper, never()).updateBalancingState(eq(state), any(SubchannelPicker.class));
}
// Old PF and new PF reverse calling order of updateBlaancingState and refreshNameResolution
private void refreshInvokedAndUpdateBS(InOrder inOrder, ConnectivityState state) {
if (PickFirstLoadBalancerProvider.isEnabledNewPickFirst()) {
inOrder.verify(helper).updateBalancingState(eq(state), any());
}
inOrder.verify(helper).refreshNameResolution();
if (!PickFirstLoadBalancerProvider.isEnabledNewPickFirst()) {
inOrder.verify(helper).updateBalancingState(eq(state), any());
}
} }
@Test @Test
@ -319,7 +341,7 @@ public class RingHashLoadBalancerTest {
// Bring all subchannels to READY so that next pick always succeeds. // Bring all subchannels to READY so that next pick always succeeds.
for (Subchannel subchannel : subchannels.values()) { for (Subchannel subchannel : subchannels.values()) {
deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); deliverSubchannelState(subchannel, CSI_READY);
inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
} }
@ -336,8 +358,8 @@ public class RingHashLoadBalancerTest {
Attributes attr = addr.getAttributes().toBuilder().set(CUSTOM_KEY, "custom value").build(); Attributes attr = addr.getAttributes().toBuilder().set(CUSTOM_KEY, "custom value").build();
updatedServers.add(new EquivalentAddressGroup(addr.getAddresses(), attr)); updatedServers.add(new EquivalentAddressGroup(addr.getAddresses(), attr));
} }
Subchannel subchannel0_old = subchannels.get(Collections.singletonList(servers.get(0))); Subchannel subchannel0_old = getSubchannel(servers, 0);
Subchannel subchannel1_old = subchannels.get(Collections.singletonList(servers.get(1))); Subchannel subchannel1_old = getSubchannel(servers, 1);
Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
ResolvedAddresses.newBuilder() ResolvedAddresses.newBuilder()
.setAddresses(updatedServers).setLoadBalancingPolicyConfig(config).build()); .setAddresses(updatedServers).setLoadBalancingPolicyConfig(config).build());
@ -360,7 +382,7 @@ public class RingHashLoadBalancerTest {
// Bring all subchannels to READY so that next pick always succeeds. // Bring all subchannels to READY so that next pick always succeeds.
for (Subchannel subchannel : subchannels.values()) { for (Subchannel subchannel : subchannels.values()) {
deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); deliverSubchannelState(subchannel, CSI_READY);
inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
} }
@ -425,16 +447,15 @@ public class RingHashLoadBalancerTest {
verify(getSubChannel(servers.get(1))).requestConnection(); // kicked off connection to server2 verify(getSubChannel(servers.get(1))).requestConnection(); // kicked off connection to server2
assertThat(subchannels.size()).isEqualTo(2); // no excessive connection assertThat(subchannels.size()).isEqualTo(2); // no excessive connection
reset(helper); deliverSubchannelState(getSubChannel(servers.get(1)), CSI_CONNECTING);
deliverSubchannelState(getSubChannel(servers.get(1)), verify(helper, atLeast(1))
ConnectivityStateInfo.forNonError(CONNECTING)); .updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
result = pickerCaptor.getValue().pickSubchannel(args); result = pickerCaptor.getValue().pickSubchannel(args);
assertThat(result.getStatus().isOk()).isTrue(); assertThat(result.getStatus().isOk()).isTrue();
assertThat(result.getSubchannel()).isNull(); // buffer request assertThat(result.getSubchannel()).isNull(); // buffer request
deliverSubchannelState(getSubChannel(servers.get(1)), ConnectivityStateInfo.forNonError(READY)); deliverSubchannelState(getSubChannel(servers.get(1)), CSI_READY);
verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
result = pickerCaptor.getValue().pickSubchannel(args); result = pickerCaptor.getValue().pickSubchannel(args);
@ -471,21 +492,22 @@ public class RingHashLoadBalancerTest {
// Bring down server0 and server2 to force trying server1. // Bring down server0 and server2 to force trying server1.
deliverSubchannelState( deliverSubchannelState(
subchannels.get(Collections.singletonList(servers.get(1))), getSubchannel(servers, 1),
ConnectivityStateInfo.forTransientFailure( ConnectivityStateInfo.forTransientFailure(
Status.UNAVAILABLE.withDescription("unreachable"))); Status.UNAVAILABLE.withDescription("unreachable")));
deliverSubchannelState( deliverSubchannelState(
subchannels.get(Collections.singletonList(servers.get(2))), getSubchannel(servers, 2),
ConnectivityStateInfo.forTransientFailure( ConnectivityStateInfo.forTransientFailure(
Status.PERMISSION_DENIED.withDescription("permission denied"))); Status.PERMISSION_DENIED.withDescription("permission denied")));
verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
verifyConnection(0); verifyConnection(0);
PickResult result = pickerCaptor.getValue().pickSubchannel(args); // activate last subchannel PickResult result = pickerCaptor.getValue().pickSubchannel(args); // activate last subchannel
assertThat(result.getStatus().isOk()).isTrue(); assertThat(result.getStatus().isOk()).isTrue();
verifyConnection(1); int expectedCount = PickFirstLoadBalancerProvider.isEnabledNewPickFirst() ? 0 : 1;
verifyConnection(expectedCount);
deliverSubchannelState( deliverSubchannelState(
subchannels.get(Collections.singletonList(servers.get(0))), getSubchannel(servers, 0),
ConnectivityStateInfo.forTransientFailure( ConnectivityStateInfo.forTransientFailure(
Status.PERMISSION_DENIED.withDescription("permission denied again"))); Status.PERMISSION_DENIED.withDescription("permission denied again")));
verify(helper, times(2)).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); verify(helper, times(2)).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
@ -496,9 +518,7 @@ public class RingHashLoadBalancerTest {
assertThat(result.getStatus().getDescription()).isEqualTo("unreachable"); assertThat(result.getStatus().getDescription()).isEqualTo("unreachable");
// Now connecting to server1. // Now connecting to server1.
deliverSubchannelState( deliverSubchannelState(getSubchannel(servers, 1), CSI_CONNECTING);
subchannels.get(Collections.singletonList(servers.get(1))),
ConnectivityStateInfo.forNonError(CONNECTING));
reset(helper); reset(helper);
@ -509,9 +529,7 @@ public class RingHashLoadBalancerTest {
assertThat(result.getStatus().getDescription()).isEqualTo("unreachable"); assertThat(result.getStatus().getDescription()).isEqualTo("unreachable");
// Simulate server1 becomes READY. // Simulate server1 becomes READY.
deliverSubchannelState( deliverSubchannelState(getSubchannel(servers, 1), CSI_READY);
subchannels.get(Collections.singletonList(servers.get(1))),
ConnectivityStateInfo.forNonError(READY));
verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
SubchannelPicker picker = pickerCaptor.getValue(); SubchannelPicker picker = pickerCaptor.getValue();
@ -574,7 +592,7 @@ public class RingHashLoadBalancerTest {
initializeLbSubchannels(config, servers); initializeLbSubchannels(config, servers);
// Go to TF does nothing, though PF will try to reconnect after backoff // Go to TF does nothing, though PF will try to reconnect after backoff
deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(1))), deliverSubchannelState(getSubchannel(servers, 1),
ConnectivityStateInfo.forTransientFailure( ConnectivityStateInfo.forTransientFailure(
Status.UNAVAILABLE.withDescription("unreachable"))); Status.UNAVAILABLE.withDescription("unreachable")));
verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
@ -594,22 +612,21 @@ public class RingHashLoadBalancerTest {
List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1); List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1);
initializeLbSubchannels(config, servers); initializeLbSubchannels(config, servers);
deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(0))), deliverSubchannelState(getSubchannel(servers, 0), CSI_CONNECTING);
ConnectivityStateInfo.forNonError(CONNECTING)); deliverSubchannelState(getSubchannel(servers, 1), CSI_CONNECTING);
deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(1))),
ConnectivityStateInfo.forNonError(CONNECTING));
verify(helper, times(2)).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); verify(helper, times(2)).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
// Picking subchannel triggers connection. // Picking subchannel triggers connection.
PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid());
PickResult result = pickerCaptor.getValue().pickSubchannel(args); PickResult result = pickerCaptor.getValue().pickSubchannel(args);
assertThat(result.getStatus().isOk()).isTrue(); assertThat(result.getStatus().isOk()).isTrue();
verify(subchannels.get(Collections.singletonList(servers.get(0))), never()) verify(getSubchannel(servers, 0), never()).requestConnection();
.requestConnection(); verify(getSubchannel(servers, 1), never()).requestConnection();
verify(subchannels.get(Collections.singletonList(servers.get(1))), never()) verify(getSubchannel(servers, 2), never()).requestConnection();
.requestConnection(); }
verify(subchannels.get(Collections.singletonList(servers.get(2))), never())
.requestConnection(); private Subchannel getSubchannel(List<EquivalentAddressGroup> servers, int serverIndex) {
return subchannels.get(Collections.singletonList(servers.get(serverIndex)));
} }
@Test @Test
@ -656,17 +673,16 @@ public class RingHashLoadBalancerTest {
// "FakeSocketAddress-server0_0" // "FakeSocketAddress-server0_0"
// "FakeSocketAddress-server2_0" // "FakeSocketAddress-server2_0"
Subchannel firstSubchannel = subchannels.get(Collections.singletonList(servers.get(0))); Subchannel firstSubchannel = getSubchannel(servers, 0);
deliverSubchannelUnreachable(firstSubchannel); deliverSubchannelUnreachable(firstSubchannel);
verifyConnection(0); verifyConnection(0);
deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(2))), deliverSubchannelState(getSubchannel(servers, 2), CSI_CONNECTING);
ConnectivityStateInfo.forNonError(CONNECTING));
verify(helper, times(2)).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); verify(helper, times(2)).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
verifyConnection(0); verifyConnection(0);
// Picking subchannel when idle triggers connection. // Picking subchannel when idle triggers connection.
deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(2))), deliverSubchannelState(getSubchannel(servers, 2),
ConnectivityStateInfo.forNonError(IDLE)); ConnectivityStateInfo.forNonError(IDLE));
verifyConnection(0); verifyConnection(0);
PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid());
@ -688,9 +704,9 @@ public class RingHashLoadBalancerTest {
// "FakeSocketAddress-server0_0" // "FakeSocketAddress-server0_0"
// "FakeSocketAddress-server2_0" // "FakeSocketAddress-server2_0"
Subchannel firstSubchannel = subchannels.get(Collections.singletonList(servers.get(0))); Subchannel firstSubchannel = getSubchannel(servers, 0);
deliverSubchannelUnreachable(firstSubchannel); deliverSubchannelUnreachable(firstSubchannel);
deliverSubchannelUnreachable(subchannels.get(Collections.singletonList(servers.get(2)))); deliverSubchannelUnreachable(getSubchannel(servers, 2));
verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
verifyConnection(0); verifyConnection(0);
@ -698,7 +714,7 @@ public class RingHashLoadBalancerTest {
PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid());
PickResult result = pickerCaptor.getValue().pickSubchannel(args); PickResult result = pickerCaptor.getValue().pickSubchannel(args);
assertThat(result.getStatus().isOk()).isTrue(); assertThat(result.getStatus().isOk()).isTrue();
verify(subchannels.get(Collections.singletonList(servers.get(1)))).requestConnection(); verify(getSubchannel(servers, 1)).requestConnection();
verifyConnection(1); verifyConnection(1);
} }
@ -715,12 +731,11 @@ public class RingHashLoadBalancerTest {
// "FakeSocketAddress-server0_0" // "FakeSocketAddress-server0_0"
// "FakeSocketAddress-server2_0" // "FakeSocketAddress-server2_0"
Subchannel firstSubchannel = subchannels.get(Collections.singletonList(servers.get(0))); Subchannel firstSubchannel = getSubchannel(servers, 0);
deliverSubchannelUnreachable(firstSubchannel); deliverSubchannelUnreachable(firstSubchannel);
deliverSubchannelUnreachable(subchannels.get(Collections.singletonList(servers.get(2)))); deliverSubchannelUnreachable(getSubchannel(servers, 2));
deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(1))), deliverSubchannelState(getSubchannel(servers, 1), CSI_CONNECTING);
ConnectivityStateInfo.forNonError(CONNECTING));
verify(helper, atLeastOnce()) verify(helper, atLeastOnce())
.updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); .updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
verifyConnection(0); verifyConnection(0);
@ -741,7 +756,7 @@ public class RingHashLoadBalancerTest {
initializeLbSubchannels(config, servers); initializeLbSubchannels(config, servers);
// Bring one subchannel to TRANSIENT_FAILURE. // Bring one subchannel to TRANSIENT_FAILURE.
Subchannel firstSubchannel = subchannels.get(Collections.singletonList(servers.get(0))); Subchannel firstSubchannel = getSubchannel(servers, 0);
deliverSubchannelUnreachable(firstSubchannel); deliverSubchannelUnreachable(firstSubchannel);
verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
@ -752,15 +767,14 @@ public class RingHashLoadBalancerTest {
// Should not have called updateBalancingState on the helper again because PickFirst is // Should not have called updateBalancingState on the helper again because PickFirst is
// shielding the higher level from the state change. // shielding the higher level from the state change.
verify(helper, never()).updateBalancingState(any(), any()); verify(helper, never()).updateBalancingState(any(), any());
verifyConnection(1); verifyConnection(PickFirstLoadBalancerProvider.isEnabledNewPickFirst() ? 0 : 1);
// Picking subchannel triggers connection on second address. RPC hash hits server0. // Picking subchannel triggers connection on second address. RPC hash hits server0.
PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid());
PickResult result = pickerCaptor.getValue().pickSubchannel(args); PickResult result = pickerCaptor.getValue().pickSubchannel(args);
assertThat(result.getStatus().isOk()).isTrue(); assertThat(result.getStatus().isOk()).isTrue();
verify(subchannels.get(Collections.singletonList(servers.get(1)))).requestConnection(); verify(getSubchannel(servers, 1)).requestConnection();
verify(subchannels.get(Collections.singletonList(servers.get(2))), never()) verify(getSubchannel(servers, 2), never()).requestConnection();
.requestConnection();
} }
@Test @Test
@ -811,7 +825,7 @@ public class RingHashLoadBalancerTest {
// Bring all subchannels to READY. // Bring all subchannels to READY.
Map<EquivalentAddressGroup, Integer> pickCounts = new HashMap<>(); Map<EquivalentAddressGroup, Integer> pickCounts = new HashMap<>();
for (Subchannel subchannel : subchannels.values()) { for (Subchannel subchannel : subchannels.values()) {
deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); deliverSubchannelState(subchannel, CSI_READY);
pickCounts.put(subchannel.getAddresses(), 0); pickCounts.put(subchannel.getAddresses(), 0);
} }
verify(helper, times(3)).updateBalancingState(eq(READY), pickerCaptor.capture()); verify(helper, times(3)).updateBalancingState(eq(READY), pickerCaptor.capture());
@ -858,7 +872,7 @@ public class RingHashLoadBalancerTest {
pickerCaptor.getValue().pickSubchannel(args); pickerCaptor.getValue().pickSubchannel(args);
verify(helper, never()).updateBalancingState(eq(READY), any(SubchannelPicker.class)); verify(helper, never()).updateBalancingState(eq(READY), any(SubchannelPicker.class));
deliverSubchannelState( deliverSubchannelState(
Iterables.getOnlyElement(subchannels.values()), ConnectivityStateInfo.forNonError(READY)); Iterables.getOnlyElement(subchannels.values()), CSI_READY);
verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class)); verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class));
reset(helper); reset(helper);