diff --git a/core/src/main/java/com/google/net/stubby/ChannelImpl.java b/core/src/main/java/com/google/net/stubby/ChannelImpl.java index b9b8c3e3ce..f44af56f36 100644 --- a/core/src/main/java/com/google/net/stubby/ChannelImpl.java +++ b/core/src/main/java/com/google/net/stubby/ChannelImpl.java @@ -138,8 +138,6 @@ public final class ChannelImpl extends AbstractService implements Channel { @Override public void start(Listener observer, Metadata.Headers headers) { Preconditions.checkState(stream == null, "Already started"); - headers.setPath(method.getName()); - headers.setAuthority("fixme"); stream = obtainActiveTransport().newStream(method, headers, new StreamListenerImpl(observer)); } diff --git a/core/src/main/java/com/google/net/stubby/Metadata.java b/core/src/main/java/com/google/net/stubby/Metadata.java index ca83e0c7ad..c757294653 100644 --- a/core/src/main/java/com/google/net/stubby/Metadata.java +++ b/core/src/main/java/com/google/net/stubby/Metadata.java @@ -301,6 +301,26 @@ public abstract class Metadata { public void setAuthority(String authority) { this.authority = authority; } + + @Override + public void merge(Metadata other) { + super.merge(other); + mergePathAndAuthority(other); + } + + @Override + public void merge(Metadata other, Set keys) { + super.merge(other, keys); + mergePathAndAuthority(other); + } + + private void mergePathAndAuthority(Metadata other) { + if (other instanceof Headers) { + Headers otherHeaders = (Headers) other; + path = otherHeaders.path != null ? otherHeaders.path : path; + authority = otherHeaders.authority != null ? otherHeaders.authority : authority; + } + } } /** diff --git a/core/src/main/java/com/google/net/stubby/http2/okhttp/Http2Request.java b/core/src/main/java/com/google/net/stubby/http2/okhttp/Http2Request.java index afaa92186c..9691635339 100644 --- a/core/src/main/java/com/google/net/stubby/http2/okhttp/Http2Request.java +++ b/core/src/main/java/com/google/net/stubby/http2/okhttp/Http2Request.java @@ -21,8 +21,10 @@ import java.util.List; public class Http2Request extends Http2Operation implements Request { private final Response response; - public Http2Request(FrameWriter frameWriter, String operationName, + public Http2Request(FrameWriter frameWriter, Metadata.Headers headers, + String defaultPath, + String defaultAuthority, Response response, RequestRegistry requestRegistry, Framer framer) { super(response.getId(), frameWriter, framer); @@ -31,8 +33,8 @@ public class Http2Request extends Http2Operation implements Request { // Register this request. requestRegistry.register(this); - List
requestHeaders = Headers.createRequestHeaders(operationName, - headers.serialize()); + List
requestHeaders = + Headers.createRequestHeaders(headers, defaultPath, defaultAuthority); frameWriter.synStream(false, false, getId(), 0, requestHeaders); } catch (IOException ioe) { close(new Status(Transport.Code.UNKNOWN, ioe)); diff --git a/core/src/main/java/com/google/net/stubby/http2/okhttp/OkHttpSession.java b/core/src/main/java/com/google/net/stubby/http2/okhttp/OkHttpSession.java index c8f9d39892..33a3c2af7e 100644 --- a/core/src/main/java/com/google/net/stubby/http2/okhttp/OkHttpSession.java +++ b/core/src/main/java/com/google/net/stubby/http2/okhttp/OkHttpSession.java @@ -32,6 +32,7 @@ import okio.ByteString; import okio.Okio; import java.io.IOException; +import java.net.InetSocketAddress; import java.net.Socket; import java.util.List; import java.util.concurrent.Executor; @@ -81,6 +82,7 @@ public class OkHttpSession implements Session { } } + private final String defaultAuthority; private final FrameReader frameReader; private final FrameWriter frameWriter; private final AtomicInteger sessionId; @@ -108,6 +110,10 @@ public class OkHttpSession implements Session { this.serverSession = null; this.requestRegistry = requestRegistry; executor.execute(new FrameHandler()); + + // Determine the default :authority header to use. + InetSocketAddress remoteAddress = (InetSocketAddress) socket.getRemoteSocketAddress(); + defaultAuthority = remoteAddress.getHostString() + ":" + remoteAddress.getPort(); } /** @@ -129,6 +135,9 @@ public class OkHttpSession implements Session { this.serverSession = server; this.requestRegistry = requestRegistry; executor.execute(new FrameHandler()); + + // Authority is not used for server-side sessions. + defaultAuthority = null; } @Override @@ -147,13 +156,18 @@ public class OkHttpSession implements Session { } @Override - public Request startRequest(String operationName, - Metadata.Headers headers, - Response.ResponseBuilder responseBuilder) { + public Request startRequest(String operationName, Metadata.Headers headers, + Response.ResponseBuilder responseBuilder) { int nextStreamId = getNextStreamId(); Response response = responseBuilder.build(nextStreamId); - Http2Request request = new Http2Request(frameWriter, operationName, headers, response, - requestRegistry, new MessageFramer(4096)); + String defaultPath = "/" + operationName; + Http2Request request = new Http2Request(frameWriter, + headers, + defaultPath, + defaultAuthority, + response, + requestRegistry, + new MessageFramer(4096)); return request; } @@ -259,7 +273,22 @@ public class OkHttpSession implements Session { // Start an Operation for SYN_STREAM if (op == null && headersMode == HeadersMode.HTTP_20_HEADERS) { + // TODO(user): Throwing inside this method seems to cause a request to + // hang indefinitely ... possibly an OkHttp bug? We should investigate + // this and come up with a solution that works for any handler method that encounters + // an exception. String path = findReservedHeader(Header.TARGET_PATH.utf8(), headers); + if (path == null) { + try { + // The :path MUST be provided. This is a protocol error. + frameWriter.rstStream(streamId, ErrorCode.PROTOCOL_ERROR); + frameWriter.flush(); + return; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + byte[][] binaryHeaders = new byte[headers.size() * 2][]; for (int i = 0; i < headers.size(); i++) { Header header = headers.get(i); @@ -269,13 +298,10 @@ public class OkHttpSession implements Session { Metadata.Headers grpcHeaders = new Metadata.Headers(binaryHeaders); grpcHeaders.setPath(path); grpcHeaders.setAuthority(findReservedHeader(Header.TARGET_AUTHORITY.utf8(), headers)); - if (path != null) { - Request request = serverSession.startRequest(path, - grpcHeaders, - Http2Response.builder(streamId, frameWriter, new MessageFramer(4096))); - requestRegistry.register(request); - op = request; - } + Request request = serverSession.startRequest(path, grpcHeaders, + Http2Response.builder(streamId, frameWriter, new MessageFramer(4096))); + requestRegistry.register(request); + op = request; } if (op == null) { return; @@ -291,10 +317,11 @@ public class OkHttpSession implements Session { for (Header header : headers) { // Reserved headers must come before non-reserved headers, so we can exit the loop // early if we see a non-reserved header. - if (!header.name.utf8().startsWith(":")) { - return null; + String headerString = header.name.utf8(); + if (!headerString.startsWith(":")) { + break; } - if (header.name.utf8().equals(name)) { + if (headerString.equals(name)) { return header.value.utf8(); } } diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/CreateStreamCommand.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/CreateStreamCommand.java index 98fea676cd..7e965a392b 100644 --- a/core/src/main/java/com/google/net/stubby/newtransport/netty/CreateStreamCommand.java +++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/CreateStreamCommand.java @@ -1,33 +1,28 @@ package com.google.net.stubby.newtransport.netty; import com.google.common.base.Preconditions; -import com.google.net.stubby.MethodDescriptor; + +import io.netty.handler.codec.http2.Http2Headers; /** * A command to create a new stream. This is created by {@link NettyClientStream} and passed to the * {@link NettyClientHandler} for processing in the Channel thread. */ class CreateStreamCommand { - private final MethodDescriptor method; - private final String[] headers; + private final Http2Headers headers; private final NettyClientStream stream; - CreateStreamCommand(MethodDescriptor method, String[] headers, + CreateStreamCommand(Http2Headers headers, NettyClientStream stream) { - this.method = Preconditions.checkNotNull(method, "method"); this.stream = Preconditions.checkNotNull(stream, "stream"); this.headers = Preconditions.checkNotNull(headers, "headers"); } - MethodDescriptor method() { - return method; - } - NettyClientStream stream() { return stream; } - String[] headers() { + Http2Headers headers() { return headers; } } diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientHandler.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientHandler.java index cada74ed59..86407ddad2 100644 --- a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientHandler.java +++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientHandler.java @@ -1,13 +1,9 @@ package com.google.net.stubby.newtransport.netty; -import static com.google.net.stubby.newtransport.HttpUtil.CONTENT_TYPE_HEADER; -import static com.google.net.stubby.newtransport.HttpUtil.CONTENT_TYPE_PROTORPC; -import static com.google.net.stubby.newtransport.HttpUtil.HTTP_METHOD; import static com.google.net.stubby.newtransport.netty.NettyClientStream.PENDING_STREAM_ID; import com.google.common.base.Preconditions; import com.google.net.stubby.Metadata; -import com.google.net.stubby.MethodDescriptor; import com.google.net.stubby.Status; import com.google.net.stubby.transport.Transport; @@ -17,7 +13,6 @@ import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import io.netty.handler.codec.http2.AbstractHttp2ConnectionHandler; -import io.netty.handler.codec.http2.DefaultHttp2Headers; import io.netty.handler.codec.http2.DefaultHttp2InboundFlowController; import io.netty.handler.codec.http2.Http2Connection; import io.netty.handler.codec.http2.Http2ConnectionAdapter; @@ -45,35 +40,27 @@ class NettyClientHandler extends AbstractHttp2ConnectionHandler { * A pending stream creation. */ private final class PendingStream { - private final MethodDescriptor method; - private final String[] headers; + private final Http2Headers headers; private final NettyClientStream stream; private final ChannelPromise promise; public PendingStream(CreateStreamCommand command, ChannelPromise promise) { - method = command.method(); headers = command.headers(); stream = command.stream(); this.promise = promise; } } - private final String host; - private final String scheme; private final DefaultHttp2InboundFlowController inboundFlow; private final Deque pendingStreams = new ArrayDeque(); private Status goAwayStatus = GOAWAY_STATUS; - public NettyClientHandler(String host, - boolean ssl, - Http2Connection connection, + public NettyClientHandler(Http2Connection connection, Http2FrameReader frameReader, Http2FrameWriter frameWriter, DefaultHttp2InboundFlowController inboundFlow, Http2OutboundFlowController outboundFlow) { super(connection, frameReader, frameWriter, inboundFlow, outboundFlow); - this.host = Preconditions.checkNotNull(host, "host"); - this.scheme = ssl ? "https" : "http"; this.inboundFlow = Preconditions.checkNotNull(inboundFlow, "inboundFlow"); // Disallow stream creation by the server. @@ -320,22 +307,7 @@ class NettyClientHandler extends AbstractHttp2ConnectionHandler { // Finish creation of the stream by writing a headers frame. final PendingStream pendingStream = pendingStreams.remove(); - // TODO(user): Change Netty to not send priority, just use default. - // TODO(user): Switch to binary headers when Netty supports it. - DefaultHttp2Headers.Builder headersBuilder = DefaultHttp2Headers.newBuilder(); - for (int i = 0; i < pendingStream.headers.length; i++) { - headersBuilder.add( - pendingStream.headers[i], - pendingStream.headers[++i]); - } - headersBuilder - .method(HTTP_METHOD) - .authority(host) - .scheme(scheme) - .add(CONTENT_TYPE_HEADER, CONTENT_TYPE_PROTORPC) - .path("/" + pendingStream.method.getName()) - .build(); - writeHeaders(ctx(), streamId, headersBuilder.build(), 0, false, ctx().newPromise()) + writeHeaders(ctx(), streamId, pendingStream.headers, 0, false, ctx().newPromise()) .addListener(new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientTransport.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientTransport.java index b23dd5dd38..6d3056fa0b 100644 --- a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientTransport.java +++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientTransport.java @@ -29,11 +29,13 @@ import io.netty.handler.codec.http2.Http2Connection; import io.netty.handler.codec.http2.Http2FrameLogger; import io.netty.handler.codec.http2.Http2FrameReader; import io.netty.handler.codec.http2.Http2FrameWriter; +import io.netty.handler.codec.http2.Http2Headers; import io.netty.handler.codec.http2.Http2InboundFrameLogger; import io.netty.handler.codec.http2.Http2OutboundFlowController; import io.netty.handler.codec.http2.Http2OutboundFrameLogger; import io.netty.util.internal.logging.InternalLogLevel; +import java.net.InetSocketAddress; import java.util.concurrent.ExecutionException; import javax.net.ssl.SSLEngine; @@ -43,39 +45,41 @@ import javax.net.ssl.SSLEngine; */ class NettyClientTransport extends AbstractClientTransport { - private final String host; - private final int port; + private final InetSocketAddress address; private final EventLoopGroup eventGroup; private final Http2Negotiator.Negotiation negotiation; private final NettyClientHandler handler; + private final boolean ssl; + private final String authority; private Channel channel; - NettyClientTransport(String host, int port, NegotiationType negotiationType) { - this(host, port, negotiationType, new NioEventLoopGroup()); + NettyClientTransport(InetSocketAddress address, NegotiationType negotiationType) { + this(address, negotiationType, new NioEventLoopGroup()); } - NettyClientTransport(String host, int port, NegotiationType negotiationType, + NettyClientTransport(InetSocketAddress address, NegotiationType negotiationType, EventLoopGroup eventGroup) { - Preconditions.checkNotNull(host, "host"); - Preconditions.checkArgument(port >= 0, "port must be positive"); - Preconditions.checkNotNull(eventGroup, "eventGroup"); Preconditions.checkNotNull(negotiationType, "negotiationType"); - this.host = host; - this.port = port; - this.eventGroup = eventGroup; + this.address = Preconditions.checkNotNull(address, "address"); + this.eventGroup = Preconditions.checkNotNull(eventGroup, "eventGroup"); - handler = newHandler(host, negotiationType == NegotiationType.TLS); + authority = address.getHostString() + ":" + address.getPort(); + + handler = newHandler(); switch (negotiationType) { case PLAINTEXT: negotiation = Http2Negotiator.plaintext(handler); + ssl = false; break; case PLAINTEXT_UPGRADE: negotiation = Http2Negotiator.plaintextUpgrade(handler); + ssl = false; break; case TLS: SSLEngine sslEngine = SslContextFactory.getClientContext().createSSLEngine(); sslEngine.setUseClientMode(true); negotiation = Http2Negotiator.tls(handler, sslEngine); + ssl = true; break; default: throw new IllegalArgumentException("Unsupported negotiationType: " + negotiationType); @@ -83,17 +87,18 @@ class NettyClientTransport extends AbstractClientTransport { } @Override - protected ClientStream newStreamInternal(MethodDescriptor method, - Metadata.Headers headers, - StreamListener listener) { + protected ClientStream newStreamInternal(MethodDescriptor method, Metadata.Headers headers, + StreamListener listener) { // Create the stream. NettyClientStream stream = new NettyClientStream(listener, channel, handler.inboundFlow()); try { + // Convert the headers into Netty HTTP/2 headers. + String defaultPath = "/" + method.getName(); + Http2Headers http2Headers = Utils.convertHeaders(headers, ssl, defaultPath, authority); + // Write the request and await creation of the stream. - channel.writeAndFlush(new CreateStreamCommand(method, - headers.serializeAscii(), - stream)).get(); + channel.writeAndFlush(new CreateStreamCommand(http2Headers, stream)).get(); } catch (InterruptedException e) { // Restore the interrupt. Thread.currentThread().interrupt(); @@ -116,7 +121,7 @@ class NettyClientTransport extends AbstractClientTransport { b.handler(negotiation.initializer()); // Start the connection operation to the server. - b.connect(host, port).addListener(new ChannelFutureListener() { + b.connect(address).addListener(new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { if (future.isSuccess()) { @@ -154,7 +159,7 @@ class NettyClientTransport extends AbstractClientTransport { } } - private static NettyClientHandler newHandler(String host, boolean ssl) { + private static NettyClientHandler newHandler() { Http2Connection connection = new DefaultHttp2Connection(false, new DefaultHttp2StreamRemovalPolicy()); Http2FrameReader frameReader = new DefaultHttp2FrameReader(); @@ -168,12 +173,6 @@ class NettyClientTransport extends AbstractClientTransport { new DefaultHttp2InboundFlowController(connection, frameWriter); Http2OutboundFlowController outboundFlow = new DefaultHttp2OutboundFlowController(connection, frameWriter); - return new NettyClientHandler(host, - ssl, - connection, - frameReader, - frameWriter, - inboundFlow, - outboundFlow); + return new NettyClientHandler(connection, frameReader, frameWriter, inboundFlow, outboundFlow); } } diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientTransportFactory.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientTransportFactory.java index 3c37fd55b1..300e833dff 100644 --- a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientTransportFactory.java +++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientTransportFactory.java @@ -3,6 +3,8 @@ package com.google.net.stubby.newtransport.netty; import com.google.common.base.Preconditions; import com.google.net.stubby.newtransport.ClientTransportFactory; +import java.net.InetSocketAddress; + import io.netty.channel.EventLoopGroup; /** @@ -31,22 +33,19 @@ public class NettyClientTransportFactory implements ClientTransportFactory { PLAINTEXT } - private final String host; - private final int port; + private final InetSocketAddress address; private final NegotiationType negotiationType; private final EventLoopGroup group; - public NettyClientTransportFactory(String host, int port, NegotiationType negotiationType, + public NettyClientTransportFactory(InetSocketAddress address, NegotiationType negotiationType, EventLoopGroup group) { + this.address = Preconditions.checkNotNull(address, "address"); this.group = Preconditions.checkNotNull(group, "group"); - Preconditions.checkArgument(port > 0, "Port must be positive"); - this.host = Preconditions.checkNotNull(host, "host"); this.negotiationType = Preconditions.checkNotNull(negotiationType, "negotiationType"); - this.port = port; } @Override public NettyClientTransport newClientTransport() { - return new NettyClientTransport(host, port, negotiationType, group); + return new NettyClientTransport(address, negotiationType, group); } } diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/Utils.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/Utils.java index 9119460972..4d7f4ebaf6 100644 --- a/core/src/main/java/com/google/net/stubby/newtransport/netty/Utils.java +++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/Utils.java @@ -1,9 +1,17 @@ package com.google.net.stubby.newtransport.netty; +import static com.google.net.stubby.newtransport.HttpUtil.CONTENT_TYPE_HEADER; +import static com.google.net.stubby.newtransport.HttpUtil.CONTENT_TYPE_PROTORPC; +import static com.google.net.stubby.newtransport.HttpUtil.HTTP_METHOD; +import static io.netty.util.CharsetUtil.UTF_8; + +import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; +import com.google.net.stubby.Metadata; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.handler.codec.http2.DefaultHttp2Headers; import io.netty.handler.codec.http2.Http2Headers; import java.nio.ByteBuffer; @@ -25,6 +33,43 @@ class Utils { return buf; } + public static Http2Headers convertHeaders(Metadata.Headers headers, + boolean ssl, + String defaultPath, + String defaultAuthority) { + Preconditions.checkNotNull(headers, "headers"); + Preconditions.checkNotNull(defaultPath, "defaultPath"); + Preconditions.checkNotNull(defaultAuthority, "defaultAuthority"); + + DefaultHttp2Headers.Builder headersBuilder = DefaultHttp2Headers.newBuilder(); + + // Add any application-provided headers first. + byte[][] serializedHeaders = headers.serialize(); + for (int i = 0; i < serializedHeaders.length; i++) { + String key = new String(serializedHeaders[i], UTF_8); + String value = new String(serializedHeaders[++i], UTF_8); + headersBuilder.add(key, value); + } + + // Now set GRPC-specific default headers. + headersBuilder + .authority(defaultAuthority) + .path(defaultPath) + .method(HTTP_METHOD) + .scheme(ssl? "https" : "http") + .add(CONTENT_TYPE_HEADER, CONTENT_TYPE_PROTORPC); + + // Override the default authority and path if provided by the headers. + if (headers.getAuthority() != null) { + headersBuilder.authority(headers.getAuthority()); + } + if (headers.getPath() != null) { + headersBuilder.path(headers.getPath()); + } + + return headersBuilder.build(); + } + public static ImmutableMap> convertHeaders(Http2Headers headers) { ImmutableMap.Builder> grpcHeaders = new ImmutableMap.Builder>(); diff --git a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/Headers.java b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/Headers.java index f63640aadc..62023116b0 100644 --- a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/Headers.java +++ b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/Headers.java @@ -1,6 +1,9 @@ package com.google.net.stubby.newtransport.okhttp; +import com.google.common.base.Preconditions; import com.google.common.collect.Lists; +import com.google.net.stubby.Metadata; +import com.google.net.stubby.newtransport.HttpUtil; import com.squareup.okhttp.internal.spdy.Header; @@ -12,19 +15,47 @@ import java.util.List; * Constants for request/response headers. */ public class Headers { + public static final Header SCHEME_HEADER = new Header(Header.TARGET_SCHEME, "https"); + public static final Header METHOD_HEADER = new Header(Header.TARGET_METHOD, HttpUtil.HTTP_METHOD); public static final Header CONTENT_TYPE_HEADER = - new Header("content-type", "application/protorpc"); + new Header(HttpUtil.CONTENT_TYPE_HEADER, HttpUtil.CONTENT_TYPE_PROTORPC); public static final Header RESPONSE_STATUS_OK = new Header(Header.RESPONSE_STATUS, "200"); - public static List
createRequestHeaders(String operationName, byte[][] headers) { + /** + * Serializes the given headers and creates a list of OkHttp {@link Header}s to be used when + * creating a stream. Since this serializes the headers, this method should be called in the + * application thread context. + */ + public static List
createRequestHeaders(Metadata.Headers headers, String defaultPath, + String defaultAuthority) { + Preconditions.checkNotNull(headers, "headers"); + Preconditions.checkNotNull(defaultPath, "defaultPath"); + Preconditions.checkNotNull(defaultAuthority, "defaultAuthority"); + List
okhttpHeaders = Lists.newArrayListWithCapacity(6); - okhttpHeaders.add(new Header(Header.TARGET_PATH, operationName)); + + // Set GRPC-specific headers. okhttpHeaders.add(SCHEME_HEADER); + okhttpHeaders.add(METHOD_HEADER); + String authority = headers.getAuthority() != null ? headers.getAuthority() : defaultAuthority; + okhttpHeaders.add(new Header(Header.TARGET_AUTHORITY, authority)); + String path = headers.getPath() != null ? headers.getPath() : defaultPath; + okhttpHeaders.add(new Header(Header.TARGET_PATH, path)); + + // All non-pseudo headers must come after pseudo headers. okhttpHeaders.add(CONTENT_TYPE_HEADER); - for (int i = 0; i < headers.length; i++) { - okhttpHeaders.add(new Header(ByteString.of(headers[i]), ByteString.of(headers[++i]))); + + // Now add any application-provided headers. + byte[][] serializedHeaders = headers.serialize(); + for (int i = 0; i < serializedHeaders.length; i++) { + ByteString key = ByteString.of(serializedHeaders[i]); + ByteString value = ByteString.of(serializedHeaders[++i]); + if (isApplicationHeader(key)) { + okhttpHeaders.add(new Header(key, value)); + } } + return okhttpHeaders; } @@ -34,4 +65,15 @@ public class Headers { headers.add(RESPONSE_STATUS_OK); return headers; } + + /** + * Returns {@code true} if the given header is an application-provided header. Otherwise, returns + * {@code false} if the header is reserved by GRPC. + */ + private static boolean isApplicationHeader(ByteString key) { + String keyString = key.utf8(); + // Don't allow HTTP/2 pseudo headers or content-type to be added by the applciation. + return (!keyString.startsWith(":") + && !HttpUtil.CONTENT_TYPE_HEADER.equalsIgnoreCase(keyString)); + } } diff --git a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java index 4d94ac1ddb..852591f0ea 100644 --- a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java +++ b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransport.java @@ -32,6 +32,7 @@ import okio.ByteString; import okio.Okio; import java.io.IOException; +import java.net.InetSocketAddress; import java.net.Socket; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -84,8 +85,8 @@ public class OkHttpClientTransport extends AbstractClientTransport { ERROR_CODE_TO_STATUS = Collections.unmodifiableMap(errorToStatus); } - private final String host; - private final int port; + private final InetSocketAddress address; + private final String defaultAuthority; private FrameReader frameReader; private AsyncFrameWriter frameWriter; private final Object lock = new Object(); @@ -102,9 +103,9 @@ public class OkHttpClientTransport extends AbstractClientTransport { @GuardedBy("lock") private Status goAwayStatus; - OkHttpClientTransport(String host, int port, Executor executor) { - this.host = Preconditions.checkNotNull(host); - this.port = port; + OkHttpClientTransport(InetSocketAddress address, Executor executor) { + this.address = Preconditions.checkNotNull(address); + defaultAuthority = address.getHostString() + ":" + address.getPort(); this.executor = Preconditions.checkNotNull(executor); // Client initiated streams are odd, server initiated ones are even. Server should not need to // use it. We start clients at 3 to avoid conflicting with HTTP negotiation. @@ -117,8 +118,8 @@ public class OkHttpClientTransport extends AbstractClientTransport { @VisibleForTesting OkHttpClientTransport(Executor executor, FrameReader frameReader, AsyncFrameWriter frameWriter, int nextStreamId) { - host = null; - port = -1; + address = null; + defaultAuthority = "notarealauthority:80"; this.executor = Preconditions.checkNotNull(executor); this.frameReader = Preconditions.checkNotNull(frameReader); this.frameWriter = Preconditions.checkNotNull(frameWriter); @@ -129,17 +130,17 @@ public class OkHttpClientTransport extends AbstractClientTransport { protected ClientStream newStreamInternal(MethodDescriptor method, Metadata.Headers headers, StreamListener listener) { - return new OkHttpClientStream(method, headers.serialize(), listener); + return new OkHttpClientStream(method, headers, listener); } @Override protected void doStart() { // We set host to null for test. - if (host != null) { + if (address != null) { BufferedSource source; BufferedSink sink; try { - Socket socket = new Socket(host, port); + Socket socket = new Socket(address.getAddress(), address.getPort()); source = Okio.buffer(Okio.source(socket)); sink = Okio.buffer(Okio.sink(socket)); } catch (IOException e) { @@ -401,7 +402,8 @@ public class OkHttpClientTransport extends AbstractClientTransport { final InputStreamDeframer deframer; int unacknowledgedBytesRead; - OkHttpClientStream(MethodDescriptor method, byte[][] headers, StreamListener listener) { + OkHttpClientStream(MethodDescriptor method, Metadata.Headers headers, + StreamListener listener) { super(listener); deframer = new InputStreamDeframer(inboundMessageHandler()); synchronized (lock) { @@ -411,8 +413,9 @@ public class OkHttpClientTransport extends AbstractClientTransport { } assignStreamId(this); } + String defaultPath = "/" + method.getName(); frameWriter.synStream(false, false, streamId, 0, - Headers.createRequestHeaders(method.getName(), headers)); + Headers.createRequestHeaders(headers, defaultPath, defaultAuthority)); } InputStreamDeframer getDeframer() { diff --git a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransportFactory.java b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransportFactory.java index adf7ba2697..7ed6a368fb 100644 --- a/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransportFactory.java +++ b/core/src/main/java/com/google/net/stubby/newtransport/okhttp/OkHttpClientTransportFactory.java @@ -1,27 +1,27 @@ package com.google.net.stubby.newtransport.okhttp; +import com.google.common.base.Preconditions; import com.google.net.stubby.newtransport.ClientTransport; import com.google.net.stubby.newtransport.ClientTransportFactory; +import java.net.InetSocketAddress; import java.util.concurrent.ExecutorService; /** * Factory that manufactures instances of {@link OkHttpClientTransport}. */ public class OkHttpClientTransportFactory implements ClientTransportFactory { - private final String host; - private final int port; + private final InetSocketAddress address; private final ExecutorService executor; - public OkHttpClientTransportFactory(String host, int port, ExecutorService executor) { - this.host = host; - this.port = port; - this.executor = executor; + public OkHttpClientTransportFactory(InetSocketAddress address, ExecutorService executor) { + this.address = Preconditions.checkNotNull(address, "address"); + this.executor = Preconditions.checkNotNull(executor, "executor"); } @Override public ClientTransport newClientTransport() { - return new OkHttpClientTransport(host, port, executor); + return new OkHttpClientTransport(address, executor); } } diff --git a/core/src/test/java/com/google/net/stubby/MetadataTest.java b/core/src/test/java/com/google/net/stubby/MetadataTest.java index 168c539a7f..a0f79a62fa 100644 --- a/core/src/test/java/com/google/net/stubby/MetadataTest.java +++ b/core/src/test/java/com/google/net/stubby/MetadataTest.java @@ -5,7 +5,6 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; import org.junit.Test; import org.junit.runner.RunWith; @@ -20,8 +19,7 @@ import java.util.Iterator; @RunWith(JUnit4.class) public class MetadataTest { - private static final Metadata.Marshaller FISH_MARSHALLER = - new Metadata.Marshaller() { + private static final Metadata.Marshaller FISH_MARSHALLER = new Metadata.Marshaller() { @Override public byte[] toBytes(Fish fish) { return fish.name.getBytes(StandardCharsets.UTF_8); @@ -32,7 +30,7 @@ public class MetadataTest { return value.name; } - @Override + @Override public Fish parseBytes(byte[] serialized) { return new Fish(new String(serialized, StandardCharsets.UTF_8)); } @@ -70,38 +68,44 @@ public class MetadataTest { @Test public void testWriteRaw() { - Metadata.Headers raw = new Metadata.Headers( - KEY.asciiName(), LANCE_BYTES); + Metadata.Headers raw = new Metadata.Headers(KEY.asciiName(), LANCE_BYTES); Fish lance = raw.get(KEY); assertEquals(lance, new Fish(LANCE)); // Reading again should return the same parsed instance assertSame(lance, raw.get(KEY)); } - @Test + @Test(expected = IllegalStateException.class) public void testFailSerializeRaw() { - Metadata.Headers raw = new Metadata.Headers( - KEY.asciiName(), LANCE_BYTES); + Metadata.Headers raw = new Metadata.Headers(KEY.asciiName(), LANCE_BYTES); + raw.serialize(); + } - try { - raw.serialize(); - fail("Can't serialize raw metadata"); - } catch (IllegalStateException ise) { - // Success - } + @Test(expected = IllegalArgumentException.class) + public void testFailMergeRawIntoSerializable() { + Metadata.Headers raw = new Metadata.Headers(KEY.asciiName(), LANCE_BYTES); + Metadata.Headers serializable = new Metadata.Headers(); + serializable.merge(raw); } @Test - public void testFailMergeRawIntoSerializable() { - Metadata.Headers raw = new Metadata.Headers( - KEY.asciiName(), LANCE_BYTES); - Metadata.Headers serializable = new Metadata.Headers(); - try { - serializable.merge(raw); - fail("Can't serialize raw metadata"); - } catch (IllegalArgumentException iae) { - // Success - } + public void headerMergeShouldCopyValues() { + Fish lance = new Fish(LANCE); + Metadata.Headers h1 = new Metadata.Headers(); + + Metadata.Headers h2 = new Metadata.Headers(); + h2.setPath("/some/path"); + h2.setAuthority("authority"); + h2.put(KEY, lance); + + h1.merge(h2); + + Iterator fishes = h1.getAll(KEY).iterator(); + assertTrue(fishes.hasNext()); + assertSame(fishes.next(), lance); + assertFalse(fishes.hasNext()); + assertEquals("/some/path", h1.getPath()); + assertEquals("authority", h1.getAuthority()); } private static class Fish { @@ -113,10 +117,16 @@ public class MetadataTest { @Override public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } Fish fish = (Fish) o; - if (name != null ? !name.equals(fish.name) : fish.name != null) return false; + if (name != null ? !name.equals(fish.name) : fish.name != null) { + return false; + } return true; } } diff --git a/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyClientHandlerTest.java b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyClientHandlerTest.java index 9c7ce95fab..0b7c00d3e6 100644 --- a/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyClientHandlerTest.java +++ b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyClientHandlerTest.java @@ -15,7 +15,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import com.google.net.stubby.Metadata; -import com.google.net.stubby.MethodDescriptor; import com.google.net.stubby.Status; import com.google.net.stubby.newtransport.HttpUtil; import com.google.net.stubby.newtransport.StreamState; @@ -61,10 +60,8 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase { @Mock private NettyClientStream stream; - @Mock - private MethodDescriptor method; private ByteBuf content; - private Metadata.Headers grpcHeaders; + private Http2Headers grpcHeaders; @Before public void setup() throws Exception { @@ -72,18 +69,23 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase { frameWriter = new DefaultHttp2FrameWriter(); frameReader = new DefaultHttp2FrameReader(); - handler = newHandler("www.fake.com", true); + handler = newHandler(); content = Unpooled.copiedBuffer("hello world", UTF_8); when(channel.isActive()).thenReturn(true); mockContext(); mockFuture(true); - Metadata.Key key = new Metadata.Key("auth", Metadata.STRING_MARSHALLER); - grpcHeaders = new Metadata.Headers(); - grpcHeaders.put(key, "sometoken"); + grpcHeaders = DefaultHttp2Headers + .newBuilder() + .scheme("https") + .authority("www.fake.com") + .path("/fakemethod") + .method(HTTP_METHOD) + .add("auth", "sometoken") + .add(CONTENT_TYPE_HEADER, CONTENT_TYPE_PROTORPC) + .build(); - when(method.getName()).thenReturn("fakemethod"); when(stream.state()).thenReturn(StreamState.OPEN); // Simulate activation of the handler to force writing of the initial settings @@ -100,7 +102,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase { @Test public void createStreamShouldSucceed() throws Exception { - handler.write(ctx, new CreateStreamCommand(method, grpcHeaders.serializeAscii(), stream), + handler.write(ctx, new CreateStreamCommand(grpcHeaders, stream), promise); verify(promise).setSuccess(); verify(stream).id(eq(3)); @@ -190,7 +192,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase { public void createShouldQueueStream() throws Exception { // Disallow stream creation to force the stream to get added to the pending queue. setMaxConcurrentStreams(0); - handler.write(ctx, new CreateStreamCommand(method, grpcHeaders.serializeAscii(), stream), + handler.write(ctx, new CreateStreamCommand(grpcHeaders, stream), promise); // Make sure the write never occurred. @@ -208,7 +210,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase { public void receivedGoAwayShouldFailQueuedStreams() throws Exception { // Force a stream to get added to the pending queue. setMaxConcurrentStreams(0); - handler.write(ctx, new CreateStreamCommand(method, grpcHeaders.serializeAscii(), stream), + handler.write(ctx, new CreateStreamCommand(grpcHeaders, stream), promise); handler.channelRead(ctx, goAwayFrame(0)); @@ -218,7 +220,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase { @Test public void receivedGoAwayShouldFailUnknownStreams() throws Exception { // Force a stream to get added to the pending queue. - handler.write(ctx, new CreateStreamCommand(method, grpcHeaders.serializeAscii(), stream), + handler.write(ctx, new CreateStreamCommand(grpcHeaders, stream), promise); // Read a GOAWAY that indicates our stream was never processed by the server. @@ -246,14 +248,14 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase { private void createStream() throws Exception { // Create the stream. - handler.write(ctx, new CreateStreamCommand(method, grpcHeaders.serializeAscii(), stream), + handler.write(ctx, new CreateStreamCommand(grpcHeaders, stream), promise); when(stream.id()).thenReturn(3); // Reset the context mock to clear recording of sent headers frame. mockContext(); } - private static NettyClientHandler newHandler(String host, boolean ssl) { + private static NettyClientHandler newHandler() { Http2Connection connection = new DefaultHttp2Connection(false); Http2FrameReader frameReader = new DefaultHttp2FrameReader(); Http2FrameWriter frameWriter = new DefaultHttp2FrameWriter(); @@ -261,9 +263,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase { new DefaultHttp2InboundFlowController(connection, frameWriter); Http2OutboundFlowController outboundFlow = new DefaultHttp2OutboundFlowController(connection, frameWriter); - return new NettyClientHandler(host, - ssl, - connection, + return new NettyClientHandler(connection, frameReader, frameWriter, inboundFlow, diff --git a/stub/src/main/java/com/google/net/stubby/stub/HeadersInterceptor.java b/stub/src/main/java/com/google/net/stubby/stub/HeadersInterceptor.java index 6a02c51a77..830722525e 100644 --- a/stub/src/main/java/com/google/net/stubby/stub/HeadersInterceptor.java +++ b/stub/src/main/java/com/google/net/stubby/stub/HeadersInterceptor.java @@ -1,6 +1,7 @@ package com.google.net.stubby.stub; import com.google.net.stubby.Call; +import com.google.net.stubby.Channel; import com.google.net.stubby.Metadata; import com.google.net.stubby.MethodDescriptor; import com.google.net.stubby.context.ForwardingChannel; @@ -20,18 +21,30 @@ public class HeadersInterceptor { public static T intercept( T stub, final Metadata.Headers extraHeaders) { - return (T) stub.configureNewStub().setChannel( - new ForwardingChannel(stub.getChannel()) { - @Override - public Call newCall(MethodDescriptor method) { - return new ForwardingCall(delegate.newCall(method)) { - @Override - public void start(Listener responseListener, Metadata.Headers headers) { - headers.merge(extraHeaders); - delegate.start(responseListener, headers); - } - }; - } - }).build(); + return (T) stub.configureNewStub().setChannel(interceptChannel(stub.getChannel(), extraHeaders)) + .build(); + } + + /** + * Attach a set of request headers to a channel. + * + * @param channel to channel to intercept. + * @param extraHeaders the headers to be passed by each call on the returned stub. + * @return an implementation of the channel with extraHeaders bound to each call. + */ + @SuppressWarnings("unchecked") + public static Channel interceptChannel(Channel channel, final Metadata.Headers extraHeaders) { + return new ForwardingChannel(channel) { + @Override + public Call newCall(MethodDescriptor method) { + return new ForwardingCall(delegate.newCall(method)) { + @Override + public void start(Listener responseListener, Metadata.Headers headers) { + headers.merge(extraHeaders); + delegate.start(responseListener, headers); + } + }; + } + }; } }