diff --git a/protobuf-nano/src/test/java/io/grpc/protobuf/nano/NanoUtilsTest.java b/protobuf-nano/src/test/java/io/grpc/protobuf/nano/NanoUtilsTest.java index 621f11cffb..f49184c620 100644 --- a/protobuf-nano/src/test/java/io/grpc/protobuf/nano/NanoUtilsTest.java +++ b/protobuf-nano/src/test/java/io/grpc/protobuf/nano/NanoUtilsTest.java @@ -34,10 +34,10 @@ package io.grpc.protobuf.nano; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import com.google.protobuf.nano.InvalidProtocolBufferNanoException; import com.google.protobuf.nano.MessageNano; import io.grpc.MethodDescriptor.Marshaller; @@ -49,6 +49,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; @@ -77,20 +78,14 @@ public class NanoUtilsTest { } @Test - public void testIoException() { - final IOException ioException = new IOException(); - InputStream is = new InputStream() { - @Override - public int read() throws IOException { - throw ioException; - } - }; + public void parseInvalid() throws Exception { + InputStream is = new ByteArrayInputStream(new byte[] {-127}); try { marshaller.parse(is); fail("Expected exception"); } catch (StatusRuntimeException ex) { assertEquals(Status.Code.INTERNAL, ex.getStatus().getCode()); - assertSame(ioException, ex.getCause()); + assertTrue(ex.getCause() instanceof InvalidProtocolBufferNanoException); } } diff --git a/protobuf/src/main/java/io/grpc/protobuf/ProtoUtils.java b/protobuf/src/main/java/io/grpc/protobuf/ProtoUtils.java index 61a44fef7b..dbc5d1eb4e 100644 --- a/protobuf/src/main/java/io/grpc/protobuf/ProtoUtils.java +++ b/protobuf/src/main/java/io/grpc/protobuf/ProtoUtils.java @@ -31,6 +31,7 @@ package io.grpc.protobuf; +import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.CodedInputStream; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; @@ -71,9 +72,15 @@ public class ProtoUtils { // if not, using the same MethodDescriptor would ensure the parser matches and permit us // to enable this optimization. if (protoStream.parser() == parser) { - @SuppressWarnings("unchecked") - T message = (T) ((ProtoInputStream) stream).message(); - return message; + try { + @SuppressWarnings("unchecked") + T message = (T) ((ProtoInputStream) stream).message(); + return message; + } catch (IllegalStateException ex) { + // Stream must have been read from, which is a strange state. Since the point of this + // optimization is to be transparent, instead of throwing an error we'll continue, + // even though it seems likely there's a bug. + } } } try { @@ -105,25 +112,30 @@ public class ProtoUtils { /** * Produce a metadata key for a generated protobuf type. */ - public static Metadata.Key keyForProto(final T instance) { + public static Metadata.Key keyForProto(T instance) { return Metadata.Key.of( instance.getDescriptorForType().getFullName() + Metadata.BINARY_HEADER_SUFFIX, - new Metadata.BinaryMarshaller() { - @Override - public byte[] toBytes(T value) { - return value.toByteArray(); - } + keyMarshaller(instance)); + } - @Override - @SuppressWarnings("unchecked") - public T parseBytes(byte[] serialized) { - try { - return (T) instance.getParserForType().parseFrom(serialized); - } catch (InvalidProtocolBufferException ipbe) { - throw new IllegalArgumentException(ipbe); - } - } - }); + @VisibleForTesting + static Metadata.BinaryMarshaller keyMarshaller(final T instance) { + return new Metadata.BinaryMarshaller() { + @Override + public byte[] toBytes(T value) { + return value.toByteArray(); + } + + @Override + @SuppressWarnings("unchecked") + public T parseBytes(byte[] serialized) { + try { + return (T) instance.getParserForType().parseFrom(serialized); + } catch (InvalidProtocolBufferException ipbe) { + throw new IllegalArgumentException(ipbe); + } + } + }; } private ProtoUtils() { diff --git a/protobuf/src/test/java/io/grpc/protobuf/ProtoUtilsTest.java b/protobuf/src/test/java/io/grpc/protobuf/ProtoUtilsTest.java index 7858568c5c..dc9ede460e 100644 --- a/protobuf/src/test/java/io/grpc/protobuf/ProtoUtilsTest.java +++ b/protobuf/src/test/java/io/grpc/protobuf/ProtoUtilsTest.java @@ -33,16 +33,22 @@ package io.grpc.protobuf; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.fail; import com.google.common.io.ByteStreams; import com.google.protobuf.ByteString; import com.google.protobuf.Empty; import com.google.protobuf.Enum; +import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Type; import io.grpc.Drainable; +import io.grpc.Metadata; import io.grpc.MethodDescriptor.Marshaller; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; import org.junit.Test; import org.junit.runner.RunWith; @@ -72,6 +78,33 @@ public class ProtoUtilsTest { assertEquals(proto, marshaller.parse(is)); } + @Test + public void testInvalidatedMessage() throws Exception { + InputStream is = marshaller.stream(proto); + // Invalidates message, and drains all bytes + ByteStreams.toByteArray(is); + try { + ((ProtoInputStream) is).message(); + fail("Expected exception"); + } catch (IllegalStateException ex) { + // expected + } + // Zero bytes is the default message + assertEquals(Type.getDefaultInstance(), marshaller.parse(is)); + } + + @Test + public void parseInvalid() throws Exception { + InputStream is = new ByteArrayInputStream(new byte[] {-127}); + try { + marshaller.parse(is); + fail("Expected exception"); + } catch (StatusRuntimeException ex) { + assertEquals(Status.Code.INTERNAL, ex.getStatus().getCode()); + assertNotNull(((InvalidProtocolBufferException) ex.getCause()).getUnfinishedMessage()); + } + } + @Test public void testMismatch() throws Exception { Marshaller enumMarshaller = ProtoUtils.marshaller(Enum.getDefaultInstance()); @@ -159,4 +192,31 @@ public class ProtoUtilsTest { assertArrayEquals(new byte[0], baos.toByteArray()); assertEquals(0, is.available()); } + + @Test + public void keyForProto() { + assertEquals("google.protobuf.Type-bin", + ProtoUtils.keyForProto(Type.getDefaultInstance()).originalName()); + } + + @Test + public void keyMarshaller_roundtrip() { + Metadata.BinaryMarshaller keyMarshaller = + ProtoUtils.keyMarshaller(Type.getDefaultInstance()); + assertEquals(proto, keyMarshaller.parseBytes(keyMarshaller.toBytes(proto))); + } + + @Test + public void keyMarshaller_invalid() { + Metadata.BinaryMarshaller keyMarshaller = + ProtoUtils.keyMarshaller(Type.getDefaultInstance()); + try { + keyMarshaller.parseBytes(new byte[] {-127}); + fail("Expected exception"); + } catch (IllegalArgumentException ex) { + assertNotNull(((InvalidProtocolBufferException) ex.getCause()).getUnfinishedMessage()); + } + } + + }