diff --git a/core/src/main/java/com/google/net/stubby/transport/Http2ClientStream.java b/core/src/main/java/com/google/net/stubby/transport/Http2ClientStream.java index cc32fedbd4..f4b9c1a71d 100644 --- a/core/src/main/java/com/google/net/stubby/transport/Http2ClientStream.java +++ b/core/src/main/java/com/google/net/stubby/transport/Http2ClientStream.java @@ -47,6 +47,7 @@ public abstract class Http2ClientStream extends AbstractClientStream { private Status transportError; private Charset errorCharset = Charsets.UTF_8; + private boolean contentTypeChecked; protected Http2ClientStream(ClientStreamListener listener, @Nullable Decompressor decompressor, @@ -61,23 +62,20 @@ public abstract class Http2ClientStream extends AbstractClientStream { transportError = transportError.augmentDescription(headers.toString()); return; } - String contentType = headers.get(HttpUtil.CONTENT_TYPE); Status httpStatus = statusFromHttpStatus(headers); if (httpStatus == null) { transportError = Status.INTERNAL.withDescription( "received non-terminal headers with no :status"); } else if (!httpStatus.isOk()) { transportError = httpStatus; - } else if (TEMP_CHECK_CONTENT_TYPE && - !HttpUtil.CONTENT_TYPE_GRPC.equalsIgnoreCase(contentType)) { - // Malformed content-type so report an error - transportError = Status.INTERNAL.withDescription("invalid content-type " + contentType); + } else { + transportError = checkContentType(headers); } if (transportError != null) { // Note we don't immediately report the transport error, instead we wait for more data on the // stream so we can accumulate more detail into the error before reporting it. transportError = transportError.withDescription("\n" + headers.toString()); - errorCharset = charsetFromContentType(contentType); + errorCharset = extractCharset(headers); } else { stripTransportDetails(headers); inboundHeadersReceived(headers); @@ -137,6 +135,10 @@ public abstract class Http2ClientStream extends AbstractClientStream { if (transportError != null) { // Already received a transport error so just augment it. transportError = transportError.augmentDescription(trailers.toString()); + } else { + transportError = checkContentType(trailers); + } + if (transportError != null) { inboundTransportError(transportError); } else { Status status = statusFromTrailers(trailers); @@ -178,10 +180,29 @@ public abstract class Http2ClientStream extends AbstractClientStream { return status; } + /** + * Inspect the content type field from received headers or trailers and return an error Status if + * content type is invalid or not present. Returns null if no error was found. + */ + @Nullable + private Status checkContentType(Metadata headers) { + if (contentTypeChecked) { + return null; + } + contentTypeChecked = true; + String contentType = headers.get(HttpUtil.CONTENT_TYPE); + if (TEMP_CHECK_CONTENT_TYPE && !HttpUtil.CONTENT_TYPE_GRPC.equalsIgnoreCase(contentType)) { + // Malformed content-type so report an error + return Status.INTERNAL.withDescription("invalid content-type " + contentType); + } + return null; + } + /** * Inspect the raw metadata and figure out what charset is being used. */ - private static Charset charsetFromContentType(String contentType) { + private static Charset extractCharset(Metadata headers) { + String contentType = headers.get(HttpUtil.CONTENT_TYPE); if (contentType != null) { String[] split = contentType.split("charset="); try {