diff --git a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java index c3d1aeee78..c628c3d029 100644 --- a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java +++ b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java @@ -32,7 +32,6 @@ import io.grpc.Status; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.HashSet; import java.util.LinkedHashSet; import java.util.concurrent.Executor; import javax.annotation.Nonnull; @@ -65,9 +64,6 @@ final class DelayedClientTransport implements ManagedClientTransport { @GuardedBy("lock") private Collection pendingStreams = new LinkedHashSet(); - @GuardedBy("lock") - private Collection uncommittedRetriableStreams = new HashSet(); - /** * When {@code shutdownStatus != null && !hasPendingStreams()}, then the transport is considered * terminated. @@ -243,23 +239,16 @@ final class DelayedClientTransport implements ManagedClientTransport { Runnable savedReportTransportTerminated; synchronized (lock) { savedPendingStreams = pendingStreams; - savedUncommittedRetriableStreams = uncommittedRetriableStreams; savedReportTransportTerminated = reportTransportTerminated; reportTransportTerminated = null; if (!pendingStreams.isEmpty()) { pendingStreams = Collections.emptyList(); } - if (!uncommittedRetriableStreams.isEmpty()) { - uncommittedRetriableStreams = Collections.emptyList(); - } } if (savedReportTransportTerminated != null) { for (PendingStream stream : savedPendingStreams) { stream.cancel(status); } - for (ClientStream stream : savedUncommittedRetriableStreams) { - stream.cancel(status); - } channelExecutor.executeLater(savedReportTransportTerminated).drain(); } // If savedReportTransportTerminated == null, transportTerminated() has already been called in @@ -268,49 +257,14 @@ final class DelayedClientTransport implements ManagedClientTransport { public final boolean hasPendingStreams() { synchronized (lock) { - return !pendingStreams.isEmpty() || !uncommittedRetriableStreams.isEmpty(); + return !pendingStreams.isEmpty(); } } @VisibleForTesting final int getPendingStreamsCount() { synchronized (lock) { - return pendingStreams.size() + uncommittedRetriableStreams.size(); - } - } - - /** - * Registers a RetriableStream and return null if not shutdown, otherwise just returns the - * shutdown Status. - */ - @Nullable - final Status addUncommittedRetriableStream(RetriableStream retriableStream) { - synchronized (lock) { - if (shutdownStatus != null) { - return shutdownStatus; - } - uncommittedRetriableStreams.add(retriableStream); - if (getPendingStreamsCount() == 1) { - channelExecutor.executeLater(reportTransportInUse); - } - return null; - } - } - - final void removeUncommittedRetriableStream(RetriableStream retriableStream) { - synchronized (lock) { - uncommittedRetriableStreams.remove(retriableStream); - if (!hasPendingStreams()) { - channelExecutor.executeLater(reportTransportNotInUse); - if (shutdownStatus != null && reportTransportTerminated != null) { - channelExecutor.executeLater(reportTransportTerminated); - reportTransportTerminated = null; - } else { - // Because delayed transport is long-lived, we take this opportunity to down-size the - // hashmap. - uncommittedRetriableStreams = new HashSet(); - } - } + return pendingStreams.size(); } } diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index 4458dd6c2b..215d1a6c3a 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -58,6 +58,8 @@ import java.lang.ref.SoftReference; import java.lang.ref.WeakReference; import java.net.URI; import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.Collection; import java.util.HashSet; import java.util.List; import java.util.Set; @@ -162,6 +164,8 @@ public final class ManagedChannelImpl extends ManagedChannel { // reprocess() must be run from channelExecutor private final DelayedClientTransport delayedTransport; + private final UncommittedRetriableStreamsRegistry uncommittedRetriableStreamsRegistry + = new UncommittedRetriableStreamsRegistry(); // Shutdown states. // @@ -412,12 +416,12 @@ public final class ManagedChannelImpl extends ManagedChannel { return new RetriableStream(method) { @Override Status prestart() { - return delayedTransport.addUncommittedRetriableStream(this); + return uncommittedRetriableStreamsRegistry.add(this); } @Override void postCommit() { - delayedTransport.removeUncommittedRetriableStream(this); + uncommittedRetriableStreamsRegistry.remove(this); } @Override @@ -551,7 +555,7 @@ public final class ManagedChannelImpl extends ManagedChannel { } }); - delayedTransport.shutdown(SHUTDOWN_STATUS); + uncommittedRetriableStreamsRegistry.onShutdown(SHUTDOWN_STATUS); channelExecutor.executeLater(new Runnable() { @Override public void run() { @@ -572,7 +576,7 @@ public final class ManagedChannelImpl extends ManagedChannel { logger.log(Level.FINE, "[{0}] shutdownNow() called", getLogId()); shutdown(); phantom.shutdownNow = true; - delayedTransport.shutdownNow(SHUTDOWN_NOW_STATUS); + uncommittedRetriableStreamsRegistry.onShutdownNow(SHUTDOWN_NOW_STATUS); channelExecutor.executeLater(new Runnable() { @Override public void run() { @@ -710,6 +714,91 @@ public final class ManagedChannelImpl extends ManagedChannel { }).drain(); } + /** + * A registry that prevents channel shutdown from killing existing retry attempts that are in + * backoff. + */ + // TODO(zdapeng): add test coverage for shutdown during retry backoff once retry backoff is + // implemented. + private final class UncommittedRetriableStreamsRegistry { + // TODO(zdapeng): This means we would acquire a lock for each new retry-able stream, + // it's worthwhile to look for a lock-free approach. + final Object lock = new Object(); + + @GuardedBy("lock") + Collection uncommittedRetriableStreams = new HashSet(); + + @GuardedBy("lock") + Status shutdownStatus; + + void onShutdown(Status reason) { + boolean shouldShutdownDelayedTransport = false; + synchronized (lock) { + if (shutdownStatus != null) { + return; + } + shutdownStatus = reason; + // Keep the delayedTransport open until there is no more uncommitted streams, b/c those + // retriable streams, which may be in backoff and not using any transport, are already + // started RPCs. + if (uncommittedRetriableStreams.isEmpty()) { + shouldShutdownDelayedTransport = true; + } + } + + if (shouldShutdownDelayedTransport) { + delayedTransport.shutdown(reason); + } + } + + void onShutdownNow(Status reason) { + onShutdown(reason); + Collection streams; + + synchronized (lock) { + streams = new ArrayList(uncommittedRetriableStreams); + } + + for (ClientStream stream : streams) { + stream.cancel(reason); + } + delayedTransport.shutdownNow(reason); + } + + /** + * Registers a RetriableStream and return null if not shutdown, otherwise just returns the + * shutdown Status. + */ + @Nullable + Status add(RetriableStream retriableStream) { + synchronized (lock) { + if (shutdownStatus != null) { + return shutdownStatus; + } + uncommittedRetriableStreams.add(retriableStream); + return null; + } + } + + void remove(RetriableStream retriableStream) { + Status shutdownStatusCopy = null; + + synchronized (lock) { + uncommittedRetriableStreams.remove(retriableStream); + if (uncommittedRetriableStreams.isEmpty()) { + shutdownStatusCopy = shutdownStatus; + // Because retriable transport is long-lived, we take this opportunity to down-size the + // hashmap. + uncommittedRetriableStreams = new HashSet(); + } + } + + if (shutdownStatusCopy != null) { + delayedTransport.shutdown(shutdownStatusCopy); + } + } + } + private class LbHelperImpl extends LoadBalancer.Helper { LoadBalancer lb; final NameResolver nr;