From a8db154702edbd84b06cd0997dcf3d5afdd65331 Mon Sep 17 00:00:00 2001 From: Eric Gribkoff Date: Mon, 25 Sep 2017 17:58:05 -0700 Subject: [PATCH] testing: reduce mocks in AbstractTransportTest to eliminate flakes --- .../testing/AbstractTransportTest.java | 656 +++++++++--------- .../testing/TestClientStreamTracer.java | 13 + 2 files changed, 322 insertions(+), 347 deletions(-) diff --git a/testing/src/main/java/io/grpc/internal/testing/AbstractTransportTest.java b/testing/src/main/java/io/grpc/internal/testing/AbstractTransportTest.java index b96b155995..ac9fedab17 100644 --- a/testing/src/main/java/io/grpc/internal/testing/AbstractTransportTest.java +++ b/testing/src/main/java/io/grpc/internal/testing/AbstractTransportTest.java @@ -32,11 +32,9 @@ import static org.mockito.Matchers.anyBoolean; import static org.mockito.Matchers.anyString; import static org.mockito.Matchers.eq; import static org.mockito.Matchers.same; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; -import static org.mockito.Mockito.reset; import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; @@ -65,13 +63,13 @@ import io.grpc.internal.ServerStreamListener; import io.grpc.internal.ServerTransport; import io.grpc.internal.ServerTransportListener; import io.grpc.internal.StatsTraceContext; -import io.grpc.internal.StreamListener; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.util.Arrays; import java.util.List; import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.LinkedBlockingQueue; @@ -87,8 +85,6 @@ import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; import org.mockito.InOrder; import org.mockito.Matchers; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; /** Standard unit tests for {@link ClientTransport}s and {@link ServerTransport}s. */ @RunWith(JUnit4.class) @@ -151,15 +147,11 @@ public abstract class AbstractTransportTest { private ManagedClientTransport.Listener mockClientTransportListener = mock(ManagedClientTransport.Listener.class); - private ClientStreamListener mockClientStreamListener = mock(ClientStreamListener.class); - private final BlockingQueue clientStreamMessageQueue = - new LinkedBlockingQueue(); private MockServerListener serverListener = new MockServerListener(); - private ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); private ArgumentCaptor throwableCaptor = ArgumentCaptor.forClass(Throwable.class); - private ArgumentCaptor metadataCaptor = ArgumentCaptor.forClass(Metadata.class); private final ClientStreamTracer.Factory clientStreamTracerFactory = mock(ClientStreamTracer.Factory.class); + private final TestClientStreamTracer clientStreamTracer1 = new TestClientStreamTracer(); private final TestClientStreamTracer clientStreamTracer2 = new TestClientStreamTracer(); private final ServerStreamTracer.Factory serverStreamTracerFactory = @@ -181,22 +173,6 @@ public abstract class AbstractTransportTest { .thenReturn(serverStreamTracer1) .thenReturn(serverStreamTracer2); callOptions = CallOptions.DEFAULT.withStreamTracerFactory(clientStreamTracerFactory); - - doAnswer( - new Answer() { - @Override - public Void answer(InvocationOnMock invocation) throws Throwable { - StreamListener.MessageProducer producer = - (StreamListener.MessageProducer) invocation.getArguments()[0]; - InputStream message; - while ((message = producer.next()) != null) { - clientStreamMessageQueue.add(message); - } - return null; - } - }) - .when(mockClientStreamListener) - .messagesAvailable(Matchers.any()); } @After @@ -247,7 +223,8 @@ public abstract class AbstractTransportTest { // Netty channel. ClientStream stream = client.newStream(methodDescriptor, new Metadata(), callOptions); - stream.start(mockClientStreamListener); + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + stream.start(clientStreamListener); StreamCreation serverStreamCreation = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); stream.flush(); @@ -260,8 +237,9 @@ public abstract class AbstractTransportTest { serverStreamCreation.stream.writeMessage(methodDescriptor.streamResponse("bar")); serverStreamCreation.stream.flush(); - verify(mockClientStreamListener, timeout(250)) - .closed(eq(Status.CANCELLED), any(Metadata.class)); + assertEquals( + Status.CANCELLED, clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); ClientStreamListener mockClientStreamListener2 = mock(ClientStreamListener.class); @@ -289,6 +267,7 @@ public abstract class AbstractTransportTest { InOrder inOrder = inOrder(mockClientTransportListener); runIfNotNull(client.start(mockClientTransportListener)); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportTerminated(); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); inOrder.verify(mockClientTransportListener).transportShutdown(statusCaptor.capture()); assertCodeEquals(Status.UNAVAILABLE, statusCaptor.getValue()); inOrder.verify(mockClientTransportListener).transportTerminated(); @@ -349,12 +328,13 @@ public abstract class AbstractTransportTest { serverTransport = serverTransportListener.transport; ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); - clientStream.start(mockClientStreamListener); + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + clientStream.start(clientStreamListener); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportInUse(true); StreamCreation serverStreamCreation = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); ServerStream serverStream = serverStreamCreation.stream; - ServerStreamListener mockServerStreamListener = serverStreamCreation.listener; + ServerStreamListenerBase serverStreamListener = serverStreamCreation.listener; client.shutdown(Status.UNAVAILABLE); client = null; @@ -376,8 +356,8 @@ public abstract class AbstractTransportTest { // the stream still functions. serverStream.writeHeaders(new Metadata()); clientStream.halfClose(); - verify(mockClientStreamListener, timeout(TIMEOUT_MS)).headersRead(any(Metadata.class)); - verify(mockServerStreamListener, timeout(TIMEOUT_MS)).halfClosed(); + assertNotNull(clientStreamListener.headers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertTrue(serverStreamListener.awaitHalfClosed(TIMEOUT_MS, TimeUnit.MILLISECONDS)); verify(mockClientTransportListener, never()).transportTerminated(); verify(mockClientTransportListener, never()).transportInUse(false); @@ -400,11 +380,12 @@ public abstract class AbstractTransportTest { serverTransport = serverTransportListener.transport; ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); - clientStream.start(mockClientStreamListener); + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + clientStream.start(clientStreamListener); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportInUse(true); StreamCreation serverStreamCreation = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); - ServerStreamListener mockServerStreamListener = serverStreamCreation.listener; + ServerStreamListenerBase serverStreamListener = serverStreamCreation.listener; Status status = Status.UNKNOWN.withDescription("test shutdownNow"); client.shutdownNow(status); @@ -416,15 +397,15 @@ public abstract class AbstractTransportTest { assertTrue(serverTransportListener.waitForTermination(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertTrue(serverTransportListener.isTerminated()); - verify(mockClientStreamListener, timeout(TIMEOUT_MS)) - .closed(same(status), any(Metadata.class)); - verify(mockServerStreamListener, timeout(TIMEOUT_MS)).closed(statusCaptor.capture()); - assertFalse(statusCaptor.getValue().isOk()); + assertEquals(status, clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status serverStatus = serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertFalse(serverStatus.isOk()); if (metricsExpected()) { assertTrue(clientStreamTracer1.await(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertSame(status, clientStreamTracer1.getStatus()); assertTrue(serverStreamTracer1.await(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - assertSame(statusCaptor.getValue(), serverStreamTracer1.getStatus()); + assertSame(serverStatus, serverStreamTracer1.getStatus()); } } @@ -438,11 +419,12 @@ public abstract class AbstractTransportTest { serverTransport = serverTransportListener.transport; ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); - clientStream.start(mockClientStreamListener); + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + clientStream.start(clientStreamListener); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportInUse(true); StreamCreation serverStreamCreation = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); - ServerStreamListener mockServerStreamListener = serverStreamCreation.listener; + ServerStreamListenerBase serverStreamListener = serverStreamCreation.listener; Status shutdownStatus = Status.UNKNOWN.withDescription("test shutdownNow"); serverTransport.shutdownNow(shutdownStatus); @@ -454,12 +436,12 @@ public abstract class AbstractTransportTest { assertTrue(serverTransportListener.waitForTermination(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertTrue(serverTransportListener.isTerminated()); - verify(mockClientStreamListener, timeout(TIMEOUT_MS)) - .closed(statusCaptor.capture(), any(Metadata.class)); - assertFalse(statusCaptor.getValue().isOk()); + Status clientStreamStatus = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertFalse(clientStreamStatus.isOk()); + assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); if (metricsExpected()) { assertTrue(clientStreamTracer1.await(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - assertSame(statusCaptor.getValue(), clientStreamTracer1.getStatus()); + assertSame(clientStreamStatus, clientStreamTracer1.getStatus()); assertTrue(serverStreamTracer1.await(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertSame(shutdownStatus, serverStreamTracer1.getStatus()); } @@ -467,7 +449,7 @@ public abstract class AbstractTransportTest { // Generally will be same status provided to shutdownNow, but InProcessTransport can't // differentiate between client and server shutdownNow. The status is not really used on // server-side, so we don't care much. - verify(mockServerStreamListener, timeout(TIMEOUT_MS)).closed(any(Status.class)); + assertNotNull(serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); } @Test @@ -493,7 +475,8 @@ public abstract class AbstractTransportTest { runIfNotNull(client.start(mockClientTransportListener)); // Stream prevents termination ClientStream stream = client.newStream(methodDescriptor, new Metadata(), callOptions); - stream.start(mockClientStreamListener); + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + stream.start(clientStreamListener); client.shutdown(Status.UNAVAILABLE); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportShutdown(any(Status.class)); ClientTransport.PingCallback mockPingCallback = mock(ClientTransport.PingCallback.class); @@ -540,7 +523,8 @@ public abstract class AbstractTransportTest { inOrder.verify(clientStreamTracerFactory).newClientStreamTracer( any(CallOptions.class), any(Metadata.class)); } - stream.start(mockClientStreamListener); + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + stream.start(clientStreamListener); client.shutdown(Status.UNAVAILABLE); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportShutdown(any(Status.class)); @@ -549,28 +533,14 @@ public abstract class AbstractTransportTest { inOrder.verify(clientStreamTracerFactory).newClientStreamTracer( any(CallOptions.class), any(Metadata.class)); } - ClientStreamListener mockClientStreamListener2 = mock(ClientStreamListener.class); - doAnswer( - new Answer() { - @Override - public Void answer(InvocationOnMock invocation) throws Throwable { - StreamListener.MessageProducer producer = - (StreamListener.MessageProducer) invocation.getArguments()[0]; - InputStream message; - while ((message = producer.next()) != null) { - message.close(); - } - return null; - } - }) - .when(mockClientStreamListener2) - .messagesAvailable(Matchers.any()); - stream2.start(mockClientStreamListener2); - verify(mockClientStreamListener2, timeout(TIMEOUT_MS)) - .closed(statusCaptor.capture(), any(Metadata.class)); - assertCodeEquals(Status.UNAVAILABLE, statusCaptor.getValue()); + ClientStreamListenerBase clientStreamListener2 = new ClientStreamListenerBase(); + stream2.start(clientStreamListener2); + Status clientStreamStatus2 = + clientStreamListener2.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertNotNull(clientStreamListener2.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertCodeEquals(Status.UNAVAILABLE, clientStreamStatus2); if (metricsExpected()) { - assertSame(statusCaptor.getValue(), clientStreamTracer2.getStatus()); + assertSame(clientStreamStatus2, clientStreamTracer2.getStatus()); } // Make sure earlier stream works. @@ -582,9 +552,8 @@ public abstract class AbstractTransportTest { StreamCreation serverStreamCreation = serverTransportListener.takeStreamOrFail(20 * TIMEOUT_MS, TimeUnit.MILLISECONDS); serverStreamCreation.stream.close(Status.OK, new Metadata()); - verify(mockClientStreamListener, timeout(TIMEOUT_MS)) - .closed(statusCaptor.capture(), any(Metadata.class)); - assertCodeEquals(Status.OK, statusCaptor.getValue()); + assertCodeEquals(Status.OK, clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); } @Test @@ -600,9 +569,11 @@ public abstract class AbstractTransportTest { verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportTerminated(); Thread.sleep(100); ClientStream stream = client.newStream(methodDescriptor, new Metadata(), callOptions); - stream.start(mockClientStreamListener); - verify(mockClientStreamListener, timeout(TIMEOUT_MS)) - .closed(same(shutdownReason), any(Metadata.class)); + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + stream.start(clientStreamListener); + assertEquals( + shutdownReason, clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); verify(mockClientTransportListener, never()).transportInUse(anyBoolean()); if (metricsExpected()) { verify(clientStreamTracerFactory).newClientStreamTracer( @@ -619,14 +590,16 @@ public abstract class AbstractTransportTest { client = newClientTransport(server); runIfNotNull(client.start(mockClientTransportListener)); ClientStream stream1 = client.newStream(methodDescriptor, new Metadata(), callOptions); - stream1.start(mockClientStreamListener); + ClientStreamListenerBase clientStreamListener1 = new ClientStreamListenerBase(); + stream1.start(clientStreamListener1); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportInUse(true); MockServerTransportListener serverTransportListener = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); StreamCreation serverStreamCreation1 = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); ClientStream stream2 = client.newStream(methodDescriptor, new Metadata(), callOptions); - stream2.start(mockClientStreamListener); + ClientStreamListenerBase clientStreamListener2 = new ClientStreamListenerBase(); + stream2.start(clientStreamListener2); StreamCreation serverStreamCreation2 = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); @@ -647,10 +620,12 @@ public abstract class AbstractTransportTest { client = newClientTransport(server); runIfNotNull(client.start(mockClientTransportListener)); ClientStream stream1 = client.newStream(methodDescriptor, new Metadata(), callOptions); - stream1.start(mockClientStreamListener); + ClientStreamListenerBase clientStreamListener1 = new ClientStreamListenerBase(); + stream1.start(clientStreamListener1); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportInUse(true); ClientStream stream2 = client.newStream(methodDescriptor, new Metadata(), callOptions); - stream2.start(mockClientStreamListener); + ClientStreamListenerBase clientStreamListener2 = new ClientStreamListenerBase(); + stream2.start(clientStreamListener2); stream1.cancel(Status.CANCELLED); verify(mockClientTransportListener, never()).transportInUse(false); @@ -687,11 +662,12 @@ public abstract class AbstractTransportTest { same(callOptions), same(clientHeaders)); } - clientStream.start(mockClientStreamListener); + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + clientStream.start(clientStreamListener); StreamCreation serverStreamCreation = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); if (metricsExpected()) { - assertTrue(clientStreamTracer1.getOutboundHeaders()); + assertTrue(clientStreamTracer1.awaitOutboundHeaders(TIMEOUT_MS, TimeUnit.MILLISECONDS)); } assertEquals(methodDescriptor.getFullMethodName(), serverStreamCreation.method); assertEquals(Lists.newArrayList(clientHeadersCopy.getAll(asciiKey)), @@ -699,24 +675,7 @@ public abstract class AbstractTransportTest { assertEquals(Lists.newArrayList(clientHeadersCopy.getAll(binaryKey)), Lists.newArrayList(serverStreamCreation.headers.getAll(binaryKey))); ServerStream serverStream = serverStreamCreation.stream; - ServerStreamListener mockServerStreamListener = serverStreamCreation.listener; - final BlockingQueue serverStreamMessageQueue = - new LinkedBlockingQueue(); - doAnswer( - new Answer() { - @Override - public Void answer(InvocationOnMock invocation) throws Throwable { - StreamListener.MessageProducer producer = - (StreamListener.MessageProducer) invocation.getArguments()[0]; - InputStream message; - while ((message = producer.next()) != null) { - serverStreamMessageQueue.add(message); - } - return null; - } - }) - .when(mockServerStreamListener) - .messagesAvailable(Matchers.any()); + ServerStreamListenerBase serverStreamListener = serverStreamCreation.listener; if (metricsExpected()) { serverInOrder.verify(serverStreamTracerFactory).newServerStreamTracer( @@ -728,7 +687,7 @@ public abstract class AbstractTransportTest { assertNotNull(serverStream.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR)); serverStream.request(1); - verify(mockClientStreamListener, timeout(TIMEOUT_MS)).onReady(); + assertTrue(clientStreamListener.awaitOnReady(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertTrue(clientStream.isReady()); clientStream.writeMessage(methodDescriptor.streamRequest("Hello!")); if (metricsExpected()) { @@ -737,7 +696,7 @@ public abstract class AbstractTransportTest { } clientStream.flush(); - InputStream message = serverStreamMessageQueue.poll(TIMEOUT_MS, TimeUnit.MILLISECONDS); + InputStream message = serverStreamListener.messageQueue.poll(TIMEOUT_MS, TimeUnit.MILLISECONDS); assertEquals("Hello!", methodDescriptor.parseRequest(message)); message.close(); if (metricsExpected()) { @@ -748,10 +707,10 @@ public abstract class AbstractTransportTest { assertThat(serverStreamTracer1.nextInboundEvent()).isEqualTo("inboundMessage(0)"); assertThat(serverStreamTracer1.nextInboundEvent()).isEqualTo("inboundMessage()"); } - assertNull("no additional message expected", serverStreamMessageQueue.poll()); + assertNull("no additional message expected", serverStreamListener.messageQueue.poll()); clientStream.halfClose(); - verify(mockServerStreamListener, timeout(TIMEOUT_MS)).halfClosed(); + assertTrue(serverStreamListener.awaitHalfClosed(TIMEOUT_MS, TimeUnit.MILLISECONDS)); if (metricsExpected()) { assertThat(serverStreamTracer1.getInboundWireSize()).isGreaterThan(0L); @@ -768,14 +727,17 @@ public abstract class AbstractTransportTest { Metadata serverHeadersCopy = new Metadata(); serverHeadersCopy.merge(serverHeaders); serverStream.writeHeaders(serverHeaders); - verify(mockClientStreamListener, timeout(TIMEOUT_MS)).headersRead(metadataCaptor.capture()); - assertEquals(Lists.newArrayList(serverHeadersCopy.getAll(asciiKey)), - Lists.newArrayList(metadataCaptor.getValue().getAll(asciiKey))); - assertEquals(Lists.newArrayList(serverHeadersCopy.getAll(binaryKey)), - Lists.newArrayList(metadataCaptor.getValue().getAll(binaryKey))); + Metadata headers = clientStreamListener.headers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertNotNull(headers); + assertEquals( + Lists.newArrayList(serverHeadersCopy.getAll(asciiKey)), + Lists.newArrayList(headers.getAll(asciiKey))); + assertEquals( + Lists.newArrayList(serverHeadersCopy.getAll(binaryKey)), + Lists.newArrayList(headers.getAll(binaryKey))); clientStream.request(1); - verify(mockServerStreamListener, timeout(TIMEOUT_MS)).onReady(); + assertTrue(serverStreamListener.awaitOnReady(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertTrue(serverStream.isReady()); serverStream.writeMessage(methodDescriptor.streamResponse("Hi. Who are you?")); if (metricsExpected()) { @@ -784,9 +746,8 @@ public abstract class AbstractTransportTest { } serverStream.flush(); - verify(mockClientStreamListener, timeout(TIMEOUT_MS).atLeast(1)) - .messagesAvailable(any(StreamListener.MessageProducer.class)); - message = clientStreamMessageQueue.poll(TIMEOUT_MS, TimeUnit.MILLISECONDS); + message = clientStreamListener.messageQueue.poll(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertNotNull("message expected", message); if (metricsExpected()) { assertThat(serverStreamTracer1.nextOutboundEvent()) .matches("outboundMessageSent\\(0, -?[0-9]+, -?[0-9]+\\)"); @@ -804,7 +765,7 @@ public abstract class AbstractTransportTest { assertThat(clientStreamTracer1.getInboundUncompressedSize()).isGreaterThan(0L); } message.close(); - assertNull("no additional message expected", clientStreamMessageQueue.poll()); + assertNull("no additional message expected", clientStreamListener.messageQueue.poll()); Status status = Status.OK.withDescription("That was normal"); Metadata trailers = new Metadata(); @@ -818,21 +779,23 @@ public abstract class AbstractTransportTest { assertNull(serverStreamTracer1.nextInboundEvent()); assertNull(serverStreamTracer1.nextOutboundEvent()); } - verify(mockServerStreamListener, timeout(TIMEOUT_MS)).closed(statusCaptor.capture()); - assertCodeEquals(Status.OK, statusCaptor.getValue()); - verify(mockClientStreamListener, timeout(TIMEOUT_MS)) - .closed(statusCaptor.capture(), metadataCaptor.capture()); + assertCodeEquals(Status.OK, serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status clientStreamStatus = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + Metadata clientStreamTrailers = + clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); if (metricsExpected()) { - assertSame(statusCaptor.getValue(), clientStreamTracer1.getStatus()); + assertSame(clientStreamStatus, clientStreamTracer1.getStatus()); assertNull(clientStreamTracer1.nextInboundEvent()); assertNull(clientStreamTracer1.nextOutboundEvent()); } - assertEquals(status.getCode(), statusCaptor.getValue().getCode()); - assertEquals(status.getDescription(), statusCaptor.getValue().getDescription()); - assertEquals(Lists.newArrayList(trailers.getAll(asciiKey)), - Lists.newArrayList(metadataCaptor.getValue().getAll(asciiKey))); - assertEquals(Lists.newArrayList(trailers.getAll(binaryKey)), - Lists.newArrayList(metadataCaptor.getValue().getAll(binaryKey))); + assertEquals(status.getCode(), clientStreamStatus.getCode()); + assertEquals(status.getDescription(), clientStreamStatus.getDescription()); + assertEquals( + Lists.newArrayList(trailers.getAll(asciiKey)), + Lists.newArrayList(clientStreamTrailers.getAll(asciiKey))); + assertEquals( + Lists.newArrayList(trailers.getAll(binaryKey)), + Lists.newArrayList(clientStreamTrailers.getAll(binaryKey))); } @Test @@ -846,7 +809,8 @@ public abstract class AbstractTransportTest { Metadata clientHeaders = new Metadata(); ClientStream clientStream = client.newStream(methodDescriptor, clientHeaders, callOptions); - clientStream.start(mockClientStreamListener); + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + clientStream.start(clientStreamListener); StreamCreation serverStreamCreation = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); ServerStream serverStream = serverStreamCreation.stream; @@ -864,30 +828,30 @@ public abstract class AbstractTransportTest { serverTransport = serverTransportListener.transport; ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); - clientStream.start(mockClientStreamListener); + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + clientStream.start(clientStreamListener); StreamCreation serverStreamCreation = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); ServerStream serverStream = serverStreamCreation.stream; - ServerStreamListener mockServerStreamListener = serverStreamCreation.listener; + ServerStreamListenerBase serverStreamListener = serverStreamCreation.listener; clientStream.halfClose(); - verify(mockServerStreamListener, timeout(TIMEOUT_MS)).halfClosed(); + assertTrue(serverStreamListener.awaitHalfClosed(TIMEOUT_MS, TimeUnit.MILLISECONDS)); serverStream.writeHeaders(new Metadata()); - verify(mockClientStreamListener, timeout(TIMEOUT_MS)).headersRead(any(Metadata.class)); + assertNotNull(clientStreamListener.headers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); Status status = Status.OK.withDescription("Nice talking to you"); serverStream.close(status, new Metadata()); - verify(mockServerStreamListener, timeout(TIMEOUT_MS)).closed(statusCaptor.capture()); - assertCodeEquals(Status.OK, statusCaptor.getValue()); - verify(mockClientStreamListener, timeout(TIMEOUT_MS)) - .closed(statusCaptor.capture(), any(Metadata.class)); - assertEquals(status.getCode(), statusCaptor.getValue().getCode()); - assertEquals(status.getDescription(), statusCaptor.getValue().getDescription()); + assertCodeEquals(Status.OK, serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status clientStreamStatus = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertEquals(status.getCode(), clientStreamStatus.getCode()); + assertEquals(status.getDescription(), clientStreamStatus.getDescription()); if (metricsExpected()) { assertTrue(clientStreamTracer1.getOutboundHeaders()); assertTrue(clientStreamTracer1.getInboundHeaders()); - assertSame(statusCaptor.getValue(), clientStreamTracer1.getStatus()); + assertSame(clientStreamStatus, clientStreamTracer1.getStatus()); assertSame(status, serverStreamTracer1.getStatus()); } } @@ -902,43 +866,28 @@ public abstract class AbstractTransportTest { serverTransport = serverTransportListener.transport; ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); - clientStream.start(mockClientStreamListener); + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + clientStream.start(clientStreamListener); StreamCreation serverStreamCreation = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); ServerStream serverStream = serverStreamCreation.stream; - ServerStreamListener mockServerStreamListener = serverStreamCreation.listener; - doAnswer( - new Answer() { - @Override - public Void answer(InvocationOnMock invocation) throws Throwable { - StreamListener.MessageProducer producer = - (StreamListener.MessageProducer) invocation.getArguments()[0]; - InputStream message; - while ((message = producer.next()) != null) { - message.close(); - } - return null; - } - }) - .when(mockServerStreamListener) - .messagesAvailable(Matchers.any()); + ServerStreamListenerBase serverStreamListener = serverStreamCreation.listener; serverStream.writeHeaders(new Metadata()); - verify(mockClientStreamListener, timeout(TIMEOUT_MS)).headersRead(any(Metadata.class)); + assertNotNull(clientStreamListener.headers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); Status status = Status.OK.withDescription("Hello. Goodbye.").withCause(new Exception()); serverStream.close(status, new Metadata()); - verify(mockServerStreamListener, timeout(TIMEOUT_MS)).closed(statusCaptor.capture()); - assertCodeEquals(Status.OK, statusCaptor.getValue()); - verify(mockClientStreamListener, timeout(TIMEOUT_MS)) - .closed(statusCaptor.capture(), any(Metadata.class)); - assertEquals(status.getCode(), statusCaptor.getValue().getCode()); - assertEquals("Hello. Goodbye.", statusCaptor.getValue().getDescription()); - assertNull(statusCaptor.getValue().getCause()); + assertCodeEquals(Status.OK, serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status clientStreamStatus = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertEquals(status.getCode(), clientStreamStatus.getCode()); + assertEquals("Hello. Goodbye.", clientStreamStatus.getDescription()); + assertNull(clientStreamStatus.getCause()); if (metricsExpected()) { assertTrue(clientStreamTracer1.getOutboundHeaders()); assertTrue(clientStreamTracer1.getInboundHeaders()); - assertSame(statusCaptor.getValue(), clientStreamTracer1.getStatus()); + assertSame(clientStreamStatus, clientStreamTracer1.getStatus()); assertSame(status, serverStreamTracer1.getStatus()); } } @@ -953,11 +902,12 @@ public abstract class AbstractTransportTest { serverTransport = serverTransportListener.transport; ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); - clientStream.start(mockClientStreamListener); + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + clientStream.start(clientStreamListener); StreamCreation serverStreamCreation = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); ServerStream serverStream = serverStreamCreation.stream; - ServerStreamListener mockServerStreamListener = serverStreamCreation.listener; + ServerStreamListenerBase serverStreamListener = serverStreamCreation.listener; Status status = Status.OK.withDescription("Hellogoodbye").withCause(new Exception()); Metadata trailers = new Metadata(); @@ -966,21 +916,23 @@ public abstract class AbstractTransportTest { trailers.put(asciiKey, "dupvalue"); trailers.put(binaryKey, "äbinarytrailers"); serverStream.close(status, trailers); - verify(mockServerStreamListener, timeout(TIMEOUT_MS)).closed(statusCaptor.capture()); - assertCodeEquals(Status.OK, statusCaptor.getValue()); - verify(mockClientStreamListener, timeout(TIMEOUT_MS)) - .closed(statusCaptor.capture(), metadataCaptor.capture()); - assertEquals(status.getCode(), statusCaptor.getValue().getCode()); - assertEquals("Hellogoodbye", statusCaptor.getValue().getDescription()); + assertCodeEquals(Status.OK, serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status clientStreamStatus = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + Metadata clientStreamTrailers = + clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertEquals(status.getCode(), clientStreamStatus.getCode()); + assertEquals("Hellogoodbye", clientStreamStatus.getDescription()); // Cause should not be transmitted to the client. - assertNull(statusCaptor.getValue().getCause()); - assertEquals(Lists.newArrayList(trailers.getAll(asciiKey)), - Lists.newArrayList(metadataCaptor.getValue().getAll(asciiKey))); - assertEquals(Lists.newArrayList(trailers.getAll(binaryKey)), - Lists.newArrayList(metadataCaptor.getValue().getAll(binaryKey))); + assertNull(clientStreamStatus.getCause()); + assertEquals( + Lists.newArrayList(trailers.getAll(asciiKey)), + Lists.newArrayList(clientStreamTrailers.getAll(asciiKey))); + assertEquals( + Lists.newArrayList(trailers.getAll(binaryKey)), + Lists.newArrayList(clientStreamTrailers.getAll(binaryKey))); if (metricsExpected()) { assertTrue(clientStreamTracer1.getOutboundHeaders()); - assertSame(statusCaptor.getValue(), clientStreamTracer1.getStatus()); + assertSame(clientStreamStatus, clientStreamTracer1.getStatus()); assertSame(status, serverStreamTracer1.getStatus()); } } @@ -995,25 +947,25 @@ public abstract class AbstractTransportTest { serverTransport = serverTransportListener.transport; ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); - clientStream.start(mockClientStreamListener); + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + clientStream.start(clientStreamListener); StreamCreation serverStreamCreation = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); ServerStream serverStream = serverStreamCreation.stream; - ServerStreamListener mockServerStreamListener = serverStreamCreation.listener; + ServerStreamListenerBase serverStreamListener = serverStreamCreation.listener; Status status = Status.INTERNAL.withDescription("I'm not listening").withCause(new Exception()); serverStream.close(status, new Metadata()); - verify(mockServerStreamListener, timeout(TIMEOUT_MS)).closed(statusCaptor.capture()); - assertCodeEquals(Status.OK, statusCaptor.getValue()); - verify(mockClientStreamListener, timeout(TIMEOUT_MS)) - .closed(statusCaptor.capture(), any(Metadata.class)); - assertEquals(status.getCode(), statusCaptor.getValue().getCode()); - assertEquals(status.getDescription(), statusCaptor.getValue().getDescription()); - assertNull(statusCaptor.getValue().getCause()); + assertCodeEquals(Status.OK, serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status clientStreamStatus = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertEquals(status.getCode(), clientStreamStatus.getCode()); + assertEquals(status.getDescription(), clientStreamStatus.getDescription()); + assertNull(clientStreamStatus.getCause()); if (metricsExpected()) { assertTrue(clientStreamTracer1.getOutboundHeaders()); - assertSame(statusCaptor.getValue(), clientStreamTracer1.getStatus()); + assertSame(clientStreamStatus, clientStreamTracer1.getStatus()); assertSame(status, serverStreamTracer1.getStatus()); } } @@ -1028,44 +980,26 @@ public abstract class AbstractTransportTest { serverTransport = serverTransportListener.transport; ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); - clientStream.start(mockClientStreamListener); + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + clientStream.start(clientStreamListener); StreamCreation serverStreamCreation = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); - ServerStreamListener mockServerStreamListener = serverStreamCreation.listener; - doAnswer( - new Answer() { - @Override - public Void answer(InvocationOnMock invocation) throws Throwable { - StreamListener.MessageProducer producer = - (StreamListener.MessageProducer) invocation.getArguments()[0]; - InputStream message; - while ((message = producer.next()) != null) { - message.close(); - } - return null; - } - }) - .when(mockServerStreamListener) - .messagesAvailable(Matchers.any()); + ServerStreamListenerBase serverStreamListener = serverStreamCreation.listener; Status status = Status.CANCELLED.withDescription("Nevermind").withCause(new Exception()); clientStream.cancel(status); - verify(mockClientStreamListener, timeout(TIMEOUT_MS)) - .closed(Matchers.same(status), any(Metadata.class)); - verify(mockServerStreamListener, timeout(TIMEOUT_MS)).closed(statusCaptor.capture()); - assertNotEquals(Status.Code.OK, statusCaptor.getValue().getCode()); + assertEquals(status, clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status serverStatus = serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertNotEquals(Status.Code.OK, serverStatus.getCode()); // Cause should not be transmitted between client and server - assertNull(statusCaptor.getValue().getCause()); + assertNull(serverStatus.getCause()); - reset(mockServerStreamListener); - reset(mockClientStreamListener); clientStream.cancel(status); - verify(mockServerStreamListener, never()).closed(any(Status.class)); - verify(mockClientStreamListener, never()).closed(any(Status.class), any(Metadata.class)); if (metricsExpected()) { assertTrue(clientStreamTracer1.getOutboundHeaders()); assertSame(status, clientStreamTracer1.getStatus()); - assertSame(statusCaptor.getValue(), serverStreamTracer1.getStatus()); + assertSame(serverStatus, serverStreamTracer1.getStatus()); } } @@ -1118,8 +1052,8 @@ public abstract class AbstractTransportTest { = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); assertEquals(methodDescriptor.getFullMethodName(), serverStreamCreation.method); ServerStream serverStream = serverStreamCreation.stream; - ServerStreamListener mockServerStreamListener = serverStreamCreation.listener; - verify(mockServerStreamListener, timeout(TIMEOUT_MS)).onReady(); + ServerStreamListenerBase serverStreamListener = serverStreamCreation.listener; + assertTrue(serverStreamListener.awaitOnReady(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertTrue(serverStream.isReady()); serverStream.writeHeaders(new Metadata()); @@ -1155,40 +1089,37 @@ public abstract class AbstractTransportTest { serverTransport = serverTransportListener.transport; ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); - clientStream.start(mockClientStreamListener); + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + clientStream.start(clientStreamListener); StreamCreation serverStreamCreation = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); ServerStream serverStream = serverStreamCreation.stream; - ServerStreamListener mockServerStreamListener = serverStreamCreation.listener; + ServerStreamListenerBase serverStreamListener = serverStreamCreation.listener; Status status = Status.DEADLINE_EXCEEDED.withDescription("It was bound to happen") .withCause(new Exception()); serverStream.cancel(status); - verify(mockServerStreamListener, timeout(TIMEOUT_MS)).closed(Matchers.same(status)); - verify(mockClientStreamListener, timeout(TIMEOUT_MS)) - .closed(statusCaptor.capture(), any(Metadata.class)); + assertEquals(status, serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status clientStreamStatus = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); // Presently we can't sent much back to the client in this case. Verify that is the current // behavior for consistency between transports. - assertCodeEquals(Status.CANCELLED, statusCaptor.getValue()); + assertCodeEquals(Status.CANCELLED, clientStreamStatus); // Cause should not be transmitted between server and client - assertNull(statusCaptor.getValue().getCause()); + assertNull(clientStreamStatus.getCause()); if (metricsExpected()) { verify(clientStreamTracerFactory).newClientStreamTracer( any(CallOptions.class), any(Metadata.class)); assertTrue(clientStreamTracer1.getOutboundHeaders()); - assertSame(statusCaptor.getValue(), clientStreamTracer1.getStatus()); + assertSame(clientStreamStatus, clientStreamTracer1.getStatus()); verify(serverStreamTracerFactory).newServerStreamTracer(anyString(), any(Metadata.class)); assertSame(status, serverStreamTracer1.getStatus()); } // Second cancellation shouldn't trigger additional callbacks - reset(mockServerStreamListener); - reset(mockClientStreamListener); serverStream.cancel(status); doPingPong(serverListener); - verify(mockServerStreamListener, never()).closed(any(Status.class)); - verify(mockClientStreamListener, never()).closed(any(Status.class), any(Metadata.class)); } @Test @@ -1196,34 +1127,18 @@ public abstract class AbstractTransportTest { server.start(serverListener); client = newClientTransport(server); runIfNotNull(client.start(mockClientTransportListener)); - MockServerTransportListener serverTransportListener - = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); + MockServerTransportListener serverTransportListener = + serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); - clientStream.start(mockClientStreamListener); - StreamCreation serverStreamCreation - = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + clientStream.start(clientStreamListener); + StreamCreation serverStreamCreation = + serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); assertEquals(methodDescriptor.getFullMethodName(), serverStreamCreation.method); ServerStream serverStream = serverStreamCreation.stream; - ServerStreamListener mockServerStreamListener = serverStreamCreation.listener; - final BlockingQueue serverStreamMessageQueue = - new LinkedBlockingQueue(); - doAnswer( - new Answer() { - @Override - public Void answer(InvocationOnMock invocation) throws Throwable { - StreamListener.MessageProducer producer = - (StreamListener.MessageProducer) invocation.getArguments()[0]; - InputStream message; - while ((message = producer.next()) != null) { - serverStreamMessageQueue.add(message); - } - return null; - } - }) - .when(mockServerStreamListener) - .messagesAvailable(Matchers.any()); + ServerStreamListenerBase serverStreamListener = serverStreamCreation.listener; serverStream.writeHeaders(new Metadata()); @@ -1238,7 +1153,7 @@ public abstract class AbstractTransportTest { } serverStream.request(1); - verify(mockClientStreamListener, timeout(TIMEOUT_MS)).onReady(); + assertTrue(clientStreamListener.awaitOnReady(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertTrue(clientStream.isReady()); final int maxToSend = 10 * 1024; int clientSent; @@ -1260,10 +1175,10 @@ public abstract class AbstractTransportTest { } doPingPong(serverListener); - int serverReceived = verifyMessageCountAndClose(serverStreamMessageQueue, 1); + int serverReceived = verifyMessageCountAndClose(serverStreamListener.messageQueue, 1); clientStream.request(1); - verify(mockServerStreamListener, timeout(TIMEOUT_MS)).onReady(); + assertTrue(serverStreamListener.awaitOnReady(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertTrue(serverStream.isReady()); int serverSent; // Verify that flow control will push back on server. @@ -1284,25 +1199,26 @@ public abstract class AbstractTransportTest { } doPingPong(serverListener); - int clientReceived = verifyMessageCountAndClose(clientStreamMessageQueue, 1); + int clientReceived = verifyMessageCountAndClose(clientStreamListener.messageQueue, 1); serverStream.request(3); clientStream.request(3); doPingPong(serverListener); - clientReceived += verifyMessageCountAndClose(clientStreamMessageQueue, 3); - serverReceived += verifyMessageCountAndClose(serverStreamMessageQueue, 3); + clientReceived += verifyMessageCountAndClose(clientStreamListener.messageQueue, 3); + serverReceived += verifyMessageCountAndClose(serverStreamListener.messageQueue, 3); // Request the rest serverStream.request(clientSent); clientStream.request(serverSent); clientReceived += - verifyMessageCountAndClose(clientStreamMessageQueue, serverSent - clientReceived); + verifyMessageCountAndClose(clientStreamListener.messageQueue, serverSent - clientReceived); serverReceived += - verifyMessageCountAndClose(serverStreamMessageQueue, clientSent - serverReceived); + verifyMessageCountAndClose(serverStreamListener.messageQueue, clientSent - serverReceived); - verify(mockClientStreamListener, timeout(TIMEOUT_MS).times(2)).onReady(); + assertTrue(clientStreamListener.awaitOnReady(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertTrue(clientStreamListener.awaitOnReady(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertTrue(clientStream.isReady()); - verify(mockServerStreamListener, timeout(TIMEOUT_MS).times(2)).onReady(); + assertTrue(serverStreamListener.awaitOnReady(TIMEOUT_MS, TimeUnit.MILLISECONDS)); // ??? assertTrue(serverStream.isReady()); // Request four more @@ -1313,43 +1229,46 @@ public abstract class AbstractTransportTest { serverStream.flush(); } doPingPong(serverListener); - clientReceived += verifyMessageCountAndClose(clientStreamMessageQueue, 4); - serverReceived += verifyMessageCountAndClose(serverStreamMessageQueue, 4); + clientReceived += verifyMessageCountAndClose(clientStreamListener.messageQueue, 4); + serverReceived += verifyMessageCountAndClose(serverStreamListener.messageQueue, 4); // Drain exactly how many messages are left serverStream.request(1); clientStream.request(1); - clientReceived += verifyMessageCountAndClose(clientStreamMessageQueue, 1); - serverReceived += verifyMessageCountAndClose(serverStreamMessageQueue, 1); + clientReceived += verifyMessageCountAndClose(clientStreamListener.messageQueue, 1); + serverReceived += verifyMessageCountAndClose(serverStreamListener.messageQueue, 1); // And now check that the streams can still complete gracefully clientStream.writeMessage(methodDescriptor.streamRequest(largeMessage)); clientStream.flush(); clientStream.halfClose(); doPingPong(serverListener); - verify(mockServerStreamListener, never()).halfClosed(); + assertFalse(serverStreamListener.awaitHalfClosed(TIMEOUT_MS, TimeUnit.MILLISECONDS)); serverStream.request(1); - serverReceived += verifyMessageCountAndClose(serverStreamMessageQueue, 1); + serverReceived += verifyMessageCountAndClose(serverStreamListener.messageQueue, 1); assertEquals(clientSent + 6, serverReceived); - verify(mockServerStreamListener, timeout(TIMEOUT_MS)).halfClosed(); + assertTrue(serverStreamListener.awaitHalfClosed(TIMEOUT_MS, TimeUnit.MILLISECONDS)); serverStream.writeMessage(methodDescriptor.streamResponse(largeMessage)); serverStream.flush(); Status status = Status.OK.withDescription("... quite a lengthy discussion"); serverStream.close(status, new Metadata()); doPingPong(serverListener); - verify(mockClientStreamListener, never()).closed(any(Status.class), any(Metadata.class)); + try { + clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + fail("Expected TimeoutException"); + } catch (TimeoutException expectedException) { + } clientStream.request(1); - clientReceived += verifyMessageCountAndClose(clientStreamMessageQueue, 1); + clientReceived += verifyMessageCountAndClose(clientStreamListener.messageQueue, 1); assertEquals(serverSent + 6, clientReceived); - verify(mockServerStreamListener, timeout(TIMEOUT_MS)).closed(statusCaptor.capture()); - assertCodeEquals(Status.OK, statusCaptor.getValue()); - verify(mockClientStreamListener, timeout(TIMEOUT_MS)) - .closed(statusCaptor.capture(), any(Metadata.class)); - assertEquals(status.getCode(), statusCaptor.getValue().getCode()); - assertEquals(status.getDescription(), statusCaptor.getValue().getDescription()); + assertCodeEquals(Status.OK, serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status clientStreamStatus = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertEquals(status.getCode(), clientStreamStatus.getCode()); + assertEquals(status.getDescription(), clientStreamStatus.getDescription()); } private int verifyMessageCountAndClose(BlockingQueue messageQueue, int count) @@ -1376,41 +1295,21 @@ public abstract class AbstractTransportTest { // boilerplate ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); - ClientStreamListener clientListener = mock(ClientStreamListener.class); - doAnswer( - new Answer() { - @Override - public Void answer(InvocationOnMock invocation) throws Throwable { - StreamListener.MessageProducer producer = - (StreamListener.MessageProducer) invocation.getArguments()[0]; - InputStream message; - while ((message = producer.next()) != null) { - message.close(); - } - return null; - } - }) - .when(clientListener) - .messagesAvailable(Matchers.any()); - clientStream.start(clientListener); + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + clientStream.start(clientStreamListener); StreamCreation server = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); // setup clientStream.request(1); server.stream.close(Status.INTERNAL, new Metadata()); - verify(clientListener, timeout(TIMEOUT_MS).times(1)) - .closed(any(Status.class), any(Metadata.class)); - reset(clientListener); + assertNotNull(clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); // Ensure that for a closed ServerStream, interactions are noops server.stream.writeHeaders(new Metadata()); server.stream.writeMessage(methodDescriptor.streamResponse("response")); server.stream.close(Status.INTERNAL, new Metadata()); - // Even though the client requested a message earlier, the write should not go through - verify(clientListener, never()).headersRead(any(Metadata.class)); - verify(clientListener, never()).messagesAvailable(any(StreamListener.MessageProducer.class)); - verify(clientListener, never()).closed(any(Status.class), any(Metadata.class)); // Make sure new streams still work properly doPingPong(serverListener); @@ -1436,17 +1335,12 @@ public abstract class AbstractTransportTest { // setup server.stream.request(1); clientStream.cancel(Status.UNKNOWN); - verify(server.listener, timeout(TIMEOUT_MS)).closed(any(Status.class)); - reset(server.listener); + assertNotNull(server.listener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); // Ensure that for a cancelled ClientStream, interactions are noops clientStream.writeMessage(methodDescriptor.streamRequest("request")); clientStream.halfClose(); clientStream.cancel(Status.UNKNOWN); - // Even though the server requested a message earlier, the write should not go through - verify(server.listener, never()).messagesAvailable(any(StreamListener.MessageProducer.class)); - verify(server.listener, never()).halfClosed(); - verify(server.listener, never()).closed(any(Status.class)); // Make sure new streams still work properly doPingPong(serverListener); @@ -1458,39 +1352,24 @@ public abstract class AbstractTransportTest { * callbacks, it generally provides plenty of time for Runnables to execute. But it is also faster * on faster machines and more reliable on slower machines. */ - private void doPingPong(MockServerListener serverListener) throws InterruptedException { + private void doPingPong(MockServerListener serverListener) throws Exception { ManagedClientTransport client = newClientTransport(server); runIfNotNull(client.start(mock(ManagedClientTransport.Listener.class))); ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); - ClientStreamListener mockClientStreamListener = mock(ClientStreamListener.class); - doAnswer( - new Answer() { - @Override - public Void answer(InvocationOnMock invocation) throws Throwable { - StreamListener.MessageProducer producer = - (StreamListener.MessageProducer) invocation.getArguments()[0]; - InputStream message; - while ((message = producer.next()) != null) { - message.close(); - } - return null; - } - }) - .when(mockClientStreamListener) - .messagesAvailable(Matchers.any()); - clientStream.start(mockClientStreamListener); + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + clientStream.start(clientStreamListener); MockServerTransportListener serverTransportListener = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); StreamCreation serverStreamCreation = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); ServerStream serverStream = serverStreamCreation.stream; - ServerStreamListener mockServerStreamListener = serverStreamCreation.listener; + ServerStreamListenerBase serverStreamListener = serverStreamCreation.listener; serverStream.close(Status.OK, new Metadata()); - verify(mockClientStreamListener, timeout(TIMEOUT_MS)) - .closed(any(Status.class), any(Metadata.class)); - verify(mockServerStreamListener, timeout(TIMEOUT_MS)).closed(any(Status.class)); + assertNotNull(clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); client.shutdown(Status.UNAVAILABLE); } @@ -1571,22 +1450,7 @@ public abstract class AbstractTransportTest { @Override public void streamCreated(ServerStream stream, String method, Metadata headers) { - ServerStreamListener listener = mock(ServerStreamListener.class); - doAnswer( - new Answer() { - @Override - public Void answer(InvocationOnMock invocation) throws Throwable { - StreamListener.MessageProducer producer = - (StreamListener.MessageProducer) invocation.getArguments()[0]; - InputStream message; - while ((message = producer.next()) != null) { - message.close(); - } - return null; - } - }) - .when(listener) - .messagesAvailable(Matchers.any()); + ServerStreamListenerBase listener = new ServerStreamListenerBase(); streams.add(new StreamCreation(stream, method, headers, listener)); stream.setListener(listener); } @@ -1622,14 +1486,112 @@ public abstract class AbstractTransportTest { } } + private static class ServerStreamListenerBase implements ServerStreamListener { + private final BlockingQueue messageQueue = new LinkedBlockingQueue(); + private final CountDownLatch onReadyLatch = new CountDownLatch(1); + private final CountDownLatch halfClosedLatch = new CountDownLatch(1); + private final SettableFuture status = SettableFuture.create(); + + private boolean awaitOnReady(int timeout, TimeUnit unit) throws Exception { + return onReadyLatch.await(timeout, unit); + } + + private boolean awaitHalfClosed(int timeout, TimeUnit unit) throws Exception { + return halfClosedLatch.await(timeout, unit); + } + + @Override + public void messagesAvailable(MessageProducer producer) { + if (status.isDone()) { + fail("messagesAvailable invoked after closed"); + } + InputStream message; + while ((message = producer.next()) != null) { + messageQueue.add(message); + } + } + + @Override + public void onReady() { + if (status.isDone()) { + fail("onReady invoked after closed"); + } + onReadyLatch.countDown(); + } + + @Override + public void halfClosed() { + if (status.isDone()) { + fail("halfClosed invoked after closed"); + } + halfClosedLatch.countDown(); + } + + @Override + public void closed(Status status) { + if (this.status.isDone()) { + fail("closed invoked more than once"); + } + this.status.set(status); + } + } + + private static class ClientStreamListenerBase implements ClientStreamListener { + private final BlockingQueue messageQueue = new LinkedBlockingQueue(); + private final CountDownLatch onReadyLatch = new CountDownLatch(1); + private final SettableFuture headers = SettableFuture.create(); + private final SettableFuture trailers = SettableFuture.create(); + private final SettableFuture status = SettableFuture.create(); + + private boolean awaitOnReady(int timeout, TimeUnit unit) throws Exception { + return onReadyLatch.await(timeout, unit); + } + + @Override + public void messagesAvailable(MessageProducer producer) { + if (status.isDone()) { + fail("messagesAvailable invoked after closed"); + } + InputStream message; + while ((message = producer.next()) != null) { + messageQueue.add(message); + } + } + + @Override + public void onReady() { + if (status.isDone()) { + fail("onReady invoked after closed"); + } + onReadyLatch.countDown(); + } + + @Override + public void headersRead(Metadata headers) { + if (status.isDone()) { + fail("headersRead invoked after closed"); + } + this.headers.set(headers); + } + + @Override + public void closed(Status status, Metadata trailers) { + if (this.status.isDone()) { + fail("headersRead invoked after closed"); + } + this.status.set(status); + this.trailers.set(trailers); + } + } + private static class StreamCreation { public final ServerStream stream; public final String method; public final Metadata headers; - public final ServerStreamListener listener; + public final ServerStreamListenerBase listener; - public StreamCreation(ServerStream stream, String method, - Metadata headers, ServerStreamListener listener) { + public StreamCreation( + ServerStream stream, String method, Metadata headers, ServerStreamListenerBase listener) { this.stream = stream; this.method = method; this.headers = headers; diff --git a/testing/src/main/java/io/grpc/internal/testing/TestClientStreamTracer.java b/testing/src/main/java/io/grpc/internal/testing/TestClientStreamTracer.java index ee8c700801..c4e35dd169 100644 --- a/testing/src/main/java/io/grpc/internal/testing/TestClientStreamTracer.java +++ b/testing/src/main/java/io/grpc/internal/testing/TestClientStreamTracer.java @@ -18,6 +18,7 @@ package io.grpc.internal.testing; import io.grpc.ClientStreamTracer; import io.grpc.Status; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -26,6 +27,7 @@ import java.util.concurrent.atomic.AtomicBoolean; */ public class TestClientStreamTracer extends ClientStreamTracer implements TestStreamTracer { private final TestBaseStreamTracer delegate = new TestBaseStreamTracer(); + protected final CountDownLatch outboundHeadersLatch = new CountDownLatch(1); protected final AtomicBoolean outboundHeadersCalled = new AtomicBoolean(); protected final AtomicBoolean inboundHeadersCalled = new AtomicBoolean(); @@ -53,6 +55,16 @@ public class TestClientStreamTracer extends ClientStreamTracer implements TestSt return outboundHeadersCalled.get(); } + /** + * Allow tests to await the outbound header event, which depending on the test case may be + * necessary (e.g., if we test for a Netty client's outbound headers upon receiving the start of + * stream on the server side, the tracer won't know that headers were sent until a channel future + * executes). + */ + public boolean awaitOutboundHeaders(int timeout, TimeUnit unit) throws Exception { + return outboundHeadersLatch.await(timeout, unit); + } + @Override public Status getStatus() { return delegate.getStatus(); @@ -151,6 +163,7 @@ public class TestClientStreamTracer extends ClientStreamTracer implements TestSt && delegate.failDuplicateCallbacks.get()) { throw new AssertionError("outboundHeaders called more than once"); } + outboundHeadersLatch.countDown(); } @Override