From e490273edd37dba1cf3c3b93de26bfc40ac99eaa Mon Sep 17 00:00:00 2001 From: Sergii Tkachenko Date: Tue, 16 Apr 2024 16:27:51 -0700 Subject: [PATCH] netty: Handle write queue promise failures (#11016) Handles Netty write frame failures caused by issues in the Netty itself. Normally we don't need to do anything on frame write failures because the cause of a failed future would be an IO error that resulted in the stream closure. Prior to this PR we treated these issues as a noop, except the initial headers write on the client side. However, a case like netty/netty#13805 (a bug in generating next stream id) resulted in an unclosed stream on our side. This PR adds write frame future failure handlers that ensures the stream is cancelled, and the cause is propagated via Status. Fixes #10849 --- .../grpc/internal/AbstractServerStream.java | 3 +- .../java/io/grpc/internal/AbstractStream.java | 6 + .../java/io/grpc/netty/NettyClientStream.java | 39 ++++--- .../io/grpc/netty/NettyServerHandler.java | 3 +- .../java/io/grpc/netty/NettyServerStream.java | 85 ++++++++------ .../io/grpc/netty/NettyClientStreamTest.java | 49 ++++++++ .../io/grpc/netty/NettyServerStreamTest.java | 106 +++++++++++++++++- 7 files changed, 241 insertions(+), 50 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/AbstractServerStream.java b/core/src/main/java/io/grpc/internal/AbstractServerStream.java index e2a0829200..a535330f4b 100644 --- a/core/src/main/java/io/grpc/internal/AbstractServerStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractServerStream.java @@ -293,6 +293,7 @@ public abstract class AbstractServerStream extends AbstractStream */ public final void transportReportStatus(final Status status) { Preconditions.checkArgument(!status.isOk(), "status must not be OK"); + onStreamDeallocated(); if (deframerClosed) { deframerClosedTask = null; closeListener(status); @@ -315,6 +316,7 @@ public abstract class AbstractServerStream extends AbstractStream * #transportReportStatus}. */ public void complete() { + onStreamDeallocated(); if (deframerClosed) { deframerClosedTask = null; closeListener(Status.OK); @@ -350,7 +352,6 @@ public abstract class AbstractServerStream extends AbstractStream getTransportTracer().reportStreamClosed(closedStatus.isOk()); } listenerClosed = true; - onStreamDeallocated(); listener().closed(newStatus); } } diff --git a/core/src/main/java/io/grpc/internal/AbstractStream.java b/core/src/main/java/io/grpc/internal/AbstractStream.java index 4b7486e466..9efc488657 100644 --- a/core/src/main/java/io/grpc/internal/AbstractStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractStream.java @@ -322,6 +322,12 @@ public abstract class AbstractStream implements Stream { } } + protected boolean isStreamDeallocated() { + synchronized (onReadyLock) { + return deallocated; + } + } + /** * Event handler to be called by the subclass when a number of bytes are being queued for * sending to the remote endpoint. diff --git a/netty/src/main/java/io/grpc/netty/NettyClientStream.java b/netty/src/main/java/io/grpc/netty/NettyClientStream.java index d88520a749..2939eed2e3 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientStream.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientStream.java @@ -182,20 +182,10 @@ class NettyClientStream extends AbstractClientStream { if (numBytes > 0) { // Add the bytes to outbound flow control. onSendingBytes(numBytes); + ChannelFutureListener failureListener = + future -> transportState().onWriteFrameData(future, numMessages, numBytes); writeQueue.enqueue(new SendGrpcFrameCommand(transportState(), bytebuf, endOfStream), flush) - .addListener(new ChannelFutureListener() { - @Override - public void operationComplete(ChannelFuture future) throws Exception { - // If the future succeeds when http2stream is null, the stream has been cancelled - // before it began and Netty is purging pending writes from the flow-controller. - if (future.isSuccess() && transportState().http2Stream() != null) { - // Remove the bytes from outbound flow control, optionally notifying - // the client that they can send more bytes. - transportState().onSentBytes(numBytes); - NettyClientStream.this.getTransportTracer().reportMessageSent(numMessages); - } - } - }); + .addListener(failureListener); } else { // The frame is empty and will not impact outbound flow control. Just send it. writeQueue.enqueue( @@ -307,6 +297,29 @@ class NettyClientStream extends AbstractClientStream { handler.getWriteQueue().enqueue(new CancelClientStreamCommand(this, status), true); } + private void onWriteFrameData(ChannelFuture future, int numMessages, int numBytes) { + // If the future succeeds when http2stream is null, the stream has been cancelled + // before it began and Netty is purging pending writes from the flow-controller. + if (future.isSuccess() && http2Stream() == null) { + return; + } + + if (future.isSuccess()) { + // Remove the bytes from outbound flow control, optionally notifying + // the client that they can send more bytes. + onSentBytes(numBytes); + getTransportTracer().reportMessageSent(numMessages); + } else if (!isStreamDeallocated()) { + // Future failed, fail RPC. + // Normally we don't need to do anything here because the cause of a failed future + // while writing DATA frames would be an IO error and the stream is already closed. + // However, we still need handle any unexpected failures raised in Netty. + // Note: isStreamDeallocated() protects from spamming stream resets by scheduling multiple + // CancelClientStreamCommand commands. + http2ProcessingFailed(statusFromFailedFuture(future), true, new Metadata()); + } + } + @Override public void runOnTransportThread(final Runnable r) { if (eventLoop.inEventLoop()) { diff --git a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java index 2b06a3fcf5..500368e880 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java @@ -502,8 +502,7 @@ class NettyServerHandler extends AbstractNettyHandler { state, attributes, authority, - statsTraceCtx, - transportTracer); + statsTraceCtx); transportListener.streamCreated(stream, method, metadata); state.onStreamAllocated(); http2Stream.setProperty(streamKey, state); diff --git a/netty/src/main/java/io/grpc/netty/NettyServerStream.java b/netty/src/main/java/io/grpc/netty/NettyServerStream.java index a44d8b4a64..a4304d5193 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerStream.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerStream.java @@ -52,7 +52,6 @@ class NettyServerStream extends AbstractServerStream { private final WriteQueue writeQueue; private final Attributes attributes; private final String authority; - private final TransportTracer transportTracer; private final int streamId; public NettyServerStream( @@ -60,14 +59,12 @@ class NettyServerStream extends AbstractServerStream { TransportState state, Attributes transportAttrs, String authority, - StatsTraceContext statsTraceCtx, - TransportTracer transportTracer) { + StatsTraceContext statsTraceCtx) { super(new NettyWritableBufferAllocator(channel.alloc()), statsTraceCtx); this.state = checkNotNull(state, "transportState"); this.writeQueue = state.handler.getWriteQueue(); this.attributes = checkNotNull(transportAttrs); this.authority = authority; - this.transportTracer = checkNotNull(transportTracer, "transportTracer"); // Read the id early to avoid reading transportState later. this.streamId = transportState().id(); } @@ -96,38 +93,26 @@ class NettyServerStream extends AbstractServerStream { @Override public void writeHeaders(Metadata headers, boolean flush) { try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.writeHeaders")) { - writeQueue.enqueue( - SendResponseHeadersCommand.createHeaders( - transportState(), - Utils.convertServerHeaders(headers)), - flush); + Http2Headers http2headers = Utils.convertServerHeaders(headers); + SendResponseHeadersCommand headersCommand = + SendResponseHeadersCommand.createHeaders(transportState(), http2headers); + writeQueue.enqueue(headersCommand, flush) + .addListener((ChannelFutureListener) transportState()::handleWriteFutureFailures); } } - private void writeFrameInternal(WritableBuffer frame, boolean flush, final int numMessages) { - Preconditions.checkArgument(numMessages >= 0); - ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf().touch(); - final int numBytes = bytebuf.readableBytes(); - // Add the bytes to outbound flow control. - onSendingBytes(numBytes); - writeQueue.enqueue(new SendGrpcFrameCommand(transportState(), bytebuf, false), flush) - .addListener(new ChannelFutureListener() { - @Override - public void operationComplete(ChannelFuture future) throws Exception { - // Remove the bytes from outbound flow control, optionally notifying - // the client that they can send more bytes. - transportState().onSentBytes(numBytes); - if (future.isSuccess()) { - transportTracer.reportMessageSent(numMessages); - } - } - }); - } - @Override public void writeFrame(WritableBuffer frame, boolean flush, final int numMessages) { try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.writeFrame")) { - writeFrameInternal(frame, flush, numMessages); + Preconditions.checkArgument(numMessages >= 0); + ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf().touch(); + final int numBytes = bytebuf.readableBytes(); + // Add the bytes to outbound flow control. + onSendingBytes(numBytes); + ChannelFutureListener failureListener = + future -> transportState().onWriteFrameData(future, numMessages, numBytes); + writeQueue.enqueue(new SendGrpcFrameCommand(transportState(), bytebuf, false), flush) + .addListener(failureListener); } } @@ -135,9 +120,10 @@ class NettyServerStream extends AbstractServerStream { public void writeTrailers(Metadata trailers, boolean headersSent, Status status) { try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.writeTrailers")) { Http2Headers http2Trailers = Utils.convertTrailers(trailers, headersSent); - writeQueue.enqueue( - SendResponseHeadersCommand.createTrailers(transportState(), http2Trailers, status), - true); + SendResponseHeadersCommand trailersCommand = + SendResponseHeadersCommand.createTrailers(transportState(), http2Trailers, status); + writeQueue.enqueue(trailersCommand, true) + .addListener((ChannelFutureListener) transportState()::handleWriteFutureFailures); } } @@ -206,6 +192,39 @@ class NettyServerStream extends AbstractServerStream { handler.getWriteQueue().enqueue(new CancelServerStreamCommand(this, status), true); } + private void onWriteFrameData(ChannelFuture future, int numMessages, int numBytes) { + // Remove the bytes from outbound flow control, optionally notifying + // the client that they can send more bytes. + if (future.isSuccess()) { + onSentBytes(numBytes); + getTransportTracer().reportMessageSent(numMessages); + } else { + handleWriteFutureFailures(future); + } + } + + private void handleWriteFutureFailures(ChannelFuture future) { + // isStreamDeallocated() check protects from spamming stream resets by scheduling multiple + // CancelServerStreamCommand commands. + if (future.isSuccess() || isStreamDeallocated()) { + return; + } + + // Future failed, fail RPC. + // Normally we don't need to do anything on frame write failures because the cause of + // the failed future would be an IO error that closed the stream. + // However, we still need handle any unexpected failures raised in Netty. + http2ProcessingFailed(Utils.statusFromThrowable(future.cause())); + } + + /** + * Called to process a failure in HTTP/2 processing. + */ + protected void http2ProcessingFailed(Status status) { + transportReportStatus(status); + handler.getWriteQueue().enqueue(new CancelServerStreamCommand(this, status), true); + } + void inboundDataReceived(ByteBuf frame, boolean endOfStream) { super.inboundDataReceived(new NettyReadableBuffer(frame.retain()), endOfStream); } diff --git a/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java b/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java index 7a9d937a5a..2a5a0df279 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java @@ -23,6 +23,8 @@ import static io.grpc.netty.NettyTestUtil.messageFrame; import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC; import static io.grpc.netty.Utils.CONTENT_TYPE_HEADER; import static io.grpc.netty.Utils.STATUS_OK; +import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; import static io.netty.util.CharsetUtil.UTF_8; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -34,6 +36,7 @@ import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -62,6 +65,7 @@ import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelPromise; import io.netty.channel.DefaultChannelPromise; import io.netty.handler.codec.http2.DefaultHttp2Headers; +import io.netty.handler.codec.http2.Http2Exception; import io.netty.handler.codec.http2.Http2Headers; import io.netty.util.AsciiString; import java.io.BufferedInputStream; @@ -75,6 +79,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; import org.mockito.ArgumentMatchers; +import org.mockito.InOrder; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; @@ -205,6 +210,50 @@ public class NettyClientStreamTest extends NettyStreamTestBase invocation.getMethod().getName().equals("enqueue")) + // Get the third invocation of enqueue() + .skip(2).findFirst().get() + // Get the first argument (QueuedCommand command) + .getArgument(0); + + Status cancelReason = cancelCommand.reason(); + assertThat(cancelReason.getCode()).isEqualTo(Status.INTERNAL.getCode()); + assertThat(cancelReason.getCause()).isEqualTo(h2Error); + // Verify listener closed. + verify(listener).closed(same(cancelReason), eq(PROCESSED), any(Metadata.class)); + } + @Test public void setStatusWithOkShouldCloseStream() { stream().transportState().setId(STREAM_ID); diff --git a/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java b/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java index e95a2a52bc..ab54d4e4e2 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java @@ -17,12 +17,17 @@ package io.grpc.netty; import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import static io.grpc.netty.NettyTestUtil.messageFrame; +import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; import static org.junit.Assert.assertNull; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.never; @@ -43,19 +48,25 @@ import io.grpc.internal.StreamListener; import io.grpc.internal.TransportTracer; import io.netty.buffer.EmptyByteBuf; import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.channel.DefaultChannelPromise; import io.netty.handler.codec.http2.DefaultHttp2Headers; +import io.netty.handler.codec.http2.Http2Exception; import io.netty.util.AsciiString; import java.io.ByteArrayInputStream; import java.io.InputStream; import java.util.LinkedList; +import java.util.List; import java.util.Queue; +import java.util.stream.Collectors; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; import org.mockito.ArgumentMatchers; +import org.mockito.InOrder; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -124,6 +135,99 @@ public class NettyServerStreamTest extends NettyStreamTestBase headersCommandClass = SendResponseHeadersCommand.class; + when(writeQueue.enqueue(any(headersCommandClass), anyBoolean())).thenReturn( + new DefaultChannelPromise(channel).setFailure(h2Error)); + + // Prepare different headers to make it easier to distinguish in the error message. + Metadata headers1 = new Metadata(); + headers1.put(Metadata.Key.of("writeHeaders", Metadata.ASCII_STRING_MARSHALLER), "1"); + Metadata headers2 = new Metadata(); + headers2.put(Metadata.Key.of("writeHeaders", Metadata.ASCII_STRING_MARSHALLER), "2"); + Metadata headers3 = new Metadata(); + headers3.put(Metadata.Key.of("writeHeaders", Metadata.ASCII_STRING_MARSHALLER), "3"); + + // Note writeHeaders flush argument shouldn't matter for this test. + stream().writeHeaders(headers1, false); + stream().writeHeaders(headers2, false); + stream().writeHeaders(headers3, true); + stream.flush(); + + verifyWriteFutureFailure(h2Error); + // Verify CancelServerStreamCommand enqueued once, right after first SendResponseHeadersCommand. + InOrder inOrder = Mockito.inOrder(writeQueue); + inOrder.verify(writeQueue).enqueue(any(headersCommandClass), anyBoolean()); + inOrder.verify(writeQueue).enqueue(any(CancelServerStreamCommand.class), eq(true)); + inOrder.verify(writeQueue, atLeast(1)).enqueue(any(headersCommandClass), anyBoolean()); + inOrder.verifyNoMoreInteractions(); + } + + @Test + public void writeTrailersFutureFailedShouldCancelRpc() { + Http2Exception h2Error = connectionError(PROTOCOL_ERROR, "Stream does not exist %d", STREAM_ID); + when(writeQueue.enqueue(any(SendResponseHeadersCommand.class), eq(true))).thenReturn( + new DefaultChannelPromise(channel).setFailure(h2Error)); + + stream().close(Status.OK, trailers); + + verifyWriteFutureFailure(h2Error); + verify(writeQueue).enqueue(any(CancelServerStreamCommand.class), eq(true)); + } + + private void verifyWriteFutureFailure(Http2Exception h2Error) { + // Check the error that caused the future write failure propagated via Status. + Status cancelReason = findCancelServerStreamCommand().reason(); + assertThat(cancelReason.getCode()).isEqualTo(Status.INTERNAL.getCode()); + assertThat(cancelReason.getCause()).isEqualTo(h2Error); + // Verify the listener has closed. + verify(serverListener).closed(same(cancelReason)); + } + + private CancelServerStreamCommand findCancelServerStreamCommand() { + // Ensure there's no CancelServerStreamCommand enqueued with flush=false. + verify(writeQueue, never()).enqueue(any(CancelServerStreamCommand.class), eq(false)); + + List commands = Mockito.mockingDetails(writeQueue).getInvocations() + .stream() + // Get enqueue() innovations only. + .filter(invocation -> invocation.getMethod().getName().equals("enqueue")) + // Find the cancel commands. + .filter(invocation -> invocation.getArgument(0) instanceof CancelServerStreamCommand) + .map(invocation -> invocation.getArgument(0, CancelServerStreamCommand.class)) + .collect(Collectors.toList()); + + assertWithMessage("Expected exactly one CancelClientStreamCommand").that(commands).hasSize(1); + return commands.get(0); + } + @Test public void writeHeadersShouldSendHeaders() throws Exception { Metadata headers = new Metadata(); @@ -290,7 +394,7 @@ public class NettyServerStreamTest extends NettyStreamTestBase