diff --git a/netty/src/main/java/io/grpc/netty/WriteQueue.java b/netty/src/main/java/io/grpc/netty/WriteQueue.java index 2a58ab1e0a..3fbaf1bb73 100644 --- a/netty/src/main/java/io/grpc/netty/WriteQueue.java +++ b/netty/src/main/java/io/grpc/netty/WriteQueue.java @@ -110,7 +110,7 @@ class WriteQueue { */ ChannelFuture enqueue(QueuedCommand command, ChannelPromise promise, boolean flush) { // Detect erroneous code that tries to reuse command objects. - Preconditions.checkNotNull(command.promise() == null, "promise must not be set on command"); + Preconditions.checkArgument(command.promise() == null, "promise must not be set on command"); command.promise(promise); queue.add(command); diff --git a/netty/src/test/java/io/grpc/netty/WriteQueueTest.java b/netty/src/test/java/io/grpc/netty/WriteQueueTest.java index 8b460cc9df..696072c06f 100644 --- a/netty/src/test/java/io/grpc/netty/WriteQueueTest.java +++ b/netty/src/test/java/io/grpc/netty/WriteQueueTest.java @@ -33,6 +33,8 @@ package io.grpc.netty; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Matchers.isA; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -65,9 +67,6 @@ public class WriteQueueTest { @Mock public ChannelPromise promise; - private QueuedCommand command = new WriteQueue.AbstractQueuedCommand() { - }; - private long flushCalledNanos; private long writeCalledNanos; @@ -102,24 +101,25 @@ public class WriteQueueTest { } }); - when(channel.write(command, promise)).thenAnswer(new Answer() { - @Override - public ChannelFuture answer(InvocationOnMock invocation) throws Throwable { - writeCalledNanos = System.nanoTime(); - if (writeCalledNanos == flushCalledNanos) { - writeCalledNanos += 1; - } - return promise; - } - }); + when(channel.write(any(QueuedCommand.class), eq(promise))).thenAnswer( + new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocation) throws Throwable { + writeCalledNanos = System.nanoTime(); + if (writeCalledNanos == flushCalledNanos) { + writeCalledNanos += 1; + } + return promise; + } + }); } @Test public void singleWriteShouldWork() { WriteQueue queue = new WriteQueue(channel); - queue.enqueue(command, true); + queue.enqueue(new CuteCommand(), true); - verify(channel).write(command, promise); + verify(channel).write(isA(QueuedCommand.class), eq(promise)); verify(channel).flush(); } @@ -127,11 +127,11 @@ public class WriteQueueTest { public void multipleWritesShouldBeBatched() { WriteQueue queue = new WriteQueue(channel); for (int i = 0; i < 5; i++) { - queue.enqueue(command, false); + queue.enqueue(new CuteCommand(), false); } queue.scheduleFlush(); - verify(channel, times(5)).write(command, promise); + verify(channel, times(5)).write(isA(QueuedCommand.class), eq(promise)); verify(channel).flush(); } @@ -140,11 +140,11 @@ public class WriteQueueTest { WriteQueue queue = new WriteQueue(channel); int writes = WriteQueue.DEQUE_CHUNK_SIZE + 10; for (int i = 0; i < writes; i++) { - queue.enqueue(command, false); + queue.enqueue(new CuteCommand(), false); } queue.scheduleFlush(); - verify(channel, times(writes)).write(command, promise); + verify(channel, times(writes)).write(isA(QueuedCommand.class), eq(promise)); verify(channel, times(2)).flush(); } @@ -195,12 +195,16 @@ public class WriteQueueTest { flusherStarted.await(); int writes = 10 * WriteQueue.DEQUE_CHUNK_SIZE; for (int i = 0; i < writes; i++) { - queue.enqueue(command, false); + queue.enqueue(new CuteCommand(), false); } doneWriting.set(true); flusher.join(); exHandler.checkException(); - verify(channel, times(writes)).write(command, promise); + verify(channel, times(writes)).write(isA(CuteCommand.class), eq(promise)); + } + + static class CuteCommand extends WriteQueue.AbstractQueuedCommand { + } }