netty: Fix ByteBuf leaks in tests (#11593)

Part of #3353
This commit is contained in:
vinodhabib 2024-12-03 00:39:25 +05:30 committed by GitHub
parent 7f9c1f39f3
commit f66d7fc54d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 39 additions and 36 deletions

View File

@ -217,6 +217,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
// Simulate receipt of initial remote settings.
ByteBuf serializedSettings = serializeSettings(new Http2Settings());
channelRead(serializedSettings);
channel().releaseOutbound();
}
@Test
@ -342,11 +343,12 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase<NettyClientHand
createStream();
// Send a frame and verify that it was written.
ByteBuf content = content();
ChannelFuture future
= enqueue(new SendGrpcFrameCommand(streamTransportState, content(), true));
= enqueue(new SendGrpcFrameCommand(streamTransportState, content, true));
assertTrue(future.isSuccess());
verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), eq(content()), eq(0), eq(true),
verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), same(content), eq(0), eq(true),
any(ChannelPromise.class));
verify(mockKeepAliveManager, times(1)).onTransportActive(); // onStreamActive
verifyNoMoreInteractions(mockKeepAliveManager);

View File

@ -38,7 +38,6 @@ import io.grpc.internal.TransportTracer;
import io.grpc.internal.WritableBuffer;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.buffer.UnpooledByteBufAllocator;
@ -68,6 +67,7 @@ import java.io.ByteArrayInputStream;
import java.nio.ByteBuffer;
import java.util.concurrent.Delayed;
import java.util.concurrent.TimeUnit;
import org.junit.After;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -84,7 +84,6 @@ import org.mockito.verification.VerificationMode;
public abstract class NettyHandlerTestBase<T extends Http2ConnectionHandler> {
protected static final int STREAM_ID = 3;
private ByteBuf content;
private EmbeddedChannel channel;
@ -106,18 +105,24 @@ public abstract class NettyHandlerTestBase<T extends Http2ConnectionHandler> {
protected final TransportTracer transportTracer = new TransportTracer();
protected int flowControlWindow = DEFAULT_WINDOW_SIZE;
protected boolean autoFlowControl = false;
private final FakeClock fakeClock = new FakeClock();
FakeClock fakeClock() {
return fakeClock;
}
@After
public void tearDown() throws Exception {
if (channel() != null) {
channel().releaseInbound();
channel().releaseOutbound();
}
}
/**
* Must be called by subclasses to initialize the handler and channel.
*/
protected final void initChannel(Http2HeadersDecoder headersDecoder) throws Exception {
content = Unpooled.copiedBuffer("hello world", UTF_8);
frameWriter = mock(Http2FrameWriter.class, delegatesTo(new DefaultHttp2FrameWriter()));
frameReader = new DefaultHttp2FrameReader(headersDecoder);
@ -233,11 +238,11 @@ public abstract class NettyHandlerTestBase<T extends Http2ConnectionHandler> {
}
protected final ByteBuf content() {
return content;
return Unpooled.copiedBuffer(contentAsArray());
}
protected final byte[] contentAsArray() {
return ByteBufUtil.getBytes(content());
return "\000\000\000\000\rhello world".getBytes(UTF_8);
}
protected final Http2FrameWriter verifyWrite() {
@ -252,8 +257,8 @@ public abstract class NettyHandlerTestBase<T extends Http2ConnectionHandler> {
channel.writeInbound(obj);
}
protected ByteBuf grpcDataFrame(int streamId, boolean endStream, byte[] content) {
final ByteBuf compressionFrame = Unpooled.buffer(content.length);
protected ByteBuf grpcFrame(byte[] message) {
final ByteBuf compressionFrame = Unpooled.buffer(message.length);
MessageFramer framer = new MessageFramer(
new MessageFramer.Sink() {
@Override
@ -262,23 +267,22 @@ public abstract class NettyHandlerTestBase<T extends Http2ConnectionHandler> {
if (frame != null) {
ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf();
compressionFrame.writeBytes(bytebuf);
bytebuf.release();
}
}
},
new NettyWritableBufferAllocator(ByteBufAllocator.DEFAULT),
StatsTraceContext.NOOP);
framer.writePayload(new ByteArrayInputStream(content));
framer.flush();
ChannelHandlerContext ctx = newMockContext();
new DefaultHttp2FrameWriter().writeData(ctx, streamId, compressionFrame, 0, endStream,
newPromise());
return captureWrite(ctx);
framer.writePayload(new ByteArrayInputStream(message));
framer.close();
return compressionFrame;
}
protected final ByteBuf grpcDataFrame(int streamId, boolean endStream, byte[] content) {
return dataFrame(streamId, endStream, grpcFrame(content));
}
protected final ByteBuf dataFrame(int streamId, boolean endStream, ByteBuf content) {
// Need to retain the content since the frameWriter releases it.
content.retain();
ChannelHandlerContext ctx = newMockContext();
new DefaultHttp2FrameWriter().writeData(ctx, streamId, content, 0, endStream, newPromise());
return captureWrite(ctx);
@ -410,6 +414,7 @@ public abstract class NettyHandlerTestBase<T extends Http2ConnectionHandler> {
channelRead(dataFrame(3, false, buff.copy()));
assertEquals(length * 3, handler.flowControlPing().getDataSincePing());
buff.release();
}
@Test
@ -608,12 +613,14 @@ public abstract class NettyHandlerTestBase<T extends Http2ConnectionHandler> {
private void readPingAck(long pingData) throws Exception {
channelRead(pingFrame(true, pingData));
channel().releaseOutbound();
}
private void readXCopies(int copies, byte[] data) throws Exception {
for (int i = 0; i < copies; i++) {
channelRead(grpcDataFrame(STREAM_ID, false, data)); // buffer it
stream().request(1); // consume it
channel().releaseOutbound();
}
}

View File

@ -43,6 +43,7 @@ import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
@ -74,7 +75,6 @@ import io.grpc.internal.StreamListener;
import io.grpc.internal.testing.TestServerStreamTracer;
import io.grpc.netty.GrpcHttp2HeadersUtils.GrpcHttp2ServerHeadersDecoder;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
@ -120,23 +120,16 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
public final TestRule globalTimeout = new DisableOnDebug(Timeout.seconds(10));
@Rule
public final MockitoRule mocks = MockitoJUnit.rule();
private static final AsciiString HTTP_FAKE_METHOD = AsciiString.of("FAKE");
@Mock
private ServerStreamListener streamListener;
@Mock
private ServerStreamTracer.Factory streamTracerFactory;
private final ServerTransportListener transportListener =
mock(ServerTransportListener.class, delegatesTo(new ServerTransportListenerImpl()));
private final TestServerStreamTracer streamTracer = new TestServerStreamTracer();
private NettyServerStream stream;
private KeepAliveManager spyKeepAliveManager;
final Queue<InputStream> streamListenerMessageQueue = new LinkedList<>();
private int maxConcurrentStreams = Integer.MAX_VALUE;
@ -208,6 +201,7 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
// Simulate receipt of initial remote settings.
ByteBuf serializedSettings = serializeSettings(new Http2Settings());
channelRead(serializedSettings);
channel().releaseOutbound();
}
@Test
@ -229,10 +223,11 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
createStream();
// Send a frame and verify that it was written.
ByteBuf content = content();
ChannelFuture future = enqueue(
new SendGrpcFrameCommand(stream.transportState(), content(), false));
new SendGrpcFrameCommand(stream.transportState(), content, false));
assertTrue(future.isSuccess());
verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), eq(content()), eq(0), eq(false),
verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), same(content), eq(0), eq(false),
any(ChannelPromise.class));
}
@ -267,10 +262,11 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
// Create a data frame and then trigger the handler to read it.
ByteBuf frame = grpcDataFrame(STREAM_ID, endStream, contentAsArray());
channelRead(frame);
channel().releaseOutbound();
verify(streamListener, atLeastOnce())
.messagesAvailable(any(StreamListener.MessageProducer.class));
InputStream message = streamListenerMessageQueue.poll();
assertArrayEquals(ByteBufUtil.getBytes(content()), ByteStreams.toByteArray(message));
assertArrayEquals(contentAsArray(), ByteStreams.toByteArray(message));
message.close();
assertNull("no additional message expected", streamListenerMessageQueue.poll());
@ -870,7 +866,7 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
future.get();
for (int i = 0; i < 10; i++) {
future = enqueue(
new SendGrpcFrameCommand(stream.transportState(), content().retainedSlice(), false));
new SendGrpcFrameCommand(stream.transportState(), content(), false));
future.get();
channel().releaseOutbound();
channelRead(pingFrame(false /* isAck */, 1L));
@ -1293,6 +1289,7 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100);
manualSetUp();
rapidReset(maxRstCount);
assertTrue(channel().isOpen());
}
@ -1302,6 +1299,7 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100);
manualSetUp();
assertThrows(ClosedChannelException.class, () -> rapidReset(maxRstCount + 1));
assertFalse(channel().isOpen());
}
@ -1344,11 +1342,7 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
private ByteBuf emptyGrpcFrame(int streamId, boolean endStream) throws Exception {
ByteBuf buf = NettyTestUtil.messageFrame("");
try {
return dataFrame(streamId, endStream, buf);
} finally {
buf.release();
}
}
@Override