diff --git a/core/src/main/java/io/grpc/ChannelImpl.java b/core/src/main/java/io/grpc/ChannelImpl.java index 3289626dd4..944bcc326b 100644 --- a/core/src/main/java/io/grpc/ChannelImpl.java +++ b/core/src/main/java/io/grpc/ChannelImpl.java @@ -31,6 +31,7 @@ package io.grpc; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; @@ -39,6 +40,7 @@ import io.grpc.transport.ClientStreamListener; import io.grpc.transport.ClientTransport; import io.grpc.transport.ClientTransport.PingCallback; import io.grpc.transport.ClientTransportFactory; +import io.grpc.transport.HttpUtil; import java.io.InputStream; import java.util.ArrayList; @@ -292,6 +294,7 @@ public final class ChannelImpl extends Channel { new Metadata.Trailers()); return; } + completeHeaders(headers); try { stream = transport.newStream(method, headers, listener); } catch (IllegalStateException ex) { @@ -350,6 +353,15 @@ public final class ChannelImpl extends Channel { return stream.isReady(); } + /** + * Set missing properties on the headers. The given headers will be mutated. + * @param headers the headers to complete + */ + private void completeHeaders(Metadata.Headers headers) { + headers.removeAll(TIMEOUT_KEY); + headers.put(TIMEOUT_KEY, method.getTimeout()); + } + private class ClientStreamListenerImpl implements ClientStreamListener { private final Listener observer; private boolean closed; @@ -423,4 +435,81 @@ public final class ChannelImpl extends Channel { } } } + + /** + * Intended for internal use only. + */ + // TODO(johnbcoughlin) make this package private when we can do so with the tests. + @VisibleForTesting + public static final Metadata.Key TIMEOUT_KEY = + Metadata.Key.of(HttpUtil.TIMEOUT, new TimeoutMarshaller()); + + /** + * Marshals a microseconds representation of the timeout to and from a string representation, + * consisting of an ASCII decimal representation of a number with at most 8 digits, followed by a + * unit: + * u = microseconds + * m = milliseconds + * S = seconds + * M = minutes + * H = hours + * + *

The representation is greedy with respect to precision. That is, 2 seconds will be + * represented as `2000000u`.

+ * + *

See the + * request header definition

