diff --git a/core/src/main/java/io/grpc/internal/MessageFramer.java b/core/src/main/java/io/grpc/internal/MessageFramer.java index 5e75fa2e6f..8b5ccb864a 100644 --- a/core/src/main/java/io/grpc/internal/MessageFramer.java +++ b/core/src/main/java/io/grpc/internal/MessageFramer.java @@ -75,6 +75,10 @@ public class MessageFramer implements Framer { // effectively final. Can only be set once. private int maxOutboundMessageSize = NO_MAX_OUTBOUND_MESSAGE_SIZE; private WritableBuffer buffer; + /** + * if > 0 - the number of bytes to allocate for the current known-length message. + */ + private int knownLengthPendingAllocation; private Compressor compressor = Codec.Identity.NONE; private boolean messageCompression = true; private final OutputStreamAdapter outputStreamAdapter = new OutputStreamAdapter(); @@ -222,9 +226,7 @@ public class MessageFramer implements Framer { headerScratch.put(UNCOMPRESSED).putInt(messageLength); // Allocate the initial buffer chunk based on frame header + payload length. // Note that the allocator may allocate a buffer larger or smaller than this length - if (buffer == null) { - buffer = bufferAllocator.allocate(headerScratch.position() + messageLength); - } + knownLengthPendingAllocation = HEADER_LENGTH + messageLength; writeRaw(headerScratch.array(), 0, headerScratch.position()); return writeToOutputStream(message, outputStreamAdapter); } @@ -288,8 +290,9 @@ public class MessageFramer implements Framer { commitToSink(false, false); } if (buffer == null) { - // Request a buffer allocation using the message length as a hint. - buffer = bufferAllocator.allocate(len); + checkState(knownLengthPendingAllocation > 0, "knownLengthPendingAllocation reached 0"); + buffer = bufferAllocator.allocate(knownLengthPendingAllocation); + knownLengthPendingAllocation -= min(knownLengthPendingAllocation, buffer.writableBytes()); } int toWrite = min(len, buffer.writableBytes()); buffer.write(b, off, toWrite); @@ -388,6 +391,8 @@ public class MessageFramer implements Framer { * {@link OutputStream}. */ private final class BufferChainOutputStream extends OutputStream { + private static final int FIRST_BUFFER_SIZE = 4096; + private final List bufferList = new ArrayList<>(); private WritableBuffer current; @@ -397,7 +402,7 @@ public class MessageFramer implements Framer { * {@link #write(byte[], int, int)}. */ @Override - public void write(int b) throws IOException { + public void write(int b) { if (current != null && current.writableBytes() > 0) { current.write((byte)b); return; @@ -410,7 +415,7 @@ public class MessageFramer implements Framer { public void write(byte[] b, int off, int len) { if (current == null) { // Request len bytes initially from the allocator, it may give us more. - current = bufferAllocator.allocate(len); + current = bufferAllocator.allocate(Math.max(FIRST_BUFFER_SIZE, len)); bufferList.add(current); } while (len > 0) { diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/CompressionTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/CompressionTest.java index 208eb40c43..5307c26949 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/CompressionTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/CompressionTest.java @@ -24,6 +24,8 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; import com.google.protobuf.ByteString; import io.grpc.CallOptions; import io.grpc.Channel; @@ -53,8 +55,6 @@ import io.grpc.testing.integration.Messages.SimpleResponse; import io.grpc.testing.integration.TestServiceGrpc.TestServiceBlockingStub; import io.grpc.testing.integration.TransportCompressionTest.Fzip; import java.nio.charset.Charset; -import java.util.ArrayList; -import java.util.Collection; import java.util.List; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; @@ -146,25 +146,16 @@ public class CompressionTest { * Parameters for test. */ @Parameters - public static Collection params() { - boolean[] bools = new boolean[]{false, true}; - List combos = new ArrayList<>(64); - for (boolean enableClientMessageCompression : bools) { - for (boolean clientAcceptEncoding : bools) { - for (boolean clientEncoding : bools) { - for (boolean enableServerMessageCompression : bools) { - for (boolean serverAcceptEncoding : bools) { - for (boolean serverEncoding : bools) { - combos.add(new Object[] { - enableClientMessageCompression, clientAcceptEncoding, clientEncoding, - enableServerMessageCompression, serverAcceptEncoding, serverEncoding}); - } - } - } - } - } - } - return combos; + public static Iterable params() { + List bools = Lists.newArrayList(false, true); + return Iterables.transform(Lists.cartesianProduct( + bools, // enableClientMessageCompression + bools, // clientAcceptEncoding + bools, // clientEncoding + bools, // enableServerMessageCompression + bools, // serverAcceptEncoding + bools // serverEncoding + ), List::toArray); } @Test diff --git a/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java b/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java index a44a196ac8..4dd24c3fd4 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java @@ -233,7 +233,7 @@ public class NettyClientStreamTest extends NettyStreamTestBase