diff --git a/core/src/main/java/com/google/net/stubby/newtransport/AbstractClientStream.java b/core/src/main/java/com/google/net/stubby/newtransport/AbstractClientStream.java index f2c4acaa99..6180109162 100644 --- a/core/src/main/java/com/google/net/stubby/newtransport/AbstractClientStream.java +++ b/core/src/main/java/com/google/net/stubby/newtransport/AbstractClientStream.java @@ -34,6 +34,10 @@ public abstract class AbstractClientStream extends AbstractStream implements Cli this.listener = Preconditions.checkNotNull(listener); } + protected ListenableFuture receiveHeaders(Metadata.Headers headers) { + return listener.headersRead(headers); + } + @Override protected ListenableFuture receiveMessage(InputStream is, int length) { return listener.messageRead(is, length); @@ -51,8 +55,8 @@ public abstract class AbstractClientStream extends AbstractStream implements Cli * If using gRPC v2 protocol, this method must be called with received trailers before notifying * deframer of end of stream. */ - public void stashTrailers(Metadata.Trailers trailers) { - Preconditions.checkNotNull(status, "trailers"); + protected void stashTrailers(Metadata.Trailers trailers) { + Preconditions.checkNotNull(trailers, "trailers"); stashedStatus = trailers.get(Status.CODE_KEY) .withDescription(trailers.get(Status.MESSAGE_KEY)); trailers.removeAll(Status.CODE_KEY); @@ -62,6 +66,14 @@ public abstract class AbstractClientStream extends AbstractStream implements Cli @Override protected void remoteEndClosed() { + // TODO(user): Delete this hack when trailers are supported by GFE with v2. Currently GFE + // doesn't support trailers, so when using gRPC v2 protocol GFE will not send any status. We + // paper over this for now by just assuming OK. For all properly functioning servers (both v1 + // and v2), stashedStatus should not be null here. + if (stashedStatus == null) { + stashedStatus = Status.OK; + stashedTrailers = new Metadata.Trailers(); + } Preconditions.checkState(stashedStatus != null, "Status and trailers should have been set"); setStatus(stashedStatus, stashedTrailers); } diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientStream.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientStream.java index 1e67fd1018..e11ecadbe6 100644 --- a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientStream.java +++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientStream.java @@ -11,6 +11,7 @@ import com.google.common.util.concurrent.ListenableFuture; import com.google.net.stubby.Metadata; import com.google.net.stubby.Status; import com.google.net.stubby.newtransport.AbstractClientStream; +import com.google.net.stubby.newtransport.Buffers; import com.google.net.stubby.newtransport.ClientStreamListener; import com.google.net.stubby.newtransport.GrpcDeframer; import com.google.net.stubby.newtransport.HttpUtil; @@ -40,7 +41,8 @@ class NettyClientStream extends AbstractClientStream implements NettyStream { private final WindowUpdateManager windowUpdateManager; private Status responseStatus = Status.UNKNOWN; private boolean isGrpcResponse; - private StringBuilder nonGrpcErrorMessage = new StringBuilder(); + private boolean seenHeaders; + private StringBuilder nonGrpcErrorMessage; NettyClientStream(ClientStreamListener listener, Channel channel, DefaultHttp2InboundFlowController inboundFlow) { @@ -83,16 +85,20 @@ class NettyClientStream extends AbstractClientStream implements NettyStream { */ public void inboundHeadersRecieved(Http2Headers headers, boolean endOfStream) { responseStatus = responseStatus(headers, responseStatus); - isGrpcResponse = isGrpcResponse(headers, responseStatus); - if (endOfStream) { - if (isGrpcResponse) { - // TODO(user): call stashTrailers() as appropriate, then provide endOfStream to - // deframer. - setStatus(responseStatus, new Metadata.Trailers()); - } else { - setStatus(responseStatus, new Metadata.Trailers()); + if (!seenHeaders) { + seenHeaders = true; + isGrpcResponse = isGrpcResponse(headers); + // If endOfStream, we have trailers and no "headers" were sent. + if (!endOfStream && GRPC_V2_PROTOCOL) { + deframer2.delayProcessing(receiveHeaders(Utils.convertHeaders(headers))); } } + if (endOfStream) { + if (GRPC_V2_PROTOCOL) { + stashTrailers(Utils.convertTrailers(headers)); + } + endOfStream(); + } } /** @@ -111,21 +117,38 @@ class NettyClientStream extends AbstractClientStream implements NettyStream { if (isGrpcResponse) { // Retain the ByteBuf until it is released by the deframer. if (!GRPC_V2_PROTOCOL) { - deframer.deframe(new NettyBuffer(frame.retain()), endOfStream); + deframer.deframe(new NettyBuffer(frame.retain()), false); } else { - deframer2.deframe(new NettyBuffer(frame.retain()), endOfStream); + deframer2.deframe(new NettyBuffer(frame.retain()), false); } } else { // It's not a GRPC response, assume that the frame contains a text-based error message. // TODO(user): Should we send RST_STREAM as well? // TODO(user): is there a better way to handle large non-GRPC error messages? - nonGrpcErrorMessage.append(frame.toString(UTF_8)); - - if (endOfStream) { - String msg = nonGrpcErrorMessage.toString(); - setStatus(responseStatus.withDescription(msg), new Metadata.Trailers()); + if (nonGrpcErrorMessage == null) { + nonGrpcErrorMessage = new StringBuilder(); } + nonGrpcErrorMessage.append(frame.toString(UTF_8)); + } + + if (endOfStream) { + endOfStream(); + } + } + + private void endOfStream() { + if (isGrpcResponse) { + if (!GRPC_V2_PROTOCOL) { + deframer.deframe(Buffers.empty(), true); + } else { + deframer2.deframe(Buffers.empty(), true); + } + } else { + if (nonGrpcErrorMessage != null && nonGrpcErrorMessage.length() > 0) { + responseStatus = responseStatus.withDescription(nonGrpcErrorMessage.toString()); + } + setStatus(responseStatus, new Metadata.Trailers()); } } @@ -144,7 +167,7 @@ class NettyClientStream extends AbstractClientStream implements NettyStream { /** * Determines whether or not the response from the server is a GRPC response. */ - private boolean isGrpcResponse(Http2Headers headers, Status status) { + private boolean isGrpcResponse(Http2Headers headers) { if (isGrpcResponse) { // Already verified that it's a gRPC response. return true; @@ -155,24 +178,28 @@ class NettyClientStream extends AbstractClientStream implements NettyStream { return false; } - // GRPC responses should always return OK. Updated this code once b/16290036 is fixed. - if (status.isOk()) { - // ESF currently returns the wrong content-type for grpc. + AsciiString contentType = headers.get(CONTENT_TYPE_HEADER); + if (CONTENT_TYPE_PROTORPC.equalsIgnoreCase(contentType)) { return true; } - AsciiString contentType = headers.get(CONTENT_TYPE_HEADER); - return CONTENT_TYPE_PROTORPC.equalsIgnoreCase(contentType); + // Since ESF returns the wrong content-type, assume that any 200 response is gRPC, until + // b/16290036 is fixed. + AsciiString statusLine = headers.status(); + if (statusLine != null) { + HttpResponseStatus httpStatus = HttpResponseStatus.parseLine(statusLine); + if (HttpResponseStatus.OK.equals(httpStatus)) { + return true; + } + } + + return false; } /** * Parses the response status and converts it to a transport code. */ private static Status responseStatus(Http2Headers headers, Status defaultValue) { - if (headers == null) { - return defaultValue; - } - // First, check to see if we found a v2 protocol grpc-status header. AsciiString grpcStatus = headers.get(GRPC_STATUS_HEADER); if (grpcStatus != null) { @@ -181,10 +208,15 @@ class NettyClientStream extends AbstractClientStream implements NettyStream { // Next, check the HTTP/2 status. AsciiString statusLine = headers.status(); - if (statusLine == null) { - return defaultValue; + if (statusLine != null) { + HttpResponseStatus httpStatus = HttpResponseStatus.parseLine(statusLine); + Status status = HttpUtil.httpStatusToGrpcStatus(httpStatus.code()); + // Only use OK when provided via the GRPC status header. + if (!status.isOk()) { + return status; + } } - HttpResponseStatus status = HttpResponseStatus.parseLine(statusLine); - return HttpUtil.httpStatusToGrpcStatus(status.code()); + + return defaultValue; } } diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientTransport.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientTransport.java index c6f59493d9..f3080debf4 100644 --- a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientTransport.java +++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientTransport.java @@ -99,7 +99,7 @@ class NettyClientTransport extends AbstractClientTransport { try { // Convert the headers into Netty HTTP/2 headers. AsciiString defaultPath = new AsciiString("/" + method.getName()); - Http2Headers http2Headers = Utils.convertHeaders(headers, ssl, defaultPath, authority); + Http2Headers http2Headers = Utils.convertClientHeaders(headers, ssl, defaultPath, authority); // Write the request and await creation of the stream. channel.writeAndFlush(new CreateStreamCommand(http2Headers, stream)).get(); diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyServerHandler.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyServerHandler.java index 71622fbbcc..f24bbe0464 100644 --- a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyServerHandler.java +++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyServerHandler.java @@ -3,7 +3,6 @@ package com.google.net.stubby.newtransport.netty; import static com.google.net.stubby.newtransport.netty.Utils.CONTENT_TYPE_HEADER; import static com.google.net.stubby.newtransport.netty.Utils.CONTENT_TYPE_PROTORPC; import static com.google.net.stubby.newtransport.netty.Utils.HTTP_METHOD; -import static com.google.net.stubby.newtransport.netty.Utils.STATUS_OK; import static io.netty.buffer.Unpooled.EMPTY_BUFFER; import static io.netty.handler.codec.http2.Http2CodecUtil.toByteBuf; import static io.netty.handler.codec.http2.Http2Error.NO_ERROR; @@ -189,14 +188,7 @@ class NettyServerHandler extends Http2ConnectionHandler { ctx.flush(); } else if (msg instanceof SendResponseHeadersCommand) { SendResponseHeadersCommand cmd = (SendResponseHeadersCommand) msg; - encoder().writeHeaders(ctx, - cmd.streamId(), - new DefaultHttp2Headers() - .status(STATUS_OK) - .set(CONTENT_TYPE_HEADER, CONTENT_TYPE_PROTORPC), - 0, - false, - promise); + encoder().writeHeaders(ctx, cmd.streamId(), cmd.headers(), 0, cmd.endOfStream(), promise); ctx.flush(); } else { AssertionError e = new AssertionError("Write called for unexpected type: " diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyServerStream.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyServerStream.java index cc5c1be028..d3990bca39 100644 --- a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyServerStream.java +++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyServerStream.java @@ -11,6 +11,8 @@ import com.google.net.stubby.newtransport.StreamState; import io.netty.buffer.ByteBuf; import io.netty.channel.Channel; import io.netty.handler.codec.http2.DefaultHttp2InboundFlowController; +import io.netty.handler.codec.http2.DefaultHttp2Headers; +import io.netty.handler.codec.http2.Http2Headers; import java.nio.ByteBuffer; @@ -61,7 +63,10 @@ class NettyServerStream extends AbstractServerStream implements NettyStream { @Override protected void sendFrame(ByteBuffer frame, boolean endOfStream) { if (!headersSent) { - channel.write(new SendResponseHeadersCommand(id)); + Http2Headers headers = new DefaultHttp2Headers() + .status(Utils.STATUS_OK) + .set(Utils.CONTENT_TYPE_HEADER, Utils.CONTENT_TYPE_PROTORPC); + channel.write(new SendResponseHeadersCommand(id, headers, false)); headersSent = true; } SendGrpcFrameCommand cmd = @@ -71,7 +76,13 @@ class NettyServerStream extends AbstractServerStream implements NettyStream { @Override protected void sendTrailers(Metadata.Trailers trailers) { - // TODO(user): send trailers + Http2Headers http2Trailers = Utils.convertTrailers(trailers); + if (!headersSent) { + http2Trailers.status(Utils.STATUS_OK) + .set(Utils.CONTENT_TYPE_HEADER, Utils.CONTENT_TYPE_PROTORPC); + headersSent = true; + } + channel.writeAndFlush(new SendResponseHeadersCommand(id, http2Trailers, true)); } @Override diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/SendResponseHeadersCommand.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/SendResponseHeadersCommand.java index 2c761b4b81..f7d5386ec9 100644 --- a/core/src/main/java/com/google/net/stubby/newtransport/netty/SendResponseHeadersCommand.java +++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/SendResponseHeadersCommand.java @@ -1,31 +1,50 @@ package com.google.net.stubby.newtransport.netty; +import com.google.common.base.Preconditions; + +import io.netty.handler.codec.http2.Http2Headers; + /** * Command sent from the transport to the Netty channel to send response headers to the client. */ class SendResponseHeadersCommand { private final int streamId; + private final Http2Headers headers; + private final boolean endOfStream; - SendResponseHeadersCommand(int streamId) { + SendResponseHeadersCommand(int streamId, Http2Headers headers, boolean endOfStream) { this.streamId = streamId; + this.headers = Preconditions.checkNotNull(headers); + this.endOfStream = endOfStream; } int streamId() { return streamId; } + Http2Headers headers() { + return headers; + } + + boolean endOfStream() { + return endOfStream; + } + @Override public boolean equals(Object that) { if (that == null || !that.getClass().equals(SendResponseHeadersCommand.class)) { return false; } SendResponseHeadersCommand thatCmd = (SendResponseHeadersCommand) that; - return thatCmd.streamId == streamId; + return thatCmd.streamId == streamId + && thatCmd.headers.equals(headers) + && thatCmd.endOfStream == endOfStream; } @Override public String toString() { - return getClass().getSimpleName() + "(streamId=" + streamId + ")"; + return getClass().getSimpleName() + "(streamId=" + streamId + ", headers=" + headers + + ", endOfStream=" + endOfStream + ")"; } @Override diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/Utils.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/Utils.java index 742f5edaaa..e81be7149d 100644 --- a/core/src/main/java/com/google/net/stubby/newtransport/netty/Utils.java +++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/Utils.java @@ -38,6 +38,21 @@ class Utils { } public static Metadata.Headers convertHeaders(Http2Headers http2Headers) { + Metadata.Headers headers = new Metadata.Headers(convertHeadersToArray(http2Headers)); + if (http2Headers.authority() != null) { + headers.setAuthority(http2Headers.authority().toString()); + } + if (http2Headers.path() != null) { + headers.setPath(http2Headers.path().toString()); + } + return headers; + } + + public static Metadata.Trailers convertTrailers(Http2Headers http2Headers) { + return new Metadata.Trailers(convertHeadersToArray(http2Headers)); + } + + private static byte[][] convertHeadersToArray(Http2Headers http2Headers) { // The Netty AsciiString class is really just a wrapper around a byte[] and supports // arbitrary binary data, not just ASCII. byte[][] headerValues = new byte[http2Headers.size()*2][]; @@ -46,33 +61,24 @@ class Utils { headerValues[i++] = entry.getKey().array(); headerValues[i++] = entry.getValue().array(); } - return new Metadata.Headers(headerValues); + return headerValues; } - public static Http2Headers convertHeaders(Metadata.Headers headers, + public static Http2Headers convertClientHeaders(Metadata.Headers headers, boolean ssl, AsciiString defaultPath, AsciiString defaultAuthority) { - Preconditions.checkNotNull(headers, "headers"); Preconditions.checkNotNull(defaultPath, "defaultPath"); Preconditions.checkNotNull(defaultAuthority, "defaultAuthority"); - - Http2Headers http2Headers = new DefaultHttp2Headers(); - // Add any application-provided headers first. - byte[][] serializedHeaders = headers.serialize(); - for (int i = 0; i < serializedHeaders.length; i++) { - http2Headers.add(new AsciiString(serializedHeaders[i], false), - new AsciiString(serializedHeaders[++i], false)); - } + Http2Headers http2Headers = convertMetadata(headers); // Now set GRPC-specific default headers. - http2Headers - .authority(defaultAuthority) + http2Headers.authority(defaultAuthority) .path(defaultPath) .method(HTTP_METHOD) - .scheme(ssl? HTTPS : HTTP) - .add(CONTENT_TYPE_HEADER, CONTENT_TYPE_PROTORPC); + .scheme(ssl ? HTTPS : HTTP) + .set(CONTENT_TYPE_HEADER, CONTENT_TYPE_PROTORPC); // Override the default authority and path if provided by the headers. if (headers.getAuthority() != null) { @@ -85,6 +91,25 @@ class Utils { return http2Headers; } + public static Http2Headers convertServerHeaders(Metadata.Headers headers) { + return convertMetadata(headers); + } + + public static Http2Headers convertTrailers(Metadata.Trailers trailers) { + return convertMetadata(trailers); + } + + private static Http2Headers convertMetadata(Metadata headers) { + Preconditions.checkNotNull(headers, "headers"); + Http2Headers http2Headers = new DefaultHttp2Headers(); + byte[][] serializedHeaders = headers.serialize(); + for (int i = 0; i < serializedHeaders.length; i++) { + http2Headers.add(new AsciiString(serializedHeaders[i], false), + new AsciiString(serializedHeaders[++i], false)); + } + return http2Headers; + } + private Utils() { // Prevents instantiation } diff --git a/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyServerStreamTest.java b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyServerStreamTest.java index 4af0a9a182..8123705b89 100644 --- a/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyServerStreamTest.java +++ b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyServerStreamTest.java @@ -15,6 +15,9 @@ import com.google.net.stubby.newtransport.StreamState; import io.netty.buffer.EmptyByteBuf; import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.handler.codec.AsciiString; +import io.netty.handler.codec.http2.DefaultHttp2Headers; +import io.netty.handler.codec.http2.Http2Headers; import org.junit.Test; import org.junit.runner.RunWith; @@ -32,7 +35,10 @@ public class NettyServerStreamTest extends NettyStreamTestBase { public void writeMessageShouldSendResponse() throws Exception { stream.writeMessage(input, input.available(), accepted); stream.flush(); - verify(channel).write(new SendResponseHeadersCommand(STREAM_ID)); + Http2Headers headers = new DefaultHttp2Headers() + .status(Utils.STATUS_OK) + .set(Utils.CONTENT_TYPE_HEADER, Utils.CONTENT_TYPE_PROTORPC); + verify(channel).write(new SendResponseHeadersCommand(STREAM_ID, headers, false)); verify(channel).writeAndFlush(new SendGrpcFrameCommand(STREAM_ID, messageFrame(), false)); verify(accepted).run(); }