diff --git a/xds/src/main/java/io/grpc/xds/XdsClientLoadRecorder.java b/xds/src/main/java/io/grpc/xds/XdsClientLoadRecorder.java
index edf531ca7f..c2738b506e 100644
--- a/xds/src/main/java/io/grpc/xds/XdsClientLoadRecorder.java
+++ b/xds/src/main/java/io/grpc/xds/XdsClientLoadRecorder.java
@@ -18,6 +18,7 @@ package io.grpc.xds;
import static com.google.common.base.Preconditions.checkNotNull;
+import com.google.common.annotations.VisibleForTesting;
import io.grpc.ClientStreamTracer;
import io.grpc.ClientStreamTracer.StreamInfo;
import io.grpc.Metadata;
@@ -72,6 +73,16 @@ final class XdsClientLoadRecorder extends ClientStreamTracer.Factory {
private final AtomicLong callsFailed = new AtomicLong();
private boolean active = true;
+ ClientLoadCounter() {
+ }
+
+ @VisibleForTesting
+ ClientLoadCounter(long callsInProgress, long callsFinished, long callsFailed) {
+ this.callsInProgress.set(callsInProgress);
+ this.callsFinished.set(callsFinished);
+ this.callsFailed.set(callsFailed);
+ }
+
/**
* Generate a query count snapshot and reset counts for next snapshot.
*/
diff --git a/xds/src/main/java/io/grpc/xds/XdsLoadStatsManager.java b/xds/src/main/java/io/grpc/xds/XdsLoadStatsManager.java
new file mode 100644
index 0000000000..6e2e5338f6
--- /dev/null
+++ b/xds/src/main/java/io/grpc/xds/XdsLoadStatsManager.java
@@ -0,0 +1,89 @@
+/*
+ * Copyright 2019 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.xds;
+
+import io.envoyproxy.envoy.api.v2.core.Locality;
+import io.grpc.LoadBalancer.PickResult;
+import javax.annotation.concurrent.NotThreadSafe;
+
+/**
+ * An {@link XdsLoadStatsManager} is in charge of recording client side load stats, collecting
+ * backend cost metrics and sending load reports to the remote balancer. It shares the same
+ * channel with {@link XdsLoadBalancer} and its lifecycle is managed by {@link XdsLoadBalancer}.
+ */
+@NotThreadSafe
+interface XdsLoadStatsManager {
+
+ /**
+ * Establishes load reporting communication and negotiates with the remote balancer to report load
+ * stats periodically.
+ *
+ *
This method should be the first method to be called in the lifecycle of {@link
+ * XdsLoadStatsManager} and should only be called once.
+ *
+ *
This method is not thread-safe and should be called from the same synchronized context
+ * returned by {@link XdsLoadBalancer#helper#getSynchronizationContext}.
+ */
+ void startLoadReporting();
+
+ /**
+ * Terminates load reporting.
+ *
+ *
No method in {@link XdsLoadStatsManager} should be called after calling this method.
+ *
+ *
This method is not thread-safe and should be called from the same synchronized context
+ * returned by {@link XdsLoadBalancer#helper#getSynchronizationContext}.
+ */
+ void stopLoadReporting();
+
+ /**
+ * Applies client side load recording to {@link PickResult}s picked by the intra-locality picker
+ * for the provided locality.
+ *
+ *
This method is thread-safe.
+ */
+ PickResult interceptPickResult(PickResult pickResult, Locality locality);
+
+ /**
+ * Tracks load stats for endpoints in the provided locality. To be called upon balancer locality
+ * updates only for newly assigned localities. Only load stats for endpoints in added localities
+ * will be reported to the remote balancer.
+ *
+ *
This method is not thread-safe and should be called from the same synchronized context
+ * returned by {@link XdsLoadBalancer#helper#getSynchronizationContext}.
+ */
+ void addLocality(Locality locality);
+
+ /**
+ * Stops tracking load stats for endpoints in the provided locality. To be called upon balancer
+ * locality updates only for newly removed localities. Load stats for endpoints in removed
+ * localities will no longer be reported to the remote balancer when client stop sending loads to
+ * them.
+ *
+ *
This method is not thread-safe and should be called from the same synchronized context *
+ * returned by {@link XdsLoadBalancer#helper#getSynchronizationContext}.
+ */
+ void removeLocality(Locality locality);
+
+ /**
+ * Records a client-side request drop with the provided category instructed by the remote
+ * balancer. Stats for dropped requests are aggregated in cluster level.
+ *
+ *
This method is thread-safe.
+ */
+ void recordDroppedRequest(String category);
+}
diff --git a/xds/src/main/java/io/grpc/xds/XdsLrsClient.java b/xds/src/main/java/io/grpc/xds/XdsLrsClient.java
new file mode 100644
index 0000000000..65bbe2990d
--- /dev/null
+++ b/xds/src/main/java/io/grpc/xds/XdsLrsClient.java
@@ -0,0 +1,351 @@
+/*
+ * Copyright 2019 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.xds;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkNotNull;
+import static com.google.common.base.Preconditions.checkState;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Stopwatch;
+import com.google.common.base.Supplier;
+import com.google.protobuf.Struct;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.Durations;
+import io.envoyproxy.envoy.api.v2.core.Locality;
+import io.envoyproxy.envoy.api.v2.core.Node;
+import io.envoyproxy.envoy.service.load_stats.v2.LoadReportingServiceGrpc;
+import io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest;
+import io.envoyproxy.envoy.service.load_stats.v2.LoadStatsResponse;
+import io.grpc.ChannelLogger;
+import io.grpc.ChannelLogger.ChannelLogLevel;
+import io.grpc.LoadBalancer.Helper;
+import io.grpc.LoadBalancer.PickResult;
+import io.grpc.ManagedChannel;
+import io.grpc.Status;
+import io.grpc.SynchronizationContext;
+import io.grpc.SynchronizationContext.ScheduledHandle;
+import io.grpc.internal.BackoffPolicy;
+import io.grpc.internal.GrpcUtil;
+import io.grpc.stub.StreamObserver;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.NotThreadSafe;
+
+/**
+ * Client of XDS load reporting service. Methods in this class are expected to be called in
+ * the same synchronized context that {@link XdsLoadBalancer#helper#getSynchronizationContext}
+ * returns.
+ */
+@NotThreadSafe
+final class XdsLrsClient implements XdsLoadStatsManager {
+
+ @VisibleForTesting
+ static final String TRAFFICDIRECTOR_HOSTNAME_FIELD
+ = "com.googleapis.trafficdirector.grpc_hostname";
+ private final String serviceName;
+ private final ManagedChannel channel;
+ private final SynchronizationContext syncContext;
+ private final ScheduledExecutorService timerService;
+ private final Supplier stopwatchSupplier;
+ private final Stopwatch retryStopwatch;
+ private final ChannelLogger logger;
+ private final BackoffPolicy.Provider backoffPolicyProvider;
+ private final XdsLoadReportStore loadReportStore;
+ private boolean started;
+
+ @Nullable
+ private BackoffPolicy lrsRpcRetryPolicy;
+ @Nullable
+ private ScheduledHandle lrsRpcRetryTimer;
+
+ @Nullable
+ private LrsStream lrsStream;
+
+ XdsLrsClient(ManagedChannel channel,
+ Helper helper,
+ BackoffPolicy.Provider backoffPolicyProvider) {
+ this(channel, helper, GrpcUtil.STOPWATCH_SUPPLIER, backoffPolicyProvider,
+ new XdsLoadReportStore(checkNotNull(helper, "helper").getAuthority()));
+ }
+
+ @VisibleForTesting
+ XdsLrsClient(ManagedChannel channel,
+ Helper helper,
+ Supplier stopwatchSupplier,
+ BackoffPolicy.Provider backoffPolicyProvider,
+ XdsLoadReportStore loadReportStore) {
+ this.channel = checkNotNull(channel, "channel");
+ this.serviceName = checkNotNull(helper.getAuthority(), "serviceName");
+ this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext");
+ this.stopwatchSupplier = checkNotNull(stopwatchSupplier, "stopwatchSupplier");
+ this.retryStopwatch = stopwatchSupplier.get();
+ this.logger = checkNotNull(helper.getChannelLogger(), "logger");
+ this.timerService = checkNotNull(helper.getScheduledExecutorService(), "timeService");
+ this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider");
+ this.loadReportStore = checkNotNull(loadReportStore, "loadReportStore");
+ started = false;
+ }
+
+ @Override
+ public void startLoadReporting() {
+ checkState(!started, "load reporting has already started");
+ started = true;
+ startLrsRpc();
+ }
+
+ @Override
+ public void stopLoadReporting() {
+ if (lrsRpcRetryTimer != null) {
+ lrsRpcRetryTimer.cancel();
+ }
+ if (lrsStream != null) {
+ lrsStream.close(null);
+ }
+ // Do not shutdown channel as it is not owned by LrsClient.
+ }
+
+ @Override
+ public void addLocality(Locality locality) {
+ checkState(started, "load reporting must be started first");
+ syncContext.throwIfNotInThisSynchronizationContext();
+ loadReportStore.addLocality(locality);
+ }
+
+ @Override
+ public void removeLocality(final Locality locality) {
+ checkState(started, "load reporting must be started first");
+ syncContext.throwIfNotInThisSynchronizationContext();
+ loadReportStore.removeLocality(locality);
+ }
+
+ @Override
+ public void recordDroppedRequest(String category) {
+ checkState(started, "load reporting must be started first");
+ loadReportStore.recordDroppedRequest(category);
+ }
+
+ @Override
+ public PickResult interceptPickResult(PickResult pickResult, Locality locality) {
+ checkState(started, "load reporting must be started first");
+ return loadReportStore.interceptPickResult(pickResult, locality);
+ }
+
+ @VisibleForTesting
+ static class LoadReportingTask implements Runnable {
+ private final LrsStream stream;
+
+ LoadReportingTask(LrsStream stream) {
+ this.stream = stream;
+ }
+
+ @Override
+ public void run() {
+ stream.sendLoadReport();
+ }
+ }
+
+ @VisibleForTesting
+ class LrsRpcRetryTask implements Runnable {
+
+ @Override
+ public void run() {
+ startLrsRpc();
+ }
+ }
+
+ private void startLrsRpc() {
+ checkState(lrsStream == null, "previous lbStream has not been cleared yet");
+ LoadReportingServiceGrpc.LoadReportingServiceStub stub
+ = LoadReportingServiceGrpc.newStub(channel);
+ lrsStream = new LrsStream(stub, stopwatchSupplier.get());
+ retryStopwatch.reset().start();
+ lrsStream.start();
+ }
+
+ private class LrsStream implements StreamObserver {
+
+ final LoadReportingServiceGrpc.LoadReportingServiceStub stub;
+ final Stopwatch reportStopwatch;
+ StreamObserver lrsRequestWriter;
+ boolean initialResponseReceived;
+ boolean closed;
+ long loadReportIntervalNano = -1;
+ ScheduledHandle loadReportTimer;
+
+ LrsStream(LoadReportingServiceGrpc.LoadReportingServiceStub stub, Stopwatch stopwatch) {
+ this.stub = checkNotNull(stub, "stub");
+ reportStopwatch = checkNotNull(stopwatch, "stopwatch");
+ }
+
+ void start() {
+ lrsRequestWriter = stub.withWaitForReady().streamLoadStats(this);
+ reportStopwatch.reset().start();
+ LoadStatsRequest initRequest =
+ LoadStatsRequest.newBuilder()
+ .setNode(Node.newBuilder()
+ .setMetadata(Struct.newBuilder()
+ .putFields(
+ TRAFFICDIRECTOR_HOSTNAME_FIELD,
+ Value.newBuilder().setStringValue(serviceName).build())))
+ .build();
+ lrsRequestWriter.onNext(initRequest);
+ logger.log(ChannelLogLevel.DEBUG, "Initial LRS request sent: {0}", initRequest);
+ }
+
+ @Override
+ public void onNext(final LoadStatsResponse response) {
+ syncContext.execute(new Runnable() {
+ @Override
+ public void run() {
+ handleResponse(response);
+ }
+ });
+ }
+
+ @Override
+ public void onError(final Throwable t) {
+ syncContext.execute(new Runnable() {
+ @Override
+ public void run() {
+ handleStreamClosed(Status.fromThrowable(t)
+ .augmentDescription("Stream to XDS management server had an error"));
+ }
+ });
+ }
+
+ @Override
+ public void onCompleted() {
+ syncContext.execute(new Runnable() {
+ @Override
+ public void run() {
+ handleStreamClosed(
+ Status.UNAVAILABLE.withDescription("Stream to XDS management server was closed"));
+ }
+ });
+ }
+
+ private void sendLoadReport() {
+ long interval = reportStopwatch.elapsed(TimeUnit.NANOSECONDS);
+ reportStopwatch.reset().start();
+ lrsRequestWriter.onNext(LoadStatsRequest.newBuilder()
+ .setNode(Node.newBuilder()
+ .setMetadata(Struct.newBuilder()
+ .putFields(
+ TRAFFICDIRECTOR_HOSTNAME_FIELD,
+ Value.newBuilder().setStringValue(serviceName).build())))
+ .addClusterStats(loadReportStore.generateLoadReport(Durations.fromNanos(interval)))
+ .build());
+ scheduleNextLoadReport();
+ }
+
+ private void scheduleNextLoadReport() {
+ // Cancel pending load report and reschedule with updated load reporting interval.
+ if (loadReportTimer != null && loadReportTimer.isPending()) {
+ loadReportTimer.cancel();
+ loadReportTimer = null;
+ }
+ if (loadReportIntervalNano > 0) {
+ loadReportTimer = syncContext.schedule(
+ new LoadReportingTask(this), loadReportIntervalNano, TimeUnit.NANOSECONDS,
+ timerService);
+ }
+ }
+
+ private void handleResponse(LoadStatsResponse response) {
+ if (closed) {
+ return;
+ }
+
+ if (!initialResponseReceived) {
+ logger.log(ChannelLogLevel.DEBUG, "Received LRS initial response: {0}", response);
+ initialResponseReceived = true;
+ } else {
+ logger.log(ChannelLogLevel.DEBUG, "Received an LRS response: {0}", response);
+ }
+ loadReportIntervalNano = Durations.toNanos(response.getLoadReportingInterval());
+ List serviceList = Collections.unmodifiableList(response.getClustersList());
+ // For gRPC use case, LRS response will only contain one cluster, which is the same as in
+ // the EDS response.
+ if (serviceList.size() != 1 || !serviceList.get(0).equals(serviceName)) {
+ logger.log(ChannelLogLevel.ERROR, "Unmatched cluster name(s): {0} with EDS response: {1}",
+ serviceList, serviceName);
+ return;
+ }
+ scheduleNextLoadReport();
+ }
+
+ private void handleStreamClosed(Status status) {
+ checkArgument(!status.isOk(), "unexpected OK status");
+ if (closed) {
+ return;
+ }
+ closed = true;
+ cleanUp();
+
+ long delayNanos = 0;
+ if (initialResponseReceived || lrsRpcRetryPolicy == null) {
+ // Reset the backoff sequence if balancer has sent the initial response, or backoff sequence
+ // has never been initialized.
+ lrsRpcRetryPolicy = backoffPolicyProvider.get();
+ }
+ // Backoff only when balancer wasn't working previously.
+ if (!initialResponseReceived) {
+ // The back-off policy determines the interval between consecutive RPC upstarts, thus the
+ // actual delay may be smaller than the value from the back-off policy, or even negative,
+ // depending how much time was spent in the previous RPC.
+ delayNanos =
+ lrsRpcRetryPolicy.nextBackoffNanos() - retryStopwatch.elapsed(TimeUnit.NANOSECONDS);
+ }
+ logger.log(ChannelLogLevel.DEBUG, "LRS stream closed, backoff in {0} second(s)",
+ TimeUnit.NANOSECONDS.toSeconds(delayNanos <= 0 ? 0 : delayNanos));
+ if (delayNanos <= 0) {
+ startLrsRpc();
+ } else {
+ lrsRpcRetryTimer =
+ syncContext.schedule(new LrsRpcRetryTask(), delayNanos, TimeUnit.NANOSECONDS,
+ timerService);
+ }
+ }
+
+ private void close(@Nullable Exception error) {
+ if (closed) {
+ return;
+ }
+ closed = true;
+ cleanUp();
+ if (error == null) {
+ lrsRequestWriter.onCompleted();
+ } else {
+ lrsRequestWriter.onError(error);
+ }
+ }
+
+ private void cleanUp() {
+ if (loadReportTimer != null) {
+ loadReportTimer.cancel();
+ loadReportTimer = null;
+ }
+ if (lrsStream == this) {
+ lrsStream = null;
+ }
+ }
+ }
+}
diff --git a/xds/src/test/java/io/grpc/xds/XdsLrsClientTest.java b/xds/src/test/java/io/grpc/xds/XdsLrsClientTest.java
new file mode 100644
index 0000000000..e19d0639bc
--- /dev/null
+++ b/xds/src/test/java/io/grpc/xds/XdsLrsClientTest.java
@@ -0,0 +1,512 @@
+/*
+ * Copyright 2019 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.xds;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertEquals;
+import static org.mockito.AdditionalAnswers.delegatesTo;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.inOrder;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.when;
+
+import com.google.common.collect.Iterables;
+import com.google.protobuf.Struct;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.Durations;
+import io.envoyproxy.envoy.api.v2.core.Locality;
+import io.envoyproxy.envoy.api.v2.core.Node;
+import io.envoyproxy.envoy.api.v2.endpoint.ClusterStats;
+import io.envoyproxy.envoy.api.v2.endpoint.ClusterStats.DroppedRequests;
+import io.envoyproxy.envoy.api.v2.endpoint.UpstreamLocalityStats;
+import io.envoyproxy.envoy.service.load_stats.v2.LoadReportingServiceGrpc;
+import io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest;
+import io.envoyproxy.envoy.service.load_stats.v2.LoadStatsResponse;
+import io.grpc.ChannelLogger;
+import io.grpc.LoadBalancer.Helper;
+import io.grpc.ManagedChannel;
+import io.grpc.Status;
+import io.grpc.SynchronizationContext;
+import io.grpc.inprocess.InProcessChannelBuilder;
+import io.grpc.inprocess.InProcessServerBuilder;
+import io.grpc.internal.BackoffPolicy;
+import io.grpc.internal.FakeClock;
+import io.grpc.stub.StreamObserver;
+import io.grpc.testing.GrpcCleanupRule;
+import io.grpc.xds.XdsClientLoadRecorder.ClientLoadCounter;
+import java.text.MessageFormat;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.Random;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicLong;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Captor;
+import org.mockito.InOrder;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+/** Unit tests for {@link XdsLrsClient}. */
+@RunWith(JUnit4.class)
+public class XdsLrsClientTest {
+
+ private static final String SERVICE_AUTHORITY = "api.google.com";
+ private static final FakeClock.TaskFilter LOAD_REPORTING_TASK_FILTER =
+ new FakeClock.TaskFilter() {
+ @Override
+ public boolean shouldAccept(Runnable command) {
+ return command.toString().contains(XdsLrsClient.LoadReportingTask.class.getSimpleName());
+ }
+ };
+ private static final FakeClock.TaskFilter LRS_RPC_RETRY_TASK_FILTER =
+ new FakeClock.TaskFilter() {
+ @Override
+ public boolean shouldAccept(Runnable command) {
+ return command.toString().contains(XdsLrsClient.LrsRpcRetryTask.class.getSimpleName());
+ }
+ };
+ @Rule
+ public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule();
+ private final SynchronizationContext syncContext = new SynchronizationContext(
+ new Thread.UncaughtExceptionHandler() {
+ @Override
+ public void uncaughtException(Thread t, Throwable e) {
+ throw new AssertionError(e);
+ }
+ });
+ private final LinkedList logs = new LinkedList<>();
+ private final ChannelLogger channelLogger = new ChannelLogger() {
+ @Override
+ public void log(ChannelLogLevel level, String msg) {
+ logs.add(level + ": " + msg);
+ }
+
+ @Override
+ public void log(ChannelLogLevel level, String template, Object... args) {
+ log(level, MessageFormat.format(template, args));
+ }
+ };
+ private LoadReportingServiceGrpc.LoadReportingServiceImplBase mockLoadReportingService;
+ private FakeClock fakeClock = new FakeClock();
+ private final LinkedList> lrsRequestObservers =
+ new LinkedList<>();
+ @Captor
+ private ArgumentCaptor> lrsResponseObserverCaptor;
+
+ @Mock
+ private Helper helper;
+ @Mock
+ private BackoffPolicy.Provider backoffPolicyProvider;
+ private static final LoadStatsRequest EXPECTED_INITIAL_REQ = LoadStatsRequest.newBuilder()
+ .setNode(Node.newBuilder()
+ .setMetadata(Struct.newBuilder()
+ .putFields(
+ XdsLrsClient.TRAFFICDIRECTOR_HOSTNAME_FIELD,
+ Value.newBuilder().setStringValue(SERVICE_AUTHORITY).build())))
+ .build();
+ @Mock
+ private BackoffPolicy backoffPolicy1;
+ private ManagedChannel channel;
+ private XdsLrsClient lrsClient;
+ @Mock
+ private BackoffPolicy backoffPolicy2;
+
+ private static ClusterStats buildEmptyClusterStats(long loadReportIntervalNanos) {
+ return ClusterStats.newBuilder()
+ .setClusterName(SERVICE_AUTHORITY)
+ .setLoadReportInterval(Durations.fromNanos(loadReportIntervalNanos)).build();
+ }
+
+ private static LoadStatsResponse buildLrsResponse(long loadReportIntervalNanos) {
+ return LoadStatsResponse.newBuilder()
+ .addClusters(SERVICE_AUTHORITY)
+ .setLoadReportingInterval(Durations.fromNanos(loadReportIntervalNanos)).build();
+ }
+
+ @SuppressWarnings("unchecked")
+ @Before
+ public void setUp() throws Exception {
+ MockitoAnnotations.initMocks(this);
+ mockLoadReportingService = mock(LoadReportingServiceGrpc.LoadReportingServiceImplBase.class,
+ delegatesTo(
+ new LoadReportingServiceGrpc.LoadReportingServiceImplBase() {
+ @Override
+ public StreamObserver streamLoadStats(
+ final StreamObserver responseObserver) {
+ StreamObserver requestObserver =
+ mock(StreamObserver.class);
+ Answer closeRpc = new Answer() {
+ @Override
+ public Void answer(InvocationOnMock invocation) {
+ responseObserver.onCompleted();
+ return null;
+ }
+ };
+ doAnswer(closeRpc).when(requestObserver).onCompleted();
+ lrsRequestObservers.add(requestObserver);
+ return requestObserver;
+ }
+ }
+ ));
+ cleanupRule.register(InProcessServerBuilder.forName("fakeLoadReportingServer").directExecutor()
+ .addService(mockLoadReportingService).build().start());
+ channel = cleanupRule.register(
+ InProcessChannelBuilder.forName("fakeLoadReportingServer").directExecutor().build());
+ when(helper.getSynchronizationContext()).thenReturn(syncContext);
+ when(helper.getScheduledExecutorService()).thenReturn(fakeClock.getScheduledExecutorService());
+ when(helper.getChannelLogger()).thenReturn(channelLogger);
+ when(helper.getAuthority()).thenReturn(SERVICE_AUTHORITY);
+ when(backoffPolicyProvider.get()).thenReturn(backoffPolicy1, backoffPolicy2);
+ when(backoffPolicy1.nextBackoffNanos())
+ .thenReturn(TimeUnit.SECONDS.toNanos(1L), TimeUnit.SECONDS.toNanos(10L));
+ when(backoffPolicy2.nextBackoffNanos())
+ .thenReturn(TimeUnit.SECONDS.toNanos(1L), TimeUnit.SECONDS.toNanos(10L));
+ logs.clear();
+ lrsClient = new XdsLrsClient(channel, helper, fakeClock.getStopwatchSupplier(),
+ backoffPolicyProvider, new XdsLoadReportStore(SERVICE_AUTHORITY));
+ lrsClient.startLoadReporting();
+ }
+
+ @After
+ public void tearDown() {
+ lrsClient.stopLoadReporting();
+ }
+
+ private void assertNextReport(InOrder inOrder, StreamObserver requestObserver,
+ ClusterStats expectedStats) {
+ long loadReportIntervalNanos = Durations.toNanos(expectedStats.getLoadReportInterval());
+ assertEquals(0, fakeClock.forwardTime(loadReportIntervalNanos - 1, TimeUnit.NANOSECONDS));
+ inOrder.verifyNoMoreInteractions();
+ assertEquals(1, fakeClock.forwardTime(1, TimeUnit.NANOSECONDS));
+ // A second load report is scheduled upon the first is sent.
+ assertEquals(1, fakeClock.numPendingTasks(LOAD_REPORTING_TASK_FILTER));
+ ArgumentCaptor reportCaptor = ArgumentCaptor.forClass(null);
+ inOrder.verify(requestObserver).onNext(reportCaptor.capture());
+ LoadStatsRequest report = reportCaptor.getValue();
+ assertEquals(report.getNode(), Node.newBuilder()
+ .setMetadata(Struct.newBuilder()
+ .putFields(
+ XdsLrsClient.TRAFFICDIRECTOR_HOSTNAME_FIELD,
+ Value.newBuilder().setStringValue(SERVICE_AUTHORITY).build()))
+ .build());
+ assertEquals(1, report.getClusterStatsCount());
+ assertClusterStatsEqual(expectedStats, report.getClusterStats(0));
+ }
+
+ private void assertClusterStatsEqual(ClusterStats stats1, ClusterStats stats2) {
+ assertEquals(stats1.getClusterName(), stats2.getClusterName());
+ assertEquals(stats1.getLoadReportInterval(), stats2.getLoadReportInterval());
+ assertEquals(stats1.getUpstreamLocalityStatsCount(), stats2.getUpstreamLocalityStatsCount());
+ assertEquals(stats1.getDroppedRequestsCount(), stats2.getDroppedRequestsCount());
+ assertEquals(new HashSet<>(stats1.getUpstreamLocalityStatsList()),
+ new HashSet<>(stats2.getUpstreamLocalityStatsList()));
+ assertEquals(new HashSet<>(stats1.getDroppedRequestsList()),
+ new HashSet<>(stats2.getDroppedRequestsList()));
+ }
+
+ @Test
+ public void loadReportInitialRequest() {
+ verify(mockLoadReportingService).streamLoadStats(lrsResponseObserverCaptor.capture());
+ assertThat(lrsRequestObservers).hasSize(1);
+ StreamObserver requestObserver = lrsRequestObservers.poll();
+ verify(requestObserver).onNext(EXPECTED_INITIAL_REQ);
+ // No more request should be sent until receiving initial response. No load reporting
+ // should be scheduled.
+ assertThat(fakeClock.getPendingTasks(LOAD_REPORTING_TASK_FILTER)).isEmpty();
+ verifyNoMoreInteractions(requestObserver);
+ }
+
+ @Test
+ public void loadReportActualIntervalAsSpecified() {
+ verify(mockLoadReportingService).streamLoadStats(lrsResponseObserverCaptor.capture());
+ StreamObserver responseObserver = lrsResponseObserverCaptor.getValue();
+ assertThat(lrsRequestObservers).hasSize(1);
+ StreamObserver requestObserver = lrsRequestObservers.poll();
+ InOrder inOrder = inOrder(requestObserver);
+ inOrder.verify(requestObserver).onNext(EXPECTED_INITIAL_REQ);
+ assertThat(logs).containsExactly("DEBUG: Initial LRS request sent: " + EXPECTED_INITIAL_REQ);
+ logs.poll();
+
+ responseObserver.onNext(buildLrsResponse(1453));
+ assertThat(logs).containsExactly(
+ "DEBUG: Received LRS initial response: " + buildLrsResponse(1453));
+ assertNextReport(inOrder, requestObserver, buildEmptyClusterStats(1453));
+ }
+
+ @Test
+ public void loadReportIntervalUpdate() {
+ verify(mockLoadReportingService).streamLoadStats(lrsResponseObserverCaptor.capture());
+ StreamObserver responseObserver = lrsResponseObserverCaptor.getValue();
+ assertThat(lrsRequestObservers).hasSize(1);
+ StreamObserver requestObserver = lrsRequestObservers.poll();
+ InOrder inOrder = inOrder(requestObserver);
+ inOrder.verify(requestObserver).onNext(EXPECTED_INITIAL_REQ);
+ assertThat(logs).containsExactly("DEBUG: Initial LRS request sent: " + EXPECTED_INITIAL_REQ);
+ logs.poll();
+
+ responseObserver.onNext(buildLrsResponse(1362));
+ assertThat(logs).containsExactly(
+ "DEBUG: Received LRS initial response: " + buildLrsResponse(1362));
+ logs.poll();
+ assertNextReport(inOrder, requestObserver, buildEmptyClusterStats(1362));
+
+ responseObserver.onNext(buildLrsResponse(2183345));
+ assertThat(logs).containsExactly(
+ "DEBUG: Received an LRS response: " + buildLrsResponse(2183345));
+ // Updated load reporting interval becomes effective immediately.
+ assertNextReport(inOrder, requestObserver, buildEmptyClusterStats(2183345));
+ }
+
+ @Test
+ public void reportRecordedLoadData() {
+ lrsClient.stopLoadReporting();
+ verify(mockLoadReportingService).streamLoadStats(lrsResponseObserverCaptor.capture());
+ lrsRequestObservers.clear();
+
+ ConcurrentMap localityCounters =
+ new ConcurrentHashMap<>();
+ ConcurrentMap dropCounters = new ConcurrentHashMap<>();
+ XdsLoadReportStore loadReportStore =
+ new XdsLoadReportStore(SERVICE_AUTHORITY, localityCounters, dropCounters);
+ lrsClient = new XdsLrsClient(channel, helper, fakeClock.getStopwatchSupplier(),
+ backoffPolicyProvider, loadReportStore);
+ lrsClient.startLoadReporting();
+
+ verify(mockLoadReportingService, times(2)).streamLoadStats(lrsResponseObserverCaptor.capture());
+ StreamObserver responseObserver = lrsResponseObserverCaptor.getValue();
+ assertThat(lrsRequestObservers).hasSize(1);
+ StreamObserver requestObserver = lrsRequestObservers.poll();
+ InOrder inOrder = inOrder(requestObserver);
+ inOrder.verify(requestObserver).onNext(EXPECTED_INITIAL_REQ);
+
+ Locality locality = Locality.newBuilder()
+ .setRegion("test_region")
+ .setZone("test_zone")
+ .setSubZone("test_subzone")
+ .build();
+ Random rand = new Random();
+ // Integer range is large enough for testing.
+ long callsInProgress1 = rand.nextInt(Integer.MAX_VALUE);
+ long callsFinished1 = rand.nextInt(Integer.MAX_VALUE);
+ long callsFailed1 = callsFinished1 - rand.nextInt((int) callsFinished1);
+ localityCounters.put(locality,
+ new ClientLoadCounter(callsInProgress1, callsFinished1, callsFailed1));
+
+ long numLbDrops = rand.nextLong();
+ long numThrottleDrops = rand.nextLong();
+ dropCounters.put("lb", new AtomicLong(numLbDrops));
+ dropCounters.put("throttle", new AtomicLong(numThrottleDrops));
+
+ responseObserver.onNext(buildLrsResponse(1362));
+
+ ClusterStats expectedStats = ClusterStats.newBuilder()
+ .setClusterName(SERVICE_AUTHORITY)
+ .setLoadReportInterval(Durations.fromNanos(1362))
+ .addUpstreamLocalityStats(UpstreamLocalityStats.newBuilder()
+ .setLocality(locality)
+ .setTotalRequestsInProgress(callsInProgress1)
+ .setTotalSuccessfulRequests(callsFinished1 - callsFailed1)
+ .setTotalErrorRequests(callsFailed1))
+ .addDroppedRequests(DroppedRequests.newBuilder()
+ .setCategory("lb")
+ .setDroppedCount(numLbDrops))
+ .addDroppedRequests(DroppedRequests.newBuilder()
+ .setCategory("throttle")
+ .setDroppedCount(numThrottleDrops))
+ .build();
+ assertNextReport(inOrder, requestObserver, expectedStats);
+
+ // No client load happens upon next load reporting, only number of in-progress
+ // calls are non-zero.
+ expectedStats = ClusterStats.newBuilder()
+ .setClusterName(SERVICE_AUTHORITY)
+ .setLoadReportInterval(Durations.fromNanos(1362))
+ .addUpstreamLocalityStats(UpstreamLocalityStats.newBuilder()
+ .setLocality(locality)
+ .setTotalRequestsInProgress(callsInProgress1))
+ .addDroppedRequests(DroppedRequests.newBuilder()
+ .setCategory("lb")
+ .setDroppedCount(0))
+ .addDroppedRequests(DroppedRequests.newBuilder()
+ .setCategory("throttle")
+ .setDroppedCount(0))
+ .build();
+ assertNextReport(inOrder, requestObserver, expectedStats);
+ }
+
+ @Test
+ public void lrsStreamClosedAndRetried() {
+ InOrder inOrder = inOrder(mockLoadReportingService, backoffPolicyProvider, backoffPolicy1,
+ backoffPolicy2);
+ inOrder.verify(mockLoadReportingService).streamLoadStats(lrsResponseObserverCaptor.capture());
+ StreamObserver responseObserver = lrsResponseObserverCaptor.getValue();
+ assertEquals(1, lrsRequestObservers.size());
+ StreamObserver requestObserver = lrsRequestObservers.poll();
+
+ // First balancer RPC
+ verify(requestObserver).onNext(EXPECTED_INITIAL_REQ);
+ assertThat(logs).containsExactly("DEBUG: Initial LRS request sent: " + EXPECTED_INITIAL_REQ);
+ logs.poll();
+ assertEquals(0, fakeClock.numPendingTasks(LRS_RPC_RETRY_TASK_FILTER));
+
+ // Balancer closes it immediately (erroneously)
+ responseObserver.onCompleted();
+
+ // Will start backoff sequence 1 (1s)
+ inOrder.verify(backoffPolicyProvider).get();
+ inOrder.verify(backoffPolicy1).nextBackoffNanos();
+ assertThat(logs).containsExactly("DEBUG: LRS stream closed, backoff in 1 second(s)");
+ logs.poll();
+ assertEquals(1, fakeClock.numPendingTasks(LRS_RPC_RETRY_TASK_FILTER));
+
+ // Fast-forward to a moment before the retry
+ fakeClock.forwardNanos(TimeUnit.SECONDS.toNanos(1) - 1);
+ verifyNoMoreInteractions(mockLoadReportingService);
+ // Then time for retry
+ fakeClock.forwardNanos(1);
+ inOrder.verify(mockLoadReportingService).streamLoadStats(lrsResponseObserverCaptor.capture());
+ responseObserver = lrsResponseObserverCaptor.getValue();
+ assertEquals(1, lrsRequestObservers.size());
+ requestObserver = lrsRequestObservers.poll();
+ verify(requestObserver).onNext(eq(EXPECTED_INITIAL_REQ));
+ assertThat(logs).containsExactly("DEBUG: Initial LRS request sent: " + EXPECTED_INITIAL_REQ);
+ logs.poll();
+ assertEquals(0, fakeClock.numPendingTasks(LRS_RPC_RETRY_TASK_FILTER));
+
+ // Balancer closes it with an error.
+ responseObserver.onError(Status.UNAVAILABLE.asException());
+ // Will continue the backoff sequence 1 (10s)
+ verifyNoMoreInteractions(backoffPolicyProvider);
+ inOrder.verify(backoffPolicy1).nextBackoffNanos();
+ assertThat(logs).containsExactly("DEBUG: LRS stream closed, backoff in 10 second(s)");
+ logs.poll();
+ assertEquals(1, fakeClock.numPendingTasks(LRS_RPC_RETRY_TASK_FILTER));
+
+ // Fast-forward to a moment before the retry
+ fakeClock.forwardNanos(TimeUnit.SECONDS.toNanos(10) - 1);
+ verifyNoMoreInteractions(mockLoadReportingService);
+ // Then time for retry
+ fakeClock.forwardNanos(1);
+ inOrder.verify(mockLoadReportingService).streamLoadStats(lrsResponseObserverCaptor.capture());
+ responseObserver = lrsResponseObserverCaptor.getValue();
+ assertEquals(1, lrsRequestObservers.size());
+ requestObserver = lrsRequestObservers.poll();
+ verify(requestObserver).onNext(eq(EXPECTED_INITIAL_REQ));
+ assertThat(logs).containsExactly("DEBUG: Initial LRS request sent: " + EXPECTED_INITIAL_REQ);
+ logs.poll();
+ assertEquals(0, fakeClock.numPendingTasks(LRS_RPC_RETRY_TASK_FILTER));
+
+ // Balancer sends initial response.
+ responseObserver.onNext(buildLrsResponse(0));
+ assertThat(logs).containsExactly(
+ "DEBUG: Received LRS initial response: " + buildLrsResponse(0));
+ logs.poll();
+
+ // Then breaks the RPC
+ responseObserver.onError(Status.UNAVAILABLE.asException());
+
+ // Will reset the retry sequence and retry immediately, because balancer has responded.
+ inOrder.verify(backoffPolicyProvider).get();
+ inOrder.verify(mockLoadReportingService).streamLoadStats(lrsResponseObserverCaptor.capture());
+ assertThat(logs).containsExactly("DEBUG: LRS stream closed, backoff in 0 second(s)",
+ "DEBUG: Initial LRS request sent: " + EXPECTED_INITIAL_REQ);
+ logs.clear();
+ responseObserver = lrsResponseObserverCaptor.getValue();
+ assertEquals(1, lrsRequestObservers.size());
+ requestObserver = lrsRequestObservers.poll();
+ verify(requestObserver).onNext(eq(EXPECTED_INITIAL_REQ));
+
+ // Fail the retry after spending 4ns
+ fakeClock.forwardNanos(4);
+ responseObserver.onError(Status.UNAVAILABLE.asException());
+
+ // Will be on the first retry (1s) of backoff sequence 2.
+ inOrder.verify(backoffPolicy2).nextBackoffNanos();
+ // The logged backoff time will be 0 seconds as it is in granularity of seconds.
+ assertThat(logs).containsExactly("DEBUG: LRS stream closed, backoff in 0 second(s)");
+ logs.poll();
+ assertEquals(1, fakeClock.numPendingTasks(LRS_RPC_RETRY_TASK_FILTER));
+
+ // Fast-forward to a moment before the retry, the time spent in the last try is deducted.
+ fakeClock.forwardNanos(TimeUnit.SECONDS.toNanos(1) - 4 - 1);
+ verifyNoMoreInteractions(mockLoadReportingService);
+ // Then time for retry
+ fakeClock.forwardNanos(1);
+ inOrder.verify(mockLoadReportingService).streamLoadStats(lrsResponseObserverCaptor.capture());
+ assertEquals(1, lrsRequestObservers.size());
+ requestObserver = lrsRequestObservers.poll();
+ verify(requestObserver).onNext(eq(EXPECTED_INITIAL_REQ));
+ assertThat(logs).containsExactly("DEBUG: Initial LRS request sent: " + EXPECTED_INITIAL_REQ);
+ assertEquals(0, fakeClock.numPendingTasks(LRS_RPC_RETRY_TASK_FILTER));
+
+ // Wrapping up
+ verify(backoffPolicyProvider, times(2)).get();
+ verify(backoffPolicy1, times(2)).nextBackoffNanos();
+ verify(backoffPolicy2, times(1)).nextBackoffNanos();
+ }
+
+ @Test
+ public void raceBetweenLoadReportingAndLbStreamClosure() {
+ verify(mockLoadReportingService).streamLoadStats(lrsResponseObserverCaptor.capture());
+ StreamObserver responseObserver = lrsResponseObserverCaptor.getValue();
+ assertEquals(1, lrsRequestObservers.size());
+ StreamObserver requestObserver = lrsRequestObservers.poll();
+ InOrder inOrder = inOrder(requestObserver);
+
+ // First balancer RPC
+ inOrder.verify(requestObserver).onNext(EXPECTED_INITIAL_REQ);
+ assertEquals(0, fakeClock.numPendingTasks(LRS_RPC_RETRY_TASK_FILTER));
+
+ // Simulate receiving LB response
+ assertEquals(0, fakeClock.numPendingTasks(LOAD_REPORTING_TASK_FILTER));
+ responseObserver.onNext(buildLrsResponse(1983));
+ // Load reporting task is scheduled
+ assertEquals(1, fakeClock.numPendingTasks(LOAD_REPORTING_TASK_FILTER));
+ FakeClock.ScheduledTask scheduledTask =
+ Iterables.getOnlyElement(fakeClock.getPendingTasks(LOAD_REPORTING_TASK_FILTER));
+ assertEquals(1983, scheduledTask.getDelay(TimeUnit.NANOSECONDS));
+
+ // Close lbStream
+ requestObserver.onCompleted();
+
+ // Reporting task cancelled
+ assertEquals(0, fakeClock.numPendingTasks(LOAD_REPORTING_TASK_FILTER));
+
+ // Simulate a race condition where the task has just started when its cancelled
+ scheduledTask.command.run();
+
+ // No report sent. No new task scheduled
+ inOrder.verify(requestObserver, never()).onNext(any(LoadStatsRequest.class));
+ assertEquals(0, fakeClock.numPendingTasks(LOAD_REPORTING_TASK_FILTER));
+ }
+}
\ No newline at end of file