xds: synchronize LoadReportClient operations with lock (#7528)

Replace the SynchronizationContext used in LoadReportClient with a lock.
This commit is contained in:
Chengyuan Zhang 2020-10-20 16:58:08 -07:00 committed by GitHub
parent c329aad2bc
commit 0ec3bfb471
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 154 additions and 190 deletions

View File

@ -30,8 +30,6 @@ import io.envoyproxy.envoy.service.load_stats.v3.LoadStatsRequest;
import io.envoyproxy.envoy.service.load_stats.v3.LoadStatsResponse;
import io.grpc.InternalLogId;
import io.grpc.Status;
import io.grpc.SynchronizationContext;
import io.grpc.SynchronizationContext.ScheduledHandle;
import io.grpc.internal.BackoffPolicy;
import io.grpc.stub.StreamObserver;
import io.grpc.xds.EnvoyProtoData.ClusterStats;
@ -39,34 +37,33 @@ import io.grpc.xds.EnvoyProtoData.Node;
import io.grpc.xds.XdsClient.XdsChannel;
import io.grpc.xds.XdsLogger.XdsLogLevel;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
import javax.annotation.concurrent.NotThreadSafe;
/**
* Client of xDS load reporting service based on LRS protocol, which reports load stats of
* gRPC client's perspective to a management server.
*/
@NotThreadSafe
final class LoadReportClient {
private final InternalLogId logId;
private final XdsLogger logger;
private final XdsChannel xdsChannel;
private final Node node;
private final SynchronizationContext syncContext;
private final ScheduledExecutorService timerService;
private final Stopwatch retryStopwatch;
private final BackoffPolicy.Provider backoffPolicyProvider;
private final LoadStatsManager loadStatsManager;
private final Object lock = new Object();
private boolean started;
@Nullable
private BackoffPolicy lrsRpcRetryPolicy;
@Nullable
private ScheduledHandle lrsRpcRetryTimer;
private ScheduledFuture<?> lrsRpcRetryTimer;
@Nullable
private LrsStream lrsStream;
@ -74,13 +71,11 @@ final class LoadReportClient {
LoadStatsManager loadStatsManager,
XdsChannel xdsChannel,
Node node,
SynchronizationContext syncContext,
ScheduledExecutorService scheduledExecutorService,
BackoffPolicy.Provider backoffPolicyProvider,
Supplier<Stopwatch> stopwatchSupplier) {
this.loadStatsManager = checkNotNull(loadStatsManager, "loadStatsManager");
this.xdsChannel = checkNotNull(xdsChannel, "xdsChannel");
this.syncContext = checkNotNull(syncContext, "syncContext");
this.timerService = checkNotNull(scheduledExecutorService, "timeService");
this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider");
this.retryStopwatch = checkNotNull(stopwatchSupplier, "stopwatchSupplier").get();
@ -97,6 +92,7 @@ final class LoadReportClient {
* no-op.
*/
void startLoadReporting() {
synchronized (lock) {
if (started) {
return;
}
@ -104,28 +100,31 @@ final class LoadReportClient {
logger.log(XdsLogLevel.INFO, "Starting load reporting RPC");
startLrsRpc();
}
}
/**
* Terminates load reporting. Calling this method on an already stopped
* {@link LoadReportClient} is no-op.
*/
void stopLoadReporting() {
synchronized (lock) {
if (!started) {
return;
}
started = false;
logger.log(XdsLogLevel.INFO, "Stopping load reporting RPC");
if (lrsRpcRetryTimer != null) {
lrsRpcRetryTimer.cancel();
lrsRpcRetryTimer.cancel(false);
}
if (lrsStream != null) {
lrsStream.close(Status.CANCELLED.withDescription("stop load reporting").asException());
}
started = false;
// Do not shutdown channel as it is not owned by LrsClient.
}
}
@VisibleForTesting
static class LoadReportingTask implements Runnable {
class LoadReportingTask implements Runnable {
private final LrsStream stream;
LoadReportingTask(LrsStream stream) {
@ -134,20 +133,27 @@ final class LoadReportClient {
@Override
public void run() {
synchronized (lock) {
stream.sendLoadReport();
}
}
}
@VisibleForTesting
class LrsRpcRetryTask implements Runnable {
@Override
public void run() {
synchronized (lock) {
startLrsRpc();
}
}
}
private void startLrsRpc() {
if (!started) {
return;
}
checkState(lrsStream == null, "previous lbStream has not been cleared yet");
if (xdsChannel.isUseProtocolV3()) {
lrsStream = new LrsStreamV3();
@ -161,19 +167,19 @@ final class LoadReportClient {
private abstract class LrsStream {
boolean initialResponseReceived;
boolean closed;
long loadReportIntervalNano = -1;
long intervalNano = -1;
boolean reportAllClusters;
List<String> clusterNames; // clusters to report loads for, if not report all.
ScheduledHandle loadReportTimer;
ScheduledFuture<?> loadReportTimer;
abstract void start();
abstract void sendLoadStatsRequest(LoadStatsRequestData request);
abstract void sendLoadStatsRequest(List<ClusterStats> clusterStatsList);
abstract void sendError(Exception error);
// Must run in syncContext.
final void handleResponse(LoadStatsResponseData response) {
final void handleRpcResponse(List<String> clusters, boolean sendAllClusters,
long loadReportIntervalNano) {
if (closed) {
return;
}
@ -181,30 +187,30 @@ final class LoadReportClient {
logger.log(XdsLogLevel.DEBUG, "Initial LRS response received");
initialResponseReceived = true;
}
reportAllClusters = response.getSendAllClusters();
reportAllClusters = sendAllClusters;
if (reportAllClusters) {
logger.log(XdsLogLevel.INFO, "Report loads for all clusters");
} else {
logger.log(XdsLogLevel.INFO, "Report loads for clusters: ", response.getClustersList());
clusterNames = response.getClustersList();
logger.log(XdsLogLevel.INFO, "Report loads for clusters: ", clusters);
clusterNames = clusters;
}
long interval = response.getLoadReportingIntervalNanos();
logger.log(XdsLogLevel.INFO, "Update load reporting interval to {0} ns", interval);
loadReportIntervalNano = interval;
intervalNano = loadReportIntervalNano;
logger.log(XdsLogLevel.INFO, "Update load reporting interval to {0} ns", intervalNano);
scheduleNextLoadReport();
}
// Must run in syncContext.
final void handleRpcError(Throwable t) {
handleStreamClosed(Status.fromThrowable(t));
}
// Must run in syncContext.
final void handleRpcCompleted() {
handleStreamClosed(Status.UNAVAILABLE.withDescription("Closed by server"));
}
private void sendLoadReport() {
if (closed) {
return;
}
List<ClusterStats> clusterStatsList;
if (reportAllClusters) {
clusterStatsList = loadStatsManager.getAllLoadReports();
@ -214,21 +220,19 @@ final class LoadReportClient {
clusterStatsList.addAll(loadStatsManager.getClusterLoadReports(name));
}
}
LoadStatsRequestData request = new LoadStatsRequestData(node, clusterStatsList);
sendLoadStatsRequest(request);
sendLoadStatsRequest(clusterStatsList);
scheduleNextLoadReport();
}
private void scheduleNextLoadReport() {
// Cancel pending load report and reschedule with updated load reporting interval.
if (loadReportTimer != null && loadReportTimer.isPending()) {
loadReportTimer.cancel();
if (loadReportTimer != null && !loadReportTimer.isDone()) {
loadReportTimer.cancel(false);
loadReportTimer = null;
}
if (loadReportIntervalNano > 0) {
loadReportTimer = syncContext.schedule(
new LoadReportingTask(this), loadReportIntervalNano, TimeUnit.NANOSECONDS,
timerService);
if (intervalNano > 0) {
loadReportTimer = timerService.schedule(
new LoadReportingTask(this), intervalNano, TimeUnit.NANOSECONDS);
}
}
@ -263,8 +267,7 @@ final class LoadReportClient {
startLrsRpc();
} else {
lrsRpcRetryTimer =
syncContext.schedule(new LrsRpcRetryTask(), delayNanos, TimeUnit.NANOSECONDS,
timerService);
timerService.schedule(new LrsRpcRetryTask(), delayNanos, TimeUnit.NANOSECONDS);
}
}
@ -279,7 +282,7 @@ final class LoadReportClient {
private void cleanUp() {
if (loadReportTimer != null) {
loadReportTimer.cancel();
loadReportTimer.cancel(false);
loadReportTimer = null;
}
if (lrsStream == this) {
@ -298,34 +301,26 @@ final class LoadReportClient {
new StreamObserver<io.envoyproxy.envoy.service.load_stats.v2.LoadStatsResponse>() {
@Override
public void onNext(
final io.envoyproxy.envoy.service.load_stats.v2.LoadStatsResponse response) {
syncContext.execute(new Runnable() {
@Override
public void run() {
io.envoyproxy.envoy.service.load_stats.v2.LoadStatsResponse response) {
synchronized (lock) {
logger.log(XdsLogLevel.DEBUG, "Received LoadStatsResponse:\n{0}", response);
handleResponse(LoadStatsResponseData.fromEnvoyProtoV2(response));
handleRpcResponse(response.getClustersList(), response.getSendAllClusters(),
Durations.toNanos(response.getLoadReportingInterval()));
}
});
}
@Override
public void onError(final Throwable t) {
syncContext.execute(new Runnable() {
@Override
public void run() {
public void onError(Throwable t) {
synchronized (lock) {
handleRpcError(t);
}
});
}
@Override
public void onCompleted() {
syncContext.execute(new Runnable() {
@Override
public void run() {
synchronized (lock) {
handleRpcCompleted();
}
});
}
};
io.envoyproxy.envoy.service.load_stats.v2.LoadReportingServiceGrpc.LoadReportingServiceStub
@ -333,15 +328,20 @@ final class LoadReportClient {
xdsChannel.getManagedChannel());
lrsRequestWriterV2 = stubV2.withWaitForReady().streamLoadStats(lrsResponseReaderV2);
logger.log(XdsLogLevel.DEBUG, "Sending initial LRS request");
sendLoadStatsRequest(new LoadStatsRequestData(node, null));
sendLoadStatsRequest(Collections.<ClusterStats>emptyList());
}
@Override
void sendLoadStatsRequest(LoadStatsRequestData request) {
io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest requestProto =
request.toEnvoyProtoV2();
lrsRequestWriterV2.onNext(requestProto);
logger.log(XdsLogLevel.DEBUG, "Sent LoadStatsRequest\n{0}", requestProto);
void sendLoadStatsRequest(List<ClusterStats> clusterStatsList) {
io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest.Builder requestBuilder =
io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest.newBuilder()
.setNode(node.toEnvoyProtoNodeV2());
for (ClusterStats stats : clusterStatsList) {
requestBuilder.addClusterStats(stats.toEnvoyProtoClusterStatsV2());
}
io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest request = requestBuilder.build();
lrsRequestWriterV2.onNext(requestBuilder.build());
logger.log(XdsLogLevel.DEBUG, "Sent LoadStatsRequest\n{0}", request);
}
@Override
@ -358,48 +358,45 @@ final class LoadReportClient {
StreamObserver<LoadStatsResponse> lrsResponseReaderV3 =
new StreamObserver<LoadStatsResponse>() {
@Override
public void onNext(final LoadStatsResponse response) {
syncContext.execute(new Runnable() {
@Override
public void run() {
public void onNext(LoadStatsResponse response) {
synchronized (lock) {
logger.log(XdsLogLevel.DEBUG, "Received LRS response:\n{0}", response);
handleResponse(LoadStatsResponseData.fromEnvoyProtoV3(response));
handleRpcResponse(response.getClustersList(), response.getSendAllClusters(),
Durations.toNanos(response.getLoadReportingInterval()));
}
});
}
@Override
public void onError(final Throwable t) {
syncContext.execute(new Runnable() {
@Override
public void run() {
public void onError(Throwable t) {
synchronized (lock) {
handleRpcError(t);
}
});
}
@Override
public void onCompleted() {
syncContext.execute(new Runnable() {
@Override
public void run() {
synchronized (lock) {
handleRpcCompleted();
}
});
}
};
LoadReportingServiceStub stubV3 =
LoadReportingServiceGrpc.newStub(xdsChannel.getManagedChannel());
lrsRequestWriterV3 = stubV3.withWaitForReady().streamLoadStats(lrsResponseReaderV3);
logger.log(XdsLogLevel.DEBUG, "Sending initial LRS request");
sendLoadStatsRequest(new LoadStatsRequestData(node, null));
sendLoadStatsRequest(Collections.<ClusterStats>emptyList());
}
@Override
void sendLoadStatsRequest(LoadStatsRequestData request) {
LoadStatsRequest requestProto = request.toEnvoyProtoV3();
lrsRequestWriterV3.onNext(requestProto);
logger.log(XdsLogLevel.DEBUG, "Sent LoadStatsRequest\n{0}", requestProto);
void sendLoadStatsRequest(List<ClusterStats> clusterStatsList) {
LoadStatsRequest.Builder requestBuilder =
LoadStatsRequest.newBuilder().setNode(node.toEnvoyProtoNode());
for (ClusterStats stats : clusterStatsList) {
requestBuilder.addClusterStats(stats.toEnvoyProtoClusterStats());
}
LoadStatsRequest request = requestBuilder.build();
lrsRequestWriterV3.onNext(request);
logger.log(XdsLogLevel.DEBUG, "Sent LoadStatsRequest\n{0}", request);
}
@Override
@ -407,78 +404,4 @@ final class LoadReportClient {
lrsRequestWriterV3.onError(error);
}
}
private static final class LoadStatsRequestData {
final Node node;
@Nullable
final List<ClusterStats> clusterStatsList;
LoadStatsRequestData(Node node, @Nullable List<ClusterStats> clusterStatsList) {
this.node = checkNotNull(node, "node");
this.clusterStatsList = clusterStatsList;
}
io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest toEnvoyProtoV2() {
io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest.Builder builder
= io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest.newBuilder()
.setNode(node.toEnvoyProtoNodeV2());
if (clusterStatsList != null) {
for (ClusterStats stats : clusterStatsList) {
builder.addClusterStats(stats.toEnvoyProtoClusterStatsV2());
}
}
return builder.build();
}
LoadStatsRequest toEnvoyProtoV3() {
LoadStatsRequest.Builder builder = LoadStatsRequest.newBuilder()
.setNode(node.toEnvoyProtoNode());
if (clusterStatsList != null) {
for (ClusterStats stats : clusterStatsList) {
builder.addClusterStats(stats.toEnvoyProtoClusterStats());
}
}
return builder.build();
}
}
private static final class LoadStatsResponseData {
final boolean sendAllClusters;
final List<String> clusters;
final long loadReportingIntervalNanos;
LoadStatsResponseData(
boolean sendAllClusters, List<String> clusters, long loadReportingIntervalNanos) {
this.sendAllClusters = sendAllClusters;
this.clusters = checkNotNull(clusters, "clusters");
this.loadReportingIntervalNanos = loadReportingIntervalNanos;
}
boolean getSendAllClusters() {
return sendAllClusters;
}
List<String> getClustersList() {
return clusters;
}
long getLoadReportingIntervalNanos() {
return loadReportingIntervalNanos;
}
static LoadStatsResponseData fromEnvoyProtoV2(
io.envoyproxy.envoy.service.load_stats.v2.LoadStatsResponse loadStatsResponse) {
return new LoadStatsResponseData(
loadStatsResponse.getSendAllClusters(),
loadStatsResponse.getClustersList(),
Durations.toNanos(loadStatsResponse.getLoadReportingInterval()));
}
static LoadStatsResponseData fromEnvoyProtoV3(LoadStatsResponse loadStatsResponse) {
return new LoadStatsResponseData(
loadStatsResponse.getSendAllClusters(),
loadStatsResponse.getClustersList(),
Durations.toNanos(loadStatsResponse.getLoadReportingInterval()));
}
}
}

View File

@ -159,7 +159,7 @@ final class XdsClientImpl2 extends XdsClient {
this.timeService = checkNotNull(timeService, "timeService");
this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider");
adsStreamRetryStopwatch = checkNotNull(stopwatchSupplier, "stopwatchSupplier").get();
lrsClient = new LoadReportClient(loadStatsManager, xdsChannel, node, syncContext, timeService,
lrsClient = new LoadReportClient(loadStatsManager, xdsChannel, node, timeService,
backoffPolicyProvider, stopwatchSupplier);
logId = InternalLogId.allocate("xds-client", null);
logger = XdsLogger.withLogId(logId);

View File

@ -44,7 +44,6 @@ import io.grpc.Context;
import io.grpc.Context.CancellationListener;
import io.grpc.ManagedChannel;
import io.grpc.Status;
import io.grpc.SynchronizationContext;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.internal.BackoffPolicy;
@ -114,13 +113,6 @@ public class LoadReportClientTest {
@Rule
public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule();
private final SynchronizationContext syncContext = new SynchronizationContext(
new Thread.UncaughtExceptionHandler() {
@Override
public void uncaughtException(Thread t, Throwable e) {
throw new AssertionError(e);
}
});
private final FakeClock fakeClock = new FakeClock();
private final ArrayDeque<StreamObserver<LoadStatsRequest>> lrsRequestObservers =
new ArrayDeque<>();
@ -142,6 +134,8 @@ public class LoadReportClientTest {
private BackoffPolicy backoffPolicy2;
@Captor
private ArgumentCaptor<StreamObserver<LoadStatsResponse>> lrsResponseObserverCaptor;
@Captor
private ArgumentCaptor<Throwable> errorCaptor;
private LoadReportingServiceGrpc.LoadReportingServiceImplBase mockLoadReportingService;
private ManagedChannel channel;
@ -187,7 +181,6 @@ public class LoadReportClientTest {
loadStatsManager,
new XdsChannel(channel, false),
NODE,
syncContext,
fakeClock.getScheduledExecutorService(),
backoffPolicyProvider,
fakeClock.getStopwatchSupplier());
@ -407,7 +400,55 @@ public class LoadReportClientTest {
}
@Test
public void raceBetweenLoadReportingAndLbStreamClosure() {
public void raceBetweenStopAndLoadReporting() {
verify(mockLoadReportingService).streamLoadStats(lrsResponseObserverCaptor.capture());
StreamObserver<LoadStatsResponse> responseObserver = lrsResponseObserverCaptor.getValue();
StreamObserver<LoadStatsRequest> requestObserver =
Iterables.getOnlyElement(lrsRequestObservers);
verify(requestObserver).onNext(eq(buildInitialRequest()));
responseObserver.onNext(buildLrsResponse(Collections.singletonList(CLUSTER1), 1234));
assertEquals(1, fakeClock.numPendingTasks(LOAD_REPORTING_TASK_FILTER));
FakeClock.ScheduledTask scheduledTask =
Iterables.getOnlyElement(fakeClock.getPendingTasks(LOAD_REPORTING_TASK_FILTER));
assertEquals(1234, scheduledTask.getDelay(TimeUnit.NANOSECONDS));
fakeClock.forwardNanos(1233);
lrsClient.stopLoadReporting();
verify(requestObserver).onError(errorCaptor.capture());
assertEquals("CANCELLED: client cancelled", errorCaptor.getValue().getMessage());
assertThat(scheduledTask.isCancelled()).isTrue();
fakeClock.forwardNanos(1);
assertEquals(0, fakeClock.numPendingTasks(LOAD_REPORTING_TASK_FILTER));
fakeClock.forwardNanos(1234);
verifyNoMoreInteractions(requestObserver);
}
@Test
public void raceBetweenStopAndLrsStreamRetry() {
verify(mockLoadReportingService).streamLoadStats(lrsResponseObserverCaptor.capture());
StreamObserver<LoadStatsResponse> responseObserver = lrsResponseObserverCaptor.getValue();
StreamObserver<LoadStatsRequest> requestObserver =
Iterables.getOnlyElement(lrsRequestObservers);
verify(requestObserver).onNext(eq(buildInitialRequest()));
responseObserver.onCompleted();
assertEquals(1, fakeClock.numPendingTasks(LRS_RPC_RETRY_TASK_FILTER));
FakeClock.ScheduledTask scheduledTask =
Iterables.getOnlyElement(fakeClock.getPendingTasks(LRS_RPC_RETRY_TASK_FILTER));
assertEquals(1, scheduledTask.getDelay(TimeUnit.SECONDS));
fakeClock.forwardTime(999, TimeUnit.MILLISECONDS);
lrsClient.stopLoadReporting();
assertThat(scheduledTask.isCancelled()).isTrue();
fakeClock.forwardTime(1, TimeUnit.MILLISECONDS);
assertEquals(0, fakeClock.numPendingTasks(LRS_RPC_RETRY_TASK_FILTER));
fakeClock.forwardTime(10, TimeUnit.SECONDS);
verifyNoMoreInteractions(requestObserver);
}
@Test
public void raceBetweenLoadReportingAndLrsStreamClosure() {
verify(mockLoadReportingService).streamLoadStats(lrsResponseObserverCaptor.capture());
StreamObserver<LoadStatsResponse> responseObserver = lrsResponseObserverCaptor.getValue();
assertThat(lrsRequestObservers).hasSize(1);