diff --git a/core/src/main/java/io/grpc/internal/AbstractClientStream.java b/core/src/main/java/io/grpc/internal/AbstractClientStream.java index 8b2bcf199b..91edfb631e 100644 --- a/core/src/main/java/io/grpc/internal/AbstractClientStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractClientStream.java @@ -28,6 +28,7 @@ import io.grpc.Decompressor; import io.grpc.DecompressorRegistry; import io.grpc.Metadata; import io.grpc.Status; +import io.grpc.internal.ClientStreamListener.RpcProgress; import java.io.InputStream; import java.util.logging.Level; import java.util.logging.Logger; @@ -356,6 +357,27 @@ public abstract class AbstractClientStream extends AbstractStream */ public final void transportReportStatus(final Status status, boolean stopDelivery, final Metadata trailers) { + transportReportStatus(status, RpcProgress.PROCESSED, stopDelivery, trailers); + } + + /** + * Report stream closure with status to the application layer if not already reported. This + * method must be called from the transport thread. + * + * @param status the new status to set + * @param rpcProgress RPC progress that the + * {@link ClientStreamListener#closed(Status, RpcProgress, Metadata)} + * will receive + * @param stopDelivery if {@code true}, interrupts any further delivery of inbound messages that + * may already be queued up in the deframer. If {@code false}, the listener will be + * notified immediately after all currently completed messages in the deframer have been + * delivered to the application. + * @param trailers new instance of {@code Trailers}, either empty or those returned by the + * server + */ + public final void transportReportStatus( + final Status status, final RpcProgress rpcProgress, boolean stopDelivery, + final Metadata trailers) { checkNotNull(status, "status"); checkNotNull(trailers, "trailers"); // If stopDelivery, we continue in case previous invocation is waiting for stall @@ -367,13 +389,13 @@ public abstract class AbstractClientStream extends AbstractStream if (deframerClosed) { deframerClosedTask = null; - closeListener(status, trailers); + closeListener(status, rpcProgress, trailers); } else { deframerClosedTask = new Runnable() { @Override public void run() { - closeListener(status, trailers); + closeListener(status, rpcProgress, trailers); } }; closeDeframer(stopDelivery); @@ -385,11 +407,12 @@ public abstract class AbstractClientStream extends AbstractStream * * @throws IllegalStateException if the call has not yet been started. */ - private void closeListener(Status status, Metadata trailers) { + private void closeListener( + Status status, RpcProgress rpcProgress, Metadata trailers) { if (!listenerClosed) { listenerClosed = true; statsTraceCtx.streamClosed(status); - listener().closed(status, trailers); + listener().closed(status, rpcProgress, trailers); if (getTransportTracer() != null) { getTransportTracer().reportStreamClosed(status.isOk()); } diff --git a/core/src/main/java/io/grpc/internal/ClientCallImpl.java b/core/src/main/java/io/grpc/internal/ClientCallImpl.java index 7c8776ddf8..b81a6614ca 100644 --- a/core/src/main/java/io/grpc/internal/ClientCallImpl.java +++ b/core/src/main/java/io/grpc/internal/ClientCallImpl.java @@ -559,6 +559,11 @@ final class ClientCallImpl extends ClientCall { @Override public void closed(Status status, Metadata trailers) { + closed(status, RpcProgress.PROCESSED, trailers); + } + + @Override + public void closed(Status status, RpcProgress rpcProgress, Metadata trailers) { Deadline deadline = effectiveDeadline(); if (status.getCode() == Status.Code.CANCELLED && deadline != null) { // When the server's deadline expires, it can only reset the stream with CANCEL and no diff --git a/core/src/main/java/io/grpc/internal/ClientStreamListener.java b/core/src/main/java/io/grpc/internal/ClientStreamListener.java index 4606a1c421..e205391619 100644 --- a/core/src/main/java/io/grpc/internal/ClientStreamListener.java +++ b/core/src/main/java/io/grpc/internal/ClientStreamListener.java @@ -45,5 +45,39 @@ public interface ClientStreamListener extends StreamListener { * @param status details about the remote closure * @param trailers trailing metadata */ + // TODO(zdapeng): remove this method in favor of the 3-arg one. void closed(Status status, Metadata trailers); + + /** + * Called when the stream is fully closed. {@link + * io.grpc.Status.Code#OK} is the only status code that is guaranteed + * to have been sent from the remote server. Any other status code may have been caused by + * abnormal stream termination. This is guaranteed to always be the final call on a listener. No + * further callbacks will be issued. + * + *

This method should return quickly, as the same thread may be used to process other streams. + * + * @param status details about the remote closure + * @param rpcProgress RPC progress when client stream listener is closed + * @param trailers trailing metadata + */ + void closed(Status status, RpcProgress rpcProgress, Metadata trailers); + + /** + * The progress of the RPC when client stream listener is closed. + */ + enum RpcProgress { + /** + * The RPC is processed by the server normally. + */ + PROCESSED, + /** + * The RPC is not processed by the server's application logic. + */ + REFUSED, + /** + * The RPC is dropped (by load balancer). + */ + DROPPED + } } diff --git a/core/src/main/java/io/grpc/internal/DelayedStream.java b/core/src/main/java/io/grpc/internal/DelayedStream.java index 1d9048059a..81bbd22c9f 100644 --- a/core/src/main/java/io/grpc/internal/DelayedStream.java +++ b/core/src/main/java/io/grpc/internal/DelayedStream.java @@ -420,6 +420,18 @@ class DelayedStream implements ClientStream { }); } + @Override + public void closed( + final Status status, final RpcProgress rpcProgress, + final Metadata trailers) { + delayOrExecute(new Runnable() { + @Override + public void run() { + realListener.closed(status, rpcProgress, trailers); + } + }); + } + public void drainPendingCallbacks() { assert !passThrough; List toRun = new ArrayList(); diff --git a/core/src/main/java/io/grpc/internal/ForwardingClientStreamListener.java b/core/src/main/java/io/grpc/internal/ForwardingClientStreamListener.java index e371aa20f9..04e8085742 100644 --- a/core/src/main/java/io/grpc/internal/ForwardingClientStreamListener.java +++ b/core/src/main/java/io/grpc/internal/ForwardingClientStreamListener.java @@ -33,6 +33,11 @@ abstract class ForwardingClientStreamListener implements ClientStreamListener { delegate().closed(status, trailers); } + @Override + public void closed(Status status, RpcProgress rpcProgress, Metadata trailers) { + delegate().closed(status, rpcProgress, trailers); + } + @Override public void messagesAvailable(MessageProducer producer) { delegate().messagesAvailable(producer); diff --git a/core/src/main/java/io/grpc/internal/InternalSubchannel.java b/core/src/main/java/io/grpc/internal/InternalSubchannel.java index baefc81c5d..c6806c3c1b 100644 --- a/core/src/main/java/io/grpc/internal/InternalSubchannel.java +++ b/core/src/main/java/io/grpc/internal/InternalSubchannel.java @@ -650,6 +650,13 @@ final class InternalSubchannel implements Instrumented { callTracer.reportCallEnded(status.isOk()); super.closed(status, trailers); } + + @Override + public void closed( + Status status, RpcProgress rpcProgress, Metadata trailers) { + callTracer.reportCallEnded(status.isOk()); + super.closed(status, rpcProgress, trailers); + } }); } }; diff --git a/core/src/main/java/io/grpc/internal/RetriableStream.java b/core/src/main/java/io/grpc/internal/RetriableStream.java index c09be6f7c8..54aff78195 100644 --- a/core/src/main/java/io/grpc/internal/RetriableStream.java +++ b/core/src/main/java/io/grpc/internal/RetriableStream.java @@ -54,6 +54,7 @@ abstract class RetriableStream implements ClientStream { @VisibleForTesting static final Metadata.Key GRPC_PREVIOUS_RPC_ATTEMPTS = Metadata.Key.of("grpc-previous-rpc-attempts", Metadata.ASCII_STRING_MARSHALLER); + @VisibleForTesting static final Metadata.Key GRPC_RETRY_PUSHBACK_MS = Metadata.Key.of("grpc-retry-pushback-ms", Metadata.ASCII_STRING_MARSHALLER); @@ -80,6 +81,11 @@ abstract class RetriableStream implements ClientStream { private volatile State state = new State( new ArrayList(), Collections.emptySet(), null, false, false); + /** + * Either transparent retry happened or reached server's application logic. + */ + private boolean noMoreTransparentRetry; + // Used for recording the share of buffer used for the current call out of the channel buffer. // This field would not be necessary if there is no channel buffer limit. @GuardedBy("lock") @@ -152,10 +158,6 @@ abstract class RetriableStream implements ClientStream { } } - private void retry(int previousAttempts) { - Substream substream = createSubstream(previousAttempts); - drain(substream); - } private Substream createSubstream(int previousAttempts) { Substream sub = new Substream(previousAttempts); @@ -183,11 +185,11 @@ abstract class RetriableStream implements ClientStream { /** Adds grpc-previous-rpc-attempts in the headers of a retry/hedging RPC. */ @VisibleForTesting - final Metadata updateHeaders(Metadata originalHeaders, int previousAttempts) { - Metadata newHeaders = originalHeaders; + final Metadata updateHeaders( + Metadata originalHeaders, int previousAttempts) { + Metadata newHeaders = new Metadata(); + newHeaders.merge(originalHeaders); if (previousAttempts > 0) { - newHeaders = new Metadata(); - newHeaders.merge(originalHeaders); newHeaders.put(GRPC_PREVIOUS_RPC_ATTEMPTS, String.valueOf(previousAttempts)); } return newHeaders; @@ -282,7 +284,7 @@ abstract class RetriableStream implements ClientStream { @Override public final void cancel(Status reason) { - Substream noopSubstream = new Substream(0 /* previousAttempts doesn't matter here*/); + Substream noopSubstream = new Substream(0 /* previousAttempts doesn't matter here */); noopSubstream.stream = new NoopClientStream(); Runnable runnable = commit(noopSubstream); @@ -530,6 +532,11 @@ abstract class RetriableStream implements ClientStream { @Override public void closed(Status status, Metadata trailers) { + closed(status, RpcProgress.PROCESSED, trailers); + } + + @Override + public void closed(Status status, RpcProgress rpcProgress, Metadata trailers) { synchronized (lock) { state = state.substreamClosed(substream); } @@ -545,6 +552,22 @@ abstract class RetriableStream implements ClientStream { } if (state.winningSubstream == null) { + if (rpcProgress == RpcProgress.REFUSED && !noMoreTransparentRetry) { + // TODO(zdapeng): in hedging case noMoreTransparentRetry might need be synchronized. + noMoreTransparentRetry = true; + callExecutor.execute(new Runnable() { + @Override + public void run() { + // transparent retry + Substream newSubstream = createSubstream( + substream.previousAttempts); + drain(newSubstream); + } + }); + return; + } // TODO(zdapeng): else if (rpcProgress == RpcProgress.DROPPED) + + noMoreTransparentRetry = true; RetryPlan retryPlan = makeRetryDecision(retryPolicy, status, trailers); if (retryPlan.shouldRetry) { // The check state.winningSubstream == null, checking if is not already committed, is @@ -557,7 +580,9 @@ abstract class RetriableStream implements ClientStream { callExecutor.execute(new Runnable() { @Override public void run() { - retry(substream.previousAttempts + 1); + // retry + Substream newSubstream = createSubstream(substream.previousAttempts + 1); + drain(newSubstream); } }); } @@ -768,7 +793,6 @@ abstract class RetriableStream implements ClientStream { // setting to true must be GuardedBy RetriableStream.lock boolean bufferLimitExceeded; - // TODO(zdapeng): add transparent-retry-attempts final int previousAttempts; Substream(int previousAttempts) { diff --git a/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java b/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java index 5923c690e4..c6ca66ba94 100644 --- a/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java @@ -17,6 +17,7 @@ package io.grpc.internal; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED; import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; @@ -24,6 +25,7 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.same; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -136,7 +138,7 @@ public class AbstractClientStreamTest { stream.cancel(Status.DEADLINE_EXCEEDED); stream.cancel(Status.DEADLINE_EXCEEDED); - verify(mockListener).closed(any(Status.class), any(Metadata.class)); + verify(mockListener).closed(any(Status.class), same(PROCESSED), any(Metadata.class)); } @Test @@ -324,7 +326,7 @@ public class AbstractClientStreamTest { // Simulate getting a reset stream.transportState().transportReportStatus(status, false /*stop delivery*/, new Metadata()); - verify(mockListener).closed(any(Status.class), any(Metadata.class)); + verify(mockListener).closed(any(Status.class), same(PROCESSED), any(Metadata.class)); } @Test diff --git a/core/src/test/java/io/grpc/internal/Http2ClientStreamTransportStateTest.java b/core/src/test/java/io/grpc/internal/Http2ClientStreamTransportStateTest.java index 86c7f6513d..5cb526299f 100644 --- a/core/src/test/java/io/grpc/internal/Http2ClientStreamTransportStateTest.java +++ b/core/src/test/java/io/grpc/internal/Http2ClientStreamTransportStateTest.java @@ -17,6 +17,7 @@ package io.grpc.internal; import static com.google.common.base.Charsets.US_ASCII; +import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED; import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -79,7 +80,7 @@ public class Http2ClientStreamTransportStateTest { "application/grpc"); state.transportHeadersReceived(headers); - verify(mockListener, never()).closed(any(Status.class), any(Metadata.class)); + verify(mockListener, never()).closed(any(Status.class), same(PROCESSED), any(Metadata.class)); verify(mockListener).headersRead(headers); } @@ -93,7 +94,7 @@ public class Http2ClientStreamTransportStateTest { "application/grpc"); state.transportHeadersReceived(headers); - verify(mockListener, never()).closed(any(Status.class), any(Metadata.class)); + verify(mockListener, never()).closed(any(Status.class), same(PROCESSED), any(Metadata.class)); verify(mockListener).headersRead(headers); } @@ -108,7 +109,7 @@ public class Http2ClientStreamTransportStateTest { state.transportDataReceived(ReadableBuffers.empty(), true); verify(mockListener, never()).headersRead(any(Metadata.class)); - verify(mockListener).closed(statusCaptor.capture(), same(headers)); + verify(mockListener).closed(statusCaptor.capture(), same(PROCESSED), same(headers)); assertEquals(Code.INTERNAL, statusCaptor.getValue().getCode()); } @@ -123,7 +124,7 @@ public class Http2ClientStreamTransportStateTest { state.transportDataReceived(ReadableBuffers.empty(), true); verify(mockListener, never()).headersRead(any(Metadata.class)); - verify(mockListener).closed(statusCaptor.capture(), same(headers)); + verify(mockListener).closed(statusCaptor.capture(), same(PROCESSED), same(headers)); assertEquals(Code.UNKNOWN, statusCaptor.getValue().getCode()); assertTrue(statusCaptor.getValue().getDescription().contains("200")); } @@ -139,7 +140,7 @@ public class Http2ClientStreamTransportStateTest { state.transportDataReceived(ReadableBuffers.empty(), true); verify(mockListener, never()).headersRead(any(Metadata.class)); - verify(mockListener).closed(statusCaptor.capture(), same(headers)); + verify(mockListener).closed(statusCaptor.capture(), same(PROCESSED), same(headers)); assertEquals(Code.UNAUTHENTICATED, statusCaptor.getValue().getCode()); assertTrue(statusCaptor.getValue().getDescription().contains("401")); assertTrue(statusCaptor.getValue().getDescription().contains("text/html")); @@ -163,7 +164,7 @@ public class Http2ClientStreamTransportStateTest { "application/grpc"); state.transportHeadersReceived(headers); - verify(mockListener, never()).closed(any(Status.class), any(Metadata.class)); + verify(mockListener, never()).closed(any(Status.class), same(PROCESSED), any(Metadata.class)); verify(mockListener).headersRead(headers); } @@ -181,7 +182,7 @@ public class Http2ClientStreamTransportStateTest { state.transportDataReceived(ReadableBuffers.empty(), true); verify(mockListener).headersRead(headers); - verify(mockListener).closed(statusCaptor.capture(), same(headersAgain)); + verify(mockListener).closed(statusCaptor.capture(), same(PROCESSED), same(headersAgain)); assertEquals(Code.INTERNAL, statusCaptor.getValue().getCode()); assertTrue(statusCaptor.getValue().getDescription().contains("twice")); } @@ -201,7 +202,7 @@ public class Http2ClientStreamTransportStateTest { state.transportDataReceived(ReadableBuffers.empty(), true); verify(mockListener, never()).headersRead(any(Metadata.class)); - verify(mockListener).closed(statusCaptor.capture(), same(headers)); + verify(mockListener).closed(statusCaptor.capture(), same(PROCESSED), same(headers)); assertEquals(Code.UNKNOWN, statusCaptor.getValue().getCode()); assertTrue(statusCaptor.getValue().getDescription().contains(testString)); } @@ -213,7 +214,7 @@ public class Http2ClientStreamTransportStateTest { String testString = "This is a test"; state.transportDataReceived(ReadableBuffers.wrap(testString.getBytes(US_ASCII)), true); - verify(mockListener).closed(statusCaptor.capture(), any(Metadata.class)); + verify(mockListener).closed(statusCaptor.capture(), same(PROCESSED), any(Metadata.class)); assertEquals(Code.INTERNAL, statusCaptor.getValue().getCode()); } @@ -228,7 +229,7 @@ public class Http2ClientStreamTransportStateTest { String testString = "This is a test"; state.transportDataReceived(ReadableBuffers.wrap(testString.getBytes(US_ASCII)), true); - verify(mockListener).closed(statusCaptor.capture(), same(headers)); + verify(mockListener).closed(statusCaptor.capture(), same(PROCESSED), same(headers)); assertTrue(statusCaptor.getValue().getDescription().contains(testString)); } @@ -244,7 +245,7 @@ public class Http2ClientStreamTransportStateTest { state.transportTrailersReceived(trailers); verify(mockListener, never()).headersRead(any(Metadata.class)); - verify(mockListener).closed(Status.OK, trailers); + verify(mockListener).closed(Status.OK, PROCESSED, trailers); } @Test @@ -261,7 +262,7 @@ public class Http2ClientStreamTransportStateTest { state.transportTrailersReceived(trailers); verify(mockListener).headersRead(headers); - verify(mockListener).closed(Status.OK, trailers); + verify(mockListener).closed(Status.OK, PROCESSED, trailers); } @Test @@ -276,7 +277,7 @@ public class Http2ClientStreamTransportStateTest { state.transportTrailersReceived(trailers); verify(mockListener, never()).headersRead(any(Metadata.class)); - verify(mockListener).closed(Status.CANCELLED, trailers); + verify(mockListener).closed(Status.CANCELLED, PROCESSED, trailers); } @Test @@ -290,7 +291,7 @@ public class Http2ClientStreamTransportStateTest { state.transportTrailersReceived(trailers); verify(mockListener, never()).headersRead(any(Metadata.class)); - verify(mockListener).closed(statusCaptor.capture(), same(trailers)); + verify(mockListener).closed(statusCaptor.capture(), same(PROCESSED), same(trailers)); assertEquals(Code.UNAUTHENTICATED, statusCaptor.getValue().getCode()); assertTrue(statusCaptor.getValue().getDescription().contains("401")); } @@ -306,7 +307,7 @@ public class Http2ClientStreamTransportStateTest { state.transportTrailersReceived(trailers); verify(mockListener, never()).headersRead(any(Metadata.class)); - verify(mockListener).closed(statusCaptor.capture(), same(trailers)); + verify(mockListener).closed(statusCaptor.capture(), same(PROCESSED), same(trailers)); assertEquals(Code.INTERNAL, statusCaptor.getValue().getCode()); } @@ -320,7 +321,7 @@ public class Http2ClientStreamTransportStateTest { state.transportTrailersReceived(trailers); verify(mockListener, never()).headersRead(any(Metadata.class)); - verify(mockListener).closed(statusCaptor.capture(), same(trailers)); + verify(mockListener).closed(statusCaptor.capture(), same(PROCESSED), same(trailers)); assertEquals(Code.INTERNAL, statusCaptor.getValue().getCode()); } @@ -338,7 +339,7 @@ public class Http2ClientStreamTransportStateTest { state.transportTrailersReceived(trailers); verify(mockListener).headersRead(headers); - verify(mockListener).closed(statusCaptor.capture(), same(trailers)); + verify(mockListener).closed(statusCaptor.capture(), same(PROCESSED), same(trailers)); assertEquals(Code.UNKNOWN, statusCaptor.getValue().getCode()); } diff --git a/core/src/test/java/io/grpc/internal/NoopClientStreamListener.java b/core/src/test/java/io/grpc/internal/NoopClientStreamListener.java index 7392a68683..3436ee6efa 100644 --- a/core/src/test/java/io/grpc/internal/NoopClientStreamListener.java +++ b/core/src/test/java/io/grpc/internal/NoopClientStreamListener.java @@ -22,7 +22,7 @@ import io.grpc.Status; /** * No-op base class for testing. */ -class NoopClientStreamListener implements ClientStreamListener { +public class NoopClientStreamListener implements ClientStreamListener { @Override public void messagesAvailable(MessageProducer producer) {} @@ -34,4 +34,7 @@ class NoopClientStreamListener implements ClientStreamListener { @Override public void closed(Status status, Metadata trailers) {} + + @Override + public void closed(Status status, RpcProgress rpcProgress, Metadata trailers) {} } diff --git a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java index 62042a4750..20d7c070f2 100644 --- a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java +++ b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java @@ -17,11 +17,13 @@ package io.grpc.internal; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED; +import static io.grpc.internal.ClientStreamListener.RpcProgress.REFUSED; import static io.grpc.internal.RetriableStream.GRPC_PREVIOUS_RPC_ATTEMPTS; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.Matchers.any; @@ -927,7 +929,8 @@ public class RetriableStreamTest { public void updateHeaders() { Metadata originalHeaders = new Metadata(); Metadata headers = retriableStream.updateHeaders(originalHeaders, 0); - assertSame(originalHeaders, headers); + assertNotSame(originalHeaders, headers); + assertNull(headers.get(GRPC_PREVIOUS_RPC_ATTEMPTS)); headers = retriableStream.updateHeaders(originalHeaders, 345); assertEquals("345", headers.get(GRPC_PREVIOUS_RPC_ATTEMPTS)); @@ -1363,6 +1366,146 @@ public class RetriableStreamTest { assertTrue(throttle.isAboveThreshold()); // count = 2.6 } + @Test + public void transparentRetry() { + ClientStream mockStream1 = mock(ClientStream.class); + ClientStream mockStream2 = mock(ClientStream.class); + ClientStream mockStream3 = mock(ClientStream.class); + InOrder inOrder = inOrder( + retriableStreamRecorder, + mockStream1, mockStream2, mockStream3); + + // start + doReturn(mockStream1).when(retriableStreamRecorder).newSubstream(0); + retriableStream.start(masterListener); + + inOrder.verify(retriableStreamRecorder).newSubstream(0); + ArgumentCaptor sublistenerCaptor1 = + ArgumentCaptor.forClass(ClientStreamListener.class); + inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verifyNoMoreInteractions(); + + // transparent retry + doReturn(mockStream2).when(retriableStreamRecorder).newSubstream(0); + sublistenerCaptor1.getValue() + .closed(Status.fromCode(NON_RETRIABLE_STATUS_CODE), REFUSED, new Metadata()); + + inOrder.verify(retriableStreamRecorder).newSubstream(0); + ArgumentCaptor sublistenerCaptor2 = + ArgumentCaptor.forClass(ClientStreamListener.class); + inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); + inOrder.verifyNoMoreInteractions(); + verify(retriableStreamRecorder, never()).postCommit(); + assertEquals(0, fakeClock.numPendingTasks()); + + // no more transparent retry + doReturn(mockStream3).when(retriableStreamRecorder).newSubstream(1); + sublistenerCaptor2.getValue() + .closed(Status.fromCode(RETRIABLE_STATUS_CODE_1), REFUSED, new Metadata()); + + assertEquals(1, fakeClock.numPendingTasks()); + fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + inOrder.verify(retriableStreamRecorder).newSubstream(1); + ArgumentCaptor sublistenerCaptor3 = + ArgumentCaptor.forClass(ClientStreamListener.class); + inOrder.verify(mockStream3).start(sublistenerCaptor3.capture()); + inOrder.verifyNoMoreInteractions(); + verify(retriableStreamRecorder, never()).postCommit(); + assertEquals(0, fakeClock.numPendingTasks()); + } + + @Test + public void normalRetry_thenNoTransparentRetry_butNormalRetry() { + ClientStream mockStream1 = mock(ClientStream.class); + ClientStream mockStream2 = mock(ClientStream.class); + ClientStream mockStream3 = mock(ClientStream.class); + InOrder inOrder = inOrder( + retriableStreamRecorder, + mockStream1, mockStream2, mockStream3); + + // start + doReturn(mockStream1).when(retriableStreamRecorder).newSubstream(0); + retriableStream.start(masterListener); + + inOrder.verify(retriableStreamRecorder).newSubstream(0); + ArgumentCaptor sublistenerCaptor1 = + ArgumentCaptor.forClass(ClientStreamListener.class); + inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verifyNoMoreInteractions(); + + // normal retry + doReturn(mockStream2).when(retriableStreamRecorder).newSubstream(1); + sublistenerCaptor1.getValue() + .closed(Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); + + assertEquals(1, fakeClock.numPendingTasks()); + fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + inOrder.verify(retriableStreamRecorder).newSubstream(1); + ArgumentCaptor sublistenerCaptor2 = + ArgumentCaptor.forClass(ClientStreamListener.class); + inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); + inOrder.verifyNoMoreInteractions(); + verify(retriableStreamRecorder, never()).postCommit(); + assertEquals(0, fakeClock.numPendingTasks()); + + // no more transparent retry + doReturn(mockStream3).when(retriableStreamRecorder).newSubstream(2); + sublistenerCaptor2.getValue() + .closed(Status.fromCode(RETRIABLE_STATUS_CODE_1), REFUSED, new Metadata()); + + assertEquals(1, fakeClock.numPendingTasks()); + fakeClock.forwardTime( + (long) (INITIAL_BACKOFF_IN_SECONDS * BACKOFF_MULTIPLIER * FAKE_RANDOM), TimeUnit.SECONDS); + inOrder.verify(retriableStreamRecorder).newSubstream(2); + ArgumentCaptor sublistenerCaptor3 = + ArgumentCaptor.forClass(ClientStreamListener.class); + inOrder.verify(mockStream3).start(sublistenerCaptor3.capture()); + inOrder.verifyNoMoreInteractions(); + verify(retriableStreamRecorder, never()).postCommit(); + } + + @Test + public void normalRetry_thenNoTransparentRetry_andNoMoreRetry() { + ClientStream mockStream1 = mock(ClientStream.class); + ClientStream mockStream2 = mock(ClientStream.class); + ClientStream mockStream3 = mock(ClientStream.class); + InOrder inOrder = inOrder( + retriableStreamRecorder, + mockStream1, mockStream2, mockStream3); + + // start + doReturn(mockStream1).when(retriableStreamRecorder).newSubstream(0); + retriableStream.start(masterListener); + + inOrder.verify(retriableStreamRecorder).newSubstream(0); + ArgumentCaptor sublistenerCaptor1 = + ArgumentCaptor.forClass(ClientStreamListener.class); + inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verifyNoMoreInteractions(); + + // normal retry + doReturn(mockStream2).when(retriableStreamRecorder).newSubstream(1); + sublistenerCaptor1.getValue() + .closed(Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); + + assertEquals(1, fakeClock.numPendingTasks()); + fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + inOrder.verify(retriableStreamRecorder).newSubstream(1); + ArgumentCaptor sublistenerCaptor2 = + ArgumentCaptor.forClass(ClientStreamListener.class); + inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); + inOrder.verifyNoMoreInteractions(); + verify(retriableStreamRecorder, never()).postCommit(); + assertEquals(0, fakeClock.numPendingTasks()); + + // no more transparent retry + doReturn(mockStream3).when(retriableStreamRecorder).newSubstream(2); + sublistenerCaptor2.getValue() + .closed(Status.fromCode(NON_RETRIABLE_STATUS_CODE), REFUSED, new Metadata()); + + verify(retriableStreamRecorder).postCommit(); + } + /** * Used to stub a retriable stream as well as to record methods of the retriable stream being * called. diff --git a/cronet/src/test/java/io/grpc/cronet/CronetClientStreamTest.java b/cronet/src/test/java/io/grpc/cronet/CronetClientStreamTest.java index eac3ee153f..176033c3a5 100644 --- a/cronet/src/test/java/io/grpc/cronet/CronetClientStreamTest.java +++ b/cronet/src/test/java/io/grpc/cronet/CronetClientStreamTest.java @@ -35,6 +35,7 @@ import io.grpc.MethodDescriptor; import io.grpc.Status; import io.grpc.cronet.CronetChannelBuilder.StreamBuilderFactory; import io.grpc.internal.ClientStreamListener; +import io.grpc.internal.ClientStreamListener.RpcProgress; import io.grpc.internal.GrpcUtil; import io.grpc.internal.StatsTraceContext; import io.grpc.internal.StreamListener.MessageProducer; @@ -322,7 +323,8 @@ public final class CronetClientStreamTest { // Verify trailer ArgumentCaptor trailerCaptor = ArgumentCaptor.forClass(Metadata.class); ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); - verify(clientListener).closed(statusCaptor.capture(), trailerCaptor.capture()); + verify(clientListener) + .closed(statusCaptor.capture(), isA(RpcProgress.class), trailerCaptor.capture()); // Verify recevied headers. Metadata trailers = trailerCaptor.getValue(); Status status = statusCaptor.getValue(); @@ -365,7 +367,8 @@ public final class CronetClientStreamTest { callback.onSucceeded(cronetStream, info); ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); - verify(clientListener).closed(statusCaptor.capture(), isA(Metadata.class)); + verify(clientListener) + .closed(statusCaptor.capture(), isA(RpcProgress.class), isA(Metadata.class)); // Verify error status. Status status = statusCaptor.getValue(); assertFalse(status.isOk()); @@ -390,7 +393,8 @@ public final class CronetClientStreamTest { clientStream.transportState().transportReportStatus(Status.UNAVAILABLE, false, new Metadata()); ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); - verify(clientListener).closed(statusCaptor.capture(), isA(Metadata.class)); + verify(clientListener) + .closed(statusCaptor.capture(), isA(RpcProgress.class), isA(Metadata.class)); Status status = statusCaptor.getValue(); assertEquals(Status.UNAVAILABLE.getCode(), status.getCode()); } @@ -417,7 +421,8 @@ public final class CronetClientStreamTest { clientStream.transportState().transportReportStatus(Status.UNAVAILABLE, false, new Metadata()); ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); - verify(clientListener).closed(statusCaptor.capture(), isA(Metadata.class)); + verify(clientListener) + .closed(statusCaptor.capture(), isA(RpcProgress.class), isA(Metadata.class)); Status status = statusCaptor.getValue(); assertEquals(Status.UNAVAILABLE.getCode(), status.getCode()); } @@ -447,7 +452,8 @@ public final class CronetClientStreamTest { clientStream.transportState().transportReportStatus(Status.UNAVAILABLE, false, new Metadata()); ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); - verify(clientListener).closed(statusCaptor.capture(), isA(Metadata.class)); + verify(clientListener) + .closed(statusCaptor.capture(), isA(RpcProgress.class), isA(Metadata.class)); Status status = statusCaptor.getValue(); // Stream has already finished so OK status should be reported. assertEquals(Status.UNAVAILABLE.getCode(), status.getCode()); @@ -479,7 +485,8 @@ public final class CronetClientStreamTest { clientStream.transportState().transportReportStatus(Status.UNAVAILABLE, false, new Metadata()); ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); - verify(clientListener).closed(statusCaptor.capture(), isA(Metadata.class)); + verify(clientListener) + .closed(statusCaptor.capture(), isA(RpcProgress.class), isA(Metadata.class)); Status status = statusCaptor.getValue(); // Stream has already finished so OK status should be reported. assertEquals(Status.OK.getCode(), status.getCode()); @@ -522,12 +529,14 @@ public final class CronetClientStreamTest { // Receive trailer first ((CronetClientStream.BidirectionalStreamCallback) callback) .processTrailers(trailers(Status.UNAUTHENTICATED.getCode().value())); - verify(clientListener, times(0)).closed(isA(Status.class), isA(Metadata.class)); + verify(clientListener, times(0)) + .closed(isA(Status.class), isA(RpcProgress.class), isA(Metadata.class)); // Receive cronet's endOfStream callback.onReadCompleted(cronetStream, null, ByteBuffer.allocate(0), true); ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); - verify(clientListener, times(1)).closed(statusCaptor.capture(), isA(Metadata.class)); + verify(clientListener, times(1)) + .closed(statusCaptor.capture(), isA(RpcProgress.class), isA(Metadata.class)); Status status = statusCaptor.getValue(); assertEquals(Status.UNAUTHENTICATED.getCode(), status.getCode()); } @@ -548,13 +557,15 @@ public final class CronetClientStreamTest { callback.onResponseHeadersReceived(cronetStream, info); // Receive cronet's endOfStream callback.onReadCompleted(cronetStream, null, ByteBuffer.allocate(0), true); - verify(clientListener, times(0)).closed(isA(Status.class), isA(Metadata.class)); + verify(clientListener, times(0)) + .closed(isA(Status.class), isA(RpcProgress.class), isA(Metadata.class)); // Receive trailer ((CronetClientStream.BidirectionalStreamCallback) callback) .processTrailers(trailers(Status.UNAUTHENTICATED.getCode().value())); ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); - verify(clientListener, times(1)).closed(statusCaptor.capture(), isA(Metadata.class)); + verify(clientListener, times(1)) + .closed(statusCaptor.capture(), isA(RpcProgress.class), isA(Metadata.class)); Status status = statusCaptor.getValue(); assertEquals(Status.UNAUTHENTICATED.getCode(), status.getCode()); } diff --git a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java index 053831b616..552d174bbc 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java @@ -27,6 +27,7 @@ import io.grpc.Attributes; import io.grpc.Metadata; import io.grpc.Status; import io.grpc.StatusException; +import io.grpc.internal.ClientStreamListener.RpcProgress; import io.grpc.internal.ClientTransport.PingCallback; import io.grpc.internal.GrpcUtil; import io.grpc.internal.Http2Ping; @@ -348,7 +349,12 @@ class NettyClientHandler extends AbstractNettyHandler { if (stream != null) { Status status = GrpcUtil.Http2Error.statusForCode((int) errorCode) .augmentDescription("Received Rst Stream"); - stream.transportReportStatus(status, false /*stop delivery*/, new Metadata()); + stream.transportReportStatus( + status, + errorCode == Http2Error.REFUSED_STREAM.code() + ? RpcProgress.REFUSED : RpcProgress.PROCESSED, + false /*stop delivery*/, + new Metadata()); if (keepAliveManager != null) { keepAliveManager.onDataReceived(); } @@ -617,7 +623,7 @@ class NettyClientHandler extends AbstractNettyHandler { } /** - * Handler for a GOAWAY being either sent or received. Fails any streams created after the + * Handler for a GOAWAY being received. Fails any streams created after the * last known stream. */ private void goingAway(Status status) { @@ -631,7 +637,8 @@ class NettyClientHandler extends AbstractNettyHandler { if (stream.id() > lastKnownStream) { NettyClientStream.TransportState clientStream = clientStream(stream); if (clientStream != null) { - clientStream.transportReportStatus(goAwayStatus, false, new Metadata()); + clientStream.transportReportStatus( + goAwayStatus, RpcProgress.REFUSED, false, new Metadata()); } stream.close(); } diff --git a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java index edb22579f2..cfdd5a742b 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java @@ -17,6 +17,8 @@ package io.grpc.netty; import static com.google.common.base.Charsets.UTF_8; +import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED; +import static io.grpc.internal.ClientStreamListener.RpcProgress.REFUSED; import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC; import static io.grpc.netty.Utils.CONTENT_TYPE_HEADER; @@ -35,8 +37,10 @@ import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Matchers.notNull; +import static org.mockito.Matchers.same; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -52,6 +56,7 @@ import io.grpc.Metadata; import io.grpc.Status; import io.grpc.StatusException; import io.grpc.internal.ClientStreamListener; +import io.grpc.internal.ClientStreamListener.RpcProgress; import io.grpc.internal.ClientTransport; import io.grpc.internal.ClientTransport.PingCallback; import io.grpc.internal.GrpcUtil; @@ -193,7 +198,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase captor = ArgumentCaptor.forClass(Status.class); - verify(streamListener).closed(captor.capture(), notNull(Metadata.class)); + verify(streamListener).closed(captor.capture(), same(REFUSED), notNull(Metadata.class)); assertEquals(Status.CANCELLED.getCode(), captor.getValue().getCode()); assertEquals("HTTP/2 error code: CANCEL\nReceived Goaway\nthis is a test", captor.getValue().getDescription()); @@ -423,7 +463,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase captor = ArgumentCaptor.forClass(Status.class); - verify(streamListener).closed(captor.capture(), notNull(Metadata.class)); + verify(streamListener).closed(captor.capture(), same(PROCESSED), notNull(Metadata.class)); assertEquals(Status.UNAVAILABLE.getCode(), captor.getValue().getCode()); } diff --git a/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java b/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java index 13bec73368..fe67313231 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java @@ -17,6 +17,7 @@ package io.grpc.netty; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED; import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import static io.grpc.netty.NettyTestUtil.messageFrame; import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC; @@ -208,14 +209,14 @@ public class NettyClientStreamTest extends NettyStreamTestBase captor = ArgumentCaptor.forClass(Status.class); - verify(listener).closed(captor.capture(), any(Metadata.class)); + verify(listener).closed(captor.capture(), same(PROCESSED), any(Metadata.class)); assertEquals(Status.INTERNAL.getCode(), captor.getValue().getCode()); } @@ -295,7 +296,7 @@ public class NettyClientStreamTest extends NettyStreamTestBase captor = ArgumentCaptor.forClass(Status.class); ArgumentCaptor metadataCaptor = ArgumentCaptor.forClass(Metadata.class); - verify(listener).closed(captor.capture(), metadataCaptor.capture()); + verify(listener).closed(captor.capture(), same(PROCESSED), metadataCaptor.capture()); assertEquals(Status.UNKNOWN.getCode(), captor.getValue().getCode()); assertEquals("4", metadataCaptor.getValue() .get(Metadata.Key.of("random", Metadata.ASCII_STRING_MARSHALLER))); @@ -314,7 +315,7 @@ public class NettyClientStreamTest extends NettyStreamTestBase captor = ArgumentCaptor.forClass(Status.class); ArgumentCaptor metadataCaptor = ArgumentCaptor.forClass(Metadata.class); - verify(listener).closed(captor.capture(), metadataCaptor.capture()); + verify(listener).closed(captor.capture(), same(PROCESSED), metadataCaptor.capture()); Status status = captor.getValue(); assertEquals(Status.Code.UNKNOWN, status.getCode()); assertTrue(status.getDescription().contains("content-type")); @@ -326,7 +327,7 @@ public class NettyClientStreamTest extends NettyStreamTestBase captor = ArgumentCaptor.forClass(Status.class); - verify(listener).closed(captor.capture(), any(Metadata.class)); + verify(listener).closed(captor.capture(), same(PROCESSED), any(Metadata.class)); assertEquals(Status.Code.INTERNAL, captor.getValue().getCode()); } @@ -362,7 +363,7 @@ public class NettyClientStreamTest extends NettyStreamTestBase captor = ArgumentCaptor.forClass(Status.class); - verify(listener).closed(captor.capture(), any(Metadata.class)); + verify(listener).closed(captor.capture(), same(PROCESSED), any(Metadata.class)); assertEquals(Status.Code.INTERNAL, captor.getValue().getCode()); } diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java index cd9845ab58..7e07238207 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java @@ -676,6 +676,11 @@ public class NettyClientTransportTest { @Override public void closed(Status status, Metadata trailers) { + closed(status, RpcProgress.PROCESSED, trailers); + } + + @Override + public void closed(Status status, RpcProgress rpcProgress, Metadata trailers) { if (status.isOk()) { closedFuture.set(null); } else { diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java index 60bdcf7334..4cb6153a9c 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java @@ -18,6 +18,7 @@ package io.grpc.okhttp; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED; import com.google.common.io.BaseEncoding; import io.grpc.Attributes; @@ -307,8 +308,11 @@ class OkHttpClientStream extends AbstractClientStream { window -= length; if (window < 0) { frameWriter.rstStream(id(), ErrorCode.FLOW_CONTROL_ERROR); - transport.finishStream(id(), Status.INTERNAL.withDescription( - "Received data size exceeded our receiving window size"), false, null, null); + transport.finishStream( + id(), + Status.INTERNAL.withDescription( + "Received data size exceeded our receiving window size"), + PROCESSED, false, null, null); return; } super.transportDataReceived(new OkHttpReadableBuffer(frame), endOfStream); @@ -319,9 +323,9 @@ class OkHttpClientStream extends AbstractClientStream { if (!framer().isClosed()) { // If server's end-of-stream is received before client sends end-of-stream, we just send a // reset to server to fully close the server side stream. - transport.finishStream(id(), null, false, ErrorCode.CANCEL, null); + transport.finishStream(id(),null, PROCESSED, false, ErrorCode.CANCEL, null); } else { - transport.finishStream(id(), null, false, null, null); + transport.finishStream(id(), null, PROCESSED, false, null, null); } } @@ -344,7 +348,8 @@ class OkHttpClientStream extends AbstractClientStream { } else { // If pendingData is null, start must have already been called, which means synStream has // been called as well. - transport.finishStream(id(), reason, stopDelivery, ErrorCode.CANCEL, trailers); + transport.finishStream( + id(), reason, PROCESSED, stopDelivery, ErrorCode.CANCEL, trailers); } } diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java index 5a0a70d10e..a1fc9fa1c3 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java @@ -39,6 +39,7 @@ import io.grpc.Status.Code; import io.grpc.StatusException; import io.grpc.internal.Channelz.Security; import io.grpc.internal.Channelz.SocketStats; +import io.grpc.internal.ClientStreamListener.RpcProgress; import io.grpc.internal.ConnectionClientTransport; import io.grpc.internal.GrpcUtil; import io.grpc.internal.Http2Ping; @@ -733,12 +734,14 @@ class OkHttpClientTransport implements ConnectionClientTransport { Map.Entry entry = it.next(); if (entry.getKey() > lastKnownStreamId) { it.remove(); - entry.getValue().transportState().transportReportStatus(status, false, new Metadata()); + entry.getValue().transportState().transportReportStatus( + status, RpcProgress.REFUSED, false, new Metadata()); } } for (OkHttpClientStream stream : pendingStreams) { - stream.transportState().transportReportStatus(status, true, new Metadata()); + stream.transportState().transportReportStatus( + status, RpcProgress.REFUSED, true, new Metadata()); } pendingStreams.clear(); maybeClearInUse(); @@ -765,6 +768,7 @@ class OkHttpClientTransport implements ConnectionClientTransport { void finishStream( int streamId, @Nullable Status status, + RpcProgress rpcProgress, boolean stopDelivery, @Nullable ErrorCode errorCode, @Nullable Metadata trailers) { @@ -779,6 +783,7 @@ class OkHttpClientTransport implements ConnectionClientTransport { .transportState() .transportReportStatus( status, + rpcProgress, stopDelivery, trailers != null ? trailers : new Metadata()); } @@ -1020,7 +1025,10 @@ class OkHttpClientTransport implements ConnectionClientTransport { Status status = toGrpcStatus(errorCode).augmentDescription("Rst Stream"); boolean stopDelivery = (status.getCode() == Code.CANCELLED || status.getCode() == Code.DEADLINE_EXCEEDED); - finishStream(streamId, status, stopDelivery, null, null); + finishStream( + streamId, status, + errorCode == ErrorCode.REFUSED_STREAM ? RpcProgress.REFUSED : RpcProgress.PROCESSED, + stopDelivery, null, null); } @Override @@ -1112,8 +1120,9 @@ class OkHttpClientTransport implements ConnectionClientTransport { if (streamId == 0) { onError(ErrorCode.PROTOCOL_ERROR, errorMsg); } else { - finishStream(streamId, - Status.INTERNAL.withDescription(errorMsg), false, ErrorCode.PROTOCOL_ERROR, null); + finishStream( + streamId, Status.INTERNAL.withDescription(errorMsg), RpcProgress.PROCESSED, false, + ErrorCode.PROTOCOL_ERROR, null); } return; } diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java index 007e35f643..79340a716d 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java @@ -17,6 +17,7 @@ package io.grpc.okhttp; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.eq; @@ -30,8 +31,8 @@ import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor.MethodType; import io.grpc.Status; -import io.grpc.internal.ClientStreamListener; import io.grpc.internal.GrpcUtil; +import io.grpc.internal.NoopClientStreamListener; import io.grpc.internal.StatsTraceContext; import io.grpc.internal.TransportTracer; import io.grpc.okhttp.internal.framed.ErrorCode; @@ -102,7 +103,8 @@ public class OkHttpClientStreamTest { final AtomicReference statusRef = new AtomicReference(); stream.start(new BaseClientStreamListener() { @Override - public void closed(Status status, Metadata trailers) { + public void closed( + Status status, RpcProgress rpcProgress, Metadata trailers) { statusRef.set(status); assertTrue(Thread.holdsLock(lock)); } @@ -123,11 +125,12 @@ public class OkHttpClientStreamTest { assertTrue(Thread.holdsLock(lock)); return null; } - }).when(transport).finishStream(1234, Status.CANCELLED, true, ErrorCode.CANCEL, null); + }).when(transport).finishStream( + 1234, Status.CANCELLED, PROCESSED, true, ErrorCode.CANCEL, null); stream.cancel(Status.CANCELLED); - verify(transport).finishStream(1234, Status.CANCELLED, true, ErrorCode.CANCEL, null); + verify(transport).finishStream(1234, Status.CANCELLED, PROCESSED,true, ErrorCode.CANCEL, null); } @Test @@ -213,20 +216,12 @@ public class OkHttpClientStreamTest { // TODO(carl-mastrangelo): extract this out into a testing/ directory and remove other definitions // of it. - private static class BaseClientStreamListener implements ClientStreamListener { - @Override - public void onReady() {} + private static class BaseClientStreamListener extends NoopClientStreamListener { @Override public void messagesAvailable(MessageProducer producer) { while (producer.next() != null) {} } - - @Override - public void headersRead(Metadata headers) {} - - @Override - public void closed(Status status, Metadata trailers) {} } } diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java index b9d7061779..bac6a8e67c 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java @@ -18,6 +18,8 @@ package io.grpc.okhttp; import static com.google.common.base.Charsets.UTF_8; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED; +import static io.grpc.internal.ClientStreamListener.RpcProgress.REFUSED; import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import static io.grpc.okhttp.Headers.CONTENT_TYPE_HEADER; import static io.grpc.okhttp.Headers.METHOD_HEADER; @@ -128,6 +130,7 @@ public class OkHttpClientTransportTest { private static final ProxyParameters NO_PROXY = null; private static final String NO_USER = null; private static final String NO_PW = null; + private static final int DEFAULT_START_STREAM_ID = 3; @Rule public final Timeout globalTimeout = Timeout.seconds(10); @@ -168,7 +171,7 @@ public class OkHttpClientTransportTest { } private void initTransport() throws Exception { - startTransport(3, null, true, DEFAULT_MAX_MESSAGE_SIZE, null); + startTransport(DEFAULT_START_STREAM_ID, null, true, DEFAULT_MAX_MESSAGE_SIZE, null); } private void initTransport(int startId) throws Exception { @@ -177,7 +180,8 @@ public class OkHttpClientTransportTest { private void initTransportAndDelayConnected() throws Exception { delayConnectedCallback = new DelayConnectedCallback(); - startTransport(3, delayConnectedCallback, false, DEFAULT_MAX_MESSAGE_SIZE, null); + startTransport( + DEFAULT_START_STREAM_ID, delayConnectedCallback, false, DEFAULT_MAX_MESSAGE_SIZE, null); } private void startTransport(int startId, @Nullable Runnable connectingCallback, @@ -1663,6 +1667,88 @@ public class OkHttpClientTransportTest { shutdownAndVerify(); } + @Test + public void goAway_streamListenerRpcProgress() throws Exception { + initTransport(); + setMaxConcurrentStreams(2); + MockStreamListener listener1 = new MockStreamListener(); + MockStreamListener listener2 = new MockStreamListener(); + MockStreamListener listener3 = new MockStreamListener(); + OkHttpClientStream stream1 = + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + stream1.start(listener1); + OkHttpClientStream stream2 = + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + stream2.start(listener2); + OkHttpClientStream stream3 = + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + stream3.start(listener3); + waitForStreamPending(1); + + assertEquals(2, activeStreamCount()); + assertContainStream(DEFAULT_START_STREAM_ID); + assertContainStream(DEFAULT_START_STREAM_ID + 2); + + frameHandler() + .goAway(DEFAULT_START_STREAM_ID, ErrorCode.CANCEL, ByteString.encodeUtf8("blablabla")); + + listener2.waitUntilStreamClosed(); + listener3.waitUntilStreamClosed(); + assertNull(listener1.rpcProgress); + assertEquals(REFUSED, listener2.rpcProgress); + assertEquals(REFUSED, listener3.rpcProgress); + assertEquals(1, activeStreamCount()); + assertContainStream(DEFAULT_START_STREAM_ID); + + getStream(DEFAULT_START_STREAM_ID).cancel(Status.CANCELLED); + + listener1.waitUntilStreamClosed(); + assertEquals(PROCESSED, listener1.rpcProgress); + + shutdownAndVerify(); + } + + @Test + public void reset_streamListenerRpcProgress() throws Exception { + initTransport(); + MockStreamListener listener1 = new MockStreamListener(); + MockStreamListener listener2 = new MockStreamListener(); + MockStreamListener listener3 = new MockStreamListener(); + OkHttpClientStream stream1 = + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + stream1.start(listener1); + OkHttpClientStream stream2 = + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + stream2.start(listener2); + OkHttpClientStream stream3 = + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + stream3.start(listener3); + + assertEquals(3, activeStreamCount()); + assertContainStream(DEFAULT_START_STREAM_ID); + assertContainStream(DEFAULT_START_STREAM_ID + 2); + assertContainStream(DEFAULT_START_STREAM_ID + 4); + + frameHandler().rstStream(DEFAULT_START_STREAM_ID + 2, ErrorCode.REFUSED_STREAM); + + listener2.waitUntilStreamClosed(); + assertNull(listener1.rpcProgress); + assertEquals(REFUSED, listener2.rpcProgress); + assertNull(listener3.rpcProgress); + + frameHandler().rstStream(DEFAULT_START_STREAM_ID, ErrorCode.CANCEL); + listener1.waitUntilStreamClosed(); + assertEquals(PROCESSED, listener1.rpcProgress); + assertNull(listener3.rpcProgress); + + getStream(DEFAULT_START_STREAM_ID + 4).cancel(Status.CANCELLED); + + listener3.waitUntilStreamClosed(); + assertEquals(PROCESSED, listener3.rpcProgress); + + shutdownAndVerify(); + } + private int activeStreamCount() { return clientTransport.getActiveStreams().length; } @@ -1813,6 +1899,7 @@ public class OkHttpClientTransportTest { Status status; Metadata headers; Metadata trailers; + RpcProgress rpcProgress; CountDownLatch closed = new CountDownLatch(1); ArrayList messages = new ArrayList(); boolean onReadyCalled; @@ -1838,8 +1925,14 @@ public class OkHttpClientTransportTest { @Override public void closed(Status status, Metadata trailers) { + closed(status, PROCESSED, trailers); + } + + @Override + public void closed(Status status, RpcProgress rpcProgress, Metadata trailers) { this.status = status; this.trailers = trailers; + this.rpcProgress = rpcProgress; closed.countDown(); } diff --git a/testing/src/main/java/io/grpc/internal/testing/AbstractTransportTest.java b/testing/src/main/java/io/grpc/internal/testing/AbstractTransportTest.java index e165a46940..8bb6921658 100644 --- a/testing/src/main/java/io/grpc/internal/testing/AbstractTransportTest.java +++ b/testing/src/main/java/io/grpc/internal/testing/AbstractTransportTest.java @@ -982,6 +982,14 @@ public abstract class AbstractTransportTest { // This simulates the blocking calls which can trigger clientStream.cancel(). clientStream.cancel(Status.CANCELLED.withCause(status.asRuntimeException())); } + + @Override + public void closed( + Status status, RpcProgress rpcProgress, Metadata trailers) { + super.closed(status, rpcProgress, trailers); + // This simulates the blocking calls which can trigger clientStream.cancel(). + clientStream.cancel(Status.CANCELLED.withCause(status.asRuntimeException())); + } }; clientStream.start(clientStreamListener); StreamCreation serverStreamCreation @@ -1056,6 +1064,12 @@ public abstract class AbstractTransportTest { @Override public void closed(Status status, Metadata trailers) { + closed(status, RpcProgress.PROCESSED, trailers); + } + + @Override + public void closed( + Status status, RpcProgress rpcProgress, Metadata trailers) { assertEquals(Status.CANCELLED.getCode(), status.getCode()); assertEquals("nevermind", status.getDescription()); closedCalled.set(true); @@ -1950,6 +1964,11 @@ public abstract class AbstractTransportTest { @Override public void closed(Status status, Metadata trailers) { + closed(status, RpcProgress.PROCESSED, trailers); + } + + @Override + public void closed(Status status, RpcProgress rpcProgress, Metadata trailers) { if (this.status.isDone()) { fail("headersRead invoked after closed"); }