diff --git a/core/src/main/java/io/grpc/transport/MessageDeframer.java b/core/src/main/java/io/grpc/transport/MessageDeframer.java index 638dd51dfd..9aa88eb097 100644 --- a/core/src/main/java/io/grpc/transport/MessageDeframer.java +++ b/core/src/main/java/io/grpc/transport/MessageDeframer.java @@ -103,6 +103,7 @@ public class MessageDeframer implements Closeable { private CompositeReadableBuffer unprocessed = new CompositeReadableBuffer(); private long pendingDeliveries; private boolean deliveryStalled = true; + private boolean inDelivery = false; /** * Creates a deframer. Compression will not be supported. @@ -216,49 +217,59 @@ public class MessageDeframer implements Closeable { * Reads and delivers as many messages to the sink as possible. */ private void deliver() { - // Process the uncompressed bytes. - boolean stalled = false; - while (pendingDeliveries > 0 && readRequiredBytes()) { - switch (state) { - case HEADER: - processHeader(); - break; - case BODY: - // Read the body and deliver the message. - processBody(); - - // Since we've delivered a message, decrement the number of pending - // deliveries remaining. - pendingDeliveries--; - break; - default: - throw new AssertionError("Invalid state: " + state); - } + // We can have reentrancy here when using a direct executor, triggered by calls to + // request more messages. This is safe as we simply loop until pendingDelivers = 0 + if (inDelivery) { + return; } - // We are stalled when there are no more bytes to process. This allows delivering errors as soon - // as the buffered input has been consumed, independent of whether the application has requested - // another message. - stalled = !isDataAvailable(); + inDelivery = true; + try { + // Process the uncompressed bytes. + boolean stalled = false; + while (pendingDeliveries > 0 && readRequiredBytes()) { + switch (state) { + case HEADER: + processHeader(); + break; + case BODY: + // Read the body and deliver the message. + processBody(); - if (endOfStream) { - if (!isDataAvailable()) { - listener.endOfStream(); - } else if (stalled) { - // We've received the entire stream and have data available but we don't have - // enough to read the next frame ... this is bad. - throw Status.INTERNAL.withDescription("Encountered end-of-stream mid-frame") - .asRuntimeException(); + // Since we've delivered a message, decrement the number of pending + // deliveries remaining. + pendingDeliveries--; + break; + default: + throw new AssertionError("Invalid state: " + state); + } } - } + // We are stalled when there are no more bytes to process. This allows delivering errors as + // soon as the buffered input has been consumed, independent of whether the application + // has requested another message. + stalled = !isDataAvailable(); - // Never indicate that we're stalled if we've received all the data for the stream. - stalled &= !endOfStream; + if (endOfStream) { + if (!isDataAvailable()) { + listener.endOfStream(); + } else if (stalled) { + // We've received the entire stream and have data available but we don't have + // enough to read the next frame ... this is bad. + throw Status.INTERNAL.withDescription("Encountered end-of-stream mid-frame") + .asRuntimeException(); + } + } - // If we're transitioning to the stalled state, notify the listener. - boolean previouslyStalled = deliveryStalled; - deliveryStalled = stalled; - if (stalled && !previouslyStalled) { - listener.deliveryStalled(); + // Never indicate that we're stalled if we've received all the data for the stream. + stalled &= !endOfStream; + + // If we're transitioning to the stalled state, notify the listener. + boolean previouslyStalled = deliveryStalled; + deliveryStalled = stalled; + if (stalled && !previouslyStalled) { + listener.deliveryStalled(); + } + } finally { + inDelivery = false; } } diff --git a/core/src/test/java/io/grpc/transport/MessageDeframerTest.java b/core/src/test/java/io/grpc/transport/MessageDeframerTest.java index 7e943c3f54..b61591cff3 100644 --- a/core/src/test/java/io/grpc/transport/MessageDeframerTest.java +++ b/core/src/test/java/io/grpc/transport/MessageDeframerTest.java @@ -35,6 +35,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.anyInt; import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -49,6 +50,9 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +import org.mockito.Matchers; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -184,6 +188,26 @@ public class MessageDeframerTest { verifyNoMoreInteractions(listener); } + @Test + public void deliverIsReentrantSafe() { + doAnswer(new Answer() { + @Override + public Object answer(InvocationOnMock invocation) throws Throwable { + deframer.request(1); + return null; + } + }).when(listener).messageRead(Matchers.any()); + deframer.deframe(buffer(new byte[] {0, 0, 0, 0, 1, 3}), true); + verifyNoMoreInteractions(listener); + + deframer.request(1); + verify(listener).messageRead(messages.capture()); + assertEquals(Bytes.asList(new byte[] {3}), bytes(messages)); + verify(listener).endOfStream(); + verify(listener, atLeastOnce()).bytesRead(anyInt()); + verifyNoMoreInteractions(listener); + } + private static List bytes(ArgumentCaptor captor) { return bytes(captor.getValue()); } diff --git a/lib/netty b/lib/netty index 1cce998bb0..9d70cf33c2 160000 --- a/lib/netty +++ b/lib/netty @@ -1 +1 @@ -Subproject commit 1cce998bb06f42ff47925390872dcc5c487bbf59 +Subproject commit 9d70cf33c2ccea42d0fe651be61b2e0a6579fcb1 diff --git a/netty/src/main/java/io/grpc/transport/netty/NettyClientHandler.java b/netty/src/main/java/io/grpc/transport/netty/NettyClientHandler.java index 8991c302b1..a6b2ab3f7b 100644 --- a/netty/src/main/java/io/grpc/transport/netty/NettyClientHandler.java +++ b/netty/src/main/java/io/grpc/transport/netty/NettyClientHandler.java @@ -130,6 +130,7 @@ class NettyClientHandler extends Http2ConnectionHandler { // Initialize the connection window if we haven't already. initConnectionWindow(); + ctx.flush(); } /** diff --git a/netty/src/main/java/io/grpc/transport/netty/NettyClientStream.java b/netty/src/main/java/io/grpc/transport/netty/NettyClientStream.java index 2c4173704a..0573f20d1b 100644 --- a/netty/src/main/java/io/grpc/transport/netty/NettyClientStream.java +++ b/netty/src/main/java/io/grpc/transport/netty/NettyClientStream.java @@ -132,5 +132,7 @@ class NettyClientStream extends Http2ClientStream { @Override protected void returnProcessedBytes(int processedBytes) { handler.returnProcessedBytes(http2Stream, processedBytes); + // Need to flush as window update may have been written + channel.flush(); } } diff --git a/netty/src/main/java/io/grpc/transport/netty/NettyServerStream.java b/netty/src/main/java/io/grpc/transport/netty/NettyServerStream.java index 6e52d78038..6178cdd492 100644 --- a/netty/src/main/java/io/grpc/transport/netty/NettyServerStream.java +++ b/netty/src/main/java/io/grpc/transport/netty/NettyServerStream.java @@ -105,5 +105,7 @@ class NettyServerStream extends AbstractServerStream { @Override protected void returnProcessedBytes(int processedBytes) { handler.returnProcessedBytes(http2Stream, processedBytes); + // Need to flush as window update may have been written + channel.flush(); } } diff --git a/netty/src/test/java/io/grpc/transport/netty/NettyClientHandlerTest.java b/netty/src/test/java/io/grpc/transport/netty/NettyClientHandlerTest.java index 8f60d114ac..e6cdc45d7b 100644 --- a/netty/src/test/java/io/grpc/transport/netty/NettyClientHandlerTest.java +++ b/netty/src/test/java/io/grpc/transport/netty/NettyClientHandlerTest.java @@ -190,7 +190,6 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase { ByteBuf expected = rstStreamFrame(3, (int) Http2Error.CANCEL.code()); verify(ctx).write(eq(expected), eq(promise)); - verify(ctx).flush(); } @Test @@ -221,7 +220,6 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase { ByteBuf expected = rstStreamFrame(3, (int) Http2Error.CANCEL.code()); verify(ctx).write(eq(expected), eq(promise)); - verify(ctx).flush(); promise = mock(ChannelPromise.class); handler.write(ctx, new CancelStreamCommand(stream), promise); @@ -235,7 +233,6 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase { handler.write(ctx, new SendGrpcFrameCommand(stream, content, true), promise); verify(promise, never()).setFailure(any(Throwable.class)); ByteBuf bufWritten = captureWrite(ctx); - verify(ctx).flush(); int startIndex = bufWritten.readerIndex() + Http2CodecUtil.FRAME_HEADER_LENGTH; int length = bufWritten.writerIndex() - startIndex; ByteBuf writtenContent = bufWritten.slice(startIndex, length);