diff --git a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer.java b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer.java index cf9e7991df..d0b9248ac4 100644 --- a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer.java @@ -32,6 +32,7 @@ import io.grpc.Status; import io.grpc.internal.ObjectPool; import io.grpc.internal.ServiceConfigUtil.PolicySelection; import io.grpc.util.ForwardingLoadBalancerHelper; +import io.grpc.util.GracefulSwitchLoadBalancer; import io.grpc.xds.CdsLoadBalancerProvider.CdsConfig; import io.grpc.xds.EdsLoadBalancerProvider.EdsConfig; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; @@ -44,24 +45,31 @@ import io.grpc.xds.internal.sds.TlsContextManager; import io.grpc.xds.internal.sds.TlsContextManagerImpl; import java.util.ArrayList; import java.util.List; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicReference; import javax.annotation.Nullable; /** - * Load balancer for cds_experimental LB policy. One instance per cluster. + * Load balancer for cds_experimental LB policy. */ -final class CdsLoadBalancer extends LoadBalancer { +public final class CdsLoadBalancer extends LoadBalancer { private final XdsLogger logger; - private final Helper helper; private final LoadBalancerRegistry lbRegistry; + private final GracefulSwitchLoadBalancer switchingLoadBalancer; private final TlsContextManager tlsContextManager; // TODO(sanjaypujare): remove once xds security is released private boolean enableXdsSecurity; private static final String XDS_SECURITY_ENV_VAR = "GRPC_XDS_EXPERIMENTAL_SECURITY_SUPPORT"; + + // The following fields become non-null once handleResolvedAddresses() successfully. + + // Most recent cluster name. + @Nullable private String clusterName; + @Nullable private ObjectPool xdsClientPool; + @Nullable private XdsClient xdsClient; - private ChildLbState childLbState; - private ResolvedAddresses resolvedAddresses; CdsLoadBalancer(Helper helper) { this(helper, LoadBalancerRegistry.getDefaultRegistry(), TlsContextManagerImpl.getInstance()); @@ -70,8 +78,9 @@ final class CdsLoadBalancer extends LoadBalancer { @VisibleForTesting CdsLoadBalancer(Helper helper, LoadBalancerRegistry lbRegistry, TlsContextManager tlsContextManager) { - this.helper = checkNotNull(helper, "helper"); + checkNotNull(helper, "helper"); this.lbRegistry = lbRegistry; + this.switchingLoadBalancer = new GracefulSwitchLoadBalancer(helper); this.tlsContextManager = tlsContextManager; logger = XdsLogger.withLogId(InternalLogId.allocate("cds-lb", helper.getAuthority())); logger.log(XdsLogLevel.INFO, "Created"); @@ -79,32 +88,33 @@ final class CdsLoadBalancer extends LoadBalancer { @Override public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { - if (clusterName != null) { - return; - } logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses); - this.resolvedAddresses = resolvedAddresses; - xdsClientPool = resolvedAddresses.getAttributes().get(XdsAttributes.XDS_CLIENT_POOL); - checkNotNull(xdsClientPool, "missing xDS client pool"); - xdsClient = xdsClientPool.getObject(); + if (xdsClientPool == null) { + xdsClientPool = resolvedAddresses.getAttributes().get(XdsAttributes.XDS_CLIENT_POOL); + checkNotNull(xdsClientPool, "missing xDS client pool"); + xdsClient = xdsClientPool.getObject(); + } + Object lbConfig = resolvedAddresses.getLoadBalancingPolicyConfig(); checkNotNull(lbConfig, "missing CDS lb config"); CdsConfig newCdsConfig = (CdsConfig) lbConfig; logger.log( XdsLogLevel.INFO, "Received CDS lb config: cluster={0}", newCdsConfig.name); + + // If cluster is changed, do a graceful switch. + if (!newCdsConfig.name.equals(clusterName)) { + LoadBalancer.Factory clusterBalancerFactory = new ClusterBalancerFactory(newCdsConfig.name); + switchingLoadBalancer.switchTo(clusterBalancerFactory); + } + switchingLoadBalancer.handleResolvedAddresses(resolvedAddresses); clusterName = newCdsConfig.name; - childLbState = new ChildLbState(); } @Override public void handleNameResolutionError(Status error) { logger.log(XdsLogLevel.WARNING, "Received name resolution error: {0}", error); - if (childLbState != null) { - childLbState.propagateError(error); - } else { - helper.updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(error)); - } + switchingLoadBalancer.handleNameResolutionError(error); } @Override @@ -115,9 +125,7 @@ final class CdsLoadBalancer extends LoadBalancer { @Override public void shutdown() { logger.log(XdsLogLevel.INFO, "Shutdown"); - if (childLbState != null) { - childLbState.shutdown(); - } + switchingLoadBalancer.shutdown(); if (xdsClientPool != null) { xdsClientPool.returnObject(xdsClient); } @@ -134,25 +142,107 @@ final class CdsLoadBalancer extends LoadBalancer { enableXdsSecurity = enable; } - private final class ChannelSecurityLbHelper extends ForwardingLoadBalancerHelper { - @Nullable - private SslContextProvider sslContextProvider; + /** + * A load balancer factory that provides a load balancer for a given cluster. + */ + private final class ClusterBalancerFactory extends LoadBalancer.Factory { + + final String clusterName; + + ClusterBalancerFactory(String clusterName) { + this.clusterName = clusterName; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof ClusterBalancerFactory)) { + return false; + } + ClusterBalancerFactory that = (ClusterBalancerFactory) o; + return clusterName.equals(that.clusterName); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), clusterName); + } + + @Override + public LoadBalancer newLoadBalancer(final Helper helper) { + return new LoadBalancer() { + // Becomes non-null once handleResolvedAddresses() successfully. + // Assigned at most once. + @Nullable + ClusterWatcherImpl clusterWatcher; + + @Override + public void handleNameResolutionError(Status error) { + if (clusterWatcher == null || clusterWatcher.edsBalancer == null) { + // Go into TRANSIENT_FAILURE if we have not yet received any cluster resource. + // Otherwise, we keep running with the data we had previously. + helper.updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(error)); + } + } + + @Override + public boolean canHandleEmptyAddressListFromNameResolution() { + return true; + } + + @Override + public void shutdown() { + if (clusterWatcher != null) { + if (clusterWatcher.edsBalancer != null) { + clusterWatcher.edsBalancer.shutdown(); + } + xdsClient.cancelClusterDataWatch(clusterName, clusterWatcher); + logger.log( + XdsLogLevel.INFO, + "Cancelled cluster watcher on {0} with xDS client {1}", + clusterName, xdsClient); + } + } + + @Override + public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + if (clusterWatcher == null) { + clusterWatcher = new ClusterWatcherImpl(helper, resolvedAddresses); + logger.log( + XdsLogLevel.INFO, + "Start cluster watcher on {0} with xDS client {1}", + clusterName, xdsClient); + xdsClient.watchClusterData(clusterName, clusterWatcher); + } + } + }; + } + } + + private static final class EdsLoadBalancingHelper extends ForwardingLoadBalancerHelper { + private final Helper delegate; + private final AtomicReference sslContextProvider; + + EdsLoadBalancingHelper(Helper helper, + AtomicReference sslContextProvider) { + this.delegate = helper; + this.sslContextProvider = sslContextProvider; + } @Override public Subchannel createSubchannel(CreateSubchannelArgs createSubchannelArgs) { - if (sslContextProvider != null) { + if (sslContextProvider.get() != null) { createSubchannelArgs = createSubchannelArgs .toBuilder() .setAddresses( addUpstreamTlsContext(createSubchannelArgs.getAddresses(), - sslContextProvider.getUpstreamTlsContext())) + sslContextProvider.get().getUpstreamTlsContext())) .build(); } - return delegate().createSubchannel(createSubchannelArgs); + return delegate.createSubchannel(createSubchannelArgs); } - private List addUpstreamTlsContext( + private static List addUpstreamTlsContext( List addresses, UpstreamTlsContext upstreamTlsContext) { if (upstreamTlsContext == null || addresses == null) { @@ -174,19 +264,22 @@ final class CdsLoadBalancer extends LoadBalancer { @Override protected Helper delegate() { - return helper; + return delegate; } } - private final class ChildLbState implements ClusterWatcher { - private final ChannelSecurityLbHelper lbHelper = new ChannelSecurityLbHelper(); + private final class ClusterWatcherImpl implements ClusterWatcher { + + final EdsLoadBalancingHelper helper; + final ResolvedAddresses resolvedAddresses; + @Nullable LoadBalancer edsBalancer; - private ChildLbState() { - xdsClient.watchClusterData(clusterName, this); - logger.log(XdsLogLevel.INFO, - "Started watcher for cluster {0} with xDS client {1}", clusterName, xdsClient); + ClusterWatcherImpl(Helper helper, ResolvedAddresses resolvedAddresses) { + this.helper = new EdsLoadBalancingHelper(helper, + new AtomicReference()); + this.resolvedAddresses = resolvedAddresses; } @Override @@ -215,7 +308,7 @@ final class CdsLoadBalancer extends LoadBalancer { updateSslContextProvider(newUpdate.getUpstreamTlsContext()); } if (edsBalancer == null) { - edsBalancer = lbRegistry.getProvider(EDS_POLICY_NAME).newLoadBalancer(lbHelper); + edsBalancer = lbRegistry.getProvider(EDS_POLICY_NAME).newLoadBalancer(helper); } edsBalancer.handleResolvedAddresses( resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(edsConfig).build()); @@ -223,7 +316,8 @@ final class CdsLoadBalancer extends LoadBalancer { /** For new UpstreamTlsContext value, release old SslContextProvider. */ private void updateSslContextProvider(UpstreamTlsContext newUpstreamTlsContext) { - SslContextProvider oldSslContextProvider = lbHelper.sslContextProvider; + SslContextProvider oldSslContextProvider = + helper.sslContextProvider.get(); if (oldSslContextProvider != null) { UpstreamTlsContext oldUpstreamTlsContext = oldSslContextProvider.getUpstreamTlsContext(); @@ -233,10 +327,11 @@ final class CdsLoadBalancer extends LoadBalancer { tlsContextManager.releaseClientSslContextProvider(oldSslContextProvider); } if (newUpstreamTlsContext != null) { - lbHelper.sslContextProvider = + SslContextProvider newSslContextProvider = tlsContextManager.findOrCreateClientSslContextProvider(newUpstreamTlsContext); + helper.sslContextProvider.set(newSslContextProvider); } else { - lbHelper.sslContextProvider = null; + helper.sslContextProvider.set(null); } } @@ -265,22 +360,6 @@ final class CdsLoadBalancer extends LoadBalancer { helper.updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(error)); } } - - void shutdown() { - xdsClient.cancelClusterDataWatch(clusterName, this); - logger.log(XdsLogLevel.INFO, - "Cancelled watcher for cluster {0} with xDS client {1}", clusterName, xdsClient); - if (edsBalancer != null) { - edsBalancer.shutdown(); - } - } - - void propagateError(Status error) { - if (edsBalancer != null) { - edsBalancer.handleNameResolutionError(error); - } else { - helper.updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(error)); - } - } } + } diff --git a/xds/src/test/java/io/grpc/xds/CdsLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/CdsLoadBalancerTest.java index 1f2261cbde..9a9bb69241 100644 --- a/xds/src/test/java/io/grpc/xds/CdsLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/CdsLoadBalancerTest.java @@ -17,10 +17,26 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.ConnectivityState.CONNECTING; +import static io.grpc.ConnectivityState.READY; +import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; +import static io.grpc.xds.XdsLbPolicies.EDS_POLICY_NAME; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_CLIENT_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.BAD_CLIENT_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE; +import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; -import com.google.common.collect.Iterables; +import com.google.common.collect.ImmutableList; import io.grpc.Attributes; import io.grpc.ConnectivityState; import io.grpc.EquivalentAddressGroup; @@ -30,34 +46,38 @@ import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.ResolvedAddresses; -import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancerProvider; import io.grpc.LoadBalancerRegistry; -import io.grpc.ManagedChannel; -import io.grpc.NameResolver; +import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.SynchronizationContext; -import io.grpc.internal.ObjectPool; +import io.grpc.internal.FakeClock; +import io.grpc.internal.ServiceConfigUtil.PolicySelection; import io.grpc.xds.CdsLoadBalancerProvider.CdsConfig; import io.grpc.xds.EdsLoadBalancerProvider.EdsConfig; -import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; +import io.grpc.xds.XdsClient.ClusterUpdate; +import io.grpc.xds.XdsClient.ClusterWatcher; +import io.grpc.xds.XdsClient.RefCountedXdsClientObjectPool; +import io.grpc.xds.XdsClient.XdsClientFactory; import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; import io.grpc.xds.internal.sds.SslContextProvider; import io.grpc.xds.internal.sds.TlsContextManager; -import java.net.SocketAddress; +import java.net.InetSocketAddress; +import java.util.ArrayDeque; import java.util.ArrayList; -import java.util.Collections; +import java.util.Deque; +import java.util.HashMap; import java.util.List; -import javax.annotation.Nonnull; -import javax.annotation.Nullable; -import org.junit.After; +import java.util.Map; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; import org.mockito.MockitoAnnotations; /** @@ -65,323 +85,19 @@ import org.mockito.MockitoAnnotations; */ @RunWith(JUnit4.class) public class CdsLoadBalancerTest { - private static final String AUTHORITY = "api.google.com"; - private static final String CLUSTER = "cluster-foo.googleapis.com"; - private final SynchronizationContext syncContext = new SynchronizationContext( - new Thread.UncaughtExceptionHandler() { + + private final RefCountedXdsClientObjectPool xdsClientPool = new RefCountedXdsClientObjectPool( + new XdsClientFactory() { @Override - public void uncaughtException(Thread t, Throwable e) { - throw new AssertionError(e); + XdsClient createXdsClient() { + xdsClient = mock(XdsClient.class); + return xdsClient; } - }); - private final List childBalancers = new ArrayList<>(); - private final FakeXdsClient xdsClient = new FakeXdsClient(); - private final TlsContextManager tlsContextManager = new FakeTlsContextManager(); - private LoadBalancer.Helper helper = new FakeLbHelper(); - private int xdsClientRefs; - private ConnectivityState currentState; - private SubchannelPicker currentPicker; - private CdsLoadBalancer loadBalancer; - - @Before - public void setUp() { - MockitoAnnotations.initMocks(this); - - LoadBalancerRegistry registry = new LoadBalancerRegistry(); - registry.register(new FakeLoadBalancerProvider(XdsLbPolicies.EDS_POLICY_NAME)); - registry.register(new FakeLoadBalancerProvider("round_robin")); - ObjectPool xdsClientPool = new ObjectPool() { - @Override - public XdsClient getObject() { - xdsClientRefs++; - return xdsClient; } + ); - @Override - public XdsClient returnObject(Object object) { - assertThat(xdsClientRefs).isGreaterThan(0); - xdsClientRefs--; - return null; - } - }; - loadBalancer = new CdsLoadBalancer(helper, registry, tlsContextManager); - loadBalancer.handleResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(Collections.emptyList()) - .setLoadBalancingPolicyConfig(new CdsConfig(CLUSTER)) - .setAttributes( - Attributes.newBuilder().set(XdsAttributes.XDS_CLIENT_POOL, xdsClientPool).build()) - .build()); - assertThat(xdsClient.watcher).isNotNull(); - } - - @After - public void tearDown() { - loadBalancer.shutdown(); - assertThat(xdsClient.watcher).isNull(); - assertThat(xdsClientRefs).isEqualTo(0); - } - - - @Test - public void receiveFirstClusterResourceInfo() { - xdsClient.deliverClusterInfo(null, null); - assertThat(childBalancers).hasSize(1); - FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - assertThat(childBalancer.name).isEqualTo(XdsLbPolicies.EDS_POLICY_NAME); - assertThat(childBalancer.config).isNotNull(); - EdsConfig edsConfig = (EdsConfig) childBalancer.config; - assertThat(edsConfig.clusterName).isEqualTo(CLUSTER); - assertThat(edsConfig.edsServiceName).isNull(); - assertThat(edsConfig.lrsServerName).isNull(); - assertThat(edsConfig.endpointPickingPolicy.getProvider().getPolicyName()) - .isEqualTo("round_robin"); - } - - @Test - public void clusterResourceNeverExist() { - xdsClient.deliverResourceNotFound(); - assertThat(childBalancers).isEmpty(); - assertThat(currentState).isEqualTo(ConnectivityState.TRANSIENT_FAILURE); - PickResult result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class)); - assertThat(result.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE); - assertThat(result.getStatus().getDescription()) - .isEqualTo("Resource " + CLUSTER + " is unavailable"); - } - - @Test - public void clusterResourceRemoved() { - xdsClient.deliverClusterInfo(null, null); - assertThat(childBalancers).hasSize(1); - FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - assertThat(childBalancer.shutdown).isFalse(); - - xdsClient.deliverResourceNotFound(); - assertThat(childBalancer.shutdown).isTrue(); - assertThat(currentState).isEqualTo(ConnectivityState.TRANSIENT_FAILURE); - PickResult result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class)); - assertThat(result.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE); - assertThat(result.getStatus().getDescription()) - .isEqualTo("Resource " + CLUSTER + " is unavailable"); - } - - @Test - public void clusterResourceUpdated() { - xdsClient.deliverClusterInfo(null, null); - FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - EdsConfig edsConfig = (EdsConfig) childBalancer.config; - assertThat(edsConfig.clusterName).isEqualTo(CLUSTER); - assertThat(edsConfig.edsServiceName).isNull(); - assertThat(edsConfig.lrsServerName).isNull(); - assertThat(edsConfig.endpointPickingPolicy.getProvider().getPolicyName()) - .isEqualTo("round_robin"); - - String edsService = "service-bar.googleapis.com"; - String loadReportServer = "lrs-server.googleapis.com"; - xdsClient.deliverClusterInfo(edsService, loadReportServer); - assertThat(childBalancers).containsExactly(childBalancer); - edsConfig = (EdsConfig) childBalancer.config; - assertThat(edsConfig.clusterName).isEqualTo(CLUSTER); - assertThat(edsConfig.edsServiceName).isEqualTo(edsService); - assertThat(edsConfig.lrsServerName).isEqualTo(loadReportServer); - assertThat(edsConfig.endpointPickingPolicy.getProvider().getPolicyName()) - .isEqualTo("round_robin"); - } - - @Test - public void receiveClusterResourceInfoWithUpstreamTlsContext() { - loadBalancer.setXdsSecurity(true); - UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( - CommonTlsContextTestsUtil.CLIENT_KEY_FILE, - CommonTlsContextTestsUtil.CLIENT_PEM_FILE, - CommonTlsContextTestsUtil.CA_PEM_FILE); - xdsClient.deliverClusterInfo(null, null, upstreamTlsContext); - FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - List addresses = createEndpointAddresses(2); - CreateSubchannelArgs args = - CreateSubchannelArgs.newBuilder() - .setAddresses(addresses) - .build(); - Subchannel subchannel = childBalancer.helper.createSubchannel(args); - for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) { - assertThat(eag.getAttributes().get(XdsAttributes.ATTR_UPSTREAM_TLS_CONTEXT)) - .isEqualTo(upstreamTlsContext); - } - - xdsClient.deliverClusterInfo(null, null); - subchannel = childBalancer.helper.createSubchannel(args); - for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) { - assertThat(eag.getAttributes().get(XdsAttributes.ATTR_UPSTREAM_TLS_CONTEXT)).isNull(); - } - - upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( - CommonTlsContextTestsUtil.BAD_CLIENT_KEY_FILE, - CommonTlsContextTestsUtil.BAD_CLIENT_PEM_FILE, - CommonTlsContextTestsUtil.CA_PEM_FILE); - xdsClient.deliverClusterInfo(null, null, upstreamTlsContext); - subchannel = childBalancer.helper.createSubchannel(args); - for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) { - assertThat(eag.getAttributes().get(XdsAttributes.ATTR_UPSTREAM_TLS_CONTEXT)) - .isEqualTo(upstreamTlsContext); - } - } - - @Test - public void subchannelStatePropagateFromDownstreamToUpstream() { - xdsClient.deliverClusterInfo(null, null); - FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - List addresses = createEndpointAddresses(2); - CreateSubchannelArgs args = - CreateSubchannelArgs.newBuilder() - .setAddresses(addresses) - .build(); - Subchannel subchannel = childBalancer.helper.createSubchannel(args); - childBalancer.deliverSubchannelState(subchannel, ConnectivityState.READY); - assertThat(currentState).isEqualTo(ConnectivityState.READY); - assertThat(currentPicker.pickSubchannel(mock(PickSubchannelArgs.class)).getSubchannel()) - .isSameInstanceAs(subchannel); - } - - @Test - public void clusterDiscoveryError_beforeChildPolicyInstantiated_propagateToUpstream() { - xdsClient.deliverError(Status.UNAUTHENTICATED.withDescription("permission denied")); - assertThat(currentState).isEqualTo(ConnectivityState.TRANSIENT_FAILURE); - PickResult result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class)); - assertThat(result.getStatus().isOk()).isFalse(); - assertThat(result.getStatus().getCode()).isEqualTo(Code.UNAUTHENTICATED); - assertThat(result.getStatus().getDescription()).isEqualTo("permission denied"); - } - - @Test - public void clusterDiscoveryError_afterChildPolicyInstantiated_keepUsingCurrentCluster() { - xdsClient.deliverClusterInfo(null, null); - FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - xdsClient.deliverError(Status.UNAVAILABLE.withDescription("unreachable")); - assertThat(currentState).isNull(); - assertThat(currentPicker).isNull(); - assertThat(childBalancer.shutdown).isFalse(); - } - - @Test - public void nameResolutionError_beforeChildPolicyInstantiated_returnErrorPickerToUpstream() { - loadBalancer.handleNameResolutionError( - Status.UNIMPLEMENTED.withDescription("not found")); - assertThat(currentState).isEqualTo(ConnectivityState.TRANSIENT_FAILURE); - PickResult result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class)); - assertThat(result.getStatus().isOk()).isFalse(); - assertThat(result.getStatus().getCode()).isEqualTo(Code.UNIMPLEMENTED); - assertThat(result.getStatus().getDescription()).isEqualTo("not found"); - } - - @Test - public void nameResolutionError_afterChildPolicyInstantiated_propagateToDownstream() { - xdsClient.deliverClusterInfo(null, null); - FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - loadBalancer.handleNameResolutionError( - Status.UNAVAILABLE.withDescription("cannot reach server")); - assertThat(childBalancer.upstreamError.getCode()).isEqualTo(Code.UNAVAILABLE); - assertThat(childBalancer.upstreamError.getDescription()) - .isEqualTo("cannot reach server"); - } - - private static List createEndpointAddresses(int n) { - List list = new ArrayList<>(); - for (int i = 0; i < n; i++) { - list.add(new EquivalentAddressGroup(mock(SocketAddress.class))); - } - return list; - } - - private final class FakeXdsClient extends XdsClient { - private ClusterWatcher watcher; - - @Override - void watchClusterData(String clusterName, ClusterWatcher watcher) { - assertThat(clusterName).isEqualTo(CLUSTER); - this.watcher = watcher; - } - - @Override - void cancelClusterDataWatch(String clusterName, ClusterWatcher watcher) { - assertThat(clusterName).isEqualTo(CLUSTER); - assertThat(watcher).isSameInstanceAs(this.watcher); - this.watcher = null; - } - - @Override - void shutdown() { - // no-op - } - - void deliverClusterInfo( - @Nullable final String edsServiceName, @Nullable final String lrsServerName) { - syncContext.execute(new Runnable() { - @Override - public void run() { - watcher.onClusterChanged( - ClusterUpdate.newBuilder() - .setClusterName(CLUSTER) - .setEdsServiceName(edsServiceName) - .setLbPolicy("round_robin") // only supported policy - .setLrsServerName(lrsServerName) - .build()); - } - }); - } - - void deliverClusterInfo( - @Nullable final String edsServiceName, @Nullable final String lrsServerName, - final UpstreamTlsContext tlsContext) { - syncContext.execute(new Runnable() { - @Override - public void run() { - watcher.onClusterChanged( - ClusterUpdate.newBuilder() - .setClusterName(CLUSTER) - .setEdsServiceName(edsServiceName) - .setLbPolicy("round_robin") // only supported policy - .setLrsServerName(lrsServerName) - .setUpstreamTlsContext(tlsContext) - .build()); - } - }); - } - - void deliverResourceNotFound() { - syncContext.execute(new Runnable() { - @Override - public void run() { - watcher.onResourceDoesNotExist(CLUSTER); - } - }); - } - - void deliverError(final Status error) { - syncContext.execute(new Runnable() { - @Override - public void run() { - watcher.onError(error); - } - }); - } - } - - private final class FakeLoadBalancerProvider extends LoadBalancerProvider { - private final String policyName; - - FakeLoadBalancerProvider(String policyName) { - this.policyName = policyName; - } - - @Override - public LoadBalancer newLoadBalancer(Helper helper) { - FakeLoadBalancer balancer = new FakeLoadBalancer(policyName, helper); - childBalancers.add(balancer); - return balancer; - } - + private final LoadBalancerRegistry lbRegistry = new LoadBalancerRegistry(); + private final LoadBalancerProvider fakeEdsLoadBlancerProvider = new LoadBalancerProvider() { @Override public boolean isAvailable() { return true; @@ -389,138 +105,488 @@ public class CdsLoadBalancerTest { @Override public int getPriority() { - return 0; // doesn't matter + return 5; } @Override public String getPolicyName() { - return policyName; - } - } - - private final class FakeLoadBalancer extends LoadBalancer { - private final String name; - private final Helper helper; - private Object config; - private Status upstreamError; - private boolean shutdown; - - FakeLoadBalancer(String name, Helper helper) { - this.name = name; - this.helper = helper; + return EDS_POLICY_NAME; } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { - config = resolvedAddresses.getLoadBalancingPolicyConfig(); + public LoadBalancer newLoadBalancer(Helper helper) { + edsLbHelpers.add(helper); + LoadBalancer edsLoadBalancer = mock(LoadBalancer.class); + edsLoadBalancers.add(edsLoadBalancer); + return edsLoadBalancer; + } + }; + + private final LoadBalancerProvider fakeRoundRobinLbProvider = new LoadBalancerProvider() { + @Override + public boolean isAvailable() { + return true; } @Override - public void handleNameResolutionError(Status error) { - upstreamError = error; + public int getPriority() { + return 5; } @Override - public void shutdown() { - shutdown = true; - childBalancers.remove(this); + public String getPolicyName() { + return "round_robin"; } - void deliverSubchannelState(final Subchannel subchannel, ConnectivityState state) { - SubchannelPicker picker = new SubchannelPicker() { + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + return mock(LoadBalancer.class); + } + + @Override + public ConfigOrError parseLoadBalancingPolicyConfig( + Map rawLoadBalancingPolicyConfig) { + return ConfigOrError.fromConfig("fake round robin config"); + } + }; + + private final SynchronizationContext syncContext = new SynchronizationContext( + new Thread.UncaughtExceptionHandler() { @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return PickResult.withSubchannel(subchannel); + public void uncaughtException(Thread t, Throwable e) { + throw new AssertionError(e); } - }; - helper.updateBalancingState(state, picker); - } + }); + + private final FakeClock fakeClock = new FakeClock(); + private final Deque edsLoadBalancers = new ArrayDeque<>(); + private final Deque edsLbHelpers = new ArrayDeque<>(); + + @Mock + private Helper helper; + + private LoadBalancer cdsLoadBalancer; + private XdsClient xdsClient; + + @Mock + private TlsContextManager mockTlsContextManager; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + + doReturn(syncContext).when(helper).getSynchronizationContext(); + doReturn(fakeClock.getScheduledExecutorService()).when(helper).getScheduledExecutorService(); + lbRegistry.register(fakeEdsLoadBlancerProvider); + lbRegistry.register(fakeRoundRobinLbProvider); + cdsLoadBalancer = new CdsLoadBalancer(helper, lbRegistry, mockTlsContextManager); } - private final class FakeLbHelper extends LoadBalancer.Helper { - - @Override - public void updateBalancingState( - @Nonnull ConnectivityState newState, @Nonnull SubchannelPicker newPicker) { - currentState = newState; - currentPicker = newPicker; - } - - @Override - public Subchannel createSubchannel(CreateSubchannelArgs args) { - return new FakeSubchannel(args.getAddresses()); - } - - @Override - public ManagedChannel createOobChannel(EquivalentAddressGroup eag, String authority) { - throw new UnsupportedOperationException("should not be called"); - } - - @Deprecated - @Override - public NameResolver.Factory getNameResolverFactory() { - throw new UnsupportedOperationException("should not be called"); - } - - @Override - public String getAuthority() { - return AUTHORITY; - } + @Test + public void canHandleEmptyAddressListFromNameResolution() { + assertThat(cdsLoadBalancer.canHandleEmptyAddressListFromNameResolution()).isTrue(); } - private static final class FakeSubchannel extends Subchannel { - private final List eags; + @Test + public void handleResolutionErrorBeforeOrAfterCdsWorking() { + ResolvedAddresses resolvedAddresses1 = ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setAttributes(Attributes.newBuilder() + .set(XdsAttributes.XDS_CLIENT_POOL, xdsClientPool) + .build()) + .setLoadBalancingPolicyConfig(new CdsConfig("foo.googleapis.com")) + .build(); + cdsLoadBalancer.handleResolvedAddresses(resolvedAddresses1); + ArgumentCaptor clusterWatcherCaptor1 = ArgumentCaptor.forClass(null); + verify(xdsClient).watchClusterData(eq("foo.googleapis.com"), clusterWatcherCaptor1.capture()); + ClusterWatcher clusterWatcher1 = clusterWatcherCaptor1.getValue(); - private FakeSubchannel(List eags) { - this.eags = eags; - } + // handleResolutionError() before receiving any CDS response. + cdsLoadBalancer.handleNameResolutionError(Status.DATA_LOSS.withDescription("fake status")); + verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); - @Override - public void shutdown() { - } + // CDS response received. + clusterWatcher1.onClusterChanged( + ClusterUpdate.newBuilder() + .setClusterName("foo.googleapis.com") + .setEdsServiceName("edsServiceFoo.googleapis.com") + .setLbPolicy("round_robin") + .build()); + verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); - @Override - public void requestConnection() { - } - - @Override - public List getAllAddresses() { - return eags; - } - - @Override - public Attributes getAttributes() { - return Attributes.EMPTY; - } + // handleResolutionError() after receiving CDS response. + cdsLoadBalancer.handleNameResolutionError(Status.DATA_LOSS.withDescription("fake status")); + // No more TRANSIENT_FAILURE. + verify(helper, times(1)).updateBalancingState( + eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); } - private static final class FakeTlsContextManager implements TlsContextManager { + @Test + public void handleCdsConfigUpdate() { + assertThat(xdsClient).isNull(); + ResolvedAddresses resolvedAddresses1 = ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setAttributes(Attributes.newBuilder() + .set(XdsAttributes.XDS_CLIENT_POOL, xdsClientPool) + .build()) + .setLoadBalancingPolicyConfig(new CdsConfig("foo.googleapis.com")) + .build(); + cdsLoadBalancer.handleResolvedAddresses(resolvedAddresses1); - @Override - public SslContextProvider findOrCreateClientSslContextProvider( - UpstreamTlsContext upstreamTlsContext) { - SslContextProvider sslContextProvider = mock(SslContextProvider.class); - when(sslContextProvider.getUpstreamTlsContext()).thenReturn(upstreamTlsContext); - return sslContextProvider; - } + ArgumentCaptor clusterWatcherCaptor1 = ArgumentCaptor.forClass(null); + verify(xdsClient).watchClusterData(eq("foo.googleapis.com"), clusterWatcherCaptor1.capture()); - @Override - public SslContextProvider releaseClientSslContextProvider( - SslContextProvider sslContextProvider) { - // no-op - return null; - } + ClusterWatcher clusterWatcher1 = clusterWatcherCaptor1.getValue(); + clusterWatcher1.onClusterChanged( + ClusterUpdate.newBuilder() + .setClusterName("foo.googleapis.com") + .setEdsServiceName("edsServiceFoo.googleapis.com") + .setLbPolicy("round_robin") + .build()); - @Override - public SslContextProvider findOrCreateServerSslContextProvider( - DownstreamTlsContext downstreamTlsContext) { - throw new UnsupportedOperationException("should not be called"); - } + assertThat(edsLbHelpers).hasSize(1); + assertThat(edsLoadBalancers).hasSize(1); + Helper edsLbHelper1 = edsLbHelpers.poll(); + LoadBalancer edsLoadBalancer1 = edsLoadBalancers.poll(); + ArgumentCaptor resolvedAddressesCaptor1 = ArgumentCaptor.forClass(null); + verify(edsLoadBalancer1).handleResolvedAddresses(resolvedAddressesCaptor1.capture()); + PolicySelection roundRobinPolicy = new PolicySelection( + fakeRoundRobinLbProvider, new HashMap(), "fake round robin config"); + EdsConfig expectedEdsConfig = new EdsConfig( + "foo.googleapis.com", + "edsServiceFoo.googleapis.com", + null, + roundRobinPolicy); + ResolvedAddresses resolvedAddressesFoo = resolvedAddressesCaptor1.getValue(); + assertThat(resolvedAddressesFoo.getLoadBalancingPolicyConfig()).isEqualTo(expectedEdsConfig); + assertThat(resolvedAddressesFoo.getAttributes().get(XdsAttributes.XDS_CLIENT_POOL)) + .isSameInstanceAs(xdsClientPool); - @Override - public SslContextProvider releaseServerSslContextProvider( - SslContextProvider sslContextProvider) { - throw new UnsupportedOperationException("should not be called"); - } + SubchannelPicker picker1 = mock(SubchannelPicker.class); + edsLbHelper1.updateBalancingState(ConnectivityState.READY, picker1); + verify(helper).updateBalancingState(ConnectivityState.READY, picker1); + + ResolvedAddresses resolvedAddresses2 = ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setAttributes(Attributes.newBuilder() + .set(XdsAttributes.XDS_CLIENT_POOL, xdsClientPool) + .build()) + .setLoadBalancingPolicyConfig(new CdsConfig("bar.googleapis.com")) + .build(); + cdsLoadBalancer.handleResolvedAddresses(resolvedAddresses2); + + ArgumentCaptor clusterWatcherCaptor2 = ArgumentCaptor.forClass(null); + verify(xdsClient).watchClusterData(eq("bar.googleapis.com"), clusterWatcherCaptor2.capture()); + + ClusterWatcher clusterWatcher2 = clusterWatcherCaptor2.getValue(); + clusterWatcher2.onClusterChanged( + ClusterUpdate.newBuilder() + .setClusterName("bar.googleapis.com") + .setEdsServiceName("edsServiceBar.googleapis.com") + .setLbPolicy("round_robin") + .setLrsServerName("lrsBar.googleapis.com") + .build()); + + assertThat(edsLbHelpers).hasSize(1); + assertThat(edsLoadBalancers).hasSize(1); + Helper edsLbHelper2 = edsLbHelpers.poll(); + LoadBalancer edsLoadBalancer2 = edsLoadBalancers.poll(); + ArgumentCaptor resolvedAddressesCaptor2 = ArgumentCaptor.forClass(null); + verify(edsLoadBalancer2).handleResolvedAddresses(resolvedAddressesCaptor2.capture()); + expectedEdsConfig = new EdsConfig( + "bar.googleapis.com", + "edsServiceBar.googleapis.com", + "lrsBar.googleapis.com", + roundRobinPolicy); + ResolvedAddresses resolvedAddressesBar = resolvedAddressesCaptor2.getValue(); + assertThat(resolvedAddressesBar.getLoadBalancingPolicyConfig()).isEqualTo(expectedEdsConfig); + assertThat(resolvedAddressesBar.getAttributes().get(XdsAttributes.XDS_CLIENT_POOL)) + .isSameInstanceAs(xdsClientPool); + + SubchannelPicker picker2 = mock(SubchannelPicker.class); + edsLbHelper2.updateBalancingState(ConnectivityState.CONNECTING, picker2); + verify(helper, never()).updateBalancingState(ConnectivityState.CONNECTING, picker2); + verify(edsLoadBalancer1, never()).shutdown(); + + picker2 = mock(SubchannelPicker.class); + edsLbHelper2.updateBalancingState(ConnectivityState.READY, picker2); + verify(helper).updateBalancingState(ConnectivityState.READY, picker2); + verify(edsLoadBalancer1).shutdown(); + verify(xdsClient).cancelClusterDataWatch("foo.googleapis.com", clusterWatcher1); + + clusterWatcher2.onClusterChanged( + ClusterUpdate.newBuilder() + .setClusterName("bar.googleapis.com") + .setEdsServiceName("edsServiceBar2.googleapis.com") + .setLbPolicy("round_robin") + .build()); + verify(edsLoadBalancer2, times(2)).handleResolvedAddresses(resolvedAddressesCaptor2.capture()); + expectedEdsConfig = new EdsConfig( + "bar.googleapis.com", + "edsServiceBar2.googleapis.com", + null, + roundRobinPolicy); + ResolvedAddresses resolvedAddressesBar2 = resolvedAddressesCaptor2.getValue(); + assertThat(resolvedAddressesBar2.getLoadBalancingPolicyConfig()).isEqualTo(expectedEdsConfig); + + cdsLoadBalancer.shutdown(); + verify(edsLoadBalancer2).shutdown(); + verify(xdsClient).cancelClusterDataWatch("bar.googleapis.com", clusterWatcher2); + assertThat(xdsClientPool.xdsClient).isNull(); + } + + @Test + public void handleCdsConfigUpdate_withUpstreamTlsContext() { + assertThat(cdsLoadBalancer).isInstanceOf(CdsLoadBalancer.class); + ((CdsLoadBalancer)cdsLoadBalancer).setXdsSecurity(true); + assertThat(xdsClient).isNull(); + ResolvedAddresses resolvedAddresses1 = + ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setAttributes( + Attributes.newBuilder() + .set(XdsAttributes.XDS_CLIENT_POOL, xdsClientPool) + .build()) + .setLoadBalancingPolicyConfig(new CdsConfig("foo.googleapis.com")) + .build(); + cdsLoadBalancer.handleResolvedAddresses(resolvedAddresses1); + + ArgumentCaptor clusterWatcherCaptor1 = ArgumentCaptor.forClass(null); + verify(xdsClient).watchClusterData(eq("foo.googleapis.com"), clusterWatcherCaptor1.capture()); + + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( + CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); + + SslContextProvider mockSslContextProvider = mock(SslContextProvider.class); + doReturn(upstreamTlsContext).when(mockSslContextProvider).getUpstreamTlsContext(); + doReturn(mockSslContextProvider).when(mockTlsContextManager) + .findOrCreateClientSslContextProvider(same(upstreamTlsContext)); + + ClusterWatcher clusterWatcher1 = clusterWatcherCaptor1.getValue(); + clusterWatcher1.onClusterChanged( + ClusterUpdate.newBuilder() + .setClusterName("foo.googleapis.com") + .setEdsServiceName("edsServiceFoo.googleapis.com") + .setLbPolicy("round_robin") + .setUpstreamTlsContext(upstreamTlsContext) + .build()); + + assertThat(edsLbHelpers).hasSize(1); + assertThat(edsLoadBalancers).hasSize(1); + verify(mockTlsContextManager, never()) + .releaseClientSslContextProvider(any(SslContextProvider.class)); + Helper edsLbHelper1 = edsLbHelpers.poll(); + + ArrayList eagList = new ArrayList<>(); + eagList.add(new EquivalentAddressGroup(new InetSocketAddress("foo.com", 8080))); + eagList.add(new EquivalentAddressGroup(InetSocketAddress.createUnresolved("localhost", 8081), + Attributes.newBuilder().set(XdsAttributes.XDS_CLIENT_POOL, xdsClientPool).build())); + LoadBalancer.CreateSubchannelArgs createSubchannelArgs = + LoadBalancer.CreateSubchannelArgs.newBuilder() + .setAddresses(eagList) + .build(); + ArgumentCaptor createSubchannelArgsCaptor1 = + ArgumentCaptor.forClass(null); + verify(helper, never()) + .createSubchannel(any(LoadBalancer.CreateSubchannelArgs.class)); + edsLbHelper1.createSubchannel(createSubchannelArgs); + verifyUpstreamTlsContextAttribute(upstreamTlsContext, + createSubchannelArgsCaptor1); + + // update with same upstreamTlsContext + reset(mockTlsContextManager); + clusterWatcher1.onClusterChanged( + ClusterUpdate.newBuilder() + .setClusterName("bar.googleapis.com") + .setEdsServiceName("eds1ServiceFoo.googleapis.com") + .setLbPolicy("round_robin") + .setUpstreamTlsContext(upstreamTlsContext) + .build()); + + verify(mockTlsContextManager, never()) + .releaseClientSslContextProvider(any(SslContextProvider.class)); + verify(mockTlsContextManager, never()).findOrCreateClientSslContextProvider( + any(UpstreamTlsContext.class)); + + // update with different upstreamTlsContext + reset(mockTlsContextManager); + reset(helper); + UpstreamTlsContext upstreamTlsContext1 = + CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( + BAD_CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, CA_PEM_FILE); + SslContextProvider mockSslContextProvider1 = mock(SslContextProvider.class); + doReturn(upstreamTlsContext1).when(mockSslContextProvider1).getUpstreamTlsContext(); + doReturn(mockSslContextProvider1).when(mockTlsContextManager) + .findOrCreateClientSslContextProvider(same(upstreamTlsContext1)); + clusterWatcher1.onClusterChanged( + ClusterUpdate.newBuilder() + .setClusterName("bar.googleapis.com") + .setEdsServiceName("eds1ServiceFoo.googleapis.com") + .setLbPolicy("round_robin") + .setUpstreamTlsContext(upstreamTlsContext1) + .build()); + + verify(mockTlsContextManager).releaseClientSslContextProvider(same(mockSslContextProvider)); + verify(mockTlsContextManager).findOrCreateClientSslContextProvider(same(upstreamTlsContext1)); + ArgumentCaptor createSubchannelArgsCaptor2 = + ArgumentCaptor.forClass(null); + edsLbHelper1.createSubchannel(createSubchannelArgs); + verifyUpstreamTlsContextAttribute(upstreamTlsContext1, + createSubchannelArgsCaptor2); + + // update with null + reset(mockTlsContextManager); + reset(helper); + clusterWatcher1.onClusterChanged( + ClusterUpdate.newBuilder() + .setClusterName("bar.googleapis.com") + .setEdsServiceName("eds1ServiceFoo.googleapis.com") + .setLbPolicy("round_robin") + .setUpstreamTlsContext(null) + .build()); + verify(mockTlsContextManager).releaseClientSslContextProvider(same(mockSslContextProvider1)); + verify(mockTlsContextManager, never()).findOrCreateClientSslContextProvider( + any(UpstreamTlsContext.class)); + ArgumentCaptor createSubchannelArgsCaptor3 = + ArgumentCaptor.forClass(null); + edsLbHelper1.createSubchannel(createSubchannelArgs); + verifyUpstreamTlsContextAttribute(null, + createSubchannelArgsCaptor3); + + LoadBalancer edsLoadBalancer1 = edsLoadBalancers.poll(); + + cdsLoadBalancer.shutdown(); + verify(edsLoadBalancer1).shutdown(); + verify(xdsClient).cancelClusterDataWatch("foo.googleapis.com", clusterWatcher1); + assertThat(xdsClientPool.xdsClient).isNull(); + } + + private void verifyUpstreamTlsContextAttribute( + UpstreamTlsContext upstreamTlsContext, + ArgumentCaptor createSubchannelArgsCaptor1) { + verify(helper, times(1)).createSubchannel(createSubchannelArgsCaptor1.capture()); + CreateSubchannelArgs capturedValue = createSubchannelArgsCaptor1.getValue(); + List capturedEagList = capturedValue.getAddresses(); + assertThat(capturedEagList.size()).isEqualTo(2); + EquivalentAddressGroup capturedEag = capturedEagList.get(0); + UpstreamTlsContext capturedUpstreamTlsContext = + capturedEag.getAttributes().get(XdsAttributes.ATTR_UPSTREAM_TLS_CONTEXT); + assertThat(capturedUpstreamTlsContext).isSameInstanceAs(upstreamTlsContext); + capturedEag = capturedEagList.get(1); + capturedUpstreamTlsContext = + capturedEag.getAttributes().get(XdsAttributes.ATTR_UPSTREAM_TLS_CONTEXT); + assertThat(capturedUpstreamTlsContext).isSameInstanceAs(upstreamTlsContext); + assertThat(capturedEag.getAttributes().get(XdsAttributes.XDS_CLIENT_POOL)) + .isSameInstanceAs(xdsClientPool); + } + + @Test + public void clusterWatcher_resourceNotExist() { + ResolvedAddresses resolvedAddresses = ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setAttributes(Attributes.newBuilder() + .set(XdsAttributes.XDS_CLIENT_POOL, xdsClientPool) + .build()) + .setLoadBalancingPolicyConfig(new CdsConfig("foo.googleapis.com")) + .build(); + cdsLoadBalancer.handleResolvedAddresses(resolvedAddresses); + + ArgumentCaptor clusterWatcherCaptor = ArgumentCaptor.forClass(null); + verify(xdsClient).watchClusterData(eq("foo.googleapis.com"), clusterWatcherCaptor.capture()); + + ClusterWatcher clusterWatcher = clusterWatcherCaptor.getValue(); + ArgumentCaptor pickerCaptor = ArgumentCaptor.forClass(null); + clusterWatcher.onResourceDoesNotExist("foo.googleapis.com"); + assertThat(edsLoadBalancers).isEmpty(); + verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + PickResult result = pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class)); + assertThat(result.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE); + assertThat(result.getStatus().getDescription()) + .isEqualTo("Resource foo.googleapis.com is unavailable"); + } + + @Test + public void clusterWatcher_resourceRemoved() { + ResolvedAddresses resolvedAddresses = ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setAttributes(Attributes.newBuilder() + .set(XdsAttributes.XDS_CLIENT_POOL, xdsClientPool) + .build()) + .setLoadBalancingPolicyConfig(new CdsConfig("foo.googleapis.com")) + .build(); + cdsLoadBalancer.handleResolvedAddresses(resolvedAddresses); + + ArgumentCaptor clusterWatcherCaptor = ArgumentCaptor.forClass(null); + verify(xdsClient).watchClusterData(eq("foo.googleapis.com"), clusterWatcherCaptor.capture()); + + ClusterWatcher clusterWatcher = clusterWatcherCaptor.getValue(); + ArgumentCaptor pickerCaptor = ArgumentCaptor.forClass(null); + clusterWatcher.onClusterChanged( + ClusterUpdate.newBuilder() + .setClusterName("foo.googleapis.com") + .setEdsServiceName("edsServiceFoo.googleapis.com") + .setLbPolicy("round_robin") + .build()); + assertThat(edsLoadBalancers).hasSize(1); + assertThat(edsLbHelpers).hasSize(1); + LoadBalancer edsLoadBalancer = edsLoadBalancers.poll(); + Helper edsHelper = edsLbHelpers.poll(); + SubchannelPicker subchannelPicker = mock(SubchannelPicker.class); + edsHelper.updateBalancingState(READY, subchannelPicker); + verify(helper).updateBalancingState(eq(READY), same(subchannelPicker)); + + clusterWatcher.onResourceDoesNotExist("foo.googleapis.com"); + verify(edsLoadBalancer).shutdown(); + verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + PickResult result = pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class)); + assertThat(result.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE); + assertThat(result.getStatus().getDescription()) + .isEqualTo("Resource foo.googleapis.com is unavailable"); + } + + @Test + public void clusterWatcher_onErrorCalledBeforeAndAfterOnClusterChanged() { + ResolvedAddresses resolvedAddresses = ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setAttributes(Attributes.newBuilder() + .set(XdsAttributes.XDS_CLIENT_POOL, xdsClientPool) + .build()) + .setLoadBalancingPolicyConfig(new CdsConfig("foo.googleapis.com")) + .build(); + cdsLoadBalancer.handleResolvedAddresses(resolvedAddresses); + + ArgumentCaptor clusterWatcherCaptor = ArgumentCaptor.forClass(null); + verify(xdsClient).watchClusterData(eq("foo.googleapis.com"), clusterWatcherCaptor.capture()); + + ClusterWatcher clusterWatcher = clusterWatcherCaptor.getValue(); + + // Call onError() before onClusterChanged() ever called. + clusterWatcher.onError(Status.DATA_LOSS.withDescription("fake status")); + assertThat(edsLoadBalancers).isEmpty(); + verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); + + clusterWatcher.onClusterChanged( + ClusterUpdate.newBuilder() + .setClusterName("foo.googleapis.com") + .setEdsServiceName("edsServiceFoo.googleapis.com") + .setLbPolicy("round_robin") + .build()); + + assertThat(edsLbHelpers).hasSize(1); + assertThat(edsLoadBalancers).hasSize(1); + Helper edsLbHelper = edsLbHelpers.poll(); + LoadBalancer edsLoadBalancer = edsLoadBalancers.poll(); + verify(edsLoadBalancer).handleResolvedAddresses(any(ResolvedAddresses.class)); + SubchannelPicker picker = mock(SubchannelPicker.class); + + edsLbHelper.updateBalancingState(ConnectivityState.READY, picker); + verify(helper).updateBalancingState(ConnectivityState.READY, picker); + + // Call onError() after onClusterChanged(). + clusterWatcher.onError(Status.DATA_LOSS.withDescription("fake status")); + // Verify no more TRANSIENT_FAILURE. + verify(helper, times(1)) + .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); } }