netty: Allow deframer errors to close stream with a status code

Today, deframer errors cancel the stream without communicating a status code
to the peer. This change causes deframer errors to trigger a best-effort
attempt to send trailers with a status code so that the peer understands
why the stream is being closed.

Fixes #3996
This commit is contained in:
Ryan P. Brewster 2024-04-24 13:05:51 -04:00 committed by Eric Anderson
parent 11612b484a
commit e036b1b198
5 changed files with 117 additions and 10 deletions

View File

@ -27,10 +27,23 @@ import io.grpc.Status;
final class CancelServerStreamCommand extends WriteQueue.AbstractQueuedCommand {
private final NettyServerStream.TransportState stream;
private final Status reason;
private final PeerNotify peerNotify;
CancelServerStreamCommand(NettyServerStream.TransportState stream, Status reason) {
private CancelServerStreamCommand(
NettyServerStream.TransportState stream, Status reason, PeerNotify peerNotify) {
this.stream = Preconditions.checkNotNull(stream, "stream");
this.reason = Preconditions.checkNotNull(reason, "reason");
this.peerNotify = Preconditions.checkNotNull(peerNotify, "peerNotify");
}
static CancelServerStreamCommand withReset(
NettyServerStream.TransportState stream, Status reason) {
return new CancelServerStreamCommand(stream, reason, PeerNotify.RESET);
}
static CancelServerStreamCommand withReason(
NettyServerStream.TransportState stream, Status reason) {
return new CancelServerStreamCommand(stream, reason, PeerNotify.BEST_EFFORT_STATUS);
}
NettyServerStream.TransportState stream() {
@ -41,6 +54,10 @@ final class CancelServerStreamCommand extends WriteQueue.AbstractQueuedCommand {
return reason;
}
boolean wantsHeaders() {
return peerNotify == PeerNotify.BEST_EFFORT_STATUS;
}
@Override
public boolean equals(Object o) {
if (this == o) {
@ -68,4 +85,11 @@ final class CancelServerStreamCommand extends WriteQueue.AbstractQueuedCommand {
.add("reason", reason)
.toString();
}
private enum PeerNotify {
/** Notify the peer by sending a RST_STREAM with no other information. */
RESET,
/** Notify the peer about the {@link #reason} by sending structured headers, if possible. */
BEST_EFFORT_STATUS,
}
}

View File

@ -788,10 +788,38 @@ class NettyServerHandler extends AbstractNettyHandler {
PerfMark.linkIn(cmd.getLink());
// Notify the listener if we haven't already.
cmd.stream().transportReportStatus(cmd.reason());
// Now we need to decide how we're going to notify the peer that this stream is closed.
// If possible, it's nice to inform the peer _why_ this stream was cancelled by sending
// a structured headers frame.
if (shouldCloseStreamWithHeaders(cmd, connection())) {
Metadata md = new Metadata();
md.put(InternalStatus.CODE_KEY, cmd.reason());
if (cmd.reason().getDescription() != null) {
md.put(InternalStatus.MESSAGE_KEY, cmd.reason().getDescription());
}
Http2Headers headers = Utils.convertServerHeaders(md);
encoder().writeHeaders(
ctx, cmd.stream().id(), headers, /* padding = */ 0, /* endStream = */ true, promise);
} else {
// Terminate the stream.
encoder().writeRstStream(ctx, cmd.stream().id(), Http2Error.CANCEL.code(), promise);
}
}
}
// Determine whether a CancelServerStreamCommand should try to close the stream with a
// HEADERS or a RST_STREAM frame. The caller has some influence over this (they can
// configure cmd.wantsHeaders()). The state of the stream also has an influence: we
// only try to send HEADERS if the stream exists and hasn't already sent any headers.
private static boolean shouldCloseStreamWithHeaders(
CancelServerStreamCommand cmd, Http2Connection conn) {
if (!cmd.wantsHeaders()) {
return false;
}
Http2Stream stream = conn.stream(cmd.stream().id());
return stream != null && !stream.isHeadersSent();
}
private void gracefulClose(final ChannelHandlerContext ctx, final GracefulServerCloseCommand msg,
ChannelPromise promise) throws Exception {

View File

@ -130,7 +130,7 @@ class NettyServerStream extends AbstractServerStream {
@Override
public void cancel(Status status) {
try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.cancel")) {
writeQueue.enqueue(new CancelServerStreamCommand(transportState(), status), true);
writeQueue.enqueue(CancelServerStreamCommand.withReset(transportState(), status), true);
}
}
}
@ -189,7 +189,7 @@ class NettyServerStream extends AbstractServerStream {
log.log(Level.WARNING, "Exception processing message", cause);
Status status = Status.fromThrowable(cause);
transportReportStatus(status);
handler.getWriteQueue().enqueue(new CancelServerStreamCommand(this, status), true);
handler.getWriteQueue().enqueue(CancelServerStreamCommand.withReason(this, status), true);
}
private void onWriteFrameData(ChannelFuture future, int numMessages, int numBytes) {
@ -222,7 +222,7 @@ class NettyServerStream extends AbstractServerStream {
*/
protected void http2ProcessingFailed(Status status) {
transportReportStatus(status);
handler.getWriteQueue().enqueue(new CancelServerStreamCommand(this, status), true);
handler.getWriteQueue().enqueue(CancelServerStreamCommand.withReset(this, status), true);
}
void inboundDataReceived(ByteBuf frame, boolean endOfStream) {

View File

@ -89,8 +89,10 @@ import io.netty.util.AsciiString;
import java.io.InputStream;
import java.nio.channels.ClosedChannelException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
@ -469,11 +471,41 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
public void cancelShouldSendRstStream() throws Exception {
manualSetUp();
createStream();
enqueue(new CancelServerStreamCommand(stream.transportState(), Status.DEADLINE_EXCEEDED));
enqueue(CancelServerStreamCommand.withReset(stream.transportState(), Status.DEADLINE_EXCEEDED));
verifyWrite().writeRstStream(eq(ctx()), eq(stream.transportState().id()),
eq(Http2Error.CANCEL.code()), any(ChannelPromise.class));
}
@Test
public void cancelWithNotify_shouldSendHeaders() throws Exception {
manualSetUp();
createStream();
enqueue(CancelServerStreamCommand.withReason(
stream.transportState(),
Status.RESOURCE_EXHAUSTED.withDescription("my custom description")
));
ArgumentCaptor<Http2Headers> captor = ArgumentCaptor.forClass(Http2Headers.class);
verifyWrite()
.writeHeaders(
eq(ctx()),
eq(STREAM_ID),
captor.capture(),
eq(0),
eq(true),
any(ChannelPromise.class));
// For arcane reasons, the specific implementation of Http2Headers here doesn't actually support
// methods like `get(...)`, so we have to manually convert it into a map.
Map<String, String> actualHeaders = new HashMap<>();
for (Map.Entry<CharSequence, CharSequence> entry : captor.getValue()) {
actualHeaders.put(entry.getKey().toString(), entry.getValue().toString());
}
assertEquals("8", actualHeaders.get(InternalStatus.CODE_KEY.name()));
assertEquals("my custom description", actualHeaders.get(InternalStatus.MESSAGE_KEY.name()));
}
@Test
public void headersWithInvalidContentTypeShouldFail() throws Exception {
manualSetUp();

View File

@ -18,7 +18,6 @@ package io.grpc.netty;
import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth.assertWithMessage;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static io.grpc.netty.NettyTestUtil.messageFrame;
import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2Exception.connectionError;
@ -37,6 +36,7 @@ import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ListMultimap;
import io.grpc.Attributes;
@ -73,6 +73,8 @@ import org.mockito.stubbing.Answer;
/** Unit tests for {@link NettyServerStream}. */
@RunWith(JUnit4.class)
public class NettyServerStreamTest extends NettyStreamTestBase<NettyServerStream> {
private static final int TEST_MAX_MESSAGE_SIZE = 128;
@Mock
protected ServerStreamListener serverListener;
@ -380,10 +382,31 @@ public class NettyServerStreamTest extends NettyStreamTestBase<NettyServerStream
public void cancelStreamShouldSucceed() {
stream().cancel(Status.DEADLINE_EXCEEDED);
verify(writeQueue).enqueue(
new CancelServerStreamCommand(stream().transportState(), Status.DEADLINE_EXCEEDED),
CancelServerStreamCommand.withReset(stream().transportState(), Status.DEADLINE_EXCEEDED),
true);
}
@Test
public void oversizedMessagesResultInResourceExhaustedTrailers() throws Exception {
@SuppressWarnings("InlineMeInliner") // Requires Java 11
String oversizedMsg = Strings.repeat("a", TEST_MAX_MESSAGE_SIZE + 1);
stream.request(1);
stream.transportState().inboundDataReceived(messageFrame(oversizedMsg), false);
assertNull("message should have caused a deframer error", listenerMessageQueue().poll());
ArgumentCaptor<CancelServerStreamCommand> cancelCmdCap =
ArgumentCaptor.forClass(CancelServerStreamCommand.class);
verify(writeQueue).enqueue(cancelCmdCap.capture(), eq(true));
Status status = Status.RESOURCE_EXHAUSTED
.withDescription("gRPC message exceeds maximum size 128: 129");
CancelServerStreamCommand actualCmd = cancelCmdCap.getValue();
assertThat(actualCmd.reason().getCode()).isEqualTo(status.getCode());
assertThat(actualCmd.reason().getDescription()).isEqualTo(status.getDescription());
assertThat(actualCmd.wantsHeaders()).isTrue();
}
@Override
@SuppressWarnings("DirectInvocationOnMock")
protected NettyServerStream createStream() {
@ -391,7 +414,7 @@ public class NettyServerStreamTest extends NettyStreamTestBase<NettyServerStream
StatsTraceContext statsTraceCtx = StatsTraceContext.NOOP;
TransportTracer transportTracer = new TransportTracer();
NettyServerStream.TransportState state = new NettyServerStream.TransportState(
handler, channel.eventLoop(), http2Stream, DEFAULT_MAX_MESSAGE_SIZE, statsTraceCtx,
handler, channel.eventLoop(), http2Stream, TEST_MAX_MESSAGE_SIZE, statsTraceCtx,
transportTracer, "method");
NettyServerStream stream = new NettyServerStream(channel, state, Attributes.EMPTY,
"test-authority", statsTraceCtx);