diff --git a/core/src/main/java/io/grpc/ClientCall.java b/core/src/main/java/io/grpc/ClientCall.java index f6111a7321..268b9797a1 100644 --- a/core/src/main/java/io/grpc/ClientCall.java +++ b/core/src/main/java/io/grpc/ClientCall.java @@ -61,6 +61,16 @@ package io.grpc; * {@link Status#CANCELLED CANCELLED}. Otherwise, {@link Listener#onClose Listener.onClose()} is * called with whatever status the RPC was finished. We ensure that at most one is called. * + *

Example: A simple Unary (1 request, 1 response) RPC would look like this: + *

+ *   call = channel.newCall(method, callOptions);
+ *   call.start(listener, headers);
+ *   call.sendMessage(message);
+ *   call.halfClose();
+ *   call.request(1);
+ *   // wait for listener.onMessage()
+ * 
+ * * @param type of message sent one or more times to the server. * @param type of message received one or more times from the server. */ @@ -157,7 +167,8 @@ public abstract class ClientCall { public abstract void cancel(); /** - * Close the call for request message sending. Incoming response messages are unaffected. + * Close the call for request message sending. Incoming response messages are unaffected. This + * should be called when no more messages will be sent from the client. * * @throws IllegalStateException if call is already {@code halfClose()}d or {@link #cancel}ed */ diff --git a/core/src/main/java/io/grpc/internal/ServerCallImpl.java b/core/src/main/java/io/grpc/internal/ServerCallImpl.java index 622174ee05..2c9322c8ed 100644 --- a/core/src/main/java/io/grpc/internal/ServerCallImpl.java +++ b/core/src/main/java/io/grpc/internal/ServerCallImpl.java @@ -34,11 +34,13 @@ package io.grpc.internal; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Throwables; import io.grpc.Context; import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.MethodDescriptor.MethodType; import io.grpc.ServerCall; import io.grpc.Status; @@ -121,33 +123,45 @@ final class ServerCallImpl extends ServerCall { return cancelled; } - ServerStreamListenerImpl newServerStreamListener(ServerCall.Listener listener, + ServerStreamListener newServerStreamListener(ServerCall.Listener listener, Future timeout) { - return new ServerStreamListenerImpl(listener, timeout); + return new ServerStreamListenerImpl(this, listener, timeout); } /** * All of these callbacks are assumed to called on an application thread, and the caller is * responsible for handling thrown exceptions. */ - private class ServerStreamListenerImpl implements ServerStreamListener { + @VisibleForTesting + static final class ServerStreamListenerImpl implements ServerStreamListener { + private final ServerCallImpl call; private final ServerCall.Listener listener; private final Future timeout; + private boolean messageReceived; - public ServerStreamListenerImpl(ServerCall.Listener listener, Future timeout) { + public ServerStreamListenerImpl( + ServerCallImpl call, ServerCall.Listener listener, Future timeout) { + this.call = checkNotNull(call, "call"); this.listener = checkNotNull(listener, "listener must not be null"); - // TODO: check if timeout should not be null - this.timeout = timeout; + this.timeout = checkNotNull(timeout, "timeout"); } @Override public void messageRead(final InputStream message) { try { - if (cancelled) { + if (call.cancelled) { return; } + // Special case for unary calls. + if (messageReceived && call.method.getType() == MethodType.UNARY) { + call.stream.close(Status.INVALID_ARGUMENT.withDescription( + "More than one request messages for unary call or server streaming call"), + new Metadata()); + return; + } + messageReceived = true; - listener.onMessage(method.parseRequest(message)); + listener.onMessage(call.method.parseRequest(message)); } finally { try { message.close(); @@ -159,7 +173,7 @@ final class ServerCallImpl extends ServerCall { @Override public void halfClosed() { - if (cancelled) { + if (call.cancelled) { return; } @@ -172,14 +186,14 @@ final class ServerCallImpl extends ServerCall { if (status.isOk()) { listener.onComplete(); } else { - cancelled = true; + call.cancelled = true; listener.onCancel(); } } @Override public void onReady() { - if (cancelled) { + if (call.cancelled) { return; } listener.onReady(); diff --git a/core/src/test/java/io/grpc/internal/ServerCallImplTest.java b/core/src/test/java/io/grpc/internal/ServerCallImplTest.java index 796cd3c506..2742ee28c3 100644 --- a/core/src/test/java/io/grpc/internal/ServerCallImplTest.java +++ b/core/src/test/java/io/grpc/internal/ServerCallImplTest.java @@ -31,6 +31,7 @@ package io.grpc.internal; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; @@ -41,13 +42,16 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import com.google.common.io.CharStreams; +import com.google.common.util.concurrent.Futures; import io.grpc.Context; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor.Marshaller; import io.grpc.MethodDescriptor.MethodType; +import io.grpc.ServerCall; import io.grpc.Status; +import io.grpc.internal.ServerCallImpl.ServerStreamListenerImpl; import org.junit.Before; import org.junit.Rule; @@ -55,17 +59,25 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import java.io.ByteArrayInputStream; import java.io.InputStream; import java.io.InputStreamReader; +import java.util.concurrent.Future; @RunWith(JUnit4.class) public class ServerCallImplTest { @Rule public final ExpectedException thrown = ExpectedException.none(); @Mock private ServerStream stream; + @Mock private ServerCall.Listener callListener; + @Captor private ArgumentCaptor statusCaptor; + + private final Future timeout = Futures.immediateCancelledFuture(); private ServerCallImpl call; private Context.CancellableContext context; @@ -186,6 +198,110 @@ public class ServerCallImplTest { verify(stream).setMessageCompression(true); } + @Test + public void streamListener_halfClosed() { + ServerStreamListenerImpl streamListener = + new ServerCallImpl.ServerStreamListenerImpl(call, callListener, timeout); + + streamListener.halfClosed(); + + verify(callListener).onHalfClose(); + } + + @Test + public void streamListener_halfClosed_onlyOnce() { + ServerStreamListenerImpl streamListener = + new ServerCallImpl.ServerStreamListenerImpl(call, callListener, timeout); + streamListener.halfClosed(); + // canceling the call should short circuit future halfClosed() calls. + streamListener.closed(Status.CANCELLED); + + streamListener.halfClosed(); + + verify(callListener).onHalfClose(); + } + + @Test + public void streamListener_closedOk() { + ServerStreamListenerImpl streamListener = + new ServerCallImpl.ServerStreamListenerImpl(call, callListener, timeout); + + streamListener.closed(Status.OK); + + verify(callListener).onComplete(); + assertTrue(timeout.isCancelled()); + } + + @Test + public void streamListener_closedCancelled() { + ServerStreamListenerImpl streamListener = + new ServerCallImpl.ServerStreamListenerImpl(call, callListener, timeout); + + streamListener.closed(Status.CANCELLED); + + verify(callListener).onCancel(); + assertTrue(timeout.isCancelled()); + } + + @Test + public void streamListener_onReady() { + ServerStreamListenerImpl streamListener = + new ServerCallImpl.ServerStreamListenerImpl(call, callListener, timeout); + + streamListener.onReady(); + + verify(callListener).onReady(); + } + + @Test + public void streamListener_onReady_onlyOnce() { + ServerStreamListenerImpl streamListener = + new ServerCallImpl.ServerStreamListenerImpl(call, callListener, timeout); + streamListener.onReady(); + // canceling the call should short circuit future halfClosed() calls. + streamListener.closed(Status.CANCELLED); + + streamListener.onReady(); + + verify(callListener).onReady(); + } + + @Test + public void streamListener_messageRead() { + ServerStreamListenerImpl streamListener = + new ServerCallImpl.ServerStreamListenerImpl(call, callListener, timeout); + streamListener.messageRead(method.streamRequest(1234L)); + + verify(callListener).onMessage(1234L); + } + + @Test + public void streamListener_messageRead_unaryFailsOnMultiple() { + ServerStreamListenerImpl streamListener = + new ServerCallImpl.ServerStreamListenerImpl(call, callListener, timeout); + streamListener.messageRead(method.streamRequest(1234L)); + streamListener.messageRead(method.streamRequest(1234L)); + + // Makes sure this was only called once. + verify(callListener).onMessage(1234L); + + verify(stream).close(statusCaptor.capture(), Mockito.isA(Metadata.class)); + assertEquals(Status.Code.INVALID_ARGUMENT, statusCaptor.getValue().getCode()); + } + + @Test + public void streamListener_messageRead_onlyOnce() { + ServerStreamListenerImpl streamListener = + new ServerCallImpl.ServerStreamListenerImpl(call, callListener, timeout); + streamListener.messageRead(method.streamRequest(1234L)); + // canceling the call should short circuit future halfClosed() calls. + streamListener.closed(Status.CANCELLED); + + streamListener.messageRead(method.streamRequest(1234L)); + + verify(callListener).onMessage(1234L); + } + private static class LongMarshaller implements Marshaller { @Override public InputStream stream(Long value) { diff --git a/stub/src/main/java/io/grpc/stub/ServerCalls.java b/stub/src/main/java/io/grpc/stub/ServerCalls.java index dc2a460448..e01de42fa1 100644 --- a/stub/src/main/java/io/grpc/stub/ServerCalls.java +++ b/stub/src/main/java/io/grpc/stub/ServerCalls.java @@ -130,22 +130,15 @@ public class ServerCalls { Metadata headers) { final ResponseObserver responseObserver = new ResponseObserver(call); // We expect only 1 request, but we ask for 2 requests here so that if a misbehaving client - // sends more than 1 requests, we will catch it in onMessage() and emit INVALID_ARGUMENT. + // sends more than 1 requests, ServerCall will catch it. call.request(2); return new EmptyServerCallListener() { ReqT request; @Override public void onMessage(ReqT request) { - if (this.request == null) { - // We delay calling method.invoke() until onHalfClose(), because application may call - // close(OK) inside invoke(), while close(OK) is not allowed before onHalfClose(). - this.request = request; - } else { - call.close( - Status.INVALID_ARGUMENT.withDescription( - "More than one request messages for unary call or server streaming call"), - new Metadata()); - } + // We delay calling method.invoke() until onHalfClose() to make sure the client + // half-closes. + this.request = request; } @Override