Pass timeout header in ChannelImpl

This commit is contained in:
Jack Coughlin 2015-06-17 15:47:27 -07:00 committed by Eric Anderson
parent 4322a43824
commit 77878a04ee
6 changed files with 150 additions and 0 deletions

View File

@ -31,6 +31,7 @@
package io.grpc;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
@ -39,6 +40,7 @@ import io.grpc.transport.ClientStreamListener;
import io.grpc.transport.ClientTransport;
import io.grpc.transport.ClientTransport.PingCallback;
import io.grpc.transport.ClientTransportFactory;
import io.grpc.transport.HttpUtil;
import java.io.InputStream;
import java.util.ArrayList;
@ -292,6 +294,7 @@ public final class ChannelImpl extends Channel {
new Metadata.Trailers());
return;
}
completeHeaders(headers);
try {
stream = transport.newStream(method, headers, listener);
} catch (IllegalStateException ex) {
@ -350,6 +353,15 @@ public final class ChannelImpl extends Channel {
return stream.isReady();
}
/**
* Set missing properties on the headers. The given headers will be mutated.
* @param headers the headers to complete
*/
private void completeHeaders(Metadata.Headers headers) {
headers.removeAll(TIMEOUT_KEY);
headers.put(TIMEOUT_KEY, method.getTimeout());
}
private class ClientStreamListenerImpl implements ClientStreamListener {
private final Listener<RespT> observer;
private boolean closed;
@ -423,4 +435,81 @@ public final class ChannelImpl extends Channel {
}
}
}
/**
* Intended for internal use only.
*/
// TODO(johnbcoughlin) make this package private when we can do so with the tests.
@VisibleForTesting
public static final Metadata.Key<Long> TIMEOUT_KEY =
Metadata.Key.of(HttpUtil.TIMEOUT, new TimeoutMarshaller());
/**
* Marshals a microseconds representation of the timeout to and from a string representation,
* consisting of an ASCII decimal representation of a number with at most 8 digits, followed by a
* unit:
* u = microseconds
* m = milliseconds
* S = seconds
* M = minutes
* H = hours
*
* <p>The representation is greedy with respect to precision. That is, 2 seconds will be
* represented as `2000000u`.</p>
*
* <p>See <a href="https://github.com/grpc/grpc-common/blob/master/PROTOCOL-HTTP2.md#requests">the
* request header definition</a></p>
*/
@VisibleForTesting
static class TimeoutMarshaller implements Metadata.AsciiMarshaller<Long> {
@Override
public String toAsciiString(Long timeoutMicros) {
Preconditions.checkArgument(timeoutMicros >= 0, "Negative timeout");
long timeout;
String timeoutUnit;
// the smallest integer with 9 digits
int cutoff = 100000000;
if (timeoutMicros < cutoff) {
timeout = timeoutMicros;
timeoutUnit = "u";
} else if (timeoutMicros / 1000 < cutoff) {
timeout = timeoutMicros / 1000;
timeoutUnit = "m";
} else if (timeoutMicros / (1000 * 1000) < cutoff) {
timeout = timeoutMicros / (1000 * 1000);
timeoutUnit = "S";
} else if (timeoutMicros / (60 * 1000 * 1000) < cutoff) {
timeout = timeoutMicros / (60 * 1000 * 1000);
timeoutUnit = "M";
} else if (timeoutMicros / (60L * 60L * 1000L * 1000L) < cutoff) {
timeout = timeoutMicros / (60L * 60L * 1000L * 1000L);
timeoutUnit = "H";
} else {
throw new IllegalArgumentException("Timeout too large");
}
return Long.toString(timeout) + timeoutUnit;
}
@Override
public Long parseAsciiString(String serialized) {
String valuePart = serialized.substring(0, serialized.length() - 1);
char unit = serialized.charAt(serialized.length() - 1);
long factor;
switch (unit) {
case 'u':
factor = 1; break;
case 'm':
factor = 1000L; break;
case 'S':
factor = 1000L * 1000L; break;
case 'M':
factor = 60L * 1000L * 1000L; break;
case 'H':
factor = 60L * 60L * 1000L * 1000L; break;
default:
throw new IllegalArgumentException(String.format("Invalid timeout unit: %s", unit));
}
return Long.parseLong(valuePart) * factor;
}
}
}

View File

