diff --git a/xds/src/main/java/io/grpc/xds/LoadReportClient.java b/xds/src/main/java/io/grpc/xds/LoadReportClient.java index da52614790..d5126c6891 100644 --- a/xds/src/main/java/io/grpc/xds/LoadReportClient.java +++ b/xds/src/main/java/io/grpc/xds/LoadReportClient.java @@ -23,23 +23,21 @@ import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Stopwatch; import com.google.common.base.Supplier; -import com.google.protobuf.Struct; -import com.google.protobuf.Value; import com.google.protobuf.util.Durations; -import io.envoyproxy.envoy.api.v2.core.Node; -import io.envoyproxy.envoy.service.load_stats.v2.LoadReportingServiceGrpc; -import io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest; -import io.envoyproxy.envoy.service.load_stats.v2.LoadStatsResponse; import io.grpc.InternalLogId; -import io.grpc.ManagedChannel; import io.grpc.Status; import io.grpc.SynchronizationContext; import io.grpc.SynchronizationContext.ScheduledHandle; import io.grpc.internal.BackoffPolicy; import io.grpc.stub.StreamObserver; import io.grpc.xds.EnvoyProtoData.ClusterStats; +import io.grpc.xds.EnvoyProtoData.Node; +import io.grpc.xds.XdsClient.XdsChannel; import io.grpc.xds.XdsLogger.XdsLogLevel; +import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; @@ -56,11 +54,10 @@ final class LoadReportClient { private final InternalLogId logId; private final XdsLogger logger; - private final ManagedChannel channel; + private final XdsChannel xdsChannel; private final Node node; private final SynchronizationContext syncContext; private final ScheduledExecutorService timerService; - private final Supplier stopwatchSupplier; private final Stopwatch retryStopwatch; private final BackoffPolicy.Provider backoffPolicyProvider; private final LoadStatsManager loadStatsManager; @@ -77,29 +74,27 @@ final class LoadReportClient { LoadReportClient( String targetName, LoadStatsManager loadStatsManager, - ManagedChannel channel, + XdsChannel xdsChannel, Node node, SynchronizationContext syncContext, ScheduledExecutorService scheduledExecutorService, BackoffPolicy.Provider backoffPolicyProvider, Supplier stopwatchSupplier) { this.loadStatsManager = checkNotNull(loadStatsManager, "loadStatsManager"); - this.channel = checkNotNull(channel, "channel"); + this.xdsChannel = checkNotNull(xdsChannel, "xdsChannel"); this.syncContext = checkNotNull(syncContext, "syncContext"); this.timerService = checkNotNull(scheduledExecutorService, "timeService"); this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider"); - this.stopwatchSupplier = checkNotNull(stopwatchSupplier, "stopwatchSupplier"); + checkNotNull(stopwatchSupplier, "stopwatchSupplier"); this.retryStopwatch = stopwatchSupplier.get(); checkNotNull(targetName, "targetName"); checkNotNull(node, "node"); - Struct metadata = - node.getMetadata() - .toBuilder() - .putFields( - TARGET_NAME_METADATA_KEY, - Value.newBuilder().setStringValue(targetName).build()) - .build(); - this.node = node.toBuilder().setMetadata(metadata).build(); + Map newMetadata = new HashMap<>(); + if (node.getMetadata() != null) { + newMetadata.putAll(node.getMetadata()); + } + newMetadata.put(TARGET_NAME_METADATA_KEY, targetName); + this.node = node.toBuilder().setMetadata(newMetadata).build(); logId = InternalLogId.allocate("lrs-client", targetName); logger = XdsLogger.withLogId(logId); logger.log(XdsLogLevel.INFO, "Created"); @@ -163,17 +158,14 @@ final class LoadReportClient { private void startLrsRpc() { checkState(lrsStream == null, "previous lbStream has not been cleared yet"); - LoadReportingServiceGrpc.LoadReportingServiceStub stub - = LoadReportingServiceGrpc.newStub(channel); - lrsStream = new LrsStream(stub, stopwatchSupplier.get()); + // TODO(zdapeng): implement LrsStreamV3 and instantiate lrsStream based on value of + // xdsChannel.useProtocolV3 + lrsStream = new LrsStreamV2(); retryStopwatch.reset().start(); lrsStream.start(); } - private class LrsStream implements StreamObserver { - - final LoadReportingServiceGrpc.LoadReportingServiceStub stub; - StreamObserver lrsRequestWriter; + private abstract class LrsStream { boolean initialResponseReceived; boolean closed; long loadReportIntervalNano = -1; @@ -181,32 +173,39 @@ final class LoadReportClient { List clusterNames; // clusters to report loads for, if not report all. ScheduledHandle loadReportTimer; - LrsStream(LoadReportingServiceGrpc.LoadReportingServiceStub stub, Stopwatch stopwatch) { - this.stub = checkNotNull(stub, "stub"); - } + abstract void start(); - void start() { - lrsRequestWriter = stub.withWaitForReady().streamLoadStats(this); - LoadStatsRequest initRequest = - LoadStatsRequest.newBuilder() - .setNode(node) - .build(); - lrsRequestWriter.onNext(initRequest); - logger.log(XdsLogLevel.DEBUG, "Initial LRS request sent:\n{0}", initRequest); - } + abstract void sendLoadStatsRequest(LoadStatsRequestData request); - @Override - public void onNext(final LoadStatsResponse response) { + abstract void sendError(Exception error); + + final void handleResponse(final LoadStatsResponseData response) { syncContext.execute(new Runnable() { @Override public void run() { - handleResponse(response); + if (closed) { + return; + } + if (!initialResponseReceived) { + logger.log(XdsLogLevel.DEBUG, "Initial LRS response received"); + initialResponseReceived = true; + } + reportAllClusters = response.getSendAllClusters(); + if (reportAllClusters) { + logger.log(XdsLogLevel.INFO, "Report loads for all clusters"); + } else { + logger.log(XdsLogLevel.INFO, "Report loads for clusters: ", response.getClustersList()); + clusterNames = response.getClustersList(); + } + long interval = response.getLoadReportingIntervalNanos(); + logger.log(XdsLogLevel.INFO, "Update load reporting interval to {0} ns", interval); + loadReportIntervalNano = interval; + scheduleNextLoadReport(); } }); } - @Override - public void onError(final Throwable t) { + final void handleRpcError(final Throwable t) { syncContext.execute(new Runnable() { @Override public void run() { @@ -215,8 +214,7 @@ final class LoadReportClient { }); } - @Override - public void onCompleted() { + final void handleRpcComplete() { syncContext.execute(new Runnable() { @Override public void run() { @@ -227,21 +225,17 @@ final class LoadReportClient { } private void sendLoadReport() { - LoadStatsRequest.Builder requestBuilder = LoadStatsRequest.newBuilder().setNode(node); + List clusterStatsList; if (reportAllClusters) { - for (ClusterStats clusterStats : loadStatsManager.getAllLoadReports()) { - requestBuilder.addClusterStats(clusterStats.toEnvoyProtoClusterStatsV2()); - } + clusterStatsList = loadStatsManager.getAllLoadReports(); } else { + clusterStatsList = new ArrayList<>(); for (String name : clusterNames) { - for (ClusterStats clusterStats : loadStatsManager.getClusterLoadReports(name)) { - requestBuilder.addClusterStats(clusterStats.toEnvoyProtoClusterStatsV2()); - } + clusterStatsList.addAll(loadStatsManager.getClusterLoadReports(name)); } } - LoadStatsRequest request = requestBuilder.build(); - lrsRequestWriter.onNext(request); - logger.log(XdsLogLevel.DEBUG, "Sent LoadStatsRequest\n{0}", request); + LoadStatsRequestData request = new LoadStatsRequestData(node, clusterStatsList); + sendLoadStatsRequest(request); scheduleNextLoadReport(); } @@ -258,29 +252,6 @@ final class LoadReportClient { } } - private void handleResponse(LoadStatsResponse response) { - if (closed) { - return; - } - if (!initialResponseReceived) { - logger.log(XdsLogLevel.DEBUG, "Received LRS initial response:\n{0}", response); - initialResponseReceived = true; - } else { - logger.log(XdsLogLevel.DEBUG, "Received LRS response:\n{0}", response); - } - reportAllClusters = response.getSendAllClusters(); - if (reportAllClusters) { - logger.log(XdsLogLevel.INFO, "Report loads for all clusters"); - } else { - logger.log(XdsLogLevel.INFO, "Report loads for clusters: ", response.getClustersList()); - clusterNames = response.getClustersList(); - } - long interval = Durations.toNanos(response.getLoadReportingInterval()); - logger.log(XdsLogLevel.INFO, "Update load reporting interval to {0} ns", interval); - loadReportIntervalNano = interval; - scheduleNextLoadReport(); - } - private void handleStreamClosed(Status status) { checkArgument(!status.isOk(), "unexpected OK status"); if (closed) { @@ -317,17 +288,13 @@ final class LoadReportClient { } } - private void close(@Nullable Exception error) { + private void close(Exception error) { if (closed) { return; } closed = true; cleanUp(); - if (error == null) { - lrsRequestWriter.onCompleted(); - } else { - lrsRequestWriter.onError(error); - } + sendError(error); } private void cleanUp() { @@ -340,4 +307,107 @@ final class LoadReportClient { } } } + + private final class LrsStreamV2 extends LrsStream { + StreamObserver lrsRequestWriterV2; + + @Override + void start() { + StreamObserver + lrsResponseReaderV2 = + new StreamObserver() { + @Override + public void onNext( + io.envoyproxy.envoy.service.load_stats.v2.LoadStatsResponse response) { + logger.log(XdsLogLevel.DEBUG, "Received LRS response:\n{0}", response); + handleResponse(LoadStatsResponseData.fromEnvoyProtoV2(response)); + } + + @Override + public void onError(Throwable t) { + handleRpcError(t); + } + + @Override + public void onCompleted() { + handleRpcComplete(); + } + }; + io.envoyproxy.envoy.service.load_stats.v2.LoadReportingServiceGrpc.LoadReportingServiceStub + stubV2 = io.envoyproxy.envoy.service.load_stats.v2.LoadReportingServiceGrpc.newStub( + xdsChannel.getManagedChannel()); + lrsRequestWriterV2 = stubV2.withWaitForReady().streamLoadStats(lrsResponseReaderV2); + logger.log(XdsLogLevel.DEBUG, "Sending initial LRS request"); + sendLoadStatsRequest(new LoadStatsRequestData(node, null)); + } + + @Override + void sendLoadStatsRequest(LoadStatsRequestData request) { + io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest requestProto = + request.toEnvoyProtoV2(); + lrsRequestWriterV2.onNext(requestProto); + logger.log(XdsLogLevel.DEBUG, "Sent LoadStatsRequest\n{0}", requestProto); + } + + @Override + void sendError(Exception error) { + lrsRequestWriterV2.onError(error); + } + } + + private static final class LoadStatsRequestData { + final Node node; + @Nullable + final List clusterStatsList; + + LoadStatsRequestData(Node node, @Nullable List clusterStatsList) { + this.node = checkNotNull(node, "node"); + this.clusterStatsList = clusterStatsList; + } + + io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest toEnvoyProtoV2() { + io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest.Builder builder + = io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest.newBuilder(); + builder.setNode(node.toEnvoyProtoNodeV2()); + if (clusterStatsList != null) { + for (ClusterStats stats : clusterStatsList) { + builder.addClusterStats(stats.toEnvoyProtoClusterStatsV2()); + } + } + return builder.build(); + } + } + + private static final class LoadStatsResponseData { + final boolean sendAllClusters; + final List clusters; + final long loadReportingIntervalNanos; + + LoadStatsResponseData( + boolean sendAllClusters, List clusters, long loadReportingIntervalNanos) { + this.sendAllClusters = sendAllClusters; + this.clusters = checkNotNull(clusters, "clusters"); + this.loadReportingIntervalNanos = loadReportingIntervalNanos; + } + + boolean getSendAllClusters() { + return sendAllClusters; + } + + List getClustersList() { + return clusters; + } + + long getLoadReportingIntervalNanos() { + return loadReportingIntervalNanos; + } + + static LoadStatsResponseData fromEnvoyProtoV2( + io.envoyproxy.envoy.service.load_stats.v2.LoadStatsResponse loadStatsResponse) { + return new LoadStatsResponseData( + loadStatsResponse.getSendAllClusters(), + loadStatsResponse.getClustersList(), + Durations.toNanos(loadStatsResponse.getLoadReportingInterval())); + } + } } diff --git a/xds/src/main/java/io/grpc/xds/XdsClientImpl.java b/xds/src/main/java/io/grpc/xds/XdsClientImpl.java index c386daca3c..82addf72b3 100644 --- a/xds/src/main/java/io/grpc/xds/XdsClientImpl.java +++ b/xds/src/main/java/io/grpc/xds/XdsClientImpl.java @@ -48,7 +48,6 @@ import io.envoyproxy.envoy.service.discovery.v3.AggregatedDiscoveryServiceGrpc; import io.envoyproxy.envoy.service.discovery.v3.DiscoveryRequest; import io.envoyproxy.envoy.service.discovery.v3.DiscoveryResponse; import io.grpc.InternalLogId; -import io.grpc.ManagedChannel; import io.grpc.Status; import io.grpc.SynchronizationContext; import io.grpc.SynchronizationContext.ScheduledHandle; @@ -117,8 +116,7 @@ final class XdsClientImpl extends XdsClient { private final XdsLogger logger; // Name of the target server this gRPC client is trying to talk to. private final String targetName; - private final ManagedChannel channel; - private final boolean useProtocolV3; + private final XdsChannel xdsChannel; private final SynchronizationContext syncContext; private final ScheduledExecutorService timeService; private final BackoffPolicy.Provider backoffPolicyProvider; @@ -209,8 +207,7 @@ final class XdsClientImpl extends XdsClient { XdsChannel xdsChannel = checkNotNull(channelFactory, "channelFactory") .createChannel(checkNotNull(servers, "servers")); - this.channel = xdsChannel.getManagedChannel(); - this.useProtocolV3 = xdsChannel.isUseProtocolV3(); + this.xdsChannel = xdsChannel; this.node = checkNotNull(node, "node"); this.syncContext = checkNotNull(syncContext, "syncContext"); this.timeService = checkNotNull(timeService, "timeService"); @@ -225,7 +222,7 @@ final class XdsClientImpl extends XdsClient { @Override void shutdown() { logger.log(XdsLogLevel.INFO, "Shutting down"); - channel.shutdown(); + xdsChannel.getManagedChannel().shutdown(); if (adsStream != null) { adsStream.close(Status.CANCELLED.withDescription("shutdown").asException()); } @@ -484,8 +481,8 @@ final class XdsClientImpl extends XdsClient { new LoadReportClient( targetName, loadStatsManager, - channel, - node.toEnvoyProtoNodeV2(), + xdsChannel, + node, syncContext, timeService, backoffPolicyProvider, @@ -529,7 +526,7 @@ final class XdsClientImpl extends XdsClient { */ private void startRpcStream() { checkState(adsStream == null, "Previous adsStream has not been cleared yet"); - if (useProtocolV3) { + if (xdsChannel.isUseProtocolV3()) { adsStream = new AdsStream(); } else { adsStream = new AdsStreamV2(); @@ -1739,8 +1736,8 @@ final class XdsClientImpl extends XdsClient { private StreamObserver requestWriterV2; AdsStreamV2() { - stubV2 = - io.envoyproxy.envoy.service.discovery.v2.AggregatedDiscoveryServiceGrpc.newStub(channel); + stubV2 = io.envoyproxy.envoy.service.discovery.v2.AggregatedDiscoveryServiceGrpc.newStub( + xdsChannel.getManagedChannel()); } @Override @@ -1795,7 +1792,7 @@ final class XdsClientImpl extends XdsClient { private StreamObserver requestWriter; AdsStream() { - stub = AggregatedDiscoveryServiceGrpc.newStub(channel); + stub = AggregatedDiscoveryServiceGrpc.newStub(xdsChannel.getManagedChannel()); } @Override diff --git a/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java b/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java index a63171ef44..0a2951ad4d 100644 --- a/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java +++ b/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java @@ -30,6 +30,7 @@ import static org.mockito.Mockito.when; import com.google.common.base.Stopwatch; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.common.util.concurrent.MoreExecutors; import com.google.protobuf.Struct; @@ -56,6 +57,7 @@ import io.grpc.xds.EnvoyProtoData.Locality; import io.grpc.xds.EnvoyProtoData.UpstreamLocalityStats; import io.grpc.xds.LoadStatsManager.LoadStatsStore; import io.grpc.xds.LoadStatsManager.LoadStatsStoreFactory; +import io.grpc.xds.XdsClient.XdsChannel; import java.util.ArrayDeque; import java.util.Arrays; import java.util.Collection; @@ -87,14 +89,10 @@ import org.mockito.MockitoAnnotations; public class LoadReportClientTest { private static final String TARGET_NAME = "lrs-test.example.com"; // bootstrap node identifier - private static final Node NODE = - Node.newBuilder() + private static final EnvoyProtoData.Node NODE = + EnvoyProtoData.Node.newBuilder() .setId("LRS test") - .setMetadata( - Struct.newBuilder() - .putFields( - "TRAFFICDIRECTOR_NETWORK_HOSTNAME", - Value.newBuilder().setStringValue("default").build())) + .setMetadata(ImmutableMap.of("TRAFFICDIRECTOR_NETWORK_HOSTNAME", "default")) .build(); private static final String CLUSTER1 = "cluster-foo.googleapis.com"; private static final String CLUSTER2 = "cluster-bar.googleapis.com"; @@ -189,7 +187,7 @@ public class LoadReportClientTest { new LoadReportClient( TARGET_NAME, loadStatsManager, - channel, + new XdsChannel(channel, false), NODE, syncContext, fakeClock.getScheduledExecutorService(),