From 48c6b3d3986189dafa906c52e61b7c2c12429aec Mon Sep 17 00:00:00 2001 From: Carl Mastrangelo Date: Mon, 29 Aug 2016 13:25:33 -0700 Subject: [PATCH] all/tests: unmock ClientCall and ServerCall --- .../grpc/auth/ClientAuthInterceptorTest.java | 53 ++++- .../java/io/grpc/ClientInterceptorsTest.java | 112 +++++++--- core/src/test/java/io/grpc/ContextsTest.java | 26 ++- .../java/io/grpc/ServerInterceptorsTest.java | 35 ++- .../java/io/grpc/stub/StubConfigTest.java | 31 ++- stub/build.gradle | 1 + .../java/io/grpc/stub/ClientCallsTest.java | 200 ++++++++++++------ .../java/io/grpc/stub/ServerCallsTest.java | 79 +++++-- 8 files changed, 406 insertions(+), 131 deletions(-) diff --git a/auth/src/test/java/io/grpc/auth/ClientAuthInterceptorTest.java b/auth/src/test/java/io/grpc/auth/ClientAuthInterceptorTest.java index 96ccb45584..06f5a5f61c 100644 --- a/auth/src/test/java/io/grpc/auth/ClientAuthInterceptorTest.java +++ b/auth/src/test/java/io/grpc/auth/ClientAuthInterceptorTest.java @@ -31,11 +31,12 @@ package io.grpc.auth; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; import static org.mockito.Matchers.any; import static org.mockito.Matchers.isA; import static org.mockito.Matchers.same; import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -102,8 +103,7 @@ public class ClientAuthInterceptorTest { @Mock Channel channel; - @Mock - ClientCall call; + ClientCallRecorder call = new ClientCallRecorder(); ClientAuthInterceptor interceptor; @@ -130,7 +130,8 @@ public class ClientAuthInterceptorTest { interceptor.interceptCall(descriptor, CallOptions.DEFAULT, channel); Metadata headers = new Metadata(); interceptedCall.start(listener, headers); - verify(call).start(listener, headers); + assertEquals(listener, call.responseListener); + assertEquals(headers, call.headers); Iterable authorization = headers.getAll(AUTHORIZATION); Assert.assertArrayEquals(new String[]{"token1", "token2"}, @@ -150,7 +151,8 @@ public class ClientAuthInterceptorTest { ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); Mockito.verify(listener).onClose(statusCaptor.capture(), isA(Metadata.class)); Assert.assertNull(headers.getAll(AUTHORIZATION)); - Mockito.verify(call, never()).start(listener, headers); + assertNull(call.responseListener); + assertNull(call.headers); Assert.assertEquals(Status.Code.UNAUTHENTICATED, statusCaptor.getValue().getCode()); Assert.assertNotNull(statusCaptor.getValue().getCause()); } @@ -169,7 +171,8 @@ public class ClientAuthInterceptorTest { interceptor.interceptCall(descriptor, CallOptions.DEFAULT, channel); Metadata headers = new Metadata(); interceptedCall.start(listener, headers); - verify(call).start(listener, headers); + assertEquals(listener, call.responseListener); + assertEquals(headers, call.headers); Iterable authorization = headers.getAll(AUTHORIZATION); Assert.assertArrayEquals(new String[]{"Bearer allyourbase"}, Iterables.toArray(authorization, String.class)); @@ -191,4 +194,42 @@ public class ClientAuthInterceptorTest { verify(credentials).getRequestMetadata(URI.create("https://example.com:123/a.service")); interceptedCall.cancel("Cancel for test", null); } + + private static final class ClientCallRecorder extends ClientCall { + private ClientCall.Listener responseListener; + private Metadata headers; + private int numMessages; + private String cancelMessage; + private Throwable cancelCause; + private boolean halfClosed; + private String sentMessage; + + @Override + public void start(ClientCall.Listener responseListener, Metadata headers) { + this.responseListener = responseListener; + this.headers = headers; + } + + @Override + public void request(int numMessages) { + this.numMessages = numMessages; + } + + @Override + public void cancel(String message, Throwable cause) { + this.cancelMessage = message; + this.cancelCause = cause; + } + + @Override + public void halfClose() { + halfClosed = true; + } + + @Override + public void sendMessage(String message) { + sentMessage = message; + } + + } } diff --git a/core/src/test/java/io/grpc/ClientInterceptorsTest.java b/core/src/test/java/io/grpc/ClientInterceptorsTest.java index e4acc53b04..5d369d3e18 100644 --- a/core/src/test/java/io/grpc/ClientInterceptorsTest.java +++ b/core/src/test/java/io/grpc/ClientInterceptorsTest.java @@ -31,17 +31,16 @@ package io.grpc; +import static com.google.common.truth.Truth.assertThat; import static java.util.concurrent.TimeUnit.NANOSECONDS; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; -import static org.mockito.Matchers.anyInt; -import static org.mockito.Matchers.eq; import static org.mockito.Matchers.isA; import static org.mockito.Matchers.same; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; @@ -61,8 +60,6 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.MockitoAnnotations; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; import java.util.ArrayList; import java.util.Arrays; @@ -75,8 +72,7 @@ public class ClientInterceptorsTest { @Mock private Channel channel; - @Mock - private ClientCall call; + private BaseClientCall call = new BaseClientCall(); @Mock private MethodDescriptor method; @@ -89,18 +85,6 @@ public class ClientInterceptorsTest { when(channel.newCall( Mockito.>any(), any(CallOptions.class))) .thenReturn(call); - - // Emulate the precondition checks in ChannelImpl.CallImpl - Answer checkStartCalled = new Answer() { - @Override - public Void answer(InvocationOnMock invocation) { - verify(call).start(Mockito.>any(), Mockito.any()); - return null; - } - }; - doAnswer(checkStartCalled).when(call).request(anyInt()); - doAnswer(checkStartCalled).when(call).halfClose(); - doAnswer(checkStartCalled).when(call).sendMessage(Mockito.any()); } @Test(expected = NullPointerException.class) @@ -290,11 +274,10 @@ public class ClientInterceptorsTest { ClientCall interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT); // start() on the intercepted call will eventually reach the call created by the real channel interceptedCall.start(listener, new Metadata()); - ArgumentCaptor captor = ArgumentCaptor.forClass(Metadata.class); // The headers passed to the real channel call will contain the information inserted by the // interceptor. - verify(call).start(same(listener), captor.capture()); - assertEquals("abcd", captor.getValue().get(credKey)); + assertSame(listener, call.listener); + assertEquals("abcd", call.headers.get(credKey)); } @Test @@ -327,12 +310,11 @@ public class ClientInterceptorsTest { ClientCall interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT); interceptedCall.start(listener, new Metadata()); // Capture the underlying call listener that will receive headers from the transport. - ArgumentCaptor> captor = ArgumentCaptor.forClass(null); - verify(call).start(captor.capture(), Mockito.any()); + Metadata inboundHeaders = new Metadata(); // Simulate that a headers arrives on the underlying call listener. - captor.getValue().onHeaders(inboundHeaders); - assertEquals(Arrays.asList(inboundHeaders), examinedHeaders); + call.listener.onHeaders(inboundHeaders); + assertThat(examinedHeaders).contains(inboundHeaders); } @Test @@ -354,13 +336,14 @@ public class ClientInterceptorsTest { ClientCall.Listener listener = mock(ClientCall.Listener.class); Metadata headers = new Metadata(); interceptedCall.start(listener, headers); - verify(call).start(same(listener), same(headers)); + assertSame(listener, call.listener); + assertSame(headers, call.headers); interceptedCall.sendMessage("request"); - verify(call).sendMessage(eq("request")); + assertThat(call.messages).containsExactly("request"); interceptedCall.halfClose(); - verify(call).halfClose(); + assertTrue(call.halfClosed); interceptedCall.request(1); - verify(call).request(1); + assertThat(call.requests).containsExactly(1); } @Test @@ -392,7 +375,7 @@ public class ClientInterceptorsTest { interceptedCall.sendMessage("request"); interceptedCall.halfClose(); interceptedCall.request(1); - verifyNoMoreInteractions(call); + call.done = true; ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); verify(listener).onClose(captor.capture(), any(Metadata.class)); assertSame(error, captor.getValue().getCause()); @@ -406,7 +389,6 @@ public class ClientInterceptorsTest { noop.halfClose(); noop.sendMessage(null); assertFalse(noop.isReady()); - verifyNoMoreInteractions(call); } @Test @@ -432,12 +414,12 @@ public class ClientInterceptorsTest { CallOptions callOptions = CallOptions.DEFAULT.withOption(customOption, "value"); ArgumentCaptor passedOptions = ArgumentCaptor.forClass(CallOptions.class); ClientInterceptor interceptor = spy(new NoopInterceptor()); - + Channel intercepted = ClientInterceptors.intercept(channel, interceptor); - + assertSame(call, intercepted.newCall(method, callOptions)); verify(channel).newCall(same(method), same(callOptions)); - + verify(interceptor).interceptCall(same(method), passedOptions.capture(), isA(Channel.class)); assertSame("value", passedOptions.getValue().getOption(customOption)); } @@ -449,4 +431,64 @@ public class ClientInterceptorsTest { return next.newCall(method, callOptions); } } + + private static class BaseClientCall extends ClientCall { + private boolean started; + private boolean done; + private ClientCall.Listener listener; + private Metadata headers; + private List requests = new ArrayList(); + private List messages = new ArrayList(); + private boolean halfClosed; + private Throwable cancelCause; + private String cancelMessage; + + @Override + public void start(ClientCall.Listener listener, Metadata headers) { + checkNotDone(); + started = true; + this.listener = listener; + this.headers = headers; + } + + @Override + public void request(int numMessages) { + checkNotDone(); + checkStarted(); + requests.add(numMessages); + } + + @Override + public void cancel(String message, Throwable cause) { + checkNotDone(); + this.cancelMessage = message; + this.cancelCause = cause; + } + + @Override + public void halfClose() { + checkNotDone(); + checkStarted(); + this.halfClosed = true; + } + + @Override + public void sendMessage(String message) { + checkNotDone(); + checkStarted(); + messages.add(message); + } + + private void checkNotDone() { + if (done) { + throw new IllegalStateException("no more methods should be called"); + } + } + + private void checkStarted() { + if (!started) { + throw new IllegalStateException("should have called start"); + } + } + } } diff --git a/core/src/test/java/io/grpc/ContextsTest.java b/core/src/test/java/io/grpc/ContextsTest.java index 8ea6271492..8e9cb3fc5e 100644 --- a/core/src/test/java/io/grpc/ContextsTest.java +++ b/core/src/test/java/io/grpc/ContextsTest.java @@ -45,6 +45,7 @@ import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; import io.grpc.internal.FakeClock; + import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -66,7 +67,30 @@ public class ContextsTest { @SuppressWarnings("unchecked") private MethodDescriptor method = mock(MethodDescriptor.class); @SuppressWarnings("unchecked") - private ServerCall call = mock(ServerCall.class); + private ServerCall call = new ServerCall() { + + @Override + public void request(int numMessages) {} + + @Override + public void sendHeaders(Metadata headers) {} + + @Override + public void sendMessage(Object message) {} + + @Override + public void close(Status status, Metadata trailers) {} + + @Override + public boolean isCancelled() { + return false; + } + + @Override + public MethodDescriptor getMethodDescriptor() { + return null; + } + }; private Metadata headers = new Metadata(); @Test diff --git a/core/src/test/java/io/grpc/ServerInterceptorsTest.java b/core/src/test/java/io/grpc/ServerInterceptorsTest.java index e2a00e7078..4877fa0d51 100644 --- a/core/src/test/java/io/grpc/ServerInterceptorsTest.java +++ b/core/src/test/java/io/grpc/ServerInterceptorsTest.java @@ -45,7 +45,6 @@ import static org.mockito.Mockito.verifyZeroInteractions; import io.grpc.MethodDescriptor.Marshaller; import io.grpc.MethodDescriptor.MethodType; import io.grpc.ServerCall.Listener; -import io.grpc.ServerMethodDefinition; import org.junit.After; import org.junit.Before; @@ -78,9 +77,8 @@ public class ServerInterceptorsTest { private ServerCall.Listener listener; private MethodDescriptor flowMethod; - - @Mock - private ServerCall call; + + private ServerCall call = new BaseServerCall(); private ServerServiceDefinition serviceDefinition; @@ -282,7 +280,7 @@ public class ServerInterceptorsTest { @Test public void argumentsPassed() { @SuppressWarnings("unchecked") - final ServerCall call2 = mock(ServerCall.class); + final ServerCall call2 = new BaseServerCall(); @SuppressWarnings("unchecked") final ServerCall.Listener listener2 = mock(ServerCall.Listener.class); @@ -408,7 +406,7 @@ public class ServerInterceptorsTest { .intercept(inputStreamMessageService, interceptor2); ServerMethodDefinition serverMethod = (ServerMethodDefinition) intercepted2.getMethod("basic/wrapped"); - ServerCall call2 = mock(ServerCall.class); + ServerCall call2 = new BaseServerCall(); byte[] bytes = {}; serverMethod .getServerCallHandler() @@ -459,4 +457,29 @@ public class ServerInterceptorsTest { return inputStream; } } + + private static class BaseServerCall extends ServerCall { + + @Override + public void request(int numMessages) {} + + @Override + public void sendHeaders(Metadata headers) {} + + @Override + public void sendMessage(RespT message) {} + + @Override + public void close(Status status, Metadata trailers) {} + + @Override + public boolean isCancelled() { + return false; + } + + @Override + public MethodDescriptor getMethodDescriptor() { + return null; + } + } } diff --git a/interop-testing/src/test/java/io/grpc/stub/StubConfigTest.java b/interop-testing/src/test/java/io/grpc/stub/StubConfigTest.java index 02f4e4b4f8..6bd8b0c335 100644 --- a/interop-testing/src/test/java/io/grpc/stub/StubConfigTest.java +++ b/interop-testing/src/test/java/io/grpc/stub/StubConfigTest.java @@ -48,6 +48,7 @@ import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ClientCall; import io.grpc.Deadline; +import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.testing.integration.Messages.SimpleRequest; import io.grpc.testing.integration.Messages.SimpleResponse; @@ -73,17 +74,37 @@ public class StubConfigTest { @Mock private StreamObserver responseObserver; - @Mock - private ClientCall call; - /** * Sets up mocks. */ @Before public void setUp() { MockitoAnnotations.initMocks(this); + ClientCall call = + new ClientCall() { + @Override + public void start( + ClientCall.Listener responseListener, Metadata headers) { + } + + @Override + public void request(int numMessages) { + } + + @Override + public void cancel(String message, Throwable cause) { + } + + @Override + public void halfClose() { + } + + @Override + public void sendMessage(SimpleRequest message) { + } + }; when(channel.newCall( - Mockito.>any(), any(CallOptions.class))) - .thenReturn(call); + Mockito.>any(), any(CallOptions.class))) + .thenReturn(call); } @Test diff --git a/stub/build.gradle b/stub/build.gradle index 2a591633a5..d4ed63b067 100644 --- a/stub/build.gradle +++ b/stub/build.gradle @@ -5,6 +5,7 @@ plugins { description = "gRPC: Stub" dependencies { compile project(':grpc-core') + testCompile libraries.truth } // Configure the animal sniffer plugin diff --git a/stub/src/test/java/io/grpc/stub/ClientCallsTest.java b/stub/src/test/java/io/grpc/stub/ClientCallsTest.java index 7cf836c8a0..e058780d9d 100644 --- a/stub/src/test/java/io/grpc/stub/ClientCallsTest.java +++ b/stub/src/test/java/io/grpc/stub/ClientCallsTest.java @@ -31,21 +31,18 @@ package io.grpc.stub; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import static org.mockito.Matchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import io.grpc.CallOptions; import io.grpc.ClientCall; -import io.grpc.ClientCall.Listener; import io.grpc.ManagedChannel; import io.grpc.Metadata; import io.grpc.MethodDescriptor; @@ -64,12 +61,7 @@ import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import org.mockito.ArgumentCaptor; -import org.mockito.Mock; -import org.mockito.Mockito; import org.mockito.MockitoAnnotations; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; import java.util.ArrayList; import java.util.Arrays; @@ -80,6 +72,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; /** * Unit tests for {@link ClientCalls}. @@ -96,9 +89,6 @@ public class ClientCallsTest { private Server server; private ManagedChannel channel; - @Mock - private ClientCall call; - @Before public void setUp() { MockitoAnnotations.initMocks(this); @@ -121,16 +111,13 @@ public class ClientCallsTest { final Status status = Status.OK; final Metadata trailers = new Metadata(); - doAnswer(new Answer() { + BaseClientCall call = new BaseClientCall() { @Override - public Void answer(InvocationOnMock in) throws Throwable { - @SuppressWarnings("unchecked") - Listener listener = (Listener) in.getArguments()[0]; + public void start(ClientCall.Listener listener, Metadata headers) { listener.onMessage(resp); listener.onClose(status, trailers); - return null; } - }).when(call).start(Mockito.>any(), any(Metadata.class)); + }; String actualResponse = ClientCalls.blockingUnaryCall(call, req); assertEquals(resp, actualResponse); @@ -142,15 +129,12 @@ public class ClientCallsTest { final Status status = Status.INTERNAL.withDescription("Unique status"); final Metadata trailers = new Metadata(); - doAnswer(new Answer() { + BaseClientCall call = new BaseClientCall() { @Override - public Void answer(InvocationOnMock in) throws Throwable { - @SuppressWarnings("unchecked") - Listener listener = (Listener) in.getArguments()[0]; + public void start(io.grpc.ClientCall.Listener listener, Metadata headers) { listener.onClose(status, trailers); - return null; } - }).when(call).start(Mockito.>any(), any(Metadata.class)); + }; try { ClientCalls.blockingUnaryCall(call, req); @@ -163,27 +147,50 @@ public class ClientCallsTest { @Test public void unaryFutureCallSuccess() throws Exception { + final AtomicReference> listener = + new AtomicReference>(); + final AtomicReference message = new AtomicReference(); + final AtomicReference halfClosed = new AtomicReference(); + BaseClientCall call = new BaseClientCall() { + @Override + public void start(io.grpc.ClientCall.Listener responseListener, Metadata headers) { + listener.set(responseListener); + } + + @Override + public void sendMessage(Integer msg) { + message.set(msg); + } + + @Override + public void halfClose() { + halfClosed.set(true); + } + }; Integer req = 2; ListenableFuture future = ClientCalls.futureUnaryCall(call, req); - ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(null); - verify(call).start(listenerCaptor.capture(), any(Metadata.class)); - ClientCall.Listener listener = listenerCaptor.getValue(); - verify(call).sendMessage(req); - verify(call).halfClose(); - listener.onMessage("bar"); - listener.onClose(Status.OK, new Metadata()); + + assertEquals(req, message.get()); + assertTrue(halfClosed.get()); + listener.get().onMessage("bar"); + listener.get().onClose(Status.OK, new Metadata()); assertEquals("bar", future.get()); } @Test public void unaryFutureCallFailed() throws Exception { + final AtomicReference> listener = + new AtomicReference>(); + BaseClientCall call = new BaseClientCall() { + @Override + public void start(io.grpc.ClientCall.Listener responseListener, Metadata headers) { + listener.set(responseListener); + } + }; Integer req = 2; ListenableFuture future = ClientCalls.futureUnaryCall(call, req); - ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(null); - verify(call).start(listenerCaptor.capture(), any(Metadata.class)); - ClientCall.Listener listener = listenerCaptor.getValue(); Metadata trailers = new Metadata(); - listener.onClose(Status.INTERNAL, trailers); + listener.get().onClose(Status.INTERNAL, trailers); try { future.get(); fail("Should fail"); @@ -197,15 +204,29 @@ public class ClientCallsTest { @Test public void unaryFutureCallCancelled() throws Exception { + final AtomicReference> listener = + new AtomicReference>(); + final AtomicReference cancelMessage = new AtomicReference(); + final AtomicReference cancelCause = new AtomicReference(); + BaseClientCall call = new BaseClientCall() { + @Override + public void start(io.grpc.ClientCall.Listener responseListener, Metadata headers) { + listener.set(responseListener); + } + + @Override + public void cancel(String message, Throwable cause) { + cancelMessage.set(message); + cancelCause.set(cause); + } + }; Integer req = 2; ListenableFuture future = ClientCalls.futureUnaryCall(call, req); - ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(null); - verify(call).start(listenerCaptor.capture(), any(Metadata.class)); - ClientCall.Listener listener = listenerCaptor.getValue(); future.cancel(true); - verify(call).cancel("GrpcFuture was cancelled", null); - listener.onMessage("bar"); - listener.onClose(Status.OK, new Metadata()); + assertEquals("GrpcFuture was cancelled", cancelMessage.get()); + assertNull(cancelCause.get()); + listener.get().onMessage("bar"); + listener.get().onClose(Status.OK, new Metadata()); try { future.get(); fail("Should fail"); @@ -216,6 +237,7 @@ public class ClientCallsTest { @Test public void cannotSetOnReadyAfterCallStarted() throws Exception { + BaseClientCall call = new BaseClientCall(); CallStreamObserver callStreamObserver = (CallStreamObserver) ClientCalls.asyncClientStreamingCall(call, new NoopStreamObserver()); @@ -235,7 +257,20 @@ public class ClientCallsTest { @Test public void disablingInboundAutoFlowControlSuppressesRequestsForMoreMessages() throws Exception { - ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(null); + final AtomicReference> listener = + new AtomicReference>(); + final List requests = new ArrayList(); + BaseClientCall call = new BaseClientCall() { + @Override + public void start(io.grpc.ClientCall.Listener responseListener, Metadata headers) { + listener.set(responseListener); + } + + @Override + public void request(int numMessages) { + requests.add(numMessages); + } + }; ClientCalls.asyncBidiStreamingCall(call, new ClientResponseObserver() { @Override public void beforeStart(ClientCallStreamObserver requestStream) { @@ -257,15 +292,13 @@ public class ClientCallsTest { } }); - verify(call).start(listenerCaptor.capture(), any(Metadata.class)); - listenerCaptor.getValue().onMessage("message"); - verify(call, times(1)).request(1); + listener.get().onMessage("message"); + assertThat(requests).containsExactly(1); } @Test public void callStreamObserverPropagatesFlowControlRequestsToCall() throws Exception { - ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(null); ClientResponseObserver responseObserver = new ClientResponseObserver() { @Override @@ -285,19 +318,32 @@ public class ClientCallsTest { public void onCompleted() { } }; + final AtomicReference> listener = + new AtomicReference>(); + final List requests = new ArrayList(); + BaseClientCall call = new BaseClientCall() { + @Override + public void start(io.grpc.ClientCall.Listener responseListener, Metadata headers) { + listener.set(responseListener); + } + + @Override + public void request(int numMessages) { + requests.add(numMessages); + } + }; CallStreamObserver requestObserver = (CallStreamObserver) ClientCalls.asyncBidiStreamingCall(call, responseObserver); - verify(call).start(listenerCaptor.capture(), any(Metadata.class)); - listenerCaptor.getValue().onMessage("message"); + listener.get().onMessage("message"); requestObserver.request(5); - verify(call, times(1)).request(5); + assertThat(requests).contains(5); } @Test public void canCaptureInboundFlowControlForServerStreamingObserver() throws Exception { - ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(null); + ClientResponseObserver responseObserver = new ClientResponseObserver() { @Override @@ -318,11 +364,23 @@ public class ClientCallsTest { public void onCompleted() { } }; + final AtomicReference> listener = + new AtomicReference>(); + final List requests = new ArrayList(); + BaseClientCall call = new BaseClientCall() { + @Override + public void start(io.grpc.ClientCall.Listener responseListener, Metadata headers) { + listener.set(responseListener); + } + + @Override + public void request(int numMessages) { + requests.add(numMessages); + } + }; ClientCalls.asyncServerStreamingCall(call, 1, responseObserver); - verify(call).start(listenerCaptor.capture(), any(Metadata.class)); - listenerCaptor.getValue().onMessage("message"); - verify(call, times(1)).request(1); - verify(call, times(1)).request(5); + listener.get().onMessage("message"); + assertThat(requests).containsExactly(5, 1).inOrder(); } @Test @@ -497,13 +555,20 @@ public class ClientCallsTest { @Test public void blockingResponseStreamFailed() throws Exception { + final AtomicReference> listener = + new AtomicReference>(); + BaseClientCall call = new BaseClientCall() { + @Override + public void start(io.grpc.ClientCall.Listener responseListener, Metadata headers) { + listener.set(responseListener); + } + }; + Integer req = 2; Iterator iter = ClientCalls.blockingServerStreamingCall(call, req); - ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(null); - verify(call).start(listenerCaptor.capture(), any(Metadata.class)); - ClientCall.Listener listener = listenerCaptor.getValue(); + Metadata trailers = new Metadata(); - listener.onClose(Status.INTERNAL, trailers); + listener.get().onClose(Status.INTERNAL, trailers); try { iter.next(); fail("Should fail"); @@ -514,4 +579,21 @@ public class ClientCallsTest { assertSame(trailers, metadata); } } + + private static class BaseClientCall extends ClientCall { + @Override + public void start(io.grpc.ClientCall.Listener responseListener, Metadata headers) {} + + @Override + public void request(int numMessages) {} + + @Override + public void cancel(String message, Throwable cause) {} + + @Override + public void halfClose() {} + + @Override + public void sendMessage(Integer message) {} + } } diff --git a/stub/src/test/java/io/grpc/stub/ServerCallsTest.java b/stub/src/test/java/io/grpc/stub/ServerCallsTest.java index 48d8f6235f..c6188dd7bd 100644 --- a/stub/src/test/java/io/grpc/stub/ServerCallsTest.java +++ b/stub/src/test/java/io/grpc/stub/ServerCallsTest.java @@ -31,12 +31,12 @@ package io.grpc.stub; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import static org.mockito.Mockito.times; import io.grpc.CallOptions; import io.grpc.ClientCall; @@ -51,13 +51,9 @@ import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.internal.ManagedChannelImpl; -import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.MockitoAnnotations; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -65,6 +61,8 @@ import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; import java.io.InputStream; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; @@ -88,13 +86,7 @@ public class ServerCallsTest { "some/unarymethod", new IntegerMarshaller(), new IntegerMarshaller()); - @Mock - ServerCall serverCall; - - @Before - public void setUp() throws Exception { - MockitoAnnotations.initMocks(this); - } + private final ServerCallRecorder serverCall = new ServerCallRecorder(); @Test public void runtimeStreamObserverIsServerCallStreamObserver() throws Exception { @@ -130,8 +122,8 @@ public class ServerCallsTest { }); ServerCall.Listener callListener = callHandler.startCall(serverCall, new Metadata()); - Mockito.when(serverCall.isReady()).thenReturn(true).thenReturn(false); - Mockito.when(serverCall.isCancelled()).thenReturn(false).thenReturn(true); + serverCall.isReady = true; + serverCall.isCancelled = false; assertTrue(callObserver.get().isReady()); assertFalse(callObserver.get().isCancelled()); callListener.onReady(); @@ -140,11 +132,13 @@ public class ServerCallsTest { assertTrue(invokeCalled.get()); assertTrue(onReadyCalled.get()); assertTrue(onCancelCalled.get()); + serverCall.isReady = false; + serverCall.isCancelled = true; assertFalse(callObserver.get().isReady()); assertTrue(callObserver.get().isCancelled()); // Is called twice, once to permit the first message and once again after the first message // has been processed (auto flow control) - Mockito.verify(serverCall, times(2)).request(1); + assertThat(serverCall.requestCalls).containsExactly(1, 1).inOrder(); } @Test @@ -247,7 +241,7 @@ public class ServerCallsTest { // to verify that message delivery does not trigger a call to request(1). callListener.onMessage(1); // Should never be called - Mockito.verify(serverCall, times(0)).request(1); + assertThat(serverCall.requestCalls).isEmpty(); } @Test @@ -265,7 +259,7 @@ public class ServerCallsTest { callHandler.startCall(serverCall, new Metadata()); // Auto inbound flow-control always requests 2 messages for unary to detect a violation // of the unary semantic. - Mockito.verify(serverCall, times(1)).request(2); + assertThat(serverCall.requestCalls).containsExactly(2); } @Test @@ -288,8 +282,8 @@ public class ServerCallsTest { }); ServerCall.Listener callListener = callHandler.startCall(serverCall, new Metadata()); - Mockito.when(serverCall.isReady()).thenReturn(true).thenReturn(false); - Mockito.when(serverCall.isCancelled()).thenReturn(false).thenReturn(true); + serverCall.isReady = true; + serverCall.isCancelled = false; callListener.onReady(); // On ready is not called until the unary request message is delivered assertEquals(0, onReadyCalled.get()); @@ -392,4 +386,51 @@ public class ServerCallsTest { } } } + + private static class ServerCallRecorder extends ServerCall { + private List requestCalls = new ArrayList(); + private Metadata headers; + private Integer message; + private Metadata trailers; + private Status status; + private boolean isCancelled; + private MethodDescriptor methodDescriptor; + private boolean isReady; + + @Override + public void request(int numMessages) { + requestCalls.add(numMessages); + } + + @Override + public void sendHeaders(Metadata headers) { + this.headers = headers; + } + + @Override + public void sendMessage(Integer message) { + this.message = message; + } + + @Override + public void close(Status status, Metadata trailers) { + this.status = status; + this.trailers = trailers; + } + + @Override + public boolean isCancelled() { + return isCancelled; + } + + @Override + public boolean isReady() { + return isReady; + } + + @Override + public MethodDescriptor getMethodDescriptor() { + return methodDescriptor; + } + } }