diff --git a/netty/src/main/java/io/grpc/netty/AbstractNettyHandler.java b/netty/src/main/java/io/grpc/netty/AbstractNettyHandler.java index 8d67b1feb9..76f0b4ba45 100644 --- a/netty/src/main/java/io/grpc/netty/AbstractNettyHandler.java +++ b/netty/src/main/java/io/grpc/netty/AbstractNettyHandler.java @@ -31,25 +31,40 @@ package io.grpc.netty; +import static io.netty.buffer.Unpooled.directBuffer; +import static io.netty.buffer.Unpooled.unreleasableBuffer; import static io.netty.handler.codec.http2.Http2CodecUtil.getEmbeddedHttp2Exception; import static java.util.concurrent.TimeUnit.SECONDS; +import com.google.common.annotations.VisibleForTesting; + +import io.grpc.netty.AbstractNettyHandler.FlowControlPinger; +import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http2.Http2ConnectionDecoder; import io.netty.handler.codec.http2.Http2ConnectionEncoder; import io.netty.handler.codec.http2.Http2ConnectionHandler; import io.netty.handler.codec.http2.Http2Exception; +import io.netty.handler.codec.http2.Http2LocalFlowController; import io.netty.handler.codec.http2.Http2Settings; import io.netty.handler.codec.http2.Http2Stream; +import java.util.concurrent.TimeUnit; + /** * Base class for all Netty gRPC handlers. This class standardizes exception handling (always * shutdown the connection) as well as sending the initial connection window at startup. */ abstract class AbstractNettyHandler extends Http2ConnectionHandler { private static long GRACEFUL_SHUTDOWN_TIMEOUT = SECONDS.toMillis(5); + private boolean autoTuneFlowControlOn = false; private int initialConnectionWindow; private ChannelHandlerContext ctx; + private final FlowControlPinger flowControlPing = new FlowControlPinger(); + + private static final int BDP_MEASUREMENT_PING = 1234; + private static final ByteBuf payloadBuf = + unreleasableBuffer(directBuffer(8).writeLong(BDP_MEASUREMENT_PING)); AbstractNettyHandler(Http2ConnectionDecoder decoder, Http2ConnectionEncoder encoder, @@ -108,4 +123,114 @@ abstract class AbstractNettyHandler extends Http2ConnectionHandler { ctx.flush(); } } + + @VisibleForTesting + FlowControlPinger flowControlPing() { + return flowControlPing; + } + + @VisibleForTesting + void setAutoTuneFlowControl(boolean isOn) { + autoTuneFlowControlOn = isOn; + } + + /** + * Class for handling flow control pinging and flow control window updates as necessary. + */ + final class FlowControlPinger { + + private static final int MAX_WINDOW_SIZE = 8 * 1024 * 1024; + private int pingCount; + private int pingReturn; + private boolean pinging; + private int dataSizeSincePing; + private float lastBandwidth; // bytes per second + private long lastPingTime; + + public int payload() { + return BDP_MEASUREMENT_PING; + } + + public int maxWindow() { + return MAX_WINDOW_SIZE; + } + + public void onDataRead(int dataLength, int paddingLength) { + if (!autoTuneFlowControlOn) { + return; + } + if (!isPinging()) { + setPinging(true); + sendPing(ctx()); + } + incrementDataSincePing(dataLength + paddingLength); + } + + public void updateWindow() throws Http2Exception { + if (!autoTuneFlowControlOn) { + return; + } + pingReturn++; + long elapsedTime = (System.nanoTime() - lastPingTime); + if (elapsedTime == 0) { + elapsedTime = 1; + } + long bandwidth = (getDataSincePing() * TimeUnit.SECONDS.toNanos(1)) / elapsedTime; + Http2LocalFlowController fc = decoder().flowController(); + // Calculate new window size by doubling the observed BDP, but cap at max window + int targetWindow = Math.min(getDataSincePing() * 2, MAX_WINDOW_SIZE); + setPinging(false); + int currentWindow = fc.initialWindowSize(connection().connectionStream()); + if (targetWindow > currentWindow && bandwidth > lastBandwidth) { + lastBandwidth = bandwidth; + int increase = targetWindow - currentWindow; + fc.incrementWindowSize(connection().connectionStream(), increase); + fc.initialWindowSize(targetWindow); + Http2Settings settings = new Http2Settings(); + settings.initialWindowSize(targetWindow); + frameWriter().writeSettings(ctx(), settings, ctx().newPromise()); + } + + } + + private boolean isPinging() { + return pinging; + } + + private void setPinging(boolean pingOut) { + pinging = pingOut; + } + + private void sendPing(ChannelHandlerContext ctx) { + setDataSizeSincePing(0); + lastPingTime = System.nanoTime(); + encoder().writePing(ctx, false, payloadBuf.slice(), ctx.newPromise()); + pingCount++; + } + + private void incrementDataSincePing(int increase) { + int currentSize = getDataSincePing(); + setDataSizeSincePing(currentSize + increase); + } + + @VisibleForTesting + int getPingCount() { + return pingCount; + } + + @VisibleForTesting + int getPingReturn() { + return pingReturn; + } + + @VisibleForTesting + int getDataSincePing() { + return dataSizeSincePing; + } + + @VisibleForTesting + void setDataSizeSincePing(int dataSize) { + dataSizeSincePing = dataSize; + } + } } diff --git a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java index fdda262ab8..a184358304 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java @@ -103,6 +103,7 @@ class NettyClientHandler extends AbstractNettyHandler { */ private static final Status EXHAUSTED_STREAMS_STATUS = Status.UNAVAILABLE.withDescription("Stream IDs have been exhausted"); + private static final long USER_PING_PAYLOAD = 1111; private final Http2Connection.PropertyKey streamKey; private final ClientTransportLifecycleManager lifecycleManager; @@ -120,6 +121,7 @@ class NettyClientHandler extends AbstractNettyHandler { Http2FrameReader frameReader = new DefaultHttp2FrameReader(headersDecoder); Http2FrameWriter frameWriter = new DefaultHttp2FrameWriter(); Http2Connection connection = new DefaultHttp2Connection(false); + return newHandler( connection, frameReader, frameWriter, lifecycleManager, flowControlWindow, ticker); } @@ -145,8 +147,8 @@ class NettyClientHandler extends AbstractNettyHandler { new DefaultHttp2ConnectionEncoder(connection, frameWriter)); // Create the local flow controller configured to auto-refill the connection window. - connection.local().flowController(new DefaultHttp2LocalFlowController(connection, - DEFAULT_WINDOW_UPDATE_RATIO, true)); + connection.local().flowController( + new DefaultHttp2LocalFlowController(connection, DEFAULT_WINDOW_UPDATE_RATIO, true)); Http2ConnectionDecoder decoder = new DefaultHttp2ConnectionDecoder(connection, encoder, frameReader); @@ -172,6 +174,7 @@ class NettyClientHandler extends AbstractNettyHandler { Http2Connection connection = encoder.connection(); streamKey = connection.newKey(); + connection.addListener(new Http2ConnectionAdapter() { @Override public void onGoAwayReceived(int lastStreamId, long errorCode, ByteBuf debugData) { @@ -219,6 +222,11 @@ class NettyClientHandler extends AbstractNettyHandler { } } + // @VisibleForTesting + // FlowControlPinger flowControlPinger() { + // return flowControlPing; + // } + void startWriteQueue(Channel channel) { clientWriteQueue = new WriteQueue(channel); } @@ -246,11 +254,14 @@ class NettyClientHandler extends AbstractNettyHandler { /** * Handler for an inbound HTTP/2 DATA frame. */ - private void onDataRead(int streamId, ByteBuf data, boolean endOfStream) { + + private void onDataRead(int streamId, ByteBuf data, int padding, boolean endOfStream) { + flowControlPing().onDataRead(data.readableBytes(), padding); NettyClientStream stream = clientStream(requireHttp2Stream(streamId)); stream.transportDataReceived(data, endOfStream); } + /** * Handler for an inbound HTTP/2 RST_STREAM frame, terminating a stream. */ @@ -449,7 +460,7 @@ class NettyClientHandler extends AbstractNettyHandler { promise.setSuccess(); promise = ctx().newPromise(); // set outstanding operation - long data = random.nextLong(); + long data = USER_PING_PAYLOAD; ByteBuf buffer = ctx.alloc().buffer(8); buffer.writeLong(data); Stopwatch stopwatch = Stopwatch.createStarted(ticker); @@ -585,7 +596,7 @@ class NettyClientHandler extends AbstractNettyHandler { @Override public int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, boolean endOfStream) throws Http2Exception { - NettyClientHandler.this.onDataRead(streamId, data, endOfStream); + NettyClientHandler.this.onDataRead(streamId, data, padding, endOfStream); return padding; } @@ -607,17 +618,23 @@ class NettyClientHandler extends AbstractNettyHandler { NettyClientHandler.this.onRstStreamRead(streamId, errorCode); } - @Override public void onPingAckRead(ChannelHandlerContext ctx, ByteBuf data) - throws Http2Exception { + @Override + public void onPingAckRead(ChannelHandlerContext ctx, ByteBuf data) throws Http2Exception { Http2Ping p = ping; - if (p != null) { + if (data.getLong(data.readerIndex()) == flowControlPing().payload()) { + flowControlPing().updateWindow(); + if (logger.isLoggable(Level.FINE)) { + logger.log(Level.FINE, String.format("Window: %d", + decoder().flowController().initialWindowSize(connection().connectionStream()))); + } + } else if (p != null) { long ackPayload = data.readLong(); if (p.payload() == ackPayload) { p.complete(); ping = null; } else { - logger.log(Level.WARNING, String.format("Received unexpected ping ack. " - + "Expecting %d, got %d", p.payload(), ackPayload)); + logger.log(Level.WARNING, String.format( + "Received unexpected ping ack. Expecting %d, got %d", p.payload(), ackPayload)); } } else { logger.warning("Received unexpected ping ack. No ping outstanding"); diff --git a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java index d519fc1e85..4b218ecb84 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java @@ -130,8 +130,9 @@ class NettyServerHandler extends AbstractNettyHandler { Http2Connection connection = new DefaultHttp2Connection(true); // Create the local flow controller configured to auto-refill the connection window. - connection.local().flowController(new DefaultHttp2LocalFlowController(connection, - DEFAULT_WINDOW_UPDATE_RATIO, true)); + connection.local().flowController( + new DefaultHttp2LocalFlowController(connection, DEFAULT_WINDOW_UPDATE_RATIO, true)); + Http2ConnectionEncoder encoder = new DefaultHttp2ConnectionEncoder(connection, frameWriter); Http2ConnectionDecoder decoder = new DefaultHttp2ConnectionDecoder(connection, encoder, @@ -211,7 +212,9 @@ class NettyServerHandler extends AbstractNettyHandler { } } - private void onDataRead(int streamId, ByteBuf data, boolean endOfStream) throws Http2Exception { + private void onDataRead(int streamId, ByteBuf data, int padding, boolean endOfStream) + throws Http2Exception { + flowControlPing().onDataRead(data.readableBytes(), padding); try { NettyServerStream.TransportState stream = serverStream(requireHttp2Stream(streamId)); stream.inboundDataReceived(data, endOfStream); @@ -426,6 +429,7 @@ class NettyServerHandler extends AbstractNettyHandler { } } + /** * Returns the server stream associated to the given HTTP/2 stream object. */ @@ -443,7 +447,7 @@ class NettyServerHandler extends AbstractNettyHandler { @Override public int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, boolean endOfStream) throws Http2Exception { - NettyServerHandler.this.onDataRead(streamId, data, endOfStream); + NettyServerHandler.this.onDataRead(streamId, data, padding, endOfStream); return padding; } @@ -464,5 +468,18 @@ class NettyServerHandler extends AbstractNettyHandler { throws Http2Exception { NettyServerHandler.this.onRstStreamRead(streamId); } + + @Override + public void onPingAckRead(ChannelHandlerContext ctx, ByteBuf data) throws Http2Exception { + if (data.getLong(data.readerIndex()) == flowControlPing().payload()) { + flowControlPing().updateWindow(); + if (logger.isLoggable(Level.FINE)) { + logger.log(Level.FINE, String.format("Window: %d", + decoder().flowController().initialWindowSize(connection().connectionStream()))); + } + } else { + logger.warning("Received unexpected ping ack. No ping outstanding"); + } + } } } diff --git a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java index a88004ab03..c9e547ed45 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java @@ -94,7 +94,6 @@ import org.mockito.MockitoAnnotations; */ @RunWith(JUnit4.class) public class NettyClientHandlerTest extends NettyHandlerTestBase { - // TODO(zhangkun83): mocking concrete classes is not safe. Consider making NettyClientStream an // interface. @Mock @@ -368,6 +367,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase captor = ArgumentCaptor.forClass(ByteBuf.class); + verifyWrite().writePing(eq(ctx()), eq(false), captor.capture(), any(ChannelPromise.class)); + ByteBuf payload = captor.getValue(); + channelRead(dataFrame(3, false)); + long pingData = handler().flowControlPing().payload(); + ByteBuf buffer = handler().ctx().alloc().buffer(8); + buffer.writeLong(pingData); + channelRead(pingFrame(true, buffer)); + + assertEquals(1, handler().flowControlPing().getPingReturn()); + assertEquals(0, callback.invocationCount); + + channelRead(pingFrame(true, payload)); + + assertEquals(1, handler().flowControlPing().getPingReturn()); + assertEquals(1, callback.invocationCount); + } + @Test public void exceptionCaughtShouldCloseConnection() throws Exception { handler().exceptionCaught(ctx(), new RuntimeException("fake exception")); @@ -466,6 +491,11 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase { protected abstract T newHandler() throws Http2Exception; protected abstract WriteQueue initWriteQueue(); + + protected abstract void makeStream() throws Exception; + + @Test + public void dataPingSentOnHeaderRecieved() throws Exception { + makeStream(); + AbstractNettyHandler handler = (AbstractNettyHandler) handler(); + handler.setAutoTuneFlowControl(true); + + channelRead(dataFrame(3, false, content())); + + assertEquals(1, handler.flowControlPing().getPingCount()); + } + + @Test + public void dataPingAckIsRecognized() throws Exception { + makeStream(); + AbstractNettyHandler handler = (AbstractNettyHandler) handler(); + handler.setAutoTuneFlowControl(true); + + channelRead(dataFrame(3, false, content())); + long pingData = handler.flowControlPing().payload(); + ByteBuf payload = handler.ctx().alloc().buffer(8); + payload.writeLong(pingData); + channelRead(pingFrame(true, payload)); + + assertEquals(1, handler.flowControlPing().getPingCount()); + assertEquals(1, handler.flowControlPing().getPingReturn()); + } + + @Test + public void dataSizeSincePingAccumulates() throws Exception { + makeStream(); + AbstractNettyHandler handler = (AbstractNettyHandler) handler(); + handler.setAutoTuneFlowControl(true); + long frameData = 123456; + ByteBuf buff = ctx().alloc().buffer(16); + buff.writeLong(frameData); + int length = buff.readableBytes(); + + channelRead(dataFrame(3, false, buff.copy())); + channelRead(dataFrame(3, false, buff.copy())); + channelRead(dataFrame(3, false, buff.copy())); + + assertEquals(length * 3, handler.flowControlPing().getDataSincePing()); + } + + @Test + public void windowUpdateMatchesTarget() throws Exception { + Http2Stream connectionStream = connection().connectionStream(); + Http2LocalFlowController localFlowController = connection().local().flowController(); + makeStream(); + AbstractNettyHandler handler = (AbstractNettyHandler) handler(); + handler.setAutoTuneFlowControl(true); + + ByteBuf data = ctx().alloc().buffer(1024); + while (data.isWritable()) { + data.writeLong(1111); + } + int length = data.readableBytes(); + ByteBuf frame = dataFrame(3, false, data.copy()); + channelRead(frame); + int accumulator = length; + // 40 is arbitrary, any number large enough to trigger a window update would work + for (int i = 0; i < 40; i++) { + channelRead(dataFrame(3, false, data.copy())); + accumulator += length; + } + long pingData = handler.flowControlPing().payload(); + ByteBuf buffer = handler.ctx().alloc().buffer(8); + buffer.writeLong(pingData); + channelRead(pingFrame(true, buffer)); + + assertEquals(accumulator, handler.flowControlPing().getDataSincePing()); + assertEquals(2 * accumulator, localFlowController.initialWindowSize(connectionStream)); + } + + @Test + public void windowShouldNotExceedMaxWindowSize() throws Exception { + makeStream(); + AbstractNettyHandler handler = (AbstractNettyHandler) handler(); + handler.setAutoTuneFlowControl(true); + Http2Stream connectionStream = connection().connectionStream(); + Http2LocalFlowController localFlowController = connection().local().flowController(); + int maxWindow = handler.flowControlPing().maxWindow(); + + handler.flowControlPing().setDataSizeSincePing(maxWindow); + int payload = handler.flowControlPing().payload(); + ByteBuf buffer = handler.ctx().alloc().buffer(8); + buffer.writeLong(payload); + channelRead(pingFrame(true, buffer)); + + assertEquals(maxWindow, localFlowController.initialWindowSize(connectionStream)); + } + } diff --git a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java index 01915a3880..34b25d7c45 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java @@ -353,4 +353,9 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase