diff --git a/core/src/main/java/io/grpc/internal/ReadableBuffers.java b/core/src/main/java/io/grpc/internal/ReadableBuffers.java index 6832940289..8eba4fdf75 100644 --- a/core/src/main/java/io/grpc/internal/ReadableBuffers.java +++ b/core/src/main/java/io/grpc/internal/ReadableBuffers.java @@ -312,7 +312,7 @@ public final class ReadableBuffers { /** * An {@link InputStream} that is backed by a {@link ReadableBuffer}. */ - private static class BufferInputStream extends InputStream implements KnownLength { + private static final class BufferInputStream extends InputStream implements KnownLength { final ReadableBuffer buffer; public BufferInputStream(ReadableBuffer buffer) { diff --git a/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java b/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java index 3f51550389..a66f57b5a3 100644 --- a/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java +++ b/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java @@ -40,10 +40,13 @@ import com.google.protobuf.MessageLite; import com.google.protobuf.Parser; import io.grpc.ExperimentalApi; +import io.grpc.KnownLength; import io.grpc.Metadata; import io.grpc.MethodDescriptor.Marshaller; import io.grpc.Status; +import io.grpc.internal.GrpcUtil; +import java.io.IOException; import java.io.InputStream; /** @@ -107,23 +110,45 @@ public class ProtoLiteUtils { } } } + CodedInputStream cis = null; try { - return parseFrom(stream); + if (stream instanceof KnownLength) { + int size = stream.available(); + if (size > 0 && size <= GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE) { + byte[] buf = new byte[size]; + int chunkSize; + int position = 0; + while ((chunkSize = stream.read(buf, position, buf.length - position)) != -1) { + position += chunkSize; + } + if (buf.length != position) { + throw new RuntimeException("size inaccurate: " + buf.length + " != " + position); + } + cis = CodedInputStream.newInstance(buf); + } + } + } catch (IOException e) { + throw new RuntimeException(e); + } + if (cis == null) { + cis = CodedInputStream.newInstance(stream); + } + // Pre-create the CodedInputStream so that we can remove the size limit restriction + // when parsing. + cis.setSizeLimit(Integer.MAX_VALUE); + + try { + return parseFrom(cis); } catch (InvalidProtocolBufferException ipbe) { throw Status.INTERNAL.withDescription("Invalid protobuf byte sequence") .withCause(ipbe).asRuntimeException(); } } - private T parseFrom(InputStream stream) throws InvalidProtocolBufferException { - // Pre-create the CodedInputStream so that we can remove the size limit restriction - // when parsing. - CodedInputStream codedInput = CodedInputStream.newInstance(stream); - codedInput.setSizeLimit(Integer.MAX_VALUE); - - T message = parser.parseFrom(codedInput, globalRegistry); + private T parseFrom(CodedInputStream stream) throws InvalidProtocolBufferException { + T message = parser.parseFrom(stream, globalRegistry); try { - codedInput.checkLastTagWas(0); + stream.checkLastTagWas(0); return message; } catch (InvalidProtocolBufferException e) { e.setUnfinishedMessage(message);