diff --git a/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java b/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java index adeb635e00..98c1661440 100644 --- a/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java +++ b/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java @@ -18,13 +18,18 @@ package io.grpc.netty; import static com.google.common.base.Charsets.UTF_8; import static org.junit.Assert.assertEquals; +import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyLong; import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import com.google.errorprone.annotations.CanIgnoreReturnValue; +import io.grpc.internal.FakeClock; import io.grpc.internal.MessageFramer; import io.grpc.internal.StatsTraceContext; import io.grpc.internal.WritableBuffer; @@ -35,6 +40,7 @@ import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.Unpooled; import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import io.netty.channel.EventLoop; @@ -51,12 +57,18 @@ import io.netty.handler.codec.http2.Http2HeadersDecoder; import io.netty.handler.codec.http2.Http2LocalFlowController; import io.netty.handler.codec.http2.Http2Settings; import io.netty.handler.codec.http2.Http2Stream; +import io.netty.util.concurrent.DefaultPromise; +import io.netty.util.concurrent.Promise; +import io.netty.util.concurrent.ScheduledFuture; import java.io.ByteArrayInputStream; +import java.util.concurrent.Delayed; +import java.util.concurrent.TimeUnit; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; -import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; import org.mockito.verification.VerificationMode; /** @@ -84,6 +96,12 @@ public abstract class NettyHandlerTestBase { */ protected void manualSetUp() throws Exception {} + private final FakeClock fakeClock = new FakeClock(); + + FakeClock fakeClock() { + return fakeClock; + } + /** * Must be called by subclasses to initialize the handler and channel. */ @@ -94,12 +112,91 @@ public abstract class NettyHandlerTestBase { handler = newHandler(); - channel = new EmbeddedChannel(handler); + channel = new FakeClockSupportedChanel(handler); ctx = channel.pipeline().context(handler); writeQueue = initWriteQueue(); } + private final class FakeClockSupportedChanel extends EmbeddedChannel { + EventLoop eventLoop; + + FakeClockSupportedChanel(ChannelHandler... handlers) { + super(handlers); + } + + @Override + public EventLoop eventLoop() { + if (eventLoop == null) { + createEventLoop(); + } + return eventLoop; + } + + void createEventLoop() { + EventLoop realEventLoop = super.eventLoop(); + if (realEventLoop == null) { + return; + } + eventLoop = mock(EventLoop.class, delegatesTo(realEventLoop)); + doAnswer( + new Answer>() { + @Override + public ScheduledFuture answer(InvocationOnMock invocation) throws Throwable { + Runnable command = (Runnable) invocation.getArguments()[0]; + Long delay = (Long) invocation.getArguments()[1]; + TimeUnit timeUnit = (TimeUnit) invocation.getArguments()[2]; + return new FakeClockScheduledNettyFuture(eventLoop, command, delay, timeUnit); + } + }).when(eventLoop).schedule(any(Runnable.class), anyLong(), any(TimeUnit.class)); + } + } + + private final class FakeClockScheduledNettyFuture extends DefaultPromise + implements ScheduledFuture { + final java.util.concurrent.ScheduledFuture future; + + FakeClockScheduledNettyFuture( + EventLoop eventLoop, final Runnable command, long delay, TimeUnit timeUnit) { + super(eventLoop); + Runnable wrap = new Runnable() { + @Override + public void run() { + try { + command.run(); + } catch (Throwable t) { + setFailure(t); + return; + } + if (!isDone()) { + Promise unused = setSuccess(null); + } + // else: The command itself, such as a shutdown task, might have cancelled all the + // scheduled tasks already. + } + }; + future = fakeClock.getScheduledExecutorService().schedule(wrap, delay, timeUnit); + } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + if (future.cancel(mayInterruptIfRunning)) { + return super.cancel(mayInterruptIfRunning); + } + return false; + } + + @Override + public long getDelay(TimeUnit unit) { + return Math.max(future.getDelay(unit), 1L); // never return zero or negative delay. + } + + @Override + public int compareTo(Delayed o) { + return future.compareTo(o); + } + } + protected final T handler() { return handler; } @@ -221,9 +318,9 @@ public abstract class NettyHandlerTestBase { } protected final ChannelHandlerContext newMockContext() { - ChannelHandlerContext ctx = Mockito.mock(ChannelHandlerContext.class); + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); - EventLoop eventLoop = Mockito.mock(EventLoop.class); + EventLoop eventLoop = mock(EventLoop.class); when(ctx.executor()).thenReturn(eventLoop); return ctx; } diff --git a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java index c99b8614ab..e1549c3458 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java @@ -20,7 +20,6 @@ import static com.google.common.base.Charsets.UTF_8; import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import static io.grpc.internal.GrpcUtil.DEFAULT_SERVER_KEEPALIVE_TIMEOUT_NANOS; import static io.grpc.internal.GrpcUtil.DEFAULT_SERVER_KEEPALIVE_TIME_NANOS; -import static io.grpc.internal.testing.TestUtils.sleepAtLeast; import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE; import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_AGE_NANOS_DISABLED; import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_IDLE_NANOS_DISABLED; @@ -488,14 +487,12 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase