diff --git a/protobuf/src/main/java/io/grpc/protobuf/ProtoInputStream.java b/protobuf/src/main/java/io/grpc/protobuf/ProtoInputStream.java index d5ebae188c..186e3c12e2 100644 --- a/protobuf/src/main/java/io/grpc/protobuf/ProtoInputStream.java +++ b/protobuf/src/main/java/io/grpc/protobuf/ProtoInputStream.java @@ -70,9 +70,11 @@ class ProtoInputStream extends InputStream implements Drainable, KnownLength { written = message.getSerializedSize(); message.writeTo(target); message = null; - } else { + } else if (partial != null) { written = (int) ByteStreams.copy(partial, target); partial = null; + } else { + written = 0; } return written; } diff --git a/protobuf/src/test/java/io/grpc/protobuf/ProtoUtilsTest.java b/protobuf/src/test/java/io/grpc/protobuf/ProtoUtilsTest.java index 7fc93e023d..7858568c5c 100644 --- a/protobuf/src/test/java/io/grpc/protobuf/ProtoUtilsTest.java +++ b/protobuf/src/test/java/io/grpc/protobuf/ProtoUtilsTest.java @@ -41,6 +41,7 @@ import com.google.protobuf.Empty; import com.google.protobuf.Enum; import com.google.protobuf.Type; +import io.grpc.Drainable; import io.grpc.MethodDescriptor.Marshaller; import org.junit.Test; @@ -48,6 +49,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; import java.util.Arrays; @@ -115,4 +117,46 @@ public class ProtoUtilsTest { assertEquals(-1, is.read()); assertEquals(0, is.available()); } + + @Test + public void testDrainTo_all() throws Exception { + byte[] golden = ByteStreams.toByteArray(marshaller.stream(proto)); + InputStream is = marshaller.stream(proto); + Drainable d = (Drainable) is; + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + int drained = d.drainTo(baos); + assertEquals(baos.size(), drained); + assertArrayEquals(golden, baos.toByteArray()); + assertEquals(0, is.available()); + } + + @Test + public void testDrainTo_partial() throws Exception { + final byte[] golden; + { + InputStream is = marshaller.stream(proto); + is.read(); + golden = ByteStreams.toByteArray(is); + } + InputStream is = marshaller.stream(proto); + is.read(); + Drainable d = (Drainable) is; + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + int drained = d.drainTo(baos); + assertEquals(baos.size(), drained); + assertArrayEquals(golden, baos.toByteArray()); + assertEquals(0, is.available()); + } + + @Test + public void testDrainTo_none() throws Exception { + byte[] golden = ByteStreams.toByteArray(marshaller.stream(proto)); + InputStream is = marshaller.stream(proto); + ByteStreams.toByteArray(is); + Drainable d = (Drainable) is; + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + assertEquals(0, d.drainTo(baos)); + assertArrayEquals(new byte[0], baos.toByteArray()); + assertEquals(0, is.available()); + } }