diff --git a/core/src/main/java/io/grpc/internal/TransportFrameUtil.java b/core/src/main/java/io/grpc/internal/TransportFrameUtil.java index e015740b79..84780315fd 100644 --- a/core/src/main/java/io/grpc/internal/TransportFrameUtil.java +++ b/core/src/main/java/io/grpc/internal/TransportFrameUtil.java @@ -21,8 +21,11 @@ import static com.google.common.base.Charsets.US_ASCII; import com.google.common.io.BaseEncoding; import io.grpc.InternalMetadata; import io.grpc.Metadata; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.logging.Logger; +import javax.annotation.CheckReturnValue; /** * Utility functions for transport layer framing. @@ -84,21 +87,30 @@ public final class TransportFrameUtil { /** * Transform HTTP/2-compliant headers to the raw serialized format which can be deserialized by - * metadata marshallers. It decodes the Base64-encoded binary headers. This function modifies - * the headers in place. By modifying the input array. + * metadata marshallers. It decodes the Base64-encoded binary headers. + * + *

Warning: This function may partially modify the headers in place by modifying the input + * array (but not modifying any single byte), so the input reference {@code http2Headers} can not + * be used again. * * @param http2Headers the interleaved keys and values of HTTP/2-compliant headers * @return the interleaved keys and values in the raw serialized format */ @SuppressWarnings("BetaApi") // BaseEncoding is stable in Guava 20.0 + @CheckReturnValue public static byte[][] toRawSerializedHeaders(byte[][] http2Headers) { for (int i = 0; i < http2Headers.length; i += 2) { byte[] key = http2Headers[i]; byte[] value = http2Headers[i + 1]; - http2Headers[i] = key; if (endsWith(key, binaryHeaderSuffixBytes)) { // Binary header - http2Headers[i + 1] = BaseEncoding.base64().decode(new String(value, US_ASCII)); + for (int idx = 0; idx < value.length; idx++) { + if (value[idx] == (byte) ',') { + return serializeHeadersWithCommasInBin(http2Headers, i); + } + } + byte[] decodedVal = BaseEncoding.base64().decode(new String(value, US_ASCII)); + http2Headers[i + 1] = decodedVal; } else { // Non-binary header // Nothing to do, the value is already in the right place. @@ -107,6 +119,35 @@ public final class TransportFrameUtil { return http2Headers; } + private static byte[][] serializeHeadersWithCommasInBin(byte[][] http2Headers, int resumeFrom) { + List headerList = new ArrayList<>(http2Headers.length + 10); + for (int i = 0; i < resumeFrom; i++) { + headerList.add(http2Headers[i]); + } + for (int i = resumeFrom; i < http2Headers.length; i += 2) { + byte[] key = http2Headers[i]; + byte[] value = http2Headers[i + 1]; + if (!endsWith(key, binaryHeaderSuffixBytes)) { + headerList.add(key); + headerList.add(value); + continue; + } + // Binary header + int prevIdx = 0; + for (int idx = 0; idx <= value.length; idx++) { + if (idx != value.length && value[idx] != (byte) ',') { + continue; + } + byte[] decodedVal = + BaseEncoding.base64().decode(new String(value, prevIdx, idx - prevIdx, US_ASCII)); + prevIdx = idx + 1; + headerList.add(key); + headerList.add(decodedVal); + } + } + return headerList.toArray(new byte[0][]); + } + /** * Returns {@code true} if {@code subject} ends with {@code suffix}. */ diff --git a/core/src/test/java/io/grpc/internal/TransportFrameUtilTest.java b/core/src/test/java/io/grpc/internal/TransportFrameUtilTest.java index aa99fdb9f5..4d0f6ed0b5 100644 --- a/core/src/test/java/io/grpc/internal/TransportFrameUtilTest.java +++ b/core/src/test/java/io/grpc/internal/TransportFrameUtilTest.java @@ -19,11 +19,14 @@ package io.grpc.internal; import static com.google.common.base.Charsets.US_ASCII; import static com.google.common.base.Charsets.UTF_8; import static io.grpc.Metadata.ASCII_STRING_MARSHALLER; +import static io.grpc.Metadata.BINARY_BYTE_MARSHALLER; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import com.google.common.collect.Iterables; import com.google.common.io.BaseEncoding; import io.grpc.InternalMetadata; import io.grpc.Metadata; @@ -59,6 +62,7 @@ public class TransportFrameUtilTest { private static final Key BINARY_STRING = Key.of("string-bin", UTF8_STRING_MARSHALLER); private static final Key BINARY_STRING_WITHOUT_SUFFIX = Key.of("string", ASCII_STRING_MARSHALLER); + private static final Key BINARY_BYTES = Key.of("bytes-bin", BINARY_BYTE_MARSHALLER); @Test public void testToHttp2Headers() { @@ -99,6 +103,31 @@ public class TransportFrameUtilTest { assertNull(recoveredHeaders.get(BINARY_STRING_WITHOUT_SUFFIX)); } + @Test + public void dupBinHeadersWithComma() { + byte[][] http2Headers = new byte[][] { + BINARY_BYTES.name().getBytes(US_ASCII), + "BaS,e6,,4+,padding==".getBytes(US_ASCII), + BINARY_BYTES.name().getBytes(US_ASCII), + "more".getBytes(US_ASCII), + BINARY_BYTES.name().getBytes(US_ASCII), + "".getBytes(US_ASCII)}; + byte[][] rawSerialized = TransportFrameUtil.toRawSerializedHeaders(http2Headers); + Metadata recoveredHeaders = InternalMetadata.newMetadata(rawSerialized); + byte[][] values = Iterables.toArray(recoveredHeaders.getAll(BINARY_BYTES), byte[].class); + + assertTrue(Arrays.deepEquals( + new byte[][] { + BaseEncoding.base64().decode("BaS"), + BaseEncoding.base64().decode("e6"), + BaseEncoding.base64().decode(""), + BaseEncoding.base64().decode("4+"), + BaseEncoding.base64().decode("padding"), + BaseEncoding.base64().decode("more"), + BaseEncoding.base64().decode("")}, + values)); + } + private static void assertContains(byte[][] headers, byte[] key, byte[] value) { String keyString = new String(key, US_ASCII); for (int i = 0; i < headers.length; i += 2) { diff --git a/netty/src/main/java/io/grpc/netty/GrpcHttp2HeadersUtils.java b/netty/src/main/java/io/grpc/netty/GrpcHttp2HeadersUtils.java index 920afca0c6..de52f0b043 100644 --- a/netty/src/main/java/io/grpc/netty/GrpcHttp2HeadersUtils.java +++ b/netty/src/main/java/io/grpc/netty/GrpcHttp2HeadersUtils.java @@ -101,18 +101,37 @@ class GrpcHttp2HeadersUtils { values = new AsciiString[numHeadersGuess]; } + @SuppressWarnings("BetaApi") // BaseEncoding is stable in Guava 20.0 protected Http2Headers add(AsciiString name, AsciiString value) { + byte[] nameBytes = bytes(name); + byte[] valueBytes; + if (!name.endsWith(binaryHeaderSuffix)) { + valueBytes = bytes(value); + addHeader(value, nameBytes, valueBytes); + return this; + } + int startPos = 0; + int endPos = -1; + while (endPos < value.length()) { + int indexOfComma = value.indexOf(',', startPos); + endPos = indexOfComma == AsciiString.INDEX_NOT_FOUND ? value.length() : indexOfComma; + AsciiString curVal = value.subSequence(startPos, endPos, false); + valueBytes = BaseEncoding.base64().decode(curVal); + startPos = indexOfComma + 1; + addHeader(curVal, nameBytes, valueBytes); + } + return this; + } + + private void addHeader(AsciiString value, byte[] nameBytes, byte[] valueBytes) { if (namesAndValuesIdx == namesAndValues.length) { expandHeadersAndValues(); } - byte[] nameBytes = bytes(name); - byte[] valueBytes = toBinaryValue(name, value); values[namesAndValuesIdx / 2] = value; namesAndValues[namesAndValuesIdx] = nameBytes; namesAndValuesIdx++; namesAndValues[namesAndValuesIdx] = valueBytes; namesAndValuesIdx++; - return this; } protected CharSequence get(AsciiString name) { @@ -179,13 +198,6 @@ class GrpcHttp2HeadersUtils { return PlatformDependent.equals(bytes0, offset0, bytes1, offset1, length0); } - @SuppressWarnings("BetaApi") // BaseEncoding is stable in Guava 20.0 - private static byte[] toBinaryValue(AsciiString name, AsciiString value) { - return name.endsWith(binaryHeaderSuffix) - ? BaseEncoding.base64().decode(value) - : bytes(value); - } - protected static byte[] bytes(AsciiString str) { return str.isEntireArrayUsed() ? str.array() : str.toByteArray(); } diff --git a/netty/src/main/java/io/grpc/netty/Utils.java b/netty/src/main/java/io/grpc/netty/Utils.java index 3276592229..4b71ef0804 100644 --- a/netty/src/main/java/io/grpc/netty/Utils.java +++ b/netty/src/main/java/io/grpc/netty/Utils.java @@ -48,6 +48,7 @@ import java.util.Map; import java.util.Map.Entry; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; +import javax.annotation.CheckReturnValue; /** * Common utility methods. @@ -72,9 +73,6 @@ class Utils { public static final Resource DEFAULT_WORKER_EVENT_LOOP_GROUP = new DefaultEventLoopGroupResource(0, "grpc-default-worker-ELG"); - @VisibleForTesting - static boolean validateHeaders = false; - public static Metadata convertHeaders(Http2Headers http2Headers) { if (http2Headers instanceof GrpcHttp2InboundHeaders) { GrpcHttp2InboundHeaders h = (GrpcHttp2InboundHeaders) http2Headers; @@ -83,6 +81,7 @@ class Utils { return InternalMetadata.newMetadata(convertHeadersToArray(http2Headers)); } + @CheckReturnValue private static byte[][] convertHeadersToArray(Http2Headers http2Headers) { // The Netty AsciiString class is really just a wrapper around a byte[] and supports // arbitrary binary data, not just ASCII. diff --git a/netty/src/test/java/io/grpc/netty/GrpcHttp2HeadersUtilsTest.java b/netty/src/test/java/io/grpc/netty/GrpcHttp2HeadersUtilsTest.java index 0c548e64e0..96d87a5378 100644 --- a/netty/src/test/java/io/grpc/netty/GrpcHttp2HeadersUtilsTest.java +++ b/netty/src/test/java/io/grpc/netty/GrpcHttp2HeadersUtilsTest.java @@ -16,13 +16,20 @@ package io.grpc.netty; +import static io.grpc.Metadata.BINARY_BYTE_MARSHALLER; import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE; import static io.netty.util.AsciiString.of; import static org.hamcrest.CoreMatchers.containsString; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import com.google.common.collect.Iterables; +import com.google.common.io.BaseEncoding; +import io.grpc.Metadata; +import io.grpc.Metadata.Key; import io.grpc.netty.GrpcHttp2HeadersUtils.GrpcHttp2ClientHeadersDecoder; +import io.grpc.netty.GrpcHttp2HeadersUtils.GrpcHttp2RequestHeaders; import io.grpc.netty.GrpcHttp2HeadersUtils.GrpcHttp2ServerHeadersDecoder; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; @@ -33,6 +40,8 @@ import io.netty.handler.codec.http2.Http2Headers; import io.netty.handler.codec.http2.Http2HeadersDecoder; import io.netty.handler.codec.http2.Http2HeadersEncoder; import io.netty.handler.codec.http2.Http2HeadersEncoder.SensitivityDetector; +import io.netty.util.AsciiString; +import java.util.Arrays; import org.junit.After; import org.junit.Test; import org.junit.runner.RunWith; @@ -125,6 +134,28 @@ public class GrpcHttp2HeadersUtilsTest { assertThat(decodedHeaders.toString(), containsString("[]")); } + @Test + public void dupBinHeadersWithComma() { + Key key = Key.of("bytes-bin", BINARY_BYTE_MARSHALLER); + Http2Headers http2Headers = new GrpcHttp2RequestHeaders(2); + http2Headers.add(AsciiString.of("bytes-bin"), AsciiString.of("BaS,e6,,4+,padding==")); + http2Headers.add(AsciiString.of("bytes-bin"), AsciiString.of("more")); + http2Headers.add(AsciiString.of("bytes-bin"), AsciiString.of("")); + Metadata recoveredHeaders = Utils.convertHeaders(http2Headers); + byte[][] values = Iterables.toArray(recoveredHeaders.getAll(key), byte[].class); + + assertTrue(Arrays.deepEquals( + new byte[][] { + BaseEncoding.base64().decode("BaS"), + BaseEncoding.base64().decode("e6"), + BaseEncoding.base64().decode(""), + BaseEncoding.base64().decode("4+"), + BaseEncoding.base64().decode("padding"), + BaseEncoding.base64().decode("more"), + BaseEncoding.base64().decode("")}, + values)); + } + private static void assertContainsKeyAndValue(String str, CharSequence key, CharSequence value) { assertThat(str, containsString(key.toString())); assertThat(str, containsString(value.toString())); diff --git a/okhttp/src/main/java/io/grpc/okhttp/Utils.java b/okhttp/src/main/java/io/grpc/okhttp/Utils.java index 913369dc18..7ae35f9a37 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/Utils.java +++ b/okhttp/src/main/java/io/grpc/okhttp/Utils.java @@ -29,6 +29,7 @@ import java.net.SocketException; import java.util.List; import java.util.logging.Level; import java.util.logging.Logger; +import javax.annotation.CheckReturnValue; /** * Common utility methods for OkHttp transport. @@ -51,6 +52,7 @@ class Utils { return InternalMetadata.newMetadata(convertHeadersToArray(http2Headers)); } + @CheckReturnValue private static byte[][] convertHeadersToArray(List

