xds: implement WeightedTargetLoadBalancer

This commit is contained in:
ZHANG Dapeng 2020-03-11 15:35:31 -07:00 committed by GitHub
parent 3b8e36358c
commit 5e7b8c672f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 959 additions and 100 deletions

View File

@ -48,13 +48,12 @@ import io.grpc.xds.EnvoyProtoData.DropOverload;
import io.grpc.xds.EnvoyProtoData.LbEndpoint; import io.grpc.xds.EnvoyProtoData.LbEndpoint;
import io.grpc.xds.EnvoyProtoData.Locality; import io.grpc.xds.EnvoyProtoData.Locality;
import io.grpc.xds.EnvoyProtoData.LocalityLbEndpoints; import io.grpc.xds.EnvoyProtoData.LocalityLbEndpoints;
import io.grpc.xds.InterLocalityPicker.WeightedChildPicker;
import io.grpc.xds.OrcaOobUtil.OrcaReportingConfig; import io.grpc.xds.OrcaOobUtil.OrcaReportingConfig;
import io.grpc.xds.OrcaOobUtil.OrcaReportingHelperWrapper; import io.grpc.xds.OrcaOobUtil.OrcaReportingHelperWrapper;
import io.grpc.xds.WeightedRandomPicker.WeightedChildPicker;
import io.grpc.xds.XdsLogger.XdsLogLevel; import io.grpc.xds.XdsLogger.XdsLogLevel;
import io.grpc.xds.XdsSubchannelPickers.ErrorPicker; import io.grpc.xds.XdsSubchannelPickers.ErrorPicker;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
@ -109,7 +108,6 @@ interface LocalityStore {
private final XdsLogger logger; private final XdsLogger logger;
private final Helper helper; private final Helper helper;
private final PickerFactory pickerFactory;
private final LoadBalancerProvider loadBalancerProvider; private final LoadBalancerProvider loadBalancerProvider;
private final ThreadSafeRandom random; private final ThreadSafeRandom random;
private final LoadStatsStore loadStatsStore; private final LoadStatsStore loadStatsStore;
@ -130,7 +128,6 @@ interface LocalityStore {
this( this(
logId, logId,
helper, helper,
pickerFactoryImpl,
lbRegistry, lbRegistry,
ThreadSafeRandom.ThreadSafeRandomImpl.instance, ThreadSafeRandom.ThreadSafeRandomImpl.instance,
loadStatsStore, loadStatsStore,
@ -142,14 +139,12 @@ interface LocalityStore {
LocalityStoreImpl( LocalityStoreImpl(
InternalLogId logId, InternalLogId logId,
Helper helper, Helper helper,
PickerFactory pickerFactory,
LoadBalancerRegistry lbRegistry, LoadBalancerRegistry lbRegistry,
ThreadSafeRandom random, ThreadSafeRandom random,
LoadStatsStore loadStatsStore, LoadStatsStore loadStatsStore,
OrcaPerRequestUtil orcaPerRequestUtil, OrcaPerRequestUtil orcaPerRequestUtil,
OrcaOobUtil orcaOobUtil) { OrcaOobUtil orcaOobUtil) {
this.helper = checkNotNull(helper, "helper"); this.helper = checkNotNull(helper, "helper");
this.pickerFactory = checkNotNull(pickerFactory, "pickerFactory");
loadBalancerProvider = checkNotNull( loadBalancerProvider = checkNotNull(
lbRegistry.getProvider(ROUND_ROBIN), lbRegistry.getProvider(ROUND_ROBIN),
"Unable to find '%s' LoadBalancer", ROUND_ROBIN); "Unable to find '%s' LoadBalancer", ROUND_ROBIN);
@ -160,11 +155,6 @@ interface LocalityStore {
logger = XdsLogger.withLogId(checkNotNull(logId, "logId")); logger = XdsLogger.withLogId(checkNotNull(logId, "logId"));
} }
@VisibleForTesting // Introduced for testing only.
interface PickerFactory {
SubchannelPicker picker(List<WeightedChildPicker> childPickers);
}
private final class DroppablePicker extends SubchannelPicker { private final class DroppablePicker extends SubchannelPicker {
final List<DropOverload> dropOverloads; final List<DropOverload> dropOverloads;
@ -206,14 +196,6 @@ interface LocalityStore {
} }
} }
private static final PickerFactory pickerFactoryImpl =
new PickerFactory() {
@Override
public SubchannelPicker picker(List<WeightedChildPicker> childPickers) {
return new InterLocalityPicker(childPickers);
}
};
@Override @Override
public void reset() { public void reset() {
for (Locality locality : localityMap.keySet()) { for (Locality locality : localityMap.keySet()) {
@ -335,7 +317,6 @@ interface LocalityStore {
private void updatePicker( private void updatePicker(
@Nullable ConnectivityState state, List<WeightedChildPicker> childPickers) { @Nullable ConnectivityState state, List<WeightedChildPicker> childPickers) {
childPickers = Collections.unmodifiableList(childPickers);
SubchannelPicker picker; SubchannelPicker picker;
if (childPickers.isEmpty()) { if (childPickers.isEmpty()) {
if (state == TRANSIENT_FAILURE) { if (state == TRANSIENT_FAILURE) {
@ -344,7 +325,7 @@ interface LocalityStore {
picker = XdsSubchannelPickers.BUFFER_PICKER; picker = XdsSubchannelPickers.BUFFER_PICKER;
} }
} else { } else {
picker = pickerFactory.picker(childPickers); picker = new WeightedRandomPicker(childPickers);
} }
if (!dropOverloads.isEmpty()) { if (!dropOverloads.isEmpty()) {

View File

@ -21,21 +21,24 @@ import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects; import com.google.common.base.MoreObjects;
import com.google.common.collect.ImmutableList;
import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickResult;
import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.PickSubchannelArgs;
import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancer.SubchannelPicker;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Objects;
final class InterLocalityPicker extends SubchannelPicker { final class WeightedRandomPicker extends SubchannelPicker {
@VisibleForTesting
final List<WeightedChildPicker> weightedChildPickers;
private final List<WeightedChildPicker> weightedChildPickers;
private final ThreadSafeRandom random; private final ThreadSafeRandom random;
private final int totalWeight; private final int totalWeight;
static final class WeightedChildPicker { static final class WeightedChildPicker {
final int weight; private final int weight;
final SubchannelPicker childPicker; private final SubchannelPicker childPicker;
WeightedChildPicker(int weight, SubchannelPicker childPicker) { WeightedChildPicker(int weight, SubchannelPicker childPicker) {
checkArgument(weight >= 0, "weight is negative"); checkArgument(weight >= 0, "weight is negative");
@ -53,6 +56,23 @@ final class InterLocalityPicker extends SubchannelPicker {
return childPicker; return childPicker;
} }
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
WeightedChildPicker that = (WeightedChildPicker) o;
return weight == that.weight && Objects.equals(childPicker, that.childPicker);
}
@Override
public int hashCode() {
return Objects.hash(weight, childPicker);
}
@Override @Override
public String toString() { public String toString() {
return MoreObjects.toStringHelper(this) return MoreObjects.toStringHelper(this)
@ -62,16 +82,16 @@ final class InterLocalityPicker extends SubchannelPicker {
} }
} }
InterLocalityPicker(List<WeightedChildPicker> weightedChildPickers) { WeightedRandomPicker(List<WeightedChildPicker> weightedChildPickers) {
this(weightedChildPickers, ThreadSafeRandom.ThreadSafeRandomImpl.instance); this(weightedChildPickers, ThreadSafeRandom.ThreadSafeRandomImpl.instance);
} }
@VisibleForTesting @VisibleForTesting
InterLocalityPicker(List<WeightedChildPicker> weightedChildPickers, ThreadSafeRandom random) { WeightedRandomPicker(List<WeightedChildPicker> weightedChildPickers, ThreadSafeRandom random) {
checkNotNull(weightedChildPickers, "weightedChildPickers in null"); checkNotNull(weightedChildPickers, "weightedChildPickers in null");
checkArgument(!weightedChildPickers.isEmpty(), "weightedChildPickers is empty"); checkArgument(!weightedChildPickers.isEmpty(), "weightedChildPickers is empty");
this.weightedChildPickers = ImmutableList.copyOf(weightedChildPickers); this.weightedChildPickers = Collections.unmodifiableList(weightedChildPickers);
int totalWeight = 0; int totalWeight = 0;
for (WeightedChildPicker weightedChildPicker : weightedChildPickers) { for (WeightedChildPicker weightedChildPicker : weightedChildPickers) {

View File

@ -0,0 +1,193 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds;
import static com.google.common.base.Preconditions.checkNotNull;
import static io.grpc.ConnectivityState.CONNECTING;
import static io.grpc.ConnectivityState.IDLE;
import static io.grpc.ConnectivityState.READY;
import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
import static io.grpc.xds.XdsSubchannelPickers.BUFFER_PICKER;
import com.google.common.collect.ImmutableMap;
import io.grpc.ConnectivityState;
import io.grpc.InternalLogId;
import io.grpc.LoadBalancer;
import io.grpc.Status;
import io.grpc.util.ForwardingLoadBalancerHelper;
import io.grpc.util.GracefulSwitchLoadBalancer;
import io.grpc.xds.WeightedRandomPicker.WeightedChildPicker;
import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedPolicySelection;
import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedTargetConfig;
import io.grpc.xds.XdsLogger.XdsLogLevel;
import io.grpc.xds.XdsSubchannelPickers.ErrorPicker;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
/** Load balancer for weighted_target policy. */
final class WeightedTargetLoadBalancer extends LoadBalancer {
private final XdsLogger logger;
private final Map<String, GracefulSwitchLoadBalancer> childBalancers = new HashMap<>();
private final Map<String, ChildHelper> childHelpers = new HashMap<>();
private final Helper helper;
private Map<String, WeightedPolicySelection> targets = ImmutableMap.of();
WeightedTargetLoadBalancer(Helper helper) {
this.helper = helper;
logger = XdsLogger.withLogId(
InternalLogId.allocate("weighted-target-lb", helper.getAuthority()));
logger.log(XdsLogLevel.INFO, "Created");
}
@Override
public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) {
logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses);
Object lbConfig = resolvedAddresses.getLoadBalancingPolicyConfig();
checkNotNull(lbConfig, "missing weighted_target lb config");
WeightedTargetConfig weightedTargetConfig = (WeightedTargetConfig) lbConfig;
Map<String, WeightedPolicySelection> newTargets = weightedTargetConfig.targets;
for (String targetName : newTargets.keySet()) {
WeightedPolicySelection weightedChildLbConfig = newTargets.get(targetName);
if (!targets.containsKey(targetName)) {
ChildHelper childHelper = new ChildHelper();
GracefulSwitchLoadBalancer childBalancer = new GracefulSwitchLoadBalancer(childHelper);
childBalancer.switchTo(weightedChildLbConfig.policySelection.getProvider());
childHelpers.put(targetName, childHelper);
childBalancers.put(targetName, childBalancer);
} else if (!weightedChildLbConfig.policySelection.getProvider().equals(
targets.get(targetName).policySelection.getProvider())) {
childBalancers.get(targetName)
.switchTo(weightedChildLbConfig.policySelection.getProvider());
}
}
targets = newTargets;
for (String targetName : targets.keySet()) {
childBalancers.get(targetName).handleResolvedAddresses(
resolvedAddresses.toBuilder()
.setLoadBalancingPolicyConfig(targets.get(targetName).policySelection.getConfig())
.build());
}
// Cleanup removed targets.
// TODO(zdapeng): cache removed target for 15 minutes.
for (String targetName : childBalancers.keySet()) {
if (!targets.containsKey(targetName)) {
childBalancers.get(targetName).shutdown();
}
}
childBalancers.keySet().retainAll(targets.keySet());
childHelpers.keySet().retainAll(targets.keySet());
}
@Override
public void handleNameResolutionError(Status error) {
logger.log(XdsLogLevel.WARNING, "Received name resolution error: {0}", error);
if (childBalancers.isEmpty()) {
helper.updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(error));
}
for (LoadBalancer childBalancer : childBalancers.values()) {
childBalancer.handleNameResolutionError(error);
}
}
@Override
public boolean canHandleEmptyAddressListFromNameResolution() {
return true;
}
@Override
public void shutdown() {
logger.log(XdsLogLevel.INFO, "Shutdown");
for (LoadBalancer childBalancer : childBalancers.values()) {
childBalancer.shutdown();
}
}
private void updateOverallBalancingState() {
List<WeightedChildPicker> childPickers = new ArrayList<>();
ConnectivityState overallState = null;
for (String name : targets.keySet()) {
ChildHelper childHelper = childHelpers.get(name);
ConnectivityState childState = childHelper.currentState;
overallState = aggregateState(overallState, childState);
if (READY == childState) {
int weight = targets.get(name).weight;
childPickers.add(new WeightedChildPicker(weight, childHelper.currentPicker));
}
}
SubchannelPicker picker;
if (childPickers.isEmpty()) {
if (overallState == TRANSIENT_FAILURE) {
picker = new ErrorPicker(Status.UNAVAILABLE); // TODO: more details in status
} else {
picker = XdsSubchannelPickers.BUFFER_PICKER;
}
} else {
picker = new WeightedRandomPicker(childPickers);
}
if (overallState != null) {
helper.updateBalancingState(overallState, picker);
}
}
@Nullable
private ConnectivityState aggregateState(
@Nullable ConnectivityState overallState, ConnectivityState childState) {
if (overallState == null) {
return childState;
}
if (overallState == READY || childState == READY) {
return READY;
}
if (overallState == CONNECTING || childState == CONNECTING) {
return CONNECTING;
}
if (overallState == IDLE || childState == IDLE) {
return IDLE;
}
return overallState;
}
private final class ChildHelper extends ForwardingLoadBalancerHelper {
ConnectivityState currentState = CONNECTING;
SubchannelPicker currentPicker = BUFFER_PICKER;
@Override
public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) {
currentState = newState;
currentPicker = newPicker;
updateOverallBalancingState();
}
@Override
protected Helper delegate() {
return helper;
}
}
}

View File

@ -0,0 +1,198 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import io.grpc.Internal;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancer.Helper;
import io.grpc.LoadBalancerProvider;
import io.grpc.LoadBalancerRegistry;
import io.grpc.NameResolver.ConfigOrError;
import io.grpc.Status;
import io.grpc.internal.JsonUtil;
import io.grpc.internal.ServiceConfigUtil;
import io.grpc.internal.ServiceConfigUtil.LbConfig;
import io.grpc.internal.ServiceConfigUtil.PolicySelection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import javax.annotation.Nullable;
/**
* The provider for the weighted_target balancing policy. This class should not be
* directly referenced in code. The policy should be accessed through {@link
* LoadBalancerRegistry#getProvider} with the name "weighted_target_experimental".
*/
@Internal
public final class WeightedTargetLoadBalancerProvider extends LoadBalancerProvider {
static final String WEIGHTED_TARGET_POLICY_NAME = "weighted_target_experimental";
@Nullable
private final LoadBalancerRegistry lbRegistry;
// We can not call this(LoadBalancerRegistry.getDefaultRegistry()), because it will get stuck
// recursively loading LoadBalancerRegistry and WeightedTargetLoadBalancerProvider.
public WeightedTargetLoadBalancerProvider() {
this(null);
}
@VisibleForTesting
WeightedTargetLoadBalancerProvider(@Nullable LoadBalancerRegistry lbRegistry) {
this.lbRegistry = lbRegistry;
}
@Override
public boolean isAvailable() {
return true;
}
@Override
public int getPriority() {
return 5;
}
@Override
public String getPolicyName() {
return WEIGHTED_TARGET_POLICY_NAME;
}
@Override
public LoadBalancer newLoadBalancer(Helper helper) {
return new WeightedTargetLoadBalancer(helper);
}
@Override
public ConfigOrError parseLoadBalancingPolicyConfig(Map<String, ?> rawConfig) {
try {
Map<String, ?> targets = JsonUtil.getObject(rawConfig, "targets");
if (targets == null || targets.isEmpty()) {
return ConfigOrError.fromError(Status.INTERNAL.withDescription(
"No targets provided for weighted_target LB policy:\n " + rawConfig));
}
Map<String, WeightedPolicySelection> parsedChildConfigs = new LinkedHashMap<>();
for (String name : targets.keySet()) {
Map<String, ?> rawWeightedTarget = JsonUtil.getObject(targets, name);
if (rawWeightedTarget == null || rawWeightedTarget.isEmpty()) {
return ConfigOrError.fromError(Status.INTERNAL.withDescription(
"No config for target " + name + " in weighted_target LB policy:\n " + rawConfig));
}
Integer weight = JsonUtil.getNumberAsInteger(rawWeightedTarget, "weight");
if (weight == null || weight < 1) {
return ConfigOrError.fromError(Status.INTERNAL.withDescription(
"Wrong weight for target " + name + " in weighted_target LB policy:\n " + rawConfig));
}
List<LbConfig> childConfigCandidates = ServiceConfigUtil.unwrapLoadBalancingConfigList(
JsonUtil.getListOfObjects(rawWeightedTarget, "childPolicy"));
if (childConfigCandidates == null || childConfigCandidates.isEmpty()) {
return ConfigOrError.fromError(Status.INTERNAL.withDescription(
"No child policy for target " + name + " in weighted_target LB policy:\n "
+ rawConfig));
}
LoadBalancerRegistry lbRegistry =
this.lbRegistry == null ? LoadBalancerRegistry.getDefaultRegistry() : this.lbRegistry;
ConfigOrError selectedConfig =
ServiceConfigUtil.selectLbPolicyFromList(childConfigCandidates, lbRegistry);
if (selectedConfig.getError() != null) {
return selectedConfig;
}
PolicySelection policySelection = (PolicySelection) selectedConfig.getConfig();
parsedChildConfigs.put(name, new WeightedPolicySelection(weight, policySelection));
}
return ConfigOrError.fromConfig(new WeightedTargetConfig(parsedChildConfigs));
} catch (RuntimeException e) {
return ConfigOrError.fromError(
Status.fromThrowable(e).withDescription(
"Failed to parse weighted_target LB config: " + rawConfig));
}
}
static final class WeightedPolicySelection {
final int weight;
final PolicySelection policySelection;
@VisibleForTesting
WeightedPolicySelection(int weight, PolicySelection policySelection) {
this.weight = weight;
this.policySelection = policySelection;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
WeightedPolicySelection that = (WeightedPolicySelection) o;
return weight == that.weight && Objects.equals(policySelection, that.policySelection);
}
@Override
public int hashCode() {
return Objects.hash(weight, policySelection);
}
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("weight", weight)
.add("policySelection", policySelection)
.toString();
}
}
/** The lb config for WeightedTargetLoadBalancer. */
static final class WeightedTargetConfig {
final Map<String, WeightedPolicySelection> targets;
@VisibleForTesting
WeightedTargetConfig(Map<String, WeightedPolicySelection> targets) {
this.targets = targets;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
WeightedTargetConfig that = (WeightedTargetConfig) o;
return Objects.equals(targets, that.targets);
}
@Override
public int hashCode() {
return Objects.hash(targets);
}
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("targets", targets)
.toString();
}
}
}

View File

@ -1,4 +1,5 @@
io.grpc.xds.CdsLoadBalancerProvider io.grpc.xds.CdsLoadBalancerProvider
io.grpc.xds.EdsLoadBalancerProvider io.grpc.xds.EdsLoadBalancerProvider
io.grpc.xds.WeightedTargetLoadBalancerProvider
io.grpc.xds.XdsLoadBalancerProvider io.grpc.xds.XdsLoadBalancerProvider
io.grpc.xds.XdsRoutingLoadBalancerProvider io.grpc.xds.XdsRoutingLoadBalancerProvider

View File

@ -68,20 +68,18 @@ import io.grpc.xds.EnvoyProtoData.DropOverload;
import io.grpc.xds.EnvoyProtoData.LbEndpoint; import io.grpc.xds.EnvoyProtoData.LbEndpoint;
import io.grpc.xds.EnvoyProtoData.Locality; import io.grpc.xds.EnvoyProtoData.Locality;
import io.grpc.xds.EnvoyProtoData.LocalityLbEndpoints; import io.grpc.xds.EnvoyProtoData.LocalityLbEndpoints;
import io.grpc.xds.InterLocalityPicker.WeightedChildPicker;
import io.grpc.xds.LocalityStore.LocalityStoreImpl; import io.grpc.xds.LocalityStore.LocalityStoreImpl;
import io.grpc.xds.LocalityStore.LocalityStoreImpl.PickerFactory;
import io.grpc.xds.OrcaOobUtil.OrcaOobReportListener; import io.grpc.xds.OrcaOobUtil.OrcaOobReportListener;
import io.grpc.xds.OrcaOobUtil.OrcaReportingConfig; import io.grpc.xds.OrcaOobUtil.OrcaReportingConfig;
import io.grpc.xds.OrcaOobUtil.OrcaReportingHelperWrapper; import io.grpc.xds.OrcaOobUtil.OrcaReportingHelperWrapper;
import io.grpc.xds.OrcaPerRequestUtil.OrcaPerRequestReportListener; import io.grpc.xds.OrcaPerRequestUtil.OrcaPerRequestReportListener;
import io.grpc.xds.WeightedRandomPicker.WeightedChildPicker;
import io.grpc.xds.XdsSubchannelPickers.ErrorPicker; import io.grpc.xds.XdsSubchannelPickers.ErrorPicker;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@ -108,25 +106,6 @@ public class LocalityStoreTest {
@Rule @Rule
public final MockitoRule mockitoRule = MockitoJUnit.rule(); public final MockitoRule mockitoRule = MockitoJUnit.rule();
private static final class FakePickerFactory implements PickerFactory {
int totalReadyLocalities;
int nextIndex;
List<WeightedChildPicker> perLocalitiesPickers;
@Override
public SubchannelPicker picker(final List<WeightedChildPicker> childPickers) {
totalReadyLocalities = childPickers.size();
perLocalitiesPickers = Collections.unmodifiableList(childPickers);
return new SubchannelPicker() {
@Override
public PickResult pickSubchannel(PickSubchannelArgs args) {
return childPickers.get(nextIndex).getPicker().pickSubchannel(args);
}
};
}
}
private final SynchronizationContext syncContext = new SynchronizationContext( private final SynchronizationContext syncContext = new SynchronizationContext(
new Thread.UncaughtExceptionHandler() { new Thread.UncaughtExceptionHandler() {
@Override @Override
@ -182,8 +161,6 @@ public class LocalityStoreTest {
} }
}; };
private final FakePickerFactory pickerFactory = new FakePickerFactory();
private final Locality locality1 = new Locality("r1", "z1", "sz1"); private final Locality locality1 = new Locality("r1", "z1", "sz1");
private final Locality locality2 = new Locality("r2", "z2", "sz2"); private final Locality locality2 = new Locality("r2", "z2", "sz2");
private final Locality locality3 = new Locality("r3", "z3", "sz3"); private final Locality locality3 = new Locality("r3", "z3", "sz3");
@ -253,7 +230,7 @@ public class LocalityStoreTest {
}); });
lbRegistry.register(lbProvider); lbRegistry.register(lbProvider);
localityStore = localityStore =
new LocalityStoreImpl(logId, helper, pickerFactory, lbRegistry, random, loadStatsStore, new LocalityStoreImpl(logId, helper, lbRegistry, random, loadStatsStore,
orcaPerRequestUtil, orcaOobUtil); orcaPerRequestUtil, orcaOobUtil);
} }
@ -302,7 +279,6 @@ public class LocalityStoreTest {
// Two child balancers are created. // Two child balancers are created.
assertThat(loadBalancers).hasSize(2); assertThat(loadBalancers).hasSize(2);
assertThat(pickerFactory.totalReadyLocalities).isEqualTo(0);
ClientStreamTracer.Factory metricsTracingFactory1 = mock(ClientStreamTracer.Factory.class); ClientStreamTracer.Factory metricsTracingFactory1 = mock(ClientStreamTracer.Factory.class);
ClientStreamTracer.Factory metricsTracingFactory2 = mock(ClientStreamTracer.Factory.class); ClientStreamTracer.Factory metricsTracingFactory2 = mock(ClientStreamTracer.Factory.class);
@ -326,10 +302,11 @@ public class LocalityStoreTest {
childHelpers.get("sz1").updateBalancingState(READY, subchannelPicker1); childHelpers.get("sz1").updateBalancingState(READY, subchannelPicker1);
childHelpers.get("sz2").updateBalancingState(READY, subchannelPicker2); childHelpers.get("sz2").updateBalancingState(READY, subchannelPicker2);
assertThat(pickerFactory.totalReadyLocalities).isEqualTo(2);
ArgumentCaptor<SubchannelPicker> interLocalityPickerCaptor = ArgumentCaptor.forClass(null); ArgumentCaptor<SubchannelPicker> interLocalityPickerCaptor = ArgumentCaptor.forClass(null);
verify(helper, times(2)).updateBalancingState(eq(READY), interLocalityPickerCaptor.capture()); verify(helper, times(2)).updateBalancingState(eq(READY), interLocalityPickerCaptor.capture());
SubchannelPicker interLocalityPicker = interLocalityPickerCaptor.getValue(); WeightedRandomPicker interLocalityPicker =
(WeightedRandomPicker) interLocalityPickerCaptor.getValue();
assertThat(interLocalityPicker.weightedChildPickers).hasSize(2);
// Verify each PickResult picked is intercepted with client stream tracer factory for // Verify each PickResult picked is intercepted with client stream tracer factory for
// recording load and backend metrics. // recording load and backend metrics.
@ -337,9 +314,9 @@ public class LocalityStoreTest {
= ImmutableMap.of(subchannel1, locality1, subchannel2, locality2); = ImmutableMap.of(subchannel1, locality1, subchannel2, locality2);
Map<Subchannel, ClientStreamTracer.Factory> metricsTracingFactoriesBySubchannel Map<Subchannel, ClientStreamTracer.Factory> metricsTracingFactoriesBySubchannel
= ImmutableMap.of(subchannel1, metricsTracingFactory1, subchannel2, metricsTracingFactory2); = ImmutableMap.of(subchannel1, metricsTracingFactory1, subchannel2, metricsTracingFactory2);
for (int i = 0; i < pickerFactory.totalReadyLocalities; i++) { for (int i = 0; i < interLocalityPicker.weightedChildPickers.size(); i++) {
pickerFactory.nextIndex = i; PickResult pickResult = interLocalityPicker.weightedChildPickers.get(i).getPicker()
PickResult pickResult = interLocalityPicker.pickSubchannel(pickSubchannelArgs); .pickSubchannel(pickSubchannelArgs);
Subchannel expectedSubchannel = pickResult.getSubchannel(); Subchannel expectedSubchannel = pickResult.getSubchannel();
Locality expectedLocality = localitiesBySubchannel.get(expectedSubchannel); Locality expectedLocality = localitiesBySubchannel.get(expectedSubchannel);
ArgumentCaptor<OrcaPerRequestReportListener> listenerCaptor = ArgumentCaptor.forClass(null); ArgumentCaptor<OrcaPerRequestReportListener> listenerCaptor = ArgumentCaptor.forClass(null);
@ -466,7 +443,6 @@ public class LocalityStoreTest {
ArgumentCaptor.forClass(ResolvedAddresses.class); ArgumentCaptor.forClass(ResolvedAddresses.class);
verify(loadBalancers.get("sz3")).handleResolvedAddresses(resolvedAddressesCaptor3.capture()); verify(loadBalancers.get("sz3")).handleResolvedAddresses(resolvedAddressesCaptor3.capture());
assertThat(resolvedAddressesCaptor3.getValue().getAddresses()).containsExactly(eag31, eag32); assertThat(resolvedAddressesCaptor3.getValue().getAddresses()).containsExactly(eag31, eag32);
assertThat(pickerFactory.totalReadyLocalities).isEqualTo(0);
// verify no more updateBalancingState except the initial CONNECTING state // verify no more updateBalancingState except the initial CONNECTING state
verify(helper, times(1)).updateBalancingState( verify(helper, times(1)).updateBalancingState(
any(ConnectivityState.class), any(SubchannelPicker.class)); any(ConnectivityState.class), any(SubchannelPicker.class));
@ -484,7 +460,6 @@ public class LocalityStoreTest {
ArgumentCaptor.forClass(SubchannelPicker.class); ArgumentCaptor.forClass(SubchannelPicker.class);
verify(helper, times(2)).updateBalancingState( verify(helper, times(2)).updateBalancingState(
same(CONNECTING), subchannelPickerCaptor12.capture()); same(CONNECTING), subchannelPickerCaptor12.capture());
assertThat(pickerFactory.totalReadyLocalities).isEqualTo(0);
assertThat(subchannelPickerCaptor12.getValue().pickSubchannel(pickSubchannelArgs)) assertThat(subchannelPickerCaptor12.getValue().pickSubchannel(pickSubchannelArgs))
.isEqualTo(PickResult.withNoResult()); .isEqualTo(PickResult.withNoResult());
@ -500,21 +475,22 @@ public class LocalityStoreTest {
ArgumentCaptor<SubchannelPicker> subchannelPickerCaptor = ArgumentCaptor<SubchannelPicker> subchannelPickerCaptor =
ArgumentCaptor.forClass(null); ArgumentCaptor.forClass(null);
verify(helper).updateBalancingState(same(READY), subchannelPickerCaptor.capture()); verify(helper).updateBalancingState(same(READY), subchannelPickerCaptor.capture());
assertThat(pickerFactory.totalReadyLocalities).isEqualTo(1); WeightedRandomPicker interLocalityPicker =
pickerFactory.nextIndex = 0; (WeightedRandomPicker) subchannelPickerCaptor.getValue();
assertThat(subchannelPickerCaptor.getValue().pickSubchannel(pickSubchannelArgs).getSubchannel()) assertThat(interLocalityPicker.weightedChildPickers).hasSize(1);
assertThat(interLocalityPicker.pickSubchannel(pickSubchannelArgs).getSubchannel())
.isEqualTo(subchannel31); .isEqualTo(subchannel31);
// subchannel12 goes to READY // subchannel12 goes to READY
childHelpers.get("sz1").updateBalancingState(READY, subchannelPicker12); childHelpers.get("sz1").updateBalancingState(READY, subchannelPicker12);
verify(helper, times(2)).updateBalancingState(same(READY), subchannelPickerCaptor.capture()); verify(helper, times(2)).updateBalancingState(same(READY), subchannelPickerCaptor.capture());
assertThat(pickerFactory.totalReadyLocalities).isEqualTo(2); interLocalityPicker = (WeightedRandomPicker) subchannelPickerCaptor.getValue();
assertThat(interLocalityPicker.weightedChildPickers).hasSize(2);
SubchannelPicker interLocalityPicker = subchannelPickerCaptor.getValue();
Set<Subchannel> pickedReadySubchannels = new HashSet<>(); Set<Subchannel> pickedReadySubchannels = new HashSet<>();
for (int i = 0; i < pickerFactory.totalReadyLocalities; i++) { for (int i = 0; i < interLocalityPicker.weightedChildPickers.size(); i++) {
pickerFactory.nextIndex = i; PickResult result = interLocalityPicker.weightedChildPickers.get(i).getPicker()
PickResult result = interLocalityPicker.pickSubchannel(pickSubchannelArgs); .pickSubchannel(pickSubchannelArgs);
pickedReadySubchannels.add(result.getSubchannel()); pickedReadySubchannels.add(result.getSubchannel());
} }
assertThat(pickedReadySubchannels).containsExactly(subchannel31, subchannel12); assertThat(pickedReadySubchannels).containsExactly(subchannel31, subchannel12);
@ -539,7 +515,9 @@ public class LocalityStoreTest {
verify(loadBalancers.get("sz1"), times(2)) verify(loadBalancers.get("sz1"), times(2))
.handleResolvedAddresses(resolvedAddressesCaptor1.capture()); .handleResolvedAddresses(resolvedAddressesCaptor1.capture());
assertThat(resolvedAddressesCaptor1.getValue().getAddresses()).containsExactly(eag11); assertThat(resolvedAddressesCaptor1.getValue().getAddresses()).containsExactly(eag11);
assertThat(pickerFactory.totalReadyLocalities).isEqualTo(1); verify(helper, times(3)).updateBalancingState(same(READY), subchannelPickerCaptor.capture());
interLocalityPicker = (WeightedRandomPicker) subchannelPickerCaptor.getValue();
assertThat(interLocalityPicker.weightedChildPickers).hasSize(1);
fakeClock.forwardTime(14, TimeUnit.MINUTES); fakeClock.forwardTime(14, TimeUnit.MINUTES);
verify(loadBalancers.get("sz3"), never()).shutdown(); verify(loadBalancers.get("sz3"), never()).shutdown();
@ -598,9 +576,10 @@ public class LocalityStoreTest {
// helper updated multiple times. Don't care how many times, just capture the latest picker // helper updated multiple times. Don't care how many times, just capture the latest picker
verify(helper, atLeastOnce()).updateBalancingState( verify(helper, atLeastOnce()).updateBalancingState(
same(READY), subchannelPickerCaptor.capture()); same(READY), subchannelPickerCaptor.capture());
assertThat(pickerFactory.totalReadyLocalities).isEqualTo(1); WeightedRandomPicker interLocalityPicker =
pickerFactory.nextIndex = 0; (WeightedRandomPicker) subchannelPickerCaptor.getValue();
assertThat(subchannelPickerCaptor.getValue().pickSubchannel(pickSubchannelArgs).getSubchannel()) assertThat(interLocalityPicker.weightedChildPickers).hasSize(1);
assertThat(interLocalityPicker.pickSubchannel(pickSubchannelArgs).getSubchannel())
.isEqualTo(subchannel3); .isEqualTo(subchannel3);
// verify no traffic will go to deactivated locality // verify no traffic will go to deactivated locality
@ -614,9 +593,10 @@ public class LocalityStoreTest {
childHelpers.get("sz2").updateBalancingState(READY, subchannelPicker2); childHelpers.get("sz2").updateBalancingState(READY, subchannelPicker2);
verify(helper, atLeastOnce()).updateBalancingState( verify(helper, atLeastOnce()).updateBalancingState(
same(READY), subchannelPickerCaptor.capture()); same(READY), subchannelPickerCaptor.capture());
assertThat(pickerFactory.totalReadyLocalities).isEqualTo(1); interLocalityPicker =
pickerFactory.nextIndex = 0; (WeightedRandomPicker) subchannelPickerCaptor.getValue();
assertThat(subchannelPickerCaptor.getValue().pickSubchannel(pickSubchannelArgs).getSubchannel()) assertThat(interLocalityPicker.weightedChildPickers).hasSize(1);
assertThat(interLocalityPicker.pickSubchannel(pickSubchannelArgs).getSubchannel())
.isEqualTo(subchannel3); .isEqualTo(subchannel3);
// update localities, reactivating sz1 // update localities, reactivating sz1
@ -625,13 +605,13 @@ public class LocalityStoreTest {
localityStore.updateLocalityStore(localityInfoMap); localityStore.updateLocalityStore(localityInfoMap);
verify(helper, atLeastOnce()).updateBalancingState( verify(helper, atLeastOnce()).updateBalancingState(
same(READY), subchannelPickerCaptor.capture()); same(READY), subchannelPickerCaptor.capture());
assertThat(pickerFactory.totalReadyLocalities).isEqualTo(2); interLocalityPicker =
pickerFactory.nextIndex = 0; (WeightedRandomPicker) subchannelPickerCaptor.getValue();
assertThat(subchannelPickerCaptor.getValue().pickSubchannel(pickSubchannelArgs).getSubchannel()) assertThat(interLocalityPicker.weightedChildPickers).hasSize(2);
.isEqualTo(subchannel1); assertThat(interLocalityPicker.weightedChildPickers.get(0).getPicker()
pickerFactory.nextIndex = 1; .pickSubchannel(pickSubchannelArgs).getSubchannel()).isEqualTo(subchannel1);
assertThat(subchannelPickerCaptor.getValue().pickSubchannel(pickSubchannelArgs).getSubchannel()) assertThat(interLocalityPicker.weightedChildPickers.get(1).getPicker()
.isEqualTo(subchannel3); .pickSubchannel(pickSubchannelArgs).getSubchannel()).isEqualTo(subchannel3);
verify(lb2, never()).shutdown(); verify(lb2, never()).shutdown();
// delayed deletion timer expires, no reactivation // delayed deletion timer expires, no reactivation
@ -648,9 +628,10 @@ public class LocalityStoreTest {
verify(helper, atLeastOnce()).updateBalancingState( verify(helper, atLeastOnce()).updateBalancingState(
same(READY), subchannelPickerCaptor.capture()); same(READY), subchannelPickerCaptor.capture());
assertThat(pickerFactory.totalReadyLocalities).isEqualTo(1); interLocalityPicker =
pickerFactory.nextIndex = 0; (WeightedRandomPicker) subchannelPickerCaptor.getValue();
assertThat(subchannelPickerCaptor.getValue().pickSubchannel(pickSubchannelArgs).getSubchannel()) assertThat(interLocalityPicker.weightedChildPickers).hasSize(1);
assertThat(interLocalityPicker.pickSubchannel(pickSubchannelArgs).getSubchannel())
.isEqualTo(subchannel1); .isEqualTo(subchannel1);
// sz3, sz4 pending removal // sz3, sz4 pending removal
assertThat(fakeClock.getPendingTasks(deactivationTaskFilter)).hasSize(2); assertThat(fakeClock.getPendingTasks(deactivationTaskFilter)).hasSize(2);
@ -701,7 +682,6 @@ public class LocalityStoreTest {
ArgumentCaptor.forClass(ResolvedAddresses.class); ArgumentCaptor.forClass(ResolvedAddresses.class);
verify(loadBalancers.get("sz3")).handleResolvedAddresses(resolvedAddressesCaptor3.capture()); verify(loadBalancers.get("sz3")).handleResolvedAddresses(resolvedAddressesCaptor3.capture());
assertThat(resolvedAddressesCaptor3.getValue().getAddresses()).containsExactly(eag31, eag32); assertThat(resolvedAddressesCaptor3.getValue().getAddresses()).containsExactly(eag31, eag32);
assertThat(pickerFactory.totalReadyLocalities).isEqualTo(0);
ArgumentCaptor<SubchannelPicker> subchannelPickerCaptor = ArgumentCaptor<SubchannelPicker> subchannelPickerCaptor =
ArgumentCaptor.forClass(SubchannelPicker.class); ArgumentCaptor.forClass(SubchannelPicker.class);
verify(helper).updateBalancingState(same(CONNECTING), subchannelPickerCaptor.capture()); verify(helper).updateBalancingState(same(CONNECTING), subchannelPickerCaptor.capture());
@ -908,7 +888,6 @@ public class LocalityStoreTest {
assertThat(loadBalancers).hasSize(3); assertThat(loadBalancers).hasSize(3);
assertThat(loadBalancers.keySet()).containsExactly("sz1", "sz2", "sz3"); assertThat(loadBalancers.keySet()).containsExactly("sz1", "sz2", "sz3");
assertThat(pickerFactory.totalReadyLocalities).isEqualTo(0);
// Update locality weights before any subchannel becomes READY. // Update locality weights before any subchannel becomes READY.
localityInfo1 = new LocalityLbEndpoints(ImmutableList.of(lbEndpoint11, lbEndpoint12), 4, 0); localityInfo1 = new LocalityLbEndpoints(ImmutableList.of(lbEndpoint11, lbEndpoint12), 4, 0);
@ -918,8 +897,6 @@ public class LocalityStoreTest {
locality1, localityInfo1, locality2, localityInfo2, locality3, localityInfo3); locality1, localityInfo1, locality2, localityInfo2, locality3, localityInfo3);
localityStore.updateLocalityStore(localityInfoMap); localityStore.updateLocalityStore(localityInfoMap);
assertThat(pickerFactory.totalReadyLocalities).isEqualTo(0);
final Map<Subchannel, Locality> localitiesBySubchannel = new HashMap<>(); final Map<Subchannel, Locality> localitiesBySubchannel = new HashMap<>();
for (final Helper h : childHelpers.values()) { for (final Helper h : childHelpers.values()) {
h.updateBalancingState(READY, new SubchannelPicker() { h.updateBalancingState(READY, new SubchannelPicker() {
@ -932,10 +909,16 @@ public class LocalityStoreTest {
}); });
} }
assertThat(pickerFactory.totalReadyLocalities).isEqualTo(3); ArgumentCaptor<SubchannelPicker> subchannelPickerCaptor =
for (int i = 0; i < pickerFactory.totalReadyLocalities; i++) { ArgumentCaptor.forClass(SubchannelPicker.class);
verify(helper, atLeastOnce()).updateBalancingState(
same(READY), subchannelPickerCaptor.capture());
WeightedRandomPicker interLocalityPicker =
(WeightedRandomPicker) subchannelPickerCaptor.getValue();
assertThat(interLocalityPicker.weightedChildPickers).hasSize(3);
for (int i = 0; i < interLocalityPicker.weightedChildPickers.size(); i++) {
WeightedChildPicker weightedChildPicker WeightedChildPicker weightedChildPicker
= pickerFactory.perLocalitiesPickers.get(i); = interLocalityPicker.weightedChildPickers.get(i);
Subchannel subchannel Subchannel subchannel
= weightedChildPicker.getPicker().pickSubchannel(pickSubchannelArgs).getSubchannel(); = weightedChildPicker.getPicker().pickSubchannel(pickSubchannelArgs).getSubchannel();
assertThat(weightedChildPicker.getWeight()) assertThat(weightedChildPicker.getWeight())

View File

@ -24,7 +24,7 @@ import io.grpc.LoadBalancer.PickSubchannelArgs;
import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.Subchannel;
import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancer.SubchannelPicker;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.xds.InterLocalityPicker.WeightedChildPicker; import io.grpc.xds.WeightedRandomPicker.WeightedChildPicker;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
@ -38,10 +38,10 @@ import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule; import org.mockito.junit.MockitoRule;
/** /**
* Tests for {@link InterLocalityPicker}. * Tests for {@link WeightedRandomPicker}.
*/ */
@RunWith(JUnit4.class) @RunWith(JUnit4.class)
public class InterLocalityPickerTest { public class WeightedRandomPickerTest {
@Rule @Rule
public final ExpectedException thrown = ExpectedException.none(); public final ExpectedException thrown = ExpectedException.none();
@ -105,7 +105,7 @@ public class InterLocalityPickerTest {
List<WeightedChildPicker> emptyList = new ArrayList<>(); List<WeightedChildPicker> emptyList = new ArrayList<>();
thrown.expect(IllegalArgumentException.class); thrown.expect(IllegalArgumentException.class);
new InterLocalityPicker(emptyList); new WeightedRandomPicker(emptyList);
} }
@Test @Test
@ -121,7 +121,7 @@ public class InterLocalityPickerTest {
WeightedChildPicker weightedChildPicker2 = new WeightedChildPicker(0, childPicker2); WeightedChildPicker weightedChildPicker2 = new WeightedChildPicker(0, childPicker2);
WeightedChildPicker weightedChildPicker3 = new WeightedChildPicker(10, childPicker3); WeightedChildPicker weightedChildPicker3 = new WeightedChildPicker(10, childPicker3);
InterLocalityPicker xdsPicker = new InterLocalityPicker( WeightedRandomPicker xdsPicker = new WeightedRandomPicker(
Arrays.asList( Arrays.asList(
weightedChildPicker0, weightedChildPicker0,
weightedChildPicker1, weightedChildPicker1,
@ -157,7 +157,7 @@ public class InterLocalityPickerTest {
WeightedChildPicker weightedChildPicker2 = new WeightedChildPicker(0, childPicker2); WeightedChildPicker weightedChildPicker2 = new WeightedChildPicker(0, childPicker2);
WeightedChildPicker weightedChildPicker3 = new WeightedChildPicker(0, childPicker3); WeightedChildPicker weightedChildPicker3 = new WeightedChildPicker(0, childPicker3);
InterLocalityPicker xdsPicker = new InterLocalityPicker( WeightedRandomPicker xdsPicker = new WeightedRandomPicker(
Arrays.asList( Arrays.asList(
weightedChildPicker0, weightedChildPicker0,
weightedChildPicker1, weightedChildPicker1,

View File

@ -0,0 +1,139 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds;
import static com.google.common.truth.Truth.assertThat;
import static org.mockito.Mockito.mock;
import com.google.common.collect.ImmutableMap;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancer.Helper;
import io.grpc.LoadBalancerProvider;
import io.grpc.LoadBalancerRegistry;
import io.grpc.NameResolver.ConfigOrError;
import io.grpc.internal.JsonParser;
import io.grpc.internal.ServiceConfigUtil.PolicySelection;
import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedPolicySelection;
import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedTargetConfig;
import java.util.HashMap;
import java.util.Map;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Tests for {@link WeightedTargetLoadBalancerProvider}. */
@RunWith(JUnit4.class)
public class WeightedTargetLoadBalancerProviderTest {
@Test
public void parseWeightedTargetConfig() throws Exception {
LoadBalancerRegistry lbRegistry = new LoadBalancerRegistry();
WeightedTargetLoadBalancerProvider weightedTargetLoadBalancerProvider =
new WeightedTargetLoadBalancerProvider(lbRegistry);
final Object fooConfig = new Object();
LoadBalancerProvider lbProviderFoo = new LoadBalancerProvider() {
@Override
public boolean isAvailable() {
return true;
}
@Override
public int getPriority() {
return 5;
}
@Override
public String getPolicyName() {
return "foo_policy";
}
@Override
public LoadBalancer newLoadBalancer(Helper helper) {
return mock(LoadBalancer.class);
}
@Override
public ConfigOrError parseLoadBalancingPolicyConfig(Map<String, ?> rawConfig) {
return ConfigOrError.fromConfig(fooConfig);
}
};
final Object barConfig = new Object();
LoadBalancerProvider lbProviderBar = new LoadBalancerProvider() {
@Override
public boolean isAvailable() {
return true;
}
@Override
public int getPriority() {
return 5;
}
@Override
public String getPolicyName() {
return "bar_policy";
}
@Override
public LoadBalancer newLoadBalancer(Helper helper) {
return mock(LoadBalancer.class);
}
@Override
public ConfigOrError parseLoadBalancingPolicyConfig(Map<String, ?> rawConfig) {
return ConfigOrError.fromConfig(barConfig);
}
};
lbRegistry.register(lbProviderFoo);
lbRegistry.register(lbProviderBar);
String weightedTargetConfigJson = ("{"
+ " 'targets' : {"
+ " 'target_1' : {"
+ " 'weight' : 10,"
+ " 'childPolicy' : ["
+ " {'unsupported_policy' : {}},"
+ " {'foo_policy' : {}}"
+ " ]"
+ " },"
+ " 'target_2' : {"
+ " 'weight' : 20,"
+ " 'childPolicy' : ["
+ " {'unsupported_policy' : {}},"
+ " {'bar_policy' : {}}"
+ " ]"
+ " }"
+ " }"
+ "}").replace("'", "\"");
@SuppressWarnings("unchecked")
Map<String, ?> rawLbConfigMap = (Map<String, ?>) JsonParser.parse(weightedTargetConfigJson);
ConfigOrError parsedConfig =
weightedTargetLoadBalancerProvider.parseLoadBalancingPolicyConfig(rawLbConfigMap);
ConfigOrError expectedConfig = ConfigOrError.fromConfig(
new WeightedTargetConfig(ImmutableMap.of(
"target_1",
new WeightedPolicySelection(
10,
new PolicySelection(lbProviderFoo, new HashMap<String, Object>(), fooConfig)),
"target_2",
new WeightedPolicySelection(
20,
new PolicySelection(lbProviderBar, new HashMap<String, Object>(), barConfig)))));
assertThat(parsedConfig).isEqualTo(expectedConfig);
}
}

View File

@ -0,0 +1,344 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.ConnectivityState.CONNECTING;
import static io.grpc.ConnectivityState.READY;
import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
import static io.grpc.xds.XdsSubchannelPickers.BUFFER_PICKER;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.grpc.Attributes;
import io.grpc.ChannelLogger;
import io.grpc.EquivalentAddressGroup;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancer.Helper;
import io.grpc.LoadBalancer.PickResult;
import io.grpc.LoadBalancer.PickSubchannelArgs;
import io.grpc.LoadBalancer.ResolvedAddresses;
import io.grpc.LoadBalancer.SubchannelPicker;
import io.grpc.LoadBalancerProvider;
import io.grpc.LoadBalancerRegistry;
import io.grpc.Status;
import io.grpc.internal.ServiceConfigUtil.PolicySelection;
import io.grpc.xds.WeightedRandomPicker.WeightedChildPicker;
import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedPolicySelection;
import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedTargetConfig;
import io.grpc.xds.XdsSubchannelPickers.ErrorPicker;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
/** Tests for {@link WeightedTargetLoadBalancer}. */
@RunWith(JUnit4.class)
public class WeightedTargetLoadBalancerTest {
private final LoadBalancerRegistry lbRegistry = new LoadBalancerRegistry();
private final List<LoadBalancer> childBalancers = new ArrayList<>();
private final List<Helper> childHelpers = new ArrayList<>();
private final int[] weights = new int[]{10, 20, 30, 40};
private final Object[] configs = new Object[]{"config0", "config1", "config3", "config4"};
private final LoadBalancerProvider fooLbProvider = new LoadBalancerProvider() {
@Override
public boolean isAvailable() {
return true;
}
@Override
public int getPriority() {
return 5;
}
@Override
public String getPolicyName() {
return "foo_policy";
}
@Override
public LoadBalancer newLoadBalancer(Helper helper) {
childHelpers.add(helper);
LoadBalancer childBalancer = mock(LoadBalancer.class);
childBalancers.add(childBalancer);
fooLbCreated++;
return childBalancer;
}
};
private final LoadBalancerProvider barLbProvider = new LoadBalancerProvider() {
@Override
public boolean isAvailable() {
return true;
}
@Override
public int getPriority() {
return 5;
}
@Override
public String getPolicyName() {
return "bar_policy";
}
@Override
public LoadBalancer newLoadBalancer(Helper helper) {
childHelpers.add(helper);
LoadBalancer childBalancer = mock(LoadBalancer.class);
childBalancers.add(childBalancer);
barLbCreated++;
return childBalancer;
}
};
private final WeightedPolicySelection weightedLbConfig0 = new WeightedPolicySelection(
weights[0], new PolicySelection(fooLbProvider, null, configs[0]));
private final WeightedPolicySelection weightedLbConfig1 = new WeightedPolicySelection(
weights[1], new PolicySelection(barLbProvider, null, configs[1]));
private final WeightedPolicySelection weightedLbConfig2 = new WeightedPolicySelection(
weights[2], new PolicySelection(barLbProvider, null, configs[2]));
private final WeightedPolicySelection weightedLbConfig3 = new WeightedPolicySelection(
weights[3], new PolicySelection(fooLbProvider, null, configs[3]));
@Mock
private Helper helper;
@Mock
private ChannelLogger channelLogger;
private LoadBalancer weightedTargetLb;
private int fooLbCreated;
private int barLbCreated;
@Before
public void setUp() {
MockitoAnnotations.initMocks(this);
doReturn(channelLogger).when(helper).getChannelLogger();
lbRegistry.register(fooLbProvider);
lbRegistry.register(barLbProvider);
weightedTargetLb = new WeightedTargetLoadBalancer(helper);
}
@After
public void tearDown() {
weightedTargetLb.shutdown();
for (LoadBalancer childBalancer : childBalancers) {
verify(childBalancer).shutdown();
}
}
@Test
public void handleResolvedAddresses() {
ArgumentCaptor<ResolvedAddresses> resolvedAddressesCaptor = ArgumentCaptor.forClass(null);
Attributes.Key<Object> fakeKey = Attributes.Key.create("fake_key");
Object fakeValue = new Object();
Map<String, WeightedPolicySelection> targets = ImmutableMap.of(
// {foo, 10, config0}
"target0", weightedLbConfig0,
// {bar, 20, config1}
"target1", weightedLbConfig1,
// {bar, 30, config2}
"target2", weightedLbConfig2,
// {foo, 40, config3}
"target3", weightedLbConfig3);
weightedTargetLb.handleResolvedAddresses(
ResolvedAddresses.newBuilder()
.setAddresses(ImmutableList.<EquivalentAddressGroup>of())
.setAttributes(Attributes.newBuilder().set(fakeKey, fakeValue).build())
.setLoadBalancingPolicyConfig(new WeightedTargetConfig(targets))
.build());
assertThat(childBalancers).hasSize(4);
assertThat(childHelpers).hasSize(4);
assertThat(fooLbCreated).isEqualTo(2);
assertThat(barLbCreated).isEqualTo(2);
for (int i = 0; i < childBalancers.size(); i++) {
verify(childBalancers.get(i)).handleResolvedAddresses(resolvedAddressesCaptor.capture());
assertThat(resolvedAddressesCaptor.getValue().getLoadBalancingPolicyConfig())
.isEqualTo(configs[i]);
assertThat(resolvedAddressesCaptor.getValue().getAttributes().get(fakeKey))
.isEqualTo(fakeValue);
}
// Update new weighted target config for a typical workflow.
// target0 removed. target1, target2, target3 changed weight and config. target4 added.
int[] newWeights = new int[]{11, 22, 33, 44};
Object[] newConfigs = new Object[]{"newConfig1", "newConfig2", "newConfig3", "newConfig4"};
Map<String, WeightedPolicySelection> newTargets = ImmutableMap.of(
"target1",
new WeightedPolicySelection(
newWeights[0], new PolicySelection(barLbProvider, null, newConfigs[0])),
"target2",
new WeightedPolicySelection(
newWeights[1], new PolicySelection(barLbProvider, null, newConfigs[1])),
"target3",
new WeightedPolicySelection(
newWeights[2], new PolicySelection(fooLbProvider, null, newConfigs[2])),
"target4",
new WeightedPolicySelection(
newWeights[3], new PolicySelection(fooLbProvider, null, newConfigs[3])));
weightedTargetLb.handleResolvedAddresses(
ResolvedAddresses.newBuilder()
.setAddresses(ImmutableList.<EquivalentAddressGroup>of())
.setLoadBalancingPolicyConfig(new WeightedTargetConfig(newTargets))
.build());
assertThat(childBalancers).hasSize(5);
assertThat(childHelpers).hasSize(5);
assertThat(fooLbCreated).isEqualTo(3); // One more foo LB created for target4
assertThat(barLbCreated).isEqualTo(2);
verify(childBalancers.get(0)).shutdown();
for (int i = 1; i < childBalancers.size(); i++) {
verify(childBalancers.get(i), atLeastOnce())
.handleResolvedAddresses(resolvedAddressesCaptor.capture());
assertThat(resolvedAddressesCaptor.getValue().getLoadBalancingPolicyConfig())
.isEqualTo(newConfigs[i - 1]);
}
}
@Test
public void handleNameResolutionError() {
ArgumentCaptor<SubchannelPicker> pickerCaptor = ArgumentCaptor.forClass(null);
ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(null);
// Error before any child balancer created.
weightedTargetLb.handleNameResolutionError(Status.DATA_LOSS);
verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class));
assertThat(pickResult.getStatus().getCode()).isEqualTo(Status.Code.DATA_LOSS);
// Child configs updated.
Map<String, WeightedPolicySelection> targets = ImmutableMap.of(
// {foo, 10, config0}
"target0", weightedLbConfig0,
// {bar, 20, config1}
"target1", weightedLbConfig1,
// {bar, 30, config2}
"target2", weightedLbConfig2,
// {foo, 40, config3}
"target3", weightedLbConfig3);
weightedTargetLb.handleResolvedAddresses(
ResolvedAddresses.newBuilder()
.setAddresses(ImmutableList.<EquivalentAddressGroup>of())
.setLoadBalancingPolicyConfig(new WeightedTargetConfig(targets))
.build());
// Error after child balancers created.
weightedTargetLb.handleNameResolutionError(Status.ABORTED);
for (LoadBalancer childBalancer : childBalancers) {
verify(childBalancer).handleNameResolutionError(statusCaptor.capture());
assertThat(statusCaptor.getValue().getCode()).isEqualTo(Status.Code.ABORTED);
}
}
@Test
public void balancingStateUpdatedFromChildBalancers() {
Map<String, WeightedPolicySelection> targets = ImmutableMap.of(
// {foo, 10, config0}
"target0", weightedLbConfig0,
// {bar, 20, config1}
"target1", weightedLbConfig1,
// {bar, 30, config2}
"target2", weightedLbConfig2,
// {foo, 40, config3}
"target3", weightedLbConfig3);
weightedTargetLb.handleResolvedAddresses(
ResolvedAddresses.newBuilder()
.setAddresses(ImmutableList.<EquivalentAddressGroup>of())
.setLoadBalancingPolicyConfig(new WeightedTargetConfig(targets))
.build());
// Subchannels to be created for each child balancer.
final SubchannelPicker[] subchannelPickers = new SubchannelPicker[]{
mock(SubchannelPicker.class),
mock(SubchannelPicker.class),
mock(SubchannelPicker.class),
mock(SubchannelPicker.class)};
ArgumentCaptor<SubchannelPicker> pickerCaptor = ArgumentCaptor.forClass(null);
// One child balancer goes to TRANSIENT_FAILURE.
childHelpers.get(1).updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(Status.ABORTED));
verify(helper, never()).updateBalancingState(
eq(TRANSIENT_FAILURE), any(SubchannelPicker.class));
verify(helper).updateBalancingState(eq(CONNECTING), eq(BUFFER_PICKER));
// Another child balancer goes to READY.
childHelpers.get(2).updateBalancingState(READY, subchannelPickers[2]);
verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
assertThat(pickerCaptor.getValue()).isInstanceOf(WeightedRandomPicker.class);
WeightedRandomPicker overallPicker = (WeightedRandomPicker) pickerCaptor.getValue();
assertThat(overallPicker.weightedChildPickers).isEqualTo(
ImmutableList.of(new WeightedChildPicker(weights[2], subchannelPickers[2])));
// Another child balancer goes to READY.
childHelpers.get(3).updateBalancingState(READY, subchannelPickers[3]);
verify(helper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture());
overallPicker = (WeightedRandomPicker) pickerCaptor.getValue();
assertThat(overallPicker.weightedChildPickers).isEqualTo(
ImmutableList.of(
new WeightedChildPicker(weights[2], subchannelPickers[2]),
new WeightedChildPicker(weights[3], subchannelPickers[3])));
// Another child balancer goes to READY.
childHelpers.get(0).updateBalancingState(READY, subchannelPickers[0]);
verify(helper, times(3)).updateBalancingState(eq(READY), pickerCaptor.capture());
overallPicker = (WeightedRandomPicker) pickerCaptor.getValue();
assertThat(overallPicker.weightedChildPickers).isEqualTo(
ImmutableList.of(
new WeightedChildPicker(weights[0], subchannelPickers[0]),
new WeightedChildPicker(weights[2], subchannelPickers[2]),
new WeightedChildPicker(weights[3], subchannelPickers[3])));
// One of READY child balancers goes to TRANSIENT_FAILURE.
childHelpers.get(2).updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(Status.DATA_LOSS));
verify(helper, times(4)).updateBalancingState(eq(READY), pickerCaptor.capture());
overallPicker = (WeightedRandomPicker) pickerCaptor.getValue();
assertThat(overallPicker.weightedChildPickers).isEqualTo(
ImmutableList.of(
new WeightedChildPicker(weights[0], subchannelPickers[0]),
new WeightedChildPicker(weights[3], subchannelPickers[3])));
// All child balancers go to TRANSIENT_FAILURE.
childHelpers.get(3).updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(Status.DATA_LOSS));
childHelpers.get(0).updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(Status.CANCELLED));
verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class));
}
}