diff --git a/netty/src/main/java/io/grpc/transport/netty/BufferingHttp2ConnectionEncoder.java b/netty/src/main/java/io/grpc/transport/netty/BufferingHttp2ConnectionEncoder.java index 6372b8c948..9592c69392 100644 --- a/netty/src/main/java/io/grpc/transport/netty/BufferingHttp2ConnectionEncoder.java +++ b/netty/src/main/java/io/grpc/transport/netty/BufferingHttp2ConnectionEncoder.java @@ -92,6 +92,13 @@ class BufferingHttp2ConnectionEncoder extends DecoratingHttp2ConnectionEncoder { }); } + /** + * Indicates the number of streams that are currently buffered, awaiting creation. + */ + public int numBufferedStreams() { + return pendingStreams.size(); + } + @Override public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding, boolean endStream, ChannelPromise promise) { @@ -199,6 +206,7 @@ class BufferingHttp2ConnectionEncoder extends DecoratingHttp2ConnectionEncoder { while (iter.hasNext()) { PendingStream stream = iter.next(); if (stream.streamId > lastStreamId) { + iter.remove(); stream.close(e); } } diff --git a/netty/src/test/java/io/grpc/transport/netty/BufferingHttp2ConnectionEncoderTest.java b/netty/src/test/java/io/grpc/transport/netty/BufferingHttp2ConnectionEncoderTest.java index d19950944d..a8e6bee5f9 100644 --- a/netty/src/test/java/io/grpc/transport/netty/BufferingHttp2ConnectionEncoderTest.java +++ b/netty/src/test/java/io/grpc/transport/netty/BufferingHttp2ConnectionEncoderTest.java @@ -58,7 +58,6 @@ import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder; import io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder; import io.netty.handler.codec.http2.DefaultHttp2Headers; import io.netty.handler.codec.http2.Http2Connection; -import io.netty.handler.codec.http2.Http2ConnectionEncoder; import io.netty.handler.codec.http2.Http2ConnectionHandler; import io.netty.handler.codec.http2.Http2FrameListener; import io.netty.handler.codec.http2.Http2FrameReader; @@ -66,6 +65,7 @@ import io.netty.handler.codec.http2.Http2FrameSizePolicy; import io.netty.handler.codec.http2.Http2FrameWriter; import io.netty.handler.codec.http2.Http2Headers; import io.netty.util.concurrent.ImmediateEventExecutor; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -82,7 +82,7 @@ import org.mockito.verification.VerificationMode; @RunWith(JUnit4.class) public class BufferingHttp2ConnectionEncoderTest { - private Http2ConnectionEncoder encoder; + private BufferingHttp2ConnectionEncoder encoder; private Http2Connection connection; @@ -133,6 +133,7 @@ public class BufferingHttp2ConnectionEncoderTest { @Test public void multipleWritesToActiveStream() { encoderWriteHeaders(3, promise); + assertEquals(0, encoder.numBufferedStreams()); encoder.writeData(ctx, 3, data(), 0, false, promise); encoder.writeData(ctx, 3, data(), 0, false, promise); encoder.writeData(ctx, 3, data(), 0, false, promise); @@ -148,9 +149,12 @@ public class BufferingHttp2ConnectionEncoderTest { connection.local().maxActiveStreams(1); encoderWriteHeaders(3, promise); + assertEquals(0, encoder.numBufferedStreams()); + // This one gets buffered. encoderWriteHeaders(5, promise); assertEquals(1, connection.numActiveStreams()); + assertEquals(1, encoder.numBufferedStreams()); // Now prevent us from creating another stream. connection.local().maxActiveStreams(0); @@ -163,6 +167,7 @@ public class BufferingHttp2ConnectionEncoderTest { writeVerifyWriteHeaders(times(1), 3, promise); writeVerifyWriteHeaders(never(), 5, promise); assertEquals(0, connection.numActiveStreams()); + assertEquals(1, encoder.numBufferedStreams()); } @Test @@ -170,8 +175,11 @@ public class BufferingHttp2ConnectionEncoderTest { connection.local().maxActiveStreams(1); encoderWriteHeaders(3, promise); + assertEquals(0, encoder.numBufferedStreams()); + encoderWriteHeaders(5, promise); assertEquals(1, connection.numActiveStreams()); + assertEquals(1, encoder.numBufferedStreams()); encoder.writeData(ctx, 3, Unpooled.buffer(0), 0, false, promise); writeVerifyWriteHeaders(times(1), 3, promise); @@ -187,6 +195,7 @@ public class BufferingHttp2ConnectionEncoderTest { promise = mock(ChannelPromise.class); encoderWriteHeaders(3, promise); + assertEquals(0, encoder.numBufferedStreams()); verify(promise).setFailure(any(Throwable.class)); } @@ -199,12 +208,14 @@ public class BufferingHttp2ConnectionEncoderTest { encoderWriteHeaders(streamId, promise); streamId += 2; } + assertEquals(4, encoder.numBufferedStreams()); connection.goAwayReceived(11, 8, null); assertEquals(5, connection.numActiveStreams()); // The 4 buffered streams must have been failed. verify(promise, times(4)).setFailure(any(Throwable.class)); + assertEquals(0, encoder.numBufferedStreams()); } @Test @@ -212,13 +223,17 @@ public class BufferingHttp2ConnectionEncoderTest { connection.local().maxActiveStreams(1); encoderWriteHeaders(3, promise); + assertEquals(0, encoder.numBufferedStreams()); encoderWriteHeaders(5, promise); + assertEquals(1, encoder.numBufferedStreams()); encoderWriteHeaders(7, promise); + assertEquals(2, encoder.numBufferedStreams()); ByteBuf empty = Unpooled.buffer(0); encoder.writeGoAway(ctx, 3, CANCEL.code(), empty, promise); assertEquals(1, connection.numActiveStreams()); + assertEquals(2, encoder.numBufferedStreams()); verify(promise, never()).setFailure(any(GoAwayClosedStreamException.class)); } @@ -227,11 +242,13 @@ public class BufferingHttp2ConnectionEncoderTest { connection.local().maxActiveStreams(0); encoderWriteHeaders(3, promise); + assertEquals(1, encoder.numBufferedStreams()); ByteBuf empty = Unpooled.buffer(0); encoder.writeData(ctx, 3, empty, 0, true, promise); assertEquals(0, connection.numActiveStreams()); + assertEquals(1, encoder.numBufferedStreams()); // Simulate that we received a SETTINGS frame which // increased MAX_CONCURRENT_STREAMS to 1. @@ -239,6 +256,7 @@ public class BufferingHttp2ConnectionEncoderTest { encoder.writeSettingsAck(ctx, promise); assertEquals(1, connection.numActiveStreams()); + assertEquals(0, encoder.numBufferedStreams()); assertEquals(HALF_CLOSED_LOCAL, connection.stream(3).state()); } @@ -247,12 +265,14 @@ public class BufferingHttp2ConnectionEncoderTest { connection.local().maxActiveStreams(0); encoderWriteHeaders(3, promise); + assertEquals(1, encoder.numBufferedStreams()); verify(promise, never()).setSuccess(); ChannelPromise rstStreamPromise = mock(ChannelPromise.class); encoder.writeRstStream(ctx, 3, CANCEL.code(), rstStreamPromise); verify(promise).setSuccess(); verify(rstStreamPromise).setSuccess(); + assertEquals(0, encoder.numBufferedStreams()); } @Test @@ -260,8 +280,11 @@ public class BufferingHttp2ConnectionEncoderTest { connection.local().maxActiveStreams(1); encoderWriteHeaders(3, promise); + assertEquals(0, encoder.numBufferedStreams()); encoderWriteHeaders(5, promise); + assertEquals(1, encoder.numBufferedStreams()); encoderWriteHeaders(7, promise); + assertEquals(2, encoder.numBufferedStreams()); writeVerifyWriteHeaders(times(1), 3, promise); writeVerifyWriteHeaders(never(), 5, promise); @@ -269,10 +292,13 @@ public class BufferingHttp2ConnectionEncoderTest { encoder.writeRstStream(ctx, 3, CANCEL.code(), promise); assertEquals(1, connection.numActiveStreams()); + assertEquals(1, encoder.numBufferedStreams()); encoder.writeRstStream(ctx, 5, CANCEL.code(), promise); assertEquals(1, connection.numActiveStreams()); + assertEquals(0, encoder.numBufferedStreams()); encoder.writeRstStream(ctx, 7, CANCEL.code(), promise); assertEquals(0, connection.numActiveStreams()); + assertEquals(0, encoder.numBufferedStreams()); } @Test @@ -283,6 +309,7 @@ public class BufferingHttp2ConnectionEncoderTest { encoderWriteHeaders(5, promise); encoderWriteHeaders(7, promise); encoderWriteHeaders(9, promise); + assertEquals(2, encoder.numBufferedStreams()); writeVerifyWriteHeaders(times(1), 3, promise); writeVerifyWriteHeaders(times(1), 5, promise); @@ -294,6 +321,7 @@ public class BufferingHttp2ConnectionEncoderTest { connection.local().maxActiveStreams(5); encoder.writeSettingsAck(ctx, promise); + assertEquals(0, encoder.numBufferedStreams()); writeVerifyWriteHeaders(times(1), 7, promise); writeVerifyWriteHeaders(times(1), 9, promise); @@ -309,11 +337,13 @@ public class BufferingHttp2ConnectionEncoderTest { connection.local().maxActiveStreams(0); ByteBuf data = mock(ByteBuf.class); encoderWriteHeaders(3, promise); + assertEquals(1, encoder.numBufferedStreams()); encoder.writeData(ctx, 3, data, 0, false, promise); ChannelPromise rstPromise = mock(ChannelPromise.class); encoder.writeRstStream(ctx, 3, CANCEL.code(), rstPromise); + assertEquals(0, encoder.numBufferedStreams()); verify(rstPromise).setSuccess(); verify(promise, times(2)).setSuccess(); verify(data).release();