netty: Fix ByteBuf leaks in tests (#11593)

Part of #3353
This commit is contained in:
vinodhabib 2024-12-03 00:39:25 +05:30 committed by GitHub
parent 7f9c1f39f3
commit f66d7fc54d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 39 additions and 36 deletions

View File

@ -217,6 +217,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
// Simulate receipt of initial remote settings. // Simulate receipt of initial remote settings.
ByteBuf serializedSettings = serializeSettings(new Http2Settings()); ByteBuf serializedSettings = serializeSettings(new Http2Settings());
channelRead(serializedSettings); channelRead(serializedSettings);
channel().releaseOutbound();
} }
@Test @Test
@ -342,11 +343,12 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
createStream(); createStream();
// Send a frame and verify that it was written. // Send a frame and verify that it was written.
ByteBuf content = content();
ChannelFuture future ChannelFuture future
= enqueue(new SendGrpcFrameCommand(streamTransportState, content(), true)); = enqueue(new SendGrpcFrameCommand(streamTransportState, content, true));
assertTrue(future.isSuccess()); assertTrue(future.isSuccess());
verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), eq(content()), eq(0), eq(true), verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), same(content), eq(0), eq(true),
any(ChannelPromise.class)); any(ChannelPromise.class));
verify(mockKeepAliveManager, times(1)).onTransportActive(); // onStreamActive verify(mockKeepAliveManager, times(1)).onTransportActive(); // onStreamActive
verifyNoMoreInteractions(mockKeepAliveManager); verifyNoMoreInteractions(mockKeepAliveManager);

View File

