Enforce sending headers before messages on server

ServerCall already had "headers must be sent before any messages, which
must be sent before closing," but the implementation did not enforce it
and our async server handler didn't obey.

The benefit of forcing sending headers first is that it removes the only
implicit call in our API and interceptors dealing just with metadata
don't need to override sendMessage. The implicit behavior was bug-prone
since it wasn't obvious you were forgetting that headers may not be
sent.
This commit is contained in:
Eric Anderson 2015-09-11 15:52:14 -07:00
parent 701f9cd7ee
commit 1cfba96d17
9 changed files with 33 additions and 39 deletions

View File

@ -102,9 +102,8 @@ public abstract class AbstractServerStream<IdT> extends AbstractStream<IdT>
@Override
public final void writeMessage(InputStream message) {
if (!headersSent) {
writeHeaders(new Metadata());
headersSent = true;
if (outboundPhase() != Phase.MESSAGE) {
throw new IllegalStateException("Messages are only permitted after headers and before close");
}
super.writeMessage(message);
}

View File

@ -170,6 +170,7 @@ public class AbstractServerStreamTest {
capturedHeaders.set(captured);
}
};
stream.writeHeaders(new Metadata());
stream.writeMessage(new ByteArrayInputStream(new byte[]{}));
@ -204,6 +205,7 @@ public class AbstractServerStreamTest {
sendCalled.set(true);
}
};
stream.writeHeaders(new Metadata());
stream.closeFramer();
stream.writeMessage(new ByteArrayInputStream(new byte[]{}));
@ -220,6 +222,7 @@ public class AbstractServerStreamTest {
sendCalled.set(true);
}
};
stream.writeHeaders(new Metadata());
stream.writeMessage(new ByteArrayInputStream(new byte[]{}));
// Force the message to be flushed

View File

@ -37,7 +37,6 @@ import io.grpc.MethodDescriptor;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import java.util.logging.Logger;
@ -60,26 +59,10 @@ public class HeaderServerInterceptor implements ServerInterceptor {
ServerCallHandler<ReqT, RespT> next) {
logger.info("header received from client:" + requestHeaders.toString());
return next.startCall(method, new SimpleForwardingServerCall<RespT>(call) {
boolean sentHeaders = false;
@Override
public void sendHeaders(Metadata responseHeaders) {
responseHeaders.put(customHeadKey, "customRespondValue");
super.sendHeaders(responseHeaders);
sentHeaders = true;
}
@Override
public void sendMessage(RespT message) {
if (!sentHeaders) {
sendHeaders(new Metadata());
}
super.sendMessage(message);
}
@Override
public void close(Status status, Metadata trailers) {
super.close(status, trailers);
}
}, requestHeaders);
}

View File

