diff --git a/core/src/main/java/io/grpc/internal/AbstractServerStream.java b/core/src/main/java/io/grpc/internal/AbstractServerStream.java index e2a0829200..a535330f4b 100644 --- a/core/src/main/java/io/grpc/internal/AbstractServerStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractServerStream.java @@ -293,6 +293,7 @@ public abstract class AbstractServerStream extends AbstractStream */ public final void transportReportStatus(final Status status) { Preconditions.checkArgument(!status.isOk(), "status must not be OK"); + onStreamDeallocated(); if (deframerClosed) { deframerClosedTask = null; closeListener(status); @@ -315,6 +316,7 @@ public abstract class AbstractServerStream extends AbstractStream * #transportReportStatus}. */ public void complete() { + onStreamDeallocated(); if (deframerClosed) { deframerClosedTask = null; closeListener(Status.OK); @@ -350,7 +352,6 @@ public abstract class AbstractServerStream extends AbstractStream getTransportTracer().reportStreamClosed(closedStatus.isOk()); } listenerClosed = true; - onStreamDeallocated(); listener().closed(newStatus); } } diff --git a/core/src/main/java/io/grpc/internal/AbstractStream.java b/core/src/main/java/io/grpc/internal/AbstractStream.java index 4b7486e466..9efc488657 100644 --- a/core/src/main/java/io/grpc/internal/AbstractStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractStream.java @@ -322,6 +322,12 @@ public abstract class AbstractStream implements Stream { } } + protected boolean isStreamDeallocated() { + synchronized (onReadyLock) { + return deallocated; + } + } + /** * Event handler to be called by the subclass when a number of bytes are being queued for * sending to the remote endpoint. diff --git a/netty/src/main/java/io/grpc/netty/NettyClientStream.java b/netty/src/main/java/io/grpc/netty/NettyClientStream.java index d88520a749..2939eed2e3 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientStream.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientStream.java @@ -182,20 +182,10 @@ class NettyClientStream extends AbstractClientStream { if (numBytes > 0) { // Add the bytes to outbound flow control. onSendingBytes(numBytes); + ChannelFutureListener failureListener = + future -> transportState().onWriteFrameData(future, numMessages, numBytes); writeQueue.enqueue(new SendGrpcFrameCommand(transportState(), bytebuf, endOfStream), flush) - .addListener(new ChannelFutureListener() { - @Override - public void operationComplete(ChannelFuture future) throws Exception { - // If the future succeeds when http2stream is null, the stream has been cancelled - // before it began and Netty is purging pending writes from the flow-controller. - if (future.isSuccess() && transportState().http2Stream() != null) { - // Remove the bytes from outbound flow control, optionally notifying - // the client that they can send more bytes. - transportState().onSentBytes(numBytes); - NettyClientStream.this.getTransportTracer().reportMessageSent(numMessages); - } - } - }); + .addListener(failureListener); } else { // The frame is empty and will not impact outbound flow control. Just send it. writeQueue.enqueue( @@ -307,6 +297,29 @@ class NettyClientStream extends AbstractClientStream { handler.getWriteQueue().enqueue(new CancelClientStreamCommand(this, status), true); } + private void onWriteFrameData(ChannelFuture future, int numMessages, int numBytes) { + // If the future succeeds when http2stream is null, the stream has been cancelled + // before it began and Netty is purging pending writes from the flow-controller. + if (future.isSuccess() && http2Stream() == null) { + return; + } + + if (future.isSuccess()) { + // Remove the bytes from outbound flow control, optionally notifying + // the client that they can send more bytes. + onSentBytes(numBytes); + getTransportTracer().reportMessageSent(numMessages); + } else if (!isStreamDeallocated()) { + // Future failed, fail RPC. + // Normally we don't need to do anything here because the cause of a failed future + // while writing DATA frames would be an IO error and the stream is already closed. + // However, we still need handle any unexpected failures raised in Netty. + // Note: isStreamDeallocated() protects from spamming stream resets by scheduling multiple + // CancelClientStreamCommand commands. + http2ProcessingFailed(statusFromFailedFuture(future), true, new Metadata()); + } + } + @Override public void runOnTransportThread(final Runnable r) { if (eventLoop.inEventLoop()) { diff --git a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java index 2b06a3fcf5..500368e880 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java @@ -502,8 +502,7 @@ class NettyServerHandler extends AbstractNettyHandler { state, attributes, authority, - statsTraceCtx, - transportTracer); + statsTraceCtx); transportListener.streamCreated(stream, method, metadata); state.onStreamAllocated(); http2Stream.setProperty(streamKey, state); diff --git a/netty/src/main/java/io/grpc/netty/NettyServerStream.java b/netty/src/main/java/io/grpc/netty/NettyServerStream.java index a44d8b4a64..a4304d5193 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerStream.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerStream.java @@ -52,7 +52,6 @@ class NettyServerStream extends AbstractServerStream { private final WriteQueue writeQueue; private final Attributes attributes; private final String authority; - private final TransportTracer transportTracer; private final int streamId; public NettyServerStream( @@ -60,14 +59,12 @@ class NettyServerStream extends AbstractServerStream { TransportState state, Attributes transportAttrs, String authority, - StatsTraceContext statsTraceCtx, - TransportTracer transportTracer) { + StatsTraceContext statsTraceCtx) { super(new NettyWritableBufferAllocator(channel.alloc()), statsTraceCtx); this.state = checkNotNull(state, "transportState"); this.writeQueue = state.handler.getWriteQueue(); this.attributes = checkNotNull(transportAttrs); this.authority = authority; - this.transportTracer = checkNotNull(transportTracer, "transportTracer"); // Read the id early to avoid reading transportState later. this.streamId = transportState().id(); } @@ -96,38 +93,26 @@ class NettyServerStream extends AbstractServerStream { @Override public void writeHeaders(Metadata headers, boolean flush) { try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.writeHeaders")) { - writeQueue.enqueue( - SendResponseHeadersCommand.createHeaders( - transportState(), - Utils.convertServerHeaders(headers)), - flush); + Http2Headers http2headers = Utils.convertServerHeaders(headers); + SendResponseHeadersCommand headersCommand = + SendResponseHeadersCommand.createHeaders(transportState(), http2headers); + writeQueue.enqueue(headersCommand, flush) + .addListener((ChannelFutureListener) transportState()::handleWriteFutureFailures); } } - private void writeFrameInternal(WritableBuffer frame, boolean flush, final int numMessages) { - Preconditions.checkArgument(numMessages >= 0); - ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf().touch(); - final int numBytes = bytebuf.readableBytes(); - // Add the bytes to outbound flow control. - onSendingBytes(numBytes); - writeQueue.enqueue(new SendGrpcFrameCommand(transportState(), bytebuf, false), flush) - .addListener(new ChannelFutureListener() { - @Override - public void operationComplete(ChannelFuture future) throws Exception { - // Remove the bytes from outbound flow control, optionally notifying - // the client that they can send more bytes. - transportState().onSentBytes(numBytes); - if (future.isSuccess()) { - transportTracer.reportMessageSent(numMessages); - } - } - }); - } - @Override public void writeFrame(WritableBuffer frame, boolean flush, final int numMessages) { try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.writeFrame")) { - writeFrameInternal(frame, flush, numMessages); + Preconditions.checkArgument(numMessages >= 0); + ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf().touch(); + final int numBytes = bytebuf.readableBytes(); + // Add the bytes to outbound flow control. + onSendingBytes(numBytes); + ChannelFutureListener failureListener = + future -> transportState().onWriteFrameData(future, numMessages, numBytes); + writeQueue.enqueue(new SendGrpcFrameCommand(transportState(), bytebuf, false), flush) + .addListener(failureListener); } } @@ -135,9 +120,10 @@ class NettyServerStream extends AbstractServerStream { public void writeTrailers(Metadata trailers, boolean headersSent, Status status) { try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.writeTrailers")) { Http2Headers http2Trailers = Utils.convertTrailers(trailers, headersSent); - writeQueue.enqueue( - SendResponseHeadersCommand.createTrailers(transportState(), http2Trailers, status), - true); + SendResponseHeadersCommand trailersCommand = + SendResponseHeadersCommand.createTrailers(transportState(), http2Trailers, status); + writeQueue.enqueue(trailersCommand, true) + .addListener((ChannelFutureListener) transportState()::handleWriteFutureFailures); } } @@ -206,6 +192,39 @@ class NettyServerStream extends AbstractServerStream { handler.getWriteQueue().enqueue(new CancelServerStreamCommand(this, status), true); } + private void onWriteFrameData(ChannelFuture future, int numMessages, int numBytes) { + // Remove the bytes from outbound flow control, optionally notifying + // the client that they can send more bytes. + if (future.isSuccess()) { + onSentBytes(numBytes); + getTransportTracer().reportMessageSent(numMessages); + } else { + handleWriteFutureFailures(future); + } + } + + private void handleWriteFutureFailures(ChannelFuture future) { + // isStreamDeallocated() check protects from spamming stream resets by scheduling multiple + // CancelServerStreamCommand commands. + if (future.isSuccess() || isStreamDeallocated()) { + return; + } + + // Future failed, fail RPC. + // Normally we don't need to do anything on frame write failures because the cause of + // the failed future would be an IO error that closed the stream. + // However, we still need handle any unexpected failures raised in Netty. + http2ProcessingFailed(Utils.statusFromThrowable(future.cause())); + } + + /** + * Called to process a failure in HTTP/2 processing. + */ + protected void http2ProcessingFailed(Status status) { + transportReportStatus(status); + handler.getWriteQueue().enqueue(new CancelServerStreamCommand(this, status), true); + } + void inboundDataReceived(ByteBuf frame, boolean endOfStream) { super.inboundDataReceived(new NettyReadableBuffer(frame.retain()), endOfStream); } diff --git a/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java b/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java index 7a9d937a5a..2a5a0df279 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java @@ -23,6 +23,8 @@ import static io.grpc.netty.NettyTestUtil.messageFrame; import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC; import static io.grpc.netty.Utils.CONTENT_TYPE_HEADER; import static io.grpc.netty.Utils.STATUS_OK; +import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; import static io.netty.util.CharsetUtil.UTF_8; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -34,6 +36,7 @@ import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -62,6 +65,7 @@ import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelPromise; import io.netty.channel.DefaultChannelPromise; import io.netty.handler.codec.http2.DefaultHttp2Headers; +import io.netty.handler.codec.http2.Http2Exception; import io.netty.handler.codec.http2.Http2Headers; import io.netty.util.AsciiString; import java.io.BufferedInputStream; @@ -75,6 +79,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; import org.mockito.ArgumentMatchers; +import org.mockito.InOrder; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; @@ -205,6 +210,50 @@ public class NettyClientStreamTest extends NettyStreamTestBase invocation.getMethod().getName().equals("enqueue")) + // Get the third invocation of enqueue() + .skip(2).findFirst().get() + // Get the first argument (QueuedCommand command) + .getArgument(0); + + Status cancelReason = cancelCommand.reason(); + assertThat(cancelReason.getCode()).isEqualTo(Status.INTERNAL.getCode()); + assertThat(cancelReason.getCause()).isEqualTo(h2Error); + // Verify listener closed. + verify(listener).closed(same(cancelReason), eq(PROCESSED), any(Metadata.class)); + } + @Test public void setStatusWithOkShouldCloseStream() { stream().transportState().setId(STREAM_ID); diff --git a/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java b/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java index e95a2a52bc..ab54d4e4e2 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java @@ -17,12 +17,17 @@ package io.grpc.netty; import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import static io.grpc.netty.NettyTestUtil.messageFrame; +import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; import static org.junit.Assert.assertNull; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.never; @@ -43,19 +48,25 @@ import io.grpc.internal.StreamListener; import io.grpc.internal.TransportTracer; import io.netty.buffer.EmptyByteBuf; import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.channel.DefaultChannelPromise; import io.netty.handler.codec.http2.DefaultHttp2Headers; +import io.netty.handler.codec.http2.Http2Exception; import io.netty.util.AsciiString; import java.io.ByteArrayInputStream; import java.io.InputStream; import java.util.LinkedList; +import java.util.List; import java.util.Queue; +import java.util.stream.Collectors; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; import org.mockito.ArgumentMatchers; +import org.mockito.InOrder; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -124,6 +135,99 @@ public class NettyServerStreamTest extends NettyStreamTestBase headersCommandClass = SendResponseHeadersCommand.class; + when(writeQueue.enqueue(any(headersCommandClass), anyBoolean())).thenReturn( + new DefaultChannelPromise(channel).setFailure(h2Error)); + + // Prepare different headers to make it easier to distinguish in the error message. + Metadata headers1 = new Metadata(); + headers1.put(Metadata.Key.of("writeHeaders", Metadata.ASCII_STRING_MARSHALLER), "1"); + Metadata headers2 = new Metadata(); + headers2.put(Metadata.Key.of("writeHeaders", Metadata.ASCII_STRING_MARSHALLER), "2"); + Metadata headers3 = new Metadata(); + headers3.put(Metadata.Key.of("writeHeaders", Metadata.ASCII_STRING_MARSHALLER), "3"); + + // Note writeHeaders flush argument shouldn't matter for this test. + stream().writeHeaders(headers1, false); + stream().writeHeaders(headers2, false); + stream().writeHeaders(headers3, true); + stream.flush(); + + verifyWriteFutureFailure(h2Error); + // Verify CancelServerStreamCommand enqueued once, right after first SendResponseHeadersCommand. + InOrder inOrder = Mockito.inOrder(writeQueue); + inOrder.verify(writeQueue).enqueue(any(headersCommandClass), anyBoolean()); + inOrder.verify(writeQueue).enqueue(any(CancelServerStreamCommand.class), eq(true)); + inOrder.verify(writeQueue, atLeast(1)).enqueue(any(headersCommandClass), anyBoolean()); + inOrder.verifyNoMoreInteractions(); + } + + @Test + public void writeTrailersFutureFailedShouldCancelRpc() { + Http2Exception h2Error = connectionError(PROTOCOL_ERROR, "Stream does not exist %d", STREAM_ID); + when(writeQueue.enqueue(any(SendResponseHeadersCommand.class), eq(true))).thenReturn( + new DefaultChannelPromise(channel).setFailure(h2Error)); + + stream().close(Status.OK, trailers); + + verifyWriteFutureFailure(h2Error); + verify(writeQueue).enqueue(any(CancelServerStreamCommand.class), eq(true)); + } + + private void verifyWriteFutureFailure(Http2Exception h2Error) { + // Check the error that caused the future write failure propagated via Status. + Status cancelReason = findCancelServerStreamCommand().reason(); + assertThat(cancelReason.getCode()).isEqualTo(Status.INTERNAL.getCode()); + assertThat(cancelReason.getCause()).isEqualTo(h2Error); + // Verify the listener has closed. + verify(serverListener).closed(same(cancelReason)); + } + + private CancelServerStreamCommand findCancelServerStreamCommand() { + // Ensure there's no CancelServerStreamCommand enqueued with flush=false. + verify(writeQueue, never()).enqueue(any(CancelServerStreamCommand.class), eq(false)); + + List commands = Mockito.mockingDetails(writeQueue).getInvocations() + .stream() + // Get enqueue() innovations only. + .filter(invocation -> invocation.getMethod().getName().equals("enqueue")) + // Find the cancel commands. + .filter(invocation -> invocation.getArgument(0) instanceof CancelServerStreamCommand) + .map(invocation -> invocation.getArgument(0, CancelServerStreamCommand.class)) + .collect(Collectors.toList()); + + assertWithMessage("Expected exactly one CancelClientStreamCommand").that(commands).hasSize(1); + return commands.get(0); + } + @Test public void writeHeadersShouldSendHeaders() throws Exception { Metadata headers = new Metadata(); @@ -290,7 +394,7 @@ public class NettyServerStreamTest extends NettyStreamTestBase