diff --git a/core/src/main/java/io/grpc/internal/TransportFrameUtil.java b/core/src/main/java/io/grpc/internal/TransportFrameUtil.java
index e015740b79..84780315fd 100644
--- a/core/src/main/java/io/grpc/internal/TransportFrameUtil.java
+++ b/core/src/main/java/io/grpc/internal/TransportFrameUtil.java
@@ -21,8 +21,11 @@ import static com.google.common.base.Charsets.US_ASCII;
import com.google.common.io.BaseEncoding;
import io.grpc.InternalMetadata;
import io.grpc.Metadata;
+import java.util.ArrayList;
import java.util.Arrays;
+import java.util.List;
import java.util.logging.Logger;
+import javax.annotation.CheckReturnValue;
/**
* Utility functions for transport layer framing.
@@ -84,21 +87,30 @@ public final class TransportFrameUtil {
/**
* Transform HTTP/2-compliant headers to the raw serialized format which can be deserialized by
- * metadata marshallers. It decodes the Base64-encoded binary headers. This function modifies
- * the headers in place. By modifying the input array.
+ * metadata marshallers. It decodes the Base64-encoded binary headers.
+ *
+ *
Warning: This function may partially modify the headers in place by modifying the input
+ * array (but not modifying any single byte), so the input reference {@code http2Headers} can not
+ * be used again.
*
* @param http2Headers the interleaved keys and values of HTTP/2-compliant headers
* @return the interleaved keys and values in the raw serialized format
*/
@SuppressWarnings("BetaApi") // BaseEncoding is stable in Guava 20.0
+ @CheckReturnValue
public static byte[][] toRawSerializedHeaders(byte[][] http2Headers) {
for (int i = 0; i < http2Headers.length; i += 2) {
byte[] key = http2Headers[i];
byte[] value = http2Headers[i + 1];
- http2Headers[i] = key;
if (endsWith(key, binaryHeaderSuffixBytes)) {
// Binary header
- http2Headers[i + 1] = BaseEncoding.base64().decode(new String(value, US_ASCII));
+ for (int idx = 0; idx < value.length; idx++) {
+ if (value[idx] == (byte) ',') {
+ return serializeHeadersWithCommasInBin(http2Headers, i);
+ }
+ }
+ byte[] decodedVal = BaseEncoding.base64().decode(new String(value, US_ASCII));
+ http2Headers[i + 1] = decodedVal;
} else {
// Non-binary header
// Nothing to do, the value is already in the right place.
@@ -107,6 +119,35 @@ public final class TransportFrameUtil {
return http2Headers;
}
+ private static byte[][] serializeHeadersWithCommasInBin(byte[][] http2Headers, int resumeFrom) {
+ List headerList = new ArrayList<>(http2Headers.length + 10);
+ for (int i = 0; i < resumeFrom; i++) {
+ headerList.add(http2Headers[i]);
+ }
+ for (int i = resumeFrom; i < http2Headers.length; i += 2) {
+ byte[] key = http2Headers[i];
+ byte[] value = http2Headers[i + 1];
+ if (!endsWith(key, binaryHeaderSuffixBytes)) {
+ headerList.add(key);
+ headerList.add(value);
+ continue;
+ }
+ // Binary header
+ int prevIdx = 0;
+ for (int idx = 0; idx <= value.length; idx++) {
+ if (idx != value.length && value[idx] != (byte) ',') {
+ continue;
+ }
+ byte[] decodedVal =
+ BaseEncoding.base64().decode(new String(value, prevIdx, idx - prevIdx, US_ASCII));
+ prevIdx = idx + 1;
+ headerList.add(key);
+ headerList.add(decodedVal);
+ }
+ }
+ return headerList.toArray(new byte[0][]);
+ }
+
/**
* Returns {@code true} if {@code subject} ends with {@code suffix}.
*/
diff --git a/core/src/test/java/io/grpc/internal/TransportFrameUtilTest.java b/core/src/test/java/io/grpc/internal/TransportFrameUtilTest.java
index aa99fdb9f5..4d0f6ed0b5 100644
--- a/core/src/test/java/io/grpc/internal/TransportFrameUtilTest.java
+++ b/core/src/test/java/io/grpc/internal/TransportFrameUtilTest.java
@@ -19,11 +19,14 @@ package io.grpc.internal;
import static com.google.common.base.Charsets.US_ASCII;
import static com.google.common.base.Charsets.UTF_8;
import static io.grpc.Metadata.ASCII_STRING_MARSHALLER;
+import static io.grpc.Metadata.BINARY_BYTE_MARSHALLER;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
+import com.google.common.collect.Iterables;
import com.google.common.io.BaseEncoding;
import io.grpc.InternalMetadata;
import io.grpc.Metadata;
@@ -59,6 +62,7 @@ public class TransportFrameUtilTest {
private static final Key BINARY_STRING = Key.of("string-bin", UTF8_STRING_MARSHALLER);
private static final Key BINARY_STRING_WITHOUT_SUFFIX =
Key.of("string", ASCII_STRING_MARSHALLER);
+ private static final Key BINARY_BYTES = Key.of("bytes-bin", BINARY_BYTE_MARSHALLER);
@Test
public void testToHttp2Headers() {
@@ -99,6 +103,31 @@ public class TransportFrameUtilTest {
assertNull(recoveredHeaders.get(BINARY_STRING_WITHOUT_SUFFIX));
}
+ @Test
+ public void dupBinHeadersWithComma() {
+ byte[][] http2Headers = new byte[][] {
+ BINARY_BYTES.name().getBytes(US_ASCII),
+ "BaS,e6,,4+,padding==".getBytes(US_ASCII),
+ BINARY_BYTES.name().getBytes(US_ASCII),
+ "more".getBytes(US_ASCII),
+ BINARY_BYTES.name().getBytes(US_ASCII),
+ "".getBytes(US_ASCII)};
+ byte[][] rawSerialized = TransportFrameUtil.toRawSerializedHeaders(http2Headers);
+ Metadata recoveredHeaders = InternalMetadata.newMetadata(rawSerialized);
+ byte[][] values = Iterables.toArray(recoveredHeaders.getAll(BINARY_BYTES), byte[].class);
+
+ assertTrue(Arrays.deepEquals(
+ new byte[][] {
+ BaseEncoding.base64().decode("BaS"),
+ BaseEncoding.base64().decode("e6"),
+ BaseEncoding.base64().decode(""),
+ BaseEncoding.base64().decode("4+"),
+ BaseEncoding.base64().decode("padding"),
+ BaseEncoding.base64().decode("more"),
+ BaseEncoding.base64().decode("")},
+ values));
+ }
+
private static void assertContains(byte[][] headers, byte[] key, byte[] value) {
String keyString = new String(key, US_ASCII);
for (int i = 0; i < headers.length; i += 2) {
diff --git a/netty/src/main/java/io/grpc/netty/GrpcHttp2HeadersUtils.java b/netty/src/main/java/io/grpc/netty/GrpcHttp2HeadersUtils.java
index 920afca0c6..de52f0b043 100644
--- a/netty/src/main/java/io/grpc/netty/GrpcHttp2HeadersUtils.java
+++ b/netty/src/main/java/io/grpc/netty/GrpcHttp2HeadersUtils.java
@@ -101,18 +101,37 @@ class GrpcHttp2HeadersUtils {
values = new AsciiString[numHeadersGuess];
}
+ @SuppressWarnings("BetaApi") // BaseEncoding is stable in Guava 20.0
protected Http2Headers add(AsciiString name, AsciiString value) {
+ byte[] nameBytes = bytes(name);
+ byte[] valueBytes;
+ if (!name.endsWith(binaryHeaderSuffix)) {
+ valueBytes = bytes(value);
+ addHeader(value, nameBytes, valueBytes);
+ return this;
+ }
+ int startPos = 0;
+ int endPos = -1;
+ while (endPos < value.length()) {
+ int indexOfComma = value.indexOf(',', startPos);
+ endPos = indexOfComma == AsciiString.INDEX_NOT_FOUND ? value.length() : indexOfComma;
+ AsciiString curVal = value.subSequence(startPos, endPos, false);
+ valueBytes = BaseEncoding.base64().decode(curVal);
+ startPos = indexOfComma + 1;
+ addHeader(curVal, nameBytes, valueBytes);
+ }
+ return this;
+ }
+
+ private void addHeader(AsciiString value, byte[] nameBytes, byte[] valueBytes) {
if (namesAndValuesIdx == namesAndValues.length) {
expandHeadersAndValues();
}
- byte[] nameBytes = bytes(name);
- byte[] valueBytes = toBinaryValue(name, value);
values[namesAndValuesIdx / 2] = value;
namesAndValues[namesAndValuesIdx] = nameBytes;
namesAndValuesIdx++;
namesAndValues[namesAndValuesIdx] = valueBytes;
namesAndValuesIdx++;
- return this;
}
protected CharSequence get(AsciiString name) {
@@ -179,13 +198,6 @@ class GrpcHttp2HeadersUtils {
return PlatformDependent.equals(bytes0, offset0, bytes1, offset1, length0);
}
- @SuppressWarnings("BetaApi") // BaseEncoding is stable in Guava 20.0
- private static byte[] toBinaryValue(AsciiString name, AsciiString value) {
- return name.endsWith(binaryHeaderSuffix)
- ? BaseEncoding.base64().decode(value)
- : bytes(value);
- }
-
protected static byte[] bytes(AsciiString str) {
return str.isEntireArrayUsed() ? str.array() : str.toByteArray();
}
diff --git a/netty/src/main/java/io/grpc/netty/Utils.java b/netty/src/main/java/io/grpc/netty/Utils.java
index 3276592229..4b71ef0804 100644
--- a/netty/src/main/java/io/grpc/netty/Utils.java
+++ b/netty/src/main/java/io/grpc/netty/Utils.java
@@ -48,6 +48,7 @@ import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
+import javax.annotation.CheckReturnValue;
/**
* Common utility methods.
@@ -72,9 +73,6 @@ class Utils {
public static final Resource DEFAULT_WORKER_EVENT_LOOP_GROUP =
new DefaultEventLoopGroupResource(0, "grpc-default-worker-ELG");
- @VisibleForTesting
- static boolean validateHeaders = false;
-
public static Metadata convertHeaders(Http2Headers http2Headers) {
if (http2Headers instanceof GrpcHttp2InboundHeaders) {
GrpcHttp2InboundHeaders h = (GrpcHttp2InboundHeaders) http2Headers;
@@ -83,6 +81,7 @@ class Utils {
return InternalMetadata.newMetadata(convertHeadersToArray(http2Headers));
}
+ @CheckReturnValue
private static byte[][] convertHeadersToArray(Http2Headers http2Headers) {
// The Netty AsciiString class is really just a wrapper around a byte[] and supports
// arbitrary binary data, not just ASCII.
diff --git a/netty/src/test/java/io/grpc/netty/GrpcHttp2HeadersUtilsTest.java b/netty/src/test/java/io/grpc/netty/GrpcHttp2HeadersUtilsTest.java
index 0c548e64e0..96d87a5378 100644
--- a/netty/src/test/java/io/grpc/netty/GrpcHttp2HeadersUtilsTest.java
+++ b/netty/src/test/java/io/grpc/netty/GrpcHttp2HeadersUtilsTest.java
@@ -16,13 +16,20 @@
package io.grpc.netty;
+import static io.grpc.Metadata.BINARY_BYTE_MARSHALLER;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE;
import static io.netty.util.AsciiString.of;
import static org.hamcrest.CoreMatchers.containsString;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
+import com.google.common.collect.Iterables;
+import com.google.common.io.BaseEncoding;
+import io.grpc.Metadata;
+import io.grpc.Metadata.Key;
import io.grpc.netty.GrpcHttp2HeadersUtils.GrpcHttp2ClientHeadersDecoder;
+import io.grpc.netty.GrpcHttp2HeadersUtils.GrpcHttp2RequestHeaders;
import io.grpc.netty.GrpcHttp2HeadersUtils.GrpcHttp2ServerHeadersDecoder;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
@@ -33,6 +40,8 @@ import io.netty.handler.codec.http2.Http2Headers;
import io.netty.handler.codec.http2.Http2HeadersDecoder;
import io.netty.handler.codec.http2.Http2HeadersEncoder;
import io.netty.handler.codec.http2.Http2HeadersEncoder.SensitivityDetector;
+import io.netty.util.AsciiString;
+import java.util.Arrays;
import org.junit.After;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -125,6 +134,28 @@ public class GrpcHttp2HeadersUtilsTest {
assertThat(decodedHeaders.toString(), containsString("[]"));
}
+ @Test
+ public void dupBinHeadersWithComma() {
+ Key key = Key.of("bytes-bin", BINARY_BYTE_MARSHALLER);
+ Http2Headers http2Headers = new GrpcHttp2RequestHeaders(2);
+ http2Headers.add(AsciiString.of("bytes-bin"), AsciiString.of("BaS,e6,,4+,padding=="));
+ http2Headers.add(AsciiString.of("bytes-bin"), AsciiString.of("more"));
+ http2Headers.add(AsciiString.of("bytes-bin"), AsciiString.of(""));
+ Metadata recoveredHeaders = Utils.convertHeaders(http2Headers);
+ byte[][] values = Iterables.toArray(recoveredHeaders.getAll(key), byte[].class);
+
+ assertTrue(Arrays.deepEquals(
+ new byte[][] {
+ BaseEncoding.base64().decode("BaS"),
+ BaseEncoding.base64().decode("e6"),
+ BaseEncoding.base64().decode(""),
+ BaseEncoding.base64().decode("4+"),
+ BaseEncoding.base64().decode("padding"),
+ BaseEncoding.base64().decode("more"),
+ BaseEncoding.base64().decode("")},
+ values));
+ }
+
private static void assertContainsKeyAndValue(String str, CharSequence key, CharSequence value) {
assertThat(str, containsString(key.toString()));
assertThat(str, containsString(value.toString()));
diff --git a/okhttp/src/main/java/io/grpc/okhttp/Utils.java b/okhttp/src/main/java/io/grpc/okhttp/Utils.java
index 913369dc18..7ae35f9a37 100644
--- a/okhttp/src/main/java/io/grpc/okhttp/Utils.java
+++ b/okhttp/src/main/java/io/grpc/okhttp/Utils.java
@@ -29,6 +29,7 @@ import java.net.SocketException;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
+import javax.annotation.CheckReturnValue;
/**
* Common utility methods for OkHttp transport.
@@ -51,6 +52,7 @@ class Utils {
return InternalMetadata.newMetadata(convertHeadersToArray(http2Headers));
}
+ @CheckReturnValue
private static byte[][] convertHeadersToArray(List http2Headers) {
byte[][] headerValues = new byte[http2Headers.size() * 2][];
int i = 0;
diff --git a/testing/src/main/java/io/grpc/internal/testing/AbstractTransportTest.java b/testing/src/main/java/io/grpc/internal/testing/AbstractTransportTest.java
index 45a0741a7d..8b8582bb60 100644
--- a/testing/src/main/java/io/grpc/internal/testing/AbstractTransportTest.java
+++ b/testing/src/main/java/io/grpc/internal/testing/AbstractTransportTest.java
@@ -722,6 +722,7 @@ public abstract class AbstractTransportTest {
clientHeaders.put(asciiKey, "dupvalue");
clientHeaders.put(asciiKey, "dupvalue");
clientHeaders.put(binaryKey, "äbinaryclient");
+ clientHeaders.put(binaryKey, "dup,value");
Metadata clientHeadersCopy = new Metadata();
clientHeadersCopy.merge(clientHeaders);
@@ -790,14 +791,13 @@ public abstract class AbstractTransportTest {
serverHeaders.put(asciiKey, "dupvalue");
serverHeaders.put(asciiKey, "dupvalue");
serverHeaders.put(binaryKey, "äbinaryserver");
+ serverHeaders.put(binaryKey, "dup,value");
Metadata serverHeadersCopy = new Metadata();
serverHeadersCopy.merge(serverHeaders);
serverStream.writeHeaders(serverHeaders);
Metadata headers = clientStreamListener.headers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS);
assertNotNull(headers);
- assertEquals(
- Lists.newArrayList(serverHeadersCopy.getAll(asciiKey)),
- Lists.newArrayList(headers.getAll(asciiKey)));
+ assertAsciiMetadataValuesEqual(serverHeadersCopy.getAll(asciiKey), headers.getAll(asciiKey));
assertEquals(
Lists.newArrayList(serverHeadersCopy.getAll(binaryKey)),
Lists.newArrayList(headers.getAll(binaryKey)));
@@ -842,6 +842,7 @@ public abstract class AbstractTransportTest {
trailers.put(asciiKey, "dupvalue");
trailers.put(asciiKey, "dupvalue");
trailers.put(binaryKey, "äbinarytrailers");
+ trailers.put(binaryKey, "dup,value");
serverStream.close(status, trailers);
assertNull(serverStreamTracer1.nextInboundEvent());
assertNull(serverStreamTracer1.nextOutboundEvent());
@@ -855,9 +856,8 @@ public abstract class AbstractTransportTest {
assertNull(clientStreamTracer1.nextOutboundEvent());
assertEquals(status.getCode(), clientStreamStatus.getCode());
assertEquals(status.getDescription(), clientStreamStatus.getDescription());
- assertEquals(
- Lists.newArrayList(trailers.getAll(asciiKey)),
- Lists.newArrayList(clientStreamTrailers.getAll(asciiKey)));
+ assertAsciiMetadataValuesEqual(
+ trailers.getAll(asciiKey), clientStreamTrailers.getAll(asciiKey));
assertEquals(
Lists.newArrayList(trailers.getAll(binaryKey)),
Lists.newArrayList(clientStreamTrailers.getAll(binaryKey)));
@@ -883,6 +883,18 @@ public abstract class AbstractTransportTest {
assertEquals(testAuthority(server), serverStream.getAuthority());
}
+ private void assertAsciiMetadataValuesEqual(Iterable expected, Iterable actural) {
+ StringBuilder sbExpected = new StringBuilder();
+ for (String str : expected) {
+ sbExpected.append(str).append(",");
+ }
+ StringBuilder sbActual = new StringBuilder();
+ for (String str : actural) {
+ sbActual.append(str).append(",");
+ }
+ assertEquals(sbExpected.toString(), sbActual.toString());
+ }
+
@Test
public void zeroMessageStream() throws Exception {
server.start(serverListener);