http2Headers) { byte[][] headerValues = new byte[http2Headers.size() * 2][]; int i = 0; diff --git a/testing/src/main/java/io/grpc/internal/testing/AbstractTransportTest.java b/testing/src/main/java/io/grpc/internal/testing/AbstractTransportTest.java index 45a0741a7d..8b8582bb60 100644 --- a/testing/src/main/java/io/grpc/internal/testing/AbstractTransportTest.java +++ b/testing/src/main/java/io/grpc/internal/testing/AbstractTransportTest.java @@ -722,6 +722,7 @@ public abstract class AbstractTransportTest { clientHeaders.put(asciiKey, "dupvalue"); clientHeaders.put(asciiKey, "dupvalue"); clientHeaders.put(binaryKey, "äbinaryclient"); + clientHeaders.put(binaryKey, "dup,value"); Metadata clientHeadersCopy = new Metadata(); clientHeadersCopy.merge(clientHeaders); @@ -790,14 +791,13 @@ public abstract class AbstractTransportTest { serverHeaders.put(asciiKey, "dupvalue"); serverHeaders.put(asciiKey, "dupvalue"); serverHeaders.put(binaryKey, "äbinaryserver"); + serverHeaders.put(binaryKey, "dup,value"); Metadata serverHeadersCopy = new Metadata(); serverHeadersCopy.merge(serverHeaders); serverStream.writeHeaders(serverHeaders); Metadata headers = clientStreamListener.headers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); assertNotNull(headers); - assertEquals( - Lists.newArrayList(serverHeadersCopy.getAll(asciiKey)), - Lists.newArrayList(headers.getAll(asciiKey))); + assertAsciiMetadataValuesEqual(serverHeadersCopy.getAll(asciiKey), headers.getAll(asciiKey)); assertEquals( Lists.newArrayList(serverHeadersCopy.getAll(binaryKey)), Lists.newArrayList(headers.getAll(binaryKey))); @@ -842,6 +842,7 @@ public abstract class AbstractTransportTest { trailers.put(asciiKey, "dupvalue"); trailers.put(asciiKey, "dupvalue"); trailers.put(binaryKey, "äbinarytrailers"); + trailers.put(binaryKey, "dup,value"); serverStream.close(status, trailers); assertNull(serverStreamTracer1.nextInboundEvent()); assertNull(serverStreamTracer1.nextOutboundEvent()); @@ -855,9 +856,8 @@ public abstract class AbstractTransportTest { assertNull(clientStreamTracer1.nextOutboundEvent()); assertEquals(status.getCode(), clientStreamStatus.getCode()); assertEquals(status.getDescription(), clientStreamStatus.getDescription()); - assertEquals( - Lists.newArrayList(trailers.getAll(asciiKey)), - Lists.newArrayList(clientStreamTrailers.getAll(asciiKey))); + assertAsciiMetadataValuesEqual( + trailers.getAll(asciiKey), clientStreamTrailers.getAll(asciiKey)); assertEquals( Lists.newArrayList(trailers.getAll(binaryKey)), Lists.newArrayList(clientStreamTrailers.getAll(binaryKey))); @@ -883,6 +883,18 @@ public abstract class AbstractTransportTest { assertEquals(testAuthority(server), serverStream.getAuthority()); } + private void assertAsciiMetadataValuesEqual(Iterable expected, Iterable actural) { + StringBuilder sbExpected = new StringBuilder(); + for (String str : expected) { + sbExpected.append(str).append(","); + } + StringBuilder sbActual = new StringBuilder(); + for (String str : actural) { + sbActual.append(str).append(","); + } + assertEquals(sbExpected.toString(), sbActual.toString()); + } + @Test public void zeroMessageStream() throws Exception { server.start(serverListener);