diff --git a/core/src/main/java/io/grpc/internal/AbstractClientStream.java b/core/src/main/java/io/grpc/internal/AbstractClientStream.java index 4ef743bf96..a4ebfa52d6 100644 --- a/core/src/main/java/io/grpc/internal/AbstractClientStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractClientStream.java @@ -455,10 +455,10 @@ public abstract class AbstractClientStream extends AbstractStream if (!listenerClosed) { listenerClosed = true; statsTraceCtx.streamClosed(status); - listener().closed(status, rpcProgress, trailers); if (getTransportTracer() != null) { getTransportTracer().reportStreamClosed(status.isOk()); } + listener().closed(status, rpcProgress, trailers); } } } diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java index 880320c6ca..dbbcf39d0a 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java @@ -18,7 +18,10 @@ package io.grpc.testing.integration; import static com.google.common.truth.Truth.assertThat; import static java.util.concurrent.TimeUnit.SECONDS; +import static org.junit.Assert.assertNotNull; +import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.verify; @@ -78,8 +81,6 @@ import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import org.mockito.ArgumentCaptor; -import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; @@ -103,8 +104,11 @@ public class RetryTest { @Rule public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); private final FakeClock fakeClock = new FakeClock(); - @Mock - private ClientCall.Listener mockCallListener; + private TestListener testCallListener = new TestListener(); + @SuppressWarnings("unchecked") + private ClientCall.Listener mockCallListener = + mock(ClientCall.Listener.class, delegatesTo(testCallListener)); + private CountDownLatch backoffLatch = new CountDownLatch(1); private final EventLoopGroup group = new DefaultEventLoopGroup() { @SuppressWarnings("FutureReturnValueIgnored") @@ -245,7 +249,9 @@ public class RetryTest { private void assertRpcStatusRecorded( Status.Code code, long roundtripLatencyMs, long outboundMessages) throws Exception { MetricsRecord record = clientStatsRecorder.pollRecord(7, SECONDS); + assertNotNull(record); TagValue statusTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_STATUS); + assertNotNull(statusTag); assertThat(statusTag.asString()).isEqualTo(code.toString()); assertThat(record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_FINISHED_COUNT)) .isEqualTo(1); @@ -295,14 +301,14 @@ public class RetryTest { verify(mockCallListener, never()).onClose(any(Status.class), any(Metadata.class)); // send one more message, should exceed buffer limit call.sendMessage(message); + // let attempt fail + testCallListener.clear(); serverCall.close( Status.UNAVAILABLE.withDescription("2nd attempt failed"), new Metadata()); // no more retry - ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); - verify(mockCallListener, timeout(5000)).onClose(statusCaptor.capture(), any(Metadata.class)); - assertThat(statusCaptor.getValue().getDescription()).contains("2nd attempt failed"); + testCallListener.verifyDescription("2nd attempt failed", 5000); } @Test @@ -534,4 +540,26 @@ public class RetryTest { assertRpcStatusRecorded(Code.INVALID_ARGUMENT, 0, 0); assertRetryStatsRecorded(0, 1, 0); } + + private static class TestListener extends ClientCall.Listener { + Status status = null; + private CountDownLatch closeLatch = new CountDownLatch(1); + + @Override + public void onClose(Status status, Metadata trailers) { + this.status = status; + closeLatch.countDown(); + } + + void clear() { + status = null; + closeLatch = new CountDownLatch(1); + } + + void verifyDescription(String description, long timeoutMs) throws InterruptedException { + closeLatch.await(timeoutMs, TimeUnit.MILLISECONDS); + assertNotNull(status); + assertThat(status.getDescription()).contains(description); + } + } }