diff --git a/core/src/main/java/io/grpc/internal/AbstractServerStream.java b/core/src/main/java/io/grpc/internal/AbstractServerStream.java index cfafa94434..750364713d 100644 --- a/core/src/main/java/io/grpc/internal/AbstractServerStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractServerStream.java @@ -102,9 +102,8 @@ public abstract class AbstractServerStream extends AbstractStream @Override public final void writeMessage(InputStream message) { - if (!headersSent) { - writeHeaders(new Metadata()); - headersSent = true; + if (outboundPhase() != Phase.MESSAGE) { + throw new IllegalStateException("Messages are only permitted after headers and before close"); } super.writeMessage(message); } diff --git a/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java b/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java index 1286276462..2443082903 100644 --- a/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java @@ -170,6 +170,7 @@ public class AbstractServerStreamTest { capturedHeaders.set(captured); } }; + stream.writeHeaders(new Metadata()); stream.writeMessage(new ByteArrayInputStream(new byte[]{})); @@ -204,6 +205,7 @@ public class AbstractServerStreamTest { sendCalled.set(true); } }; + stream.writeHeaders(new Metadata()); stream.closeFramer(); stream.writeMessage(new ByteArrayInputStream(new byte[]{})); @@ -220,6 +222,7 @@ public class AbstractServerStreamTest { sendCalled.set(true); } }; + stream.writeHeaders(new Metadata()); stream.writeMessage(new ByteArrayInputStream(new byte[]{})); // Force the message to be flushed diff --git a/examples/src/main/java/io/grpc/examples/header/HeaderServerInterceptor.java b/examples/src/main/java/io/grpc/examples/header/HeaderServerInterceptor.java index 6c8d07cb20..97a2b5a387 100644 --- a/examples/src/main/java/io/grpc/examples/header/HeaderServerInterceptor.java +++ b/examples/src/main/java/io/grpc/examples/header/HeaderServerInterceptor.java @@ -37,7 +37,6 @@ import io.grpc.MethodDescriptor; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; import io.grpc.ServerInterceptor; -import io.grpc.Status; import java.util.logging.Logger; @@ -60,26 +59,10 @@ public class HeaderServerInterceptor implements ServerInterceptor { ServerCallHandler next) { logger.info("header received from client:" + requestHeaders.toString()); return next.startCall(method, new SimpleForwardingServerCall(call) { - boolean sentHeaders = false; - @Override public void sendHeaders(Metadata responseHeaders) { responseHeaders.put(customHeadKey, "customRespondValue"); super.sendHeaders(responseHeaders); - sentHeaders = true; - } - - @Override - public void sendMessage(RespT message) { - if (!sentHeaders) { - sendHeaders(new Metadata()); - } - super.sendMessage(message); - } - - @Override - public void close(Status status, Metadata trailers) { - super.close(status, trailers); } }, requestHeaders); } diff --git a/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java b/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java index c866c18968..ad0964f727 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java @@ -80,7 +80,7 @@ import java.io.InputStream; * Tests for {@link NettyClientStream}. */ @RunWith(JUnit4.class) -public class NettyClientStreamTest extends NettyStreamTestBase { +public class NettyClientStreamTest extends NettyStreamTestBase { @Mock protected ClientStreamListener listener; @@ -371,6 +371,9 @@ public class NettyClientStreamTest extends NettyStreamTestBase { return stream; } + @Override + protected void sendHeadersIfServer() {} + @Override protected void closeStream() { stream().cancel(Status.CANCELLED); diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java index 83f634b5f9..40de685847 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java @@ -339,6 +339,7 @@ public class NettyClientTransportTest { this.stream = stream; this.method = method; this.headers = headers; + stream.writeHeaders(new Metadata()); stream.request(1); } diff --git a/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java b/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java index 55ac31e693..2128ada9b3 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java @@ -71,7 +71,7 @@ import java.io.ByteArrayInputStream; /** Unit tests for {@link NettyServerStream}. */ @RunWith(JUnit4.class) -public class NettyServerStreamTest extends NettyStreamTestBase { +public class NettyServerStreamTest extends NettyStreamTestBase { @Mock protected ServerStreamListener serverListener; @@ -92,13 +92,14 @@ public class NettyServerStreamTest extends NettyStreamTestBase { @Test public void writeMessageShouldSendResponse() throws Exception { - byte[] msg = smallMessage(); - stream.writeMessage(new ByteArrayInputStream(msg)); - stream.flush(); + stream.writeHeaders(new Metadata()); Http2Headers headers = new DefaultHttp2Headers() .status(Utils.STATUS_OK) .set(Utils.CONTENT_TYPE_HEADER, Utils.CONTENT_TYPE_GRPC); verify(writeQueue).enqueue(new SendResponseHeadersCommand(STREAM_ID, headers, false), true); + byte[] msg = smallMessage(); + stream.writeMessage(new ByteArrayInputStream(msg)); + stream.flush(); verify(writeQueue).enqueue(eq(new SendGrpcFrameCommand(stream, messageFrame(MESSAGE), false)), any(ChannelPromise.class), eq(true)); @@ -266,6 +267,11 @@ public class NettyServerStreamTest extends NettyStreamTestBase { return stream; } + @Override + protected void sendHeadersIfServer() { + stream.writeHeaders(new Metadata()); + } + @Override protected void closeStream() { stream().close(Status.ABORTED, new Metadata()); diff --git a/netty/src/test/java/io/grpc/netty/NettyStreamTestBase.java b/netty/src/test/java/io/grpc/netty/NettyStreamTestBase.java index 90099af725..23d47ac957 100644 --- a/netty/src/test/java/io/grpc/netty/NettyStreamTestBase.java +++ b/netty/src/test/java/io/grpc/netty/NettyStreamTestBase.java @@ -71,7 +71,7 @@ import java.util.concurrent.TimeUnit; /** * Base class for Netty stream unit tests. */ -public abstract class NettyStreamTestBase { +public abstract class NettyStreamTestBase> { protected static final String MESSAGE = "hello world"; protected static final int STREAM_ID = 1; @@ -99,7 +99,7 @@ public abstract class NettyStreamTestBase { @Mock protected WriteQueue writeQueue; - protected AbstractStream stream; + protected T stream; /** Set up for test. */ @Before @@ -160,6 +160,7 @@ public abstract class NettyStreamTestBase { @Test public void notifiedOnReadyAfterWriteCompletes() throws IOException { + sendHeadersIfServer(); assertTrue(stream.isReady()); byte[] msg = largeMessage(); // The future is set up to automatically complete, indicating that the write is done. @@ -171,6 +172,7 @@ public abstract class NettyStreamTestBase { @Test public void shouldBeReadyForDataAfterWritingSmallMessage() throws IOException { + sendHeadersIfServer(); // Make sure the writes don't complete so we "back up" reset(future); @@ -184,6 +186,7 @@ public abstract class NettyStreamTestBase { @Test public void shouldNotBeReadyForDataAfterWritingLargeMessage() throws IOException { + sendHeadersIfServer(); // Make sure the writes don't complete so we "back up" reset(future); @@ -209,7 +212,9 @@ public abstract class NettyStreamTestBase { return largeMessage; } - protected abstract AbstractStream createStream(); + protected abstract T createStream(); + + protected abstract void sendHeadersIfServer(); protected abstract StreamListener listener(); diff --git a/stub/src/main/java/io/grpc/stub/ServerCalls.java b/stub/src/main/java/io/grpc/stub/ServerCalls.java index 9870e8d392..9cf44508ac 100644 --- a/stub/src/main/java/io/grpc/stub/ServerCalls.java +++ b/stub/src/main/java/io/grpc/stub/ServerCalls.java @@ -223,6 +223,7 @@ public class ServerCalls { private static class ResponseObserver implements StreamObserver { final ServerCall call; volatile boolean cancelled; + private boolean sentHeaders; ResponseObserver(ServerCall call) { this.call = call; @@ -233,6 +234,10 @@ public class ServerCalls { if (cancelled) { throw Status.CANCELLED.asRuntimeException(); } + if (!sentHeaders) { + call.sendHeaders(new Metadata()); + sentHeaders = true; + } call.sendMessage(response); } diff --git a/testing/src/main/java/io/grpc/testing/TestUtils.java b/testing/src/main/java/io/grpc/testing/TestUtils.java index 53e3c39d9a..1127cfbc43 100644 --- a/testing/src/main/java/io/grpc/testing/TestUtils.java +++ b/testing/src/main/java/io/grpc/testing/TestUtils.java @@ -90,21 +90,10 @@ public class TestUtils { ServerCallHandler next) { return next.startCall(method, new SimpleForwardingServerCall(call) { - boolean sentHeaders; - @Override public void sendHeaders(Metadata responseHeaders) { responseHeaders.merge(requestHeaders, keySet); super.sendHeaders(responseHeaders); - sentHeaders = true; - } - - @Override - public void sendMessage(RespT message) { - if (!sentHeaders) { - sendHeaders(new Metadata()); - } - super.sendMessage(message); } @Override