mirror of https://github.com/grpc/grpc-java.git
Enforce sending headers before messages on server
ServerCall already had "headers must be sent before any messages, which must be sent before closing," but the implementation did not enforce it and our async server handler didn't obey. The benefit of forcing sending headers first is that it removes the only implicit call in our API and interceptors dealing just with metadata don't need to override sendMessage. The implicit behavior was bug-prone since it wasn't obvious you were forgetting that headers may not be sent.
This commit is contained in:
parent
701f9cd7ee
commit
1cfba96d17
|
|
@ -102,9 +102,8 @@ public abstract class AbstractServerStream<IdT> extends AbstractStream<IdT>
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public final void writeMessage(InputStream message) {
|
public final void writeMessage(InputStream message) {
|
||||||
if (!headersSent) {
|
if (outboundPhase() != Phase.MESSAGE) {
|
||||||
writeHeaders(new Metadata());
|
throw new IllegalStateException("Messages are only permitted after headers and before close");
|
||||||
headersSent = true;
|
|
||||||
}
|
}
|
||||||
super.writeMessage(message);
|
super.writeMessage(message);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -170,6 +170,7 @@ public class AbstractServerStreamTest {
|
||||||
capturedHeaders.set(captured);
|
capturedHeaders.set(captured);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
stream.writeHeaders(new Metadata());
|
||||||
|
|
||||||
stream.writeMessage(new ByteArrayInputStream(new byte[]{}));
|
stream.writeMessage(new ByteArrayInputStream(new byte[]{}));
|
||||||
|
|
||||||
|
|
@ -204,6 +205,7 @@ public class AbstractServerStreamTest {
|
||||||
sendCalled.set(true);
|
sendCalled.set(true);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
stream.writeHeaders(new Metadata());
|
||||||
stream.closeFramer();
|
stream.closeFramer();
|
||||||
|
|
||||||
stream.writeMessage(new ByteArrayInputStream(new byte[]{}));
|
stream.writeMessage(new ByteArrayInputStream(new byte[]{}));
|
||||||
|
|
@ -220,6 +222,7 @@ public class AbstractServerStreamTest {
|
||||||
sendCalled.set(true);
|
sendCalled.set(true);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
stream.writeHeaders(new Metadata());
|
||||||
|
|
||||||
stream.writeMessage(new ByteArrayInputStream(new byte[]{}));
|
stream.writeMessage(new ByteArrayInputStream(new byte[]{}));
|
||||||
// Force the message to be flushed
|
// Force the message to be flushed
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,6 @@ import io.grpc.MethodDescriptor;
|
||||||
import io.grpc.ServerCall;
|
import io.grpc.ServerCall;
|
||||||
import io.grpc.ServerCallHandler;
|
import io.grpc.ServerCallHandler;
|
||||||
import io.grpc.ServerInterceptor;
|
import io.grpc.ServerInterceptor;
|
||||||
import io.grpc.Status;
|
|
||||||
|
|
||||||
import java.util.logging.Logger;
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
|
|
@ -60,26 +59,10 @@ public class HeaderServerInterceptor implements ServerInterceptor {
|
||||||
ServerCallHandler<ReqT, RespT> next) {
|
ServerCallHandler<ReqT, RespT> next) {
|
||||||
logger.info("header received from client:" + requestHeaders.toString());
|
logger.info("header received from client:" + requestHeaders.toString());
|
||||||
return next.startCall(method, new SimpleForwardingServerCall<RespT>(call) {
|
return next.startCall(method, new SimpleForwardingServerCall<RespT>(call) {
|
||||||
boolean sentHeaders = false;
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void sendHeaders(Metadata responseHeaders) {
|
public void sendHeaders(Metadata responseHeaders) {
|
||||||
responseHeaders.put(customHeadKey, "customRespondValue");
|
responseHeaders.put(customHeadKey, "customRespondValue");
|
||||||
super.sendHeaders(responseHeaders);
|
super.sendHeaders(responseHeaders);
|
||||||
sentHeaders = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void sendMessage(RespT message) {
|
|
||||||
if (!sentHeaders) {
|
|
||||||
sendHeaders(new Metadata());
|
|
||||||
}
|
|
||||||
super.sendMessage(message);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void close(Status status, Metadata trailers) {
|
|
||||||
super.close(status, trailers);
|
|
||||||
}
|
}
|
||||||
}, requestHeaders);
|
}, requestHeaders);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -80,7 +80,7 @@ import java.io.InputStream;
|
||||||
* Tests for {@link NettyClientStream}.
|
* Tests for {@link NettyClientStream}.
|
||||||
*/
|
*/
|
||||||
@RunWith(JUnit4.class)
|
@RunWith(JUnit4.class)
|
||||||
public class NettyClientStreamTest extends NettyStreamTestBase {
|
public class NettyClientStreamTest extends NettyStreamTestBase<NettyClientStream> {
|
||||||
@Mock
|
@Mock
|
||||||
protected ClientStreamListener listener;
|
protected ClientStreamListener listener;
|
||||||
|
|
||||||
|
|
@ -371,6 +371,9 @@ public class NettyClientStreamTest extends NettyStreamTestBase {
|
||||||
return stream;
|
return stream;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void sendHeadersIfServer() {}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void closeStream() {
|
protected void closeStream() {
|
||||||
stream().cancel(Status.CANCELLED);
|
stream().cancel(Status.CANCELLED);
|
||||||
|
|
|
||||||
|
|
@ -339,6 +339,7 @@ public class NettyClientTransportTest {
|
||||||
this.stream = stream;
|
this.stream = stream;
|
||||||
this.method = method;
|
this.method = method;
|
||||||
this.headers = headers;
|
this.headers = headers;
|
||||||
|
stream.writeHeaders(new Metadata());
|
||||||
stream.request(1);
|
stream.request(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,7 @@ import java.io.ByteArrayInputStream;
|
||||||
|
|
||||||
/** Unit tests for {@link NettyServerStream}. */
|
/** Unit tests for {@link NettyServerStream}. */
|
||||||
@RunWith(JUnit4.class)
|
@RunWith(JUnit4.class)
|
||||||
public class NettyServerStreamTest extends NettyStreamTestBase {
|
public class NettyServerStreamTest extends NettyStreamTestBase<NettyServerStream> {
|
||||||
@Mock
|
@Mock
|
||||||
protected ServerStreamListener serverListener;
|
protected ServerStreamListener serverListener;
|
||||||
|
|
||||||
|
|
@ -92,13 +92,14 @@ public class NettyServerStreamTest extends NettyStreamTestBase {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void writeMessageShouldSendResponse() throws Exception {
|
public void writeMessageShouldSendResponse() throws Exception {
|
||||||
byte[] msg = smallMessage();
|
stream.writeHeaders(new Metadata());
|
||||||
stream.writeMessage(new ByteArrayInputStream(msg));
|
|
||||||
stream.flush();
|
|
||||||
Http2Headers headers = new DefaultHttp2Headers()
|
Http2Headers headers = new DefaultHttp2Headers()
|
||||||
.status(Utils.STATUS_OK)
|
.status(Utils.STATUS_OK)
|
||||||
.set(Utils.CONTENT_TYPE_HEADER, Utils.CONTENT_TYPE_GRPC);
|
.set(Utils.CONTENT_TYPE_HEADER, Utils.CONTENT_TYPE_GRPC);
|
||||||
verify(writeQueue).enqueue(new SendResponseHeadersCommand(STREAM_ID, headers, false), true);
|
verify(writeQueue).enqueue(new SendResponseHeadersCommand(STREAM_ID, headers, false), true);
|
||||||
|
byte[] msg = smallMessage();
|
||||||
|
stream.writeMessage(new ByteArrayInputStream(msg));
|
||||||
|
stream.flush();
|
||||||
verify(writeQueue).enqueue(eq(new SendGrpcFrameCommand(stream, messageFrame(MESSAGE), false)),
|
verify(writeQueue).enqueue(eq(new SendGrpcFrameCommand(stream, messageFrame(MESSAGE), false)),
|
||||||
any(ChannelPromise.class),
|
any(ChannelPromise.class),
|
||||||
eq(true));
|
eq(true));
|
||||||
|
|
@ -266,6 +267,11 @@ public class NettyServerStreamTest extends NettyStreamTestBase {
|
||||||
return stream;
|
return stream;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void sendHeadersIfServer() {
|
||||||
|
stream.writeHeaders(new Metadata());
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void closeStream() {
|
protected void closeStream() {
|
||||||
stream().close(Status.ABORTED, new Metadata());
|
stream().close(Status.ABORTED, new Metadata());
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,7 @@ import java.util.concurrent.TimeUnit;
|
||||||
/**
|
/**
|
||||||
* Base class for Netty stream unit tests.
|
* Base class for Netty stream unit tests.
|
||||||
*/
|
*/
|
||||||
public abstract class NettyStreamTestBase {
|
public abstract class NettyStreamTestBase<T extends AbstractStream<Integer>> {
|
||||||
protected static final String MESSAGE = "hello world";
|
protected static final String MESSAGE = "hello world";
|
||||||
protected static final int STREAM_ID = 1;
|
protected static final int STREAM_ID = 1;
|
||||||
|
|
||||||
|
|
@ -99,7 +99,7 @@ public abstract class NettyStreamTestBase {
|
||||||
@Mock
|
@Mock
|
||||||
protected WriteQueue writeQueue;
|
protected WriteQueue writeQueue;
|
||||||
|
|
||||||
protected AbstractStream<Integer> stream;
|
protected T stream;
|
||||||
|
|
||||||
/** Set up for test. */
|
/** Set up for test. */
|
||||||
@Before
|
@Before
|
||||||
|
|
@ -160,6 +160,7 @@ public abstract class NettyStreamTestBase {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void notifiedOnReadyAfterWriteCompletes() throws IOException {
|
public void notifiedOnReadyAfterWriteCompletes() throws IOException {
|
||||||
|
sendHeadersIfServer();
|
||||||
assertTrue(stream.isReady());
|
assertTrue(stream.isReady());
|
||||||
byte[] msg = largeMessage();
|
byte[] msg = largeMessage();
|
||||||
// The future is set up to automatically complete, indicating that the write is done.
|
// The future is set up to automatically complete, indicating that the write is done.
|
||||||
|
|
@ -171,6 +172,7 @@ public abstract class NettyStreamTestBase {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void shouldBeReadyForDataAfterWritingSmallMessage() throws IOException {
|
public void shouldBeReadyForDataAfterWritingSmallMessage() throws IOException {
|
||||||
|
sendHeadersIfServer();
|
||||||
// Make sure the writes don't complete so we "back up"
|
// Make sure the writes don't complete so we "back up"
|
||||||
reset(future);
|
reset(future);
|
||||||
|
|
||||||
|
|
@ -184,6 +186,7 @@ public abstract class NettyStreamTestBase {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void shouldNotBeReadyForDataAfterWritingLargeMessage() throws IOException {
|
public void shouldNotBeReadyForDataAfterWritingLargeMessage() throws IOException {
|
||||||
|
sendHeadersIfServer();
|
||||||
// Make sure the writes don't complete so we "back up"
|
// Make sure the writes don't complete so we "back up"
|
||||||
reset(future);
|
reset(future);
|
||||||
|
|
||||||
|
|
@ -209,7 +212,9 @@ public abstract class NettyStreamTestBase {
|
||||||
return largeMessage;
|
return largeMessage;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract AbstractStream<Integer> createStream();
|
protected abstract T createStream();
|
||||||
|
|
||||||
|
protected abstract void sendHeadersIfServer();
|
||||||
|
|
||||||
protected abstract StreamListener listener();
|
protected abstract StreamListener listener();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -223,6 +223,7 @@ public class ServerCalls {
|
||||||
private static class ResponseObserver<RespT> implements StreamObserver<RespT> {
|
private static class ResponseObserver<RespT> implements StreamObserver<RespT> {
|
||||||
final ServerCall<RespT> call;
|
final ServerCall<RespT> call;
|
||||||
volatile boolean cancelled;
|
volatile boolean cancelled;
|
||||||
|
private boolean sentHeaders;
|
||||||
|
|
||||||
ResponseObserver(ServerCall<RespT> call) {
|
ResponseObserver(ServerCall<RespT> call) {
|
||||||
this.call = call;
|
this.call = call;
|
||||||
|
|
@ -233,6 +234,10 @@ public class ServerCalls {
|
||||||
if (cancelled) {
|
if (cancelled) {
|
||||||
throw Status.CANCELLED.asRuntimeException();
|
throw Status.CANCELLED.asRuntimeException();
|
||||||
}
|
}
|
||||||
|
if (!sentHeaders) {
|
||||||
|
call.sendHeaders(new Metadata());
|
||||||
|
sentHeaders = true;
|
||||||
|
}
|
||||||
call.sendMessage(response);
|
call.sendMessage(response);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -90,21 +90,10 @@ public class TestUtils {
|
||||||
ServerCallHandler<ReqT, RespT> next) {
|
ServerCallHandler<ReqT, RespT> next) {
|
||||||
return next.startCall(method,
|
return next.startCall(method,
|
||||||
new SimpleForwardingServerCall<RespT>(call) {
|
new SimpleForwardingServerCall<RespT>(call) {
|
||||||
boolean sentHeaders;
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void sendHeaders(Metadata responseHeaders) {
|
public void sendHeaders(Metadata responseHeaders) {
|
||||||
responseHeaders.merge(requestHeaders, keySet);
|
responseHeaders.merge(requestHeaders, keySet);
|
||||||
super.sendHeaders(responseHeaders);
|
super.sendHeaders(responseHeaders);
|
||||||
sentHeaders = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void sendMessage(RespT message) {
|
|
||||||
if (!sentHeaders) {
|
|
||||||
sendHeaders(new Metadata());
|
|
||||||
}
|
|
||||||
super.sendMessage(message);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue