mirror of https://github.com/grpc/grpc-java.git
Disallow compressing zero length messages.
This commit is contained in:
parent
e967be8d3f
commit
f5d09ff0b2
|
|
@ -31,6 +31,7 @@
|
||||||
|
|
||||||
package io.grpc.internal;
|
package io.grpc.internal;
|
||||||
|
|
||||||
|
import static com.google.common.base.Preconditions.checkArgument;
|
||||||
import static com.google.common.base.Preconditions.checkNotNull;
|
import static com.google.common.base.Preconditions.checkNotNull;
|
||||||
import static java.lang.Math.min;
|
import static java.lang.Math.min;
|
||||||
|
|
||||||
|
|
@ -40,6 +41,7 @@ import io.grpc.Codec;
|
||||||
import io.grpc.Compressor;
|
import io.grpc.Compressor;
|
||||||
import io.grpc.Drainable;
|
import io.grpc.Drainable;
|
||||||
import io.grpc.KnownLength;
|
import io.grpc.KnownLength;
|
||||||
|
import io.grpc.Status;
|
||||||
|
|
||||||
import java.io.ByteArrayInputStream;
|
import java.io.ByteArrayInputStream;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
@ -76,7 +78,7 @@ public class MessageFramer {
|
||||||
|
|
||||||
private final Sink sink;
|
private final Sink sink;
|
||||||
private WritableBuffer buffer;
|
private WritableBuffer buffer;
|
||||||
private Compressor compressor;
|
private Compressor compressor = Codec.Identity.NONE;
|
||||||
private boolean messageCompression;
|
private boolean messageCompression;
|
||||||
private final OutputStreamAdapter outputStreamAdapter = new OutputStreamAdapter();
|
private final OutputStreamAdapter outputStreamAdapter = new OutputStreamAdapter();
|
||||||
private final byte[] headerScratch = new byte[HEADER_LENGTH];
|
private final byte[] headerScratch = new byte[HEADER_LENGTH];
|
||||||
|
|
@ -84,34 +86,24 @@ public class MessageFramer {
|
||||||
private boolean closed;
|
private boolean closed;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a {@code MessageFramer} without compression.
|
* Creates a {@code MessageFramer}.
|
||||||
*
|
*
|
||||||
* @param sink the sink used to deliver frames to the transport
|
* @param sink the sink used to deliver frames to the transport
|
||||||
* @param bufferAllocator allocates buffers that the transport can commit to the wire.
|
* @param bufferAllocator allocates buffers that the transport can commit to the wire.
|
||||||
*/
|
*/
|
||||||
public MessageFramer(Sink sink, WritableBufferAllocator bufferAllocator) {
|
public MessageFramer(Sink sink, WritableBufferAllocator bufferAllocator) {
|
||||||
this(sink, bufferAllocator, Codec.Identity.NONE);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a {@code MessageFramer}.
|
|
||||||
*
|
|
||||||
* @param sink the sink used to deliver frames to the transport
|
|
||||||
* @param bufferAllocator allocates buffers that the transport can commit to the wire.
|
|
||||||
* @param compressor the compressor to use
|
|
||||||
*/
|
|
||||||
public MessageFramer(Sink sink, WritableBufferAllocator bufferAllocator, Compressor compressor) {
|
|
||||||
this.sink = checkNotNull(sink, "sink");
|
this.sink = checkNotNull(sink, "sink");
|
||||||
this.bufferAllocator = bufferAllocator;
|
this.bufferAllocator = checkNotNull(bufferAllocator, "bufferAllocator");
|
||||||
this.compressor = checkNotNull(compressor, "compressor");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void setCompressor(Compressor compressor) {
|
MessageFramer setCompressor(Compressor compressor) {
|
||||||
this.compressor = checkNotNull(compressor, "Can't pass an empty compressor");
|
this.compressor = checkNotNull(compressor, "Can't pass an empty compressor");
|
||||||
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
void setMessageCompression(boolean enable) {
|
MessageFramer setMessageCompression(boolean enable) {
|
||||||
messageCompression = enable;
|
messageCompression = enable;
|
||||||
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -121,47 +113,58 @@ public class MessageFramer {
|
||||||
*/
|
*/
|
||||||
public void writePayload(InputStream message) {
|
public void writePayload(InputStream message) {
|
||||||
verifyNotClosed();
|
verifyNotClosed();
|
||||||
|
boolean compressed = messageCompression && compressor != Codec.Identity.NONE;
|
||||||
|
int written = -1;
|
||||||
|
int messageLength = -2;
|
||||||
try {
|
try {
|
||||||
if (messageCompression && compressor != Codec.Identity.NONE) {
|
messageLength = getKnownLength(message);
|
||||||
writeCompressed(message);
|
if (messageLength != 0 && compressed) {
|
||||||
|
written = writeCompressed(message, messageLength);
|
||||||
} else {
|
} else {
|
||||||
writeUncompressed(message);
|
written = writeUncompressed(message, messageLength);
|
||||||
}
|
}
|
||||||
} catch (IOException ex) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException(ex);
|
// This should not be possible, since sink#deliverFrame doesn't throw.
|
||||||
|
throw Status.INTERNAL
|
||||||
|
.withDescription("Failed to frame message")
|
||||||
|
.withCause(e)
|
||||||
|
.asRuntimeException();
|
||||||
|
} catch (RuntimeException e) {
|
||||||
|
throw Status.INTERNAL
|
||||||
|
.withDescription("Failed to frame message")
|
||||||
|
.withCause(e)
|
||||||
|
.asRuntimeException();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (messageLength != -1 && written != messageLength) {
|
||||||
|
String err = String.format("Message length inaccurate %s != %s", written, messageLength);
|
||||||
|
throw Status.INTERNAL.withDescription(err).asRuntimeException();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void writeUncompressed(InputStream message) throws IOException {
|
private int writeUncompressed(InputStream message, int messageLength) throws IOException {
|
||||||
int messageLength = getKnownLength(message);
|
|
||||||
if (messageLength != -1) {
|
if (messageLength != -1) {
|
||||||
writeKnownLength(message, messageLength, false);
|
return writeKnownLength(message, messageLength, false);
|
||||||
} else {
|
|
||||||
BufferChainOutputStream bufferChain = new BufferChainOutputStream();
|
|
||||||
writeToOutputStream(message, bufferChain);
|
|
||||||
writeBufferChain(bufferChain, false);
|
|
||||||
}
|
}
|
||||||
|
BufferChainOutputStream bufferChain = new BufferChainOutputStream();
|
||||||
|
int written = writeToOutputStream(message, bufferChain);
|
||||||
|
writeBufferChain(bufferChain, false);
|
||||||
|
return written;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void writeCompressed(InputStream message) throws IOException {
|
private int writeCompressed(InputStream message, int messageLength) throws IOException {
|
||||||
BufferChainOutputStream bufferChain = new BufferChainOutputStream();
|
BufferChainOutputStream bufferChain = new BufferChainOutputStream();
|
||||||
// Why this doesn't use getKnownLength() idk, but let's just roll with it.
|
|
||||||
int messageLength = -1;
|
|
||||||
if (message instanceof KnownLength) {
|
|
||||||
messageLength = message.available();
|
|
||||||
}
|
|
||||||
|
|
||||||
OutputStream compressingStream = compressor.compress(bufferChain);
|
OutputStream compressingStream = compressor.compress(bufferChain);
|
||||||
|
int written;
|
||||||
try {
|
try {
|
||||||
long written = writeToOutputStream(message, compressingStream);
|
written = writeToOutputStream(message, compressingStream);
|
||||||
if (messageLength != -1 && messageLength != written) {
|
|
||||||
throw new RuntimeException("Message length was inaccurate");
|
|
||||||
}
|
|
||||||
} finally {
|
} finally {
|
||||||
compressingStream.close();
|
compressingStream.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
writeBufferChain(bufferChain, true);
|
writeBufferChain(bufferChain, true);
|
||||||
|
return written;
|
||||||
}
|
}
|
||||||
|
|
||||||
private int getKnownLength(InputStream inputStream) throws IOException {
|
private int getKnownLength(InputStream inputStream) throws IOException {
|
||||||
|
|
@ -174,7 +177,7 @@ public class MessageFramer {
|
||||||
/**
|
/**
|
||||||
* Write an unserialized message with a known length.
|
* Write an unserialized message with a known length.
|
||||||
*/
|
*/
|
||||||
private void writeKnownLength(InputStream message, int messageLength, boolean compressed)
|
private int writeKnownLength(InputStream message, int messageLength, boolean compressed)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
ByteBuffer header = ByteBuffer.wrap(headerScratch);
|
ByteBuffer header = ByteBuffer.wrap(headerScratch);
|
||||||
header.put(compressed ? COMPRESSED : UNCOMPRESSED);
|
header.put(compressed ? COMPRESSED : UNCOMPRESSED);
|
||||||
|
|
@ -185,10 +188,7 @@ public class MessageFramer {
|
||||||
buffer = bufferAllocator.allocate(header.position() + messageLength);
|
buffer = bufferAllocator.allocate(header.position() + messageLength);
|
||||||
}
|
}
|
||||||
writeRaw(headerScratch, 0, header.position());
|
writeRaw(headerScratch, 0, header.position());
|
||||||
long written = writeToOutputStream(message, outputStreamAdapter);
|
return writeToOutputStream(message, outputStreamAdapter);
|
||||||
if (messageLength != written) {
|
|
||||||
throw new RuntimeException("Message length was inaccurate");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -220,14 +220,16 @@ public class MessageFramer {
|
||||||
buffer = bufferList.get(bufferList.size() - 1);
|
buffer = bufferList.get(bufferList.size() - 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static long writeToOutputStream(InputStream message, OutputStream outputStream)
|
private static int writeToOutputStream(InputStream message, OutputStream outputStream)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
if (message instanceof Drainable) {
|
if (message instanceof Drainable) {
|
||||||
return ((Drainable) message).drainTo(outputStream);
|
return ((Drainable) message).drainTo(outputStream);
|
||||||
} else {
|
} else {
|
||||||
// This makes an unnecessary copy of the bytes when bytebuf supports array(). However, we
|
// This makes an unnecessary copy of the bytes when bytebuf supports array(). However, we
|
||||||
// expect performance-critical code to support flushTo().
|
// expect performance-critical code to support flushTo().
|
||||||
return ByteStreams.copy(message, outputStream);
|
long written = ByteStreams.copy(message, outputStream);
|
||||||
|
checkArgument(written <= Integer.MAX_VALUE, "Message size overflow: %s", written);
|
||||||
|
return (int) written;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,7 @@ import com.google.common.collect.ImmutableMultimap;
|
||||||
import com.google.common.collect.Multimap;
|
import com.google.common.collect.Multimap;
|
||||||
|
|
||||||
import io.grpc.internal.AbstractStream.Phase;
|
import io.grpc.internal.AbstractStream.Phase;
|
||||||
|
import io.grpc.internal.MessageFramerTest.ByteWritableBuffer;
|
||||||
|
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
@ -59,6 +60,13 @@ public class AbstractStreamTest {
|
||||||
@Mock MessageFramer framer;
|
@Mock MessageFramer framer;
|
||||||
@Mock MessageDeframer deframer;
|
@Mock MessageDeframer deframer;
|
||||||
|
|
||||||
|
private final WritableBufferAllocator allocator = new WritableBufferAllocator() {
|
||||||
|
@Override
|
||||||
|
public WritableBuffer allocate(int capacityHint) {
|
||||||
|
return new ByteWritableBuffer(capacityHint);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
MockitoAnnotations.initMocks(this);
|
MockitoAnnotations.initMocks(this);
|
||||||
|
|
@ -114,7 +122,7 @@ public class AbstractStreamTest {
|
||||||
*/
|
*/
|
||||||
private class AbstractStreamBase<IdT> extends AbstractStream<IdT> {
|
private class AbstractStreamBase<IdT> extends AbstractStream<IdT> {
|
||||||
private AbstractStreamBase(WritableBufferAllocator bufferAllocator) {
|
private AbstractStreamBase(WritableBufferAllocator bufferAllocator) {
|
||||||
super(bufferAllocator, DEFAULT_MAX_MESSAGE_SIZE);
|
super(allocator, DEFAULT_MAX_MESSAGE_SIZE);
|
||||||
}
|
}
|
||||||
|
|
||||||
private AbstractStreamBase(MessageFramer framer, MessageDeframer deframer) {
|
private AbstractStreamBase(MessageFramer framer, MessageDeframer deframer) {
|
||||||
|
|
|
||||||
|
|
@ -238,8 +238,9 @@ public class MessageFramerTest {
|
||||||
@Test
|
@Test
|
||||||
public void compressed() throws Exception {
|
public void compressed() throws Exception {
|
||||||
allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE);
|
allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE);
|
||||||
framer = new MessageFramer(sink, allocator, new Codec.Gzip());
|
framer = new MessageFramer(sink, allocator)
|
||||||
framer.setMessageCompression(true);
|
.setCompressor(new Codec.Gzip())
|
||||||
|
.setMessageCompression(true);
|
||||||
writeKnownLength(framer, new byte[1000]);
|
writeKnownLength(framer, new byte[1000]);
|
||||||
framer.flush();
|
framer.flush();
|
||||||
// The GRPC header is written first as a separate frame.
|
// The GRPC header is written first as a separate frame.
|
||||||
|
|
@ -262,8 +263,8 @@ public class MessageFramerTest {
|
||||||
@Test
|
@Test
|
||||||
public void dontCompressIfNoEncoding() throws Exception {
|
public void dontCompressIfNoEncoding() throws Exception {
|
||||||
allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE);
|
allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE);
|
||||||
framer = new MessageFramer(sink, allocator, Codec.Identity.NONE);
|
framer = new MessageFramer(sink, allocator)
|
||||||
framer.setMessageCompression(true);
|
.setMessageCompression(true);
|
||||||
writeKnownLength(framer, new byte[1000]);
|
writeKnownLength(framer, new byte[1000]);
|
||||||
framer.flush();
|
framer.flush();
|
||||||
// The GRPC header is written first as a separate frame
|
// The GRPC header is written first as a separate frame
|
||||||
|
|
@ -286,8 +287,9 @@ public class MessageFramerTest {
|
||||||
@Test
|
@Test
|
||||||
public void dontCompressIfNotRequested() throws Exception {
|
public void dontCompressIfNotRequested() throws Exception {
|
||||||
allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE);
|
allocator = new BytesWritableBufferAllocator(100, Integer.MAX_VALUE);
|
||||||
framer = new MessageFramer(sink, allocator, new Codec.Gzip());
|
framer = new MessageFramer(sink, allocator)
|
||||||
framer.setMessageCompression(false);
|
.setCompressor(new Codec.Gzip())
|
||||||
|
.setMessageCompression(false);
|
||||||
writeKnownLength(framer, new byte[1000]);
|
writeKnownLength(framer, new byte[1000]);
|
||||||
framer.flush();
|
framer.flush();
|
||||||
// The GRPC header is written first as a separate frame
|
// The GRPC header is written first as a separate frame
|
||||||
|
|
@ -321,11 +323,19 @@ public class MessageFramerTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
framer = new MessageFramer(reentrant, allocator, Codec.Identity.NONE);
|
framer = new MessageFramer(reentrant, allocator);
|
||||||
writeKnownLength(framer, new byte[]{3, 14});
|
writeKnownLength(framer, new byte[]{3, 14});
|
||||||
framer.close();
|
framer.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void zeroLengthCompressibleMessageIsNotCompressed() {
|
||||||
|
framer.setCompressor(new Codec.Gzip());
|
||||||
|
framer.setMessageCompression(true);
|
||||||
|
writeKnownLength(framer, new byte[]{});
|
||||||
|
framer.flush();
|
||||||
|
verify(sink).deliverFrame(toWriteBuffer(new byte[] {0, 0, 0, 0, 0}), false, true);
|
||||||
|
}
|
||||||
|
|
||||||
private static WritableBuffer toWriteBuffer(byte[] data) {
|
private static WritableBuffer toWriteBuffer(byte[] data) {
|
||||||
return toWriteBufferWithMinSize(data, 0);
|
return toWriteBufferWithMinSize(data, 0);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue