mirror of https://github.com/grpc/grpc-java.git
Disallow compressing zero length messages.
This commit is contained in:
parent
e967be8d3f
commit
f5d09ff0b2
|
|
@ -31,6 +31,7 @@
|
|||
|
||||
package io.grpc.internal;
|
||||
|
||||
import static com.google.common.base.Preconditions.checkArgument;
|
||||
import static com.google.common.base.Preconditions.checkNotNull;
|
||||
import static java.lang.Math.min;
|
||||
|
||||
|
|
@ -40,6 +41,7 @@ import io.grpc.Codec;
|
|||
import io.grpc.Compressor;
|
||||
import io.grpc.Drainable;
|
||||
import io.grpc.KnownLength;
|
||||
import io.grpc.Status;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.IOException;
|
||||
|
|
@ -76,7 +78,7 @@ public class MessageFramer {
|
|||
|
||||
private final Sink sink;
|
||||
private WritableBuffer buffer;
|
||||
private Compressor compressor;
|
||||
private Compressor compressor = Codec.Identity.NONE;
|
||||
private boolean messageCompression;
|
||||
private final OutputStreamAdapter outputStreamAdapter = new OutputStreamAdapter();
|
||||
private final byte[] headerScratch = new byte[HEADER_LENGTH];
|
||||
|
|
@ -84,34 +86,24 @@ public class MessageFramer {
|
|||
private boolean closed;
|
||||
|
||||
/**
|
||||
* Creates a {@code MessageFramer} without compression.
|
||||
* Creates a {@code MessageFramer}.
|
||||
*
|
||||
* @param sink the sink used to deliver frames to the transport
|
||||
* @param bufferAllocator allocates buffers that the transport can commit to the wire.
|
||||
*/
|
||||
public MessageFramer(Sink sink, WritableBufferAllocator bufferAllocator) {
|
||||
this(sink, bufferAllocator, Codec.Identity.NONE);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@code MessageFramer}.
|
||||
*
|
||||
* @param sink the sink used to deliver frames to the transport
|
||||
* @param bufferAllocator allocates buffers that the transport can commit to the wire.
|
||||
* @param compressor the compressor to use
|
||||
*/
|
||||
public MessageFramer(Sink sink, WritableBufferAllocator bufferAllocator, Compressor compressor) {
|
||||
this.sink = checkNotNull(sink, "sink");
|
||||
this.bufferAllocator = bufferAllocator;
|
||||
this.compressor = checkNotNull(compressor, "compressor");
|
||||
this.bufferAllocator = checkNotNull(bufferAllocator, "bufferAllocator");
|
||||
}
|
||||
|
||||
void setCompressor(Compressor compressor) {
|
||||
MessageFramer setCompressor(Compressor compressor) {
|
||||
this.compressor = checkNotNull(compressor, "Can't pass an empty compressor");
|
||||
return this;
|
||||
}
|
||||
|
||||
void setMessageCompression(boolean enable) {
|
||||
MessageFramer setMessageCompression(boolean enable) {
|
||||
messageCompression = enable;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -121,47 +113,58 @@ public class MessageFramer {
|
|||
*/
|
||||
public void writePayload(InputStream message) {
|
||||
verifyNotClosed();
|
||||
boolean compressed = messageCompression && compressor != Codec.Identity.NONE;
|
||||
int written = -1;
|
||||
int messageLength = -2;
|
||||
try {
|
||||
if (messageCompression && compressor != Codec.Identity.NONE) {
|
||||
writeCompressed(message);
|
||||
messageLength = getKnownLength(message);
|
||||
if (messageLength != 0 && compressed) {
|
||||
written = writeCompressed(message, messageLength);
|
||||
} else {
|
||||
writeUncompressed(message);
|
||||
written = writeUncompressed(message, messageLength);
|
||||
}
|
||||
} catch (IOException ex) {
|
||||
throw new RuntimeException(ex);
|
||||
} catch (IOException e) {
|
||||
// This should not be possible, since sink#deliverFrame doesn't throw.
|
||||
throw Status.INTERNAL
|
||||
.withDescription("Failed to frame message")
|
||||
.withCause(e)
|
||||
.asRuntimeException();
|
||||
} catch (RuntimeException e) {
|
||||
throw Status.INTERNAL
|
||||
.withDescription("Failed to frame message")
|
||||
.withCause(e)
|
||||
.asRuntimeException();
|
||||
}
|
||||
|
||||
if (messageLength != -1 && written != messageLength) {
|
||||
String err = String.format("Message length inaccurate %s != %s", written, messageLength);
|
||||
throw Status.INTERNAL.withDescription(err).asRuntimeException();
|
||||
}
|
||||
}
|
||||
|
||||
private void writeUncompressed(InputStream message) throws IOException {
|
||||
int messageLength = getKnownLength(message);
|
||||
private int writeUncompressed(InputStream message, int messageLength) throws IOException {
|
||||
if (messageLength != -1) {
|
||||
writeKnownLength(message, messageLength, false);
|
||||
} else {
|
||||
BufferChainOutputStream bufferChain = new BufferChainOutputStream();
|
||||
writeToOutputStream(message, bufferChain);
|
||||
writeBufferChain(bufferChain, false);
|
||||
return writeKnownLength(message, messageLength, false);
|
||||
}
|
||||
BufferChainOutputStream bufferChain = new BufferChainOutputStream();
|
||||
int written = writeToOutputStream(message, bufferChain);
|
||||
writeBufferChain(bufferChain, false);
|
||||
return written;
|
||||
}
|
||||
|
||||
private void writeCompressed(InputStream message) throws IOException {
|
||||
private int writeCompressed(InputStream message, int messageLength) throws IOException {
|
||||
BufferChainOutputStream bufferChain = new BufferChainOutputStream();
|
||||
// Why this doesn't use getKnownLength() idk, but let's just roll with it.
|
||||
int messageLength = -1;
|
||||
if (message instanceof KnownLength) {
|
||||
messageLength = message.available();
|
||||
}
|
||||
|
||||
OutputStream compressingStream = compressor.compress(bufferChain);
|
||||
int written;
|
||||
try {
|
||||
long written = writeToOutputStream(message, compressingStream);
|
||||
if (messageLength != -1 && messageLength != written) {
|
||||
throw new RuntimeException("Message length was inaccurate");
|
||||
}
|
||||
written = writeToOutputStream(message, compressingStream);
|
||||
} finally {
|
||||
compressingStream.close();
|
||||
}
|
||||
|
||||
writeBufferChain(bufferChain, true);
|
||||
return written;
|
||||
}
|
||||
|
||||
private int getKnownLength(InputStream inputStream) throws IOException {
|
||||
|
|
@ -174,7 +177,7 @@ public class MessageFramer {
|
|||
/**
|
||||
* Write an unserialized message with a known length.
|
||||
*/
|
||||
private void writeKnownLength(InputStream message, int messageLength, boolean compressed)
|
||||
private int writeKnownLength(InputStream message, int messageLength, boolean compressed)
|
||||
throws IOException {
|
||||
ByteBuffer header = ByteBuffer.wrap(headerScratch);
|
||||
header.put(compressed ? COMPRESSED : UNCOMPRESSED);
|
||||
|
|
@ -185,10 +188,7 @@ public class MessageFramer {
|
|||
buffer = bufferAllocator.allocate(header.position() + messageLength);
|
||||
}
|
||||
writeRaw(headerScratch, 0, header.position());
|
||||
long written = writeToOutputStream(message, outputStreamAdapter);
|
||||
if (messageLength != written) {
|
||||
throw new RuntimeException("Message length was inaccurate");
|
||||
}
|
||||
return writeToOutputStream(message, outputStreamAdapter);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -217,17 +217,19 @@ public class MessageFramer {
|
|||
}
|
||||
// Assign the current buffer to the last in the chain so it can be used
|
||||
// for future writes or written with end-of-stream=true on close.
|
||||
buffer = bufferList.get(bufferList.size() - 1);
|
||||
buffer = bufferList.get(bufferList.size() - 1);
|
||||
}
|
||||
|
||||
private static long writeToOutputStream(InputStream message, OutputStream outputStream)
|
||||
private static int writeToOutputStream(InputStream message, OutputStream outputStream)
|
||||
throws IOException {
|
||||
if (message instanceof Drainable) {
|
||||
return ((Drainable) message).drainTo(outputStream);
|
||||
} else {
|
||||
// This makes an unnecessary copy of the bytes when bytebuf supports array(). However, we
|
||||
// expect performance-critical code to support flushTo().
|
||||
return ByteStreams.copy(message, outputStream);
|
||||
long written = ByteStreams.copy(message, outputStream);
|
||||
checkArgument(written <= Integer.MAX_VALUE, "Message size overflow: %s", written);
|
||||
return (int) written;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ import com.google.common.collect.ImmutableMultimap;
|
|||
import com.google.common.collect.Multimap;
|
||||
|
||||
import io.grpc.internal.AbstractStream.Phase;
|
||||
import io.grpc.internal.MessageFramerTest.ByteWritableBuffer;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
|
@ -59,6 +60,13 @@ public class AbstractStreamTest {
|
|||
@Mock MessageFramer framer;
|
||||
@Mock MessageDeframer deframer;
|
||||
|
||||
private final WritableBufferAllocator allocator = new WritableBufferAllocator() {
|
||||
@Override
|
||||
public WritableBuffer allocate(int capacityHint) {
|
||||
return new ByteWritableBuffer(capacityHint);
|
||||
}
|
||||
};
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
MockitoAnnotations.initMocks(this);
|
||||
|
|
@ -114,7 +122,7 @@ public class AbstractStreamTest {
|
|||
*/
|
||||
private class AbstractStreamBase<IdT> extends AbstractStream<IdT> {
|
||||
private AbstractStreamBase(WritableBufferAllocator bufferAllocator) {
|
||||
super(bufferAllocator, DEFAULT_MAX_MESSAGE_SIZE);
|
||||
super(allocator, DEFAULT_MAX_MESSAGE_SIZE);
|
||||
}
|
||||
|
||||
private AbstractStreamBase(MessageFramer framer, MessageDeframer deframer) {
|
||||
|
|
|
|||
|
|
@ -238,8 +238,9 @@ public class MessageFramerTest {
|
|||
@Test
|
||||
public void compressed() throws Exception {
|
||||
allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE);
|
||||
framer = new MessageFramer(sink, allocator, new Codec.Gzip());
|
||||
framer.setMessageCompression(true);
|
||||
framer = new MessageFramer(sink, allocator)
|
||||
.setCompressor(new Codec.Gzip())
|
||||
.setMessageCompression(true);
|
||||
writeKnownLength(framer, new byte[1000]);
|
||||
framer.flush();
|
||||
// The GRPC header is written first as a separate frame.
|
||||
|
|
@ -262,8 +263,8 @@ public class MessageFramerTest {
|
|||
@Test
|
||||
public void dontCompressIfNoEncoding() throws Exception {
|
||||
allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE);
|
||||
framer = new MessageFramer(sink, allocator, Codec.Identity.NONE);
|
||||
framer.setMessageCompression(true);
|
||||
framer = new MessageFramer(sink, allocator)
|
||||
.setMessageCompression(true);
|
||||
writeKnownLength(framer, new byte[1000]);
|
||||
framer.flush();
|
||||
// The GRPC header is written first as a separate frame
|
||||
|
|
@ -286,8 +287,9 @@ public class MessageFramerTest {
|
|||
@Test
|
||||
public void dontCompressIfNotRequested() throws Exception {
|
||||
allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE);
|
||||
framer = new MessageFramer(sink, allocator, new Codec.Gzip());
|
||||
framer.setMessageCompression(false);
|
||||
framer = new MessageFramer(sink, allocator)
|
||||
.setCompressor(new Codec.Gzip())
|
||||
.setMessageCompression(false);
|
||||
writeKnownLength(framer, new byte[1000]);
|
||||
framer.flush();
|
||||
// The GRPC header is written first as a separate frame
|
||||
|
|
@ -321,11 +323,19 @@ public class MessageFramerTest {
|
|||
}
|
||||
}
|
||||
};
|
||||
framer = new MessageFramer(reentrant, allocator, Codec.Identity.NONE);
|
||||
framer = new MessageFramer(reentrant, allocator);
|
||||
writeKnownLength(framer, new byte[]{3, 14});
|
||||
framer.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void zeroLengthCompressibleMessageIsNotCompressed() {
|
||||
framer.setCompressor(new Codec.Gzip());
|
||||
framer.setMessageCompression(true);
|
||||
writeKnownLength(framer, new byte[]{});
|
||||
framer.flush();
|
||||
verify(sink).deliverFrame(toWriteBuffer(new byte[] {0, 0, 0, 0, 0}), false, true);
|
||||
}
|
||||
|
||||
private static WritableBuffer toWriteBuffer(byte[] data) {
|
||||
return toWriteBufferWithMinSize(data, 0);
|
||||
|
|
|
|||
Loading…
Reference in New Issue