xds: synchronize LoadReportClient operations with lock (#7528)

Replace the SynchronizationContext used in LoadReportClient with a lock.
This commit is contained in:
Chengyuan Zhang 2020-10-20 16:58:08 -07:00 committed by GitHub
parent c329aad2bc
commit 0ec3bfb471
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 154 additions and 190 deletions

View File

@ -30,8 +30,6 @@ import io.envoyproxy.envoy.service.load_stats.v3.LoadStatsRequest;
import io.envoyproxy.envoy.service.load_stats.v3.LoadStatsResponse; import io.envoyproxy.envoy.service.load_stats.v3.LoadStatsResponse;
import io.grpc.InternalLogId; import io.grpc.InternalLogId;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.SynchronizationContext;
import io.grpc.SynchronizationContext.ScheduledHandle;
import io.grpc.internal.BackoffPolicy; import io.grpc.internal.BackoffPolicy;
import io.grpc.stub.StreamObserver; import io.grpc.stub.StreamObserver;
import io.grpc.xds.EnvoyProtoData.ClusterStats; import io.grpc.xds.EnvoyProtoData.ClusterStats;
@ -39,34 +37,33 @@ import io.grpc.xds.EnvoyProtoData.Node;
import io.grpc.xds.XdsClient.XdsChannel; import io.grpc.xds.XdsClient.XdsChannel;
import io.grpc.xds.XdsLogger.XdsLogLevel; import io.grpc.xds.XdsLogger.XdsLogLevel;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import javax.annotation.concurrent.NotThreadSafe;
/** /**
* Client of xDS load reporting service based on LRS protocol, which reports load stats of * Client of xDS load reporting service based on LRS protocol, which reports load stats of
* gRPC client's perspective to a management server. * gRPC client's perspective to a management server.
*/ */
@NotThreadSafe
final class LoadReportClient { final class LoadReportClient {
private final InternalLogId logId; private final InternalLogId logId;
private final XdsLogger logger; private final XdsLogger logger;
private final XdsChannel xdsChannel; private final XdsChannel xdsChannel;
private final Node node; private final Node node;
private final SynchronizationContext syncContext;
private final ScheduledExecutorService timerService; private final ScheduledExecutorService timerService;
private final Stopwatch retryStopwatch; private final Stopwatch retryStopwatch;
private final BackoffPolicy.Provider backoffPolicyProvider; private final BackoffPolicy.Provider backoffPolicyProvider;
private final LoadStatsManager loadStatsManager; private final LoadStatsManager loadStatsManager;
private final Object lock = new Object();
private boolean started; private boolean started;
@Nullable @Nullable
private BackoffPolicy lrsRpcRetryPolicy; private BackoffPolicy lrsRpcRetryPolicy;
@Nullable @Nullable
private ScheduledHandle lrsRpcRetryTimer; private ScheduledFuture<?> lrsRpcRetryTimer;
@Nullable @Nullable
private LrsStream lrsStream; private LrsStream lrsStream;
@ -74,13 +71,11 @@ final class LoadReportClient {
LoadStatsManager loadStatsManager, LoadStatsManager loadStatsManager,
XdsChannel xdsChannel, XdsChannel xdsChannel,
Node node, Node node,
SynchronizationContext syncContext,
ScheduledExecutorService scheduledExecutorService, ScheduledExecutorService scheduledExecutorService,
BackoffPolicy.Provider backoffPolicyProvider, BackoffPolicy.Provider backoffPolicyProvider,
Supplier<Stopwatch> stopwatchSupplier) { Supplier<Stopwatch> stopwatchSupplier) {
this.loadStatsManager = checkNotNull(loadStatsManager, "loadStatsManager"); this.loadStatsManager = checkNotNull(loadStatsManager, "loadStatsManager");
this.xdsChannel = checkNotNull(xdsChannel, "xdsChannel"); this.xdsChannel = checkNotNull(xdsChannel, "xdsChannel");
this.syncContext = checkNotNull(syncContext, "syncContext");
this.timerService = checkNotNull(scheduledExecutorService, "timeService"); this.timerService = checkNotNull(scheduledExecutorService, "timeService");
this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider"); this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider");
this.retryStopwatch = checkNotNull(stopwatchSupplier, "stopwatchSupplier").get(); this.retryStopwatch = checkNotNull(stopwatchSupplier, "stopwatchSupplier").get();
@ -97,12 +92,14 @@ final class LoadReportClient {
* no-op. * no-op.
*/ */
void startLoadReporting() { void startLoadReporting() {
if (started) { synchronized (lock) {
return; if (started) {
return;
}
started = true;
logger.log(XdsLogLevel.INFO, "Starting load reporting RPC");
startLrsRpc();
} }
started = true;
logger.log(XdsLogLevel.INFO, "Starting load reporting RPC");
startLrsRpc();
} }
/** /**
@ -110,22 +107,24 @@ final class LoadReportClient {
* {@link LoadReportClient} is no-op. * {@link LoadReportClient} is no-op.
*/ */
void stopLoadReporting() { void stopLoadReporting() {
if (!started) { synchronized (lock) {
return; if (!started) {
return;
}
started = false;
logger.log(XdsLogLevel.INFO, "Stopping load reporting RPC");
if (lrsRpcRetryTimer != null) {
lrsRpcRetryTimer.cancel(false);
}
if (lrsStream != null) {
lrsStream.close(Status.CANCELLED.withDescription("stop load reporting").asException());
}
// Do not shutdown channel as it is not owned by LrsClient.
} }
logger.log(XdsLogLevel.INFO, "Stopping load reporting RPC");
if (lrsRpcRetryTimer != null) {
lrsRpcRetryTimer.cancel();
}
if (lrsStream != null) {
lrsStream.close(Status.CANCELLED.withDescription("stop load reporting").asException());
}
started = false;
// Do not shutdown channel as it is not owned by LrsClient.
} }
@VisibleForTesting @VisibleForTesting
static class LoadReportingTask implements Runnable { class LoadReportingTask implements Runnable {
private final LrsStream stream; private final LrsStream stream;
LoadReportingTask(LrsStream stream) { LoadReportingTask(LrsStream stream) {
@ -134,7 +133,9 @@ final class LoadReportClient {
@Override @Override
public void run() { public void run() {
stream.sendLoadReport(); synchronized (lock) {
stream.sendLoadReport();
}
} }
} }
@ -143,11 +144,16 @@ final class LoadReportClient {
@Override @Override
public void run() { public void run() {
startLrsRpc(); synchronized (lock) {
startLrsRpc();
}
} }
} }
private void startLrsRpc() { private void startLrsRpc() {
if (!started) {
return;
}
checkState(lrsStream == null, "previous lbStream has not been cleared yet"); checkState(lrsStream == null, "previous lbStream has not been cleared yet");
if (xdsChannel.isUseProtocolV3()) { if (xdsChannel.isUseProtocolV3()) {
lrsStream = new LrsStreamV3(); lrsStream = new LrsStreamV3();
@ -161,19 +167,19 @@ final class LoadReportClient {
private abstract class LrsStream { private abstract class LrsStream {
boolean initialResponseReceived; boolean initialResponseReceived;
boolean closed; boolean closed;
long loadReportIntervalNano = -1; long intervalNano = -1;
boolean reportAllClusters; boolean reportAllClusters;
List<String> clusterNames; // clusters to report loads for, if not report all. List<String> clusterNames; // clusters to report loads for, if not report all.
ScheduledHandle loadReportTimer; ScheduledFuture<?> loadReportTimer;
abstract void start(); abstract void start();
abstract void sendLoadStatsRequest(LoadStatsRequestData request); abstract void sendLoadStatsRequest(List<ClusterStats> clusterStatsList);
abstract void sendError(Exception error); abstract void sendError(Exception error);
// Must run in syncContext. final void handleRpcResponse(List<String> clusters, boolean sendAllClusters,
final void handleResponse(LoadStatsResponseData response) { long loadReportIntervalNano) {
if (closed) { if (closed) {
return; return;
} }
@ -181,30 +187,30 @@ final class LoadReportClient {
logger.log(XdsLogLevel.DEBUG, "Initial LRS response received"); logger.log(XdsLogLevel.DEBUG, "Initial LRS response received");
initialResponseReceived = true; initialResponseReceived = true;
} }
reportAllClusters = response.getSendAllClusters(); reportAllClusters = sendAllClusters;
if (reportAllClusters) { if (reportAllClusters) {
logger.log(XdsLogLevel.INFO, "Report loads for all clusters"); logger.log(XdsLogLevel.INFO, "Report loads for all clusters");
} else { } else {
logger.log(XdsLogLevel.INFO, "Report loads for clusters: ", response.getClustersList()); logger.log(XdsLogLevel.INFO, "Report loads for clusters: ", clusters);
clusterNames = response.getClustersList(); clusterNames = clusters;
} }
long interval = response.getLoadReportingIntervalNanos(); intervalNano = loadReportIntervalNano;
logger.log(XdsLogLevel.INFO, "Update load reporting interval to {0} ns", interval); logger.log(XdsLogLevel.INFO, "Update load reporting interval to {0} ns", intervalNano);
loadReportIntervalNano = interval;
scheduleNextLoadReport(); scheduleNextLoadReport();
} }
// Must run in syncContext.
final void handleRpcError(Throwable t) { final void handleRpcError(Throwable t) {
handleStreamClosed(Status.fromThrowable(t)); handleStreamClosed(Status.fromThrowable(t));
} }
// Must run in syncContext.
final void handleRpcCompleted() { final void handleRpcCompleted() {
handleStreamClosed(Status.UNAVAILABLE.withDescription("Closed by server")); handleStreamClosed(Status.UNAVAILABLE.withDescription("Closed by server"));
} }
private void sendLoadReport() { private void sendLoadReport() {
if (closed) {
return;
}
List<ClusterStats> clusterStatsList; List<ClusterStats> clusterStatsList;
if (reportAllClusters) { if (reportAllClusters) {
clusterStatsList = loadStatsManager.getAllLoadReports(); clusterStatsList = loadStatsManager.getAllLoadReports();
@ -214,21 +220,19 @@ final class LoadReportClient {
clusterStatsList.addAll(loadStatsManager.getClusterLoadReports(name)); clusterStatsList.addAll(loadStatsManager.getClusterLoadReports(name));
} }
} }
LoadStatsRequestData request = new LoadStatsRequestData(node, clusterStatsList); sendLoadStatsRequest(clusterStatsList);
sendLoadStatsRequest(request);
scheduleNextLoadReport(); scheduleNextLoadReport();
} }
private void scheduleNextLoadReport() { private void scheduleNextLoadReport() {
// Cancel pending load report and reschedule with updated load reporting interval. // Cancel pending load report and reschedule with updated load reporting interval.
if (loadReportTimer != null && loadReportTimer.isPending()) { if (loadReportTimer != null && !loadReportTimer.isDone()) {
loadReportTimer.cancel(); loadReportTimer.cancel(false);
loadReportTimer = null; loadReportTimer = null;
} }
if (loadReportIntervalNano > 0) { if (intervalNano > 0) {
loadReportTimer = syncContext.schedule( loadReportTimer = timerService.schedule(
new LoadReportingTask(this), loadReportIntervalNano, TimeUnit.NANOSECONDS, new LoadReportingTask(this), intervalNano, TimeUnit.NANOSECONDS);
timerService);
} }
} }
@ -263,8 +267,7 @@ final class LoadReportClient {
startLrsRpc(); startLrsRpc();
} else { } else {
lrsRpcRetryTimer = lrsRpcRetryTimer =
syncContext.schedule(new LrsRpcRetryTask(), delayNanos, TimeUnit.NANOSECONDS, timerService.schedule(new LrsRpcRetryTask(), delayNanos, TimeUnit.NANOSECONDS);
timerService);
} }
} }
@ -279,7 +282,7 @@ final class LoadReportClient {
private void cleanUp() { private void cleanUp() {
if (loadReportTimer != null) { if (loadReportTimer != null) {
loadReportTimer.cancel(); loadReportTimer.cancel(false);
loadReportTimer = null; loadReportTimer = null;
} }
if (lrsStream == this) { if (lrsStream == this) {
@ -298,34 +301,26 @@ final class LoadReportClient {
new StreamObserver<io.envoyproxy.envoy.service.load_stats.v2.LoadStatsResponse>() { new StreamObserver<io.envoyproxy.envoy.service.load_stats.v2.LoadStatsResponse>() {
@Override @Override
public void onNext( public void onNext(
final io.envoyproxy.envoy.service.load_stats.v2.LoadStatsResponse response) { io.envoyproxy.envoy.service.load_stats.v2.LoadStatsResponse response) {
syncContext.execute(new Runnable() { synchronized (lock) {
@Override logger.log(XdsLogLevel.DEBUG, "Received LoadStatsResponse:\n{0}", response);
public void run() { handleRpcResponse(response.getClustersList(), response.getSendAllClusters(),
logger.log(XdsLogLevel.DEBUG, "Received LoadStatsResponse:\n{0}", response); Durations.toNanos(response.getLoadReportingInterval()));
handleResponse(LoadStatsResponseData.fromEnvoyProtoV2(response)); }
}
});
} }
@Override @Override
public void onError(final Throwable t) { public void onError(Throwable t) {
syncContext.execute(new Runnable() { synchronized (lock) {
@Override handleRpcError(t);
public void run() { }
handleRpcError(t);
}
});
} }
@Override @Override
public void onCompleted() { public void onCompleted() {
syncContext.execute(new Runnable() { synchronized (lock) {
@Override handleRpcCompleted();
public void run() { }
handleRpcCompleted();
}
});
} }
}; };
io.envoyproxy.envoy.service.load_stats.v2.LoadReportingServiceGrpc.LoadReportingServiceStub io.envoyproxy.envoy.service.load_stats.v2.LoadReportingServiceGrpc.LoadReportingServiceStub
@ -333,15 +328,20 @@ final class LoadReportClient {
xdsChannel.getManagedChannel()); xdsChannel.getManagedChannel());
lrsRequestWriterV2 = stubV2.withWaitForReady().streamLoadStats(lrsResponseReaderV2); lrsRequestWriterV2 = stubV2.withWaitForReady().streamLoadStats(lrsResponseReaderV2);
logger.log(XdsLogLevel.DEBUG, "Sending initial LRS request"); logger.log(XdsLogLevel.DEBUG, "Sending initial LRS request");
sendLoadStatsRequest(new LoadStatsRequestData(node, null)); sendLoadStatsRequest(Collections.<ClusterStats>emptyList());
} }
@Override @Override
void sendLoadStatsRequest(LoadStatsRequestData request) { void sendLoadStatsRequest(List<ClusterStats> clusterStatsList) {
io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest requestProto = io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest.Builder requestBuilder =
request.toEnvoyProtoV2(); io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest.newBuilder()
lrsRequestWriterV2.onNext(requestProto); .setNode(node.toEnvoyProtoNodeV2());
logger.log(XdsLogLevel.DEBUG, "Sent LoadStatsRequest\n{0}", requestProto); for (ClusterStats stats : clusterStatsList) {
requestBuilder.addClusterStats(stats.toEnvoyProtoClusterStatsV2());
}
io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest request = requestBuilder.build();
lrsRequestWriterV2.onNext(requestBuilder.build());
logger.log(XdsLogLevel.DEBUG, "Sent LoadStatsRequest\n{0}", request);
} }
@Override @Override
@ -358,48 +358,45 @@ final class LoadReportClient {
StreamObserver<LoadStatsResponse> lrsResponseReaderV3 = StreamObserver<LoadStatsResponse> lrsResponseReaderV3 =
new StreamObserver<LoadStatsResponse>() { new StreamObserver<LoadStatsResponse>() {
@Override @Override
public void onNext(final LoadStatsResponse response) { public void onNext(LoadStatsResponse response) {
syncContext.execute(new Runnable() { synchronized (lock) {
@Override logger.log(XdsLogLevel.DEBUG, "Received LRS response:\n{0}", response);
public void run() { handleRpcResponse(response.getClustersList(), response.getSendAllClusters(),
logger.log(XdsLogLevel.DEBUG, "Received LRS response:\n{0}", response); Durations.toNanos(response.getLoadReportingInterval()));
handleResponse(LoadStatsResponseData.fromEnvoyProtoV3(response)); }
}
});
} }
@Override @Override
public void onError(final Throwable t) { public void onError(Throwable t) {
syncContext.execute(new Runnable() { synchronized (lock) {
@Override handleRpcError(t);
public void run() { }
handleRpcError(t);
}
});
} }
@Override @Override
public void onCompleted() { public void onCompleted() {
syncContext.execute(new Runnable() { synchronized (lock) {
@Override handleRpcCompleted();
public void run() { }
handleRpcCompleted();
}
});
} }
}; };
LoadReportingServiceStub stubV3 = LoadReportingServiceStub stubV3 =
LoadReportingServiceGrpc.newStub(xdsChannel.getManagedChannel()); LoadReportingServiceGrpc.newStub(xdsChannel.getManagedChannel());
lrsRequestWriterV3 = stubV3.withWaitForReady().streamLoadStats(lrsResponseReaderV3); lrsRequestWriterV3 = stubV3.withWaitForReady().streamLoadStats(lrsResponseReaderV3);
logger.log(XdsLogLevel.DEBUG, "Sending initial LRS request"); logger.log(XdsLogLevel.DEBUG, "Sending initial LRS request");
sendLoadStatsRequest(new LoadStatsRequestData(node, null)); sendLoadStatsRequest(Collections.<ClusterStats>emptyList());
} }
@Override @Override
void sendLoadStatsRequest(LoadStatsRequestData request) { void sendLoadStatsRequest(List<ClusterStats> clusterStatsList) {
LoadStatsRequest requestProto = request.toEnvoyProtoV3(); LoadStatsRequest.Builder requestBuilder =
lrsRequestWriterV3.onNext(requestProto); LoadStatsRequest.newBuilder().setNode(node.toEnvoyProtoNode());
logger.log(XdsLogLevel.DEBUG, "Sent LoadStatsRequest\n{0}", requestProto); for (ClusterStats stats : clusterStatsList) {
requestBuilder.addClusterStats(stats.toEnvoyProtoClusterStats());
}
LoadStatsRequest request = requestBuilder.build();
lrsRequestWriterV3.onNext(request);
logger.log(XdsLogLevel.DEBUG, "Sent LoadStatsRequest\n{0}", request);
} }
@Override @Override
@ -407,78 +404,4 @@ final class LoadReportClient {
lrsRequestWriterV3.onError(error); lrsRequestWriterV3.onError(error);
} }
} }
private static final class LoadStatsRequestData {
final Node node;
@Nullable
final List<ClusterStats> clusterStatsList;
LoadStatsRequestData(Node node, @Nullable List<ClusterStats> clusterStatsList) {
this.node = checkNotNull(node, "node");
this.clusterStatsList = clusterStatsList;
}
io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest toEnvoyProtoV2() {
io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest.Builder builder
= io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest.newBuilder()
.setNode(node.toEnvoyProtoNodeV2());
if (clusterStatsList != null) {
for (ClusterStats stats : clusterStatsList) {
builder.addClusterStats(stats.toEnvoyProtoClusterStatsV2());
}
}
return builder.build();
}
LoadStatsRequest toEnvoyProtoV3() {
LoadStatsRequest.Builder builder = LoadStatsRequest.newBuilder()
.setNode(node.toEnvoyProtoNode());
if (clusterStatsList != null) {
for (ClusterStats stats : clusterStatsList) {
builder.addClusterStats(stats.toEnvoyProtoClusterStats());
}
}
return builder.build();
}
}
private static final class LoadStatsResponseData {
final boolean sendAllClusters;
final List<String> clusters;
final long loadReportingIntervalNanos;
LoadStatsResponseData(
boolean sendAllClusters, List<String> clusters, long loadReportingIntervalNanos) {
this.sendAllClusters = sendAllClusters;
this.clusters = checkNotNull(clusters, "clusters");
this.loadReportingIntervalNanos = loadReportingIntervalNanos;
}
boolean getSendAllClusters() {
return sendAllClusters;
}
List<String> getClustersList() {
return clusters;
}
long getLoadReportingIntervalNanos() {
return loadReportingIntervalNanos;
}
static LoadStatsResponseData fromEnvoyProtoV2(
io.envoyproxy.envoy.service.load_stats.v2.LoadStatsResponse loadStatsResponse) {
return new LoadStatsResponseData(
loadStatsResponse.getSendAllClusters(),
loadStatsResponse.getClustersList(),
Durations.toNanos(loadStatsResponse.getLoadReportingInterval()));
}
static LoadStatsResponseData fromEnvoyProtoV3(LoadStatsResponse loadStatsResponse) {
return new LoadStatsResponseData(
loadStatsResponse.getSendAllClusters(),
loadStatsResponse.getClustersList(),
Durations.toNanos(loadStatsResponse.getLoadReportingInterval()));
}
}
} }

View File

@ -159,7 +159,7 @@ final class XdsClientImpl2 extends XdsClient {
this.timeService = checkNotNull(timeService, "timeService"); this.timeService = checkNotNull(timeService, "timeService");
this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider"); this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider");
adsStreamRetryStopwatch = checkNotNull(stopwatchSupplier, "stopwatchSupplier").get(); adsStreamRetryStopwatch = checkNotNull(stopwatchSupplier, "stopwatchSupplier").get();
lrsClient = new LoadReportClient(loadStatsManager, xdsChannel, node, syncContext, timeService, lrsClient = new LoadReportClient(loadStatsManager, xdsChannel, node, timeService,
backoffPolicyProvider, stopwatchSupplier); backoffPolicyProvider, stopwatchSupplier);
logId = InternalLogId.allocate("xds-client", null); logId = InternalLogId.allocate("xds-client", null);
logger = XdsLogger.withLogId(logId); logger = XdsLogger.withLogId(logId);

View File

@ -44,7 +44,6 @@ import io.grpc.Context;
import io.grpc.Context.CancellationListener; import io.grpc.Context.CancellationListener;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.SynchronizationContext;
import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.internal.BackoffPolicy; import io.grpc.internal.BackoffPolicy;
@ -114,13 +113,6 @@ public class LoadReportClientTest {
@Rule @Rule
public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule();
private final SynchronizationContext syncContext = new SynchronizationContext(
new Thread.UncaughtExceptionHandler() {
@Override
public void uncaughtException(Thread t, Throwable e) {
throw new AssertionError(e);
}
});
private final FakeClock fakeClock = new FakeClock(); private final FakeClock fakeClock = new FakeClock();
private final ArrayDeque<StreamObserver<LoadStatsRequest>> lrsRequestObservers = private final ArrayDeque<StreamObserver<LoadStatsRequest>> lrsRequestObservers =
new ArrayDeque<>(); new ArrayDeque<>();
@ -142,6 +134,8 @@ public class LoadReportClientTest {
private BackoffPolicy backoffPolicy2; private BackoffPolicy backoffPolicy2;
@Captor @Captor
private ArgumentCaptor<StreamObserver<LoadStatsResponse>> lrsResponseObserverCaptor; private ArgumentCaptor<StreamObserver<LoadStatsResponse>> lrsResponseObserverCaptor;
@Captor
private ArgumentCaptor<Throwable> errorCaptor;
private LoadReportingServiceGrpc.LoadReportingServiceImplBase mockLoadReportingService; private LoadReportingServiceGrpc.LoadReportingServiceImplBase mockLoadReportingService;
private ManagedChannel channel; private ManagedChannel channel;
@ -187,7 +181,6 @@ public class LoadReportClientTest {
loadStatsManager, loadStatsManager,
new XdsChannel(channel, false), new XdsChannel(channel, false),
NODE, NODE,
syncContext,
fakeClock.getScheduledExecutorService(), fakeClock.getScheduledExecutorService(),
backoffPolicyProvider, backoffPolicyProvider,
fakeClock.getStopwatchSupplier()); fakeClock.getStopwatchSupplier());
@ -407,7 +400,55 @@ public class LoadReportClientTest {
} }
@Test @Test
public void raceBetweenLoadReportingAndLbStreamClosure() { public void raceBetweenStopAndLoadReporting() {
verify(mockLoadReportingService).streamLoadStats(lrsResponseObserverCaptor.capture());
StreamObserver<LoadStatsResponse> responseObserver = lrsResponseObserverCaptor.getValue();
StreamObserver<LoadStatsRequest> requestObserver =
Iterables.getOnlyElement(lrsRequestObservers);
verify(requestObserver).onNext(eq(buildInitialRequest()));
responseObserver.onNext(buildLrsResponse(Collections.singletonList(CLUSTER1), 1234));
assertEquals(1, fakeClock.numPendingTasks(LOAD_REPORTING_TASK_FILTER));
FakeClock.ScheduledTask scheduledTask =
Iterables.getOnlyElement(fakeClock.getPendingTasks(LOAD_REPORTING_TASK_FILTER));
assertEquals(1234, scheduledTask.getDelay(TimeUnit.NANOSECONDS));
fakeClock.forwardNanos(1233);
lrsClient.stopLoadReporting();
verify(requestObserver).onError(errorCaptor.capture());
assertEquals("CANCELLED: client cancelled", errorCaptor.getValue().getMessage());
assertThat(scheduledTask.isCancelled()).isTrue();
fakeClock.forwardNanos(1);
assertEquals(0, fakeClock.numPendingTasks(LOAD_REPORTING_TASK_FILTER));
fakeClock.forwardNanos(1234);
verifyNoMoreInteractions(requestObserver);
}
@Test
public void raceBetweenStopAndLrsStreamRetry() {
verify(mockLoadReportingService).streamLoadStats(lrsResponseObserverCaptor.capture());
StreamObserver<LoadStatsResponse> responseObserver = lrsResponseObserverCaptor.getValue();
StreamObserver<LoadStatsRequest> requestObserver =
Iterables.getOnlyElement(lrsRequestObservers);
verify(requestObserver).onNext(eq(buildInitialRequest()));
responseObserver.onCompleted();
assertEquals(1, fakeClock.numPendingTasks(LRS_RPC_RETRY_TASK_FILTER));
FakeClock.ScheduledTask scheduledTask =
Iterables.getOnlyElement(fakeClock.getPendingTasks(LRS_RPC_RETRY_TASK_FILTER));
assertEquals(1, scheduledTask.getDelay(TimeUnit.SECONDS));
fakeClock.forwardTime(999, TimeUnit.MILLISECONDS);
lrsClient.stopLoadReporting();
assertThat(scheduledTask.isCancelled()).isTrue();
fakeClock.forwardTime(1, TimeUnit.MILLISECONDS);
assertEquals(0, fakeClock.numPendingTasks(LRS_RPC_RETRY_TASK_FILTER));
fakeClock.forwardTime(10, TimeUnit.SECONDS);
verifyNoMoreInteractions(requestObserver);
}
@Test
public void raceBetweenLoadReportingAndLrsStreamClosure() {
verify(mockLoadReportingService).streamLoadStats(lrsResponseObserverCaptor.capture()); verify(mockLoadReportingService).streamLoadStats(lrsResponseObserverCaptor.capture());
StreamObserver<LoadStatsResponse> responseObserver = lrsResponseObserverCaptor.getValue(); StreamObserver<LoadStatsResponse> responseObserver = lrsResponseObserverCaptor.getValue();
assertThat(lrsRequestObservers).hasSize(1); assertThat(lrsRequestObservers).hasSize(1);