diff --git a/xds/src/main/java/io/grpc/xds/LookasideChannelLb.java b/xds/src/main/java/io/grpc/xds/LookasideChannelLb.java index b9237cf38e..52658be5d3 100644 --- a/xds/src/main/java/io/grpc/xds/LookasideChannelLb.java +++ b/xds/src/main/java/io/grpc/xds/LookasideChannelLb.java @@ -16,22 +16,18 @@ package io.grpc.xds; -import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.envoyproxy.envoy.api.v2.ClusterLoadAssignment; -import io.envoyproxy.envoy.api.v2.core.Node; import io.grpc.LoadBalancer; -import io.grpc.ManagedChannel; import io.grpc.Status; -import io.grpc.internal.ExponentialBackoffPolicy; -import io.grpc.internal.GrpcUtil; import io.grpc.xds.EnvoyProtoData.DropOverload; import io.grpc.xds.EnvoyProtoData.Locality; import io.grpc.xds.EnvoyProtoData.LocalityLbEndpoints; import io.grpc.xds.LoadReportClient.LoadReportCallback; -import io.grpc.xds.XdsComms2.AdsStreamCallback; +import io.grpc.xds.XdsClient.EndpointUpdate; +import io.grpc.xds.XdsClient.EndpointWatcher; import java.util.List; +import java.util.Map; /** * A load balancer that has a lookaside channel. This layer of load balancer creates a channel to @@ -40,33 +36,16 @@ import java.util.List; */ final class LookasideChannelLb extends LoadBalancer { - private final ManagedChannel lbChannel; private final LoadReportClient lrsClient; - private final XdsComms2 xdsComms2; + private final XdsClient xdsClient; LookasideChannelLb( - Helper helper, LookasideChannelCallback lookasideChannelCallback, ManagedChannel lbChannel, - LocalityStore localityStore, Node node) { - this( - helper, - lookasideChannelCallback, - lbChannel, - new LoadReportClientImpl( - lbChannel, helper, GrpcUtil.STOPWATCH_SUPPLIER, new ExponentialBackoffPolicy.Provider(), - localityStore.getLoadStatsStore()), - localityStore, - node); - } - - @VisibleForTesting - LookasideChannelLb( - Helper helper, + String edsServiceName, LookasideChannelCallback lookasideChannelCallback, - ManagedChannel lbChannel, + XdsClient xdsClient, LoadReportClient lrsClient, - final LocalityStore localityStore, - Node node) { - this.lbChannel = lbChannel; + final LocalityStore localityStore) { + this.xdsClient = xdsClient; LoadReportCallback lrsCallback = new LoadReportCallback() { @Override @@ -76,11 +55,9 @@ final class LookasideChannelLb extends LoadBalancer { }; this.lrsClient = lrsClient; - AdsStreamCallback adsCallback = new AdsStreamCallbackImpl( + EndpointWatcher endpointWatcher = new EndpointWatcherImpl( lookasideChannelCallback, lrsClient, lrsCallback, localityStore) ; - xdsComms2 = new XdsComms2( - lbChannel, helper, adsCallback, new ExponentialBackoffPolicy.Provider(), - GrpcUtil.STOPWATCH_SUPPLIER, node); + xdsClient.watchEndpointData(edsServiceName, endpointWatcher); } @Override @@ -91,11 +68,10 @@ final class LookasideChannelLb extends LoadBalancer { @Override public void shutdown() { lrsClient.stopLoadReporting(); - xdsComms2.shutdownLbRpc(); - lbChannel.shutdown(); + xdsClient.shutdown(); } - private static final class AdsStreamCallbackImpl implements AdsStreamCallback { + private static final class EndpointWatcherImpl implements EndpointWatcher { final LookasideChannelCallback lookasideChannelCallback; final LoadReportClient lrsClient; @@ -103,7 +79,7 @@ final class LookasideChannelLb extends LoadBalancer { final LocalityStore localityStore; boolean firstEdsResponseReceived; - AdsStreamCallbackImpl( + EndpointWatcherImpl( LookasideChannelCallback lookasideChannelCallback, LoadReportClient lrsClient, LoadReportCallback lrsCallback, LocalityStore localityStore) { this.lookasideChannelCallback = lookasideChannelCallback; @@ -113,39 +89,32 @@ final class LookasideChannelLb extends LoadBalancer { } @Override - public void onEdsResponse(ClusterLoadAssignment clusterLoadAssignment) { + public void onEndpointChanged(EndpointUpdate endpointUpdate) { if (!firstEdsResponseReceived) { firstEdsResponseReceived = true; lookasideChannelCallback.onWorking(); lrsClient.startLoadReporting(lrsCallback); } - List dropOverloadsProto = - clusterLoadAssignment.getPolicy().getDropOverloadsList(); + List dropOverloads = endpointUpdate.getDropPolicies(); ImmutableList.Builder dropOverloadsBuilder = ImmutableList.builder(); - for (ClusterLoadAssignment.Policy.DropOverload drop : dropOverloadsProto) { - DropOverload dropOverload = DropOverload.fromEnvoyProtoDropOverload(drop); + for (DropOverload dropOverload : dropOverloads) { dropOverloadsBuilder.add(dropOverload); if (dropOverload.getDropsPerMillion() == 1_000_000) { lookasideChannelCallback.onAllDrop(); break; } } - ImmutableList dropOverloads = dropOverloadsBuilder.build(); - localityStore.updateDropPercentage(dropOverloads); + localityStore.updateDropPercentage(dropOverloadsBuilder.build()); - List localities = - clusterLoadAssignment.getEndpointsList(); ImmutableMap.Builder localityEndpointsMapping = new ImmutableMap.Builder<>(); - for (io.envoyproxy.envoy.api.v2.endpoint.LocalityLbEndpoints localityLbEndpoints - : localities) { - Locality locality = Locality.fromEnvoyProtoLocality(localityLbEndpoints.getLocality()); - int localityWeight = localityLbEndpoints.getLoadBalancingWeight().getValue(); + for (Map.Entry entry + : endpointUpdate.getLocalityLbEndpointsMap().entrySet()) { + int localityWeight = entry.getValue().getLocalityWeight(); if (localityWeight != 0) { - localityEndpointsMapping.put( - locality, LocalityLbEndpoints.fromEnvoyProtoLocalityLbEndpoints(localityLbEndpoints)); + localityEndpointsMapping.put(entry.getKey(), entry.getValue()); } } @@ -153,7 +122,7 @@ final class LookasideChannelLb extends LoadBalancer { } @Override - public void onError() { + public void onError(Status error) { lookasideChannelCallback.onError(); } } diff --git a/xds/src/main/java/io/grpc/xds/LookasideLb.java b/xds/src/main/java/io/grpc/xds/LookasideLb.java index b5884d5d0f..41723be0c4 100644 --- a/xds/src/main/java/io/grpc/xds/LookasideLb.java +++ b/xds/src/main/java/io/grpc/xds/LookasideLb.java @@ -33,6 +33,8 @@ import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import io.grpc.NameResolver.ConfigOrError; import io.grpc.alts.GoogleDefaultChannelBuilder; +import io.grpc.internal.ExponentialBackoffPolicy; +import io.grpc.internal.GrpcUtil; import io.grpc.util.ForwardingLoadBalancer; import io.grpc.util.GracefulSwitchLoadBalancer; import io.grpc.xds.Bootstrapper.ChannelCreds; @@ -92,6 +94,8 @@ final class LookasideLb extends ForwardingLoadBalancer { XdsConfig xdsConfig = (XdsConfig) cfg.getConfig(); String newBalancerName = xdsConfig.balancerName; + + // The is to handle the legacy usecase that requires balancerName from xds config. if (!newBalancerName.equals(balancerName)) { balancerName = newBalancerName; // cache the name and check next time for optimization Node node = resolvedAddresses.getAttributes().get(XDS_NODE); @@ -111,6 +115,7 @@ final class LookasideLb extends ForwardingLoadBalancer { lookasideChannelLb.switchTo(newLookasideChannelLbProvider( newBalancerName, node, channelCredsList)); } + lookasideChannelLb.handleResolvedAddresses(resolvedAddresses); } @@ -157,10 +162,18 @@ final class LookasideLb extends ForwardingLoadBalancer { public LoadBalancer newLoadBalancer( Helper helper, LookasideChannelCallback lookasideChannelCallback, String balancerName, Node node, List channelCredsList) { + ManagedChannel channel = initLbChannel(helper, balancerName, channelCredsList); + XdsClient xdsClient = new XdsComms2( + channel, helper, new ExponentialBackoffPolicy.Provider(), + GrpcUtil.STOPWATCH_SUPPLIER, node); + LocalityStore localityStore = + new LocalityStoreImpl(helper, LoadBalancerRegistry.getDefaultRegistry()); + // TODO(zdapeng): Use XdsClient to do Lrs directly. + LoadReportClient lrsClient = new LoadReportClientImpl( + channel, helper, GrpcUtil.STOPWATCH_SUPPLIER, new ExponentialBackoffPolicy.Provider(), + localityStore.getLoadStatsStore()); return new LookasideChannelLb( - helper, lookasideChannelCallback, initLbChannel(helper, balancerName, channelCredsList), - new LocalityStoreImpl(helper, LoadBalancerRegistry.getDefaultRegistry()), - node); + node.getCluster(), lookasideChannelCallback, xdsClient, lrsClient, localityStore); } private static ManagedChannel initLbChannel( diff --git a/xds/src/main/java/io/grpc/xds/XdsClient.java b/xds/src/main/java/io/grpc/xds/XdsClient.java index 9a2daf08e9..b6043d7880 100644 --- a/xds/src/main/java/io/grpc/xds/XdsClient.java +++ b/xds/src/main/java/io/grpc/xds/XdsClient.java @@ -28,6 +28,7 @@ import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import javax.annotation.Nullable; /** @@ -231,6 +232,25 @@ abstract class XdsClient { return Collections.unmodifiableList(dropPolicies); } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + EndpointUpdate that = (EndpointUpdate) o; + return clusterName.equals(that.clusterName) + && localityLbEndpointsMap.equals(that.localityLbEndpointsMap) + && dropPolicies.equals(that.dropPolicies); + } + + @Override + public int hashCode() { + return Objects.hash(clusterName, localityLbEndpointsMap, dropPolicies); + } + static final class Builder { private String clusterName; private Map localityLbEndpointsMap = new LinkedHashMap<>(); diff --git a/xds/src/main/java/io/grpc/xds/XdsComms2.java b/xds/src/main/java/io/grpc/xds/XdsComms2.java index d189a9e89c..ef90698e1e 100644 --- a/xds/src/main/java/io/grpc/xds/XdsComms2.java +++ b/xds/src/main/java/io/grpc/xds/XdsComms2.java @@ -19,13 +19,16 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Stopwatch; import com.google.common.base.Supplier; import com.google.protobuf.InvalidProtocolBufferException; import io.envoyproxy.envoy.api.v2.ClusterLoadAssignment; +import io.envoyproxy.envoy.api.v2.ClusterLoadAssignment.Policy.DropOverload; import io.envoyproxy.envoy.api.v2.DiscoveryRequest; import io.envoyproxy.envoy.api.v2.DiscoveryResponse; import io.envoyproxy.envoy.api.v2.core.Node; +import io.envoyproxy.envoy.api.v2.endpoint.LocalityLbEndpoints; import io.envoyproxy.envoy.service.discovery.v2.AggregatedDiscoveryServiceGrpc; import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.LoadBalancer.Helper; @@ -42,7 +45,7 @@ import javax.annotation.CheckForNull; */ // TODO(zdapeng): This is a temporary and easy refactor of XdsComms, will be replaced by XdsClient. // Tests are deferred in XdsClientTest, otherwise it's just a refactor of XdsCommsTest. -final class XdsComms2 { +final class XdsComms2 extends XdsClient { private final ManagedChannel channel; private final Helper helper; private final BackoffPolicy.Provider backoffPolicyProvider; @@ -62,7 +65,7 @@ final class XdsComms2 { static final String EDS_TYPE_URL = "type.googleapis.com/envoy.api.v2.ClusterLoadAssignment"; - final AdsStreamCallback adsStreamCallback; + final XdsClient.EndpointWatcher endpointWatcher; final StreamObserver xdsRequestWriter; final Stopwatch retryStopwatch = stopwatchSupplier.get().start(); @@ -89,7 +92,7 @@ final class XdsComms2 { value.getResources(0).unpack(ClusterLoadAssignment.class); } catch (InvalidProtocolBufferException | RuntimeException e) { cancelRpc("Received invalid EDS response", e); - adsStreamCallback.onError(); + endpointWatcher.onError(Status.fromThrowable(e)); scheduleRetry(); return; } @@ -98,7 +101,25 @@ final class XdsComms2 { ChannelLogLevel.DEBUG, "Received an EDS response: {0}", clusterLoadAssignment); firstEdsResponseReceived = true; - adsStreamCallback.onEdsResponse(clusterLoadAssignment); + + // Converts clusterLoadAssignment data to EndpointUpdate + EndpointUpdate.Builder endpointUpdateBuilder = EndpointUpdate.newBuilder(); + endpointUpdateBuilder.setClusterName(clusterLoadAssignment.getClusterName()); + for (DropOverload dropOverload : + clusterLoadAssignment.getPolicy().getDropOverloadsList()) { + endpointUpdateBuilder.addDropPolicy( + EnvoyProtoData.DropOverload.fromEnvoyProtoDropOverload(dropOverload)); + } + for (LocalityLbEndpoints localityLbEndpoints : + clusterLoadAssignment.getEndpointsList()) { + endpointUpdateBuilder.addLocalityLbEndpoints( + EnvoyProtoData.Locality.fromEnvoyProtoLocality( + localityLbEndpoints.getLocality()), + EnvoyProtoData.LocalityLbEndpoints.fromEnvoyProtoLocalityLbEndpoints( + localityLbEndpoints)); + + } + endpointWatcher.onEndpointChanged(endpointUpdateBuilder.build()); } } } @@ -107,7 +128,7 @@ final class XdsComms2 { } @Override - public void onError(Throwable t) { + public void onError(final Throwable t) { helper.getSynchronizationContext().execute( new Runnable() { @Override @@ -116,7 +137,7 @@ final class XdsComms2 { if (cancelled) { return; } - adsStreamCallback.onError(); + endpointWatcher.onError(Status.fromThrowable(t)); scheduleRetry(); } }); @@ -124,7 +145,7 @@ final class XdsComms2 { @Override public void onCompleted() { - onError(Status.INTERNAL.withDescription("Server closed the ADS streaming RPC") + onError(Status.UNAVAILABLE.withDescription("Server closed the ADS streaming RPC") .asException()); } @@ -168,8 +189,8 @@ final class XdsComms2 { boolean cancelled; boolean closed; - AdsStream(AdsStreamCallback adsStreamCallback) { - this.adsStreamCallback = adsStreamCallback; + AdsStream(XdsClient.EndpointWatcher endpointWatcher) { + this.endpointWatcher = endpointWatcher; this.xdsRequestWriter = AggregatedDiscoveryServiceGrpc.newStub(channel).withWaitForReady() .streamAggregatedResources(xdsResponseReader); @@ -186,7 +207,7 @@ final class XdsComms2 { } AdsStream(AdsStream adsStream) { - this(adsStream.adsStreamCallback); + this(adsStream.endpointWatcher); } // run in SynchronizationContext @@ -204,15 +225,13 @@ final class XdsComms2 { * Starts a new ADS streaming RPC. */ XdsComms2( - ManagedChannel channel, Helper helper, AdsStreamCallback adsStreamCallback, + ManagedChannel channel, Helper helper, BackoffPolicy.Provider backoffPolicyProvider, Supplier stopwatchSupplier, Node node) { this.channel = checkNotNull(channel, "channel"); this.helper = checkNotNull(helper, "helper"); this.stopwatchSupplier = checkNotNull(stopwatchSupplier, "stopwatchSupplier"); this.node = node; - this.adsStream = new AdsStream( - checkNotNull(adsStreamCallback, "adsStreamCallback")); this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider"); this.adsRpcRetryPolicy = backoffPolicyProvider.get(); } @@ -227,12 +246,20 @@ final class XdsComms2 { } } - // run in SynchronizationContext - // TODO: Change method name to shutdown or shutdownXdsComms if that gives better semantics ( - // cancel LB RPC and clean up retry timer). - void shutdownLbRpc() { - adsStream.cancelRpc("shutdown", null); + @Override + void watchEndpointData(String clusterName, EndpointWatcher watcher) { + if (adsStream == null) { + adsStream = new AdsStream(watcher); + } + } + + @Override + void shutdown() { + if (adsStream != null) { + adsStream.cancelRpc("shutdown", null); + } cancelRetryTimer(); + channel.shutdown(); } // run in SynchronizationContext @@ -244,12 +271,28 @@ final class XdsComms2 { } /** - * Callback on ADS stream events. The callback methods should be called in a proper {@link - * io.grpc.SynchronizationContext}. + * Converts ClusterLoadAssignment data to {@link EndpointUpdate}. All the needed data, that is + * clusterName, localityLbEndpointsMap and dropPolicies, is extracted from ClusterLoadAssignment, + * and all other data is ignored. */ - interface AdsStreamCallback { - void onEdsResponse(ClusterLoadAssignment clusterLoadAssignment); + @VisibleForTesting + static EndpointUpdate getEndpointUpdatefromClusterAssignment( + ClusterLoadAssignment clusterLoadAssignment) { + EndpointUpdate.Builder endpointUpdateBuilder = EndpointUpdate.newBuilder(); + endpointUpdateBuilder.setClusterName(clusterLoadAssignment.getClusterName()); + for (DropOverload dropOverload : + clusterLoadAssignment.getPolicy().getDropOverloadsList()) { + endpointUpdateBuilder.addDropPolicy( + EnvoyProtoData.DropOverload.fromEnvoyProtoDropOverload(dropOverload)); + } + for (LocalityLbEndpoints localityLbEndpoints : clusterLoadAssignment.getEndpointsList()) { + endpointUpdateBuilder.addLocalityLbEndpoints( + EnvoyProtoData.Locality.fromEnvoyProtoLocality( + localityLbEndpoints.getLocality()), + EnvoyProtoData.LocalityLbEndpoints.fromEnvoyProtoLocalityLbEndpoints( + localityLbEndpoints)); - void onError(); + } + return endpointUpdateBuilder.build(); } } diff --git a/xds/src/main/java/io/grpc/xds/XdsLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/XdsLoadBalancerProvider.java index ce9a13335a..007cd06fa5 100644 --- a/xds/src/main/java/io/grpc/xds/XdsLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/XdsLoadBalancerProvider.java @@ -16,8 +16,6 @@ package io.grpc.xds; -import static com.google.common.base.Preconditions.checkNotNull; - import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import com.google.common.base.Objects; @@ -133,6 +131,7 @@ public final class XdsLoadBalancerProvider extends LoadBalancerProvider { */ static final class XdsConfig { // TODO(chengyuanzhang): delete after shifting to use bootstrap. + @Nullable final String balancerName; // TODO(carl-mastrangelo): make these Object's containing the fully parsed child configs. @Nullable @@ -150,9 +149,10 @@ public final class XdsLoadBalancerProvider extends LoadBalancerProvider { final String lrsServerName; XdsConfig( - String balancerName, @Nullable LbConfig childPolicy, @Nullable LbConfig fallbackPolicy, - @Nullable String edsServiceName, @Nullable String lrsServerName) { - this.balancerName = checkNotNull(balancerName, "balancerName"); + @Nullable String balancerName, @Nullable LbConfig childPolicy, + @Nullable LbConfig fallbackPolicy, @Nullable String edsServiceName, + @Nullable String lrsServerName) { + this.balancerName = balancerName; this.childPolicy = childPolicy; this.fallbackPolicy = fallbackPolicy; this.edsServiceName = edsServiceName; diff --git a/xds/src/test/java/io/grpc/xds/LookasideChannelLbTest.java b/xds/src/test/java/io/grpc/xds/LookasideChannelLbTest.java index de2ea94db0..c3ed9c9c3f 100644 --- a/xds/src/test/java/io/grpc/xds/LookasideChannelLbTest.java +++ b/xds/src/test/java/io/grpc/xds/LookasideChannelLbTest.java @@ -51,6 +51,8 @@ import io.grpc.Status; import io.grpc.SynchronizationContext; import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.internal.ExponentialBackoffPolicy; +import io.grpc.internal.GrpcUtil; import io.grpc.internal.testing.StreamRecorder; import io.grpc.stub.StreamObserver; import io.grpc.testing.GrpcCleanupRule; @@ -168,9 +170,11 @@ public class LookasideChannelLbTest { doReturn(mock(ChannelLogger.class)).when(helper).getChannelLogger(); doReturn(loadStatsStore).when(localityStore).getLoadStatsStore(); + XdsClient xdsClient = new XdsComms2( + channel, helper, new ExponentialBackoffPolicy.Provider(), + GrpcUtil.STOPWATCH_SUPPLIER, Node.getDefaultInstance()); lookasideChannelLb = new LookasideChannelLb( - helper, lookasideChannelCallback, channel, loadReportClient, localityStore, - Node.getDefaultInstance()); + "cluster1", lookasideChannelCallback, xdsClient, loadReportClient, localityStore); } @Test diff --git a/xds/src/test/java/io/grpc/xds/XdsCommsTest.java b/xds/src/test/java/io/grpc/xds/XdsCommsTest.java index 07dcc79e77..5d0828b9a4 100644 --- a/xds/src/test/java/io/grpc/xds/XdsCommsTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsCommsTest.java @@ -17,11 +17,12 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.XdsComms2.getEndpointUpdatefromClusterAssignment; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; @@ -33,6 +34,8 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import com.google.protobuf.Any; import com.google.protobuf.UInt32Value; import io.envoyproxy.envoy.api.v2.ClusterLoadAssignment; +import io.envoyproxy.envoy.api.v2.ClusterLoadAssignment.Policy; +import io.envoyproxy.envoy.api.v2.ClusterLoadAssignment.Policy.DropOverload; import io.envoyproxy.envoy.api.v2.DiscoveryRequest; import io.envoyproxy.envoy.api.v2.DiscoveryResponse; import io.envoyproxy.envoy.api.v2.core.Address; @@ -43,6 +46,8 @@ import io.envoyproxy.envoy.api.v2.endpoint.Endpoint; import io.envoyproxy.envoy.api.v2.endpoint.LbEndpoint; import io.envoyproxy.envoy.api.v2.endpoint.LocalityLbEndpoints; import io.envoyproxy.envoy.service.discovery.v2.AggregatedDiscoveryServiceGrpc.AggregatedDiscoveryServiceImplBase; +import io.envoyproxy.envoy.type.FractionalPercent; +import io.envoyproxy.envoy.type.FractionalPercent.DenominatorType; import io.grpc.ChannelLogger; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; @@ -58,13 +63,17 @@ import io.grpc.internal.FakeClock; import io.grpc.internal.testing.StreamRecorder; import io.grpc.stub.StreamObserver; import io.grpc.testing.GrpcCleanupRule; -import io.grpc.xds.XdsComms2.AdsStreamCallback; +import io.grpc.xds.XdsClient.EndpointUpdate; +import io.grpc.xds.XdsClient.EndpointWatcher; +import java.util.Map; import java.util.concurrent.TimeUnit; +import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; import org.mockito.InOrder; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -89,7 +98,7 @@ public class XdsCommsTest { @Mock private Helper helper; @Mock - private AdsStreamCallback adsStreamCallback; + private EndpointWatcher endpointWatcher; @Mock private BackoffPolicy.Provider backoffPolicyProvider; @Mock @@ -184,21 +193,27 @@ public class XdsCommsTest { doReturn(10L, 100L, 1000L).when(backoffPolicy1).nextBackoffNanos(); doReturn(20L, 200L).when(backoffPolicy2).nextBackoffNanos(); xdsComms = new XdsComms2( - channel, helper, adsStreamCallback, backoffPolicyProvider, + channel, helper, backoffPolicyProvider, fakeClock.getStopwatchSupplier(), Node.getDefaultInstance()); + xdsComms.watchEndpointData("", endpointWatcher); + } + + @After + public void tearDown() { + xdsComms.shutdown(); } @Test - public void shutdownLbRpc_verifyChannelNotShutdown() throws Exception { - xdsComms.shutdownLbRpc(); + public void shutdownLbRpc_verifyChannelShutdown() throws Exception { + xdsComms.shutdown(); assertTrue(streamRecorder.awaitCompletion(1, TimeUnit.SECONDS)); assertEquals(Status.Code.CANCELLED, Status.fromThrowable(streamRecorder.getError()).getCode()); - assertFalse(channel.isShutdown()); + assertTrue(channel.isShutdown()); } @Test public void cancel() throws Exception { - xdsComms.shutdownLbRpc(); + xdsComms.shutdown(); assertTrue(streamRecorder.awaitCompletion(1, TimeUnit.SECONDS)); assertEquals(Status.Code.CANCELLED, Status.fromThrowable(streamRecorder.getError()).getCode()); } @@ -273,7 +288,8 @@ public class XdsCommsTest { .build(); responseWriter.onNext(edsResponse); - verify(adsStreamCallback).onEdsResponse(clusterLoadAssignment); + verify(endpointWatcher).onEndpointChanged( + getEndpointUpdatefromClusterAssignment(clusterLoadAssignment)); ClusterLoadAssignment clusterLoadAssignment2 = ClusterLoadAssignment.newBuilder() .addEndpoints(LocalityLbEndpoints.newBuilder() @@ -285,7 +301,7 @@ public class XdsCommsTest { .setLocality(localityProto1) .addLbEndpoints(endpoint11) .addLbEndpoints(endpoint12) - .setLoadBalancingWeight(UInt32Value.of(1))) + .setLoadBalancingWeight(UInt32Value.of(3))) .build(); edsResponse = DiscoveryResponse.newBuilder() .addResources(Any.pack(clusterLoadAssignment2)) @@ -293,18 +309,19 @@ public class XdsCommsTest { .build(); responseWriter.onNext(edsResponse); - verify(adsStreamCallback).onEdsResponse(clusterLoadAssignment2); - verifyNoMoreInteractions(adsStreamCallback); - - xdsComms.shutdownLbRpc(); + verify(endpointWatcher).onEndpointChanged( + getEndpointUpdatefromClusterAssignment(clusterLoadAssignment2)); + verifyNoMoreInteractions(endpointWatcher); } @Test public void serverOnCompleteShouldFailClient() { responseWriter.onCompleted(); - verify(adsStreamCallback).onError(); - verifyNoMoreInteractions(adsStreamCallback); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(endpointWatcher).onError(statusCaptor.capture()); + assertThat(statusCaptor.getValue().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + verifyNoMoreInteractions(endpointWatcher); } /** @@ -325,7 +342,7 @@ public class XdsCommsTest { * Verify retry is scheduled. Verify the 6th PRC starts after backoff. * *

The 6th RPC fails with response observer onError() without receiving initial response. - * Verify retry is scheduled. Call {@link XdsComms2#shutdownLbRpc()}, verify retry timer is + * Verify retry is scheduled. Call {@link XdsComms2#shutdown()} ()}, verify retry timer is * cancelled. */ @Test @@ -333,7 +350,7 @@ public class XdsCommsTest { StreamRecorder currentStreamRecorder = streamRecorder; assertThat(currentStreamRecorder.getValues()).hasSize(1); InOrder inOrder = - inOrder(backoffPolicyProvider, backoffPolicy1, backoffPolicy2, adsStreamCallback); + inOrder(backoffPolicyProvider, backoffPolicy1, backoffPolicy2, endpointWatcher); inOrder.verify(backoffPolicyProvider).get(); assertEquals(0, fakeClock.numPendingTasks(LB_RPC_RETRY_TASK_FILTER)); @@ -341,7 +358,7 @@ public class XdsCommsTest { DiscoveryResponse.newBuilder().setTypeUrl(EDS_TYPE_URL).build(); // The 1st ADS RPC receives invalid response responseWriter.onNext(invalidResponse); - inOrder.verify(adsStreamCallback).onError(); + inOrder.verify(endpointWatcher).onError(any(Status.class)); assertThat(currentStreamRecorder.getError()).isNotNull(); // Will start backoff sequence 1 (10ns) @@ -364,7 +381,7 @@ public class XdsCommsTest { fakeClock.forwardNanos(4); // The 2nd RPC fails with response observer onError() without receiving initial response responseWriter.onError(new Exception("fake error")); - inOrder.verify(adsStreamCallback).onError(); + inOrder.verify(endpointWatcher).onError(any(Status.class)); // Will start backoff sequence 2 (100ns) inOrder.verify(backoffPolicy1).nextBackoffNanos(); @@ -386,7 +403,7 @@ public class XdsCommsTest { fakeClock.forwardNanos(5); // The 3rd PRC receives invalid initial response. responseWriter.onNext(invalidResponse); - inOrder.verify(adsStreamCallback).onError(); + inOrder.verify(endpointWatcher).onError(any(Status.class)); assertThat(currentStreamRecorder.getError()).isNotNull(); // Will start backoff sequence 3 (1000ns) @@ -447,7 +464,7 @@ public class XdsCommsTest { // The 5th RPC fails with response observer onError() without receiving initial response fakeClock.forwardNanos(8); responseWriter.onError(new Exception("fake error")); - inOrder.verify(adsStreamCallback).onError(); + inOrder.verify(endpointWatcher).onError(any(Status.class)); // Will start backoff sequence 1 (20ns) inOrder.verify(backoffPolicy2).nextBackoffNanos(); @@ -472,25 +489,106 @@ public class XdsCommsTest { // The 6th RPC fails with response observer onError() without receiving initial response responseWriter.onError(new Exception("fake error")); - inOrder.verify(adsStreamCallback).onError(); + inOrder.verify(endpointWatcher).onError(any(Status.class)); // Retry is scheduled assertEquals(1, fakeClock.numPendingTasks(LB_RPC_RETRY_TASK_FILTER)); // Shutdown cancels retry - xdsComms.shutdownLbRpc(); + xdsComms.shutdown(); assertEquals(0, fakeClock.numPendingTasks(LB_RPC_RETRY_TASK_FILTER)); } @Test public void refreshAdsStreamCancelsExistingRetry() { responseWriter.onError(new Exception("fake error")); - verify(adsStreamCallback).onError(); + verify(endpointWatcher).onError(any(Status.class)); assertEquals(1, fakeClock.numPendingTasks(LB_RPC_RETRY_TASK_FILTER)); xdsComms.refreshAdsStream(); assertEquals(0, fakeClock.numPendingTasks(LB_RPC_RETRY_TASK_FILTER)); + } - xdsComms.shutdownLbRpc(); + @Test + public void convertClusterLoadAssignmentToEndpointUpdate() { + Locality localityProto1 = Locality.newBuilder() + .setRegion("region1").setZone("zone1").setSubZone("subzone1").build(); + LbEndpoint endpoint11 = LbEndpoint.newBuilder() + .setEndpoint(Endpoint.newBuilder() + .setAddress(Address.newBuilder() + .setSocketAddress(SocketAddress.newBuilder() + .setAddress("addr11").setPortValue(11)))) + .setLoadBalancingWeight(UInt32Value.of(11)) + .build(); + LbEndpoint endpoint12 = LbEndpoint.newBuilder() + .setEndpoint(Endpoint.newBuilder() + .setAddress(Address.newBuilder() + .setSocketAddress(SocketAddress.newBuilder() + .setAddress("addr12").setPortValue(12)))) + .setLoadBalancingWeight(UInt32Value.of(12)) + .build(); + Locality localityProto2 = Locality.newBuilder() + .setRegion("region2").setZone("zone2").setSubZone("subzone2").build(); + LbEndpoint endpoint21 = LbEndpoint.newBuilder() + .setEndpoint(Endpoint.newBuilder() + .setAddress(Address.newBuilder() + .setSocketAddress(SocketAddress.newBuilder() + .setAddress("addr21").setPortValue(21)))) + .setLoadBalancingWeight(UInt32Value.of(21)) + .build(); + LbEndpoint endpoint22 = LbEndpoint.newBuilder() + .setEndpoint(Endpoint.newBuilder() + .setAddress(Address.newBuilder() + .setSocketAddress(SocketAddress.newBuilder() + .setAddress("addr22").setPortValue(22)))) + .setLoadBalancingWeight(UInt32Value.of(22)) + .build(); + LocalityLbEndpoints localityLbEndpointsProto1 = LocalityLbEndpoints.newBuilder() + .setLocality(localityProto1) + .setPriority(1) + .addLbEndpoints(endpoint11) + .addLbEndpoints(endpoint12) + .setLoadBalancingWeight(UInt32Value.of(1)) + .build(); + LocalityLbEndpoints localityLbEndpointsProto2 = LocalityLbEndpoints.newBuilder() + .setLocality(localityProto2) + .addLbEndpoints(endpoint21) + .addLbEndpoints(endpoint22) + .setLoadBalancingWeight(UInt32Value.of(2)) + .build(); + DropOverload dropOverloadProto1 = DropOverload.newBuilder() + .setCategory("cat1") + .setDropPercentage(FractionalPercent.newBuilder() + .setDenominator(DenominatorType.TEN_THOUSAND).setNumerator(123)) + .build(); + DropOverload dropOverloadProto2 = DropOverload.newBuilder() + .setCategory("cat2") + .setDropPercentage(FractionalPercent.newBuilder() + .setDenominator(DenominatorType.TEN_THOUSAND).setNumerator(456)) + .build(); + ClusterLoadAssignment clusterLoadAssignment = ClusterLoadAssignment.newBuilder() + .setClusterName("cluster1") + .addEndpoints(localityLbEndpointsProto1) + .addEndpoints(localityLbEndpointsProto2) + .setPolicy(Policy.newBuilder() + .addDropOverloads(dropOverloadProto1) + .addDropOverloads(dropOverloadProto2)) + .build(); + + EndpointUpdate endpointUpdate = getEndpointUpdatefromClusterAssignment(clusterLoadAssignment); + + assertThat(endpointUpdate.getClusterName()).isEqualTo("cluster1"); + Map localityLbEndpointsMap = + endpointUpdate.getLocalityLbEndpointsMap(); + assertThat(localityLbEndpointsMap).containsExactly( + EnvoyProtoData.Locality.fromEnvoyProtoLocality(localityProto1), + EnvoyProtoData.LocalityLbEndpoints.fromEnvoyProtoLocalityLbEndpoints( + localityLbEndpointsProto1), + EnvoyProtoData.Locality.fromEnvoyProtoLocality(localityProto2), + EnvoyProtoData.LocalityLbEndpoints.fromEnvoyProtoLocalityLbEndpoints( + localityLbEndpointsProto2)); + assertThat(endpointUpdate.getDropPolicies()).containsExactly( + EnvoyProtoData.DropOverload.fromEnvoyProtoDropOverload(dropOverloadProto1), + EnvoyProtoData.DropOverload.fromEnvoyProtoDropOverload(dropOverloadProto2)); } }