@ -80,7 +80,7 @@ import java.io.InputStream;
* Tests for {@link NettyClientStream}.
*/
@RunWith(JUnit4.class)
public class NettyClientStreamTest extends NettyStreamTestBase {
public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream> {
@Mock
protected ClientStreamListener listener;
@ -371,6 +371,9 @@ public class NettyClientStreamTest extends NettyStreamTestBase {
return stream;
}
@Override
protected void sendHeadersIfServer() {}
@Override
protected void closeStream() {
stream().cancel(Status.CANCELLED);

View File

@ -339,6 +339,7 @@ public class NettyClientTransportTest {
this.stream = stream;
this.method = method;
this.headers = headers;
stream.writeHeaders(new Metadata());
stream.request(1);
}

View File

@ -71,7 +71,7 @@ import java.io.ByteArrayInputStream;
/** Unit tests for {@link NettyServerStream}. */
@RunWith(JUnit4.class)
public class NettyServerStreamTest extends NettyStreamTestBase {
public class NettyServerStreamTest extends NettyStreamTestBase<NettyServerStream> {
@Mock
protected ServerStreamListener serverListener;
@ -92,13 +92,14 @@ public class NettyServerStreamTest extends NettyStreamTestBase {
@Test
public void writeMessageShouldSendResponse() throws Exception {
byte[] msg = smallMessage();
stream.writeMessage(new ByteArrayInputStream(msg));
stream.flush();
stream.writeHeaders(new Metadata());
Http2Headers headers = new DefaultHttp2Headers()
.status(Utils.STATUS_OK)
.set(Utils.CONTENT_TYPE_HEADER, Utils.CONTENT_TYPE_GRPC);
verify(writeQueue).enqueue(new SendResponseHeadersCommand(STREAM_ID, headers, false), true);
byte[] msg = smallMessage();
stream.writeMessage(new ByteArrayInputStream(msg));
stream.flush();
verify(writeQueue).enqueue(eq(new SendGrpcFrameCommand(stream, messageFrame(MESSAGE), false)),
any(ChannelPromise.class),
eq(true));
@ -266,6 +267,11 @@ public class NettyServerStreamTest extends NettyStreamTestBase {
return stream;
}
@Override
protected void sendHeadersIfServer() {
stream.writeHeaders(new Metadata());
}
@Override
protected void closeStream() {
stream().close(Status.ABORTED, new Metadata());

View File

@ -71,7 +71,7 @@ import java.util.concurrent.TimeUnit;
/**
* Base class for Netty stream unit tests.
*/
public abstract class NettyStreamTestBase {
public abstract class NettyStreamTestBase<T extends AbstractStream<Integer>> {
protected static final String MESSAGE = "hello world";
protected static final int STREAM_ID = 1;
@ -99,7 +99,7 @@ public abstract class NettyStreamTestBase {
@Mock
protected WriteQueue writeQueue;
protected AbstractStream<Integer> stream;
protected T stream;
/** Set up for test. */
@Before
@ -160,6 +160,7 @@ public abstract class NettyStreamTestBase {
@Test
public void notifiedOnReadyAfterWriteCompletes() throws IOException {
sendHeadersIfServer();
assertTrue(stream.isReady());
byte[] msg = largeMessage();
// The future is set up to automatically complete, indicating that the write is done.
@ -171,6 +172,7 @@ public abstract class NettyStreamTestBase {
@Test
public void shouldBeReadyForDataAfterWritingSmallMessage() throws IOException {
sendHeadersIfServer();
// Make sure the writes don't complete so we "back up"
reset(future);
@ -184,6 +186,7 @@ public abstract class NettyStreamTestBase {
@Test
public void shouldNotBeReadyForDataAfterWritingLargeMessage() throws IOException {
sendHeadersIfServer();
// Make sure the writes don't complete so we "back up"
reset(future);
@ -209,7 +212,9 @@ public abstract class NettyStreamTestBase {
return largeMessage;
}
protected abstract AbstractStream<Integer> createStream();
protected abstract T createStream();
protected abstract void sendHeadersIfServer();
protected abstract StreamListener listener();

View File

@ -223,6 +223,7 @@ public class ServerCalls {
private static class ResponseObserver<RespT> implements StreamObserver<RespT> {
final ServerCall<RespT> call;
volatile boolean cancelled;
private boolean sentHeaders;
ResponseObserver(ServerCall<RespT> call) {
this.call = call;
@ -233,6 +234,10 @@ public class ServerCalls {
if (cancelled) {
throw Status.CANCELLED.asRuntimeException();
}
if (!sentHeaders) {
call.sendHeaders(new Metadata());
sentHeaders = true;
}
call.sendMessage(response);
}

View File

@ -90,21 +90,10 @@ public class TestUtils {
ServerCallHandler<ReqT, RespT> next) {
return next.startCall(method,
new SimpleForwardingServerCall<RespT>(call) {
boolean sentHeaders;
@Override
public void sendHeaders(Metadata responseHeaders) {
responseHeaders.merge(requestHeaders, keySet);
super.sendHeaders(responseHeaders);
sentHeaders = true;
}
@Override
public void sendMessage(RespT message) {
if (!sentHeaders) {
sendHeaders(new Metadata());
}
super.sendMessage(message);
}
@Override