diff --git a/netty/src/main/java/io/grpc/netty/CancelServerStreamCommand.java b/netty/src/main/java/io/grpc/netty/CancelServerStreamCommand.java index d9f5d96e06..d49e3bd672 100644 --- a/netty/src/main/java/io/grpc/netty/CancelServerStreamCommand.java +++ b/netty/src/main/java/io/grpc/netty/CancelServerStreamCommand.java @@ -27,10 +27,23 @@ import io.grpc.Status; final class CancelServerStreamCommand extends WriteQueue.AbstractQueuedCommand { private final NettyServerStream.TransportState stream; private final Status reason; + private final PeerNotify peerNotify; - CancelServerStreamCommand(NettyServerStream.TransportState stream, Status reason) { + private CancelServerStreamCommand( + NettyServerStream.TransportState stream, Status reason, PeerNotify peerNotify) { this.stream = Preconditions.checkNotNull(stream, "stream"); this.reason = Preconditions.checkNotNull(reason, "reason"); + this.peerNotify = Preconditions.checkNotNull(peerNotify, "peerNotify"); + } + + static CancelServerStreamCommand withReset( + NettyServerStream.TransportState stream, Status reason) { + return new CancelServerStreamCommand(stream, reason, PeerNotify.RESET); + } + + static CancelServerStreamCommand withReason( + NettyServerStream.TransportState stream, Status reason) { + return new CancelServerStreamCommand(stream, reason, PeerNotify.BEST_EFFORT_STATUS); } NettyServerStream.TransportState stream() { @@ -41,6 +54,10 @@ final class CancelServerStreamCommand extends WriteQueue.AbstractQueuedCommand { return reason; } + boolean wantsHeaders() { + return peerNotify == PeerNotify.BEST_EFFORT_STATUS; + } + @Override public boolean equals(Object o) { if (this == o) { @@ -68,4 +85,11 @@ final class CancelServerStreamCommand extends WriteQueue.AbstractQueuedCommand { .add("reason", reason) .toString(); } + + private enum PeerNotify { + /** Notify the peer by sending a RST_STREAM with no other information. */ + RESET, + /** Notify the peer about the {@link #reason} by sending structured headers, if possible. */ + BEST_EFFORT_STATUS, + } } diff --git a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java index 77b448446b..a6e855a199 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java @@ -788,11 +788,39 @@ class NettyServerHandler extends AbstractNettyHandler { PerfMark.linkIn(cmd.getLink()); // Notify the listener if we haven't already. cmd.stream().transportReportStatus(cmd.reason()); - // Terminate the stream. - encoder().writeRstStream(ctx, cmd.stream().id(), Http2Error.CANCEL.code(), promise); + + // Now we need to decide how we're going to notify the peer that this stream is closed. + // If possible, it's nice to inform the peer _why_ this stream was cancelled by sending + // a structured headers frame. + if (shouldCloseStreamWithHeaders(cmd, connection())) { + Metadata md = new Metadata(); + md.put(InternalStatus.CODE_KEY, cmd.reason()); + if (cmd.reason().getDescription() != null) { + md.put(InternalStatus.MESSAGE_KEY, cmd.reason().getDescription()); + } + Http2Headers headers = Utils.convertServerHeaders(md); + encoder().writeHeaders( + ctx, cmd.stream().id(), headers, /* padding = */ 0, /* endStream = */ true, promise); + } else { + // Terminate the stream. + encoder().writeRstStream(ctx, cmd.stream().id(), Http2Error.CANCEL.code(), promise); + } } } + // Determine whether a CancelServerStreamCommand should try to close the stream with a + // HEADERS or a RST_STREAM frame. The caller has some influence over this (they can + // configure cmd.wantsHeaders()). The state of the stream also has an influence: we + // only try to send HEADERS if the stream exists and hasn't already sent any headers. + private static boolean shouldCloseStreamWithHeaders( + CancelServerStreamCommand cmd, Http2Connection conn) { + if (!cmd.wantsHeaders()) { + return false; + } + Http2Stream stream = conn.stream(cmd.stream().id()); + return stream != null && !stream.isHeadersSent(); + } + private void gracefulClose(final ChannelHandlerContext ctx, final GracefulServerCloseCommand msg, ChannelPromise promise) throws Exception { // Ideally we'd adjust a pre-existing graceful shutdown's grace period to at least what is diff --git a/netty/src/main/java/io/grpc/netty/NettyServerStream.java b/netty/src/main/java/io/grpc/netty/NettyServerStream.java index a4304d5193..836f39ddf1 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerStream.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerStream.java @@ -130,7 +130,7 @@ class NettyServerStream extends AbstractServerStream { @Override public void cancel(Status status) { try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.cancel")) { - writeQueue.enqueue(new CancelServerStreamCommand(transportState(), status), true); + writeQueue.enqueue(CancelServerStreamCommand.withReset(transportState(), status), true); } } } @@ -189,7 +189,7 @@ class NettyServerStream extends AbstractServerStream { log.log(Level.WARNING, "Exception processing message", cause); Status status = Status.fromThrowable(cause); transportReportStatus(status); - handler.getWriteQueue().enqueue(new CancelServerStreamCommand(this, status), true); + handler.getWriteQueue().enqueue(CancelServerStreamCommand.withReason(this, status), true); } private void onWriteFrameData(ChannelFuture future, int numMessages, int numBytes) { @@ -222,7 +222,7 @@ class NettyServerStream extends AbstractServerStream { */ protected void http2ProcessingFailed(Status status) { transportReportStatus(status); - handler.getWriteQueue().enqueue(new CancelServerStreamCommand(this, status), true); + handler.getWriteQueue().enqueue(CancelServerStreamCommand.withReset(this, status), true); } void inboundDataReceived(ByteBuf frame, boolean endOfStream) { diff --git a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java index 281ff3b17d..ce902a9620 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java @@ -89,8 +89,10 @@ import io.netty.util.AsciiString; import java.io.InputStream; import java.nio.channels.ClosedChannelException; import java.util.Arrays; +import java.util.HashMap; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Queue; import java.util.concurrent.TimeUnit; import org.junit.Before; @@ -469,11 +471,41 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase captor = ArgumentCaptor.forClass(Http2Headers.class); + verifyWrite() + .writeHeaders( + eq(ctx()), + eq(STREAM_ID), + captor.capture(), + eq(0), + eq(true), + any(ChannelPromise.class)); + + // For arcane reasons, the specific implementation of Http2Headers here doesn't actually support + // methods like `get(...)`, so we have to manually convert it into a map. + Map actualHeaders = new HashMap<>(); + for (Map.Entry entry : captor.getValue()) { + actualHeaders.put(entry.getKey().toString(), entry.getValue().toString()); + } + assertEquals("8", actualHeaders.get(InternalStatus.CODE_KEY.name())); + assertEquals("my custom description", actualHeaders.get(InternalStatus.MESSAGE_KEY.name())); + } + @Test public void headersWithInvalidContentTypeShouldFail() throws Exception { manualSetUp(); diff --git a/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java b/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java index ab54d4e4e2..452f68341b 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java @@ -18,7 +18,6 @@ package io.grpc.netty; import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; -import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import static io.grpc.netty.NettyTestUtil.messageFrame; import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; import static io.netty.handler.codec.http2.Http2Exception.connectionError; @@ -37,6 +36,7 @@ import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; +import com.google.common.base.Strings; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ListMultimap; import io.grpc.Attributes; @@ -73,6 +73,8 @@ import org.mockito.stubbing.Answer; /** Unit tests for {@link NettyServerStream}. */ @RunWith(JUnit4.class) public class NettyServerStreamTest extends NettyStreamTestBase { + private static final int TEST_MAX_MESSAGE_SIZE = 128; + @Mock protected ServerStreamListener serverListener; @@ -380,10 +382,31 @@ public class NettyServerStreamTest extends NettyStreamTestBase cancelCmdCap = + ArgumentCaptor.forClass(CancelServerStreamCommand.class); + verify(writeQueue).enqueue(cancelCmdCap.capture(), eq(true)); + + Status status = Status.RESOURCE_EXHAUSTED + .withDescription("gRPC message exceeds maximum size 128: 129"); + + CancelServerStreamCommand actualCmd = cancelCmdCap.getValue(); + assertThat(actualCmd.reason().getCode()).isEqualTo(status.getCode()); + assertThat(actualCmd.reason().getDescription()).isEqualTo(status.getDescription()); + assertThat(actualCmd.wantsHeaders()).isTrue(); + } + @Override @SuppressWarnings("DirectInvocationOnMock") protected NettyServerStream createStream() { @@ -391,7 +414,7 @@ public class NettyServerStreamTest extends NettyStreamTestBase