Disallow compressing zero length messages.

This commit is contained in:
Carl Mastrangelo 2015-12-14 14:56:59 -08:00
parent e967be8d3f
commit f5d09ff0b2
3 changed files with 75 additions and 55 deletions

View File

@ -31,6 +31,7 @@
package io.grpc.internal; package io.grpc.internal;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import static java.lang.Math.min; import static java.lang.Math.min;
@ -40,6 +41,7 @@ import io.grpc.Codec;
import io.grpc.Compressor; import io.grpc.Compressor;
import io.grpc.Drainable; import io.grpc.Drainable;
import io.grpc.KnownLength; import io.grpc.KnownLength;
import io.grpc.Status;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.IOException; import java.io.IOException;
@ -76,7 +78,7 @@ public class MessageFramer {
private final Sink sink; private final Sink sink;
private WritableBuffer buffer; private WritableBuffer buffer;
private Compressor compressor; private Compressor compressor = Codec.Identity.NONE;
private boolean messageCompression; private boolean messageCompression;
private final OutputStreamAdapter outputStreamAdapter = new OutputStreamAdapter(); private final OutputStreamAdapter outputStreamAdapter = new OutputStreamAdapter();
private final byte[] headerScratch = new byte[HEADER_LENGTH]; private final byte[] headerScratch = new byte[HEADER_LENGTH];
@ -84,34 +86,24 @@ public class MessageFramer {
private boolean closed; private boolean closed;
/** /**
* Creates a {@code MessageFramer} without compression. * Creates a {@code MessageFramer}.
* *
* @param sink the sink used to deliver frames to the transport * @param sink the sink used to deliver frames to the transport
* @param bufferAllocator allocates buffers that the transport can commit to the wire. * @param bufferAllocator allocates buffers that the transport can commit to the wire.
*/ */
public MessageFramer(Sink sink, WritableBufferAllocator bufferAllocator) { public MessageFramer(Sink sink, WritableBufferAllocator bufferAllocator) {
this(sink, bufferAllocator, Codec.Identity.NONE);
}
/**
* Creates a {@code MessageFramer}.
*
* @param sink the sink used to deliver frames to the transport
* @param bufferAllocator allocates buffers that the transport can commit to the wire.
* @param compressor the compressor to use
*/
public MessageFramer(Sink sink, WritableBufferAllocator bufferAllocator, Compressor compressor) {
this.sink = checkNotNull(sink, "sink"); this.sink = checkNotNull(sink, "sink");
this.bufferAllocator = bufferAllocator; this.bufferAllocator = checkNotNull(bufferAllocator, "bufferAllocator");
this.compressor = checkNotNull(compressor, "compressor");
} }
void setCompressor(Compressor compressor) { MessageFramer setCompressor(Compressor compressor) {
this.compressor = checkNotNull(compressor, "Can't pass an empty compressor"); this.compressor = checkNotNull(compressor, "Can't pass an empty compressor");
return this;
} }
void setMessageCompression(boolean enable) { MessageFramer setMessageCompression(boolean enable) {
messageCompression = enable; messageCompression = enable;
return this;
} }
/** /**
@ -121,47 +113,58 @@ public class MessageFramer {
*/ */
public void writePayload(InputStream message) { public void writePayload(InputStream message) {
verifyNotClosed(); verifyNotClosed();
boolean compressed = messageCompression && compressor != Codec.Identity.NONE;
int written = -1;
int messageLength = -2;
try { try {
if (messageCompression && compressor != Codec.Identity.NONE) { messageLength = getKnownLength(message);
writeCompressed(message); if (messageLength != 0 && compressed) {
written = writeCompressed(message, messageLength);
} else { } else {
writeUncompressed(message); written = writeUncompressed(message, messageLength);
} }
} catch (IOException ex) { } catch (IOException e) {
throw new RuntimeException(ex); // This should not be possible, since sink#deliverFrame doesn't throw.
throw Status.INTERNAL
.withDescription("Failed to frame message")
.withCause(e)
.asRuntimeException();
} catch (RuntimeException e) {
throw Status.INTERNAL
.withDescription("Failed to frame message")
.withCause(e)
.asRuntimeException();
}
if (messageLength != -1 && written != messageLength) {
String err = String.format("Message length inaccurate %s != %s", written, messageLength);
throw Status.INTERNAL.withDescription(err).asRuntimeException();
} }
} }
private void writeUncompressed(InputStream message) throws IOException { private int writeUncompressed(InputStream message, int messageLength) throws IOException {
int messageLength = getKnownLength(message);
if (messageLength != -1) { if (messageLength != -1) {
writeKnownLength(message, messageLength, false); return writeKnownLength(message, messageLength, false);
} else {
BufferChainOutputStream bufferChain = new BufferChainOutputStream();
writeToOutputStream(message, bufferChain);
writeBufferChain(bufferChain, false);
} }
BufferChainOutputStream bufferChain = new BufferChainOutputStream();
int written = writeToOutputStream(message, bufferChain);
writeBufferChain(bufferChain, false);
return written;
} }
private void writeCompressed(InputStream message) throws IOException { private int writeCompressed(InputStream message, int messageLength) throws IOException {
BufferChainOutputStream bufferChain = new BufferChainOutputStream(); BufferChainOutputStream bufferChain = new BufferChainOutputStream();
// Why this doesn't use getKnownLength() idk, but let's just roll with it.
int messageLength = -1;
if (message instanceof KnownLength) {
messageLength = message.available();
}
OutputStream compressingStream = compressor.compress(bufferChain); OutputStream compressingStream = compressor.compress(bufferChain);
int written;
try { try {
long written = writeToOutputStream(message, compressingStream); written = writeToOutputStream(message, compressingStream);
if (messageLength != -1 && messageLength != written) {
throw new RuntimeException("Message length was inaccurate");
}
} finally { } finally {
compressingStream.close(); compressingStream.close();
} }
writeBufferChain(bufferChain, true); writeBufferChain(bufferChain, true);
return written;
} }
private int getKnownLength(InputStream inputStream) throws IOException { private int getKnownLength(InputStream inputStream) throws IOException {
@ -174,7 +177,7 @@ public class MessageFramer {
/** /**
* Write an unserialized message with a known length. * Write an unserialized message with a known length.
*/ */
private void writeKnownLength(InputStream message, int messageLength, boolean compressed) private int writeKnownLength(InputStream message, int messageLength, boolean compressed)
throws IOException { throws IOException {
ByteBuffer header = ByteBuffer.wrap(headerScratch); ByteBuffer header = ByteBuffer.wrap(headerScratch);
header.put(compressed ? COMPRESSED : UNCOMPRESSED); header.put(compressed ? COMPRESSED : UNCOMPRESSED);
@ -185,10 +188,7 @@ public class MessageFramer {
buffer = bufferAllocator.allocate(header.position() + messageLength); buffer = bufferAllocator.allocate(header.position() + messageLength);
} }
writeRaw(headerScratch, 0, header.position()); writeRaw(headerScratch, 0, header.position());
long written = writeToOutputStream(message, outputStreamAdapter); return writeToOutputStream(message, outputStreamAdapter);
if (messageLength != written) {
throw new RuntimeException("Message length was inaccurate");
}
} }
/** /**
@ -220,14 +220,16 @@ public class MessageFramer {
buffer = bufferList.get(bufferList.size() - 1); buffer = bufferList.get(bufferList.size() - 1);
} }
private static long writeToOutputStream(InputStream message, OutputStream outputStream) private static int writeToOutputStream(InputStream message, OutputStream outputStream)
throws IOException { throws IOException {
if (message instanceof Drainable) { if (message instanceof Drainable) {
return ((Drainable) message).drainTo(outputStream); return ((Drainable) message).drainTo(outputStream);
} else { } else {
// This makes an unnecessary copy of the bytes when bytebuf supports array(). However, we // This makes an unnecessary copy of the bytes when bytebuf supports array(). However, we
// expect performance-critical code to support flushTo(). // expect performance-critical code to support flushTo().
return ByteStreams.copy(message, outputStream); long written = ByteStreams.copy(message, outputStream);
checkArgument(written <= Integer.MAX_VALUE, "Message size overflow: %s", written);
return (int) written;
} }
} }

View File

@ -40,6 +40,7 @@ import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.Multimap; import com.google.common.collect.Multimap;
import io.grpc.internal.AbstractStream.Phase; import io.grpc.internal.AbstractStream.Phase;
import io.grpc.internal.MessageFramerTest.ByteWritableBuffer;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
@ -59,6 +60,13 @@ public class AbstractStreamTest {
@Mock MessageFramer framer; @Mock MessageFramer framer;
@Mock MessageDeframer deframer; @Mock MessageDeframer deframer;
private final WritableBufferAllocator allocator = new WritableBufferAllocator() {
@Override
public WritableBuffer allocate(int capacityHint) {
return new ByteWritableBuffer(capacityHint);
}
};
@Before @Before
public void setUp() { public void setUp() {
MockitoAnnotations.initMocks(this); MockitoAnnotations.initMocks(this);
@ -114,7 +122,7 @@ public class AbstractStreamTest {
*/ */
private class AbstractStreamBase<IdT> extends AbstractStream<IdT> { private class AbstractStreamBase<IdT> extends AbstractStream<IdT> {
private AbstractStreamBase(WritableBufferAllocator bufferAllocator) { private AbstractStreamBase(WritableBufferAllocator bufferAllocator) {
super(bufferAllocator, DEFAULT_MAX_MESSAGE_SIZE); super(allocator, DEFAULT_MAX_MESSAGE_SIZE);
} }
private AbstractStreamBase(MessageFramer framer, MessageDeframer deframer) { private AbstractStreamBase(MessageFramer framer, MessageDeframer deframer) {

View File

@ -238,8 +238,9 @@ public class MessageFramerTest {
@Test @Test
public void compressed() throws Exception { public void compressed() throws Exception {
allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE); allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE);
framer = new MessageFramer(sink, allocator, new Codec.Gzip()); framer = new MessageFramer(sink, allocator)
framer.setMessageCompression(true); .setCompressor(new Codec.Gzip())
.setMessageCompression(true);
writeKnownLength(framer, new byte[1000]); writeKnownLength(framer, new byte[1000]);
framer.flush(); framer.flush();
// The GRPC header is written first as a separate frame. // The GRPC header is written first as a separate frame.
@ -262,8 +263,8 @@ public class MessageFramerTest {
@Test @Test
public void dontCompressIfNoEncoding() throws Exception { public void dontCompressIfNoEncoding() throws Exception {
allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE); allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE);
framer = new MessageFramer(sink, allocator, Codec.Identity.NONE); framer = new MessageFramer(sink, allocator)
framer.setMessageCompression(true); .setMessageCompression(true);
writeKnownLength(framer, new byte[1000]); writeKnownLength(framer, new byte[1000]);
framer.flush(); framer.flush();
// The GRPC header is written first as a separate frame // The GRPC header is written first as a separate frame
@ -286,8 +287,9 @@ public class MessageFramerTest {
@Test @Test
public void dontCompressIfNotRequested() throws Exception { public void dontCompressIfNotRequested() throws Exception {
allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE); allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE);
framer = new MessageFramer(sink, allocator, new Codec.Gzip()); framer = new MessageFramer(sink, allocator)
framer.setMessageCompression(false); .setCompressor(new Codec.Gzip())
.setMessageCompression(false);
writeKnownLength(framer, new byte[1000]); writeKnownLength(framer, new byte[1000]);
framer.flush(); framer.flush();
// The GRPC header is written first as a separate frame // The GRPC header is written first as a separate frame
@ -321,11 +323,19 @@ public class MessageFramerTest {
} }
} }
}; };
framer = new MessageFramer(reentrant, allocator, Codec.Identity.NONE); framer = new MessageFramer(reentrant, allocator);
writeKnownLength(framer, new byte[]{3, 14}); writeKnownLength(framer, new byte[]{3, 14});
framer.close(); framer.close();
} }
@Test
public void zeroLengthCompressibleMessageIsNotCompressed() {
framer.setCompressor(new Codec.Gzip());
framer.setMessageCompression(true);
writeKnownLength(framer, new byte[]{});
framer.flush();
verify(sink).deliverFrame(toWriteBuffer(new byte[] {0, 0, 0, 0, 0}), false, true);
}
private static WritableBuffer toWriteBuffer(byte[] data) { private static WritableBuffer toWriteBuffer(byte[] data) {
return toWriteBufferWithMinSize(data, 0); return toWriteBufferWithMinSize(data, 0);