From f66d7fc54d23fa441d0feb9873952a325346cbc4 Mon Sep 17 00:00:00 2001 From: vinodhabib <47808007+vinodhabib@users.noreply.github.com> Date: Tue, 3 Dec 2024 00:39:25 +0530 Subject: [PATCH] netty: Fix ByteBuf leaks in tests (#11593) Part of #3353 --- .../io/grpc/netty/NettyClientHandlerTest.java | 6 ++- .../io/grpc/netty/NettyHandlerTestBase.java | 41 +++++++++++-------- .../io/grpc/netty/NettyServerHandlerTest.java | 28 +++++-------- 3 files changed, 39 insertions(+), 36 deletions(-) diff --git a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java index 6c5dd6b18b..e5c97e9efd 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java @@ -217,6 +217,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase { protected static final int STREAM_ID = 3; - private ByteBuf content; private EmbeddedChannel channel; @@ -106,18 +105,24 @@ public abstract class NettyHandlerTestBase { protected final TransportTracer transportTracer = new TransportTracer(); protected int flowControlWindow = DEFAULT_WINDOW_SIZE; protected boolean autoFlowControl = false; - private final FakeClock fakeClock = new FakeClock(); FakeClock fakeClock() { return fakeClock; } + @After + public void tearDown() throws Exception { + if (channel() != null) { + channel().releaseInbound(); + channel().releaseOutbound(); + } + } + /** * Must be called by subclasses to initialize the handler and channel. */ protected final void initChannel(Http2HeadersDecoder headersDecoder) throws Exception { - content = Unpooled.copiedBuffer("hello world", UTF_8); frameWriter = mock(Http2FrameWriter.class, delegatesTo(new DefaultHttp2FrameWriter())); frameReader = new DefaultHttp2FrameReader(headersDecoder); @@ -233,11 +238,11 @@ public abstract class NettyHandlerTestBase { } protected final ByteBuf content() { - return content; + return Unpooled.copiedBuffer(contentAsArray()); } protected final byte[] contentAsArray() { - return ByteBufUtil.getBytes(content()); + return "\000\000\000\000\rhello world".getBytes(UTF_8); } protected final Http2FrameWriter verifyWrite() { @@ -252,8 +257,8 @@ public abstract class NettyHandlerTestBase { channel.writeInbound(obj); } - protected ByteBuf grpcDataFrame(int streamId, boolean endStream, byte[] content) { - final ByteBuf compressionFrame = Unpooled.buffer(content.length); + protected ByteBuf grpcFrame(byte[] message) { + final ByteBuf compressionFrame = Unpooled.buffer(message.length); MessageFramer framer = new MessageFramer( new MessageFramer.Sink() { @Override @@ -262,23 +267,22 @@ public abstract class NettyHandlerTestBase { if (frame != null) { ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf(); compressionFrame.writeBytes(bytebuf); + bytebuf.release(); } } }, new NettyWritableBufferAllocator(ByteBufAllocator.DEFAULT), StatsTraceContext.NOOP); - framer.writePayload(new ByteArrayInputStream(content)); - framer.flush(); - ChannelHandlerContext ctx = newMockContext(); - new DefaultHttp2FrameWriter().writeData(ctx, streamId, compressionFrame, 0, endStream, - newPromise()); - return captureWrite(ctx); + framer.writePayload(new ByteArrayInputStream(message)); + framer.close(); + return compressionFrame; + } + + protected final ByteBuf grpcDataFrame(int streamId, boolean endStream, byte[] content) { + return dataFrame(streamId, endStream, grpcFrame(content)); } protected final ByteBuf dataFrame(int streamId, boolean endStream, ByteBuf content) { - // Need to retain the content since the frameWriter releases it. - content.retain(); - ChannelHandlerContext ctx = newMockContext(); new DefaultHttp2FrameWriter().writeData(ctx, streamId, content, 0, endStream, newPromise()); return captureWrite(ctx); @@ -410,6 +414,7 @@ public abstract class NettyHandlerTestBase { channelRead(dataFrame(3, false, buff.copy())); assertEquals(length * 3, handler.flowControlPing().getDataSincePing()); + buff.release(); } @Test @@ -608,12 +613,14 @@ public abstract class NettyHandlerTestBase { private void readPingAck(long pingData) throws Exception { channelRead(pingFrame(true, pingData)); + channel().releaseOutbound(); } private void readXCopies(int copies, byte[] data) throws Exception { for (int i = 0; i < copies; i++) { channelRead(grpcDataFrame(STREAM_ID, false, data)); // buffer it stream().request(1); // consume it + channel().releaseOutbound(); } } diff --git a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java index 308079ff62..54c1375eef 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java @@ -43,6 +43,7 @@ import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; @@ -74,7 +75,6 @@ import io.grpc.internal.StreamListener; import io.grpc.internal.testing.TestServerStreamTracer; import io.grpc.netty.GrpcHttp2HeadersUtils.GrpcHttp2ServerHeadersDecoder; import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufUtil; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; @@ -120,23 +120,16 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase streamListenerMessageQueue = new LinkedList<>(); private int maxConcurrentStreams = Integer.MAX_VALUE; @@ -208,6 +201,7 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase rapidReset(maxRstCount + 1)); + assertFalse(channel().isOpen()); } @@ -1344,11 +1342,7 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase