Adding support for manually specifying HTTP/2 :authority and :path headers.

-------------
Created by MOE: http://code.google.com/p/moe-java
MOE_MIGRATED_REVID=75316631
This commit is contained in:
nathanmittler 2014-09-11 12:50:16 -07:00 committed by Eric Anderson
parent a7d735e69b
commit 23fbc7cb5e
15 changed files with 303 additions and 178 deletions

View File

@ -138,8 +138,6 @@ public final class ChannelImpl extends AbstractService implements Channel {
@Override
public void start(Listener<RespT> observer, Metadata.Headers headers) {
Preconditions.checkState(stream == null, "Already started");
headers.setPath(method.getName());
headers.setAuthority("fixme");
stream = obtainActiveTransport().newStream(method, headers,
new StreamListenerImpl(observer));
}

View File

@ -301,6 +301,26 @@ public abstract class Metadata<S extends Metadata> {
public void setAuthority(String authority) {
this.authority = authority;
}
@Override
public void merge(Metadata other) {
super.merge(other);
mergePathAndAuthority(other);
}
@Override
public void merge(Metadata other, Set<Key> keys) {
super.merge(other, keys);
mergePathAndAuthority(other);
}
private void mergePathAndAuthority(Metadata other) {
if (other instanceof Headers) {
Headers otherHeaders = (Headers) other;
path = otherHeaders.path != null ? otherHeaders.path : path;
authority = otherHeaders.authority != null ? otherHeaders.authority : authority;
}
}
}
/**

View File

@ -21,8 +21,10 @@ import java.util.List;
public class Http2Request extends Http2Operation implements Request {
private final Response response;
public Http2Request(FrameWriter frameWriter, String operationName,
public Http2Request(FrameWriter frameWriter,
Metadata.Headers headers,
String defaultPath,
String defaultAuthority,
Response response, RequestRegistry requestRegistry,
Framer framer) {
super(response.getId(), frameWriter, framer);
@ -31,8 +33,8 @@ public class Http2Request extends Http2Operation implements Request {
// Register this request.
requestRegistry.register(this);
List<Header> requestHeaders = Headers.createRequestHeaders(operationName,
headers.serialize());
List<Header> requestHeaders =
Headers.createRequestHeaders(headers, defaultPath, defaultAuthority);
frameWriter.synStream(false, false, getId(), 0, requestHeaders);
} catch (IOException ioe) {
close(new Status(Transport.Code.UNKNOWN, ioe));

View File

@ -32,6 +32,7 @@ import okio.ByteString;
import okio.Okio;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.util.List;
import java.util.concurrent.Executor;
@ -81,6 +82,7 @@ public class OkHttpSession implements Session {
}
}
private final String defaultAuthority;
private final FrameReader frameReader;
private final FrameWriter frameWriter;
private final AtomicInteger sessionId;
@ -108,6 +110,10 @@ public class OkHttpSession implements Session {
this.serverSession = null;
this.requestRegistry = requestRegistry;
executor.execute(new FrameHandler());
// Determine the default :authority header to use.
InetSocketAddress remoteAddress = (InetSocketAddress) socket.getRemoteSocketAddress();
defaultAuthority = remoteAddress.getHostString() + ":" + remoteAddress.getPort();
}
/**
@ -129,6 +135,9 @@ public class OkHttpSession implements Session {
this.serverSession = server;
this.requestRegistry = requestRegistry;
executor.execute(new FrameHandler());
// Authority is not used for server-side sessions.
defaultAuthority = null;
}
@Override
@ -147,13 +156,18 @@ public class OkHttpSession implements Session {
}
@Override
public Request startRequest(String operationName,
Metadata.Headers headers,
Response.ResponseBuilder responseBuilder) {
public Request startRequest(String operationName, Metadata.Headers headers,
Response.ResponseBuilder responseBuilder) {
int nextStreamId = getNextStreamId();
Response response = responseBuilder.build(nextStreamId);
Http2Request request = new Http2Request(frameWriter, operationName, headers, response,
requestRegistry, new MessageFramer(4096));
String defaultPath = "/" + operationName;
Http2Request request = new Http2Request(frameWriter,
headers,
defaultPath,
defaultAuthority,
response,
requestRegistry,
new MessageFramer(4096));
return request;
}
@ -259,7 +273,22 @@ public class OkHttpSession implements Session {
// Start an Operation for SYN_STREAM
if (op == null && headersMode == HeadersMode.HTTP_20_HEADERS) {
// TODO(user): Throwing inside this method seems to cause a request to
// hang indefinitely ... possibly an OkHttp bug? We should investigate
// this and come up with a solution that works for any handler method that encounters
// an exception.
String path = findReservedHeader(Header.TARGET_PATH.utf8(), headers);
if (path == null) {
try {
// The :path MUST be provided. This is a protocol error.
frameWriter.rstStream(streamId, ErrorCode.PROTOCOL_ERROR);
frameWriter.flush();
return;
} catch (IOException e) {
throw new RuntimeException(e);
}
}
byte[][] binaryHeaders = new byte[headers.size() * 2][];
for (int i = 0; i < headers.size(); i++) {
Header header = headers.get(i);
@ -269,13 +298,10 @@ public class OkHttpSession implements Session {
Metadata.Headers grpcHeaders = new Metadata.Headers(binaryHeaders);
grpcHeaders.setPath(path);
grpcHeaders.setAuthority(findReservedHeader(Header.TARGET_AUTHORITY.utf8(), headers));
if (path != null) {
Request request = serverSession.startRequest(path,
grpcHeaders,
Http2Response.builder(streamId, frameWriter, new MessageFramer(4096)));
requestRegistry.register(request);
op = request;
}
Request request = serverSession.startRequest(path, grpcHeaders,
Http2Response.builder(streamId, frameWriter, new MessageFramer(4096)));
requestRegistry.register(request);
op = request;
}
if (op == null) {
return;
@ -291,10 +317,11 @@ public class OkHttpSession implements Session {
for (Header header : headers) {
// Reserved headers must come before non-reserved headers, so we can exit the loop
// early if we see a non-reserved header.
if (!header.name.utf8().startsWith(":")) {
return null;
String headerString = header.name.utf8();
if (!headerString.startsWith(":")) {
break;
}
if (header.name.utf8().equals(name)) {
if (headerString.equals(name)) {
return header.value.utf8();
}
}

View File

@ -1,33 +1,28 @@
package com.google.net.stubby.newtransport.netty;
import com.google.common.base.Preconditions;
import com.google.net.stubby.MethodDescriptor;
import io.netty.handler.codec.http2.Http2Headers;
/**
* A command to create a new stream. This is created by {@link NettyClientStream} and passed to the
* {@link NettyClientHandler} for processing in the Channel thread.
*/
class CreateStreamCommand {
private final MethodDescriptor<?, ?> method;
private final String[] headers;
private final Http2Headers headers;
private final NettyClientStream stream;
CreateStreamCommand(MethodDescriptor<?, ?> method, String[] headers,
CreateStreamCommand(Http2Headers headers,
NettyClientStream stream) {
this.method = Preconditions.checkNotNull(method, "method");
this.stream = Preconditions.checkNotNull(stream, "stream");
this.headers = Preconditions.checkNotNull(headers, "headers");
}
MethodDescriptor<?, ?> method() {
return method;
}
NettyClientStream stream() {
return stream;
}
String[] headers() {
Http2Headers headers() {
return headers;
}
}

View File

@ -1,13 +1,9 @@
package com.google.net.stubby.newtransport.netty;
import static com.google.net.stubby.newtransport.HttpUtil.CONTENT_TYPE_HEADER;
import static com.google.net.stubby.newtransport.HttpUtil.CONTENT_TYPE_PROTORPC;
import static com.google.net.stubby.newtransport.HttpUtil.HTTP_METHOD;
import static com.google.net.stubby.newtransport.netty.NettyClientStream.PENDING_STREAM_ID;
import com.google.common.base.Preconditions;
import com.google.net.stubby.Metadata;
import com.google.net.stubby.MethodDescriptor;
import com.google.net.stubby.Status;
import com.google.net.stubby.transport.Transport;
@ -17,7 +13,6 @@ import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http2.AbstractHttp2ConnectionHandler;
import io.netty.handler.codec.http2.DefaultHttp2Headers;
import io.netty.handler.codec.http2.DefaultHttp2InboundFlowController;
import io.netty.handler.codec.http2.Http2Connection;
import io.netty.handler.codec.http2.Http2ConnectionAdapter;
@ -45,35 +40,27 @@ class NettyClientHandler extends AbstractHttp2ConnectionHandler {
* A pending stream creation.
*/
private final class PendingStream {
private final MethodDescriptor<?, ?> method;
private final String[] headers;
private final Http2Headers headers;
private final NettyClientStream stream;
private final ChannelPromise promise;
public PendingStream(CreateStreamCommand command, ChannelPromise promise) {
method = command.method();
headers = command.headers();
stream = command.stream();
this.promise = promise;
}
}
private final String host;
private final String scheme;
private final DefaultHttp2InboundFlowController inboundFlow;
private final Deque<PendingStream> pendingStreams = new ArrayDeque<PendingStream>();
private Status goAwayStatus = GOAWAY_STATUS;
public NettyClientHandler(String host,
boolean ssl,
Http2Connection connection,
public NettyClientHandler(Http2Connection connection,
Http2FrameReader frameReader,
Http2FrameWriter frameWriter,
DefaultHttp2InboundFlowController inboundFlow,
Http2OutboundFlowController outboundFlow) {
super(connection, frameReader, frameWriter, inboundFlow, outboundFlow);
this.host = Preconditions.checkNotNull(host, "host");
this.scheme = ssl ? "https" : "http";
this.inboundFlow = Preconditions.checkNotNull(inboundFlow, "inboundFlow");
// Disallow stream creation by the server.
@ -320,22 +307,7 @@ class NettyClientHandler extends AbstractHttp2ConnectionHandler {
// Finish creation of the stream by writing a headers frame.
final PendingStream pendingStream = pendingStreams.remove();
// TODO(user): Change Netty to not send priority, just use default.
// TODO(user): Switch to binary headers when Netty supports it.
DefaultHttp2Headers.Builder headersBuilder = DefaultHttp2Headers.newBuilder();
for (int i = 0; i < pendingStream.headers.length; i++) {
headersBuilder.add(
pendingStream.headers[i],
pendingStream.headers[++i]);
}
headersBuilder
.method(HTTP_METHOD)
.authority(host)
.scheme(scheme)
.add(CONTENT_TYPE_HEADER, CONTENT_TYPE_PROTORPC)
.path("/" + pendingStream.method.getName())
.build();
writeHeaders(ctx(), streamId, headersBuilder.build(), 0, false, ctx().newPromise())
writeHeaders(ctx(), streamId, pendingStream.headers, 0, false, ctx().newPromise())
.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {

View File

@ -29,11 +29,13 @@ import io.netty.handler.codec.http2.Http2Connection;
import io.netty.handler.codec.http2.Http2FrameLogger;
import io.netty.handler.codec.http2.Http2FrameReader;
import io.netty.handler.codec.http2.Http2FrameWriter;
import io.netty.handler.codec.http2.Http2Headers;
import io.netty.handler.codec.http2.Http2InboundFrameLogger;
import io.netty.handler.codec.http2.Http2OutboundFlowController;
import io.netty.handler.codec.http2.Http2OutboundFrameLogger;
import io.netty.util.internal.logging.InternalLogLevel;
import java.net.InetSocketAddress;
import java.util.concurrent.ExecutionException;
import javax.net.ssl.SSLEngine;
@ -43,39 +45,41 @@ import javax.net.ssl.SSLEngine;
*/
class NettyClientTransport extends AbstractClientTransport {
private final String host;
private final int port;
private final InetSocketAddress address;
private final EventLoopGroup eventGroup;
private final Http2Negotiator.Negotiation negotiation;
private final NettyClientHandler handler;
private final boolean ssl;
private final String authority;
private Channel channel;
NettyClientTransport(String host, int port, NegotiationType negotiationType) {
this(host, port, negotiationType, new NioEventLoopGroup());
NettyClientTransport(InetSocketAddress address, NegotiationType negotiationType) {
this(address, negotiationType, new NioEventLoopGroup());
}
NettyClientTransport(String host, int port, NegotiationType negotiationType,
NettyClientTransport(InetSocketAddress address, NegotiationType negotiationType,
EventLoopGroup eventGroup) {
Preconditions.checkNotNull(host, "host");
Preconditions.checkArgument(port >= 0, "port must be positive");
Preconditions.checkNotNull(eventGroup, "eventGroup");
Preconditions.checkNotNull(negotiationType, "negotiationType");
this.host = host;
this.port = port;
this.eventGroup = eventGroup;
this.address = Preconditions.checkNotNull(address, "address");
this.eventGroup = Preconditions.checkNotNull(eventGroup, "eventGroup");
handler = newHandler(host, negotiationType == NegotiationType.TLS);
authority = address.getHostString() + ":" + address.getPort();
handler = newHandler();
switch (negotiationType) {
case PLAINTEXT:
negotiation = Http2Negotiator.plaintext(handler);
ssl = false;
break;
case PLAINTEXT_UPGRADE:
negotiation = Http2Negotiator.plaintextUpgrade(handler);
ssl = false;
break;
case TLS:
SSLEngine sslEngine = SslContextFactory.getClientContext().createSSLEngine();
sslEngine.setUseClientMode(true);
negotiation = Http2Negotiator.tls(handler, sslEngine);
ssl = true;
break;
default:
throw new IllegalArgumentException("Unsupported negotiationType: " + negotiationType);
@ -83,17 +87,18 @@ class NettyClientTransport extends AbstractClientTransport {
}
@Override
protected ClientStream newStreamInternal(MethodDescriptor<?, ?> method,
Metadata.Headers headers,
StreamListener listener) {
protected ClientStream newStreamInternal(MethodDescriptor<?, ?> method, Metadata.Headers headers,
StreamListener listener) {
// Create the stream.
NettyClientStream stream = new NettyClientStream(listener, channel, handler.inboundFlow());
try {
// Convert the headers into Netty HTTP/2 headers.
String defaultPath = "/" + method.getName();
Http2Headers http2Headers = Utils.convertHeaders(headers, ssl, defaultPath, authority);
// Write the request and await creation of the stream.
channel.writeAndFlush(new CreateStreamCommand(method,
headers.serializeAscii(),
stream)).get();
channel.writeAndFlush(new CreateStreamCommand(http2Headers, stream)).get();
} catch (InterruptedException e) {
// Restore the interrupt.
Thread.currentThread().interrupt();
@ -116,7 +121,7 @@ class NettyClientTransport extends AbstractClientTransport {
b.handler(negotiation.initializer());
// Start the connection operation to the server.
b.connect(host, port).addListener(new ChannelFutureListener() {
b.connect(address).addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) {
@ -154,7 +159,7 @@ class NettyClientTransport extends AbstractClientTransport {
}
}
private static NettyClientHandler newHandler(String host, boolean ssl) {
private static NettyClientHandler newHandler() {
Http2Connection connection =
new DefaultHttp2Connection(false, new DefaultHttp2StreamRemovalPolicy());
Http2FrameReader frameReader = new DefaultHttp2FrameReader();
@ -168,12 +173,6 @@ class NettyClientTransport extends AbstractClientTransport {
new DefaultHttp2InboundFlowController(connection, frameWriter);
Http2OutboundFlowController outboundFlow =
new DefaultHttp2OutboundFlowController(connection, frameWriter);
return new NettyClientHandler(host,
ssl,
connection,
frameReader,
frameWriter,
inboundFlow,
outboundFlow);
return new NettyClientHandler(connection, frameReader, frameWriter, inboundFlow, outboundFlow);
}
}

View File

@ -3,6 +3,8 @@ package com.google.net.stubby.newtransport.netty;
import com.google.common.base.Preconditions;
import com.google.net.stubby.newtransport.ClientTransportFactory;
import java.net.InetSocketAddress;
import io.netty.channel.EventLoopGroup;
/**
@ -31,22 +33,19 @@ public class NettyClientTransportFactory implements ClientTransportFactory {
PLAINTEXT
}
private final String host;
private final int port;
private final InetSocketAddress address;
private final NegotiationType negotiationType;
private final EventLoopGroup group;
public NettyClientTransportFactory(String host, int port, NegotiationType negotiationType,
public NettyClientTransportFactory(InetSocketAddress address, NegotiationType negotiationType,
EventLoopGroup group) {
this.address = Preconditions.checkNotNull(address, "address");
this.group = Preconditions.checkNotNull(group, "group");
Preconditions.checkArgument(port > 0, "Port must be positive");
this.host = Preconditions.checkNotNull(host, "host");
this.negotiationType = Preconditions.checkNotNull(negotiationType, "negotiationType");
this.port = port;
}
@Override
public NettyClientTransport newClientTransport() {
return new NettyClientTransport(host, port, negotiationType, group);
return new NettyClientTransport(address, negotiationType, group);
}
}

View File

@ -1,9 +1,17 @@
package com.google.net.stubby.newtransport.netty;
import static com.google.net.stubby.newtransport.HttpUtil.CONTENT_TYPE_HEADER;
import static com.google.net.stubby.newtransport.HttpUtil.CONTENT_TYPE_PROTORPC;
import static com.google.net.stubby.newtransport.HttpUtil.HTTP_METHOD;
import static io.netty.util.CharsetUtil.UTF_8;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.net.stubby.Metadata;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.handler.codec.http2.DefaultHttp2Headers;
import io.netty.handler.codec.http2.Http2Headers;
import java.nio.ByteBuffer;
@ -25,6 +33,43 @@ class Utils {
return buf;
}
public static Http2Headers convertHeaders(Metadata.Headers headers,
boolean ssl,
String defaultPath,
String defaultAuthority) {
Preconditions.checkNotNull(headers, "headers");
Preconditions.checkNotNull(defaultPath, "defaultPath");
Preconditions.checkNotNull(defaultAuthority, "defaultAuthority");
DefaultHttp2Headers.Builder headersBuilder = DefaultHttp2Headers.newBuilder();
// Add any application-provided headers first.
byte[][] serializedHeaders = headers.serialize();
for (int i = 0; i < serializedHeaders.length; i++) {
String key = new String(serializedHeaders[i], UTF_8);
String value = new String(serializedHeaders[++i], UTF_8);
headersBuilder.add(key, value);
}
// Now set GRPC-specific default headers.
headersBuilder
.authority(defaultAuthority)
.path(defaultPath)
.method(HTTP_METHOD)
.scheme(ssl? "https" : "http")
.add(CONTENT_TYPE_HEADER, CONTENT_TYPE_PROTORPC);
// Override the default authority and path if provided by the headers.
if (headers.getAuthority() != null) {
headersBuilder.authority(headers.getAuthority());
}
if (headers.getPath() != null) {
headersBuilder.path(headers.getPath());
}
return headersBuilder.build();
}
public static ImmutableMap<String, Provider<String>> convertHeaders(Http2Headers headers) {
ImmutableMap.Builder<String, Provider<String>> grpcHeaders =
new ImmutableMap.Builder<String, Provider<String>>();

View File

@ -1,6 +1,9 @@
package com.google.net.stubby.newtransport.okhttp;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.net.stubby.Metadata;
import com.google.net.stubby.newtransport.HttpUtil;
import com.squareup.okhttp.internal.spdy.Header;
@ -12,19 +15,47 @@ import java.util.List;
* Constants for request/response headers.
*/
public class Headers {
public static final Header SCHEME_HEADER = new Header(Header.TARGET_SCHEME, "https");
public static final Header METHOD_HEADER = new Header(Header.TARGET_METHOD, HttpUtil.HTTP_METHOD);
public static final Header CONTENT_TYPE_HEADER =
new Header("content-type", "application/protorpc");
new Header(HttpUtil.CONTENT_TYPE_HEADER, HttpUtil.CONTENT_TYPE_PROTORPC);
public static final Header RESPONSE_STATUS_OK = new Header(Header.RESPONSE_STATUS, "200");
public static List<Header> createRequestHeaders(String operationName, byte[][] headers) {
/**
* Serializes the given headers and creates a list of OkHttp {@link Header}s to be used when
* creating a stream. Since this serializes the headers, this method should be called in the
* application thread context.
*/
public static List<Header> createRequestHeaders(Metadata.Headers headers, String defaultPath,
String defaultAuthority) {
Preconditions.checkNotNull(headers, "headers");
Preconditions.checkNotNull(defaultPath, "defaultPath");
Preconditions.checkNotNull(defaultAuthority, "defaultAuthority");
List<Header> okhttpHeaders = Lists.newArrayListWithCapacity(6);
okhttpHeaders.add(new Header(Header.TARGET_PATH, operationName));
// Set GRPC-specific headers.
okhttpHeaders.add(SCHEME_HEADER);
okhttpHeaders.add(METHOD_HEADER);
String authority = headers.getAuthority() != null ? headers.getAuthority() : defaultAuthority;
okhttpHeaders.add(new Header(Header.TARGET_AUTHORITY, authority));
String path = headers.getPath() != null ? headers.getPath() : defaultPath;
okhttpHeaders.add(new Header(Header.TARGET_PATH, path));
// All non-pseudo headers must come after pseudo headers.
okhttpHeaders.add(CONTENT_TYPE_HEADER);
for (int i = 0; i < headers.length; i++) {
okhttpHeaders.add(new Header(ByteString.of(headers[i]), ByteString.of(headers[++i])));
// Now add any application-provided headers.
byte[][] serializedHeaders = headers.serialize();
for (int i = 0; i < serializedHeaders.length; i++) {
ByteString key = ByteString.of(serializedHeaders[i]);
ByteString value = ByteString.of(serializedHeaders[++i]);
if (isApplicationHeader(key)) {
okhttpHeaders.add(new Header(key, value));
}
}
return okhttpHeaders;
}
@ -34,4 +65,15 @@ public class Headers {
headers.add(RESPONSE_STATUS_OK);
return headers;
}
/**
* Returns {@code true} if the given header is an application-provided header. Otherwise, returns
* {@code false} if the header is reserved by GRPC.
*/
private static boolean isApplicationHeader(ByteString key) {
String keyString = key.utf8();
// Don't allow HTTP/2 pseudo headers or content-type to be added by the applciation.
return (!keyString.startsWith(":")
&& !HttpUtil.CONTENT_TYPE_HEADER.equalsIgnoreCase(keyString));
}
}

View File

@ -32,6 +32,7 @@ import okio.ByteString;
import okio.Okio;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.util.ArrayList;
@ -84,8 +85,8 @@ public class OkHttpClientTransport extends AbstractClientTransport {
ERROR_CODE_TO_STATUS = Collections.unmodifiableMap(errorToStatus);
}
private final String host;
private final int port;
private final InetSocketAddress address;
private final String defaultAuthority;
private FrameReader frameReader;
private AsyncFrameWriter frameWriter;
private final Object lock = new Object();
@ -102,9 +103,9 @@ public class OkHttpClientTransport extends AbstractClientTransport {
@GuardedBy("lock")
private Status goAwayStatus;
OkHttpClientTransport(String host, int port, Executor executor) {
this.host = Preconditions.checkNotNull(host);
this.port = port;
OkHttpClientTransport(InetSocketAddress address, Executor executor) {
this.address = Preconditions.checkNotNull(address);
defaultAuthority = address.getHostString() + ":" + address.getPort();
this.executor = Preconditions.checkNotNull(executor);
// Client initiated streams are odd, server initiated ones are even. Server should not need to
// use it. We start clients at 3 to avoid conflicting with HTTP negotiation.
@ -117,8 +118,8 @@ public class OkHttpClientTransport extends AbstractClientTransport {
@VisibleForTesting
OkHttpClientTransport(Executor executor, FrameReader frameReader, AsyncFrameWriter frameWriter,
int nextStreamId) {
host = null;
port = -1;
address = null;
defaultAuthority = "notarealauthority:80";
this.executor = Preconditions.checkNotNull(executor);
this.frameReader = Preconditions.checkNotNull(frameReader);
this.frameWriter = Preconditions.checkNotNull(frameWriter);
@ -129,17 +130,17 @@ public class OkHttpClientTransport extends AbstractClientTransport {
protected ClientStream newStreamInternal(MethodDescriptor<?, ?> method,
Metadata.Headers headers,
StreamListener listener) {
return new OkHttpClientStream(method, headers.serialize(), listener);
return new OkHttpClientStream(method, headers, listener);
}
@Override
protected void doStart() {
// We set host to null for test.
if (host != null) {
if (address != null) {
BufferedSource source;
BufferedSink sink;
try {
Socket socket = new Socket(host, port);
Socket socket = new Socket(address.getAddress(), address.getPort());
source = Okio.buffer(Okio.source(socket));
sink = Okio.buffer(Okio.sink(socket));
} catch (IOException e) {
@ -401,7 +402,8 @@ public class OkHttpClientTransport extends AbstractClientTransport {
final InputStreamDeframer deframer;
int unacknowledgedBytesRead;
OkHttpClientStream(MethodDescriptor<?, ?> method, byte[][] headers, StreamListener listener) {
OkHttpClientStream(MethodDescriptor<?, ?> method, Metadata.Headers headers,
StreamListener listener) {
super(listener);
deframer = new InputStreamDeframer(inboundMessageHandler());
synchronized (lock) {
@ -411,8 +413,9 @@ public class OkHttpClientTransport extends AbstractClientTransport {
}
assignStreamId(this);
}
String defaultPath = "/" + method.getName();
frameWriter.synStream(false, false, streamId, 0,
Headers.createRequestHeaders(method.getName(), headers));
Headers.createRequestHeaders(headers, defaultPath, defaultAuthority));
}
InputStreamDeframer getDeframer() {

View File

@ -1,27 +1,27 @@
package com.google.net.stubby.newtransport.okhttp;
import com.google.common.base.Preconditions;
import com.google.net.stubby.newtransport.ClientTransport;
import com.google.net.stubby.newtransport.ClientTransportFactory;
import java.net.InetSocketAddress;
import java.util.concurrent.ExecutorService;
/**
* Factory that manufactures instances of {@link OkHttpClientTransport}.
*/
public class OkHttpClientTransportFactory implements ClientTransportFactory {
private final String host;
private final int port;
private final InetSocketAddress address;
private final ExecutorService executor;
public OkHttpClientTransportFactory(String host, int port, ExecutorService executor) {
this.host = host;
this.port = port;
this.executor = executor;
public OkHttpClientTransportFactory(InetSocketAddress address, ExecutorService executor) {
this.address = Preconditions.checkNotNull(address, "address");
this.executor = Preconditions.checkNotNull(executor, "executor");
}
@Override
public ClientTransport newClientTransport() {
return new OkHttpClientTransport(host, port, executor);
return new OkHttpClientTransport(address, executor);
}
}

View File

@ -5,7 +5,6 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -20,8 +19,7 @@ import java.util.Iterator;
@RunWith(JUnit4.class)
public class MetadataTest {
private static final Metadata.Marshaller<Fish> FISH_MARSHALLER =
new Metadata.Marshaller<Fish>() {
private static final Metadata.Marshaller<Fish> FISH_MARSHALLER = new Metadata.Marshaller<Fish>() {
@Override
public byte[] toBytes(Fish fish) {
return fish.name.getBytes(StandardCharsets.UTF_8);
@ -32,7 +30,7 @@ public class MetadataTest {
return value.name;
}
@Override
@Override
public Fish parseBytes(byte[] serialized) {
return new Fish(new String(serialized, StandardCharsets.UTF_8));
}
@ -70,38 +68,44 @@ public class MetadataTest {
@Test
public void testWriteRaw() {
Metadata.Headers raw = new Metadata.Headers(
KEY.asciiName(), LANCE_BYTES);
Metadata.Headers raw = new Metadata.Headers(KEY.asciiName(), LANCE_BYTES);
Fish lance = raw.get(KEY);
assertEquals(lance, new Fish(LANCE));
// Reading again should return the same parsed instance
assertSame(lance, raw.get(KEY));
}
@Test
@Test(expected = IllegalStateException.class)
public void testFailSerializeRaw() {
Metadata.Headers raw = new Metadata.Headers(
KEY.asciiName(), LANCE_BYTES);
Metadata.Headers raw = new Metadata.Headers(KEY.asciiName(), LANCE_BYTES);
raw.serialize();
}
try {
raw.serialize();
fail("Can't serialize raw metadata");
} catch (IllegalStateException ise) {
// Success
}
@Test(expected = IllegalArgumentException.class)
public void testFailMergeRawIntoSerializable() {
Metadata.Headers raw = new Metadata.Headers(KEY.asciiName(), LANCE_BYTES);
Metadata.Headers serializable = new Metadata.Headers();
serializable.merge(raw);
}
@Test
public void testFailMergeRawIntoSerializable() {
Metadata.Headers raw = new Metadata.Headers(
KEY.asciiName(), LANCE_BYTES);
Metadata.Headers serializable = new Metadata.Headers();
try {
serializable.merge(raw);
fail("Can't serialize raw metadata");
} catch (IllegalArgumentException iae) {
// Success
}
public void headerMergeShouldCopyValues() {
Fish lance = new Fish(LANCE);
Metadata.Headers h1 = new Metadata.Headers();
Metadata.Headers h2 = new Metadata.Headers();
h2.setPath("/some/path");
h2.setAuthority("authority");
h2.put(KEY, lance);
h1.merge(h2);
Iterator<Fish> fishes = h1.<Fish>getAll(KEY).iterator();
assertTrue(fishes.hasNext());
assertSame(fishes.next(), lance);
assertFalse(fishes.hasNext());
assertEquals("/some/path", h1.getPath());
assertEquals("authority", h1.getAuthority());
}
private static class Fish {
@ -113,10 +117,16 @@ public class MetadataTest {
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
Fish fish = (Fish) o;
if (name != null ? !name.equals(fish.name) : fish.name != null) return false;
if (name != null ? !name.equals(fish.name) : fish.name != null) {
return false;
}
return true;
}
}

View File

@ -15,7 +15,6 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.net.stubby.Metadata;
import com.google.net.stubby.MethodDescriptor;
import com.google.net.stubby.Status;
import com.google.net.stubby.newtransport.HttpUtil;
import com.google.net.stubby.newtransport.StreamState;
@ -61,10 +60,8 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase {
@Mock
private NettyClientStream stream;
@Mock
private MethodDescriptor<?, ?> method;
private ByteBuf content;
private Metadata.Headers grpcHeaders;
private Http2Headers grpcHeaders;
@Before
public void setup() throws Exception {
@ -72,18 +69,23 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase {
frameWriter = new DefaultHttp2FrameWriter();
frameReader = new DefaultHttp2FrameReader();
handler = newHandler("www.fake.com", true);
handler = newHandler();
content = Unpooled.copiedBuffer("hello world", UTF_8);
when(channel.isActive()).thenReturn(true);
mockContext();
mockFuture(true);
Metadata.Key key = new Metadata.Key("auth", Metadata.STRING_MARSHALLER);
grpcHeaders = new Metadata.Headers();
grpcHeaders.put(key, "sometoken");
grpcHeaders = DefaultHttp2Headers
.newBuilder()
.scheme("https")
.authority("www.fake.com")
.path("/fakemethod")
.method(HTTP_METHOD)
.add("auth", "sometoken")
.add(CONTENT_TYPE_HEADER, CONTENT_TYPE_PROTORPC)
.build();
when(method.getName()).thenReturn("fakemethod");
when(stream.state()).thenReturn(StreamState.OPEN);
// Simulate activation of the handler to force writing of the initial settings
@ -100,7 +102,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase {
@Test
public void createStreamShouldSucceed() throws Exception {
handler.write(ctx, new CreateStreamCommand(method, grpcHeaders.serializeAscii(), stream),
handler.write(ctx, new CreateStreamCommand(grpcHeaders, stream),
promise);
verify(promise).setSuccess();
verify(stream).id(eq(3));
@ -190,7 +192,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase {
public void createShouldQueueStream() throws Exception {
// Disallow stream creation to force the stream to get added to the pending queue.
setMaxConcurrentStreams(0);
handler.write(ctx, new CreateStreamCommand(method, grpcHeaders.serializeAscii(), stream),
handler.write(ctx, new CreateStreamCommand(grpcHeaders, stream),
promise);
// Make sure the write never occurred.
@ -208,7 +210,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase {
public void receivedGoAwayShouldFailQueuedStreams() throws Exception {
// Force a stream to get added to the pending queue.
setMaxConcurrentStreams(0);
handler.write(ctx, new CreateStreamCommand(method, grpcHeaders.serializeAscii(), stream),
handler.write(ctx, new CreateStreamCommand(grpcHeaders, stream),
promise);
handler.channelRead(ctx, goAwayFrame(0));
@ -218,7 +220,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase {
@Test
public void receivedGoAwayShouldFailUnknownStreams() throws Exception {
// Force a stream to get added to the pending queue.
handler.write(ctx, new CreateStreamCommand(method, grpcHeaders.serializeAscii(), stream),
handler.write(ctx, new CreateStreamCommand(grpcHeaders, stream),
promise);
// Read a GOAWAY that indicates our stream was never processed by the server.
@ -246,14 +248,14 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase {
private void createStream() throws Exception {
// Create the stream.
handler.write(ctx, new CreateStreamCommand(method, grpcHeaders.serializeAscii(), stream),
handler.write(ctx, new CreateStreamCommand(grpcHeaders, stream),
promise);
when(stream.id()).thenReturn(3);
// Reset the context mock to clear recording of sent headers frame.
mockContext();
}
private static NettyClientHandler newHandler(String host, boolean ssl) {
private static NettyClientHandler newHandler() {
Http2Connection connection = new DefaultHttp2Connection(false);
Http2FrameReader frameReader = new DefaultHttp2FrameReader();
Http2FrameWriter frameWriter = new DefaultHttp2FrameWriter();
@ -261,9 +263,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase {
new DefaultHttp2InboundFlowController(connection, frameWriter);
Http2OutboundFlowController outboundFlow =
new DefaultHttp2OutboundFlowController(connection, frameWriter);
return new NettyClientHandler(host,
ssl,
connection,
return new NettyClientHandler(connection,
frameReader,
frameWriter,
inboundFlow,

View File

@ -1,6 +1,7 @@
package com.google.net.stubby.stub;
import com.google.net.stubby.Call;
import com.google.net.stubby.Channel;
import com.google.net.stubby.Metadata;
import com.google.net.stubby.MethodDescriptor;
import com.google.net.stubby.context.ForwardingChannel;
@ -20,18 +21,30 @@ public class HeadersInterceptor {
public static <T extends AbstractStub> T intercept(
T stub,
final Metadata.Headers extraHeaders) {
return (T) stub.configureNewStub().setChannel(
new ForwardingChannel(stub.getChannel()) {
@Override
public <ReqT, RespT> Call<ReqT, RespT> newCall(MethodDescriptor<ReqT, RespT> method) {
return new ForwardingCall<ReqT, RespT>(delegate.newCall(method)) {
@Override
public void start(Listener<RespT> responseListener, Metadata.Headers headers) {
headers.merge(extraHeaders);
delegate.start(responseListener, headers);
}
};
}
}).build();
return (T) stub.configureNewStub().setChannel(interceptChannel(stub.getChannel(), extraHeaders))
.build();
}
/**
* Attach a set of request headers to a channel.
*
* @param channel to channel to intercept.
* @param extraHeaders the headers to be passed by each call on the returned stub.
* @return an implementation of the channel with extraHeaders bound to each call.
*/
@SuppressWarnings("unchecked")
public static Channel interceptChannel(Channel channel, final Metadata.Headers extraHeaders) {
return new ForwardingChannel(channel) {
@Override
public <ReqT, RespT> Call<ReqT, RespT> newCall(MethodDescriptor<ReqT, RespT> method) {
return new ForwardingCall<ReqT, RespT>(delegate.newCall(method)) {
@Override
public void start(Listener<RespT> responseListener, Metadata.Headers headers) {
headers.merge(extraHeaders);
delegate.start(responseListener, headers);
}
};
}
};
}
}