diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientHandler.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientHandler.java index bb95dcea6c..e281567a5b 100644 --- a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientHandler.java +++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientHandler.java @@ -5,16 +5,13 @@ import static com.google.net.stubby.newtransport.netty.NettyClientStream.PENDING import com.google.common.base.Preconditions; import com.google.net.stubby.Metadata; import com.google.net.stubby.Status; -import com.google.net.stubby.newtransport.StreamState; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; -import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder; import io.netty.handler.codec.http2.DefaultHttp2InboundFlowController; -import io.netty.handler.codec.http2.Http2CodecUtil; import io.netty.handler.codec.http2.Http2Connection; import io.netty.handler.codec.http2.Http2ConnectionAdapter; import io.netty.handler.codec.http2.Http2ConnectionHandler; @@ -130,8 +127,7 @@ class NettyClientHandler extends Http2ConnectionHandler { } private void initListener() { - ((LazyFrameListener) ((DefaultHttp2ConnectionDecoder) this.decoder()).listener()).setHandler( - this); + ((LazyFrameListener) decoder().listener()).setHandler(this); } private void onHeadersRead(int streamId, Http2Headers headers, boolean endStream) @@ -143,17 +139,10 @@ class NettyClientHandler extends Http2ConnectionHandler { /** * Handler for an inbound HTTP/2 DATA frame. */ - private void onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, - boolean endOfStream) throws Http2Exception { + private void onDataRead(int streamId, ByteBuf data, boolean endOfStream) throws Http2Exception { Http2Stream http2Stream = connection().requireStream(streamId); NettyClientStream stream = clientStream(http2Stream); stream.inboundDataReceived(data, endOfStream); - if (stream.state() == StreamState.CLOSED && !endOfStream) { - // TODO(user): This is a hack due to the test server not consistently - // setting endOfStream on the last frame for the v1 protocol. - // Remove this once b/17692766 is fixed. - lifecycleManager().closeRemoteSide(http2Stream, ctx.newSucceededFuture()); - } } /** @@ -185,21 +174,25 @@ class NettyClientHandler extends Http2ConnectionHandler { } @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - // Force the conversion of any exceptions into HTTP/2 exceptions. - Http2Exception e = Http2CodecUtil.toHttp2Exception(cause); - if (e instanceof Http2StreamException) { - // Close the stream with a status that contains the cause. - Http2Stream stream = connection().stream(((Http2StreamException) e).streamId()); - if (stream != null) { - clientStream(stream).setStatus(Status.fromThrowable(cause), new Metadata.Trailers()); - } - } else { - connectionError = e; + protected void onConnectionError(ChannelHandlerContext ctx, Throwable cause, + Http2Exception http2Ex) { + // Save the error. + connectionError = cause; + + super.onConnectionError(ctx, cause, http2Ex); + } + + @Override + protected void onStreamError(ChannelHandlerContext ctx, Throwable cause, + Http2StreamException http2Ex) { + // Close the stream with a status that contains the cause. + Http2Stream stream = connection().stream(http2Ex.streamId()); + if (stream != null) { + clientStream(stream).setStatus(Status.fromThrowable(cause), new Metadata.Trailers()); } - // Delegate to the super class for proper handling of the Http2Exception. - super.exceptionCaught(ctx, e); + // Delegate to the base class to send a RST_STREAM. + super.onStreamError(ctx, cause, http2Ex); } /** @@ -244,21 +237,7 @@ class NettyClientHandler extends Http2ConnectionHandler { * Sends the given GRPC frame for the stream. */ private void sendGrpcFrame(ChannelHandlerContext ctx, SendGrpcFrameCommand cmd, - ChannelPromise promise) throws Http2Exception { - Http2Stream http2Stream = connection().requireStream(cmd.streamId()); - switch (http2Stream.state()) { - case CLOSED: - case HALF_CLOSED_LOCAL: - case IDLE: - case RESERVED_LOCAL: - case RESERVED_REMOTE: - cmd.release(); - promise.setFailure(new Exception("Closed before write could occur")); - return; - default: - break; - } - + ChannelPromise promise) { // Call the base class to write the HTTP/2 DATA frame. // Note: no need to flush since this is handled by the outbound flow controller. encoder().writeData(ctx, cmd.streamId(), cmd.content(), 0, cmd.endStream(), promise); @@ -272,7 +251,7 @@ class NettyClientHandler extends Http2ConnectionHandler { Status goAwayStatus = goAwayStatus(); failPendingStreams(goAwayStatus); - if (connection().local().isGoAwayReceived()) { + if (connection().goAwayReceived()) { // Received a GOAWAY from the remote endpoint. Fail any streams that were created after the // last known stream. int lastKnownStream = connection().local().lastKnownStream(); @@ -435,7 +414,7 @@ class NettyClientHandler extends Http2ConnectionHandler { @Override public void onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, boolean endOfStream) throws Http2Exception { - handler.onDataRead(ctx, streamId, data, endOfStream); + handler.onDataRead(streamId, data, endOfStream); } @Override diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyServerHandler.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyServerHandler.java index a4681b96ce..f360d3ba1c 100644 --- a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyServerHandler.java +++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyServerHandler.java @@ -1,7 +1,7 @@ package com.google.net.stubby.newtransport.netty; -import static com.google.net.stubby.newtransport.netty.Utils.CONTENT_TYPE_HEADER; import static com.google.net.stubby.newtransport.netty.Utils.CONTENT_TYPE_GRPC; +import static com.google.net.stubby.newtransport.netty.Utils.CONTENT_TYPE_HEADER; import static com.google.net.stubby.newtransport.netty.Utils.HTTP_METHOD; import static io.netty.buffer.Unpooled.EMPTY_BUFFER; import static io.netty.handler.codec.http2.Http2CodecUtil.toByteBuf; @@ -18,9 +18,7 @@ import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; -import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder; import io.netty.handler.codec.http2.DefaultHttp2InboundFlowController; -import io.netty.handler.codec.http2.Http2CodecUtil; import io.netty.handler.codec.http2.Http2Connection; import io.netty.handler.codec.http2.Http2ConnectionHandler; import io.netty.handler.codec.http2.Http2Error; @@ -71,8 +69,7 @@ class NettyServerHandler extends Http2ConnectionHandler { } private void initListener() { - ((LazyFrameListener) ((DefaultHttp2ConnectionDecoder) this.decoder()).listener()).setHandler( - this); + ((LazyFrameListener) decoder().listener()).setHandler(this); } @Override @@ -103,7 +100,7 @@ class NettyServerHandler extends Http2ConnectionHandler { throw e; } catch (Throwable e) { logger.log(Level.WARNING, "Exception in onHeadersRead()", e); - throw new Http2StreamException(streamId, Http2Error.INTERNAL_ERROR, e.toString()); + throw newStreamException(streamId, e); } } @@ -115,7 +112,7 @@ class NettyServerHandler extends Http2ConnectionHandler { throw e; } catch (Throwable e) { logger.log(Level.WARNING, "Exception in onDataRead()", e); - throw new Http2StreamException(streamId, Http2Error.INTERNAL_ERROR, e.toString()); + throw newStreamException(streamId, e); } } @@ -127,29 +124,32 @@ class NettyServerHandler extends Http2ConnectionHandler { throw e; } catch (Throwable e) { logger.log(Level.WARNING, "Exception in onRstStreamRead()", e); - throw new Http2StreamException(streamId, Http2Error.INTERNAL_ERROR, e.toString()); + throw newStreamException(streamId, e); } } @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - // Force the conversion of any exceptions into HTTP/2 exceptions. - Http2Exception e = Http2CodecUtil.toHttp2Exception(cause); - if (e instanceof Http2StreamException) { - // Aborts the stream with a status that contains the cause. - Http2Stream stream = connection().stream(((Http2StreamException) cause).streamId()); - if (stream != null) { - // Send the error message to the client to help debugging. - serverStream(stream).abortStream(Status.fromThrowable(cause), true); - } else { - // Delegate to the base class for proper handling of the Http2Exception. - super.exceptionCaught(ctx, e); - } + protected void onConnectionError(ChannelHandlerContext ctx, Throwable cause, + Http2Exception http2Ex) { + connectionError = cause; + Http2Error error = http2Ex != null ? http2Ex.error() : Http2Error.INTERNAL_ERROR; + + // Write the GO_AWAY frame to the remote endpoint and then shutdown the channel. + goAwayAndClose(ctx, error.code(), toByteBuf(ctx, cause), ctx.newPromise()); + } + + @Override + protected void onStreamError(ChannelHandlerContext ctx, Throwable cause, + Http2StreamException http2Ex) { + Http2Stream stream = connection().stream(http2Ex.streamId()); + if (stream != null) { + // Abort the stream with a status to help the client with debugging. + // Don't need to send a RST_STREAM since the end-of-stream flag will + // be sent. + serverStream(stream).abortStream(Status.fromThrowable(cause), true); } else { - // Connection error... - connectionError = e; - // Write the GO_AWAY frame to the remote endpoint and then shutdown the channel. - goAwayAndClose(ctx, e.error().code(), toByteBuf(ctx, e), ctx.newPromise()); + // Delegate to the base class to send a RST_STREAM. + super.onStreamError(ctx, cause, http2Ex); } } @@ -172,23 +172,12 @@ class NettyServerHandler extends Http2ConnectionHandler { public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Http2Exception { if (msg instanceof SendGrpcFrameCommand) { - SendGrpcFrameCommand cmd = (SendGrpcFrameCommand) msg; - if (cmd.endStream()) { - closeStreamWhenDone(promise, cmd.streamId()); - } - // Call the base class to write the HTTP/2 DATA frame. - encoder().writeData(ctx, cmd.streamId(), cmd.content(), 0, cmd.endStream(), promise); - ctx.flush(); + sendGrpcFrame(ctx, (SendGrpcFrameCommand) msg, promise); } else if (msg instanceof SendResponseHeadersCommand) { - SendResponseHeadersCommand cmd = (SendResponseHeadersCommand) msg; - if (cmd.endOfStream()) { - closeStreamWhenDone(promise, cmd.streamId()); - } - encoder().writeHeaders(ctx, cmd.streamId(), cmd.headers(), 0, cmd.endOfStream(), promise); - ctx.flush(); + sendResponseHeaders(ctx, (SendResponseHeadersCommand) msg, promise); } else { - AssertionError e = new AssertionError("Write called for unexpected type: " - + msg.getClass().getName()); + AssertionError e = + new AssertionError("Write called for unexpected type: " + msg.getClass().getName()); ReferenceCountUtil.release(msg); promise.setFailure(e); throw e; @@ -205,21 +194,45 @@ class NettyServerHandler extends Http2ConnectionHandler { }); } + /** + * Sends the given gRPC frame to the client. + */ + private void sendGrpcFrame(ChannelHandlerContext ctx, SendGrpcFrameCommand cmd, + ChannelPromise promise) throws Http2Exception { + if (cmd.endStream()) { + closeStreamWhenDone(promise, cmd.streamId()); + } + // Call the base class to write the HTTP/2 DATA frame. + encoder().writeData(ctx, cmd.streamId(), cmd.content(), 0, cmd.endStream(), promise); + ctx.flush(); + } + + /** + * Sends the response headers to the client. + */ + private void sendResponseHeaders(ChannelHandlerContext ctx, SendResponseHeadersCommand cmd, + ChannelPromise promise) throws Http2Exception { + if (cmd.endOfStream()) { + closeStreamWhenDone(promise, cmd.streamId()); + } + encoder().writeHeaders(ctx, cmd.streamId(), cmd.headers(), 0, cmd.endOfStream(), promise); + ctx.flush(); + } + /** * Writes a {@code GO_AWAY} frame to the remote endpoint. When it completes, shuts down * the channel. */ private void goAwayAndClose(final ChannelHandlerContext ctx, int errorCode, ByteBuf data, ChannelPromise promise) { - if (connection().remote().isGoAwayReceived()) { + if (connection().goAwaySent()) { // Already sent the GO_AWAY. Do nothing. return; } // Write the GO_AWAY frame to the remote endpoint. int lastKnownStream = connection().remote().lastStreamCreated(); - ChannelFuture future = - lifecycleManager().writeGoAway(ctx, lastKnownStream, errorCode, data, promise); + ChannelFuture future = writeGoAway(ctx, lastKnownStream, errorCode, data, promise); // When the write completes, close this channel. future.addListener(new ChannelFutureListener() { @@ -256,6 +269,10 @@ class NettyServerHandler extends Http2ConnectionHandler { return stream.data(); } + private Http2StreamException newStreamException(int streamId, Throwable cause) { + return new Http2StreamException(streamId, Http2Error.INTERNAL_ERROR, cause.getMessage(), cause); + } + private static class LazyFrameListener extends Http2FrameAdapter { private NettyServerHandler handler; diff --git a/core/src/test/java/com/google/net/stubby/newtransport/MessageDeframer2Test.java b/core/src/test/java/com/google/net/stubby/newtransport/MessageDeframer2Test.java index 719548ee5e..f66724b697 100644 --- a/core/src/test/java/com/google/net/stubby/newtransport/MessageDeframer2Test.java +++ b/core/src/test/java/com/google/net/stubby/newtransport/MessageDeframer2Test.java @@ -4,7 +4,6 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; -import static org.mockito.Matchers.notNull; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -14,8 +13,6 @@ import com.google.common.io.ByteStreams; import com.google.common.primitives.Bytes; import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.SettableFuture; -import com.google.net.stubby.Metadata; -import com.google.net.stubby.Status; import com.google.net.stubby.newtransport.MessageDeframer2.Sink; import org.junit.Test; @@ -23,12 +20,11 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; -import java.util.Arrays; -import java.util.List; -import java.util.zip.GZIPOutputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; +import java.util.List; +import java.util.zip.GZIPOutputStream; /** * Tests for {@link MessageDeframer2}. diff --git a/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyClientStreamTest.java b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyClientStreamTest.java index 0cd5800895..4492027190 100644 --- a/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyClientStreamTest.java +++ b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyClientStreamTest.java @@ -1,7 +1,9 @@ package com.google.net.stubby.newtransport.netty; -import static com.google.net.stubby.newtransport.netty.Utils.CONTENT_TYPE_HEADER; +import static com.google.net.stubby.newtransport.netty.NettyTestUtil.messageFrame; +import static com.google.net.stubby.newtransport.netty.NettyTestUtil.statusFrame; import static com.google.net.stubby.newtransport.netty.Utils.CONTENT_TYPE_GRPC; +import static com.google.net.stubby.newtransport.netty.Utils.CONTENT_TYPE_HEADER; import static com.google.net.stubby.newtransport.netty.Utils.STATUS_OK; import static io.netty.util.CharsetUtil.UTF_8; import static org.junit.Assert.assertEquals; @@ -9,9 +11,7 @@ import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Matchers.same; import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyZeroInteractions; import com.google.net.stubby.Metadata; import com.google.net.stubby.Status; @@ -19,11 +19,6 @@ import com.google.net.stubby.newtransport.AbstractStream; import com.google.net.stubby.newtransport.ClientStreamListener; import com.google.net.stubby.newtransport.StreamState; -import io.netty.buffer.Unpooled; -import io.netty.handler.codec.AsciiString; -import io.netty.handler.codec.http2.DefaultHttp2Headers; -import io.netty.handler.codec.http2.Http2Headers; - import org.junit.After; import org.junit.Assume; import org.junit.Before; @@ -34,6 +29,11 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.Mockito; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.AsciiString; +import io.netty.handler.codec.http2.DefaultHttp2Headers; +import io.netty.handler.codec.http2.Http2Headers; + /** * Tests for {@link NettyClientStream}. */ @@ -47,6 +47,7 @@ public class NettyClientStreamTest extends NettyStreamTestBase { return listener; } + @Override @Before public void setup() { AbstractStream.GRPC_V2_PROTOCOL = false; @@ -78,7 +79,7 @@ public class NettyClientStreamTest extends NettyStreamTestBase { stream().id(STREAM_ID); stream.writeMessage(input, input.available(), accepted); stream.flush(); - verify(channel).writeAndFlush(new SendGrpcFrameCommand(1, messageFrame(), false)); + verify(channel).writeAndFlush(new SendGrpcFrameCommand(1, messageFrame(MESSAGE), false)); verify(accepted).run(); } diff --git a/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyServerHandlerTest.java b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyServerHandlerTest.java index 8e98f644d2..b000cb357c 100644 --- a/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyServerHandlerTest.java +++ b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyServerHandlerTest.java @@ -4,6 +4,7 @@ import static com.google.net.stubby.newtransport.netty.Utils.CONTENT_TYPE_HEADER import static com.google.net.stubby.newtransport.netty.Utils.CONTENT_TYPE_GRPC; import static com.google.net.stubby.newtransport.netty.Utils.HTTP_METHOD; import static io.netty.handler.codec.http2.Http2CodecUtil.toByteBuf; +import static io.netty.handler.codec.http2.Http2Exception.protocolError; import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; @@ -52,7 +53,6 @@ import io.netty.handler.codec.http2.Http2FrameWriter; import io.netty.handler.codec.http2.Http2Headers; import io.netty.handler.codec.http2.Http2OutboundFlowController; import io.netty.handler.codec.http2.Http2Settings; -import io.netty.handler.codec.http2.Http2StreamException; import java.io.ByteArrayInputStream; import java.io.InputStream; @@ -149,8 +149,10 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase { public void clientHalfCloseShouldForwardToStreamListener() throws Exception { createStream(); - handler.channelRead(ctx, emptyDataFrame(STREAM_ID, true)); - verify(streamListener, never()).messageRead(any(InputStream.class), anyInt()); + handler.channelRead(ctx, emptyGrpcFrame(STREAM_ID, true)); + ArgumentCaptor captor = ArgumentCaptor.forClass(InputStream.class); + verify(streamListener).messageRead(captor.capture(), anyInt()); + assertArrayEquals(new byte[0], ByteStreams.toByteArray(captor.getValue())); verify(streamListener).halfClosed(); verifyNoMoreInteractions(streamListener); } @@ -169,8 +171,13 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase { public void streamErrorShouldNotCloseChannel() throws Exception { createStream(); - Http2StreamException e = new Http2StreamException(STREAM_ID, Http2Error.REFUSED_STREAM); - handler.exceptionCaught(ctx, e); + // When a DATA frame is read, throw an exception. It will be converted into an + // Http2StreamException. + RuntimeException e = new RuntimeException("Fake Exception"); + when(streamListener.messageRead(any(InputStream.class), anyInt())).thenThrow(e); + + // Read a DATA frame to trigger the exception. + handler.channelRead(ctx, emptyGrpcFrame(STREAM_ID, true)); // Verify that the context was NOT closed. verify(ctx, never()).close(); @@ -178,18 +185,20 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase { // Verify the stream was closed. ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); verify(streamListener).closed(captor.capture()); - assertEquals(e, Http2CodecUtil.toHttp2Exception(captor.getValue().asException())); + assertEquals(e, captor.getValue().asException().getCause().getCause()); assertEquals(Code.INTERNAL, captor.getValue().getCode()); } @Test public void connectionErrorShouldCloseChannel() throws Exception { - // Non-HTTP/2 exceptions are automatically interpreted as connection errors. - Exception e = new Exception("Fake Exception"); - handler.exceptionCaught(ctx, e); + createStream(); + + // Read a DATA frame to trigger the exception. + handler.channelRead(ctx, badFrame()); // Verify the expected GO_AWAY frame was written. - ByteBuf expected = goAwayFrame(0, Http2Error.INTERNAL_ERROR.code(), toByteBuf(ctx, e)); + Exception e = protocolError("Frame length 0 incorrect size for ping."); + ByteBuf expected = goAwayFrame(STREAM_ID, Http2Error.PROTOCOL_ERROR.code(), toByteBuf(ctx, e)); ByteBuf actual = captureWrite(ctx); assertEquals(expected, actual); @@ -244,9 +253,17 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase { return captureWrite(ctx); } - private ByteBuf emptyDataFrame(int streamId, boolean endStream) { + private ByteBuf emptyGrpcFrame(int streamId, boolean endStream) throws Exception { ChannelHandlerContext ctx = newContext(); - frameWriter.writeData(ctx, streamId, Unpooled.EMPTY_BUFFER, 0, endStream, newPromise()); + ByteBuf buf = NettyTestUtil.messageFrame(""); + frameWriter.writeData(ctx, streamId, buf, 0, endStream, newPromise()); + return captureWrite(ctx); + } + + private ByteBuf badFrame() throws Exception { + ChannelHandlerContext ctx = newContext(); + // Write an empty PING frame - this is invalid. + frameWriter.writePing(ctx, false, Unpooled.EMPTY_BUFFER, newPromise()); return captureWrite(ctx); } diff --git a/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyServerStreamTest.java b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyServerStreamTest.java index b964b429ff..689a40e81e 100644 --- a/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyServerStreamTest.java +++ b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyServerStreamTest.java @@ -1,5 +1,7 @@ package com.google.net.stubby.newtransport.netty; +import static com.google.net.stubby.newtransport.netty.NettyTestUtil.messageFrame; +import static com.google.net.stubby.newtransport.netty.NettyTestUtil.statusFrame; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; import static org.mockito.Matchers.same; @@ -38,7 +40,7 @@ public class NettyServerStreamTest extends NettyStreamTestBase { .status(Utils.STATUS_OK) .set(Utils.CONTENT_TYPE_HEADER, Utils.CONTENT_TYPE_GRPC); verify(channel).writeAndFlush(new SendResponseHeadersCommand(STREAM_ID, headers, false)); - verify(channel).writeAndFlush(new SendGrpcFrameCommand(STREAM_ID, messageFrame(), false)); + verify(channel).writeAndFlush(new SendGrpcFrameCommand(STREAM_ID, messageFrame(MESSAGE), false)); verify(accepted).run(); } diff --git a/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyStreamTestBase.java b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyStreamTestBase.java index a88965faa2..7fcc8c7c44 100644 --- a/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyStreamTestBase.java +++ b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyStreamTestBase.java @@ -1,7 +1,6 @@ package com.google.net.stubby.newtransport.netty; -import static com.google.net.stubby.GrpcFramingUtil.PAYLOAD_FRAME; -import static com.google.net.stubby.GrpcFramingUtil.STATUS_FRAME; +import static com.google.net.stubby.newtransport.netty.NettyTestUtil.messageFrame; import static io.netty.handler.codec.http2.DefaultHttp2InboundFlowController.DEFAULT_WINDOW_UPDATE_RATIO; import static io.netty.handler.codec.http2.DefaultHttp2InboundFlowController.WINDOW_UPDATE_OFF; import static io.netty.util.CharsetUtil.UTF_8; @@ -15,14 +14,17 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import com.google.common.io.ByteStreams; import com.google.common.util.concurrent.SettableFuture; -import com.google.net.stubby.Status; -import com.google.net.stubby.newtransport.AbstractStream; import com.google.net.stubby.newtransport.StreamListener; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; @@ -33,17 +35,7 @@ import io.netty.channel.ChannelPromise; import io.netty.channel.EventLoop; import io.netty.handler.codec.http2.DefaultHttp2InboundFlowController; -import org.junit.Before; -import org.junit.Test; -import org.mockito.ArgumentCaptor; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; - import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.DataOutputStream; import java.io.InputStream; import java.util.concurrent.TimeUnit; @@ -115,7 +107,7 @@ public abstract class NettyStreamTestBase { @Test public void inboundMessageShouldCallListener() throws Exception { - stream.inboundDataReceived(messageFrame(), false); + stream.inboundDataReceived(messageFrame(MESSAGE), false); ArgumentCaptor captor = ArgumentCaptor.forClass(InputStream.class); verify(listener()).messageRead(captor.capture(), eq(MESSAGE.length())); @@ -123,7 +115,7 @@ public abstract class NettyStreamTestBase { verify(inboundFlow).setWindowUpdateRatio(eq(ctx), eq(STREAM_ID), eq(WINDOW_UPDATE_OFF)); verify(inboundFlow, never()).setWindowUpdateRatio(eq(ctx), eq(STREAM_ID), eq(DEFAULT_WINDOW_UPDATE_RATIO)); - assertEquals(MESSAGE, toString(captor.getValue())); + assertEquals(MESSAGE, NettyTestUtil.toString(captor.getValue())); // Verify that inbound flow control window update has been re-enabled for the stream after // the future completes. @@ -136,49 +128,6 @@ public abstract class NettyStreamTestBase { protected abstract StreamListener listener(); - private String toString(InputStream in) throws Exception { - byte[] bytes = new byte[in.available()]; - ByteStreams.readFully(in, bytes); - return new String(bytes, UTF_8); - } - - protected final ByteBuf messageFrame() throws Exception { - ByteArrayOutputStream os = new ByteArrayOutputStream(); - DataOutputStream dos = new DataOutputStream(os); - if (!AbstractStream.GRPC_V2_PROTOCOL) { - dos.write(PAYLOAD_FRAME); - dos.writeInt(MESSAGE.length()); - } - dos.write(MESSAGE.getBytes(UTF_8)); - dos.close(); - - // Write the compression header followed by the context frame. - return compressionFrame(os.toByteArray()); - } - - protected final ByteBuf statusFrame(Status status) throws Exception { - ByteArrayOutputStream os = new ByteArrayOutputStream(); - DataOutputStream dos = new DataOutputStream(os); - short code = (short) status.getCode().value(); - dos.write(STATUS_FRAME); - int length = 2; - dos.writeInt(length); - dos.writeShort(code); - - // Write the compression header followed by the context frame. - return compressionFrame(os.toByteArray()); - } - - protected final ByteBuf compressionFrame(byte[] data) { - ByteBuf buf = Unpooled.buffer(); - if (AbstractStream.GRPC_V2_PROTOCOL) { - buf.writeByte(0); - } - buf.writeInt(data.length); - buf.writeBytes(data); - return buf; - } - private void mockChannelFuture(boolean succeeded) { when(future.isDone()).thenReturn(true); when(future.isCancelled()).thenReturn(false); diff --git a/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyTestUtil.java b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyTestUtil.java new file mode 100644 index 0000000000..2abea486d6 --- /dev/null +++ b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyTestUtil.java @@ -0,0 +1,65 @@ +package com.google.net.stubby.newtransport.netty; + +import static com.google.net.stubby.GrpcFramingUtil.PAYLOAD_FRAME; +import static com.google.net.stubby.GrpcFramingUtil.STATUS_FRAME; +import static io.netty.util.CharsetUtil.UTF_8; + +import com.google.common.io.ByteStreams; +import com.google.net.stubby.Status; +import com.google.net.stubby.newtransport.AbstractStream; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import java.io.InputStream; + +/** + * Utility methods for supporting Netty tests. + */ +public class NettyTestUtil { + + static String toString(InputStream in) throws Exception { + byte[] bytes = new byte[in.available()]; + ByteStreams.readFully(in, bytes); + return new String(bytes, UTF_8); + } + + static ByteBuf messageFrame(String message) throws Exception { + ByteArrayOutputStream os = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(os); + if (!AbstractStream.GRPC_V2_PROTOCOL) { + dos.write(PAYLOAD_FRAME); + dos.writeInt(message.length()); + } + dos.write(message.getBytes(UTF_8)); + dos.close(); + + // Write the compression header followed by the context frame. + return compressionFrame(os.toByteArray()); + } + + static ByteBuf statusFrame(Status status) throws Exception { + ByteArrayOutputStream os = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(os); + short code = (short) status.getCode().value(); + dos.write(STATUS_FRAME); + int length = 2; + dos.writeInt(length); + dos.writeShort(code); + + // Write the compression header followed by the context frame. + return compressionFrame(os.toByteArray()); + } + + static ByteBuf compressionFrame(byte[] data) { + ByteBuf buf = Unpooled.buffer(); + if (AbstractStream.GRPC_V2_PROTOCOL) { + buf.writeByte(0); + } + buf.writeInt(data.length); + buf.writeBytes(data); + return buf; + } +}