all/tests: unmock ClientCall and ServerCall

This commit is contained in:
Carl Mastrangelo 2016-08-29 13:25:33 -07:00
parent 3bf8d94f02
commit 48c6b3d398
8 changed files with 406 additions and 131 deletions

View File

@ -31,11 +31,12 @@
package io.grpc.auth; package io.grpc.auth;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.isA; import static org.mockito.Matchers.isA;
import static org.mockito.Matchers.same; import static org.mockito.Matchers.same;
import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -102,8 +103,7 @@ public class ClientAuthInterceptorTest {
@Mock @Mock
Channel channel; Channel channel;
@Mock ClientCallRecorder call = new ClientCallRecorder();
ClientCall<String, Integer> call;
ClientAuthInterceptor interceptor; ClientAuthInterceptor interceptor;
@ -130,7 +130,8 @@ public class ClientAuthInterceptorTest {
interceptor.interceptCall(descriptor, CallOptions.DEFAULT, channel); interceptor.interceptCall(descriptor, CallOptions.DEFAULT, channel);
Metadata headers = new Metadata(); Metadata headers = new Metadata();
interceptedCall.start(listener, headers); interceptedCall.start(listener, headers);
verify(call).start(listener, headers); assertEquals(listener, call.responseListener);
assertEquals(headers, call.headers);
Iterable<String> authorization = headers.getAll(AUTHORIZATION); Iterable<String> authorization = headers.getAll(AUTHORIZATION);
Assert.assertArrayEquals(new String[]{"token1", "token2"}, Assert.assertArrayEquals(new String[]{"token1", "token2"},
@ -150,7 +151,8 @@ public class ClientAuthInterceptorTest {
ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class); ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class);
Mockito.verify(listener).onClose(statusCaptor.capture(), isA(Metadata.class)); Mockito.verify(listener).onClose(statusCaptor.capture(), isA(Metadata.class));
Assert.assertNull(headers.getAll(AUTHORIZATION)); Assert.assertNull(headers.getAll(AUTHORIZATION));
Mockito.verify(call, never()).start(listener, headers); assertNull(call.responseListener);
assertNull(call.headers);
Assert.assertEquals(Status.Code.UNAUTHENTICATED, statusCaptor.getValue().getCode()); Assert.assertEquals(Status.Code.UNAUTHENTICATED, statusCaptor.getValue().getCode());
Assert.assertNotNull(statusCaptor.getValue().getCause()); Assert.assertNotNull(statusCaptor.getValue().getCause());
} }
@ -169,7 +171,8 @@ public class ClientAuthInterceptorTest {
interceptor.interceptCall(descriptor, CallOptions.DEFAULT, channel); interceptor.interceptCall(descriptor, CallOptions.DEFAULT, channel);
Metadata headers = new Metadata(); Metadata headers = new Metadata();
interceptedCall.start(listener, headers); interceptedCall.start(listener, headers);
verify(call).start(listener, headers); assertEquals(listener, call.responseListener);
assertEquals(headers, call.headers);
Iterable<String> authorization = headers.getAll(AUTHORIZATION); Iterable<String> authorization = headers.getAll(AUTHORIZATION);
Assert.assertArrayEquals(new String[]{"Bearer allyourbase"}, Assert.assertArrayEquals(new String[]{"Bearer allyourbase"},
Iterables.toArray(authorization, String.class)); Iterables.toArray(authorization, String.class));
@ -191,4 +194,42 @@ public class ClientAuthInterceptorTest {
verify(credentials).getRequestMetadata(URI.create("https://example.com:123/a.service")); verify(credentials).getRequestMetadata(URI.create("https://example.com:123/a.service"));
interceptedCall.cancel("Cancel for test", null); interceptedCall.cancel("Cancel for test", null);
} }
private static final class ClientCallRecorder extends ClientCall<String, Integer> {
private ClientCall.Listener<Integer> responseListener;
private Metadata headers;
private int numMessages;
private String cancelMessage;
private Throwable cancelCause;
private boolean halfClosed;
private String sentMessage;
@Override
public void start(ClientCall.Listener<Integer> responseListener, Metadata headers) {
this.responseListener = responseListener;
this.headers = headers;
}
@Override
public void request(int numMessages) {
this.numMessages = numMessages;
}
@Override
public void cancel(String message, Throwable cause) {
this.cancelMessage = message;
this.cancelCause = cause;
}
@Override
public void halfClose() {
halfClosed = true;
}
@Override
public void sendMessage(String message) {
sentMessage = message;
}
}
} }

View File

