Disallow compressing zero length messages.

This commit is contained in:
Carl Mastrangelo 2015-12-14 14:56:59 -08:00
parent e967be8d3f
commit f5d09ff0b2
3 changed files with 75 additions and 55 deletions

View File

@ -31,6 +31,7 @@
package io.grpc.internal;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static java.lang.Math.min;
@ -40,6 +41,7 @@ import io.grpc.Codec;
import io.grpc.Compressor;
import io.grpc.Drainable;
import io.grpc.KnownLength;
import io.grpc.Status;
import java.io.ByteArrayInputStream;
import java.io.IOException;
@ -76,7 +78,7 @@ public class MessageFramer {
private final Sink sink;
private WritableBuffer buffer;
private Compressor compressor;
private Compressor compressor = Codec.Identity.NONE;
private boolean messageCompression;
private final OutputStreamAdapter outputStreamAdapter = new OutputStreamAdapter();
private final byte[] headerScratch = new byte[HEADER_LENGTH];
@ -84,34 +86,24 @@ public class MessageFramer {
private boolean closed;
/**
* Creates a {@code MessageFramer} without compression.
* Creates a {@code MessageFramer}.
*
* @param sink the sink used to deliver frames to the transport
* @param bufferAllocator allocates buffers that the transport can commit to the wire.
*/
public MessageFramer(Sink sink, WritableBufferAllocator bufferAllocator) {
this(sink, bufferAllocator, Codec.Identity.NONE);
}
/**
* Creates a {@code MessageFramer}.
*
* @param sink the sink used to deliver frames to the transport
* @param bufferAllocator allocates buffers that the transport can commit to the wire.
* @param compressor the compressor to use
*/
public MessageFramer(Sink sink, WritableBufferAllocator bufferAllocator, Compressor compressor) {
this.sink = checkNotNull(sink, "sink");
this.bufferAllocator = bufferAllocator;
this.compressor = checkNotNull(compressor, "compressor");
this.bufferAllocator = checkNotNull(bufferAllocator, "bufferAllocator");
}
void setCompressor(Compressor compressor) {
MessageFramer setCompressor(Compressor compressor) {
this.compressor = checkNotNull(compressor, "Can't pass an empty compressor");
return this;
}
void setMessageCompression(boolean enable) {
MessageFramer setMessageCompression(boolean enable) {
messageCompression = enable;
return this;
}
/**
@ -121,47 +113,58 @@ public class MessageFramer {
*/
public void writePayload(InputStream message) {
verifyNotClosed();
boolean compressed = messageCompression && compressor != Codec.Identity.NONE;
int written = -1;
int messageLength = -2;
try {
if (messageCompression && compressor != Codec.Identity.NONE) {
writeCompressed(message);
messageLength = getKnownLength(message);
if (messageLength != 0 && compressed) {
written = writeCompressed(message, messageLength);
} else {
writeUncompressed(message);
written = writeUncompressed(message, messageLength);
}
} catch (IOException ex) {
throw new RuntimeException(ex);
} catch (IOException e) {
// This should not be possible, since sink#deliverFrame doesn't throw.
throw Status.INTERNAL
.withDescription("Failed to frame message")
.withCause(e)
.asRuntimeException();
} catch (RuntimeException e) {
throw Status.INTERNAL
.withDescription("Failed to frame message")
.withCause(e)
.asRuntimeException();
}
if (messageLength != -1 && written != messageLength) {
String err = String.format("Message length inaccurate %s != %s", written, messageLength);
throw Status.INTERNAL.withDescription(err).asRuntimeException();
}
}
private void writeUncompressed(InputStream message) throws IOException {
int messageLength = getKnownLength(message);
private int writeUncompressed(InputStream message, int messageLength) throws IOException {
if (messageLength != -1) {
writeKnownLength(message, messageLength, false);
} else {
BufferChainOutputStream bufferChain = new BufferChainOutputStream();
writeToOutputStream(message, bufferChain);
writeBufferChain(bufferChain, false);
return writeKnownLength(message, messageLength, false);
}
BufferChainOutputStream bufferChain = new BufferChainOutputStream();
int written = writeToOutputStream(message, bufferChain);
writeBufferChain(bufferChain, false);
return written;
}
private void writeCompressed(InputStream message) throws IOException {
private int writeCompressed(InputStream message, int messageLength) throws IOException {
BufferChainOutputStream bufferChain = new BufferChainOutputStream();
// Why this doesn't use getKnownLength() idk, but let's just roll with it.
int messageLength = -1;
if (message instanceof KnownLength) {
messageLength = message.available();
}
OutputStream compressingStream = compressor.compress(bufferChain);
int written;
try {
long written = writeToOutputStream(message, compressingStream);
if (messageLength != -1 && messageLength != written) {
throw new RuntimeException("Message length was inaccurate");
}
written = writeToOutputStream(message, compressingStream);
} finally {
compressingStream.close();
}
writeBufferChain(bufferChain, true);
return written;
}
private int getKnownLength(InputStream inputStream) throws IOException {
@ -174,7 +177,7 @@ public class MessageFramer {
/**
* Write an unserialized message with a known length.
*/
private void writeKnownLength(InputStream message, int messageLength, boolean compressed)
private int writeKnownLength(InputStream message, int messageLength, boolean compressed)
throws IOException {
ByteBuffer header = ByteBuffer.wrap(headerScratch);
header.put(compressed ? COMPRESSED : UNCOMPRESSED);
@ -185,10 +188,7 @@ public class MessageFramer {
buffer = bufferAllocator.allocate(header.position() + messageLength);
}
writeRaw(headerScratch, 0, header.position());
long written = writeToOutputStream(message, outputStreamAdapter);
if (messageLength != written) {
throw new RuntimeException("Message length was inaccurate");
}
return writeToOutputStream(message, outputStreamAdapter);
}
/**
@ -217,17 +217,19 @@ public class MessageFramer {
}
// Assign the current buffer to the last in the chain so it can be used
// for future writes or written with end-of-stream=true on close.
buffer = bufferList.get(bufferList.size() - 1);
buffer = bufferList.get(bufferList.size() - 1);
}
private static long writeToOutputStream(InputStream message, OutputStream outputStream)
private static int writeToOutputStream(InputStream message, OutputStream outputStream)
throws IOException {
if (message instanceof Drainable) {
return ((Drainable) message).drainTo(outputStream);
} else {
// This makes an unnecessary copy of the bytes when bytebuf supports array(). However, we
// expect performance-critical code to support flushTo().
return ByteStreams.copy(message, outputStream);
long written = ByteStreams.copy(message, outputStream);
checkArgument(written <= Integer.MAX_VALUE, "Message size overflow: %s", written);
return (int) written;
}
}

View File

@ -40,6 +40,7 @@ import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.Multimap;
import io.grpc.internal.AbstractStream.Phase;
import io.grpc.internal.MessageFramerTest.ByteWritableBuffer;
import org.junit.Before;
import org.junit.Test;
@ -59,6 +60,13 @@ public class AbstractStreamTest {
@Mock MessageFramer framer;
@Mock MessageDeframer deframer;
private final WritableBufferAllocator allocator = new WritableBufferAllocator() {
@Override
public WritableBuffer allocate(int capacityHint) {
return new ByteWritableBuffer(capacityHint);
}
};
@Before
public void setUp() {
MockitoAnnotations.initMocks(this);
@ -114,7 +122,7 @@ public class AbstractStreamTest {
*/
private class AbstractStreamBase<IdT> extends AbstractStream<IdT> {
private AbstractStreamBase(WritableBufferAllocator bufferAllocator) {
super(bufferAllocator, DEFAULT_MAX_MESSAGE_SIZE);
super(allocator, DEFAULT_MAX_MESSAGE_SIZE);
}
private AbstractStreamBase(MessageFramer framer, MessageDeframer deframer) {

View File

@ -238,8 +238,9 @@ public class MessageFramerTest {
@Test
public void compressed() throws Exception {
allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE);
framer = new MessageFramer(sink, allocator, new Codec.Gzip());
framer.setMessageCompression(true);
framer = new MessageFramer(sink, allocator)
.setCompressor(new Codec.Gzip())
.setMessageCompression(true);
writeKnownLength(framer, new byte[1000]);
framer.flush();
// The GRPC header is written first as a separate frame.
@ -262,8 +263,8 @@ public class MessageFramerTest {
@Test
public void dontCompressIfNoEncoding() throws Exception {
allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE);
framer = new MessageFramer(sink, allocator, Codec.Identity.NONE);
framer.setMessageCompression(true);
framer = new MessageFramer(sink, allocator)
.setMessageCompression(true);
writeKnownLength(framer, new byte[1000]);
framer.flush();
// The GRPC header is written first as a separate frame
@ -286,8 +287,9 @@ public class MessageFramerTest {
@Test
public void dontCompressIfNotRequested() throws Exception {
allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE);
framer = new MessageFramer(sink, allocator, new Codec.Gzip());
framer.setMessageCompression(false);
framer = new MessageFramer(sink, allocator)
.setCompressor(new Codec.Gzip())
.setMessageCompression(false);
writeKnownLength(framer, new byte[1000]);
framer.flush();
// The GRPC header is written first as a separate frame
@ -321,11 +323,19 @@ public class MessageFramerTest {
}
}
};
framer = new MessageFramer(reentrant, allocator, Codec.Identity.NONE);
framer = new MessageFramer(reentrant, allocator);
writeKnownLength(framer, new byte[]{3, 14});
framer.close();
}
@Test
public void zeroLengthCompressibleMessageIsNotCompressed() {
framer.setCompressor(new Codec.Gzip());
framer.setMessageCompression(true);
writeKnownLength(framer, new byte[]{});
framer.flush();
verify(sink).deliverFrame(toWriteBuffer(new byte[] {0, 0, 0, 0, 0}), false, true);
}
private static WritableBuffer toWriteBuffer(byte[] data) {
return toWriteBufferWithMinSize(data, 0);