Add unit tests for SizeEnforcingInputStream

This commit is contained in:
Carl Mastrangelo 2015-10-27 10:45:16 -07:00
parent 353aabce51
commit 122f93c26d
2 changed files with 142 additions and 4 deletions

View File

@ -33,6 +33,7 @@ package io.grpc.internal;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import io.grpc.Codec; import io.grpc.Codec;
@ -366,8 +367,9 @@ public class MessageDeframer implements Closeable {
try { try {
// Enforce the maxMessageSize limit on the returned stream. // Enforce the maxMessageSize limit on the returned stream.
return new SizeEnforcingInputStream(decompressor.decompress( InputStream unlimitedStream =
ReadableBuffers.openStream(nextFrame, true))); decompressor.decompress(ReadableBuffers.openStream(nextFrame, true));
return new SizeEnforcingInputStream(unlimitedStream, maxMessageSize);
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@ -376,12 +378,15 @@ public class MessageDeframer implements Closeable {
/** /**
* An {@link InputStream} that enforces the {@link #maxMessageSize} limit for compressed frames. * An {@link InputStream} that enforces the {@link #maxMessageSize} limit for compressed frames.
*/ */
private final class SizeEnforcingInputStream extends FilterInputStream { @VisibleForTesting
static final class SizeEnforcingInputStream extends FilterInputStream {
private final int maxMessageSize;
private long count; private long count;
private long mark = -1; private long mark = -1;
public SizeEnforcingInputStream(InputStream in) { SizeEnforcingInputStream(InputStream in, int maxMessageSize) {
super(in); super(in);
this.maxMessageSize = maxMessageSize;
} }
@Override @Override

View File

@ -42,13 +42,18 @@ import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import com.google.common.base.Charsets;
import com.google.common.io.ByteStreams; import com.google.common.io.ByteStreams;
import com.google.common.primitives.Bytes; import com.google.common.primitives.Bytes;
import io.grpc.Codec; import io.grpc.Codec;
import io.grpc.StatusRuntimeException;
import io.grpc.internal.MessageDeframer.Listener; import io.grpc.internal.MessageDeframer.Listener;
import io.grpc.internal.MessageDeframer.SizeEnforcingInputStream;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
@ -56,6 +61,7 @@ import org.mockito.Matchers;
import org.mockito.invocation.InvocationOnMock; import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
@ -67,6 +73,8 @@ import java.util.zip.GZIPOutputStream;
*/ */
@RunWith(JUnit4.class) @RunWith(JUnit4.class)
public class MessageDeframerTest { public class MessageDeframerTest {
@Rule public final ExpectedException thrown = ExpectedException.none();
private Listener listener = mock(Listener.class); private Listener listener = mock(Listener.class);
private MessageDeframer deframer = new MessageDeframer(listener, Codec.Identity.NONE, private MessageDeframer deframer = new MessageDeframer(listener, Codec.Identity.NONE,
DEFAULT_MAX_MESSAGE_SIZE); DEFAULT_MAX_MESSAGE_SIZE);
@ -211,6 +219,131 @@ public class MessageDeframerTest {
verifyNoMoreInteractions(listener); verifyNoMoreInteractions(listener);
} }
@Test
public void sizeEnforcingInputStream_readByteBelowLimit() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 4);
while (stream.read() != -1) {}
stream.close();
}
@Test
public void sizeEnforcingInputStream_readByteAtLimit() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 3);
while (stream.read() != -1) {}
stream.close();
}
@Test
public void sizeEnforcingInputStream_readByteAboveLimit() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 2);
thrown.expect(StatusRuntimeException.class);
thrown.expectMessage("INTERNAL: Compressed frame exceeds");
while (stream.read() != -1) {}
// Never run, makes compiler nag go away
stream.close();
}
@Test
public void sizeEnforcingInputStream_readBelowLimit() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 4);
byte[] buf = new byte[10];
int read = stream.read(buf, 0, buf.length);
assertEquals(3, read);
stream.close();
}
@Test
public void sizeEnforcingInputStream_readAtLimit() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 3);
byte[] buf = new byte[10];
int read = stream.read(buf, 0, buf.length);
assertEquals(3, read);
stream.close();
}
@Test
public void sizeEnforcingInputStream_readAboveLimit() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 2);
byte[] buf = new byte[10];
thrown.expect(StatusRuntimeException.class);
thrown.expectMessage("INTERNAL: Compressed frame exceeds");
stream.read(buf, 0, buf.length);
// Never called
stream.close();
}
@Test
public void sizeEnforcingInputStream_skipBelowLimit() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 4);
long skipped = stream.skip(4);
assertEquals(3, skipped);
stream.close();
}
@Test
public void sizeEnforcingInputStream_skipAtLimit() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 3);
long skipped = stream.skip(4);
assertEquals(3, skipped);
stream.close();
}
@Test
public void sizeEnforcingInputStream_skipAboveLimit() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 2);
thrown.expect(StatusRuntimeException.class);
thrown.expectMessage("INTERNAL: Compressed frame exceeds");
stream.skip(4);
// never run
stream.close();
}
@Test
public void sizeEnforcingInputStream_markReset() throws IOException {
ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8));
SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 3);
// stream currently looks like: |foo
stream.skip(1); // f|oo
stream.mark(10); // any large number will work.
stream.skip(2); // foo|
stream.reset(); // f|oo
long skipped = stream.skip(2); // foo|
assertEquals(2, skipped);
stream.close();
}
private static List<Byte> bytes(ArgumentCaptor<InputStream> captor) { private static List<Byte> bytes(ArgumentCaptor<InputStream> captor) {
return bytes(captor.getValue()); return bytes(captor.getValue());
} }