From f5d09ff0b296fdb40ed5a6d6c6e744bd30783eda Mon Sep 17 00:00:00 2001 From: Carl Mastrangelo Date: Mon, 14 Dec 2015 14:56:59 -0800 Subject: [PATCH] Disallow compressing zero length messages. --- .../java/io/grpc/internal/MessageFramer.java | 96 ++++++++++--------- .../io/grpc/internal/AbstractStreamTest.java | 10 +- .../io/grpc/internal/MessageFramerTest.java | 24 +++-- 3 files changed, 75 insertions(+), 55 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/MessageFramer.java b/core/src/main/java/io/grpc/internal/MessageFramer.java index 0545599f46..593f3070f0 100644 --- a/core/src/main/java/io/grpc/internal/MessageFramer.java +++ b/core/src/main/java/io/grpc/internal/MessageFramer.java @@ -31,6 +31,7 @@ package io.grpc.internal; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static java.lang.Math.min; @@ -40,6 +41,7 @@ import io.grpc.Codec; import io.grpc.Compressor; import io.grpc.Drainable; import io.grpc.KnownLength; +import io.grpc.Status; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -76,7 +78,7 @@ public class MessageFramer { private final Sink sink; private WritableBuffer buffer; - private Compressor compressor; + private Compressor compressor = Codec.Identity.NONE; private boolean messageCompression; private final OutputStreamAdapter outputStreamAdapter = new OutputStreamAdapter(); private final byte[] headerScratch = new byte[HEADER_LENGTH]; @@ -84,34 +86,24 @@ public class MessageFramer { private boolean closed; /** - * Creates a {@code MessageFramer} without compression. + * Creates a {@code MessageFramer}. * * @param sink the sink used to deliver frames to the transport * @param bufferAllocator allocates buffers that the transport can commit to the wire. */ public MessageFramer(Sink sink, WritableBufferAllocator bufferAllocator) { - this(sink, bufferAllocator, Codec.Identity.NONE); - } - - /** - * Creates a {@code MessageFramer}. - * - * @param sink the sink used to deliver frames to the transport - * @param bufferAllocator allocates buffers that the transport can commit to the wire. - * @param compressor the compressor to use - */ - public MessageFramer(Sink sink, WritableBufferAllocator bufferAllocator, Compressor compressor) { this.sink = checkNotNull(sink, "sink"); - this.bufferAllocator = bufferAllocator; - this.compressor = checkNotNull(compressor, "compressor"); + this.bufferAllocator = checkNotNull(bufferAllocator, "bufferAllocator"); } - void setCompressor(Compressor compressor) { + MessageFramer setCompressor(Compressor compressor) { this.compressor = checkNotNull(compressor, "Can't pass an empty compressor"); + return this; } - void setMessageCompression(boolean enable) { + MessageFramer setMessageCompression(boolean enable) { messageCompression = enable; + return this; } /** @@ -121,47 +113,58 @@ public class MessageFramer { */ public void writePayload(InputStream message) { verifyNotClosed(); + boolean compressed = messageCompression && compressor != Codec.Identity.NONE; + int written = -1; + int messageLength = -2; try { - if (messageCompression && compressor != Codec.Identity.NONE) { - writeCompressed(message); + messageLength = getKnownLength(message); + if (messageLength != 0 && compressed) { + written = writeCompressed(message, messageLength); } else { - writeUncompressed(message); + written = writeUncompressed(message, messageLength); } - } catch (IOException ex) { - throw new RuntimeException(ex); + } catch (IOException e) { + // This should not be possible, since sink#deliverFrame doesn't throw. + throw Status.INTERNAL + .withDescription("Failed to frame message") + .withCause(e) + .asRuntimeException(); + } catch (RuntimeException e) { + throw Status.INTERNAL + .withDescription("Failed to frame message") + .withCause(e) + .asRuntimeException(); + } + + if (messageLength != -1 && written != messageLength) { + String err = String.format("Message length inaccurate %s != %s", written, messageLength); + throw Status.INTERNAL.withDescription(err).asRuntimeException(); } } - private void writeUncompressed(InputStream message) throws IOException { - int messageLength = getKnownLength(message); + private int writeUncompressed(InputStream message, int messageLength) throws IOException { if (messageLength != -1) { - writeKnownLength(message, messageLength, false); - } else { - BufferChainOutputStream bufferChain = new BufferChainOutputStream(); - writeToOutputStream(message, bufferChain); - writeBufferChain(bufferChain, false); + return writeKnownLength(message, messageLength, false); } + BufferChainOutputStream bufferChain = new BufferChainOutputStream(); + int written = writeToOutputStream(message, bufferChain); + writeBufferChain(bufferChain, false); + return written; } - private void writeCompressed(InputStream message) throws IOException { + private int writeCompressed(InputStream message, int messageLength) throws IOException { BufferChainOutputStream bufferChain = new BufferChainOutputStream(); - // Why this doesn't use getKnownLength() idk, but let's just roll with it. - int messageLength = -1; - if (message instanceof KnownLength) { - messageLength = message.available(); - } OutputStream compressingStream = compressor.compress(bufferChain); + int written; try { - long written = writeToOutputStream(message, compressingStream); - if (messageLength != -1 && messageLength != written) { - throw new RuntimeException("Message length was inaccurate"); - } + written = writeToOutputStream(message, compressingStream); } finally { compressingStream.close(); } writeBufferChain(bufferChain, true); + return written; } private int getKnownLength(InputStream inputStream) throws IOException { @@ -174,7 +177,7 @@ public class MessageFramer { /** * Write an unserialized message with a known length. */ - private void writeKnownLength(InputStream message, int messageLength, boolean compressed) + private int writeKnownLength(InputStream message, int messageLength, boolean compressed) throws IOException { ByteBuffer header = ByteBuffer.wrap(headerScratch); header.put(compressed ? COMPRESSED : UNCOMPRESSED); @@ -185,10 +188,7 @@ public class MessageFramer { buffer = bufferAllocator.allocate(header.position() + messageLength); } writeRaw(headerScratch, 0, header.position()); - long written = writeToOutputStream(message, outputStreamAdapter); - if (messageLength != written) { - throw new RuntimeException("Message length was inaccurate"); - } + return writeToOutputStream(message, outputStreamAdapter); } /** @@ -217,17 +217,19 @@ public class MessageFramer { } // Assign the current buffer to the last in the chain so it can be used // for future writes or written with end-of-stream=true on close. - buffer = bufferList.get(bufferList.size() - 1); + buffer = bufferList.get(bufferList.size() - 1); } - private static long writeToOutputStream(InputStream message, OutputStream outputStream) + private static int writeToOutputStream(InputStream message, OutputStream outputStream) throws IOException { if (message instanceof Drainable) { return ((Drainable) message).drainTo(outputStream); } else { // This makes an unnecessary copy of the bytes when bytebuf supports array(). However, we // expect performance-critical code to support flushTo(). - return ByteStreams.copy(message, outputStream); + long written = ByteStreams.copy(message, outputStream); + checkArgument(written <= Integer.MAX_VALUE, "Message size overflow: %s", written); + return (int) written; } } diff --git a/core/src/test/java/io/grpc/internal/AbstractStreamTest.java b/core/src/test/java/io/grpc/internal/AbstractStreamTest.java index c0008ebfbc..056014bf83 100644 --- a/core/src/test/java/io/grpc/internal/AbstractStreamTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractStreamTest.java @@ -40,6 +40,7 @@ import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.Multimap; import io.grpc.internal.AbstractStream.Phase; +import io.grpc.internal.MessageFramerTest.ByteWritableBuffer; import org.junit.Before; import org.junit.Test; @@ -59,6 +60,13 @@ public class AbstractStreamTest { @Mock MessageFramer framer; @Mock MessageDeframer deframer; + private final WritableBufferAllocator allocator = new WritableBufferAllocator() { + @Override + public WritableBuffer allocate(int capacityHint) { + return new ByteWritableBuffer(capacityHint); + } + }; + @Before public void setUp() { MockitoAnnotations.initMocks(this); @@ -114,7 +122,7 @@ public class AbstractStreamTest { */ private class AbstractStreamBase extends AbstractStream { private AbstractStreamBase(WritableBufferAllocator bufferAllocator) { - super(bufferAllocator, DEFAULT_MAX_MESSAGE_SIZE); + super(allocator, DEFAULT_MAX_MESSAGE_SIZE); } private AbstractStreamBase(MessageFramer framer, MessageDeframer deframer) { diff --git a/core/src/test/java/io/grpc/internal/MessageFramerTest.java b/core/src/test/java/io/grpc/internal/MessageFramerTest.java index 79d7fce77d..5bbfee9245 100644 --- a/core/src/test/java/io/grpc/internal/MessageFramerTest.java +++ b/core/src/test/java/io/grpc/internal/MessageFramerTest.java @@ -238,8 +238,9 @@ public class MessageFramerTest { @Test public void compressed() throws Exception { allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE); - framer = new MessageFramer(sink, allocator, new Codec.Gzip()); - framer.setMessageCompression(true); + framer = new MessageFramer(sink, allocator) + .setCompressor(new Codec.Gzip()) + .setMessageCompression(true); writeKnownLength(framer, new byte[1000]); framer.flush(); // The GRPC header is written first as a separate frame. @@ -262,8 +263,8 @@ public class MessageFramerTest { @Test public void dontCompressIfNoEncoding() throws Exception { allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE); - framer = new MessageFramer(sink, allocator, Codec.Identity.NONE); - framer.setMessageCompression(true); + framer = new MessageFramer(sink, allocator) + .setMessageCompression(true); writeKnownLength(framer, new byte[1000]); framer.flush(); // The GRPC header is written first as a separate frame @@ -286,8 +287,9 @@ public class MessageFramerTest { @Test public void dontCompressIfNotRequested() throws Exception { allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE); - framer = new MessageFramer(sink, allocator, new Codec.Gzip()); - framer.setMessageCompression(false); + framer = new MessageFramer(sink, allocator) + .setCompressor(new Codec.Gzip()) + .setMessageCompression(false); writeKnownLength(framer, new byte[1000]); framer.flush(); // The GRPC header is written first as a separate frame @@ -321,11 +323,19 @@ public class MessageFramerTest { } } }; - framer = new MessageFramer(reentrant, allocator, Codec.Identity.NONE); + framer = new MessageFramer(reentrant, allocator); writeKnownLength(framer, new byte[]{3, 14}); framer.close(); } + @Test + public void zeroLengthCompressibleMessageIsNotCompressed() { + framer.setCompressor(new Codec.Gzip()); + framer.setMessageCompression(true); + writeKnownLength(framer, new byte[]{}); + framer.flush(); + verify(sink).deliverFrame(toWriteBuffer(new byte[] {0, 0, 0, 0, 0}), false, true); + } private static WritableBuffer toWriteBuffer(byte[] data) { return toWriteBufferWithMinSize(data, 0);