@ -31,17 +31,16 @@
package io.grpc; package io.grpc;
import static com.google.common.truth.Truth.assertThat;
import static java.util.concurrent.TimeUnit.NANOSECONDS; import static java.util.concurrent.TimeUnit.NANOSECONDS;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertSame; import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.isA; import static org.mockito.Matchers.isA;
import static org.mockito.Matchers.same; import static org.mockito.Matchers.same;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
@ -61,8 +60,6 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.Mockito; import org.mockito.Mockito;
import org.mockito.MockitoAnnotations; import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -75,8 +72,7 @@ public class ClientInterceptorsTest {
@Mock @Mock
private Channel channel; private Channel channel;
@Mock private BaseClientCall call = new BaseClientCall();
private ClientCall<String, Integer> call;
@Mock @Mock
private MethodDescriptor<String, Integer> method; private MethodDescriptor<String, Integer> method;
@ -89,18 +85,6 @@ public class ClientInterceptorsTest {
when(channel.newCall( when(channel.newCall(
Mockito.<MethodDescriptor<String, Integer>>any(), any(CallOptions.class))) Mockito.<MethodDescriptor<String, Integer>>any(), any(CallOptions.class)))
.thenReturn(call); .thenReturn(call);
// Emulate the precondition checks in ChannelImpl.CallImpl
Answer<Void> checkStartCalled = new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) {
verify(call).start(Mockito.<ClientCall.Listener<Integer>>any(), Mockito.<Metadata>any());
return null;
}
};
doAnswer(checkStartCalled).when(call).request(anyInt());
doAnswer(checkStartCalled).when(call).halfClose();
doAnswer(checkStartCalled).when(call).sendMessage(Mockito.<String>any());
} }
@Test(expected = NullPointerException.class) @Test(expected = NullPointerException.class)
@ -290,11 +274,10 @@ public class ClientInterceptorsTest {
ClientCall<String, Integer> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT); ClientCall<String, Integer> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT);
// start() on the intercepted call will eventually reach the call created by the real channel // start() on the intercepted call will eventually reach the call created by the real channel
interceptedCall.start(listener, new Metadata()); interceptedCall.start(listener, new Metadata());
ArgumentCaptor<Metadata> captor = ArgumentCaptor.forClass(Metadata.class);
// The headers passed to the real channel call will contain the information inserted by the // The headers passed to the real channel call will contain the information inserted by the
// interceptor. // interceptor.
verify(call).start(same(listener), captor.capture()); assertSame(listener, call.listener);
assertEquals("abcd", captor.getValue().get(credKey)); assertEquals("abcd", call.headers.get(credKey));
} }
@Test @Test
@ -327,12 +310,11 @@ public class ClientInterceptorsTest {
ClientCall<String, Integer> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT); ClientCall<String, Integer> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT);
interceptedCall.start(listener, new Metadata()); interceptedCall.start(listener, new Metadata());
// Capture the underlying call listener that will receive headers from the transport. // Capture the underlying call listener that will receive headers from the transport.
ArgumentCaptor<ClientCall.Listener<Integer>> captor = ArgumentCaptor.forClass(null);
verify(call).start(captor.capture(), Mockito.<Metadata>any());
Metadata inboundHeaders = new Metadata(); Metadata inboundHeaders = new Metadata();
// Simulate that a headers arrives on the underlying call listener. // Simulate that a headers arrives on the underlying call listener.
captor.getValue().onHeaders(inboundHeaders); call.listener.onHeaders(inboundHeaders);
assertEquals(Arrays.asList(inboundHeaders), examinedHeaders); assertThat(examinedHeaders).contains(inboundHeaders);
} }
@Test @Test
@ -354,13 +336,14 @@ public class ClientInterceptorsTest {
ClientCall.Listener<Integer> listener = mock(ClientCall.Listener.class); ClientCall.Listener<Integer> listener = mock(ClientCall.Listener.class);
Metadata headers = new Metadata(); Metadata headers = new Metadata();
interceptedCall.start(listener, headers); interceptedCall.start(listener, headers);
verify(call).start(same(listener), same(headers)); assertSame(listener, call.listener);
assertSame(headers, call.headers);
interceptedCall.sendMessage("request"); interceptedCall.sendMessage("request");
verify(call).sendMessage(eq("request")); assertThat(call.messages).containsExactly("request");
interceptedCall.halfClose(); interceptedCall.halfClose();
verify(call).halfClose(); assertTrue(call.halfClosed);
interceptedCall.request(1); interceptedCall.request(1);
verify(call).request(1); assertThat(call.requests).containsExactly(1);
} }
@Test @Test
@ -392,7 +375,7 @@ public class ClientInterceptorsTest {
interceptedCall.sendMessage("request"); interceptedCall.sendMessage("request");
interceptedCall.halfClose(); interceptedCall.halfClose();
interceptedCall.request(1); interceptedCall.request(1);
verifyNoMoreInteractions(call); call.done = true;
ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class); ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
verify(listener).onClose(captor.capture(), any(Metadata.class)); verify(listener).onClose(captor.capture(), any(Metadata.class));
assertSame(error, captor.getValue().getCause()); assertSame(error, captor.getValue().getCause());
@ -406,7 +389,6 @@ public class ClientInterceptorsTest {
noop.halfClose(); noop.halfClose();
noop.sendMessage(null); noop.sendMessage(null);
assertFalse(noop.isReady()); assertFalse(noop.isReady());
verifyNoMoreInteractions(call);
} }
@Test @Test
@ -432,12 +414,12 @@ public class ClientInterceptorsTest {
CallOptions callOptions = CallOptions.DEFAULT.withOption(customOption, "value"); CallOptions callOptions = CallOptions.DEFAULT.withOption(customOption, "value");
ArgumentCaptor<CallOptions> passedOptions = ArgumentCaptor.forClass(CallOptions.class); ArgumentCaptor<CallOptions> passedOptions = ArgumentCaptor.forClass(CallOptions.class);
ClientInterceptor interceptor = spy(new NoopInterceptor()); ClientInterceptor interceptor = spy(new NoopInterceptor());
Channel intercepted = ClientInterceptors.intercept(channel, interceptor); Channel intercepted = ClientInterceptors.intercept(channel, interceptor);
assertSame(call, intercepted.newCall(method, callOptions)); assertSame(call, intercepted.newCall(method, callOptions));
verify(channel).newCall(same(method), same(callOptions)); verify(channel).newCall(same(method), same(callOptions));
verify(interceptor).interceptCall(same(method), passedOptions.capture(), isA(Channel.class)); verify(interceptor).interceptCall(same(method), passedOptions.capture(), isA(Channel.class));
assertSame("value", passedOptions.getValue().getOption(customOption)); assertSame("value", passedOptions.getValue().getOption(customOption));
} }
@ -449,4 +431,64 @@ public class ClientInterceptorsTest {
return next.newCall(method, callOptions); return next.newCall(method, callOptions);
} }
} }
private static class BaseClientCall extends ClientCall<String, Integer> {
private boolean started;
private boolean done;
private ClientCall.Listener<Integer> listener;
private Metadata headers;
private List<Integer> requests = new ArrayList<Integer>();
private List<String> messages = new ArrayList<String>();
private boolean halfClosed;
private Throwable cancelCause;
private String cancelMessage;
@Override
public void start(ClientCall.Listener<Integer> listener, Metadata headers) {
checkNotDone();
started = true;
this.listener = listener;
this.headers = headers;
}
@Override
public void request(int numMessages) {
checkNotDone();
checkStarted();
requests.add(numMessages);
}
@Override
public void cancel(String message, Throwable cause) {
checkNotDone();
this.cancelMessage = message;
this.cancelCause = cause;
}
@Override
public void halfClose() {
checkNotDone();
checkStarted();
this.halfClosed = true;
}
@Override
public void sendMessage(String message) {
checkNotDone();
checkStarted();
messages.add(message);
}
private void checkNotDone() {
if (done) {
throw new IllegalStateException("no more methods should be called");
}
}
private void checkStarted() {
if (!started) {
throw new IllegalStateException("should have called start");
}
}
}
} }

View File

@ -45,6 +45,7 @@ import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import io.grpc.internal.FakeClock; import io.grpc.internal.FakeClock;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.JUnit4;
@ -66,7 +67,30 @@ public class ContextsTest {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private MethodDescriptor<Object, Object> method = mock(MethodDescriptor.class); private MethodDescriptor<Object, Object> method = mock(MethodDescriptor.class);
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private ServerCall<Object, Object> call = mock(ServerCall.class); private ServerCall<Object, Object> call = new ServerCall<Object, Object>() {
@Override
public void request(int numMessages) {}
@Override
public void sendHeaders(Metadata headers) {}
@Override
public void sendMessage(Object message) {}
@Override
public void close(Status status, Metadata trailers) {}
@Override
public boolean isCancelled() {
return false;
}
@Override
public MethodDescriptor<Object, Object> getMethodDescriptor() {
return null;
}
};
private Metadata headers = new Metadata(); private Metadata headers = new Metadata();
@Test @Test

View File

@ -45,7 +45,6 @@ import static org.mockito.Mockito.verifyZeroInteractions;
import io.grpc.MethodDescriptor.Marshaller; import io.grpc.MethodDescriptor.Marshaller;
import io.grpc.MethodDescriptor.MethodType; import io.grpc.MethodDescriptor.MethodType;
import io.grpc.ServerCall.Listener; import io.grpc.ServerCall.Listener;
import io.grpc.ServerMethodDefinition;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
@ -78,9 +77,8 @@ public class ServerInterceptorsTest {
private ServerCall.Listener<String> listener; private ServerCall.Listener<String> listener;
private MethodDescriptor<String, Integer> flowMethod; private MethodDescriptor<String, Integer> flowMethod;
@Mock private ServerCall<String, Integer> call = new BaseServerCall<String, Integer>();
private ServerCall<String, Integer> call;
private ServerServiceDefinition serviceDefinition; private ServerServiceDefinition serviceDefinition;
@ -282,7 +280,7 @@ public class ServerInterceptorsTest {
@Test @Test
public void argumentsPassed() { public void argumentsPassed() {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
final ServerCall<String, Integer> call2 = mock(ServerCall.class); final ServerCall<String, Integer> call2 = new BaseServerCall<String, Integer>();
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
final ServerCall.Listener<String> listener2 = mock(ServerCall.Listener.class); final ServerCall.Listener<String> listener2 = mock(ServerCall.Listener.class);
@ -408,7 +406,7 @@ public class ServerInterceptorsTest {
.intercept(inputStreamMessageService, interceptor2); .intercept(inputStreamMessageService, interceptor2);
ServerMethodDefinition<InputStream, InputStream> serverMethod = ServerMethodDefinition<InputStream, InputStream> serverMethod =
(ServerMethodDefinition<InputStream, InputStream>) intercepted2.getMethod("basic/wrapped"); (ServerMethodDefinition<InputStream, InputStream>) intercepted2.getMethod("basic/wrapped");
ServerCall<InputStream, InputStream> call2 = mock(ServerCall.class); ServerCall<InputStream, InputStream> call2 = new BaseServerCall<InputStream, InputStream>();
byte[] bytes = {}; byte[] bytes = {};
serverMethod serverMethod
.getServerCallHandler() .getServerCallHandler()
@ -459,4 +457,29 @@ public class ServerInterceptorsTest {
return inputStream; return inputStream;
} }
} }
private static class BaseServerCall<ReqT, RespT> extends ServerCall<ReqT, RespT> {
@Override
public void request(int numMessages) {}
@Override
public void sendHeaders(Metadata headers) {}
@Override
public void sendMessage(RespT message) {}
@Override
public void close(Status status, Metadata trailers) {}
@Override
public boolean isCancelled() {
return false;
}
@Override
public MethodDescriptor<ReqT, RespT> getMethodDescriptor() {
return null;
}
}
} }

View File

@ -48,6 +48,7 @@ import io.grpc.CallOptions;
import io.grpc.Channel; import io.grpc.Channel;
import io.grpc.ClientCall; import io.grpc.ClientCall;
import io.grpc.Deadline; import io.grpc.Deadline;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
import io.grpc.testing.integration.Messages.SimpleRequest; import io.grpc.testing.integration.Messages.SimpleRequest;
import io.grpc.testing.integration.Messages.SimpleResponse; import io.grpc.testing.integration.Messages.SimpleResponse;
@ -73,17 +74,37 @@ public class StubConfigTest {
@Mock @Mock
private StreamObserver<SimpleResponse> responseObserver; private StreamObserver<SimpleResponse> responseObserver;
@Mock
private ClientCall<SimpleRequest, SimpleResponse> call;
/** /**
* Sets up mocks. * Sets up mocks.
*/ */
@Before public void setUp() { @Before public void setUp() {
MockitoAnnotations.initMocks(this); MockitoAnnotations.initMocks(this);
ClientCall<SimpleRequest, SimpleResponse> call =
new ClientCall<SimpleRequest, SimpleResponse>() {
@Override
public void start(
ClientCall.Listener<SimpleResponse> responseListener, Metadata headers) {
}
@Override
public void request(int numMessages) {
}
@Override
public void cancel(String message, Throwable cause) {
}
@Override
public void halfClose() {
}
@Override
public void sendMessage(SimpleRequest message) {
}
};
when(channel.newCall( when(channel.newCall(
Mockito.<MethodDescriptor<SimpleRequest, SimpleResponse>>any(), any(CallOptions.class))) Mockito.<MethodDescriptor<SimpleRequest, SimpleResponse>>any(), any(CallOptions.class)))
.thenReturn(call); .thenReturn(call);
} }
@Test @Test

View File

@ -5,6 +5,7 @@ plugins {
description = "gRPC: Stub" description = "gRPC: Stub"
dependencies { dependencies {
compile project(':grpc-core') compile project(':grpc-core')
testCompile libraries.truth
} }
// Configure the animal sniffer plugin // Configure the animal sniffer plugin

View File

@ -31,21 +31,18 @@
package io.grpc.stub; package io.grpc.stub;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame; import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture; import com.google.common.util.concurrent.SettableFuture;
import io.grpc.CallOptions; import io.grpc.CallOptions;
import io.grpc.ClientCall; import io.grpc.ClientCall;
import io.grpc.ClientCall.Listener;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
@ -64,12 +61,7 @@ import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations; import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -80,6 +72,7 @@ import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.util.concurrent.Semaphore; import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
/** /**
* Unit tests for {@link ClientCalls}. * Unit tests for {@link ClientCalls}.
@ -96,9 +89,6 @@ public class ClientCallsTest {
private Server server; private Server server;
private ManagedChannel channel; private ManagedChannel channel;
@Mock
private ClientCall<Integer, String> call;
@Before @Before
public void setUp() { public void setUp() {
MockitoAnnotations.initMocks(this); MockitoAnnotations.initMocks(this);
@ -121,16 +111,13 @@ public class ClientCallsTest {
final Status status = Status.OK; final Status status = Status.OK;
final Metadata trailers = new Metadata(); final Metadata trailers = new Metadata();
doAnswer(new Answer<Void>() { BaseClientCall call = new BaseClientCall() {
@Override @Override
public Void answer(InvocationOnMock in) throws Throwable { public void start(ClientCall.Listener<String> listener, Metadata headers) {
@SuppressWarnings("unchecked")
Listener<String> listener = (Listener<String>) in.getArguments()[0];
listener.onMessage(resp); listener.onMessage(resp);
listener.onClose(status, trailers); listener.onClose(status, trailers);
return null;
} }
}).when(call).start(Mockito.<Listener<String>>any(), any(Metadata.class)); };
String actualResponse = ClientCalls.blockingUnaryCall(call, req); String actualResponse = ClientCalls.blockingUnaryCall(call, req);
assertEquals(resp, actualResponse); assertEquals(resp, actualResponse);
@ -142,15 +129,12 @@ public class ClientCallsTest {
final Status status = Status.INTERNAL.withDescription("Unique status"); final Status status = Status.INTERNAL.withDescription("Unique status");
final Metadata trailers = new Metadata(); final Metadata trailers = new Metadata();
doAnswer(new Answer<Void>() { BaseClientCall call = new BaseClientCall() {
@Override @Override
public Void answer(InvocationOnMock in) throws Throwable { public void start(io.grpc.ClientCall.Listener<String> listener, Metadata headers) {
@SuppressWarnings("unchecked")
Listener<String> listener = (Listener<String>) in.getArguments()[0];
listener.onClose(status, trailers); listener.onClose(status, trailers);
return null;
} }
}).when(call).start(Mockito.<Listener<String>>any(), any(Metadata.class)); };
try { try {
ClientCalls.blockingUnaryCall(call, req); ClientCalls.blockingUnaryCall(call, req);
@ -163,27 +147,50 @@ public class ClientCallsTest {
@Test @Test
public void unaryFutureCallSuccess() throws Exception { public void unaryFutureCallSuccess() throws Exception {
final AtomicReference<ClientCall.Listener<String>> listener =
new AtomicReference<ClientCall.Listener<String>>();
final AtomicReference<Integer> message = new AtomicReference<Integer>();
final AtomicReference<Boolean> halfClosed = new AtomicReference<Boolean>();
BaseClientCall call = new BaseClientCall() {
@Override
public void start(io.grpc.ClientCall.Listener<String> responseListener, Metadata headers) {
listener.set(responseListener);
}
@Override
public void sendMessage(Integer msg) {
message.set(msg);
}
@Override
public void halfClose() {
halfClosed.set(true);
}
};
Integer req = 2; Integer req = 2;
ListenableFuture<String> future = ClientCalls.futureUnaryCall(call, req); ListenableFuture<String> future = ClientCalls.futureUnaryCall(call, req);
ArgumentCaptor<ClientCall.Listener<String>> listenerCaptor = ArgumentCaptor.forClass(null);
verify(call).start(listenerCaptor.capture(), any(Metadata.class)); assertEquals(req, message.get());
ClientCall.Listener<String> listener = listenerCaptor.getValue(); assertTrue(halfClosed.get());
verify(call).sendMessage(req); listener.get().onMessage("bar");
verify(call).halfClose(); listener.get().onClose(Status.OK, new Metadata());
listener.onMessage("bar");
listener.onClose(Status.OK, new Metadata());
assertEquals("bar", future.get()); assertEquals("bar", future.get());
} }
@Test @Test
public void unaryFutureCallFailed() throws Exception { public void unaryFutureCallFailed() throws Exception {
final AtomicReference<ClientCall.Listener<String>> listener =
new AtomicReference<ClientCall.Listener<String>>();
BaseClientCall call = new BaseClientCall() {
@Override
public void start(io.grpc.ClientCall.Listener<String> responseListener, Metadata headers) {
listener.set(responseListener);
}
};
Integer req = 2; Integer req = 2;
ListenableFuture<String> future = ClientCalls.futureUnaryCall(call, req); ListenableFuture<String> future = ClientCalls.futureUnaryCall(call, req);
ArgumentCaptor<ClientCall.Listener<String>> listenerCaptor = ArgumentCaptor.forClass(null);
verify(call).start(listenerCaptor.capture(), any(Metadata.class));
ClientCall.Listener<String> listener = listenerCaptor.getValue();
Metadata trailers = new Metadata(); Metadata trailers = new Metadata();
listener.onClose(Status.INTERNAL, trailers); listener.get().onClose(Status.INTERNAL, trailers);
try { try {
future.get(); future.get();
fail("Should fail"); fail("Should fail");
@ -197,15 +204,29 @@ public class ClientCallsTest {
@Test @Test
public void unaryFutureCallCancelled() throws Exception { public void unaryFutureCallCancelled() throws Exception {
final AtomicReference<ClientCall.Listener<String>> listener =
new AtomicReference<ClientCall.Listener<String>>();
final AtomicReference<String> cancelMessage = new AtomicReference<String>();
final AtomicReference<Throwable> cancelCause = new AtomicReference<Throwable>();
BaseClientCall call = new BaseClientCall() {
@Override
public void start(io.grpc.ClientCall.Listener<String> responseListener, Metadata headers) {
listener.set(responseListener);
}
@Override
public void cancel(String message, Throwable cause) {
cancelMessage.set(message);
cancelCause.set(cause);
}
};
Integer req = 2; Integer req = 2;
ListenableFuture<String> future = ClientCalls.futureUnaryCall(call, req); ListenableFuture<String> future = ClientCalls.futureUnaryCall(call, req);
ArgumentCaptor<ClientCall.Listener<String>> listenerCaptor = ArgumentCaptor.forClass(null);
verify(call).start(listenerCaptor.capture(), any(Metadata.class));
ClientCall.Listener<String> listener = listenerCaptor.getValue();
future.cancel(true); future.cancel(true);
verify(call).cancel("GrpcFuture was cancelled", null); assertEquals("GrpcFuture was cancelled", cancelMessage.get());
listener.onMessage("bar"); assertNull(cancelCause.get());
listener.onClose(Status.OK, new Metadata()); listener.get().onMessage("bar");
listener.get().onClose(Status.OK, new Metadata());
try { try {
future.get(); future.get();
fail("Should fail"); fail("Should fail");
@ -216,6 +237,7 @@ public class ClientCallsTest {
@Test @Test
public void cannotSetOnReadyAfterCallStarted() throws Exception { public void cannotSetOnReadyAfterCallStarted() throws Exception {
BaseClientCall call = new BaseClientCall();
CallStreamObserver<Integer> callStreamObserver = CallStreamObserver<Integer> callStreamObserver =
(CallStreamObserver<Integer>) ClientCalls.asyncClientStreamingCall(call, (CallStreamObserver<Integer>) ClientCalls.asyncClientStreamingCall(call,
new NoopStreamObserver<String>()); new NoopStreamObserver<String>());
@ -235,7 +257,20 @@ public class ClientCallsTest {
@Test @Test
public void disablingInboundAutoFlowControlSuppressesRequestsForMoreMessages() public void disablingInboundAutoFlowControlSuppressesRequestsForMoreMessages()
throws Exception { throws Exception {
ArgumentCaptor<ClientCall.Listener<String>> listenerCaptor = ArgumentCaptor.forClass(null); final AtomicReference<ClientCall.Listener<String>> listener =
new AtomicReference<ClientCall.Listener<String>>();
final List<Integer> requests = new ArrayList<Integer>();
BaseClientCall call = new BaseClientCall() {
@Override
public void start(io.grpc.ClientCall.Listener<String> responseListener, Metadata headers) {
listener.set(responseListener);
}
@Override
public void request(int numMessages) {
requests.add(numMessages);
}
};
ClientCalls.asyncBidiStreamingCall(call, new ClientResponseObserver<Integer, String>() { ClientCalls.asyncBidiStreamingCall(call, new ClientResponseObserver<Integer, String>() {
@Override @Override
public void beforeStart(ClientCallStreamObserver<Integer> requestStream) { public void beforeStart(ClientCallStreamObserver<Integer> requestStream) {
@ -257,15 +292,13 @@ public class ClientCallsTest {
} }
}); });
verify(call).start(listenerCaptor.capture(), any(Metadata.class)); listener.get().onMessage("message");
listenerCaptor.getValue().onMessage("message"); assertThat(requests).containsExactly(1);
verify(call, times(1)).request(1);
} }
@Test @Test
public void callStreamObserverPropagatesFlowControlRequestsToCall() public void callStreamObserverPropagatesFlowControlRequestsToCall()
throws Exception { throws Exception {
ArgumentCaptor<ClientCall.Listener<String>> listenerCaptor = ArgumentCaptor.forClass(null);
ClientResponseObserver<Integer, String> responseObserver = ClientResponseObserver<Integer, String> responseObserver =
new ClientResponseObserver<Integer, String>() { new ClientResponseObserver<Integer, String>() {
@Override @Override
@ -285,19 +318,32 @@ public class ClientCallsTest {
public void onCompleted() { public void onCompleted() {
} }
}; };
final AtomicReference<ClientCall.Listener<String>> listener =
new AtomicReference<ClientCall.Listener<String>>();
final List<Integer> requests = new ArrayList<Integer>();
BaseClientCall call = new BaseClientCall() {
@Override
public void start(io.grpc.ClientCall.Listener<String> responseListener, Metadata headers) {
listener.set(responseListener);
}
@Override
public void request(int numMessages) {
requests.add(numMessages);
}
};
CallStreamObserver<Integer> requestObserver = CallStreamObserver<Integer> requestObserver =
(CallStreamObserver<Integer>) (CallStreamObserver<Integer>)
ClientCalls.asyncBidiStreamingCall(call, responseObserver); ClientCalls.asyncBidiStreamingCall(call, responseObserver);
verify(call).start(listenerCaptor.capture(), any(Metadata.class)); listener.get().onMessage("message");
listenerCaptor.getValue().onMessage("message");
requestObserver.request(5); requestObserver.request(5);
verify(call, times(1)).request(5); assertThat(requests).contains(5);
} }
@Test @Test
public void canCaptureInboundFlowControlForServerStreamingObserver() public void canCaptureInboundFlowControlForServerStreamingObserver()
throws Exception { throws Exception {
ArgumentCaptor<ClientCall.Listener<String>> listenerCaptor = ArgumentCaptor.forClass(null);
ClientResponseObserver<Integer, String> responseObserver = ClientResponseObserver<Integer, String> responseObserver =
new ClientResponseObserver<Integer, String>() { new ClientResponseObserver<Integer, String>() {
@Override @Override
@ -318,11 +364,23 @@ public class ClientCallsTest {
public void onCompleted() { public void onCompleted() {
} }
}; };
final AtomicReference<ClientCall.Listener<String>> listener =
new AtomicReference<ClientCall.Listener<String>>();
final List<Integer> requests = new ArrayList<Integer>();
BaseClientCall call = new BaseClientCall() {
@Override
public void start(io.grpc.ClientCall.Listener<String> responseListener, Metadata headers) {
listener.set(responseListener);
}
@Override
public void request(int numMessages) {
requests.add(numMessages);
}
};
ClientCalls.asyncServerStreamingCall(call, 1, responseObserver); ClientCalls.asyncServerStreamingCall(call, 1, responseObserver);
verify(call).start(listenerCaptor.capture(), any(Metadata.class)); listener.get().onMessage("message");
listenerCaptor.getValue().onMessage("message"); assertThat(requests).containsExactly(5, 1).inOrder();
verify(call, times(1)).request(1);
verify(call, times(1)).request(5);
} }
@Test @Test
@ -497,13 +555,20 @@ public class ClientCallsTest {
@Test @Test
public void blockingResponseStreamFailed() throws Exception { public void blockingResponseStreamFailed() throws Exception {
final AtomicReference<ClientCall.Listener<String>> listener =
new AtomicReference<ClientCall.Listener<String>>();
BaseClientCall call = new BaseClientCall() {
@Override
public void start(io.grpc.ClientCall.Listener<String> responseListener, Metadata headers) {
listener.set(responseListener);
}
};
Integer req = 2; Integer req = 2;
Iterator<String> iter = ClientCalls.blockingServerStreamingCall(call, req); Iterator<String> iter = ClientCalls.blockingServerStreamingCall(call, req);
ArgumentCaptor<ClientCall.Listener<String>> listenerCaptor = ArgumentCaptor.forClass(null);
verify(call).start(listenerCaptor.capture(), any(Metadata.class));
ClientCall.Listener<String> listener = listenerCaptor.getValue();
Metadata trailers = new Metadata(); Metadata trailers = new Metadata();
listener.onClose(Status.INTERNAL, trailers); listener.get().onClose(Status.INTERNAL, trailers);
try { try {
iter.next(); iter.next();
fail("Should fail"); fail("Should fail");
@ -514,4 +579,21 @@ public class ClientCallsTest {
assertSame(trailers, metadata); assertSame(trailers, metadata);
} }
} }
private static class BaseClientCall extends ClientCall<Integer, String> {
@Override
public void start(io.grpc.ClientCall.Listener<String> responseListener, Metadata headers) {}
@Override
public void request(int numMessages) {}
@Override
public void cancel(String message, Throwable cause) {}
@Override
public void halfClose() {}
@Override
public void sendMessage(Integer message) {}
}
} }

View File

@ -31,12 +31,12 @@
package io.grpc.stub; package io.grpc.stub;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.mockito.Mockito.times;
import io.grpc.CallOptions; import io.grpc.CallOptions;
import io.grpc.ClientCall; import io.grpc.ClientCall;
@ -51,13 +51,9 @@ import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.internal.ManagedChannelImpl; import io.grpc.internal.ManagedChannelImpl;
import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.JUnit4;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
@ -65,6 +61,8 @@ import java.io.DataInputStream;
import java.io.DataOutputStream; import java.io.DataOutputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Semaphore; import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@ -88,13 +86,7 @@ public class ServerCallsTest {
"some/unarymethod", "some/unarymethod",
new IntegerMarshaller(), new IntegerMarshaller()); new IntegerMarshaller(), new IntegerMarshaller());
@Mock private final ServerCallRecorder serverCall = new ServerCallRecorder();
ServerCall<Integer, Integer> serverCall;
@Before
public void setUp() throws Exception {
MockitoAnnotations.initMocks(this);
}
@Test @Test
public void runtimeStreamObserverIsServerCallStreamObserver() throws Exception { public void runtimeStreamObserverIsServerCallStreamObserver() throws Exception {
@ -130,8 +122,8 @@ public class ServerCallsTest {
}); });
ServerCall.Listener<Integer> callListener = ServerCall.Listener<Integer> callListener =
callHandler.startCall(serverCall, new Metadata()); callHandler.startCall(serverCall, new Metadata());
Mockito.when(serverCall.isReady()).thenReturn(true).thenReturn(false); serverCall.isReady = true;
Mockito.when(serverCall.isCancelled()).thenReturn(false).thenReturn(true); serverCall.isCancelled = false;
assertTrue(callObserver.get().isReady()); assertTrue(callObserver.get().isReady());
assertFalse(callObserver.get().isCancelled()); assertFalse(callObserver.get().isCancelled());
callListener.onReady(); callListener.onReady();
@ -140,11 +132,13 @@ public class ServerCallsTest {
assertTrue(invokeCalled.get()); assertTrue(invokeCalled.get());
assertTrue(onReadyCalled.get()); assertTrue(onReadyCalled.get());
assertTrue(onCancelCalled.get()); assertTrue(onCancelCalled.get());
serverCall.isReady = false;
serverCall.isCancelled = true;
assertFalse(callObserver.get().isReady()); assertFalse(callObserver.get().isReady());
assertTrue(callObserver.get().isCancelled()); assertTrue(callObserver.get().isCancelled());
// Is called twice, once to permit the first message and once again after the first message // Is called twice, once to permit the first message and once again after the first message
// has been processed (auto flow control) // has been processed (auto flow control)
Mockito.verify(serverCall, times(2)).request(1); assertThat(serverCall.requestCalls).containsExactly(1, 1).inOrder();
} }
@Test @Test
@ -247,7 +241,7 @@ public class ServerCallsTest {
// to verify that message delivery does not trigger a call to request(1). // to verify that message delivery does not trigger a call to request(1).
callListener.onMessage(1); callListener.onMessage(1);
// Should never be called // Should never be called
Mockito.verify(serverCall, times(0)).request(1); assertThat(serverCall.requestCalls).isEmpty();
} }
@Test @Test
@ -265,7 +259,7 @@ public class ServerCallsTest {
callHandler.startCall(serverCall, new Metadata()); callHandler.startCall(serverCall, new Metadata());
// Auto inbound flow-control always requests 2 messages for unary to detect a violation // Auto inbound flow-control always requests 2 messages for unary to detect a violation
// of the unary semantic. // of the unary semantic.
Mockito.verify(serverCall, times(1)).request(2); assertThat(serverCall.requestCalls).containsExactly(2);
} }
@Test @Test
@ -288,8 +282,8 @@ public class ServerCallsTest {
}); });
ServerCall.Listener<Integer> callListener = ServerCall.Listener<Integer> callListener =
callHandler.startCall(serverCall, new Metadata()); callHandler.startCall(serverCall, new Metadata());
Mockito.when(serverCall.isReady()).thenReturn(true).thenReturn(false); serverCall.isReady = true;
Mockito.when(serverCall.isCancelled()).thenReturn(false).thenReturn(true); serverCall.isCancelled = false;
callListener.onReady(); callListener.onReady();
// On ready is not called until the unary request message is delivered // On ready is not called until the unary request message is delivered
assertEquals(0, onReadyCalled.get()); assertEquals(0, onReadyCalled.get());
@ -392,4 +386,51 @@ public class ServerCallsTest {
} }
} }
} }
private static class ServerCallRecorder extends ServerCall<Integer, Integer> {
private List<Integer> requestCalls = new ArrayList<Integer>();
private Metadata headers;
private Integer message;
private Metadata trailers;
private Status status;
private boolean isCancelled;
private MethodDescriptor<Integer, Integer> methodDescriptor;
private boolean isReady;
@Override
public void request(int numMessages) {
requestCalls.add(numMessages);
}
@Override
public void sendHeaders(Metadata headers) {
this.headers = headers;
}
@Override
public void sendMessage(Integer message) {
this.message = message;
}
@Override
public void close(Status status, Metadata trailers) {
this.status = status;
this.trailers = trailers;
}
@Override
public boolean isCancelled() {
return isCancelled;
}
@Override
public boolean isReady() {
return isReady;
}
@Override
public MethodDescriptor<Integer, Integer> getMethodDescriptor() {
return methodDescriptor;
}
}
} }