diff --git a/core/src/main/java/io/grpc/ClientCall.java b/core/src/main/java/io/grpc/ClientCall.java
index f6111a7321..268b9797a1 100644
--- a/core/src/main/java/io/grpc/ClientCall.java
+++ b/core/src/main/java/io/grpc/ClientCall.java
@@ -61,6 +61,16 @@ package io.grpc;
* {@link Status#CANCELLED CANCELLED}. Otherwise, {@link Listener#onClose Listener.onClose()} is
* called with whatever status the RPC was finished. We ensure that at most one is called.
*
+ *
Example: A simple Unary (1 request, 1 response) RPC would look like this:
+ *
+ * call = channel.newCall(method, callOptions);
+ * call.start(listener, headers);
+ * call.sendMessage(message);
+ * call.halfClose();
+ * call.request(1);
+ * // wait for listener.onMessage()
+ *
+ *
* @param type of message sent one or more times to the server.
* @param type of message received one or more times from the server.
*/
@@ -157,7 +167,8 @@ public abstract class ClientCall {
public abstract void cancel();
/**
- * Close the call for request message sending. Incoming response messages are unaffected.
+ * Close the call for request message sending. Incoming response messages are unaffected. This
+ * should be called when no more messages will be sent from the client.
*
* @throws IllegalStateException if call is already {@code halfClose()}d or {@link #cancel}ed
*/
diff --git a/core/src/main/java/io/grpc/internal/ServerCallImpl.java b/core/src/main/java/io/grpc/internal/ServerCallImpl.java
index 622174ee05..2c9322c8ed 100644
--- a/core/src/main/java/io/grpc/internal/ServerCallImpl.java
+++ b/core/src/main/java/io/grpc/internal/ServerCallImpl.java
@@ -34,11 +34,13 @@ package io.grpc.internal;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Throwables;
import io.grpc.Context;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
+import io.grpc.MethodDescriptor.MethodType;
import io.grpc.ServerCall;
import io.grpc.Status;
@@ -121,33 +123,45 @@ final class ServerCallImpl extends ServerCall {
return cancelled;
}
- ServerStreamListenerImpl newServerStreamListener(ServerCall.Listener listener,
+ ServerStreamListener newServerStreamListener(ServerCall.Listener listener,
Future> timeout) {
- return new ServerStreamListenerImpl(listener, timeout);
+ return new ServerStreamListenerImpl(this, listener, timeout);
}
/**
* All of these callbacks are assumed to called on an application thread, and the caller is
* responsible for handling thrown exceptions.
*/
- private class ServerStreamListenerImpl implements ServerStreamListener {
+ @VisibleForTesting
+ static final class ServerStreamListenerImpl implements ServerStreamListener {
+ private final ServerCallImpl call;
private final ServerCall.Listener listener;
private final Future> timeout;
+ private boolean messageReceived;
- public ServerStreamListenerImpl(ServerCall.Listener listener, Future> timeout) {
+ public ServerStreamListenerImpl(
+ ServerCallImpl call, ServerCall.Listener listener, Future> timeout) {
+ this.call = checkNotNull(call, "call");
this.listener = checkNotNull(listener, "listener must not be null");
- // TODO: check if timeout should not be null
- this.timeout = timeout;
+ this.timeout = checkNotNull(timeout, "timeout");
}
@Override
public void messageRead(final InputStream message) {
try {
- if (cancelled) {
+ if (call.cancelled) {
return;
}
+ // Special case for unary calls.
+ if (messageReceived && call.method.getType() == MethodType.UNARY) {
+ call.stream.close(Status.INVALID_ARGUMENT.withDescription(
+ "More than one request messages for unary call or server streaming call"),
+ new Metadata());
+ return;
+ }
+ messageReceived = true;
- listener.onMessage(method.parseRequest(message));
+ listener.onMessage(call.method.parseRequest(message));
} finally {
try {
message.close();
@@ -159,7 +173,7 @@ final class ServerCallImpl extends ServerCall {
@Override
public void halfClosed() {
- if (cancelled) {
+ if (call.cancelled) {
return;
}
@@ -172,14 +186,14 @@ final class ServerCallImpl extends ServerCall {
if (status.isOk()) {
listener.onComplete();
} else {
- cancelled = true;
+ call.cancelled = true;
listener.onCancel();
}
}
@Override
public void onReady() {
- if (cancelled) {
+ if (call.cancelled) {
return;
}
listener.onReady();
diff --git a/core/src/test/java/io/grpc/internal/ServerCallImplTest.java b/core/src/test/java/io/grpc/internal/ServerCallImplTest.java
index 796cd3c506..2742ee28c3 100644
--- a/core/src/test/java/io/grpc/internal/ServerCallImplTest.java
+++ b/core/src/test/java/io/grpc/internal/ServerCallImplTest.java
@@ -31,6 +31,7 @@
package io.grpc.internal;
+import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
@@ -41,13 +42,16 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.common.io.CharStreams;
+import com.google.common.util.concurrent.Futures;
import io.grpc.Context;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.MethodDescriptor.Marshaller;
import io.grpc.MethodDescriptor.MethodType;
+import io.grpc.ServerCall;
import io.grpc.Status;
+import io.grpc.internal.ServerCallImpl.ServerStreamListenerImpl;
import org.junit.Before;
import org.junit.Rule;
@@ -55,17 +59,25 @@ import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Captor;
import org.mockito.Mock;
+import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.io.InputStreamReader;
+import java.util.concurrent.Future;
@RunWith(JUnit4.class)
public class ServerCallImplTest {
@Rule public final ExpectedException thrown = ExpectedException.none();
@Mock private ServerStream stream;
+ @Mock private ServerCall.Listener callListener;
+ @Captor private ArgumentCaptor statusCaptor;
+
+ private final Future> timeout = Futures.immediateCancelledFuture();
private ServerCallImpl call;
private Context.CancellableContext context;
@@ -186,6 +198,110 @@ public class ServerCallImplTest {
verify(stream).setMessageCompression(true);
}
+ @Test
+ public void streamListener_halfClosed() {
+ ServerStreamListenerImpl streamListener =
+ new ServerCallImpl.ServerStreamListenerImpl(call, callListener, timeout);
+
+ streamListener.halfClosed();
+
+ verify(callListener).onHalfClose();
+ }
+
+ @Test
+ public void streamListener_halfClosed_onlyOnce() {
+ ServerStreamListenerImpl streamListener =
+ new ServerCallImpl.ServerStreamListenerImpl(call, callListener, timeout);
+ streamListener.halfClosed();
+ // canceling the call should short circuit future halfClosed() calls.
+ streamListener.closed(Status.CANCELLED);
+
+ streamListener.halfClosed();
+
+ verify(callListener).onHalfClose();
+ }
+
+ @Test
+ public void streamListener_closedOk() {
+ ServerStreamListenerImpl streamListener =
+ new ServerCallImpl.ServerStreamListenerImpl(call, callListener, timeout);
+
+ streamListener.closed(Status.OK);
+
+ verify(callListener).onComplete();
+ assertTrue(timeout.isCancelled());
+ }
+
+ @Test
+ public void streamListener_closedCancelled() {
+ ServerStreamListenerImpl streamListener =
+ new ServerCallImpl.ServerStreamListenerImpl(call, callListener, timeout);
+
+ streamListener.closed(Status.CANCELLED);
+
+ verify(callListener).onCancel();
+ assertTrue(timeout.isCancelled());
+ }
+
+ @Test
+ public void streamListener_onReady() {
+ ServerStreamListenerImpl streamListener =
+ new ServerCallImpl.ServerStreamListenerImpl(call, callListener, timeout);
+
+ streamListener.onReady();
+
+ verify(callListener).onReady();
+ }
+
+ @Test
+ public void streamListener_onReady_onlyOnce() {
+ ServerStreamListenerImpl streamListener =
+ new ServerCallImpl.ServerStreamListenerImpl(call, callListener, timeout);
+ streamListener.onReady();
+ // canceling the call should short circuit future halfClosed() calls.
+ streamListener.closed(Status.CANCELLED);
+
+ streamListener.onReady();
+
+ verify(callListener).onReady();
+ }
+
+ @Test
+ public void streamListener_messageRead() {
+ ServerStreamListenerImpl streamListener =
+ new ServerCallImpl.ServerStreamListenerImpl(call, callListener, timeout);
+ streamListener.messageRead(method.streamRequest(1234L));
+
+ verify(callListener).onMessage(1234L);
+ }
+
+ @Test
+ public void streamListener_messageRead_unaryFailsOnMultiple() {
+ ServerStreamListenerImpl streamListener =
+ new ServerCallImpl.ServerStreamListenerImpl(call, callListener, timeout);
+ streamListener.messageRead(method.streamRequest(1234L));
+ streamListener.messageRead(method.streamRequest(1234L));
+
+ // Makes sure this was only called once.
+ verify(callListener).onMessage(1234L);
+
+ verify(stream).close(statusCaptor.capture(), Mockito.isA(Metadata.class));
+ assertEquals(Status.Code.INVALID_ARGUMENT, statusCaptor.getValue().getCode());
+ }
+
+ @Test
+ public void streamListener_messageRead_onlyOnce() {
+ ServerStreamListenerImpl streamListener =
+ new ServerCallImpl.ServerStreamListenerImpl(call, callListener, timeout);
+ streamListener.messageRead(method.streamRequest(1234L));
+ // canceling the call should short circuit future halfClosed() calls.
+ streamListener.closed(Status.CANCELLED);
+
+ streamListener.messageRead(method.streamRequest(1234L));
+
+ verify(callListener).onMessage(1234L);
+ }
+
private static class LongMarshaller implements Marshaller {
@Override
public InputStream stream(Long value) {
diff --git a/stub/src/main/java/io/grpc/stub/ServerCalls.java b/stub/src/main/java/io/grpc/stub/ServerCalls.java
index dc2a460448..e01de42fa1 100644
--- a/stub/src/main/java/io/grpc/stub/ServerCalls.java
+++ b/stub/src/main/java/io/grpc/stub/ServerCalls.java
@@ -130,22 +130,15 @@ public class ServerCalls {
Metadata headers) {
final ResponseObserver responseObserver = new ResponseObserver(call);
// We expect only 1 request, but we ask for 2 requests here so that if a misbehaving client
- // sends more than 1 requests, we will catch it in onMessage() and emit INVALID_ARGUMENT.
+ // sends more than 1 requests, ServerCall will catch it.
call.request(2);
return new EmptyServerCallListener() {
ReqT request;
@Override
public void onMessage(ReqT request) {
- if (this.request == null) {
- // We delay calling method.invoke() until onHalfClose(), because application may call
- // close(OK) inside invoke(), while close(OK) is not allowed before onHalfClose().
- this.request = request;
- } else {
- call.close(
- Status.INVALID_ARGUMENT.withDescription(
- "More than one request messages for unary call or server streaming call"),
- new Metadata());
- }
+ // We delay calling method.invoke() until onHalfClose() to make sure the client
+ // half-closes.
+ this.request = request;
}
@Override