+ */ + @VisibleForTesting + static class TimeoutMarshaller implements Metadata.AsciiMarshaller { + @Override + public String toAsciiString(Long timeoutMicros) { + Preconditions.checkArgument(timeoutMicros >= 0, "Negative timeout"); + long timeout; + String timeoutUnit; + // the smallest integer with 9 digits + int cutoff = 100000000; + if (timeoutMicros < cutoff) { + timeout = timeoutMicros; + timeoutUnit = "u"; + } else if (timeoutMicros / 1000 < cutoff) { + timeout = timeoutMicros / 1000; + timeoutUnit = "m"; + } else if (timeoutMicros / (1000 * 1000) < cutoff) { + timeout = timeoutMicros / (1000 * 1000); + timeoutUnit = "S"; + } else if (timeoutMicros / (60 * 1000 * 1000) < cutoff) { + timeout = timeoutMicros / (60 * 1000 * 1000); + timeoutUnit = "M"; + } else if (timeoutMicros / (60L * 60L * 1000L * 1000L) < cutoff) { + timeout = timeoutMicros / (60L * 60L * 1000L * 1000L); + timeoutUnit = "H"; + } else { + throw new IllegalArgumentException("Timeout too large"); + } + return Long.toString(timeout) + timeoutUnit; + } + + @Override + public Long parseAsciiString(String serialized) { + String valuePart = serialized.substring(0, serialized.length() - 1); + char unit = serialized.charAt(serialized.length() - 1); + long factor; + switch (unit) { + case 'u': + factor = 1; break; + case 'm': + factor = 1000L; break; + case 'S': + factor = 1000L * 1000L; break; + case 'M': + factor = 60L * 1000L * 1000L; break; + case 'H': + factor = 60L * 60L * 1000L * 1000L; break; + default: + throw new IllegalArgumentException(String.format("Invalid timeout unit: %s", unit)); + } + return Long.parseLong(valuePart) * factor; + } + } } diff --git a/core/src/main/java/io/grpc/transport/HttpUtil.java b/core/src/main/java/io/grpc/transport/HttpUtil.java index 1ad5258a71..398702cd2d 100644 --- a/core/src/main/java/io/grpc/transport/HttpUtil.java +++ b/core/src/main/java/io/grpc/transport/HttpUtil.java @@ -68,6 +68,11 @@ public final class HttpUtil { */ public static final String TE_TRAILERS = "trailers"; + /** + * The Timeout header name. + */ + public static final String TIMEOUT = "grpc-timeout"; + /** * Maps HTTP error response status codes to transport codes. */ diff --git a/core/src/test/java/io/grpc/MetadataTest.java b/core/src/test/java/io/grpc/MetadataTest.java index 347de50b8f..82c301c06d 100644 --- a/core/src/test/java/io/grpc/MetadataTest.java +++ b/core/src/test/java/io/grpc/MetadataTest.java @@ -145,6 +145,28 @@ public class MetadataTest { roundTripInteger(0x87654321); } + @Test + public void timeoutTest() { + ChannelImpl.TimeoutMarshaller marshaller = + new ChannelImpl.TimeoutMarshaller(); + assertEquals("1000u", marshaller.toAsciiString(1000L)); + assertEquals(1000L, (long) marshaller.parseAsciiString("1000u")); + + assertEquals("100000m", marshaller.toAsciiString(100000000L)); + assertEquals(100000000L, (long) marshaller.parseAsciiString("100000m")); + + assertEquals("100000S", marshaller.toAsciiString(100000000000L)); + assertEquals(100000000000L, (long) marshaller.parseAsciiString("100000S")); + + // 1,666,667 * 60 has 9 digits + assertEquals("1666666M", marshaller.toAsciiString(100000000000000L)); + assertEquals(60000000000000L, (long) marshaller.parseAsciiString("1000000M")); + + // 1,666,667 * 60 has 9 digits + assertEquals("1666666H", marshaller.toAsciiString(6000000000000000L)); + assertEquals(3600000000000000L, (long) marshaller.parseAsciiString("1000000H")); + } + private void roundTripInteger(Integer i) { assertEquals(i, Metadata.INTEGER_MARSHALLER.parseAsciiString( Metadata.INTEGER_MARSHALLER.toAsciiString(i))); diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractTransportTest.java b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractTransportTest.java index ee5094fe56..54b00d8fd1 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractTransportTest.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractTransportTest.java @@ -93,6 +93,8 @@ public abstract class AbstractTransportTest { public static final Metadata.Key METADATA_KEY = ProtoUtils.keyForProto(Messages.SimpleContext.getDefaultInstance()); + private static final AtomicReference requestHeadersCapture = + new AtomicReference(); private static ScheduledExecutorService testServiceExecutor; private static ServerImpl server; private static int OPERATION_TIMEOUT = 5000; @@ -102,6 +104,7 @@ public abstract class AbstractTransportTest { builder.addService(ServerInterceptors.intercept( TestServiceGrpc.bindService(new TestServiceImpl(testServiceExecutor)), + TestUtils.recordRequestHeadersInterceptor(requestHeadersCapture), TestUtils.echoRequestHeadersInterceptor(Util.METADATA_KEY))); try { server = builder.build().start(); @@ -127,6 +130,7 @@ public abstract class AbstractTransportTest { channel = createChannel(); blockingStub = TestServiceGrpc.newBlockingStub(channel); asyncStub = TestServiceGrpc.newStub(channel); + requestHeadersCapture.set(null); } /** Clean up. */ @@ -572,6 +576,16 @@ public abstract class AbstractTransportTest { Assert.assertEquals(contextValue, trailersCapture.get().get(METADATA_KEY)); } + @Test(timeout = 10000) + public void sendsTimeoutHeader() { + TestServiceGrpc.TestServiceBlockingStub stub = TestServiceGrpc.newBlockingStub(channel) + .configureNewStub() + .setTimeout(572, TimeUnit.MILLISECONDS) + .build(); + stub.emptyCall(Empty.getDefaultInstance()); + Assert.assertEquals(572000L, (long) requestHeadersCapture.get().get(ChannelImpl.TIMEOUT_KEY)); + } + protected int unaryPayloadLength() { // 10MiB. diff --git a/netty/src/main/java/io/grpc/transport/netty/Utils.java b/netty/src/main/java/io/grpc/transport/netty/Utils.java index 3b2d3b051a..0a9506189f 100644 --- a/netty/src/main/java/io/grpc/transport/netty/Utils.java +++ b/netty/src/main/java/io/grpc/transport/netty/Utils.java @@ -91,6 +91,7 @@ class Utils { public static Metadata.Headers convertHeaders(Http2Headers http2Headers) { Metadata.Headers headers = new Metadata.Headers(convertHeadersToArray(http2Headers)); if (http2Headers.authority() != null) { + // toString() here is safe since it doesn't use the default Charset. headers.setAuthority(http2Headers.authority().toString()); } if (http2Headers.path() != null) { diff --git a/testing/src/main/java/io/grpc/testing/TestUtils.java b/testing/src/main/java/io/grpc/testing/TestUtils.java index 8b173a2030..026510d86f 100644 --- a/testing/src/main/java/io/grpc/testing/TestUtils.java +++ b/testing/src/main/java/io/grpc/testing/TestUtils.java @@ -59,6 +59,7 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSocketFactory; @@ -112,6 +113,24 @@ public class TestUtils { }; } + /** + * Capture the request headers from a client. Useful for testing metadata propagation without + * requiring that it be symmetric on client and server, as with + * {@link #echoRequestHeadersInterceptor}. + */ + public static ServerInterceptor recordRequestHeadersInterceptor( + final AtomicReference headersCapture) { + return new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall(String method, + ServerCall call, + Metadata.Headers requestHeaders, + ServerCallHandler next) { + headersCapture.set(requestHeaders); + return next.startCall(method, call, requestHeaders); + } + }; + } /** * Picks an unused port.