@ -38,7 +38,6 @@ import io.grpc.internal.TransportTracer;
import io.grpc.internal.WritableBuffer; import io.grpc.internal.WritableBuffer;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.buffer.UnpooledByteBufAllocator;
@ -68,6 +67,7 @@ import java.io.ByteArrayInputStream;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.concurrent.Delayed; import java.util.concurrent.Delayed;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import org.junit.After;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
@ -84,7 +84,6 @@ import org.mockito.verification.VerificationMode;
public abstract class NettyHandlerTestBase<T extends Http2ConnectionHandler> { public abstract class NettyHandlerTestBase<T extends Http2ConnectionHandler> {
protected static final int STREAM_ID = 3; protected static final int STREAM_ID = 3;
private ByteBuf content;
private EmbeddedChannel channel; private EmbeddedChannel channel;
@ -106,18 +105,24 @@ public abstract class NettyHandlerTestBase<T extends Http2ConnectionHandler> {
protected final TransportTracer transportTracer = new TransportTracer(); protected final TransportTracer transportTracer = new TransportTracer();
protected int flowControlWindow = DEFAULT_WINDOW_SIZE; protected int flowControlWindow = DEFAULT_WINDOW_SIZE;
protected boolean autoFlowControl = false; protected boolean autoFlowControl = false;
private final FakeClock fakeClock = new FakeClock(); private final FakeClock fakeClock = new FakeClock();
FakeClock fakeClock() { FakeClock fakeClock() {
return fakeClock; return fakeClock;
} }
@After
public void tearDown() throws Exception {
if (channel() != null) {
channel().releaseInbound();
channel().releaseOutbound();
}
}
/** /**
* Must be called by subclasses to initialize the handler and channel. * Must be called by subclasses to initialize the handler and channel.
*/ */
protected final void initChannel(Http2HeadersDecoder headersDecoder) throws Exception { protected final void initChannel(Http2HeadersDecoder headersDecoder) throws Exception {
content = Unpooled.copiedBuffer("hello world", UTF_8);
frameWriter = mock(Http2FrameWriter.class, delegatesTo(new DefaultHttp2FrameWriter())); frameWriter = mock(Http2FrameWriter.class, delegatesTo(new DefaultHttp2FrameWriter()));
frameReader = new DefaultHttp2FrameReader(headersDecoder); frameReader = new DefaultHttp2FrameReader(headersDecoder);
@ -233,11 +238,11 @@ public abstract class NettyHandlerTestBase<T extends Http2ConnectionHandler> {
} }
protected final ByteBuf content() { protected final ByteBuf content() {
return content; return Unpooled.copiedBuffer(contentAsArray());
} }
protected final byte[] contentAsArray() { protected final byte[] contentAsArray() {
return ByteBufUtil.getBytes(content()); return "\000\000\000\000\rhello world".getBytes(UTF_8);
} }
protected final Http2FrameWriter verifyWrite() { protected final Http2FrameWriter verifyWrite() {
@ -252,8 +257,8 @@ public abstract class NettyHandlerTestBase<T extends Http2ConnectionHandler> {
channel.writeInbound(obj); channel.writeInbound(obj);
} }
protected ByteBuf grpcDataFrame(int streamId, boolean endStream, byte[] content) { protected ByteBuf grpcFrame(byte[] message) {
final ByteBuf compressionFrame = Unpooled.buffer(content.length); final ByteBuf compressionFrame = Unpooled.buffer(message.length);
MessageFramer framer = new MessageFramer( MessageFramer framer = new MessageFramer(
new MessageFramer.Sink() { new MessageFramer.Sink() {
@Override @Override
@ -262,23 +267,22 @@ public abstract class NettyHandlerTestBase<T extends Http2ConnectionHandler> {
if (frame != null) { if (frame != null) {
ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf(); ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf();
compressionFrame.writeBytes(bytebuf); compressionFrame.writeBytes(bytebuf);
bytebuf.release();
} }
} }
}, },
new NettyWritableBufferAllocator(ByteBufAllocator.DEFAULT), new NettyWritableBufferAllocator(ByteBufAllocator.DEFAULT),
StatsTraceContext.NOOP); StatsTraceContext.NOOP);
framer.writePayload(new ByteArrayInputStream(content)); framer.writePayload(new ByteArrayInputStream(message));
framer.flush(); framer.close();
ChannelHandlerContext ctx = newMockContext(); return compressionFrame;
new DefaultHttp2FrameWriter().writeData(ctx, streamId, compressionFrame, 0, endStream, }
newPromise());
return captureWrite(ctx); protected final ByteBuf grpcDataFrame(int streamId, boolean endStream, byte[] content) {
return dataFrame(streamId, endStream, grpcFrame(content));
} }
protected final ByteBuf dataFrame(int streamId, boolean endStream, ByteBuf content) { protected final ByteBuf dataFrame(int streamId, boolean endStream, ByteBuf content) {
// Need to retain the content since the frameWriter releases it.
content.retain();
ChannelHandlerContext ctx = newMockContext(); ChannelHandlerContext ctx = newMockContext();
new DefaultHttp2FrameWriter().writeData(ctx, streamId, content, 0, endStream, newPromise()); new DefaultHttp2FrameWriter().writeData(ctx, streamId, content, 0, endStream, newPromise());
return captureWrite(ctx); return captureWrite(ctx);
@ -410,6 +414,7 @@ public abstract class NettyHandlerTestBase<T extends Http2ConnectionHandler> {
channelRead(dataFrame(3, false, buff.copy())); channelRead(dataFrame(3, false, buff.copy()));
assertEquals(length * 3, handler.flowControlPing().getDataSincePing()); assertEquals(length * 3, handler.flowControlPing().getDataSincePing());
buff.release();
} }
@Test @Test
@ -608,12 +613,14 @@ public abstract class NettyHandlerTestBase<T extends Http2ConnectionHandler> {
private void readPingAck(long pingData) throws Exception { private void readPingAck(long pingData) throws Exception {
channelRead(pingFrame(true, pingData)); channelRead(pingFrame(true, pingData));
channel().releaseOutbound();
} }
private void readXCopies(int copies, byte[] data) throws Exception { private void readXCopies(int copies, byte[] data) throws Exception {
for (int i = 0; i < copies; i++) { for (int i = 0; i < copies; i++) {
channelRead(grpcDataFrame(STREAM_ID, false, data)); // buffer it channelRead(grpcDataFrame(STREAM_ID, false, data)); // buffer it
stream().request(1); // consume it stream().request(1); // consume it
channel().releaseOutbound();
} }
} }

View File

@ -43,6 +43,7 @@ import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA; import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
@ -74,7 +75,6 @@ import io.grpc.internal.StreamListener;
import io.grpc.internal.testing.TestServerStreamTracer; import io.grpc.internal.testing.TestServerStreamTracer;
import io.grpc.netty.GrpcHttp2HeadersUtils.GrpcHttp2ServerHeadersDecoder; import io.grpc.netty.GrpcHttp2HeadersUtils.GrpcHttp2ServerHeadersDecoder;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
@ -120,23 +120,16 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
public final TestRule globalTimeout = new DisableOnDebug(Timeout.seconds(10)); public final TestRule globalTimeout = new DisableOnDebug(Timeout.seconds(10));
@Rule @Rule
public final MockitoRule mocks = MockitoJUnit.rule(); public final MockitoRule mocks = MockitoJUnit.rule();
private static final AsciiString HTTP_FAKE_METHOD = AsciiString.of("FAKE"); private static final AsciiString HTTP_FAKE_METHOD = AsciiString.of("FAKE");
@Mock @Mock
private ServerStreamListener streamListener; private ServerStreamListener streamListener;
@Mock @Mock
private ServerStreamTracer.Factory streamTracerFactory; private ServerStreamTracer.Factory streamTracerFactory;
private final ServerTransportListener transportListener = private final ServerTransportListener transportListener =
mock(ServerTransportListener.class, delegatesTo(new ServerTransportListenerImpl())); mock(ServerTransportListener.class, delegatesTo(new ServerTransportListenerImpl()));
private final TestServerStreamTracer streamTracer = new TestServerStreamTracer(); private final TestServerStreamTracer streamTracer = new TestServerStreamTracer();
private NettyServerStream stream; private NettyServerStream stream;
private KeepAliveManager spyKeepAliveManager; private KeepAliveManager spyKeepAliveManager;
final Queue<InputStream> streamListenerMessageQueue = new LinkedList<>(); final Queue<InputStream> streamListenerMessageQueue = new LinkedList<>();
private int maxConcurrentStreams = Integer.MAX_VALUE; private int maxConcurrentStreams = Integer.MAX_VALUE;
@ -208,6 +201,7 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
// Simulate receipt of initial remote settings. // Simulate receipt of initial remote settings.
ByteBuf serializedSettings = serializeSettings(new Http2Settings()); ByteBuf serializedSettings = serializeSettings(new Http2Settings());
channelRead(serializedSettings); channelRead(serializedSettings);
channel().releaseOutbound();
} }
@Test @Test
@ -229,10 +223,11 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
createStream(); createStream();
// Send a frame and verify that it was written. // Send a frame and verify that it was written.
ByteBuf content = content();
ChannelFuture future = enqueue( ChannelFuture future = enqueue(
new SendGrpcFrameCommand(stream.transportState(), content(), false)); new SendGrpcFrameCommand(stream.transportState(), content, false));
assertTrue(future.isSuccess()); assertTrue(future.isSuccess());
verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), eq(content()), eq(0), eq(false), verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), same(content), eq(0), eq(false),
any(ChannelPromise.class)); any(ChannelPromise.class));
} }
@ -267,10 +262,11 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
// Create a data frame and then trigger the handler to read it. // Create a data frame and then trigger the handler to read it.
ByteBuf frame = grpcDataFrame(STREAM_ID, endStream, contentAsArray()); ByteBuf frame = grpcDataFrame(STREAM_ID, endStream, contentAsArray());
channelRead(frame); channelRead(frame);
channel().releaseOutbound();
verify(streamListener, atLeastOnce()) verify(streamListener, atLeastOnce())
.messagesAvailable(any(StreamListener.MessageProducer.class)); .messagesAvailable(any(StreamListener.MessageProducer.class));
InputStream message = streamListenerMessageQueue.poll(); InputStream message = streamListenerMessageQueue.poll();
assertArrayEquals(ByteBufUtil.getBytes(content()), ByteStreams.toByteArray(message)); assertArrayEquals(contentAsArray(), ByteStreams.toByteArray(message));
message.close(); message.close();
assertNull("no additional message expected", streamListenerMessageQueue.poll()); assertNull("no additional message expected", streamListenerMessageQueue.poll());
@ -870,7 +866,7 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
future.get(); future.get();
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
future = enqueue( future = enqueue(
new SendGrpcFrameCommand(stream.transportState(), content().retainedSlice(), false)); new SendGrpcFrameCommand(stream.transportState(), content(), false));
future.get(); future.get();
channel().releaseOutbound(); channel().releaseOutbound();
channelRead(pingFrame(false /* isAck */, 1L)); channelRead(pingFrame(false /* isAck */, 1L));
@ -1293,6 +1289,7 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100); maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100);
manualSetUp(); manualSetUp();
rapidReset(maxRstCount); rapidReset(maxRstCount);
assertTrue(channel().isOpen()); assertTrue(channel().isOpen());
} }
@ -1302,6 +1299,7 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100); maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100);
manualSetUp(); manualSetUp();
assertThrows(ClosedChannelException.class, () -> rapidReset(maxRstCount + 1)); assertThrows(ClosedChannelException.class, () -> rapidReset(maxRstCount + 1));
assertFalse(channel().isOpen()); assertFalse(channel().isOpen());
} }
@ -1344,11 +1342,7 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
private ByteBuf emptyGrpcFrame(int streamId, boolean endStream) throws Exception { private ByteBuf emptyGrpcFrame(int streamId, boolean endStream) throws Exception {
ByteBuf buf = NettyTestUtil.messageFrame(""); ByteBuf buf = NettyTestUtil.messageFrame("");
try {
return dataFrame(streamId, endStream, buf); return dataFrame(streamId, endStream, buf);
} finally {
buf.release();
}
} }
@Override @Override