diff --git a/core/src/main/java/io/grpc/internal/RetriableStream.java b/core/src/main/java/io/grpc/internal/RetriableStream.java index 46da45aebd..37e8341f85 100644 --- a/core/src/main/java/io/grpc/internal/RetriableStream.java +++ b/core/src/main/java/io/grpc/internal/RetriableStream.java @@ -108,7 +108,7 @@ abstract class RetriableStream implements ClientStream { private final AtomicBoolean noMoreTransparentRetry = new AtomicBoolean(); private final AtomicInteger localOnlyTransparentRetries = new AtomicInteger(); private final AtomicInteger inFlightSubStreams = new AtomicInteger(); - private Status savedCancellationReason; + private SavedCloseMasterListenerReason savedCloseMasterListenerReason; // Used for recording the share of buffer used for the current call out of the channel buffer. // This field would not be necessary if there is no channel buffer limit. @@ -222,9 +222,10 @@ abstract class RetriableStream implements ClientStream { } } - @Nullable // returns null when cancelled + // returns null means we should not create new sub streams, e.g. cancelled or + // other close condition is met for retriableStream. + @Nullable private Substream createSubstream(int previousAttemptCount, boolean isTransparentRetry) { - // increment only when >= 0, i.e. not cancelled int inFlight; do { inFlight = inFlightSubStreams.get(); @@ -506,11 +507,8 @@ abstract class RetriableStream implements ClientStream { Runnable runnable = commit(noopSubstream); if (runnable != null) { - savedCancellationReason = reason; runnable.run(); - if (inFlightSubStreams.addAndGet(Integer.MIN_VALUE) == Integer.MIN_VALUE) { - safeCloseMasterListener(reason, RpcProgress.PROCESSED, new Metadata()); - } + safeCloseMasterListener(reason, RpcProgress.PROCESSED, new Metadata()); return; } @@ -816,14 +814,30 @@ abstract class RetriableStream implements ClientStream { } private void safeCloseMasterListener(Status status, RpcProgress progress, Metadata metadata) { - listenerSerializeExecutor.execute( - new Runnable() { - @Override - public void run() { - isClosed = true; - masterListener.closed(status, progress, metadata); - } - }); + savedCloseMasterListenerReason = new SavedCloseMasterListenerReason(status, progress, + metadata); + if (inFlightSubStreams.addAndGet(Integer.MIN_VALUE) == Integer.MIN_VALUE) { + listenerSerializeExecutor.execute( + new Runnable() { + @Override + public void run() { + isClosed = true; + masterListener.closed(status, progress, metadata); + } + }); + } + } + + private static final class SavedCloseMasterListenerReason { + private final Status status; + private final RpcProgress progress; + private final Metadata metadata; + + SavedCloseMasterListenerReason(Status status, RpcProgress progress, Metadata metadata) { + this.status = status; + this.progress = progress; + this.metadata = metadata; + } } private interface BufferEntry { @@ -864,8 +878,17 @@ abstract class RetriableStream implements ClientStream { } if (inFlightSubStreams.decrementAndGet() == Integer.MIN_VALUE) { - assert savedCancellationReason != null; - safeCloseMasterListener(savedCancellationReason, RpcProgress.PROCESSED, new Metadata()); + assert savedCloseMasterListenerReason != null; + listenerSerializeExecutor.execute( + new Runnable() { + @Override + public void run() { + isClosed = true; + masterListener.closed(savedCloseMasterListenerReason.status, + savedCloseMasterListenerReason.progress, + savedCloseMasterListenerReason.metadata); + } + }); return; } diff --git a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java index 12bf697027..4910c0b613 100644 --- a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java +++ b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java @@ -2151,6 +2151,10 @@ public class RetriableStreamTest { assertEquals(Status.CANCELLED.getCode(), statusCaptor.getValue().getCode()); assertEquals(CANCELLED_BECAUSE_COMMITTED, statusCaptor.getValue().getDescription()); inOrder.verify(retriableStreamRecorder).postCommit(); + sublistenerCaptor1.getValue().closed( + Status.CANCELLED, PROCESSED, new Metadata()); + sublistenerCaptor4.getValue().closed( + Status.CANCELLED, PROCESSED, new Metadata()); inOrder.verify(masterListener).closed( any(Status.class), any(RpcProgress.class), any(Metadata.class)); inOrder.verifyNoMoreInteractions(); @@ -2158,7 +2162,8 @@ public class RetriableStreamTest { insight = new InsightBuilder(); hedgingStream.appendTimeoutInsight(insight); assertThat(insight.toString()).isEqualTo( - "[closed=[UNAVAILABLE, INTERNAL], committed=[remote_addr=2.2.2.2:81]]"); + "[closed=[UNAVAILABLE, INTERNAL, CANCELLED, CANCELLED], " + + "committed=[remote_addr=2.2.2.2:81]]"); } @Test @@ -2425,6 +2430,7 @@ public class RetriableStreamTest { assertEquals(Status.CANCELLED.getCode(), statusCaptor.getValue().getCode()); assertEquals(CANCELLED_BECAUSE_COMMITTED, statusCaptor.getValue().getDescription()); inOrder.verify(retriableStreamRecorder).postCommit(); + sublistenerCaptor3.getValue().closed(Status.CANCELLED, PROCESSED, metadata); inOrder.verify(masterListener).closed(fatal, PROCESSED, metadata); inOrder.verifyNoMoreInteractions(); } @@ -2605,6 +2611,8 @@ public class RetriableStreamTest { assertEquals(Status.CANCELLED.getCode(), statusCaptor.getValue().getCode()); assertEquals(CANCELLED_BECAUSE_COMMITTED, statusCaptor.getValue().getDescription()); verify(retriableStreamRecorder).postCommit(); + sublistenerCaptor1.getValue().closed(Status.CANCELLED, PROCESSED, metadata); + sublistenerCaptor4.getValue().closed(Status.CANCELLED, PROCESSED, metadata); verify(masterListener).closed(status, REFUSED, metadata); } @@ -2645,6 +2653,9 @@ public class RetriableStreamTest { assertEquals(Status.CANCELLED.getCode(), statusCaptor.getValue().getCode()); assertEquals(CANCELLED_BECAUSE_COMMITTED, statusCaptor.getValue().getDescription()); verify(retriableStreamRecorder).postCommit(); + sublistenerCaptor1.getValue() + .closed(Status.CANCELLED, REFUSED, new Metadata()); + //master listener close should wait until all substreams are closed verify(masterListener).closed(status, REFUSED, metadata); }