diff --git a/core/src/main/java/io/grpc/internal/MessageDeframer.java b/core/src/main/java/io/grpc/internal/MessageDeframer.java index 11077426ed..d31b759ea5 100644 --- a/core/src/main/java/io/grpc/internal/MessageDeframer.java +++ b/core/src/main/java/io/grpc/internal/MessageDeframer.java @@ -33,6 +33,7 @@ package io.grpc.internal; import static com.google.common.base.Preconditions.checkNotNull; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import io.grpc.Codec; @@ -366,8 +367,9 @@ public class MessageDeframer implements Closeable { try { // Enforce the maxMessageSize limit on the returned stream. - return new SizeEnforcingInputStream(decompressor.decompress( - ReadableBuffers.openStream(nextFrame, true))); + InputStream unlimitedStream = + decompressor.decompress(ReadableBuffers.openStream(nextFrame, true)); + return new SizeEnforcingInputStream(unlimitedStream, maxMessageSize); } catch (IOException e) { throw new RuntimeException(e); } @@ -376,12 +378,15 @@ public class MessageDeframer implements Closeable { /** * An {@link InputStream} that enforces the {@link #maxMessageSize} limit for compressed frames. */ - private final class SizeEnforcingInputStream extends FilterInputStream { + @VisibleForTesting + static final class SizeEnforcingInputStream extends FilterInputStream { + private final int maxMessageSize; private long count; private long mark = -1; - public SizeEnforcingInputStream(InputStream in) { + SizeEnforcingInputStream(InputStream in, int maxMessageSize) { super(in); + this.maxMessageSize = maxMessageSize; } @Override diff --git a/core/src/test/java/io/grpc/internal/MessageDeframerTest.java b/core/src/test/java/io/grpc/internal/MessageDeframerTest.java index 81cbf876c1..7e9d8a2f37 100644 --- a/core/src/test/java/io/grpc/internal/MessageDeframerTest.java +++ b/core/src/test/java/io/grpc/internal/MessageDeframerTest.java @@ -42,13 +42,18 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; +import com.google.common.base.Charsets; import com.google.common.io.ByteStreams; import com.google.common.primitives.Bytes; import io.grpc.Codec; +import io.grpc.StatusRuntimeException; import io.grpc.internal.MessageDeframer.Listener; +import io.grpc.internal.MessageDeframer.SizeEnforcingInputStream; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; @@ -56,6 +61,7 @@ import org.mockito.Matchers; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; +import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; @@ -67,6 +73,8 @@ import java.util.zip.GZIPOutputStream; */ @RunWith(JUnit4.class) public class MessageDeframerTest { + @Rule public final ExpectedException thrown = ExpectedException.none(); + private Listener listener = mock(Listener.class); private MessageDeframer deframer = new MessageDeframer(listener, Codec.Identity.NONE, DEFAULT_MAX_MESSAGE_SIZE); @@ -211,6 +219,131 @@ public class MessageDeframerTest { verifyNoMoreInteractions(listener); } + @Test + public void sizeEnforcingInputStream_readByteBelowLimit() throws IOException { + ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); + SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 4); + + while (stream.read() != -1) {} + + stream.close(); + } + + @Test + public void sizeEnforcingInputStream_readByteAtLimit() throws IOException { + ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); + SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 3); + + while (stream.read() != -1) {} + + stream.close(); + } + + @Test + public void sizeEnforcingInputStream_readByteAboveLimit() throws IOException { + ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); + SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 2); + + thrown.expect(StatusRuntimeException.class); + thrown.expectMessage("INTERNAL: Compressed frame exceeds"); + + while (stream.read() != -1) {} + + // Never run, makes compiler nag go away + stream.close(); + } + + @Test + public void sizeEnforcingInputStream_readBelowLimit() throws IOException { + ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); + SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 4); + byte[] buf = new byte[10]; + + int read = stream.read(buf, 0, buf.length); + + assertEquals(3, read); + stream.close(); + } + + @Test + public void sizeEnforcingInputStream_readAtLimit() throws IOException { + ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); + SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 3); + byte[] buf = new byte[10]; + + int read = stream.read(buf, 0, buf.length); + + assertEquals(3, read); + stream.close(); + } + + @Test + public void sizeEnforcingInputStream_readAboveLimit() throws IOException { + ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); + SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 2); + byte[] buf = new byte[10]; + + thrown.expect(StatusRuntimeException.class); + thrown.expectMessage("INTERNAL: Compressed frame exceeds"); + + stream.read(buf, 0, buf.length); + + // Never called + stream.close(); + } + + @Test + public void sizeEnforcingInputStream_skipBelowLimit() throws IOException { + ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); + SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 4); + + long skipped = stream.skip(4); + + assertEquals(3, skipped); + + stream.close(); + } + + @Test + public void sizeEnforcingInputStream_skipAtLimit() throws IOException { + ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); + SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 3); + + long skipped = stream.skip(4); + + assertEquals(3, skipped); + stream.close(); + } + + @Test + public void sizeEnforcingInputStream_skipAboveLimit() throws IOException { + ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); + SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 2); + + thrown.expect(StatusRuntimeException.class); + thrown.expectMessage("INTERNAL: Compressed frame exceeds"); + + stream.skip(4); + + // never run + stream.close(); + } + + @Test + public void sizeEnforcingInputStream_markReset() throws IOException { + ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); + SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 3); + // stream currently looks like: |foo + stream.skip(1); // f|oo + stream.mark(10); // any large number will work. + stream.skip(2); // foo| + stream.reset(); // f|oo + long skipped = stream.skip(2); // foo| + + assertEquals(2, skipped); + stream.close(); + } + private static List bytes(ArgumentCaptor captor) { return bytes(captor.getValue()); }