Add unit tests for SizeEnforcingInputStream

This commit is contained in:
Carl Mastrangelo 2015-10-27 10:45:16 -07:00
parent 353aabce51
commit 122f93c26d
2 changed files with 142 additions and 4 deletions

View File

@ -33,6 +33,7 @@ package io.grpc.internal;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import io.grpc.Codec;
@ -366,8 +367,9 @@ public class MessageDeframer implements Closeable {
try {
// Enforce the maxMessageSize limit on the returned stream.
return new SizeEnforcingInputStream(decompressor.decompress(
ReadableBuffers.openStream(nextFrame, true)));
InputStream unlimitedStream =
decompressor.decompress(ReadableBuffers.openStream(nextFrame, true));
return new SizeEnforcingInputStream(unlimitedStream, maxMessageSize);
} catch (IOException e) {
throw new RuntimeException(e);
}
@ -376,12 +378,15 @@ public class MessageDeframer implements Closeable {
/**
* An {@link InputStream} that enforces the {@link #maxMessageSize} limit for compressed frames.
*/
private final class SizeEnforcingInputStream extends FilterInputStream {
@VisibleForTesting
static final class SizeEnforcingInputStream extends FilterInputStream {
private final int maxMessageSize;
private long count;
private long mark = -1;
public SizeEnforcingInputStream(InputStream in) {
SizeEnforcingInputStream(InputStream in, int maxMessageSize) {
super(in);
this.maxMessageSize = maxMessageSize;
}
@Override

View File

@ -42,13 +42,18 @@ import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import com.google.common.base.Charsets;
import com.google.common.io.ByteStreams;
import com.google.common.primitives.Bytes;
import io.grpc.Codec;
import io.grpc.StatusRuntimeException;
import io.grpc.internal.MessageDeframer.Listener;
import io.grpc.internal.MessageDeframer.SizeEnforcingInputStream;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
@ -56,6 +61,7 @@ import org.mockito.Matchers;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
@ -67,6 +73,8 @@ import java.util.zip.GZIPOutputStream;
*/
@RunWith(JUnit4.class)
public class MessageDeframerTest {
@Rule public final ExpectedException thrown = ExpectedException.none();
private Listener listener = mock(Listener.class);
private MessageDeframer deframer = new MessageDeframer(listener, Codec.Identity.NONE,
DEFAULT_MAX_MESSAGE_SIZE);
@ -211,6 +219,131 @@ public class MessageDeframerTest {
verifyNoMoreInteractions(listener);
}
@Test
public void sizeEnforcingInputStream_readByteBelowLimit() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 4);
while (stream.read() != -1) {}
stream.close();
}
@Test
public void sizeEnforcingInputStream_readByteAtLimit() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 3);
while (stream.read() != -1) {}
stream.close();
}
@Test
public void sizeEnforcingInputStream_readByteAboveLimit() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 2);
thrown.expect(StatusRuntimeException.class);
thrown.expectMessage("INTERNAL: Compressed frame exceeds");
while (stream.read() != -1) {}
// Never run, makes compiler nag go away
stream.close();
}
@Test
public void sizeEnforcingInputStream_readBelowLimit() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 4);
byte[] buf = new byte[10];
int read = stream.read(buf, 0, buf.length);
assertEquals(3, read);
stream.close();
}
@Test
public void sizeEnforcingInputStream_readAtLimit() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 3);
byte[] buf = new byte[10];
int read = stream.read(buf, 0, buf.length);
assertEquals(3, read);
stream.close();
}
@Test
public void sizeEnforcingInputStream_readAboveLimit() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 2);
byte[] buf = new byte[10];
thrown.expect(StatusRuntimeException.class);
thrown.expectMessage("INTERNAL: Compressed frame exceeds");
stream.read(buf, 0, buf.length);
// Never called
stream.close();
}
@Test
public void sizeEnforcingInputStream_skipBelowLimit() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 4);
long skipped = stream.skip(4);
assertEquals(3, skipped);
stream.close();
}
@Test
public void sizeEnforcingInputStream_skipAtLimit() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 3);
long skipped = stream.skip(4);
assertEquals(3, skipped);
stream.close();
}
@Test
public void sizeEnforcingInputStream_skipAboveLimit() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 2);
thrown.expect(StatusRuntimeException.class);
thrown.expectMessage("INTERNAL: Compressed frame exceeds");
stream.skip(4);
// never run
stream.close();
}
@Test
public void sizeEnforcingInputStream_markReset() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 3);
// stream currently looks like: |foo
stream.skip(1); // f|oo
stream.mark(10); // any large number will work.
stream.skip(2); // foo|
stream.reset(); // f|oo
long skipped = stream.skip(2); // foo|
assertEquals(2, skipped);
stream.close();
}
private static List<Byte> bytes(ArgumentCaptor<InputStream> captor) {
return bytes(captor.getValue());
}