From 0ec3bfb47167b86f0689a14ac824fbbdeafc8411 Mon Sep 17 00:00:00 2001 From: Chengyuan Zhang Date: Tue, 20 Oct 2020 16:58:08 -0700 Subject: [PATCH] xds: synchronize LoadReportClient operations with lock (#7528) Replace the SynchronizationContext used in LoadReportClient with a lock. --- .../java/io/grpc/xds/LoadReportClient.java | 281 +++++++----------- .../main/java/io/grpc/xds/XdsClientImpl2.java | 2 +- .../io/grpc/xds/LoadReportClientTest.java | 61 +++- 3 files changed, 154 insertions(+), 190 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/LoadReportClient.java b/xds/src/main/java/io/grpc/xds/LoadReportClient.java index 147946d44c..0c4bf61574 100644 --- a/xds/src/main/java/io/grpc/xds/LoadReportClient.java +++ b/xds/src/main/java/io/grpc/xds/LoadReportClient.java @@ -30,8 +30,6 @@ import io.envoyproxy.envoy.service.load_stats.v3.LoadStatsRequest; import io.envoyproxy.envoy.service.load_stats.v3.LoadStatsResponse; import io.grpc.InternalLogId; import io.grpc.Status; -import io.grpc.SynchronizationContext; -import io.grpc.SynchronizationContext.ScheduledHandle; import io.grpc.internal.BackoffPolicy; import io.grpc.stub.StreamObserver; import io.grpc.xds.EnvoyProtoData.ClusterStats; @@ -39,34 +37,33 @@ import io.grpc.xds.EnvoyProtoData.Node; import io.grpc.xds.XdsClient.XdsChannel; import io.grpc.xds.XdsLogger.XdsLogLevel; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; -import javax.annotation.concurrent.NotThreadSafe; /** * Client of xDS load reporting service based on LRS protocol, which reports load stats of * gRPC client's perspective to a management server. */ -@NotThreadSafe final class LoadReportClient { private final InternalLogId logId; private final XdsLogger logger; private final XdsChannel xdsChannel; private final Node node; - private final SynchronizationContext syncContext; private final ScheduledExecutorService timerService; private final Stopwatch retryStopwatch; private final BackoffPolicy.Provider backoffPolicyProvider; private final LoadStatsManager loadStatsManager; + private final Object lock = new Object(); private boolean started; - @Nullable private BackoffPolicy lrsRpcRetryPolicy; @Nullable - private ScheduledHandle lrsRpcRetryTimer; + private ScheduledFuture lrsRpcRetryTimer; @Nullable private LrsStream lrsStream; @@ -74,13 +71,11 @@ final class LoadReportClient { LoadStatsManager loadStatsManager, XdsChannel xdsChannel, Node node, - SynchronizationContext syncContext, ScheduledExecutorService scheduledExecutorService, BackoffPolicy.Provider backoffPolicyProvider, Supplier stopwatchSupplier) { this.loadStatsManager = checkNotNull(loadStatsManager, "loadStatsManager"); this.xdsChannel = checkNotNull(xdsChannel, "xdsChannel"); - this.syncContext = checkNotNull(syncContext, "syncContext"); this.timerService = checkNotNull(scheduledExecutorService, "timeService"); this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider"); this.retryStopwatch = checkNotNull(stopwatchSupplier, "stopwatchSupplier").get(); @@ -97,12 +92,14 @@ final class LoadReportClient { * no-op. */ void startLoadReporting() { - if (started) { - return; + synchronized (lock) { + if (started) { + return; + } + started = true; + logger.log(XdsLogLevel.INFO, "Starting load reporting RPC"); + startLrsRpc(); } - started = true; - logger.log(XdsLogLevel.INFO, "Starting load reporting RPC"); - startLrsRpc(); } /** @@ -110,22 +107,24 @@ final class LoadReportClient { * {@link LoadReportClient} is no-op. */ void stopLoadReporting() { - if (!started) { - return; + synchronized (lock) { + if (!started) { + return; + } + started = false; + logger.log(XdsLogLevel.INFO, "Stopping load reporting RPC"); + if (lrsRpcRetryTimer != null) { + lrsRpcRetryTimer.cancel(false); + } + if (lrsStream != null) { + lrsStream.close(Status.CANCELLED.withDescription("stop load reporting").asException()); + } + // Do not shutdown channel as it is not owned by LrsClient. } - logger.log(XdsLogLevel.INFO, "Stopping load reporting RPC"); - if (lrsRpcRetryTimer != null) { - lrsRpcRetryTimer.cancel(); - } - if (lrsStream != null) { - lrsStream.close(Status.CANCELLED.withDescription("stop load reporting").asException()); - } - started = false; - // Do not shutdown channel as it is not owned by LrsClient. } @VisibleForTesting - static class LoadReportingTask implements Runnable { + class LoadReportingTask implements Runnable { private final LrsStream stream; LoadReportingTask(LrsStream stream) { @@ -134,7 +133,9 @@ final class LoadReportClient { @Override public void run() { - stream.sendLoadReport(); + synchronized (lock) { + stream.sendLoadReport(); + } } } @@ -143,11 +144,16 @@ final class LoadReportClient { @Override public void run() { - startLrsRpc(); + synchronized (lock) { + startLrsRpc(); + } } } private void startLrsRpc() { + if (!started) { + return; + } checkState(lrsStream == null, "previous lbStream has not been cleared yet"); if (xdsChannel.isUseProtocolV3()) { lrsStream = new LrsStreamV3(); @@ -161,19 +167,19 @@ final class LoadReportClient { private abstract class LrsStream { boolean initialResponseReceived; boolean closed; - long loadReportIntervalNano = -1; + long intervalNano = -1; boolean reportAllClusters; List clusterNames; // clusters to report loads for, if not report all. - ScheduledHandle loadReportTimer; + ScheduledFuture loadReportTimer; abstract void start(); - abstract void sendLoadStatsRequest(LoadStatsRequestData request); + abstract void sendLoadStatsRequest(List clusterStatsList); abstract void sendError(Exception error); - // Must run in syncContext. - final void handleResponse(LoadStatsResponseData response) { + final void handleRpcResponse(List clusters, boolean sendAllClusters, + long loadReportIntervalNano) { if (closed) { return; } @@ -181,30 +187,30 @@ final class LoadReportClient { logger.log(XdsLogLevel.DEBUG, "Initial LRS response received"); initialResponseReceived = true; } - reportAllClusters = response.getSendAllClusters(); + reportAllClusters = sendAllClusters; if (reportAllClusters) { logger.log(XdsLogLevel.INFO, "Report loads for all clusters"); } else { - logger.log(XdsLogLevel.INFO, "Report loads for clusters: ", response.getClustersList()); - clusterNames = response.getClustersList(); + logger.log(XdsLogLevel.INFO, "Report loads for clusters: ", clusters); + clusterNames = clusters; } - long interval = response.getLoadReportingIntervalNanos(); - logger.log(XdsLogLevel.INFO, "Update load reporting interval to {0} ns", interval); - loadReportIntervalNano = interval; + intervalNano = loadReportIntervalNano; + logger.log(XdsLogLevel.INFO, "Update load reporting interval to {0} ns", intervalNano); scheduleNextLoadReport(); } - // Must run in syncContext. final void handleRpcError(Throwable t) { handleStreamClosed(Status.fromThrowable(t)); } - // Must run in syncContext. final void handleRpcCompleted() { handleStreamClosed(Status.UNAVAILABLE.withDescription("Closed by server")); } private void sendLoadReport() { + if (closed) { + return; + } List clusterStatsList; if (reportAllClusters) { clusterStatsList = loadStatsManager.getAllLoadReports(); @@ -214,21 +220,19 @@ final class LoadReportClient { clusterStatsList.addAll(loadStatsManager.getClusterLoadReports(name)); } } - LoadStatsRequestData request = new LoadStatsRequestData(node, clusterStatsList); - sendLoadStatsRequest(request); + sendLoadStatsRequest(clusterStatsList); scheduleNextLoadReport(); } private void scheduleNextLoadReport() { // Cancel pending load report and reschedule with updated load reporting interval. - if (loadReportTimer != null && loadReportTimer.isPending()) { - loadReportTimer.cancel(); + if (loadReportTimer != null && !loadReportTimer.isDone()) { + loadReportTimer.cancel(false); loadReportTimer = null; } - if (loadReportIntervalNano > 0) { - loadReportTimer = syncContext.schedule( - new LoadReportingTask(this), loadReportIntervalNano, TimeUnit.NANOSECONDS, - timerService); + if (intervalNano > 0) { + loadReportTimer = timerService.schedule( + new LoadReportingTask(this), intervalNano, TimeUnit.NANOSECONDS); } } @@ -263,8 +267,7 @@ final class LoadReportClient { startLrsRpc(); } else { lrsRpcRetryTimer = - syncContext.schedule(new LrsRpcRetryTask(), delayNanos, TimeUnit.NANOSECONDS, - timerService); + timerService.schedule(new LrsRpcRetryTask(), delayNanos, TimeUnit.NANOSECONDS); } } @@ -279,7 +282,7 @@ final class LoadReportClient { private void cleanUp() { if (loadReportTimer != null) { - loadReportTimer.cancel(); + loadReportTimer.cancel(false); loadReportTimer = null; } if (lrsStream == this) { @@ -298,34 +301,26 @@ final class LoadReportClient { new StreamObserver() { @Override public void onNext( - final io.envoyproxy.envoy.service.load_stats.v2.LoadStatsResponse response) { - syncContext.execute(new Runnable() { - @Override - public void run() { - logger.log(XdsLogLevel.DEBUG, "Received LoadStatsResponse:\n{0}", response); - handleResponse(LoadStatsResponseData.fromEnvoyProtoV2(response)); - } - }); + io.envoyproxy.envoy.service.load_stats.v2.LoadStatsResponse response) { + synchronized (lock) { + logger.log(XdsLogLevel.DEBUG, "Received LoadStatsResponse:\n{0}", response); + handleRpcResponse(response.getClustersList(), response.getSendAllClusters(), + Durations.toNanos(response.getLoadReportingInterval())); + } } @Override - public void onError(final Throwable t) { - syncContext.execute(new Runnable() { - @Override - public void run() { - handleRpcError(t); - } - }); + public void onError(Throwable t) { + synchronized (lock) { + handleRpcError(t); + } } @Override public void onCompleted() { - syncContext.execute(new Runnable() { - @Override - public void run() { - handleRpcCompleted(); - } - }); + synchronized (lock) { + handleRpcCompleted(); + } } }; io.envoyproxy.envoy.service.load_stats.v2.LoadReportingServiceGrpc.LoadReportingServiceStub @@ -333,15 +328,20 @@ final class LoadReportClient { xdsChannel.getManagedChannel()); lrsRequestWriterV2 = stubV2.withWaitForReady().streamLoadStats(lrsResponseReaderV2); logger.log(XdsLogLevel.DEBUG, "Sending initial LRS request"); - sendLoadStatsRequest(new LoadStatsRequestData(node, null)); + sendLoadStatsRequest(Collections.emptyList()); } @Override - void sendLoadStatsRequest(LoadStatsRequestData request) { - io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest requestProto = - request.toEnvoyProtoV2(); - lrsRequestWriterV2.onNext(requestProto); - logger.log(XdsLogLevel.DEBUG, "Sent LoadStatsRequest\n{0}", requestProto); + void sendLoadStatsRequest(List clusterStatsList) { + io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest.Builder requestBuilder = + io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest.newBuilder() + .setNode(node.toEnvoyProtoNodeV2()); + for (ClusterStats stats : clusterStatsList) { + requestBuilder.addClusterStats(stats.toEnvoyProtoClusterStatsV2()); + } + io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest request = requestBuilder.build(); + lrsRequestWriterV2.onNext(requestBuilder.build()); + logger.log(XdsLogLevel.DEBUG, "Sent LoadStatsRequest\n{0}", request); } @Override @@ -358,48 +358,45 @@ final class LoadReportClient { StreamObserver lrsResponseReaderV3 = new StreamObserver() { @Override - public void onNext(final LoadStatsResponse response) { - syncContext.execute(new Runnable() { - @Override - public void run() { - logger.log(XdsLogLevel.DEBUG, "Received LRS response:\n{0}", response); - handleResponse(LoadStatsResponseData.fromEnvoyProtoV3(response)); - } - }); + public void onNext(LoadStatsResponse response) { + synchronized (lock) { + logger.log(XdsLogLevel.DEBUG, "Received LRS response:\n{0}", response); + handleRpcResponse(response.getClustersList(), response.getSendAllClusters(), + Durations.toNanos(response.getLoadReportingInterval())); + } } @Override - public void onError(final Throwable t) { - syncContext.execute(new Runnable() { - @Override - public void run() { - handleRpcError(t); - } - }); + public void onError(Throwable t) { + synchronized (lock) { + handleRpcError(t); + } } @Override public void onCompleted() { - syncContext.execute(new Runnable() { - @Override - public void run() { - handleRpcCompleted(); - } - }); + synchronized (lock) { + handleRpcCompleted(); + } } }; LoadReportingServiceStub stubV3 = LoadReportingServiceGrpc.newStub(xdsChannel.getManagedChannel()); lrsRequestWriterV3 = stubV3.withWaitForReady().streamLoadStats(lrsResponseReaderV3); logger.log(XdsLogLevel.DEBUG, "Sending initial LRS request"); - sendLoadStatsRequest(new LoadStatsRequestData(node, null)); + sendLoadStatsRequest(Collections.emptyList()); } @Override - void sendLoadStatsRequest(LoadStatsRequestData request) { - LoadStatsRequest requestProto = request.toEnvoyProtoV3(); - lrsRequestWriterV3.onNext(requestProto); - logger.log(XdsLogLevel.DEBUG, "Sent LoadStatsRequest\n{0}", requestProto); + void sendLoadStatsRequest(List clusterStatsList) { + LoadStatsRequest.Builder requestBuilder = + LoadStatsRequest.newBuilder().setNode(node.toEnvoyProtoNode()); + for (ClusterStats stats : clusterStatsList) { + requestBuilder.addClusterStats(stats.toEnvoyProtoClusterStats()); + } + LoadStatsRequest request = requestBuilder.build(); + lrsRequestWriterV3.onNext(request); + logger.log(XdsLogLevel.DEBUG, "Sent LoadStatsRequest\n{0}", request); } @Override @@ -407,78 +404,4 @@ final class LoadReportClient { lrsRequestWriterV3.onError(error); } } - - private static final class LoadStatsRequestData { - final Node node; - @Nullable - final List clusterStatsList; - - LoadStatsRequestData(Node node, @Nullable List clusterStatsList) { - this.node = checkNotNull(node, "node"); - this.clusterStatsList = clusterStatsList; - } - - io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest toEnvoyProtoV2() { - io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest.Builder builder - = io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest.newBuilder() - .setNode(node.toEnvoyProtoNodeV2()); - if (clusterStatsList != null) { - for (ClusterStats stats : clusterStatsList) { - builder.addClusterStats(stats.toEnvoyProtoClusterStatsV2()); - } - } - return builder.build(); - } - - LoadStatsRequest toEnvoyProtoV3() { - LoadStatsRequest.Builder builder = LoadStatsRequest.newBuilder() - .setNode(node.toEnvoyProtoNode()); - if (clusterStatsList != null) { - for (ClusterStats stats : clusterStatsList) { - builder.addClusterStats(stats.toEnvoyProtoClusterStats()); - } - } - return builder.build(); - } - } - - private static final class LoadStatsResponseData { - final boolean sendAllClusters; - final List clusters; - final long loadReportingIntervalNanos; - - LoadStatsResponseData( - boolean sendAllClusters, List clusters, long loadReportingIntervalNanos) { - this.sendAllClusters = sendAllClusters; - this.clusters = checkNotNull(clusters, "clusters"); - this.loadReportingIntervalNanos = loadReportingIntervalNanos; - } - - boolean getSendAllClusters() { - return sendAllClusters; - } - - List getClustersList() { - return clusters; - } - - long getLoadReportingIntervalNanos() { - return loadReportingIntervalNanos; - } - - static LoadStatsResponseData fromEnvoyProtoV2( - io.envoyproxy.envoy.service.load_stats.v2.LoadStatsResponse loadStatsResponse) { - return new LoadStatsResponseData( - loadStatsResponse.getSendAllClusters(), - loadStatsResponse.getClustersList(), - Durations.toNanos(loadStatsResponse.getLoadReportingInterval())); - } - - static LoadStatsResponseData fromEnvoyProtoV3(LoadStatsResponse loadStatsResponse) { - return new LoadStatsResponseData( - loadStatsResponse.getSendAllClusters(), - loadStatsResponse.getClustersList(), - Durations.toNanos(loadStatsResponse.getLoadReportingInterval())); - } - } } diff --git a/xds/src/main/java/io/grpc/xds/XdsClientImpl2.java b/xds/src/main/java/io/grpc/xds/XdsClientImpl2.java index aef999fed3..4717748bdb 100644 --- a/xds/src/main/java/io/grpc/xds/XdsClientImpl2.java +++ b/xds/src/main/java/io/grpc/xds/XdsClientImpl2.java @@ -159,7 +159,7 @@ final class XdsClientImpl2 extends XdsClient { this.timeService = checkNotNull(timeService, "timeService"); this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider"); adsStreamRetryStopwatch = checkNotNull(stopwatchSupplier, "stopwatchSupplier").get(); - lrsClient = new LoadReportClient(loadStatsManager, xdsChannel, node, syncContext, timeService, + lrsClient = new LoadReportClient(loadStatsManager, xdsChannel, node, timeService, backoffPolicyProvider, stopwatchSupplier); logId = InternalLogId.allocate("xds-client", null); logger = XdsLogger.withLogId(logId); diff --git a/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java b/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java index 7feadc8225..a180d65861 100644 --- a/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java +++ b/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java @@ -44,7 +44,6 @@ import io.grpc.Context; import io.grpc.Context.CancellationListener; import io.grpc.ManagedChannel; import io.grpc.Status; -import io.grpc.SynchronizationContext; import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.internal.BackoffPolicy; @@ -114,13 +113,6 @@ public class LoadReportClientTest { @Rule public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); - private final SynchronizationContext syncContext = new SynchronizationContext( - new Thread.UncaughtExceptionHandler() { - @Override - public void uncaughtException(Thread t, Throwable e) { - throw new AssertionError(e); - } - }); private final FakeClock fakeClock = new FakeClock(); private final ArrayDeque> lrsRequestObservers = new ArrayDeque<>(); @@ -142,6 +134,8 @@ public class LoadReportClientTest { private BackoffPolicy backoffPolicy2; @Captor private ArgumentCaptor> lrsResponseObserverCaptor; + @Captor + private ArgumentCaptor errorCaptor; private LoadReportingServiceGrpc.LoadReportingServiceImplBase mockLoadReportingService; private ManagedChannel channel; @@ -187,7 +181,6 @@ public class LoadReportClientTest { loadStatsManager, new XdsChannel(channel, false), NODE, - syncContext, fakeClock.getScheduledExecutorService(), backoffPolicyProvider, fakeClock.getStopwatchSupplier()); @@ -407,7 +400,55 @@ public class LoadReportClientTest { } @Test - public void raceBetweenLoadReportingAndLbStreamClosure() { + public void raceBetweenStopAndLoadReporting() { + verify(mockLoadReportingService).streamLoadStats(lrsResponseObserverCaptor.capture()); + StreamObserver responseObserver = lrsResponseObserverCaptor.getValue(); + StreamObserver requestObserver = + Iterables.getOnlyElement(lrsRequestObservers); + verify(requestObserver).onNext(eq(buildInitialRequest())); + + responseObserver.onNext(buildLrsResponse(Collections.singletonList(CLUSTER1), 1234)); + assertEquals(1, fakeClock.numPendingTasks(LOAD_REPORTING_TASK_FILTER)); + FakeClock.ScheduledTask scheduledTask = + Iterables.getOnlyElement(fakeClock.getPendingTasks(LOAD_REPORTING_TASK_FILTER)); + assertEquals(1234, scheduledTask.getDelay(TimeUnit.NANOSECONDS)); + + fakeClock.forwardNanos(1233); + lrsClient.stopLoadReporting(); + verify(requestObserver).onError(errorCaptor.capture()); + assertEquals("CANCELLED: client cancelled", errorCaptor.getValue().getMessage()); + assertThat(scheduledTask.isCancelled()).isTrue(); + fakeClock.forwardNanos(1); + assertEquals(0, fakeClock.numPendingTasks(LOAD_REPORTING_TASK_FILTER)); + fakeClock.forwardNanos(1234); + verifyNoMoreInteractions(requestObserver); + } + + @Test + public void raceBetweenStopAndLrsStreamRetry() { + verify(mockLoadReportingService).streamLoadStats(lrsResponseObserverCaptor.capture()); + StreamObserver responseObserver = lrsResponseObserverCaptor.getValue(); + StreamObserver requestObserver = + Iterables.getOnlyElement(lrsRequestObservers); + verify(requestObserver).onNext(eq(buildInitialRequest())); + + responseObserver.onCompleted(); + assertEquals(1, fakeClock.numPendingTasks(LRS_RPC_RETRY_TASK_FILTER)); + FakeClock.ScheduledTask scheduledTask = + Iterables.getOnlyElement(fakeClock.getPendingTasks(LRS_RPC_RETRY_TASK_FILTER)); + assertEquals(1, scheduledTask.getDelay(TimeUnit.SECONDS)); + + fakeClock.forwardTime(999, TimeUnit.MILLISECONDS); + lrsClient.stopLoadReporting(); + assertThat(scheduledTask.isCancelled()).isTrue(); + fakeClock.forwardTime(1, TimeUnit.MILLISECONDS); + assertEquals(0, fakeClock.numPendingTasks(LRS_RPC_RETRY_TASK_FILTER)); + fakeClock.forwardTime(10, TimeUnit.SECONDS); + verifyNoMoreInteractions(requestObserver); + } + + @Test + public void raceBetweenLoadReportingAndLrsStreamClosure() { verify(mockLoadReportingService).streamLoadStats(lrsResponseObserverCaptor.capture()); StreamObserver responseObserver = lrsResponseObserverCaptor.getValue(); assertThat(lrsRequestObservers).hasSize(1);