@ -68,6 +68,11 @@ public final class HttpUtil {
*/
public static final String TE_TRAILERS = "trailers";
/**
* The Timeout header name.
*/
public static final String TIMEOUT = "grpc-timeout";
/**
* Maps HTTP error response status codes to transport codes.
*/

View File

@ -145,6 +145,28 @@ public class MetadataTest {
roundTripInteger(0x87654321);
}
@Test
public void timeoutTest() {
ChannelImpl.TimeoutMarshaller marshaller =
new ChannelImpl.TimeoutMarshaller();
assertEquals("1000u", marshaller.toAsciiString(1000L));
assertEquals(1000L, (long) marshaller.parseAsciiString("1000u"));
assertEquals("100000m", marshaller.toAsciiString(100000000L));
assertEquals(100000000L, (long) marshaller.parseAsciiString("100000m"));
assertEquals("100000S", marshaller.toAsciiString(100000000000L));
assertEquals(100000000000L, (long) marshaller.parseAsciiString("100000S"));
// 1,666,667 * 60 has 9 digits
assertEquals("1666666M", marshaller.toAsciiString(100000000000000L));
assertEquals(60000000000000L, (long) marshaller.parseAsciiString("1000000M"));
// 1,666,667 * 60 has 9 digits
assertEquals("1666666H", marshaller.toAsciiString(6000000000000000L));
assertEquals(3600000000000000L, (long) marshaller.parseAsciiString("1000000H"));
}
private void roundTripInteger(Integer i) {
assertEquals(i, Metadata.INTEGER_MARSHALLER.parseAsciiString(
Metadata.INTEGER_MARSHALLER.toAsciiString(i)));

View File

@ -93,6 +93,8 @@ public abstract class AbstractTransportTest {
public static final Metadata.Key<Messages.SimpleContext> METADATA_KEY =
ProtoUtils.keyForProto(Messages.SimpleContext.getDefaultInstance());
private static final AtomicReference<Metadata.Headers> requestHeadersCapture =
new AtomicReference<Metadata.Headers>();
private static ScheduledExecutorService testServiceExecutor;
private static ServerImpl server;
private static int OPERATION_TIMEOUT = 5000;
@ -102,6 +104,7 @@ public abstract class AbstractTransportTest {
builder.addService(ServerInterceptors.intercept(
TestServiceGrpc.bindService(new TestServiceImpl(testServiceExecutor)),
TestUtils.recordRequestHeadersInterceptor(requestHeadersCapture),
TestUtils.echoRequestHeadersInterceptor(Util.METADATA_KEY)));
try {
server = builder.build().start();
@ -127,6 +130,7 @@ public abstract class AbstractTransportTest {
channel = createChannel();
blockingStub = TestServiceGrpc.newBlockingStub(channel);
asyncStub = TestServiceGrpc.newStub(channel);
requestHeadersCapture.set(null);
}
/** Clean up. */
@ -572,6 +576,16 @@ public abstract class AbstractTransportTest {
Assert.assertEquals(contextValue, trailersCapture.get().get(METADATA_KEY));
}
@Test(timeout = 10000)
public void sendsTimeoutHeader() {
TestServiceGrpc.TestServiceBlockingStub stub = TestServiceGrpc.newBlockingStub(channel)
.configureNewStub()
.setTimeout(572, TimeUnit.MILLISECONDS)
.build();
stub.emptyCall(Empty.getDefaultInstance());
Assert.assertEquals(572000L, (long) requestHeadersCapture.get().get(ChannelImpl.TIMEOUT_KEY));
}
protected int unaryPayloadLength() {
// 10MiB.

View File

@ -91,6 +91,7 @@ class Utils {
public static Metadata.Headers convertHeaders(Http2Headers http2Headers) {
Metadata.Headers headers = new Metadata.Headers(convertHeadersToArray(http2Headers));
if (http2Headers.authority() != null) {
// toString() here is safe since it doesn't use the default Charset.
headers.setAuthority(http2Headers.authority().toString());
}
if (http2Headers.path() != null) {

View File

@ -59,6 +59,7 @@ import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
@ -112,6 +113,24 @@ public class TestUtils {
};
}
/**
* Capture the request headers from a client. Useful for testing metadata propagation without
* requiring that it be symmetric on client and server, as with
* {@link #echoRequestHeadersInterceptor}.
*/
public static ServerInterceptor recordRequestHeadersInterceptor(
final AtomicReference<Metadata.Headers> headersCapture) {
return new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(String method,
ServerCall<RespT> call,
Metadata.Headers requestHeaders,
ServerCallHandler<ReqT, RespT> next) {
headersCapture.set(requestHeaders);
return next.startCall(method, call, requestHeaders);
}
};
}
/**
* Picks an unused port.