diff --git a/xds/build.gradle b/xds/build.gradle index 7dac6799f5..65abcc89c1 100644 --- a/xds/build.gradle +++ b/xds/build.gradle @@ -30,6 +30,8 @@ dependencies { exclude group: 'com.google.guava', module: 'guava' } + testCompile project(':grpc-core').sourceSets.test.output + compileOnly libraries.javax_annotation testCompile project(':grpc-testing') diff --git a/xds/src/main/java/io/grpc/xds/AdsStream.java b/xds/src/main/java/io/grpc/xds/AdsStream.java deleted file mode 100644 index cb198a61dd..0000000000 --- a/xds/src/main/java/io/grpc/xds/AdsStream.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright 2019 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds; - -import static com.google.common.base.Preconditions.checkNotNull; - -import io.envoyproxy.envoy.api.v2.DiscoveryRequest; -import io.envoyproxy.envoy.api.v2.DiscoveryResponse; -import io.envoyproxy.envoy.service.discovery.v2.AggregatedDiscoveryServiceGrpc.AggregatedDiscoveryServiceStub; -import io.grpc.Status; -import io.grpc.stub.StreamObserver; - -/** - * ADS client implementation. - */ -final class AdsStream implements StreamObserver { - private final AggregatedDiscoveryServiceStub stub; - - private StreamObserver xdsRequestWriter; - private boolean cancelled; - - AdsStream(AggregatedDiscoveryServiceStub stub) { - this.stub = checkNotNull(stub, "stub"); - } - - void start() { - xdsRequestWriter = stub.withWaitForReady().streamAggregatedResources(this); - } - - @Override - public void onNext(DiscoveryResponse value) { - // TODO: impl - } - - @Override - public void onError(Throwable t) { - // TODO: impl - } - - @Override - public void onCompleted() { - // TODO: impl - } - - void cancel(String message) { - if (cancelled) { - return; - } - cancelled = true; - xdsRequestWriter.onError(Status.CANCELLED.withDescription(message).asRuntimeException()); - } -} diff --git a/xds/src/main/java/io/grpc/xds/XdsComms.java b/xds/src/main/java/io/grpc/xds/XdsComms.java new file mode 100644 index 0000000000..0cb53fc870 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/XdsComms.java @@ -0,0 +1,146 @@ +/* + * Copyright 2019 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import io.envoyproxy.envoy.api.v2.DiscoveryRequest; +import io.envoyproxy.envoy.api.v2.DiscoveryResponse; +import io.envoyproxy.envoy.service.discovery.v2.AggregatedDiscoveryServiceGrpc; +import io.grpc.LoadBalancer.Helper; +import io.grpc.ManagedChannel; +import io.grpc.Status; +import io.grpc.stub.StreamObserver; + +/** + * ADS client implementation. + */ +final class XdsComms { + private final ManagedChannel channel; + private final Helper helper; + + // never null + private AdsStream adsStream; + + private final class AdsStream { + + final AdsStreamCallback adsStreamCallback; + + final StreamObserver xdsRequestWriter; + + final StreamObserver xdsResponseReader = + new StreamObserver() { + + boolean firstResponseReceived; + + @Override + public void onNext(DiscoveryResponse value) { + if (!firstResponseReceived) { + firstResponseReceived = true; + helper.getSynchronizationContext().execute( + new Runnable() { + @Override + public void run() { + adsStreamCallback.onWorking(); + } + }); + } + // TODO: more impl + } + + @Override + public void onError(Throwable t) { + helper.getSynchronizationContext().execute( + new Runnable() { + @Override + public void run() { + closed = true; + if (cancelled) { + return; + } + adsStreamCallback.onError(); + } + }); + // TODO: more impl + } + + @Override + public void onCompleted() { + // TODO: impl + } + }; + + boolean cancelled; + boolean closed; + + AdsStream(AdsStreamCallback adsStreamCallback) { + this.adsStreamCallback = adsStreamCallback; + this.xdsRequestWriter = AggregatedDiscoveryServiceGrpc.newStub(channel).withWaitForReady() + .streamAggregatedResources(xdsResponseReader); + } + } + + /** + * Starts a new ADS streaming RPC. + */ + XdsComms( + ManagedChannel channel, Helper helper, AdsStreamCallback adsStreamCallback) { + this.channel = checkNotNull(channel, "channel"); + this.helper = checkNotNull(helper, "helper"); + this.adsStream = new AdsStream(checkNotNull(adsStreamCallback, "adsStreamCallback")); + } + + void shutdownChannel() { + channel.shutdown(); + shutdownLbRpc("Loadbalancer client shutdown"); + } + + void refreshAdsStream() { + checkState(!channel.isShutdown(), "channel is alreday shutdown"); + + if (adsStream.closed || adsStream.cancelled) { + adsStream = new AdsStream(adsStream.adsStreamCallback); + } + } + + void shutdownLbRpc(String message) { + if (adsStream.cancelled) { + return; + } + adsStream.cancelled = true; + adsStream.xdsRequestWriter.onError( + Status.CANCELLED.withDescription(message).asRuntimeException()); + } + + /** + * Callback on ADS stream events. The callback methods should be called in a proper {@link + * io.grpc.SynchronizationContext}. + */ + interface AdsStreamCallback { + + /** + * Once the response observer receives the first response. + */ + void onWorking(); + + /** + * Once an error occurs in ADS stream. + */ + void onError(); + } +} diff --git a/xds/src/main/java/io/grpc/xds/XdsLbState.java b/xds/src/main/java/io/grpc/xds/XdsLbState.java index 1070c78b04..b6feb78664 100644 --- a/xds/src/main/java/io/grpc/xds/XdsLbState.java +++ b/xds/src/main/java/io/grpc/xds/XdsLbState.java @@ -16,14 +16,21 @@ package io.grpc.xds; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.collect.ImmutableList; import io.grpc.Attributes; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.Subchannel; import io.grpc.ManagedChannel; import io.grpc.Status; +import io.grpc.xds.XdsComms.AdsStreamCallback; +import java.net.SocketAddress; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; import javax.annotation.Nullable; /** @@ -40,42 +47,79 @@ import javax.annotation.Nullable; * do not request for endpoints. * */ -abstract class XdsLbState { +class XdsLbState { + private static final Attributes.Key> STATE_INFO = + Attributes.Key.create("io.grpc.xds.XdsLoadBalancer.stateInfo"); final String balancerName; @Nullable final Map childPolicy; - @Nullable - final Map fallbackPolicy; + private final SubchannelStore subchannelStore; + private final Helper helper; + private final AdsStreamCallback adsStreamCallback; @Nullable private XdsComms xdsComms; + XdsLbState( String balancerName, @Nullable Map childPolicy, - @Nullable Map fallbackPolicy, - @Nullable XdsComms xdsComms) { - this.balancerName = balancerName; + @Nullable XdsComms xdsComms, + Helper helper, + SubchannelStore subchannelStore, + AdsStreamCallback adsStreamCallback) { + this.balancerName = checkNotNull(balancerName, "balancerName"); this.childPolicy = childPolicy; - this.fallbackPolicy = fallbackPolicy; this.xdsComms = xdsComms; + this.helper = checkNotNull(helper, "helper"); + this.subchannelStore = checkNotNull(subchannelStore, "subchannelStore"); + this.adsStreamCallback = checkNotNull(adsStreamCallback, "adsStreamCallback"); } - abstract void handleResolvedAddressGroups( - List servers, Attributes attributes); + final void handleResolvedAddressGroups( + List servers, Attributes attributes) { - abstract void propagateError(Status error); + // start XdsComms if not already alive + if (xdsComms != null) { + xdsComms.refreshAdsStream(); + } else { + // ** This is wrong ** + // FIXME: use name resolver to resolve addresses for balancerName, and create xdsComms in + // name resolver listener callback. + // TODO: consider pass a fake EAG as a static final field visible to tests and verify + // createOobChannel() with this EAG in tests. + ManagedChannel oobChannel = helper.createOobChannel( + new EquivalentAddressGroup(ImmutableList.of(new SocketAddress() { + })), + balancerName); + xdsComms = new XdsComms(oobChannel, helper, adsStreamCallback); + } - abstract void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState); + // TODO: maybe update picker + } + + + final void handleNameResolutionError(Status error) { + if (!subchannelStore.hasNonDropBackends()) { + // TODO: maybe update picker with transient failure + } + } + + final void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState) { + // TODO: maybe update picker + } /** - * Shuts down subchannels and child loadbalancers, cancels fallback timeer, and cancels retry - * timer. + * Shuts down subchannels and child loadbalancers, and cancels retry timer. */ - abstract void shutdown(); + void shutdown() { + // TODO: cancel retry timer + // TODO: shutdown child balancers + subchannelStore.shutdown(); + } @Nullable final XdsComms shutdownAndReleaseXdsComms() { @@ -85,26 +129,50 @@ abstract class XdsLbState { return xdsComms; } - static final class XdsComms { - private final ManagedChannel channel; - private final AdsStream adsStream; + /** + * Manages EAG and locality info for a collection of subchannels, not including subchannels + * created by the fallback balancer. + */ + static final class SubchannelStoreImpl implements SubchannelStore { - XdsComms(ManagedChannel channel, AdsStream adsStream) { - this.channel = channel; - this.adsStream = adsStream; + SubchannelStoreImpl() {} + + @Override + public boolean hasReadyBackends() { + // TODO: impl + return false; } - void shutdownChannel() { - if (channel != null) { - channel.shutdown(); - } - shutdownLbRpc("Loadbalancer client shutdown"); + @Override + public boolean hasNonDropBackends() { + // TODO: impl + return false; } - void shutdownLbRpc(String message) { - if (adsStream != null) { - adsStream.cancel(message); - } + + @Override + public boolean hasSubchannel(Subchannel subchannel) { + // TODO: impl + return false; + } + + @Override + public void shutdown() { + // TODO: impl } } + + /** + * The interface of {@link XdsLbState.SubchannelStoreImpl} that is convenient for testing. + */ + public interface SubchannelStore { + + boolean hasReadyBackends(); + + boolean hasNonDropBackends(); + + boolean hasSubchannel(Subchannel subchannel); + + void shutdown(); + } } diff --git a/xds/src/main/java/io/grpc/xds/XdsLoadBalancer.java b/xds/src/main/java/io/grpc/xds/XdsLoadBalancer.java index 4b4dfd6771..46a0b13e22 100644 --- a/xds/src/main/java/io/grpc/xds/XdsLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/XdsLoadBalancer.java @@ -17,20 +17,29 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.ConnectivityState.IDLE; +import static io.grpc.ConnectivityState.SHUTDOWN; +import static io.grpc.internal.ServiceConfigUtil.getBalancerPolicyNameFromLoadBalancingConfig; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import io.grpc.Attributes; +import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; import io.grpc.LoadBalancerRegistry; import io.grpc.Status; +import io.grpc.SynchronizationContext.ScheduledHandle; import io.grpc.internal.ServiceConfigUtil; -import io.grpc.xds.XdsLbState.XdsComms; +import io.grpc.xds.XdsComms.AdsStreamCallback; +import io.grpc.xds.XdsLbState.SubchannelStore; import java.util.List; import java.util.Map; import java.util.Objects; -import javax.annotation.CheckReturnValue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import javax.annotation.Nullable; /** @@ -38,13 +47,43 @@ import javax.annotation.Nullable; */ final class XdsLoadBalancer extends LoadBalancer { - final Helper helper; + @VisibleForTesting + static final Attributes.Key> STATE_INFO = + Attributes.Key.create("io.grpc.xds.XdsLoadBalancer.stateInfo"); + + private static final ImmutableMap DEFAULT_FALLBACK_POLICY = + ImmutableMap.of("round_robin", (Object) ImmutableMap.of()); + + private final SubchannelStore subchannelStore; + private final Helper helper; + private final LoadBalancerRegistry lbRegistry; + private final FallbackManager fallbackManager; + + private final AdsStreamCallback adsStreamCallback = new AdsStreamCallback() { + + @Override + public void onWorking() { + fallbackManager.balancerWorking = true; + fallbackManager.cancelFallback(); + } + + @Override + public void onError() { + fallbackManager.balancerWorking = false; + fallbackManager.maybeUseFallbackPolicy(); + } + }; @Nullable private XdsLbState xdsLbState; - XdsLoadBalancer(Helper helper) { + private Map fallbackPolicy; + + XdsLoadBalancer(Helper helper, LoadBalancerRegistry lbRegistry, SubchannelStore subchannelStore) { this.helper = checkNotNull(helper, "helper"); + this.lbRegistry = lbRegistry; + this.subchannelStore = subchannelStore; + fallbackManager = new FallbackManager(helper, subchannelStore, lbRegistry); } @Override @@ -52,93 +91,81 @@ final class XdsLoadBalancer extends LoadBalancer { List servers, Attributes attributes) { Map newLbConfig = checkNotNull( attributes.get(ATTR_LOAD_BALANCING_CONFIG), "ATTR_LOAD_BALANCING_CONFIG not available"); + fallbackPolicy = selectFallbackPolicy(newLbConfig, lbRegistry); + fallbackManager.updateFallbackServers(servers, attributes, fallbackPolicy); + fallbackManager.maybeStartFallbackTimer(); handleNewConfig(newLbConfig); xdsLbState.handleResolvedAddressGroups(servers, attributes); } private void handleNewConfig(Map newLbConfig) { String newBalancerName = ServiceConfigUtil.getBalancerNameFromXdsConfig(newLbConfig); - Map childPolicy = selectChildPolicy(newLbConfig); - Map fallbackPolicy = selectFallbackPolicy(newLbConfig); + Map childPolicy = selectChildPolicy(newLbConfig, lbRegistry); XdsComms xdsComms = null; if (xdsLbState != null) { // may release and re-use/shutdown xdsComms from current xdsLbState if (!newBalancerName.equals(xdsLbState.balancerName)) { xdsComms = xdsLbState.shutdownAndReleaseXdsComms(); if (xdsComms != null) { xdsComms.shutdownChannel(); + fallbackManager.balancerWorking = false; xdsComms = null; } - } else if (!Objects.equals(childPolicy, xdsLbState.childPolicy) - // There might be optimization when only fallbackPolicy is changed. - || !Objects.equals(fallbackPolicy, xdsLbState.fallbackPolicy)) { + } else if (!Objects.equals( + getPolicyNameOrNull(childPolicy), + getPolicyNameOrNull(xdsLbState.childPolicy))) { String cancelMessage = "Changing loadbalancing mode"; xdsComms = xdsLbState.shutdownAndReleaseXdsComms(); // close the stream but reuse the channel if (xdsComms != null) { xdsComms.shutdownLbRpc(cancelMessage); + fallbackManager.balancerWorking = false; + xdsComms.refreshAdsStream(); } } else { // effectively no change in policy, keep xdsLbState unchanged return; } } - xdsLbState = newXdsLbState( - newBalancerName, childPolicy, fallbackPolicy, xdsComms); - } - - @CheckReturnValue - private XdsLbState newXdsLbState( - String balancerName, - @Nullable final Map childPolicy, - @Nullable Map fallbackPolicy, - @Nullable final XdsComms xdsComms) { - - // TODO: impl - return new XdsLbState(balancerName, childPolicy, fallbackPolicy, xdsComms) { - @Override - void handleResolvedAddressGroups( - List servers, Attributes attributes) {} - - @Override - void propagateError(Status error) {} - - @Override - void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState) {} - - @Override - void shutdown() {} - }; + xdsLbState = new XdsLbState( + newBalancerName, childPolicy, xdsComms, helper, subchannelStore, adsStreamCallback); } @Nullable - @VisibleForTesting - static Map selectChildPolicy(Map lbConfig) { - List> childConfigs = - ServiceConfigUtil.getChildPolicyFromXdsConfig(lbConfig); - return selectSupportedLbPolicy(childConfigs); - } - - @Nullable - @VisibleForTesting - static Map selectFallbackPolicy(Map lbConfig) { - if (lbConfig == null) { + private static String getPolicyNameOrNull(@Nullable Map config) { + if (config == null) { return null; } - List> fallbackConfigs = - ServiceConfigUtil.getFallbackPolicyFromXdsConfig(lbConfig); - return selectSupportedLbPolicy(fallbackConfigs); + return getBalancerPolicyNameFromLoadBalancingConfig(config); } @Nullable - private static Map selectSupportedLbPolicy(List> lbConfigs) { + @VisibleForTesting + static Map selectChildPolicy( + Map lbConfig, LoadBalancerRegistry lbRegistry) { + List> childConfigs = + ServiceConfigUtil.getChildPolicyFromXdsConfig(lbConfig); + return selectSupportedLbPolicy(childConfigs, lbRegistry); + } + + @VisibleForTesting + static Map selectFallbackPolicy( + Map lbConfig, LoadBalancerRegistry lbRegistry) { + List> fallbackConfigs = + ServiceConfigUtil.getFallbackPolicyFromXdsConfig(lbConfig); + Map fallbackPolicy = selectSupportedLbPolicy(fallbackConfigs, lbRegistry); + return fallbackPolicy == null ? DEFAULT_FALLBACK_POLICY : fallbackPolicy; + } + + @Nullable + private static Map selectSupportedLbPolicy( + List> lbConfigs, LoadBalancerRegistry lbRegistry) { if (lbConfigs == null) { return null; } - LoadBalancerRegistry loadBalancerRegistry = LoadBalancerRegistry.getDefaultRegistry(); for (Object lbConfig : lbConfigs) { @SuppressWarnings("unchecked") Map candidate = (Map) lbConfig; - String lbPolicy = candidate.entrySet().iterator().next().getKey(); - if (loadBalancerRegistry.getProvider(lbPolicy) != null) { + String lbPolicy = ServiceConfigUtil.getBalancerPolicyNameFromLoadBalancingConfig(candidate); + if (lbRegistry.getProvider(lbPolicy) != null) { return candidate; } } @@ -148,7 +175,11 @@ final class XdsLoadBalancer extends LoadBalancer { @Override public void handleNameResolutionError(Status error) { if (xdsLbState != null) { - xdsLbState.propagateError(error); + if (fallbackManager.fallbackBalancer != null) { + fallbackManager.fallbackBalancer.handleNameResolutionError(error); + } else { + xdsLbState.handleNameResolutionError(error); + } } // TODO: impl // else { @@ -160,7 +191,21 @@ final class XdsLoadBalancer extends LoadBalancer { public void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState) { // xdsLbState should never be null here since handleSubchannelState cannot be called while the // lb is shutdown. - xdsLbState.handleSubchannelState(subchannel, newState); + if (newState.getState() == SHUTDOWN) { + return; + } + + if (fallbackManager.fallbackBalancer != null) { + fallbackManager.fallbackBalancer.handleSubchannelState(subchannel, newState); + } + if (subchannelStore.hasSubchannel(subchannel)) { + if (newState.getState() == IDLE) { + subchannel.requestConnection(); + } + subchannel.getAttributes().get(STATE_INFO).set(newState); + xdsLbState.handleSubchannelState(subchannel, newState); + fallbackManager.maybeUseFallbackPolicy(); + } } @Override @@ -172,6 +217,7 @@ final class XdsLoadBalancer extends LoadBalancer { } xdsLbState = null; } + fallbackManager.cancelFallback(); } @Override @@ -179,9 +225,108 @@ final class XdsLoadBalancer extends LoadBalancer { return true; } - @VisibleForTesting @Nullable - XdsLbState getXdsLbState() { + XdsLbState getXdsLbStateForTest() { return xdsLbState; } + + @VisibleForTesting + static final class FallbackManager { + + private static final long FALLBACK_TIMEOUT_MS = TimeUnit.SECONDS.toMillis(10); // same as grpclb + + private final Helper helper; + private final SubchannelStore subchannelStore; + private final LoadBalancerRegistry lbRegistry; + + private Map fallbackPolicy; + + // read-only for outer class + private LoadBalancer fallbackBalancer; + + // Scheduled only once. Never reset. + @Nullable + private ScheduledHandle fallbackTimer; + + private List fallbackServers = ImmutableList.of(); + private Attributes fallbackAttributes; + + // allow value write by outer class + private boolean balancerWorking; + + FallbackManager( + Helper helper, SubchannelStore subchannelStore, LoadBalancerRegistry lbRegistry) { + this.helper = helper; + this.subchannelStore = subchannelStore; + this.lbRegistry = lbRegistry; + } + + void cancelFallback() { + if (fallbackTimer != null) { + fallbackTimer.cancel(); + } + if (fallbackBalancer != null) { + fallbackBalancer.shutdown(); + fallbackBalancer = null; + } + } + + void maybeUseFallbackPolicy() { + if (fallbackBalancer != null) { + return; + } + if (balancerWorking || subchannelStore.hasReadyBackends()) { + return; + } + + helper.getChannelLogger().log( + ChannelLogLevel.INFO, "Using fallback policy"); + String fallbackPolicyName = ServiceConfigUtil.getBalancerPolicyNameFromLoadBalancingConfig( + fallbackPolicy); + fallbackBalancer = lbRegistry.getProvider(fallbackPolicyName) + .newLoadBalancer(helper); + fallbackBalancer.handleResolvedAddressGroups(fallbackServers, fallbackAttributes); + // TODO: maybe update picker + } + + void updateFallbackServers( + List servers, Attributes attributes, + Map fallbackPolicy) { + this.fallbackServers = servers; + this.fallbackAttributes = Attributes.newBuilder() + .setAll(attributes) + .set(ATTR_LOAD_BALANCING_CONFIG, fallbackPolicy) + .build(); + Map currentFallbackPolicy = this.fallbackPolicy; + this.fallbackPolicy = fallbackPolicy; + if (fallbackBalancer != null) { + String currentPolicyName = + ServiceConfigUtil.getBalancerPolicyNameFromLoadBalancingConfig(currentFallbackPolicy); + String newPolicyName = + ServiceConfigUtil.getBalancerPolicyNameFromLoadBalancingConfig(fallbackPolicy); + if (newPolicyName.equals(currentPolicyName)) { + fallbackBalancer.handleResolvedAddressGroups(fallbackServers, fallbackAttributes); + } else { + fallbackBalancer.shutdown(); + fallbackBalancer = null; + maybeUseFallbackPolicy(); + } + } + } + + void maybeStartFallbackTimer() { + if (fallbackTimer == null) { + class FallbackTask implements Runnable { + @Override + public void run() { + maybeUseFallbackPolicy(); + } + } + + fallbackTimer = helper.getSynchronizationContext().schedule( + new FallbackTask(), FALLBACK_TIMEOUT_MS, TimeUnit.MILLISECONDS, + helper.getScheduledExecutorService()); + } + } + } } diff --git a/xds/src/main/java/io/grpc/xds/XdsLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/XdsLoadBalancerProvider.java index 29a64f32e1..eb3e4485de 100644 --- a/xds/src/main/java/io/grpc/xds/XdsLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/XdsLoadBalancerProvider.java @@ -20,6 +20,8 @@ import io.grpc.Internal; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; +import io.grpc.xds.XdsLbState.SubchannelStoreImpl; /** * The provider for the "xds" balancing policy. This class should not be directly referenced in @@ -46,6 +48,7 @@ public final class XdsLoadBalancerProvider extends LoadBalancerProvider { @Override public LoadBalancer newLoadBalancer(Helper helper) { - return new XdsLoadBalancer(helper); + return new XdsLoadBalancer( + helper, LoadBalancerRegistry.getDefaultRegistry(), new SubchannelStoreImpl()); } } diff --git a/xds/src/test/java/io/grpc/xds/FallbackManagerTest.java b/xds/src/test/java/io/grpc/xds/FallbackManagerTest.java new file mode 100644 index 0000000000..5520b0dfb9 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/FallbackManagerTest.java @@ -0,0 +1,152 @@ +/* + * Copyright 2019 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Matchers.eq; +import static org.mockito.Matchers.same; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + +import io.grpc.Attributes; +import io.grpc.ChannelLogger; +import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; +import io.grpc.SynchronizationContext; +import io.grpc.internal.FakeClock; +import io.grpc.xds.XdsLbState.SubchannelStoreImpl; +import io.grpc.xds.XdsLoadBalancer.FallbackManager; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Matchers; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +/** + * Unit test for {@link FallbackManager}. + */ +@RunWith(JUnit4.class) +public class FallbackManagerTest { + + private static final long FALLBACK_TIMEOUT_MS = TimeUnit.SECONDS.toMillis(10); + + private final FakeClock fakeClock = new FakeClock(); + private final LoadBalancerRegistry lbRegistry = new LoadBalancerRegistry(); + + private final LoadBalancerProvider fakeLbProvider = new LoadBalancerProvider() { + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 5; + } + + @Override + public String getPolicyName() { + return "test_policy"; + } + + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + return fakeLb; + } + }; + + private final SynchronizationContext syncContext = new SynchronizationContext( + new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + throw new AssertionError(e); + } + }); + + @Mock + private Helper helper; + @Mock + private LoadBalancer fakeLb; + @Mock + private ChannelLogger channelLogger; + + private FallbackManager fallbackManager; + private Map fallbackPolicy; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + doReturn(syncContext).when(helper).getSynchronizationContext(); + doReturn(fakeClock.getScheduledExecutorService()).when(helper).getScheduledExecutorService(); + doReturn(channelLogger).when(helper).getChannelLogger(); + fallbackManager = new FallbackManager(helper, new SubchannelStoreImpl(), lbRegistry); + fallbackPolicy = new HashMap<>(); + fallbackPolicy.put("test_policy", new HashMap<>()); + lbRegistry.register(fakeLbProvider); + } + + @After + public void tearDown() { + assertThat(fakeClock.getPendingTasks()).isEmpty(); + } + + @Test + public void useFallbackWhenTimeout() { + fallbackManager.maybeStartFallbackTimer(); + List eags = new ArrayList<>(); + fallbackManager.updateFallbackServers( + eags, Attributes.EMPTY, fallbackPolicy); + + verify(fakeLb, never()).handleResolvedAddressGroups( + Matchers.>any(), Matchers.any()); + + fakeClock.forwardTime(FALLBACK_TIMEOUT_MS, TimeUnit.MILLISECONDS); + + verify(fakeLb).handleResolvedAddressGroups( + same(eags), + eq(Attributes.newBuilder() + .set(LoadBalancer.ATTR_LOAD_BALANCING_CONFIG, fallbackPolicy) + .build())); + } + + @Test + public void cancelFallback() { + fallbackManager.maybeStartFallbackTimer(); + List eags = new ArrayList<>(); + fallbackManager.updateFallbackServers( + eags, Attributes.EMPTY, fallbackPolicy); + + fallbackManager.cancelFallback(); + + fakeClock.forwardTime(FALLBACK_TIMEOUT_MS, TimeUnit.MILLISECONDS); + + verify(fakeLb, never()).handleResolvedAddressGroups( + Matchers.>any(), Matchers.any()); + } +} diff --git a/xds/src/test/java/io/grpc/xds/AdsStreamTest.java b/xds/src/test/java/io/grpc/xds/XdsCommsTest.java similarity index 86% rename from xds/src/test/java/io/grpc/xds/AdsStreamTest.java rename to xds/src/test/java/io/grpc/xds/XdsCommsTest.java index 437f7e6666..d5dbb52044 100644 --- a/xds/src/test/java/io/grpc/xds/AdsStreamTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsCommsTest.java @@ -21,9 +21,8 @@ import static org.junit.Assert.assertTrue; import io.envoyproxy.envoy.api.v2.DiscoveryRequest; import io.envoyproxy.envoy.api.v2.DiscoveryResponse; -import io.envoyproxy.envoy.service.discovery.v2.AggregatedDiscoveryServiceGrpc; import io.envoyproxy.envoy.service.discovery.v2.AggregatedDiscoveryServiceGrpc.AggregatedDiscoveryServiceImplBase; -import io.envoyproxy.envoy.service.discovery.v2.AggregatedDiscoveryServiceGrpc.AggregatedDiscoveryServiceStub; +import io.grpc.LoadBalancer.Helper; import io.grpc.ManagedChannel; import io.grpc.Status; import io.grpc.inprocess.InProcessChannelBuilder; @@ -31,27 +30,36 @@ import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.internal.testing.StreamRecorder; import io.grpc.stub.StreamObserver; import io.grpc.testing.GrpcCleanupRule; +import io.grpc.xds.XdsComms.AdsStreamCallback; import java.util.concurrent.TimeUnit; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; /** - * Unit tests for {@link AdsStream}. + * Unit tests for {@link XdsComms}. */ @RunWith(JUnit4.class) -public class AdsStreamTest { +public class XdsCommsTest { @Rule public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); + @Mock + Helper helper; + @Mock + AdsStreamCallback adsStreamCallback; private final StreamRecorder streamRecorder = StreamRecorder.create(); - private AdsStream adsStream; + private XdsComms xdsComms; @Before public void setUp() throws Exception { + MockitoAnnotations.initMocks(this); + String serverName = InProcessServerBuilder.generateName(); AggregatedDiscoveryServiceImplBase serviceImpl = new AggregatedDiscoveryServiceImplBase() { @@ -87,14 +95,12 @@ public class AdsStreamTest { .start()); ManagedChannel channel = cleanupRule.register(InProcessChannelBuilder.forName(serverName).build()); - AggregatedDiscoveryServiceStub stub = AggregatedDiscoveryServiceGrpc.newStub(channel); - adsStream = new AdsStream(stub); - adsStream.start(); + xdsComms = new XdsComms(channel, helper, adsStreamCallback); } @Test public void cancel() throws Exception { - adsStream.cancel("cause1"); + xdsComms.shutdownLbRpc("cause1"); assertTrue(streamRecorder.awaitCompletion(1, TimeUnit.SECONDS)); assertEquals(Status.Code.CANCELLED, Status.fromThrowable(streamRecorder.getError()).getCode()); } diff --git a/xds/src/test/java/io/grpc/xds/XdsLbStateTest.java b/xds/src/test/java/io/grpc/xds/XdsLbStateTest.java index 590a5c2087..6cd3dd29b5 100644 --- a/xds/src/test/java/io/grpc/xds/XdsLbStateTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsLbStateTest.java @@ -19,22 +19,24 @@ package io.grpc.xds; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import io.envoyproxy.envoy.api.v2.DiscoveryRequest; import io.envoyproxy.envoy.api.v2.DiscoveryResponse; -import io.envoyproxy.envoy.service.discovery.v2.AggregatedDiscoveryServiceGrpc; import io.envoyproxy.envoy.service.discovery.v2.AggregatedDiscoveryServiceGrpc.AggregatedDiscoveryServiceImplBase; -import io.envoyproxy.envoy.service.discovery.v2.AggregatedDiscoveryServiceGrpc.AggregatedDiscoveryServiceStub; +import io.grpc.LoadBalancer.Helper; import io.grpc.ManagedChannel; import io.grpc.Status; +import io.grpc.SynchronizationContext; import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.internal.FakeClock; import io.grpc.internal.testing.StreamRecorder; import io.grpc.stub.StreamObserver; import io.grpc.testing.GrpcCleanupRule; -import io.grpc.xds.XdsLbState.XdsComms; +import io.grpc.xds.XdsComms.AdsStreamCallback; import java.util.concurrent.TimeUnit; import org.junit.After; import org.junit.Before; @@ -42,6 +44,7 @@ import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.Mock; import org.mockito.MockitoAnnotations; /** @@ -51,9 +54,22 @@ import org.mockito.MockitoAnnotations; public class XdsLbStateTest { @Rule public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); + @Mock + private Helper helper; + @Mock + private AdsStreamCallback adsStreamCallback; + private final FakeClock fakeClock = new FakeClock(); private final StreamRecorder streamRecorder = StreamRecorder.create(); + private final SynchronizationContext syncContext = new SynchronizationContext( + new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + throw new AssertionError(e); + } + }); + private XdsComms xdsComms; private ManagedChannel channel; @@ -61,6 +77,8 @@ public class XdsLbStateTest { @Before public void setUp() throws Exception { MockitoAnnotations.initMocks(this); + doReturn(syncContext).when(helper).getSynchronizationContext(); + doReturn(fakeClock.getScheduledExecutorService()).when(helper).getScheduledExecutorService(); String serverName = InProcessServerBuilder.generateName(); @@ -97,10 +115,7 @@ public class XdsLbStateTest { .start()); channel = cleanupRule.register(InProcessChannelBuilder.forName(serverName).build()); - AggregatedDiscoveryServiceStub stub = AggregatedDiscoveryServiceGrpc.newStub(channel); - AdsStream adsStream = new AdsStream(stub); - adsStream.start(); - xdsComms = new XdsComms(channel, adsStream); + xdsComms = new XdsComms(channel, helper, adsStreamCallback); } @After diff --git a/xds/src/test/java/io/grpc/xds/XdsLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/XdsLoadBalancerTest.java index d12c009bdf..cfb2494c34 100644 --- a/xds/src/test/java/io/grpc/xds/XdsLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsLoadBalancerTest.java @@ -18,23 +18,59 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; import static io.grpc.LoadBalancer.ATTR_LOAD_BALANCING_CONFIG; +import static io.grpc.xds.XdsLoadBalancer.STATE_INFO; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.mockito.AdditionalAnswers.delegatesTo; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import io.envoyproxy.envoy.api.v2.DiscoveryRequest; +import io.envoyproxy.envoy.api.v2.DiscoveryResponse; +import io.envoyproxy.envoy.service.discovery.v2.AggregatedDiscoveryServiceGrpc.AggregatedDiscoveryServiceImplBase; import io.grpc.Attributes; +import io.grpc.CallOptions; +import io.grpc.ChannelLogger; +import io.grpc.ConnectivityState; +import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancerProvider; import io.grpc.LoadBalancerRegistry; +import io.grpc.ManagedChannel; +import io.grpc.MethodDescriptor; +import io.grpc.Status; +import io.grpc.SynchronizationContext; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.internal.FakeClock; import io.grpc.internal.JsonParser; +import io.grpc.internal.testing.StreamRecorder; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.GrpcCleanupRule; +import io.grpc.xds.XdsLbState.SubchannelStore; +import io.grpc.xds.XdsLbState.SubchannelStoreImpl; import java.util.Collections; +import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import org.junit.After; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Matchers; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -43,10 +79,21 @@ import org.mockito.MockitoAnnotations; */ @RunWith(JUnit4.class) public class XdsLoadBalancerTest { + @Rule + public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); @Mock private Helper helper; + @Mock + private LoadBalancer fakeBalancer1; + @Mock + private LoadBalancer fakeBalancer2; private XdsLoadBalancer lb; + private final FakeClock fakeClock = new FakeClock(); + private final StreamRecorder streamRecorder = StreamRecorder.create(); + + private final LoadBalancerRegistry lbRegistry = new LoadBalancerRegistry(); + private final LoadBalancerProvider lbProvider1 = new LoadBalancerProvider() { @Override public boolean isAvailable() { @@ -65,7 +112,7 @@ public class XdsLoadBalancerTest { @Override public LoadBalancer newLoadBalancer(Helper helper) { - return null; + return fakeBalancer1; } }; @@ -85,30 +132,121 @@ public class XdsLoadBalancerTest { return "supported_2"; } + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + return fakeBalancer2; + } + }; + + private final LoadBalancerProvider roundRobin = new LoadBalancerProvider() { + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 5; + } + + @Override + public String getPolicyName() { + return "round_robin"; + } + @Override public LoadBalancer newLoadBalancer(Helper helper) { return null; } }; + private final SynchronizationContext syncContext = new SynchronizationContext( + new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + throw new AssertionError(e); + } + }); + private final SubchannelStore fakeSubchannelStore = + mock(SubchannelStore.class, delegatesTo(new SubchannelStoreImpl())); + + private ManagedChannel oobChannel1; + private ManagedChannel oobChannel2; + private ManagedChannel oobChannel3; + + private StreamObserver serverResponseWriter; @Before - public void setUp() { + public void setUp() throws Exception { MockitoAnnotations.initMocks(this); - LoadBalancerRegistry.getDefaultRegistry().register(lbProvider1); - LoadBalancerRegistry.getDefaultRegistry().register(lbProvider2); - lb = new XdsLoadBalancer(helper); + lbRegistry.register(lbProvider1); + lbRegistry.register(lbProvider2); + lbRegistry.register(roundRobin); + lb = new XdsLoadBalancer(helper, lbRegistry, fakeSubchannelStore); + doReturn(syncContext).when(helper).getSynchronizationContext(); + doReturn(fakeClock.getScheduledExecutorService()).when(helper).getScheduledExecutorService(); + doReturn(mock(ChannelLogger.class)).when(helper).getChannelLogger(); + + String serverName = InProcessServerBuilder.generateName(); + + AggregatedDiscoveryServiceImplBase serviceImpl = new AggregatedDiscoveryServiceImplBase() { + @Override + public StreamObserver streamAggregatedResources( + final StreamObserver responseObserver) { + serverResponseWriter = responseObserver; + + return new StreamObserver() { + + @Override + public void onNext(DiscoveryRequest value) { + streamRecorder.onNext(value); + } + + @Override + public void onError(Throwable t) { + streamRecorder.onError(t); + } + + @Override + public void onCompleted() { + streamRecorder.onCompleted(); + responseObserver.onCompleted(); + } + }; + } + }; + + cleanupRule.register( + InProcessServerBuilder + .forName(serverName) + .directExecutor() + .addService(serviceImpl) + .build() + .start()); + + InProcessChannelBuilder channelBuilder = + InProcessChannelBuilder.forName(serverName).directExecutor(); + oobChannel1 = mock( + ManagedChannel.class, + delegatesTo(cleanupRule.register(channelBuilder.build()))); + oobChannel2 = mock( + ManagedChannel.class, + delegatesTo(cleanupRule.register(channelBuilder.build()))); + oobChannel3 = mock( + ManagedChannel.class, + delegatesTo(cleanupRule.register(channelBuilder.build()))); + + doReturn(oobChannel1).doReturn(oobChannel2).doReturn(oobChannel3) + .when(helper).createOobChannel(Matchers.any(), anyString()); } @After public void tearDown() { - LoadBalancerRegistry.getDefaultRegistry().deregister(lbProvider1); - LoadBalancerRegistry.getDefaultRegistry().deregister(lbProvider2); + lb.shutdown(); } @Test - @SuppressWarnings("unchecked") public void selectChildPolicy() throws Exception { String lbConfigRaw = "{\"xds_experimental\" : { " + "\"balancerName\" : \"dns:///balancer.example.com:8080\"," @@ -116,17 +254,18 @@ public class XdsLoadBalancerTest { + "{\"supported_2\" : {\"key\" : \"val\"}}]," + "\"fallbackPolicy\" : [{\"lbPolicy3\" : {\"key\" : \"val\"}}, {\"lbPolicy4\" : {}}]" + "}}"; + @SuppressWarnings("unchecked") Map expectedChildPolicy = (Map) JsonParser.parse( "{\"supported_1\" : {\"key\" : \"val\"}}"); + @SuppressWarnings("unchecked") Map childPolicy = XdsLoadBalancer - .selectChildPolicy((Map) JsonParser.parse(lbConfigRaw)); + .selectChildPolicy((Map) JsonParser.parse(lbConfigRaw), lbRegistry); assertEquals(expectedChildPolicy, childPolicy); } @Test - @SuppressWarnings("unchecked") public void selectFallBackPolicy() throws Exception { String lbConfigRaw = "{\"xds_experimental\" : { " + "\"balancerName\" : \"dns:///balancer.example.com:8080\"," @@ -134,11 +273,30 @@ public class XdsLoadBalancerTest { + "\"fallbackPolicy\" : [{\"unsupported\" : {}}, {\"supported_1\" : {\"key\" : \"val\"}}," + "{\"supported_2\" : {\"key\" : \"val\"}}]" + "}}"; + @SuppressWarnings("unchecked") Map expectedFallbackPolicy = (Map) JsonParser.parse( "{\"supported_1\" : {\"key\" : \"val\"}}"); + @SuppressWarnings("unchecked") Map fallbackPolicy = XdsLoadBalancer - .selectFallbackPolicy((Map) JsonParser.parse(lbConfigRaw)); + .selectFallbackPolicy((Map) JsonParser.parse(lbConfigRaw), lbRegistry); + + assertEquals(expectedFallbackPolicy, fallbackPolicy); + } + + @Test + public void selectFallBackPolicy_roundRobinIsDefault() throws Exception { + String lbConfigRaw = "{\"xds_experimental\" : { " + + "\"balancerName\" : \"dns:///balancer.example.com:8080\"," + + "\"childPolicy\" : [{\"lbPolicy3\" : {\"key\" : \"val\"}}, {\"lbPolicy4\" : {}}]" + + "}}"; + @SuppressWarnings("unchecked") + Map expectedFallbackPolicy = (Map) JsonParser.parse( + "{\"round_robin\" : {}}"); + + @SuppressWarnings("unchecked") + Map fallbackPolicy = XdsLoadBalancer + .selectFallbackPolicy((Map) JsonParser.parse(lbConfigRaw), lbRegistry); assertEquals(expectedFallbackPolicy, fallbackPolicy); } @@ -149,121 +307,308 @@ public class XdsLoadBalancerTest { } @Test - @SuppressWarnings("unchecked") public void resolverEvent_standardModeToStandardMode() throws Exception { String lbConfigRaw = "{\"xds_experimental\" : { " + "\"balancerName\" : \"dns:///balancer.example.com:8080\"," + "\"childPolicy\" : [{\"unsupported\" : {\"key\" : \"val\"}}, {\"unsupported_2\" : {}}]," + "\"fallbackPolicy\" : [{\"unsupported\" : {}}, {\"supported_1\" : {\"key\" : \"val\"}}]" + "}}"; + @SuppressWarnings("unchecked") Map lbConfig = (Map) JsonParser.parse(lbConfigRaw); Attributes attrs = Attributes.newBuilder().set(ATTR_LOAD_BALANCING_CONFIG, lbConfig).build(); lb.handleResolvedAddressGroups(Collections.emptyList(), attrs); - assertThat(lb.getXdsLbState().childPolicy).isNull(); + XdsLbState xdsLbState1 = lb.getXdsLbStateForTest(); + assertThat(xdsLbState1.childPolicy).isNull(); + verify(helper).createOobChannel(Matchers.any(), anyString()); + verify(oobChannel1) + .newCall(Matchers.>any(), Matchers.any()); + lbConfigRaw = "{\"xds_experimental\" : { " + "\"balancerName\" : \"dns:///balancer.example.com:8080\"," + "\"fallbackPolicy\" : [{\"unsupported\" : {}}, {\"supported_1\" : {\"key\" : \"val\"}}]" + "}}"; - lbConfig = (Map) JsonParser.parse(lbConfigRaw); - attrs = Attributes.newBuilder().set(ATTR_LOAD_BALANCING_CONFIG, lbConfig).build(); + @SuppressWarnings("unchecked") + Map lbConfig2 = (Map) JsonParser.parse(lbConfigRaw); + attrs = Attributes.newBuilder().set(ATTR_LOAD_BALANCING_CONFIG, lbConfig2).build(); lb.handleResolvedAddressGroups(Collections.emptyList(), attrs); - assertThat(lb.getXdsLbState().childPolicy).isNull(); + XdsLbState xdsLbState2 = lb.getXdsLbStateForTest(); + assertThat(xdsLbState2.childPolicy).isNull(); + assertThat(xdsLbState2).isSameAs(xdsLbState1); - // TODO(zdapeng): test adsStream is unchanged. + // verify oobChannel is unchanged + verify(helper).createOobChannel(Matchers.any(), anyString()); + // verify ADS stream is unchanged + verify(oobChannel1) + .newCall(Matchers.>any(), Matchers.any()); } @Test - @SuppressWarnings("unchecked") public void resolverEvent_standardModeToCustomMode() throws Exception { String lbConfigRaw = "{\"xds_experimental\" : { " + "\"balancerName\" : \"dns:///balancer.example.com:8080\"," + "\"childPolicy\" : [{\"unsupported\" : {\"key\" : \"val\"}}, {\"unsupported_2\" : {}}]," + "\"fallbackPolicy\" : [{\"unsupported\" : {}}, {\"supported_1\" : {\"key\" : \"val\"}}]" + "}}"; + @SuppressWarnings("unchecked") Map lbConfig = (Map) JsonParser.parse(lbConfigRaw); Attributes attrs = Attributes.newBuilder().set(ATTR_LOAD_BALANCING_CONFIG, lbConfig).build(); lb.handleResolvedAddressGroups(Collections.emptyList(), attrs); + verify(helper).createOobChannel(Matchers.any(), anyString()); + verify(oobChannel1) + .newCall(Matchers.>any(), Matchers.any()); lbConfigRaw = "{\"xds_experimental\" : { " + "\"balancerName\" : \"dns:///balancer.example.com:8080\"," + "\"childPolicy\" : [{\"supported_1\" : {\"key\" : \"val\"}}, {\"unsupported_2\" : {}}]," + "\"fallbackPolicy\" : [{\"unsupported\" : {}}, {\"supported_1\" : {\"key\" : \"val\"}}]" + "}}"; - lbConfig = (Map) JsonParser.parse(lbConfigRaw); - attrs = Attributes.newBuilder().set(ATTR_LOAD_BALANCING_CONFIG, lbConfig).build(); + @SuppressWarnings("unchecked") + Map lbConfig2 = (Map) JsonParser.parse(lbConfigRaw); + attrs = Attributes.newBuilder().set(ATTR_LOAD_BALANCING_CONFIG, lbConfig2).build(); lb.handleResolvedAddressGroups(Collections.emptyList(), attrs); - assertThat(lb.getXdsLbState().childPolicy).isNotNull(); + assertThat(lb.getXdsLbStateForTest().childPolicy).isNotNull(); - // TODO(zdapeng): test adsStream is reset, channel is unchanged. + // verify oobChannel is unchanged + verify(helper).createOobChannel(Matchers.any(), anyString()); + // verify ADS stream is reset + verify(oobChannel1, times(2)) + .newCall(Matchers.>any(), Matchers.any()); } @Test - @SuppressWarnings("unchecked") public void resolverEvent_customModeToStandardMode() throws Exception { String lbConfigRaw = "{\"xds_experimental\" : { " + "\"balancerName\" : \"dns:///balancer.example.com:8080\"," + "\"childPolicy\" : [{\"supported_1\" : {\"key\" : \"val\"}}, {\"unsupported_2\" : {}}]," + "\"fallbackPolicy\" : [{\"unsupported\" : {}}, {\"supported_1\" : {\"key\" : \"val\"}}]" + "}}"; + @SuppressWarnings("unchecked") Map lbConfig = (Map) JsonParser.parse(lbConfigRaw); Attributes attrs = Attributes.newBuilder().set(ATTR_LOAD_BALANCING_CONFIG, lbConfig).build(); lb.handleResolvedAddressGroups(Collections.emptyList(), attrs); + verify(helper).createOobChannel(Matchers.any(), anyString()); + verify(oobChannel1) + .newCall(Matchers.>any(), Matchers.any()); - assertThat(lb.getXdsLbState().childPolicy).isNotNull(); + assertThat(lb.getXdsLbStateForTest().childPolicy).isNotNull(); lbConfigRaw = "{\"xds_experimental\" : { " + "\"balancerName\" : \"dns:///balancer.example.com:8080\"," + "\"childPolicy\" : [{\"unsupported\" : {\"key\" : \"val\"}}, {\"unsupported_2\" : {}}]," + "\"fallbackPolicy\" : [{\"unsupported\" : {}}, {\"supported_1\" : {\"key\" : \"val\"}}]" + "}}"; - lbConfig = (Map) JsonParser.parse(lbConfigRaw); - attrs = Attributes.newBuilder().set(ATTR_LOAD_BALANCING_CONFIG, lbConfig).build(); + @SuppressWarnings("unchecked") + Map lbConfig2 = (Map) JsonParser.parse(lbConfigRaw); + attrs = Attributes.newBuilder().set(ATTR_LOAD_BALANCING_CONFIG, lbConfig2).build(); lb.handleResolvedAddressGroups(Collections.emptyList(), attrs); - assertThat(lb.getXdsLbState().childPolicy).isNull(); + assertThat(lb.getXdsLbStateForTest().childPolicy).isNull(); - // TODO(zdapeng): test adsStream is unchanged. + // verify oobChannel is unchanged + verify(helper).createOobChannel(Matchers.any(), anyString()); + // verify ADS stream is reset + verify(oobChannel1, times(2)) + .newCall(Matchers.>any(), Matchers.any()); } @Test - @SuppressWarnings("unchecked") public void resolverEvent_customModeToCustomMode() throws Exception { String lbConfigRaw = "{\"xds_experimental\" : { " + "\"balancerName\" : \"dns:///balancer.example.com:8080\"," + "\"childPolicy\" : [{\"supported_1\" : {\"key\" : \"val\"}}, {\"unsupported_2\" : {}}]," + "\"fallbackPolicy\" : [{\"unsupported\" : {}}, {\"supported_1\" : {\"key\" : \"val\"}}]" + "}}"; + @SuppressWarnings("unchecked") Map lbConfig = (Map) JsonParser.parse(lbConfigRaw); Attributes attrs = Attributes.newBuilder().set(ATTR_LOAD_BALANCING_CONFIG, lbConfig).build(); lb.handleResolvedAddressGroups(Collections.emptyList(), attrs); - assertThat(lb.getXdsLbState().childPolicy).isNotNull(); + assertThat(lb.getXdsLbStateForTest().childPolicy).isNotNull(); + verify(helper).createOobChannel(Matchers.any(), anyString()); + verify(oobChannel1) + .newCall(Matchers.>any(), Matchers.any()); lbConfigRaw = "{\"xds_experimental\" : { " + "\"balancerName\" : \"dns:///balancer.example.com:8080\"," + "\"childPolicy\" : [{\"supported_2\" : {\"key\" : \"val\"}}, {\"unsupported_1\" : {}}]," + "\"fallbackPolicy\" : [{\"unsupported\" : {}}, {\"supported_1\" : {\"key\" : \"val\"}}]" + "}}"; - lbConfig = (Map) JsonParser.parse(lbConfigRaw); - attrs = Attributes.newBuilder().set(ATTR_LOAD_BALANCING_CONFIG, lbConfig).build(); + @SuppressWarnings("unchecked") + Map lbConfig2 = (Map) JsonParser.parse(lbConfigRaw); + attrs = Attributes.newBuilder().set(ATTR_LOAD_BALANCING_CONFIG, lbConfig2).build(); lb.handleResolvedAddressGroups(Collections.emptyList(), attrs); - assertThat(lb.getXdsLbState().childPolicy).isNotNull(); - - // TODO(zdapeng): test adsStream is reset, channel is unchanged. + assertThat(lb.getXdsLbStateForTest().childPolicy).isNotNull(); + // verify oobChannel is unchanged + verify(helper).createOobChannel(Matchers.any(), anyString()); + // verify ADS stream is reset + verify(oobChannel1, times(2)) + .newCall(Matchers.>any(), Matchers.any()); } - // TODO(zdapeng): test balancer name change + @Test + public void resolverEvent_balancerNameChange() throws Exception { + String lbConfigRaw = "{\"xds_experimental\" : { " + + "\"balancerName\" : \"dns:///balancer.example.com:8080\"," + + "\"childPolicy\" : [{\"unsupported\" : {\"key\" : \"val\"}}, {\"unsupported_2\" : {}}]," + + "\"fallbackPolicy\" : [{\"unsupported\" : {}}, {\"supported_1\" : {\"key\" : \"val\"}}]" + + "}}"; + @SuppressWarnings("unchecked") + Map lbConfig = (Map) JsonParser.parse(lbConfigRaw); + Attributes attrs = Attributes.newBuilder().set(ATTR_LOAD_BALANCING_CONFIG, lbConfig).build(); + + lb.handleResolvedAddressGroups(Collections.emptyList(), attrs); + verify(helper).createOobChannel(Matchers.any(), anyString()); + verify(oobChannel1) + .newCall(Matchers.>any(), Matchers.any()); + + lbConfigRaw = "{\"xds_experimental\" : { " + + "\"balancerName\" : \"dns:///balancer.example.com:8443\"," + + "\"childPolicy\" : [{\"supported_1\" : {\"key\" : \"val\"}}, {\"unsupported_2\" : {}}]," + + "\"fallbackPolicy\" : [{\"unsupported\" : {}}, {\"supported_1\" : {\"key\" : \"val\"}}]" + + "}}"; + @SuppressWarnings("unchecked") + Map lbConfig2 = (Map) JsonParser.parse(lbConfigRaw); + attrs = Attributes.newBuilder().set(ATTR_LOAD_BALANCING_CONFIG, lbConfig2).build(); + + lb.handleResolvedAddressGroups(Collections.emptyList(), attrs); + + assertThat(lb.getXdsLbStateForTest().childPolicy).isNotNull(); + + // verify oobChannel is unchanged + verify(helper, times(2)).createOobChannel(Matchers.any(), anyString()); + verify(oobChannel1) + .newCall(Matchers.>any(), Matchers.any()); + verify(oobChannel2) + .newCall(Matchers.>any(), Matchers.any()); + verifyNoMoreInteractions(oobChannel3); + } + + @Test + public void fallback_AdsNotWorkingYetTimerExpired() throws Exception { + lb.handleResolvedAddressGroups( + Collections.emptyList(), standardModeWithFallback1Attributes()); + + assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1); + assertThat(fakeClock.getPendingTasks()).isEmpty(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Attributes.class); + verify(fakeBalancer1).handleResolvedAddressGroups( + Matchers.>any(), captor.capture()); + assertThat(captor.getValue().get(ATTR_LOAD_BALANCING_CONFIG)) + .containsExactly("supported_1", new HashMap()); + } + + @Test + public void fallback_AdsWorkingTimerCancelled() throws Exception { + lb.handleResolvedAddressGroups( + Collections.emptyList(), standardModeWithFallback1Attributes()); + + serverResponseWriter.onNext(DiscoveryResponse.getDefaultInstance()); + + assertThat(fakeClock.getPendingTasks()).isEmpty(); + verify(fakeBalancer1, never()).handleResolvedAddressGroups( + Matchers.>any(), Matchers.any()); + } + + @Test + public void fallback_AdsErrorAndNoActiveSubchannel() throws Exception { + lb.handleResolvedAddressGroups( + Collections.emptyList(), standardModeWithFallback1Attributes()); + + serverResponseWriter.onError(new Exception("fake error")); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Attributes.class); + verify(fakeBalancer1).handleResolvedAddressGroups( + Matchers.>any(), captor.capture()); + assertThat(captor.getValue().get(ATTR_LOAD_BALANCING_CONFIG)) + .containsExactly("supported_1", new HashMap()); + + assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1); + assertThat(fakeClock.getPendingTasks()).isEmpty(); + + // verify handleResolvedAddressGroups() is not called again + verify(fakeBalancer1).handleResolvedAddressGroups( + Matchers.>any(), Matchers.any()); + } + + @Test + public void fallback_AdsErrorWithActiveSubchannel() throws Exception { + lb.handleResolvedAddressGroups( + Collections.emptyList(), standardModeWithFallback1Attributes()); + + serverResponseWriter.onNext(DiscoveryResponse.getDefaultInstance()); + doReturn(true).when(fakeSubchannelStore).hasReadyBackends(); + serverResponseWriter.onError(new Exception("fake error")); + + verify(fakeBalancer1, never()).handleResolvedAddressGroups( + Matchers.>any(), Matchers.any()); + + Subchannel subchannel = new Subchannel() { + @Override + public void shutdown() {} + + @Override + public void requestConnection() {} + + @Override + public Attributes getAttributes() { + return Attributes.newBuilder() + .set( + STATE_INFO, + new AtomicReference<>(ConnectivityStateInfo.forNonError(ConnectivityState.READY))) + .build(); + } + }; + + doReturn(true).when(fakeSubchannelStore).hasSubchannel(subchannel); + doReturn(false).when(fakeSubchannelStore).hasReadyBackends(); + lb.handleSubchannelState(subchannel, ConnectivityStateInfo.forTransientFailure( + Status.UNAVAILABLE)); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Attributes.class); + verify(fakeBalancer1).handleResolvedAddressGroups( + Matchers.>any(), captor.capture()); + assertThat(captor.getValue().get(ATTR_LOAD_BALANCING_CONFIG)) + .containsExactly("supported_1", new HashMap()); + } + + private static Attributes standardModeWithFallback1Attributes() throws Exception { + String lbConfigRaw = "{\"xds_experimental\" : { " + + "\"balancerName\" : \"dns:///balancer.example.com:8080\"," + + "\"fallbackPolicy\" : [{\"supported_1\" : {}}]" + + "}}"; + @SuppressWarnings("unchecked") + Map lbConfig = (Map) JsonParser.parse(lbConfigRaw); + return Attributes.newBuilder().set(ATTR_LOAD_BALANCING_CONFIG, lbConfig).build(); + } + + @Test + public void shutdown_cleanupTimers() throws Exception { + String lbConfigRaw = "{\"xds_experimental\" : { " + + "\"balancerName\" : \"dns:///balancer.example.com:8080\"," + + "\"childPolicy\" : [{\"unsupported\" : {\"key\" : \"val\"}}, {\"unsupported_2\" : {}}]," + + "\"fallbackPolicy\" : [{\"unsupported\" : {}}, {\"supported_1\" : {\"key\" : \"val\"}}]" + + "}}"; + @SuppressWarnings("unchecked") + Map lbConfig = (Map) JsonParser.parse(lbConfigRaw); + Attributes attrs = Attributes.newBuilder().set(ATTR_LOAD_BALANCING_CONFIG, lbConfig).build(); + lb.handleResolvedAddressGroups(Collections.emptyList(), attrs); + + assertThat(fakeClock.getPendingTasks()).isNotEmpty(); + lb.shutdown(); + assertThat(fakeClock.getPendingTasks()).isEmpty(); + } }