From 1cfba96d17526bd27ad896fc1601f2b6d3968c0c Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Fri, 11 Sep 2015 15:52:14 -0700 Subject: [PATCH] Enforce sending headers before messages on server ServerCall already had "headers must be sent before any messages, which must be sent before closing," but the implementation did not enforce it and our async server handler didn't obey. The benefit of forcing sending headers first is that it removes the only implicit call in our API and interceptors dealing just with metadata don't need to override sendMessage. The implicit behavior was bug-prone since it wasn't obvious you were forgetting that headers may not be sent. --- .../io/grpc/internal/AbstractServerStream.java | 5 ++--- .../grpc/internal/AbstractServerStreamTest.java | 3 +++ .../header/HeaderServerInterceptor.java | 17 ----------------- .../io/grpc/netty/NettyClientStreamTest.java | 5 ++++- .../io/grpc/netty/NettyClientTransportTest.java | 1 + .../io/grpc/netty/NettyServerStreamTest.java | 14 ++++++++++---- .../java/io/grpc/netty/NettyStreamTestBase.java | 11 ++++++++--- .../src/main/java/io/grpc/stub/ServerCalls.java | 5 +++++ .../main/java/io/grpc/testing/TestUtils.java | 11 ----------- 9 files changed, 33 insertions(+), 39 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/AbstractServerStream.java b/core/src/main/java/io/grpc/internal/AbstractServerStream.java index cfafa94434..750364713d 100644 --- a/core/src/main/java/io/grpc/internal/AbstractServerStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractServerStream.java @@ -102,9 +102,8 @@ public abstract class AbstractServerStream extends AbstractStream @Override public final void writeMessage(InputStream message) { - if (!headersSent) { - writeHeaders(new Metadata()); - headersSent = true; + if (outboundPhase() != Phase.MESSAGE) { + throw new IllegalStateException("Messages are only permitted after headers and before close"); } super.writeMessage(message); } diff --git a/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java b/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java index 1286276462..2443082903 100644 --- a/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java @@ -170,6 +170,7 @@ public class AbstractServerStreamTest { capturedHeaders.set(captured); } }; + stream.writeHeaders(new Metadata()); stream.writeMessage(new ByteArrayInputStream(new byte[]{})); @@ -204,6 +205,7 @@ public class AbstractServerStreamTest { sendCalled.set(true); } }; + stream.writeHeaders(new Metadata()); stream.closeFramer(); stream.writeMessage(new ByteArrayInputStream(new byte[]{})); @@ -220,6 +222,7 @@ public class AbstractServerStreamTest { sendCalled.set(true); } }; + stream.writeHeaders(new Metadata()); stream.writeMessage(new ByteArrayInputStream(new byte[]{})); // Force the message to be flushed diff --git a/examples/src/main/java/io/grpc/examples/header/HeaderServerInterceptor.java b/examples/src/main/java/io/grpc/examples/header/HeaderServerInterceptor.java index 6c8d07cb20..97a2b5a387 100644 --- a/examples/src/main/java/io/grpc/examples/header/HeaderServerInterceptor.java +++ b/examples/src/main/java/io/grpc/examples/header/HeaderServerInterceptor.java @@ -37,7 +37,6 @@ import io.grpc.MethodDescriptor; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; import io.grpc.ServerInterceptor; -import io.grpc.Status; import java.util.logging.Logger; @@ -60,26 +59,10 @@ public class HeaderServerInterceptor implements ServerInterceptor { ServerCallHandler next) { logger.info("header received from client:" + requestHeaders.toString()); return next.startCall(method, new SimpleForwardingServerCall(call) { - boolean sentHeaders = false; - @Override public void sendHeaders(Metadata responseHeaders) { responseHeaders.put(customHeadKey, "customRespondValue"); super.sendHeaders(responseHeaders); - sentHeaders = true; - } - - @Override - public void sendMessage(RespT message) { - if (!sentHeaders) { - sendHeaders(new Metadata()); - } - super.sendMessage(message); - } - - @Override - public void close(Status status, Metadata trailers) { - super.close(status, trailers); } }, requestHeaders); } diff --git a/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java b/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java index c866c18968..ad0964f727 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java @@ -80,7 +80,7 @@ import java.io.InputStream; * Tests for {@link NettyClientStream}. */ @RunWith(JUnit4.class) -public class NettyClientStreamTest extends NettyStreamTestBase { +public class NettyClientStreamTest extends NettyStreamTestBase { @Mock protected ClientStreamListener listener; @@ -371,6 +371,9 @@ public class NettyClientStreamTest extends NettyStreamTestBase { return stream; } + @Override + protected void sendHeadersIfServer() {} + @Override protected void closeStream() { stream().cancel(Status.CANCELLED); diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java index 83f634b5f9..40de685847 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java @@ -339,6 +339,7 @@ public class NettyClientTransportTest { this.stream = stream; this.method = method; this.headers = headers; + stream.writeHeaders(new Metadata()); stream.request(1); } diff --git a/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java b/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java index 55ac31e693..2128ada9b3 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java @@ -71,7 +71,7 @@ import java.io.ByteArrayInputStream; /** Unit tests for {@link NettyServerStream}. */ @RunWith(JUnit4.class) -public class NettyServerStreamTest extends NettyStreamTestBase { +public class NettyServerStreamTest extends NettyStreamTestBase { @Mock protected ServerStreamListener serverListener; @@ -92,13 +92,14 @@ public class NettyServerStreamTest extends NettyStreamTestBase { @Test public void writeMessageShouldSendResponse() throws Exception { - byte[] msg = smallMessage(); - stream.writeMessage(new ByteArrayInputStream(msg)); - stream.flush(); + stream.writeHeaders(new Metadata()); Http2Headers headers = new DefaultHttp2Headers() .status(Utils.STATUS_OK) .set(Utils.CONTENT_TYPE_HEADER, Utils.CONTENT_TYPE_GRPC); verify(writeQueue).enqueue(new SendResponseHeadersCommand(STREAM_ID, headers, false), true); + byte[] msg = smallMessage(); + stream.writeMessage(new ByteArrayInputStream(msg)); + stream.flush(); verify(writeQueue).enqueue(eq(new SendGrpcFrameCommand(stream, messageFrame(MESSAGE), false)), any(ChannelPromise.class), eq(true)); @@ -266,6 +267,11 @@ public class NettyServerStreamTest extends NettyStreamTestBase { return stream; } + @Override + protected void sendHeadersIfServer() { + stream.writeHeaders(new Metadata()); + } + @Override protected void closeStream() { stream().close(Status.ABORTED, new Metadata()); diff --git a/netty/src/test/java/io/grpc/netty/NettyStreamTestBase.java b/netty/src/test/java/io/grpc/netty/NettyStreamTestBase.java index 90099af725..23d47ac957 100644 --- a/netty/src/test/java/io/grpc/netty/NettyStreamTestBase.java +++ b/netty/src/test/java/io/grpc/netty/NettyStreamTestBase.java @@ -71,7 +71,7 @@ import java.util.concurrent.TimeUnit; /** * Base class for Netty stream unit tests. */ -public abstract class NettyStreamTestBase { +public abstract class NettyStreamTestBase> { protected static final String MESSAGE = "hello world"; protected static final int STREAM_ID = 1; @@ -99,7 +99,7 @@ public abstract class NettyStreamTestBase { @Mock protected WriteQueue writeQueue; - protected AbstractStream stream; + protected T stream; /** Set up for test. */ @Before @@ -160,6 +160,7 @@ public abstract class NettyStreamTestBase { @Test public void notifiedOnReadyAfterWriteCompletes() throws IOException { + sendHeadersIfServer(); assertTrue(stream.isReady()); byte[] msg = largeMessage(); // The future is set up to automatically complete, indicating that the write is done. @@ -171,6 +172,7 @@ public abstract class NettyStreamTestBase { @Test public void shouldBeReadyForDataAfterWritingSmallMessage() throws IOException { + sendHeadersIfServer(); // Make sure the writes don't complete so we "back up" reset(future); @@ -184,6 +186,7 @@ public abstract class NettyStreamTestBase { @Test public void shouldNotBeReadyForDataAfterWritingLargeMessage() throws IOException { + sendHeadersIfServer(); // Make sure the writes don't complete so we "back up" reset(future); @@ -209,7 +212,9 @@ public abstract class NettyStreamTestBase { return largeMessage; } - protected abstract AbstractStream createStream(); + protected abstract T createStream(); + + protected abstract void sendHeadersIfServer(); protected abstract StreamListener listener(); diff --git a/stub/src/main/java/io/grpc/stub/ServerCalls.java b/stub/src/main/java/io/grpc/stub/ServerCalls.java index 9870e8d392..9cf44508ac 100644 --- a/stub/src/main/java/io/grpc/stub/ServerCalls.java +++ b/stub/src/main/java/io/grpc/stub/ServerCalls.java @@ -223,6 +223,7 @@ public class ServerCalls { private static class ResponseObserver implements StreamObserver { final ServerCall call; volatile boolean cancelled; + private boolean sentHeaders; ResponseObserver(ServerCall call) { this.call = call; @@ -233,6 +234,10 @@ public class ServerCalls { if (cancelled) { throw Status.CANCELLED.asRuntimeException(); } + if (!sentHeaders) { + call.sendHeaders(new Metadata()); + sentHeaders = true; + } call.sendMessage(response); } diff --git a/testing/src/main/java/io/grpc/testing/TestUtils.java b/testing/src/main/java/io/grpc/testing/TestUtils.java index 53e3c39d9a..1127cfbc43 100644 --- a/testing/src/main/java/io/grpc/testing/TestUtils.java +++ b/testing/src/main/java/io/grpc/testing/TestUtils.java @@ -90,21 +90,10 @@ public class TestUtils { ServerCallHandler next) { return next.startCall(method, new SimpleForwardingServerCall(call) { - boolean sentHeaders; - @Override public void sendHeaders(Metadata responseHeaders) { responseHeaders.merge(requestHeaders, keySet); super.sendHeaders(responseHeaders); - sentHeaders = true; - } - - @Override - public void sendMessage(RespT message) { - if (!sentHeaders) { - sendHeaders(new Metadata()); - } - super.sendMessage(message); } @Override