mirror of https://github.com/grpc/grpc-java.git
Pass timeout header in ChannelImpl
This commit is contained in:
parent
4322a43824
commit
77878a04ee
|
|
@ -31,6 +31,7 @@
|
|||
|
||||
package io.grpc;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.base.Throwables;
|
||||
|
||||
|
|
@ -39,6 +40,7 @@ import io.grpc.transport.ClientStreamListener;
|
|||
import io.grpc.transport.ClientTransport;
|
||||
import io.grpc.transport.ClientTransport.PingCallback;
|
||||
import io.grpc.transport.ClientTransportFactory;
|
||||
import io.grpc.transport.HttpUtil;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.util.ArrayList;
|
||||
|
|
@ -292,6 +294,7 @@ public final class ChannelImpl extends Channel {
|
|||
new Metadata.Trailers());
|
||||
return;
|
||||
}
|
||||
completeHeaders(headers);
|
||||
try {
|
||||
stream = transport.newStream(method, headers, listener);
|
||||
} catch (IllegalStateException ex) {
|
||||
|
|
@ -350,6 +353,15 @@ public final class ChannelImpl extends Channel {
|
|||
return stream.isReady();
|
||||
}
|
||||
|
||||
/**
|
||||
* Set missing properties on the headers. The given headers will be mutated.
|
||||
* @param headers the headers to complete
|
||||
*/
|
||||
private void completeHeaders(Metadata.Headers headers) {
|
||||
headers.removeAll(TIMEOUT_KEY);
|
||||
headers.put(TIMEOUT_KEY, method.getTimeout());
|
||||
}
|
||||
|
||||
private class ClientStreamListenerImpl implements ClientStreamListener {
|
||||
private final Listener<RespT> observer;
|
||||
private boolean closed;
|
||||
|
|
@ -423,4 +435,81 @@ public final class ChannelImpl extends Channel {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Intended for internal use only.
|
||||
*/
|
||||
// TODO(johnbcoughlin) make this package private when we can do so with the tests.
|
||||
@VisibleForTesting
|
||||
public static final Metadata.Key<Long> TIMEOUT_KEY =
|
||||
Metadata.Key.of(HttpUtil.TIMEOUT, new TimeoutMarshaller());
|
||||
|
||||
/**
|
||||
* Marshals a microseconds representation of the timeout to and from a string representation,
|
||||
* consisting of an ASCII decimal representation of a number with at most 8 digits, followed by a
|
||||
* unit:
|
||||
* u = microseconds
|
||||
* m = milliseconds
|
||||
* S = seconds
|
||||
* M = minutes
|
||||
* H = hours
|
||||
*
|
||||
* <p>The representation is greedy with respect to precision. That is, 2 seconds will be
|
||||
* represented as `2000000u`.</p>
|
||||
*
|
||||
* <p>See <a href="https://github.com/grpc/grpc-common/blob/master/PROTOCOL-HTTP2.md#requests">the
|
||||
* request header definition</a></p>
|
||||
*/
|
||||
@VisibleForTesting
|
||||
static class TimeoutMarshaller implements Metadata.AsciiMarshaller<Long> {
|
||||
@Override
|
||||
public String toAsciiString(Long timeoutMicros) {
|
||||
Preconditions.checkArgument(timeoutMicros >= 0, "Negative timeout");
|
||||
long timeout;
|
||||
String timeoutUnit;
|
||||
// the smallest integer with 9 digits
|
||||
int cutoff = 100000000;
|
||||
if (timeoutMicros < cutoff) {
|
||||
timeout = timeoutMicros;
|
||||
timeoutUnit = "u";
|
||||
} else if (timeoutMicros / 1000 < cutoff) {
|
||||
timeout = timeoutMicros / 1000;
|
||||
timeoutUnit = "m";
|
||||
} else if (timeoutMicros / (1000 * 1000) < cutoff) {
|
||||
timeout = timeoutMicros / (1000 * 1000);
|
||||
timeoutUnit = "S";
|
||||
} else if (timeoutMicros / (60 * 1000 * 1000) < cutoff) {
|
||||
timeout = timeoutMicros / (60 * 1000 * 1000);
|
||||
timeoutUnit = "M";
|
||||
} else if (timeoutMicros / (60L * 60L * 1000L * 1000L) < cutoff) {
|
||||
timeout = timeoutMicros / (60L * 60L * 1000L * 1000L);
|
||||
timeoutUnit = "H";
|
||||
} else {
|
||||
throw new IllegalArgumentException("Timeout too large");
|
||||
}
|
||||
return Long.toString(timeout) + timeoutUnit;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long parseAsciiString(String serialized) {
|
||||
String valuePart = serialized.substring(0, serialized.length() - 1);
|
||||
char unit = serialized.charAt(serialized.length() - 1);
|
||||
long factor;
|
||||
switch (unit) {
|
||||
case 'u':
|
||||
factor = 1; break;
|
||||
case 'm':
|
||||
factor = 1000L; break;
|
||||
case 'S':
|
||||
factor = 1000L * 1000L; break;
|
||||
case 'M':
|
||||
factor = 60L * 1000L * 1000L; break;
|
||||
case 'H':
|
||||
factor = 60L * 60L * 1000L * 1000L; break;
|
||||
default:
|
||||
throw new IllegalArgumentException(String.format("Invalid timeout unit: %s", unit));
|
||||
}
|
||||
return Long.parseLong(valuePart) * factor;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -68,6 +68,11 @@ public final class HttpUtil {
|
|||
*/
|
||||
public static final String TE_TRAILERS = "trailers";
|
||||
|
||||
/**
|
||||
* The Timeout header name.
|
||||
*/
|
||||
public static final String TIMEOUT = "grpc-timeout";
|
||||
|
||||
/**
|
||||
* Maps HTTP error response status codes to transport codes.
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -145,6 +145,28 @@ public class MetadataTest {
|
|||
roundTripInteger(0x87654321);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void timeoutTest() {
|
||||
ChannelImpl.TimeoutMarshaller marshaller =
|
||||
new ChannelImpl.TimeoutMarshaller();
|
||||
assertEquals("1000u", marshaller.toAsciiString(1000L));
|
||||
assertEquals(1000L, (long) marshaller.parseAsciiString("1000u"));
|
||||
|
||||
assertEquals("100000m", marshaller.toAsciiString(100000000L));
|
||||
assertEquals(100000000L, (long) marshaller.parseAsciiString("100000m"));
|
||||
|
||||
assertEquals("100000S", marshaller.toAsciiString(100000000000L));
|
||||
assertEquals(100000000000L, (long) marshaller.parseAsciiString("100000S"));
|
||||
|
||||
// 1,666,667 * 60 has 9 digits
|
||||
assertEquals("1666666M", marshaller.toAsciiString(100000000000000L));
|
||||
assertEquals(60000000000000L, (long) marshaller.parseAsciiString("1000000M"));
|
||||
|
||||
// 1,666,667 * 60 has 9 digits
|
||||
assertEquals("1666666H", marshaller.toAsciiString(6000000000000000L));
|
||||
assertEquals(3600000000000000L, (long) marshaller.parseAsciiString("1000000H"));
|
||||
}
|
||||
|
||||
private void roundTripInteger(Integer i) {
|
||||
assertEquals(i, Metadata.INTEGER_MARSHALLER.parseAsciiString(
|
||||
Metadata.INTEGER_MARSHALLER.toAsciiString(i)));
|
||||
|
|
|
|||
|
|
@ -93,6 +93,8 @@ public abstract class AbstractTransportTest {
|
|||
|
||||
public static final Metadata.Key<Messages.SimpleContext> METADATA_KEY =
|
||||
ProtoUtils.keyForProto(Messages.SimpleContext.getDefaultInstance());
|
||||
private static final AtomicReference<Metadata.Headers> requestHeadersCapture =
|
||||
new AtomicReference<Metadata.Headers>();
|
||||
private static ScheduledExecutorService testServiceExecutor;
|
||||
private static ServerImpl server;
|
||||
private static int OPERATION_TIMEOUT = 5000;
|
||||
|
|
@ -102,6 +104,7 @@ public abstract class AbstractTransportTest {
|
|||
|
||||
builder.addService(ServerInterceptors.intercept(
|
||||
TestServiceGrpc.bindService(new TestServiceImpl(testServiceExecutor)),
|
||||
TestUtils.recordRequestHeadersInterceptor(requestHeadersCapture),
|
||||
TestUtils.echoRequestHeadersInterceptor(Util.METADATA_KEY)));
|
||||
try {
|
||||
server = builder.build().start();
|
||||
|
|
@ -127,6 +130,7 @@ public abstract class AbstractTransportTest {
|
|||
channel = createChannel();
|
||||
blockingStub = TestServiceGrpc.newBlockingStub(channel);
|
||||
asyncStub = TestServiceGrpc.newStub(channel);
|
||||
requestHeadersCapture.set(null);
|
||||
}
|
||||
|
||||
/** Clean up. */
|
||||
|
|
@ -572,6 +576,16 @@ public abstract class AbstractTransportTest {
|
|||
Assert.assertEquals(contextValue, trailersCapture.get().get(METADATA_KEY));
|
||||
}
|
||||
|
||||
@Test(timeout = 10000)
|
||||
public void sendsTimeoutHeader() {
|
||||
TestServiceGrpc.TestServiceBlockingStub stub = TestServiceGrpc.newBlockingStub(channel)
|
||||
.configureNewStub()
|
||||
.setTimeout(572, TimeUnit.MILLISECONDS)
|
||||
.build();
|
||||
stub.emptyCall(Empty.getDefaultInstance());
|
||||
Assert.assertEquals(572000L, (long) requestHeadersCapture.get().get(ChannelImpl.TIMEOUT_KEY));
|
||||
}
|
||||
|
||||
|
||||
protected int unaryPayloadLength() {
|
||||
// 10MiB.
|
||||
|
|
|
|||
|
|
@ -91,6 +91,7 @@ class Utils {
|
|||
public static Metadata.Headers convertHeaders(Http2Headers http2Headers) {
|
||||
Metadata.Headers headers = new Metadata.Headers(convertHeadersToArray(http2Headers));
|
||||
if (http2Headers.authority() != null) {
|
||||
// toString() here is safe since it doesn't use the default Charset.
|
||||
headers.setAuthority(http2Headers.authority().toString());
|
||||
}
|
||||
if (http2Headers.path() != null) {
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@ import java.util.Collections;
|
|||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.net.ssl.SSLSocketFactory;
|
||||
|
|
@ -112,6 +113,24 @@ public class TestUtils {
|
|||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Capture the request headers from a client. Useful for testing metadata propagation without
|
||||
* requiring that it be symmetric on client and server, as with
|
||||
* {@link #echoRequestHeadersInterceptor}.
|
||||
*/
|
||||
public static ServerInterceptor recordRequestHeadersInterceptor(
|
||||
final AtomicReference<Metadata.Headers> headersCapture) {
|
||||
return new ServerInterceptor() {
|
||||
@Override
|
||||
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(String method,
|
||||
ServerCall<RespT> call,
|
||||
Metadata.Headers requestHeaders,
|
||||
ServerCallHandler<ReqT, RespT> next) {
|
||||
headersCapture.set(requestHeaders);
|
||||
return next.startCall(method, call, requestHeaders);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Picks an unused port.
|
||||
|
|
|
|||
Loading…
Reference in New Issue