From 5e7b8c672fab7d67d642fc5d718b2b5c571f5429 Mon Sep 17 00:00:00 2001 From: ZHANG Dapeng Date: Wed, 11 Mar 2020 15:35:31 -0700 Subject: [PATCH] xds: implement WeightedTargetLoadBalancer --- .../main/java/io/grpc/xds/LocalityStore.java | 23 +- ...yPicker.java => WeightedRandomPicker.java} | 36 +- .../grpc/xds/WeightedTargetLoadBalancer.java | 193 ++++++++++ .../WeightedTargetLoadBalancerProvider.java | 198 ++++++++++ .../services/io.grpc.LoadBalancerProvider | 1 + .../java/io/grpc/xds/LocalityStoreTest.java | 113 +++--- ...est.java => WeightedRandomPickerTest.java} | 12 +- ...eightedTargetLoadBalancerProviderTest.java | 139 +++++++ .../xds/WeightedTargetLoadBalancerTest.java | 344 ++++++++++++++++++ 9 files changed, 959 insertions(+), 100 deletions(-) rename xds/src/main/java/io/grpc/xds/{InterLocalityPicker.java => WeightedRandomPicker.java} (77%) create mode 100644 xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancer.java create mode 100644 xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancerProvider.java rename xds/src/test/java/io/grpc/xds/{InterLocalityPickerTest.java => WeightedRandomPickerTest.java} (95%) create mode 100644 xds/src/test/java/io/grpc/xds/WeightedTargetLoadBalancerProviderTest.java create mode 100644 xds/src/test/java/io/grpc/xds/WeightedTargetLoadBalancerTest.java diff --git a/xds/src/main/java/io/grpc/xds/LocalityStore.java b/xds/src/main/java/io/grpc/xds/LocalityStore.java index d66203f116..2cfa1eeb9f 100644 --- a/xds/src/main/java/io/grpc/xds/LocalityStore.java +++ b/xds/src/main/java/io/grpc/xds/LocalityStore.java @@ -48,13 +48,12 @@ import io.grpc.xds.EnvoyProtoData.DropOverload; import io.grpc.xds.EnvoyProtoData.LbEndpoint; import io.grpc.xds.EnvoyProtoData.Locality; import io.grpc.xds.EnvoyProtoData.LocalityLbEndpoints; -import io.grpc.xds.InterLocalityPicker.WeightedChildPicker; import io.grpc.xds.OrcaOobUtil.OrcaReportingConfig; import io.grpc.xds.OrcaOobUtil.OrcaReportingHelperWrapper; +import io.grpc.xds.WeightedRandomPicker.WeightedChildPicker; import io.grpc.xds.XdsLogger.XdsLogLevel; import io.grpc.xds.XdsSubchannelPickers.ErrorPicker; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -109,7 +108,6 @@ interface LocalityStore { private final XdsLogger logger; private final Helper helper; - private final PickerFactory pickerFactory; private final LoadBalancerProvider loadBalancerProvider; private final ThreadSafeRandom random; private final LoadStatsStore loadStatsStore; @@ -130,7 +128,6 @@ interface LocalityStore { this( logId, helper, - pickerFactoryImpl, lbRegistry, ThreadSafeRandom.ThreadSafeRandomImpl.instance, loadStatsStore, @@ -142,14 +139,12 @@ interface LocalityStore { LocalityStoreImpl( InternalLogId logId, Helper helper, - PickerFactory pickerFactory, LoadBalancerRegistry lbRegistry, ThreadSafeRandom random, LoadStatsStore loadStatsStore, OrcaPerRequestUtil orcaPerRequestUtil, OrcaOobUtil orcaOobUtil) { this.helper = checkNotNull(helper, "helper"); - this.pickerFactory = checkNotNull(pickerFactory, "pickerFactory"); loadBalancerProvider = checkNotNull( lbRegistry.getProvider(ROUND_ROBIN), "Unable to find '%s' LoadBalancer", ROUND_ROBIN); @@ -160,11 +155,6 @@ interface LocalityStore { logger = XdsLogger.withLogId(checkNotNull(logId, "logId")); } - @VisibleForTesting // Introduced for testing only. - interface PickerFactory { - SubchannelPicker picker(List childPickers); - } - private final class DroppablePicker extends SubchannelPicker { final List dropOverloads; @@ -206,14 +196,6 @@ interface LocalityStore { } } - private static final PickerFactory pickerFactoryImpl = - new PickerFactory() { - @Override - public SubchannelPicker picker(List childPickers) { - return new InterLocalityPicker(childPickers); - } - }; - @Override public void reset() { for (Locality locality : localityMap.keySet()) { @@ -335,7 +317,6 @@ interface LocalityStore { private void updatePicker( @Nullable ConnectivityState state, List childPickers) { - childPickers = Collections.unmodifiableList(childPickers); SubchannelPicker picker; if (childPickers.isEmpty()) { if (state == TRANSIENT_FAILURE) { @@ -344,7 +325,7 @@ interface LocalityStore { picker = XdsSubchannelPickers.BUFFER_PICKER; } } else { - picker = pickerFactory.picker(childPickers); + picker = new WeightedRandomPicker(childPickers); } if (!dropOverloads.isEmpty()) { diff --git a/xds/src/main/java/io/grpc/xds/InterLocalityPicker.java b/xds/src/main/java/io/grpc/xds/WeightedRandomPicker.java similarity index 77% rename from xds/src/main/java/io/grpc/xds/InterLocalityPicker.java rename to xds/src/main/java/io/grpc/xds/WeightedRandomPicker.java index 48f9bb19eb..1f5fc6d01d 100644 --- a/xds/src/main/java/io/grpc/xds/InterLocalityPicker.java +++ b/xds/src/main/java/io/grpc/xds/WeightedRandomPicker.java @@ -21,21 +21,24 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; -import com.google.common.collect.ImmutableList; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.SubchannelPicker; +import java.util.Collections; import java.util.List; +import java.util.Objects; -final class InterLocalityPicker extends SubchannelPicker { +final class WeightedRandomPicker extends SubchannelPicker { + + @VisibleForTesting + final List weightedChildPickers; - private final List weightedChildPickers; private final ThreadSafeRandom random; private final int totalWeight; static final class WeightedChildPicker { - final int weight; - final SubchannelPicker childPicker; + private final int weight; + private final SubchannelPicker childPicker; WeightedChildPicker(int weight, SubchannelPicker childPicker) { checkArgument(weight >= 0, "weight is negative"); @@ -53,6 +56,23 @@ final class InterLocalityPicker extends SubchannelPicker { return childPicker; } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + WeightedChildPicker that = (WeightedChildPicker) o; + return weight == that.weight && Objects.equals(childPicker, that.childPicker); + } + + @Override + public int hashCode() { + return Objects.hash(weight, childPicker); + } + @Override public String toString() { return MoreObjects.toStringHelper(this) @@ -62,16 +82,16 @@ final class InterLocalityPicker extends SubchannelPicker { } } - InterLocalityPicker(List weightedChildPickers) { + WeightedRandomPicker(List weightedChildPickers) { this(weightedChildPickers, ThreadSafeRandom.ThreadSafeRandomImpl.instance); } @VisibleForTesting - InterLocalityPicker(List weightedChildPickers, ThreadSafeRandom random) { + WeightedRandomPicker(List weightedChildPickers, ThreadSafeRandom random) { checkNotNull(weightedChildPickers, "weightedChildPickers in null"); checkArgument(!weightedChildPickers.isEmpty(), "weightedChildPickers is empty"); - this.weightedChildPickers = ImmutableList.copyOf(weightedChildPickers); + this.weightedChildPickers = Collections.unmodifiableList(weightedChildPickers); int totalWeight = 0; for (WeightedChildPicker weightedChildPicker : weightedChildPickers) { diff --git a/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancer.java new file mode 100644 index 0000000000..23d32e4cc1 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancer.java @@ -0,0 +1,193 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.ConnectivityState.CONNECTING; +import static io.grpc.ConnectivityState.IDLE; +import static io.grpc.ConnectivityState.READY; +import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; +import static io.grpc.xds.XdsSubchannelPickers.BUFFER_PICKER; + +import com.google.common.collect.ImmutableMap; +import io.grpc.ConnectivityState; +import io.grpc.InternalLogId; +import io.grpc.LoadBalancer; +import io.grpc.Status; +import io.grpc.util.ForwardingLoadBalancerHelper; +import io.grpc.util.GracefulSwitchLoadBalancer; +import io.grpc.xds.WeightedRandomPicker.WeightedChildPicker; +import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedPolicySelection; +import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedTargetConfig; +import io.grpc.xds.XdsLogger.XdsLogLevel; +import io.grpc.xds.XdsSubchannelPickers.ErrorPicker; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import javax.annotation.Nullable; + +/** Load balancer for weighted_target policy. */ +final class WeightedTargetLoadBalancer extends LoadBalancer { + + private final XdsLogger logger; + private final Map childBalancers = new HashMap<>(); + private final Map childHelpers = new HashMap<>(); + private final Helper helper; + + private Map targets = ImmutableMap.of(); + + WeightedTargetLoadBalancer(Helper helper) { + this.helper = helper; + logger = XdsLogger.withLogId( + InternalLogId.allocate("weighted-target-lb", helper.getAuthority())); + logger.log(XdsLogLevel.INFO, "Created"); + } + + @Override + public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses); + Object lbConfig = resolvedAddresses.getLoadBalancingPolicyConfig(); + checkNotNull(lbConfig, "missing weighted_target lb config"); + + WeightedTargetConfig weightedTargetConfig = (WeightedTargetConfig) lbConfig; + Map newTargets = weightedTargetConfig.targets; + + for (String targetName : newTargets.keySet()) { + WeightedPolicySelection weightedChildLbConfig = newTargets.get(targetName); + if (!targets.containsKey(targetName)) { + ChildHelper childHelper = new ChildHelper(); + GracefulSwitchLoadBalancer childBalancer = new GracefulSwitchLoadBalancer(childHelper); + childBalancer.switchTo(weightedChildLbConfig.policySelection.getProvider()); + childHelpers.put(targetName, childHelper); + childBalancers.put(targetName, childBalancer); + } else if (!weightedChildLbConfig.policySelection.getProvider().equals( + targets.get(targetName).policySelection.getProvider())) { + childBalancers.get(targetName) + .switchTo(weightedChildLbConfig.policySelection.getProvider()); + } + } + + targets = newTargets; + + for (String targetName : targets.keySet()) { + childBalancers.get(targetName).handleResolvedAddresses( + resolvedAddresses.toBuilder() + .setLoadBalancingPolicyConfig(targets.get(targetName).policySelection.getConfig()) + .build()); + } + + // Cleanup removed targets. + // TODO(zdapeng): cache removed target for 15 minutes. + for (String targetName : childBalancers.keySet()) { + if (!targets.containsKey(targetName)) { + childBalancers.get(targetName).shutdown(); + } + } + childBalancers.keySet().retainAll(targets.keySet()); + childHelpers.keySet().retainAll(targets.keySet()); + } + + @Override + public void handleNameResolutionError(Status error) { + logger.log(XdsLogLevel.WARNING, "Received name resolution error: {0}", error); + if (childBalancers.isEmpty()) { + helper.updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(error)); + } + for (LoadBalancer childBalancer : childBalancers.values()) { + childBalancer.handleNameResolutionError(error); + } + } + + @Override + public boolean canHandleEmptyAddressListFromNameResolution() { + return true; + } + + @Override + public void shutdown() { + logger.log(XdsLogLevel.INFO, "Shutdown"); + for (LoadBalancer childBalancer : childBalancers.values()) { + childBalancer.shutdown(); + } + } + + private void updateOverallBalancingState() { + List childPickers = new ArrayList<>(); + + ConnectivityState overallState = null; + for (String name : targets.keySet()) { + ChildHelper childHelper = childHelpers.get(name); + ConnectivityState childState = childHelper.currentState; + overallState = aggregateState(overallState, childState); + if (READY == childState) { + int weight = targets.get(name).weight; + childPickers.add(new WeightedChildPicker(weight, childHelper.currentPicker)); + } + } + + SubchannelPicker picker; + if (childPickers.isEmpty()) { + if (overallState == TRANSIENT_FAILURE) { + picker = new ErrorPicker(Status.UNAVAILABLE); // TODO: more details in status + } else { + picker = XdsSubchannelPickers.BUFFER_PICKER; + } + } else { + picker = new WeightedRandomPicker(childPickers); + } + + if (overallState != null) { + helper.updateBalancingState(overallState, picker); + } + } + + @Nullable + private ConnectivityState aggregateState( + @Nullable ConnectivityState overallState, ConnectivityState childState) { + if (overallState == null) { + return childState; + } + if (overallState == READY || childState == READY) { + return READY; + } + if (overallState == CONNECTING || childState == CONNECTING) { + return CONNECTING; + } + if (overallState == IDLE || childState == IDLE) { + return IDLE; + } + return overallState; + } + + private final class ChildHelper extends ForwardingLoadBalancerHelper { + ConnectivityState currentState = CONNECTING; + SubchannelPicker currentPicker = BUFFER_PICKER; + + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + currentState = newState; + currentPicker = newPicker; + updateOverallBalancingState(); + } + + @Override + protected Helper delegate() { + return helper; + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancerProvider.java new file mode 100644 index 0000000000..5081fe985d --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancerProvider.java @@ -0,0 +1,198 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.MoreObjects; +import io.grpc.Internal; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; +import io.grpc.NameResolver.ConfigOrError; +import io.grpc.Status; +import io.grpc.internal.JsonUtil; +import io.grpc.internal.ServiceConfigUtil; +import io.grpc.internal.ServiceConfigUtil.LbConfig; +import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import javax.annotation.Nullable; + +/** + * The provider for the weighted_target balancing policy. This class should not be + * directly referenced in code. The policy should be accessed through {@link + * LoadBalancerRegistry#getProvider} with the name "weighted_target_experimental". + */ +@Internal +public final class WeightedTargetLoadBalancerProvider extends LoadBalancerProvider { + + static final String WEIGHTED_TARGET_POLICY_NAME = "weighted_target_experimental"; + + @Nullable + private final LoadBalancerRegistry lbRegistry; + + // We can not call this(LoadBalancerRegistry.getDefaultRegistry()), because it will get stuck + // recursively loading LoadBalancerRegistry and WeightedTargetLoadBalancerProvider. + public WeightedTargetLoadBalancerProvider() { + this(null); + } + + @VisibleForTesting + WeightedTargetLoadBalancerProvider(@Nullable LoadBalancerRegistry lbRegistry) { + this.lbRegistry = lbRegistry; + } + + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 5; + } + + @Override + public String getPolicyName() { + return WEIGHTED_TARGET_POLICY_NAME; + } + + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + return new WeightedTargetLoadBalancer(helper); + } + + @Override + public ConfigOrError parseLoadBalancingPolicyConfig(Map rawConfig) { + try { + Map targets = JsonUtil.getObject(rawConfig, "targets"); + if (targets == null || targets.isEmpty()) { + return ConfigOrError.fromError(Status.INTERNAL.withDescription( + "No targets provided for weighted_target LB policy:\n " + rawConfig)); + } + Map parsedChildConfigs = new LinkedHashMap<>(); + for (String name : targets.keySet()) { + Map rawWeightedTarget = JsonUtil.getObject(targets, name); + if (rawWeightedTarget == null || rawWeightedTarget.isEmpty()) { + return ConfigOrError.fromError(Status.INTERNAL.withDescription( + "No config for target " + name + " in weighted_target LB policy:\n " + rawConfig)); + } + Integer weight = JsonUtil.getNumberAsInteger(rawWeightedTarget, "weight"); + if (weight == null || weight < 1) { + return ConfigOrError.fromError(Status.INTERNAL.withDescription( + "Wrong weight for target " + name + " in weighted_target LB policy:\n " + rawConfig)); + } + List childConfigCandidates = ServiceConfigUtil.unwrapLoadBalancingConfigList( + JsonUtil.getListOfObjects(rawWeightedTarget, "childPolicy")); + if (childConfigCandidates == null || childConfigCandidates.isEmpty()) { + return ConfigOrError.fromError(Status.INTERNAL.withDescription( + "No child policy for target " + name + " in weighted_target LB policy:\n " + + rawConfig)); + } + LoadBalancerRegistry lbRegistry = + this.lbRegistry == null ? LoadBalancerRegistry.getDefaultRegistry() : this.lbRegistry; + ConfigOrError selectedConfig = + ServiceConfigUtil.selectLbPolicyFromList(childConfigCandidates, lbRegistry); + if (selectedConfig.getError() != null) { + return selectedConfig; + } + PolicySelection policySelection = (PolicySelection) selectedConfig.getConfig(); + parsedChildConfigs.put(name, new WeightedPolicySelection(weight, policySelection)); + } + return ConfigOrError.fromConfig(new WeightedTargetConfig(parsedChildConfigs)); + } catch (RuntimeException e) { + return ConfigOrError.fromError( + Status.fromThrowable(e).withDescription( + "Failed to parse weighted_target LB config: " + rawConfig)); + } + } + + static final class WeightedPolicySelection { + + final int weight; + final PolicySelection policySelection; + + @VisibleForTesting + WeightedPolicySelection(int weight, PolicySelection policySelection) { + this.weight = weight; + this.policySelection = policySelection; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + WeightedPolicySelection that = (WeightedPolicySelection) o; + return weight == that.weight && Objects.equals(policySelection, that.policySelection); + } + + @Override + public int hashCode() { + return Objects.hash(weight, policySelection); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("weight", weight) + .add("policySelection", policySelection) + .toString(); + } + } + + /** The lb config for WeightedTargetLoadBalancer. */ + static final class WeightedTargetConfig { + + final Map targets; + + @VisibleForTesting + WeightedTargetConfig(Map targets) { + this.targets = targets; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + WeightedTargetConfig that = (WeightedTargetConfig) o; + return Objects.equals(targets, that.targets); + } + + @Override + public int hashCode() { + return Objects.hash(targets); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("targets", targets) + .toString(); + } + } +} diff --git a/xds/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider b/xds/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider index d430ca289a..c85e6d0c4a 100644 --- a/xds/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider +++ b/xds/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider @@ -1,4 +1,5 @@ io.grpc.xds.CdsLoadBalancerProvider io.grpc.xds.EdsLoadBalancerProvider +io.grpc.xds.WeightedTargetLoadBalancerProvider io.grpc.xds.XdsLoadBalancerProvider io.grpc.xds.XdsRoutingLoadBalancerProvider diff --git a/xds/src/test/java/io/grpc/xds/LocalityStoreTest.java b/xds/src/test/java/io/grpc/xds/LocalityStoreTest.java index ddf38a73a4..9d3d6054ee 100644 --- a/xds/src/test/java/io/grpc/xds/LocalityStoreTest.java +++ b/xds/src/test/java/io/grpc/xds/LocalityStoreTest.java @@ -68,20 +68,18 @@ import io.grpc.xds.EnvoyProtoData.DropOverload; import io.grpc.xds.EnvoyProtoData.LbEndpoint; import io.grpc.xds.EnvoyProtoData.Locality; import io.grpc.xds.EnvoyProtoData.LocalityLbEndpoints; -import io.grpc.xds.InterLocalityPicker.WeightedChildPicker; import io.grpc.xds.LocalityStore.LocalityStoreImpl; -import io.grpc.xds.LocalityStore.LocalityStoreImpl.PickerFactory; import io.grpc.xds.OrcaOobUtil.OrcaOobReportListener; import io.grpc.xds.OrcaOobUtil.OrcaReportingConfig; import io.grpc.xds.OrcaOobUtil.OrcaReportingHelperWrapper; import io.grpc.xds.OrcaPerRequestUtil.OrcaPerRequestReportListener; +import io.grpc.xds.WeightedRandomPicker.WeightedChildPicker; import io.grpc.xds.XdsSubchannelPickers.ErrorPicker; import java.net.InetSocketAddress; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.TimeUnit; @@ -108,25 +106,6 @@ public class LocalityStoreTest { @Rule public final MockitoRule mockitoRule = MockitoJUnit.rule(); - private static final class FakePickerFactory implements PickerFactory { - int totalReadyLocalities; - int nextIndex; - List perLocalitiesPickers; - - @Override - public SubchannelPicker picker(final List childPickers) { - totalReadyLocalities = childPickers.size(); - perLocalitiesPickers = Collections.unmodifiableList(childPickers); - - return new SubchannelPicker() { - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return childPickers.get(nextIndex).getPicker().pickSubchannel(args); - } - }; - } - } - private final SynchronizationContext syncContext = new SynchronizationContext( new Thread.UncaughtExceptionHandler() { @Override @@ -182,8 +161,6 @@ public class LocalityStoreTest { } }; - private final FakePickerFactory pickerFactory = new FakePickerFactory(); - private final Locality locality1 = new Locality("r1", "z1", "sz1"); private final Locality locality2 = new Locality("r2", "z2", "sz2"); private final Locality locality3 = new Locality("r3", "z3", "sz3"); @@ -253,7 +230,7 @@ public class LocalityStoreTest { }); lbRegistry.register(lbProvider); localityStore = - new LocalityStoreImpl(logId, helper, pickerFactory, lbRegistry, random, loadStatsStore, + new LocalityStoreImpl(logId, helper, lbRegistry, random, loadStatsStore, orcaPerRequestUtil, orcaOobUtil); } @@ -302,7 +279,6 @@ public class LocalityStoreTest { // Two child balancers are created. assertThat(loadBalancers).hasSize(2); - assertThat(pickerFactory.totalReadyLocalities).isEqualTo(0); ClientStreamTracer.Factory metricsTracingFactory1 = mock(ClientStreamTracer.Factory.class); ClientStreamTracer.Factory metricsTracingFactory2 = mock(ClientStreamTracer.Factory.class); @@ -326,10 +302,11 @@ public class LocalityStoreTest { childHelpers.get("sz1").updateBalancingState(READY, subchannelPicker1); childHelpers.get("sz2").updateBalancingState(READY, subchannelPicker2); - assertThat(pickerFactory.totalReadyLocalities).isEqualTo(2); ArgumentCaptor interLocalityPickerCaptor = ArgumentCaptor.forClass(null); verify(helper, times(2)).updateBalancingState(eq(READY), interLocalityPickerCaptor.capture()); - SubchannelPicker interLocalityPicker = interLocalityPickerCaptor.getValue(); + WeightedRandomPicker interLocalityPicker = + (WeightedRandomPicker) interLocalityPickerCaptor.getValue(); + assertThat(interLocalityPicker.weightedChildPickers).hasSize(2); // Verify each PickResult picked is intercepted with client stream tracer factory for // recording load and backend metrics. @@ -337,9 +314,9 @@ public class LocalityStoreTest { = ImmutableMap.of(subchannel1, locality1, subchannel2, locality2); Map metricsTracingFactoriesBySubchannel = ImmutableMap.of(subchannel1, metricsTracingFactory1, subchannel2, metricsTracingFactory2); - for (int i = 0; i < pickerFactory.totalReadyLocalities; i++) { - pickerFactory.nextIndex = i; - PickResult pickResult = interLocalityPicker.pickSubchannel(pickSubchannelArgs); + for (int i = 0; i < interLocalityPicker.weightedChildPickers.size(); i++) { + PickResult pickResult = interLocalityPicker.weightedChildPickers.get(i).getPicker() + .pickSubchannel(pickSubchannelArgs); Subchannel expectedSubchannel = pickResult.getSubchannel(); Locality expectedLocality = localitiesBySubchannel.get(expectedSubchannel); ArgumentCaptor listenerCaptor = ArgumentCaptor.forClass(null); @@ -466,7 +443,6 @@ public class LocalityStoreTest { ArgumentCaptor.forClass(ResolvedAddresses.class); verify(loadBalancers.get("sz3")).handleResolvedAddresses(resolvedAddressesCaptor3.capture()); assertThat(resolvedAddressesCaptor3.getValue().getAddresses()).containsExactly(eag31, eag32); - assertThat(pickerFactory.totalReadyLocalities).isEqualTo(0); // verify no more updateBalancingState except the initial CONNECTING state verify(helper, times(1)).updateBalancingState( any(ConnectivityState.class), any(SubchannelPicker.class)); @@ -484,7 +460,6 @@ public class LocalityStoreTest { ArgumentCaptor.forClass(SubchannelPicker.class); verify(helper, times(2)).updateBalancingState( same(CONNECTING), subchannelPickerCaptor12.capture()); - assertThat(pickerFactory.totalReadyLocalities).isEqualTo(0); assertThat(subchannelPickerCaptor12.getValue().pickSubchannel(pickSubchannelArgs)) .isEqualTo(PickResult.withNoResult()); @@ -500,21 +475,22 @@ public class LocalityStoreTest { ArgumentCaptor subchannelPickerCaptor = ArgumentCaptor.forClass(null); verify(helper).updateBalancingState(same(READY), subchannelPickerCaptor.capture()); - assertThat(pickerFactory.totalReadyLocalities).isEqualTo(1); - pickerFactory.nextIndex = 0; - assertThat(subchannelPickerCaptor.getValue().pickSubchannel(pickSubchannelArgs).getSubchannel()) + WeightedRandomPicker interLocalityPicker = + (WeightedRandomPicker) subchannelPickerCaptor.getValue(); + assertThat(interLocalityPicker.weightedChildPickers).hasSize(1); + assertThat(interLocalityPicker.pickSubchannel(pickSubchannelArgs).getSubchannel()) .isEqualTo(subchannel31); // subchannel12 goes to READY childHelpers.get("sz1").updateBalancingState(READY, subchannelPicker12); verify(helper, times(2)).updateBalancingState(same(READY), subchannelPickerCaptor.capture()); - assertThat(pickerFactory.totalReadyLocalities).isEqualTo(2); + interLocalityPicker = (WeightedRandomPicker) subchannelPickerCaptor.getValue(); + assertThat(interLocalityPicker.weightedChildPickers).hasSize(2); - SubchannelPicker interLocalityPicker = subchannelPickerCaptor.getValue(); Set pickedReadySubchannels = new HashSet<>(); - for (int i = 0; i < pickerFactory.totalReadyLocalities; i++) { - pickerFactory.nextIndex = i; - PickResult result = interLocalityPicker.pickSubchannel(pickSubchannelArgs); + for (int i = 0; i < interLocalityPicker.weightedChildPickers.size(); i++) { + PickResult result = interLocalityPicker.weightedChildPickers.get(i).getPicker() + .pickSubchannel(pickSubchannelArgs); pickedReadySubchannels.add(result.getSubchannel()); } assertThat(pickedReadySubchannels).containsExactly(subchannel31, subchannel12); @@ -539,7 +515,9 @@ public class LocalityStoreTest { verify(loadBalancers.get("sz1"), times(2)) .handleResolvedAddresses(resolvedAddressesCaptor1.capture()); assertThat(resolvedAddressesCaptor1.getValue().getAddresses()).containsExactly(eag11); - assertThat(pickerFactory.totalReadyLocalities).isEqualTo(1); + verify(helper, times(3)).updateBalancingState(same(READY), subchannelPickerCaptor.capture()); + interLocalityPicker = (WeightedRandomPicker) subchannelPickerCaptor.getValue(); + assertThat(interLocalityPicker.weightedChildPickers).hasSize(1); fakeClock.forwardTime(14, TimeUnit.MINUTES); verify(loadBalancers.get("sz3"), never()).shutdown(); @@ -598,9 +576,10 @@ public class LocalityStoreTest { // helper updated multiple times. Don't care how many times, just capture the latest picker verify(helper, atLeastOnce()).updateBalancingState( same(READY), subchannelPickerCaptor.capture()); - assertThat(pickerFactory.totalReadyLocalities).isEqualTo(1); - pickerFactory.nextIndex = 0; - assertThat(subchannelPickerCaptor.getValue().pickSubchannel(pickSubchannelArgs).getSubchannel()) + WeightedRandomPicker interLocalityPicker = + (WeightedRandomPicker) subchannelPickerCaptor.getValue(); + assertThat(interLocalityPicker.weightedChildPickers).hasSize(1); + assertThat(interLocalityPicker.pickSubchannel(pickSubchannelArgs).getSubchannel()) .isEqualTo(subchannel3); // verify no traffic will go to deactivated locality @@ -614,9 +593,10 @@ public class LocalityStoreTest { childHelpers.get("sz2").updateBalancingState(READY, subchannelPicker2); verify(helper, atLeastOnce()).updateBalancingState( same(READY), subchannelPickerCaptor.capture()); - assertThat(pickerFactory.totalReadyLocalities).isEqualTo(1); - pickerFactory.nextIndex = 0; - assertThat(subchannelPickerCaptor.getValue().pickSubchannel(pickSubchannelArgs).getSubchannel()) + interLocalityPicker = + (WeightedRandomPicker) subchannelPickerCaptor.getValue(); + assertThat(interLocalityPicker.weightedChildPickers).hasSize(1); + assertThat(interLocalityPicker.pickSubchannel(pickSubchannelArgs).getSubchannel()) .isEqualTo(subchannel3); // update localities, reactivating sz1 @@ -625,13 +605,13 @@ public class LocalityStoreTest { localityStore.updateLocalityStore(localityInfoMap); verify(helper, atLeastOnce()).updateBalancingState( same(READY), subchannelPickerCaptor.capture()); - assertThat(pickerFactory.totalReadyLocalities).isEqualTo(2); - pickerFactory.nextIndex = 0; - assertThat(subchannelPickerCaptor.getValue().pickSubchannel(pickSubchannelArgs).getSubchannel()) - .isEqualTo(subchannel1); - pickerFactory.nextIndex = 1; - assertThat(subchannelPickerCaptor.getValue().pickSubchannel(pickSubchannelArgs).getSubchannel()) - .isEqualTo(subchannel3); + interLocalityPicker = + (WeightedRandomPicker) subchannelPickerCaptor.getValue(); + assertThat(interLocalityPicker.weightedChildPickers).hasSize(2); + assertThat(interLocalityPicker.weightedChildPickers.get(0).getPicker() + .pickSubchannel(pickSubchannelArgs).getSubchannel()).isEqualTo(subchannel1); + assertThat(interLocalityPicker.weightedChildPickers.get(1).getPicker() + .pickSubchannel(pickSubchannelArgs).getSubchannel()).isEqualTo(subchannel3); verify(lb2, never()).shutdown(); // delayed deletion timer expires, no reactivation @@ -648,9 +628,10 @@ public class LocalityStoreTest { verify(helper, atLeastOnce()).updateBalancingState( same(READY), subchannelPickerCaptor.capture()); - assertThat(pickerFactory.totalReadyLocalities).isEqualTo(1); - pickerFactory.nextIndex = 0; - assertThat(subchannelPickerCaptor.getValue().pickSubchannel(pickSubchannelArgs).getSubchannel()) + interLocalityPicker = + (WeightedRandomPicker) subchannelPickerCaptor.getValue(); + assertThat(interLocalityPicker.weightedChildPickers).hasSize(1); + assertThat(interLocalityPicker.pickSubchannel(pickSubchannelArgs).getSubchannel()) .isEqualTo(subchannel1); // sz3, sz4 pending removal assertThat(fakeClock.getPendingTasks(deactivationTaskFilter)).hasSize(2); @@ -701,7 +682,6 @@ public class LocalityStoreTest { ArgumentCaptor.forClass(ResolvedAddresses.class); verify(loadBalancers.get("sz3")).handleResolvedAddresses(resolvedAddressesCaptor3.capture()); assertThat(resolvedAddressesCaptor3.getValue().getAddresses()).containsExactly(eag31, eag32); - assertThat(pickerFactory.totalReadyLocalities).isEqualTo(0); ArgumentCaptor subchannelPickerCaptor = ArgumentCaptor.forClass(SubchannelPicker.class); verify(helper).updateBalancingState(same(CONNECTING), subchannelPickerCaptor.capture()); @@ -908,7 +888,6 @@ public class LocalityStoreTest { assertThat(loadBalancers).hasSize(3); assertThat(loadBalancers.keySet()).containsExactly("sz1", "sz2", "sz3"); - assertThat(pickerFactory.totalReadyLocalities).isEqualTo(0); // Update locality weights before any subchannel becomes READY. localityInfo1 = new LocalityLbEndpoints(ImmutableList.of(lbEndpoint11, lbEndpoint12), 4, 0); @@ -918,8 +897,6 @@ public class LocalityStoreTest { locality1, localityInfo1, locality2, localityInfo2, locality3, localityInfo3); localityStore.updateLocalityStore(localityInfoMap); - assertThat(pickerFactory.totalReadyLocalities).isEqualTo(0); - final Map localitiesBySubchannel = new HashMap<>(); for (final Helper h : childHelpers.values()) { h.updateBalancingState(READY, new SubchannelPicker() { @@ -932,10 +909,16 @@ public class LocalityStoreTest { }); } - assertThat(pickerFactory.totalReadyLocalities).isEqualTo(3); - for (int i = 0; i < pickerFactory.totalReadyLocalities; i++) { + ArgumentCaptor subchannelPickerCaptor = + ArgumentCaptor.forClass(SubchannelPicker.class); + verify(helper, atLeastOnce()).updateBalancingState( + same(READY), subchannelPickerCaptor.capture()); + WeightedRandomPicker interLocalityPicker = + (WeightedRandomPicker) subchannelPickerCaptor.getValue(); + assertThat(interLocalityPicker.weightedChildPickers).hasSize(3); + for (int i = 0; i < interLocalityPicker.weightedChildPickers.size(); i++) { WeightedChildPicker weightedChildPicker - = pickerFactory.perLocalitiesPickers.get(i); + = interLocalityPicker.weightedChildPickers.get(i); Subchannel subchannel = weightedChildPicker.getPicker().pickSubchannel(pickSubchannelArgs).getSubchannel(); assertThat(weightedChildPicker.getWeight()) diff --git a/xds/src/test/java/io/grpc/xds/InterLocalityPickerTest.java b/xds/src/test/java/io/grpc/xds/WeightedRandomPickerTest.java similarity index 95% rename from xds/src/test/java/io/grpc/xds/InterLocalityPickerTest.java rename to xds/src/test/java/io/grpc/xds/WeightedRandomPickerTest.java index 4b7f752259..e88f9cbfba 100644 --- a/xds/src/test/java/io/grpc/xds/InterLocalityPickerTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedRandomPickerTest.java @@ -24,7 +24,7 @@ import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.Status; -import io.grpc.xds.InterLocalityPicker.WeightedChildPicker; +import io.grpc.xds.WeightedRandomPicker.WeightedChildPicker; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -38,10 +38,10 @@ import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; /** - * Tests for {@link InterLocalityPicker}. + * Tests for {@link WeightedRandomPicker}. */ @RunWith(JUnit4.class) -public class InterLocalityPickerTest { +public class WeightedRandomPickerTest { @Rule public final ExpectedException thrown = ExpectedException.none(); @@ -105,7 +105,7 @@ public class InterLocalityPickerTest { List emptyList = new ArrayList<>(); thrown.expect(IllegalArgumentException.class); - new InterLocalityPicker(emptyList); + new WeightedRandomPicker(emptyList); } @Test @@ -121,7 +121,7 @@ public class InterLocalityPickerTest { WeightedChildPicker weightedChildPicker2 = new WeightedChildPicker(0, childPicker2); WeightedChildPicker weightedChildPicker3 = new WeightedChildPicker(10, childPicker3); - InterLocalityPicker xdsPicker = new InterLocalityPicker( + WeightedRandomPicker xdsPicker = new WeightedRandomPicker( Arrays.asList( weightedChildPicker0, weightedChildPicker1, @@ -157,7 +157,7 @@ public class InterLocalityPickerTest { WeightedChildPicker weightedChildPicker2 = new WeightedChildPicker(0, childPicker2); WeightedChildPicker weightedChildPicker3 = new WeightedChildPicker(0, childPicker3); - InterLocalityPicker xdsPicker = new InterLocalityPicker( + WeightedRandomPicker xdsPicker = new WeightedRandomPicker( Arrays.asList( weightedChildPicker0, weightedChildPicker1, diff --git a/xds/src/test/java/io/grpc/xds/WeightedTargetLoadBalancerProviderTest.java b/xds/src/test/java/io/grpc/xds/WeightedTargetLoadBalancerProviderTest.java new file mode 100644 index 0000000000..bcdf6a42a9 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/WeightedTargetLoadBalancerProviderTest.java @@ -0,0 +1,139 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.mock; + +import com.google.common.collect.ImmutableMap; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; +import io.grpc.NameResolver.ConfigOrError; +import io.grpc.internal.JsonParser; +import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedPolicySelection; +import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedTargetConfig; +import java.util.HashMap; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link WeightedTargetLoadBalancerProvider}. */ +@RunWith(JUnit4.class) +public class WeightedTargetLoadBalancerProviderTest { + + @Test + public void parseWeightedTargetConfig() throws Exception { + LoadBalancerRegistry lbRegistry = new LoadBalancerRegistry(); + WeightedTargetLoadBalancerProvider weightedTargetLoadBalancerProvider = + new WeightedTargetLoadBalancerProvider(lbRegistry); + final Object fooConfig = new Object(); + LoadBalancerProvider lbProviderFoo = new LoadBalancerProvider() { + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 5; + } + + @Override + public String getPolicyName() { + return "foo_policy"; + } + + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + return mock(LoadBalancer.class); + } + + @Override + public ConfigOrError parseLoadBalancingPolicyConfig(Map rawConfig) { + return ConfigOrError.fromConfig(fooConfig); + } + }; + final Object barConfig = new Object(); + LoadBalancerProvider lbProviderBar = new LoadBalancerProvider() { + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 5; + } + + @Override + public String getPolicyName() { + return "bar_policy"; + } + + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + return mock(LoadBalancer.class); + } + + @Override + public ConfigOrError parseLoadBalancingPolicyConfig(Map rawConfig) { + return ConfigOrError.fromConfig(barConfig); + } + }; + lbRegistry.register(lbProviderFoo); + lbRegistry.register(lbProviderBar); + + String weightedTargetConfigJson = ("{" + + " 'targets' : {" + + " 'target_1' : {" + + " 'weight' : 10," + + " 'childPolicy' : [" + + " {'unsupported_policy' : {}}," + + " {'foo_policy' : {}}" + + " ]" + + " }," + + " 'target_2' : {" + + " 'weight' : 20," + + " 'childPolicy' : [" + + " {'unsupported_policy' : {}}," + + " {'bar_policy' : {}}" + + " ]" + + " }" + + " }" + + "}").replace("'", "\""); + + @SuppressWarnings("unchecked") + Map rawLbConfigMap = (Map) JsonParser.parse(weightedTargetConfigJson); + ConfigOrError parsedConfig = + weightedTargetLoadBalancerProvider.parseLoadBalancingPolicyConfig(rawLbConfigMap); + ConfigOrError expectedConfig = ConfigOrError.fromConfig( + new WeightedTargetConfig(ImmutableMap.of( + "target_1", + new WeightedPolicySelection( + 10, + new PolicySelection(lbProviderFoo, new HashMap(), fooConfig)), + "target_2", + new WeightedPolicySelection( + 20, + new PolicySelection(lbProviderBar, new HashMap(), barConfig))))); + assertThat(parsedConfig).isEqualTo(expectedConfig); + } +} diff --git a/xds/src/test/java/io/grpc/xds/WeightedTargetLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WeightedTargetLoadBalancerTest.java new file mode 100644 index 0000000000..48c2286fdd --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/WeightedTargetLoadBalancerTest.java @@ -0,0 +1,344 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.ConnectivityState.CONNECTING; +import static io.grpc.ConnectivityState.READY; +import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; +import static io.grpc.xds.XdsSubchannelPickers.BUFFER_PICKER; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.grpc.Attributes; +import io.grpc.ChannelLogger; +import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.PickResult; +import io.grpc.LoadBalancer.PickSubchannelArgs; +import io.grpc.LoadBalancer.ResolvedAddresses; +import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; +import io.grpc.Status; +import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import io.grpc.xds.WeightedRandomPicker.WeightedChildPicker; +import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedPolicySelection; +import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedTargetConfig; +import io.grpc.xds.XdsSubchannelPickers.ErrorPicker; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +/** Tests for {@link WeightedTargetLoadBalancer}. */ +@RunWith(JUnit4.class) +public class WeightedTargetLoadBalancerTest { + + private final LoadBalancerRegistry lbRegistry = new LoadBalancerRegistry(); + private final List childBalancers = new ArrayList<>(); + private final List childHelpers = new ArrayList<>(); + private final int[] weights = new int[]{10, 20, 30, 40}; + private final Object[] configs = new Object[]{"config0", "config1", "config3", "config4"}; + + private final LoadBalancerProvider fooLbProvider = new LoadBalancerProvider() { + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 5; + } + + @Override + public String getPolicyName() { + return "foo_policy"; + } + + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + childHelpers.add(helper); + LoadBalancer childBalancer = mock(LoadBalancer.class); + childBalancers.add(childBalancer); + fooLbCreated++; + return childBalancer; + } + }; + + private final LoadBalancerProvider barLbProvider = new LoadBalancerProvider() { + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 5; + } + + @Override + public String getPolicyName() { + return "bar_policy"; + } + + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + childHelpers.add(helper); + LoadBalancer childBalancer = mock(LoadBalancer.class); + childBalancers.add(childBalancer); + barLbCreated++; + return childBalancer; + } + }; + + private final WeightedPolicySelection weightedLbConfig0 = new WeightedPolicySelection( + weights[0], new PolicySelection(fooLbProvider, null, configs[0])); + private final WeightedPolicySelection weightedLbConfig1 = new WeightedPolicySelection( + weights[1], new PolicySelection(barLbProvider, null, configs[1])); + private final WeightedPolicySelection weightedLbConfig2 = new WeightedPolicySelection( + weights[2], new PolicySelection(barLbProvider, null, configs[2])); + private final WeightedPolicySelection weightedLbConfig3 = new WeightedPolicySelection( + weights[3], new PolicySelection(fooLbProvider, null, configs[3])); + + @Mock + private Helper helper; + @Mock + private ChannelLogger channelLogger; + + private LoadBalancer weightedTargetLb; + private int fooLbCreated; + private int barLbCreated; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + + doReturn(channelLogger).when(helper).getChannelLogger(); + lbRegistry.register(fooLbProvider); + lbRegistry.register(barLbProvider); + + weightedTargetLb = new WeightedTargetLoadBalancer(helper); + } + + @After + public void tearDown() { + weightedTargetLb.shutdown(); + for (LoadBalancer childBalancer : childBalancers) { + verify(childBalancer).shutdown(); + } + } + + @Test + public void handleResolvedAddresses() { + ArgumentCaptor resolvedAddressesCaptor = ArgumentCaptor.forClass(null); + Attributes.Key fakeKey = Attributes.Key.create("fake_key"); + Object fakeValue = new Object(); + + Map targets = ImmutableMap.of( + // {foo, 10, config0} + "target0", weightedLbConfig0, + // {bar, 20, config1} + "target1", weightedLbConfig1, + // {bar, 30, config2} + "target2", weightedLbConfig2, + // {foo, 40, config3} + "target3", weightedLbConfig3); + weightedTargetLb.handleResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setAttributes(Attributes.newBuilder().set(fakeKey, fakeValue).build()) + .setLoadBalancingPolicyConfig(new WeightedTargetConfig(targets)) + .build()); + + assertThat(childBalancers).hasSize(4); + assertThat(childHelpers).hasSize(4); + assertThat(fooLbCreated).isEqualTo(2); + assertThat(barLbCreated).isEqualTo(2); + + for (int i = 0; i < childBalancers.size(); i++) { + verify(childBalancers.get(i)).handleResolvedAddresses(resolvedAddressesCaptor.capture()); + assertThat(resolvedAddressesCaptor.getValue().getLoadBalancingPolicyConfig()) + .isEqualTo(configs[i]); + assertThat(resolvedAddressesCaptor.getValue().getAttributes().get(fakeKey)) + .isEqualTo(fakeValue); + } + + // Update new weighted target config for a typical workflow. + // target0 removed. target1, target2, target3 changed weight and config. target4 added. + int[] newWeights = new int[]{11, 22, 33, 44}; + Object[] newConfigs = new Object[]{"newConfig1", "newConfig2", "newConfig3", "newConfig4"}; + Map newTargets = ImmutableMap.of( + "target1", + new WeightedPolicySelection( + newWeights[0], new PolicySelection(barLbProvider, null, newConfigs[0])), + "target2", + new WeightedPolicySelection( + newWeights[1], new PolicySelection(barLbProvider, null, newConfigs[1])), + "target3", + new WeightedPolicySelection( + newWeights[2], new PolicySelection(fooLbProvider, null, newConfigs[2])), + "target4", + new WeightedPolicySelection( + newWeights[3], new PolicySelection(fooLbProvider, null, newConfigs[3]))); + weightedTargetLb.handleResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setLoadBalancingPolicyConfig(new WeightedTargetConfig(newTargets)) + .build()); + + assertThat(childBalancers).hasSize(5); + assertThat(childHelpers).hasSize(5); + assertThat(fooLbCreated).isEqualTo(3); // One more foo LB created for target4 + assertThat(barLbCreated).isEqualTo(2); + + verify(childBalancers.get(0)).shutdown(); + for (int i = 1; i < childBalancers.size(); i++) { + verify(childBalancers.get(i), atLeastOnce()) + .handleResolvedAddresses(resolvedAddressesCaptor.capture()); + assertThat(resolvedAddressesCaptor.getValue().getLoadBalancingPolicyConfig()) + .isEqualTo(newConfigs[i - 1]); + } + } + + @Test + public void handleNameResolutionError() { + ArgumentCaptor pickerCaptor = ArgumentCaptor.forClass(null); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(null); + + // Error before any child balancer created. + weightedTargetLb.handleNameResolutionError(Status.DATA_LOSS); + + verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class)); + assertThat(pickResult.getStatus().getCode()).isEqualTo(Status.Code.DATA_LOSS); + + // Child configs updated. + Map targets = ImmutableMap.of( + // {foo, 10, config0} + "target0", weightedLbConfig0, + // {bar, 20, config1} + "target1", weightedLbConfig1, + // {bar, 30, config2} + "target2", weightedLbConfig2, + // {foo, 40, config3} + "target3", weightedLbConfig3); + weightedTargetLb.handleResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setLoadBalancingPolicyConfig(new WeightedTargetConfig(targets)) + .build()); + + // Error after child balancers created. + weightedTargetLb.handleNameResolutionError(Status.ABORTED); + + for (LoadBalancer childBalancer : childBalancers) { + verify(childBalancer).handleNameResolutionError(statusCaptor.capture()); + assertThat(statusCaptor.getValue().getCode()).isEqualTo(Status.Code.ABORTED); + } + } + + @Test + public void balancingStateUpdatedFromChildBalancers() { + Map targets = ImmutableMap.of( + // {foo, 10, config0} + "target0", weightedLbConfig0, + // {bar, 20, config1} + "target1", weightedLbConfig1, + // {bar, 30, config2} + "target2", weightedLbConfig2, + // {foo, 40, config3} + "target3", weightedLbConfig3); + weightedTargetLb.handleResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setLoadBalancingPolicyConfig(new WeightedTargetConfig(targets)) + .build()); + + // Subchannels to be created for each child balancer. + final SubchannelPicker[] subchannelPickers = new SubchannelPicker[]{ + mock(SubchannelPicker.class), + mock(SubchannelPicker.class), + mock(SubchannelPicker.class), + mock(SubchannelPicker.class)}; + ArgumentCaptor pickerCaptor = ArgumentCaptor.forClass(null); + + // One child balancer goes to TRANSIENT_FAILURE. + childHelpers.get(1).updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(Status.ABORTED)); + verify(helper, never()).updateBalancingState( + eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); + verify(helper).updateBalancingState(eq(CONNECTING), eq(BUFFER_PICKER)); + + // Another child balancer goes to READY. + childHelpers.get(2).updateBalancingState(READY, subchannelPickers[2]); + verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); + assertThat(pickerCaptor.getValue()).isInstanceOf(WeightedRandomPicker.class); + WeightedRandomPicker overallPicker = (WeightedRandomPicker) pickerCaptor.getValue(); + assertThat(overallPicker.weightedChildPickers).isEqualTo( + ImmutableList.of(new WeightedChildPicker(weights[2], subchannelPickers[2]))); + + // Another child balancer goes to READY. + childHelpers.get(3).updateBalancingState(READY, subchannelPickers[3]); + verify(helper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture()); + overallPicker = (WeightedRandomPicker) pickerCaptor.getValue(); + assertThat(overallPicker.weightedChildPickers).isEqualTo( + ImmutableList.of( + new WeightedChildPicker(weights[2], subchannelPickers[2]), + new WeightedChildPicker(weights[3], subchannelPickers[3]))); + + // Another child balancer goes to READY. + childHelpers.get(0).updateBalancingState(READY, subchannelPickers[0]); + verify(helper, times(3)).updateBalancingState(eq(READY), pickerCaptor.capture()); + overallPicker = (WeightedRandomPicker) pickerCaptor.getValue(); + assertThat(overallPicker.weightedChildPickers).isEqualTo( + ImmutableList.of( + new WeightedChildPicker(weights[0], subchannelPickers[0]), + new WeightedChildPicker(weights[2], subchannelPickers[2]), + new WeightedChildPicker(weights[3], subchannelPickers[3]))); + + // One of READY child balancers goes to TRANSIENT_FAILURE. + childHelpers.get(2).updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(Status.DATA_LOSS)); + verify(helper, times(4)).updateBalancingState(eq(READY), pickerCaptor.capture()); + overallPicker = (WeightedRandomPicker) pickerCaptor.getValue(); + assertThat(overallPicker.weightedChildPickers).isEqualTo( + ImmutableList.of( + new WeightedChildPicker(weights[0], subchannelPickers[0]), + new WeightedChildPicker(weights[3], subchannelPickers[3]))); + + // All child balancers go to TRANSIENT_FAILURE. + childHelpers.get(3).updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(Status.DATA_LOSS)); + childHelpers.get(0).updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(Status.CANCELLED)); + verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); + } +}