diff --git a/core/src/main/java/io/grpc/internal/AbstractClientStream.java b/core/src/main/java/io/grpc/internal/AbstractClientStream.java index a07baafd01..4a93c1b1d1 100644 --- a/core/src/main/java/io/grpc/internal/AbstractClientStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractClientStream.java @@ -97,8 +97,10 @@ public abstract class AbstractClientStream extends AbstractStream * responsible for properly closing streams when protocol errors occur. * * @param errorStatus the error to report + * @param metadata any metadata received */ - protected void inboundTransportError(Status errorStatus) { + protected void inboundTransportError(Status errorStatus, Metadata metadata) { + Preconditions.checkNotNull(metadata, "metadata"); if (inboundPhase() == Phase.STATUS) { log.log(Level.INFO, "Received transport error on closed stream {0} {1}", new Object[]{id(), errorStatus}); @@ -106,7 +108,7 @@ public abstract class AbstractClientStream extends AbstractStream } // For transport errors we immediately report status to the application layer // and do not wait for additional payloads. - transportReportStatus(errorStatus, false, new Metadata()); + transportReportStatus(errorStatus, false, metadata); } /** @@ -130,7 +132,7 @@ public abstract class AbstractClientStream extends AbstractStream .withCause(e); // TODO(carl-mastrangelo): look back into tearing down this stream. sendCancel() can be // buffered. - inboundTransportError(status); + inboundTransportError(status, headers); sendCancel(status); return; } @@ -155,7 +157,7 @@ public abstract class AbstractClientStream extends AbstractStream if (inboundPhase() == Phase.HEADERS) { // Have not received headers yet so error inboundTransportError(Status.INTERNAL - .withDescription("headers not received before payload")); + .withDescription("headers not received before payload"), new Metadata()); return; } inboundPhase(Phase.MESSAGE); diff --git a/core/src/main/java/io/grpc/internal/Http2ClientStream.java b/core/src/main/java/io/grpc/internal/Http2ClientStream.java index 288f6df779..8ce4adba1f 100644 --- a/core/src/main/java/io/grpc/internal/Http2ClientStream.java +++ b/core/src/main/java/io/grpc/internal/Http2ClientStream.java @@ -65,7 +65,9 @@ public abstract class Http2ClientStream extends AbstractClientStream { private static final Metadata.Key HTTP2_STATUS = Metadata.Key.of(":status", HTTP_STATUS_LINE_MARSHALLER); + /** When non-{@code null}, {@link #transportErrorMetadata} must also be non-{@code null}. */ private Status transportError; + private Metadata transportErrorMetadata; private Charset errorCharset = Charsets.UTF_8; private boolean contentTypeChecked; @@ -99,6 +101,7 @@ public abstract class Http2ClientStream extends AbstractClientStream { // Note we don't immediately report the transport error, instead we wait for more data on the // stream so we can accumulate more detail into the error before reporting it. transportError = transportError.augmentDescription("\n" + headers.toString()); + transportErrorMetadata = headers; errorCharset = extractCharset(headers); } else { stripTransportDetails(headers); @@ -117,6 +120,7 @@ public abstract class Http2ClientStream extends AbstractClientStream { // Must receive headers prior to receiving any payload as we use headers to check for // protocol correctness. transportError = Status.INTERNAL.withDescription("no headers received prior to data"); + transportErrorMetadata = new Metadata(); } if (transportError != null) { // We've already detected a transport error and now we're just accumulating more detail @@ -125,7 +129,7 @@ public abstract class Http2ClientStream extends AbstractClientStream { + ReadableBuffers.readAsString(frame, errorCharset)); frame.close(); if (transportError.getDescription().length() > 1000 || endOfStream) { - inboundTransportError(transportError); + inboundTransportError(transportError, transportErrorMetadata); // We have enough error detail so lets cancel. sendCancel(Status.CANCELLED); } @@ -134,7 +138,8 @@ public abstract class Http2ClientStream extends AbstractClientStream { if (endOfStream) { // This is a protocol violation as we expect to receive trailers. transportError = Status.INTERNAL.withDescription("Recevied EOS on DATA frame"); - inboundTransportError(transportError); + transportErrorMetadata = new Metadata(); + inboundTransportError(transportError, transportErrorMetadata); } } } @@ -151,9 +156,10 @@ public abstract class Http2ClientStream extends AbstractClientStream { transportError = transportError.augmentDescription(trailers.toString()); } else { transportError = checkContentType(trailers); + transportErrorMetadata = trailers; } if (transportError != null) { - inboundTransportError(transportError); + inboundTransportError(transportError, transportErrorMetadata); sendCancel(Status.CANCELLED); } else { Status status = statusFromTrailers(trailers); diff --git a/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java b/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java index 29dbc9c199..ed6209102b 100644 --- a/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java @@ -245,11 +245,15 @@ public class AbstractClientStreamTest { stream.start(mockListener); Metadata headers = new Metadata(); headers.put(GrpcUtil.MESSAGE_ENCODING_KEY, "bad"); + Metadata.Key randomKey = Metadata.Key.of("random", Metadata.ASCII_STRING_MARSHALLER); + headers.put(randomKey, "4"); stream.inboundHeadersReceived(headers); - verify(mockListener).closed(statusCaptor.capture(), isA(Metadata.class)); + ArgumentCaptor metadataCaptor = ArgumentCaptor.forClass(Metadata.class); + verify(mockListener).closed(statusCaptor.capture(), metadataCaptor.capture()); assertEquals(Code.INTERNAL, statusCaptor.getValue().getCode()); + assertEquals("4", metadataCaptor.getValue().get(randomKey)); } @Test diff --git a/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java b/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java index d0390d25eb..78b56f3d03 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java @@ -236,6 +236,7 @@ public class NettyClientStreamTest extends NettyStreamTestBase captor = ArgumentCaptor.forClass(Status.class); - verify(listener).closed(captor.capture(), any(Metadata.class)); + ArgumentCaptor metadataCaptor = ArgumentCaptor.forClass(Metadata.class); + verify(listener).closed(captor.capture(), metadataCaptor.capture()); assertEquals(Status.UNKNOWN.getCode(), captor.getValue().getCode()); + assertEquals("4", metadataCaptor.getValue() + .get(Metadata.Key.of("random", Metadata.ASCII_STRING_MARSHALLER))); assertTrue(stream.isClosed()); } @@ -269,10 +273,13 @@ public class NettyClientStreamTest extends NettyStreamTestBase captor = ArgumentCaptor.forClass(Status.class); - verify(listener).closed(captor.capture(), any(Metadata.class)); + ArgumentCaptor metadataCaptor = ArgumentCaptor.forClass(Metadata.class); + verify(listener).closed(captor.capture(), metadataCaptor.capture()); Status status = captor.getValue(); assertEquals(Status.Code.INTERNAL, status.getCode()); assertTrue(status.getDescription().contains("content-type")); + assertEquals("application/bad", metadataCaptor.getValue() + .get(Metadata.Key.of("Content-Type", Metadata.ASCII_STRING_MARSHALLER))); } @Test diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java index 6c090422d9..fa6174436f 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java @@ -325,8 +325,8 @@ public class OkHttpClientTransportTest { stream.start(listener); stream.request(1); assertContainStream(3); - // Empty headers block without correct content type or status - frameHandler().headers(false, false, 3, 0, new ArrayList
(), + // Headers block without correct content type or status + frameHandler().headers(false, false, 3, 0, Arrays.asList(new Header("random", "4")), HeadersMode.HTTP_20_HEADERS); // Now wait to receive 1000 bytes of data so we can have a better error message before // cancelling the streaam. @@ -335,6 +335,27 @@ public class OkHttpClientTransportTest { assertNull(listener.headers); assertEquals(Status.INTERNAL.getCode(), listener.status.getCode()); assertNotNull(listener.trailers); + assertEquals("4", listener.trailers + .get(Metadata.Key.of("random", Metadata.ASCII_STRING_MARSHALLER))); + shutdownAndVerify(); + } + + @Test + public void invalidInboundTrailersPropagateToMetadata() throws Exception { + initTransport(); + MockStreamListener listener = new MockStreamListener(); + OkHttpClientStream stream = clientTransport.newStream(method, new Metadata()); + stream.start(listener); + stream.request(1); + assertContainStream(3); + // Headers block with EOS without correct content type or status + frameHandler().headers(true, true, 3, 0, Arrays.asList(new Header("random", "4")), + HeadersMode.HTTP_20_HEADERS); + assertNull(listener.headers); + assertEquals(Status.INTERNAL.getCode(), listener.status.getCode()); + assertNotNull(listener.trailers); + assertEquals("4", listener.trailers + .get(Metadata.Key.of("random", Metadata.ASCII_STRING_MARSHALLER))); shutdownAndVerify(); }