Adding support for TLS negotiation to new Netty transport.

This also fixes the ESF test.
-------------
Created by MOE: http://code.google.com/p/moe-java
MOE_MIGRATED_REVID=72144581
This commit is contained in:
nathanmittler 2014-07-29 07:20:05 -07:00 committed by Eric Anderson
parent ed85499d63
commit 5f334f7c52
8 changed files with 230 additions and 31 deletions

View File

@ -3,13 +3,17 @@ package com.google.net.stubby.http2.netty;
import com.google.net.stubby.Response;
import com.google.net.stubby.transport.Framer;
import io.netty.handler.codec.http2.Http2Headers;
import io.netty.handler.codec.http2.DefaultHttp2Headers;
/**
* A HTTP2 based implementation of a {@link Response}.
*/
class Http2Response extends Http2Operation implements Response {
public static ResponseBuilder builder(final int id, final Http2Codec.Http2Writer writer,
final Framer framer) {
final Framer framer) {
return new ResponseBuilder() {
@Override
public Response build(int id) {
@ -23,7 +27,10 @@ class Http2Response extends Http2Operation implements Response {
};
}
private Http2Response(int id, Http2Codec.Http2Writer writer, Framer framer) {
private Http2Response(int id, Http2Codec.Http2Writer writer, Framer framer) {
super(id, writer, framer);
Http2Headers headers = DefaultHttp2Headers.newBuilder().status("200")
.add("content-type", Http2Session.PROTORPC).build();
writer.writeHeaders(id, headers, false, true);
}
}

View File

@ -1,5 +1,7 @@
package com.google.net.stubby.newtransport;
import com.google.net.stubby.transport.Transport;
/**
* Constants for GRPC-over-HTTP (or HTTP/2)
*/
@ -10,6 +12,12 @@ public final class HttpUtil {
*/
public static final String CONTENT_TYPE_HEADER = "content-type";
/**
* The Content-Length header name. Defined here since it is not explicitly defined by the HTTP/2
* spec.
*/
public static final String CONTENT_LENGTH_HEADER = "content-length";
/**
* Content-Type used for GRPC-over-HTTP/2.
*/
@ -20,5 +28,24 @@ public final class HttpUtil {
*/
public static final String HTTP_METHOD = "POST";
/**
* Maps HTTP error response status codes to transport codes.
*/
public static Transport.Code httpStatusToTransportCode(int httpStatusCode) {
if (httpStatusCode < 300) {
return Transport.Code.OK;
}
if (httpStatusCode < 400) {
return Transport.Code.UNAVAILABLE;
}
if (httpStatusCode < 500) {
return Transport.Code.INVALID_ARGUMENT;
}
if (httpStatusCode < 600) {
return Transport.Code.FAILED_PRECONDITION;
}
return Transport.Code.INTERNAL;
}
private HttpUtil() {}
}

View File

@ -22,6 +22,7 @@ import io.netty.handler.codec.http2.Http2Connection;
import io.netty.handler.codec.http2.Http2ConnectionAdapter;
import io.netty.handler.codec.http2.Http2Error;
import io.netty.handler.codec.http2.Http2Exception;
import io.netty.handler.codec.http2.Http2Headers;
import io.netty.handler.codec.http2.Http2Stream;
import io.netty.handler.codec.http2.Http2StreamException;
import io.netty.handler.codec.http2.Http2StreamRemovalPolicy;
@ -112,6 +113,21 @@ class NettyClientHandler extends AbstractHttp2ConnectionHandler {
}
}
@Override
public void onHeadersRead(ChannelHandlerContext ctx,
int streamId,
Http2Headers headers,
int streamDependency,
short weight,
boolean exclusive,
int padding,
boolean endStream,
boolean endSegment) throws Http2Exception {
// TODO(user): Assuming that all headers fit in a single HEADERS frame.
NettyClientStream stream = clientStream(connection().requireStream(streamId));
stream.inboundHeadersRecieved(headers);
}
/**
* Handler for an inbound HTTP/2 DATA frame.
*/
@ -321,9 +337,8 @@ class NettyClientHandler extends AbstractHttp2ConnectionHandler {
.add(CONTENT_TYPE_HEADER, CONTENT_TYPE_PROTORPC)
.path("/" + pendingStream.method.getName())
.build();
writeHeaders(ctx(), ctx().newPromise(), streamId, headersBuilder.build(),
0, false, false).addListener(
new ChannelFutureListener() {
writeHeaders(ctx(), ctx().newPromise(), streamId, headersBuilder.build(), 0, false, false)
.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) {

View File

@ -1,13 +1,19 @@
package com.google.net.stubby.newtransport.netty;
import static com.google.net.stubby.newtransport.StreamState.CLOSED;
import static io.netty.util.CharsetUtil.UTF_8;
import com.google.common.base.Preconditions;
import com.google.net.stubby.Status;
import com.google.net.stubby.newtransport.AbstractStream;
import com.google.net.stubby.newtransport.ClientStream;
import com.google.net.stubby.newtransport.Deframer;
import com.google.net.stubby.newtransport.HttpUtil;
import com.google.net.stubby.newtransport.StreamListener;
import com.google.net.stubby.transport.Transport;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http2.Http2Headers;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelPromise;
@ -23,6 +29,9 @@ class NettyClientStream extends AbstractStream implements ClientStream {
private volatile int id = PENDING_STREAM_ID;
private final Channel channel;
private final Deframer<ByteBuf> deframer;
private Transport.Code responseCode = Transport.Code.UNKNOWN;
private boolean isGrpcResponse;
private StringBuilder nonGrpcErrorMessage = new StringBuilder();
NettyClientStream(StreamListener listener, Channel channel) {
super(listener);
@ -49,6 +58,14 @@ class NettyClientStream extends AbstractStream implements ClientStream {
channel.writeAndFlush(new CancelStreamCommand(this));
}
/**
* Called in the channel thread to process headers received from the server.
*/
public void inboundHeadersRecieved(Http2Headers headers) {
responseCode = responseCode(headers);
isGrpcResponse = isGrpcResponse(headers, responseCode);
}
/**
* Called in the channel thread to process the content of an inbound DATA frame.
*
@ -64,11 +81,24 @@ class NettyClientStream extends AbstractStream implements ClientStream {
return;
}
// Retain the ByteBuf until it is released by the deframer.
deframer.deliverFrame(frame.retain(), endOfStream);
if (isGrpcResponse) {
// Retain the ByteBuf until it is released by the deframer.
deframer.deliverFrame(frame.retain(), endOfStream);
// TODO(user): add flow control.
promise.setSuccess();
// TODO(user): add flow control.
promise.setSuccess();
} else {
// It's not a GRPC response, assume that the frame contains a text-based error message.
// TODO(user): Should we send RST_STREAM as well?
// TODO(user): is there a better way to handle large non-GRPC error messages?
nonGrpcErrorMessage.append(frame.toString(UTF_8));
if (endOfStream) {
String msg = nonGrpcErrorMessage.toString();
setStatus(new Status(responseCode, msg));
}
}
}
@Override
@ -86,4 +116,40 @@ class NettyClientStream extends AbstractStream implements ClientStream {
buf.writeBytes(source);
return buf;
}
/**
* Determines whether or not the response from the server is a GRPC response.
*/
private static boolean isGrpcResponse(Http2Headers headers, Transport.Code code) {
if (headers == null) {
// No headers, not a GRPC response.
return false;
}
// GRPC responses should always return OK. Updated this code once b/16290036 is fixed.
if (code == Transport.Code.OK) {
// ESF currently returns the wrong content-type for grpc.
return true;
}
String contentType = headers.get(HttpUtil.CONTENT_TYPE_HEADER);
return HttpUtil.CONTENT_TYPE_PROTORPC.equalsIgnoreCase(contentType);
}
/**
* Parses the response status and converts it to a transport code.
*/
private static Transport.Code responseCode(Http2Headers headers) {
if (headers == null) {
return Transport.Code.UNKNOWN;
}
String statusLine = headers.status();
if (statusLine == null) {
return Transport.Code.UNKNOWN;
}
HttpResponseStatus status = HttpResponseStatus.parseLine(statusLine);
return HttpUtil.httpStatusToTransportCode(status.code());
}
}

View File

@ -8,20 +8,22 @@ import com.google.net.stubby.newtransport.AbstractClientTransport;
import com.google.net.stubby.newtransport.ClientStream;
import com.google.net.stubby.newtransport.ClientTransport;
import com.google.net.stubby.newtransport.StreamListener;
import com.google.net.stubby.newtransport.netty.NettyClientTransportFactory.NegotiationType;
import com.google.net.stubby.testing.utils.ssl.SslContextFactory;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http2.DefaultHttp2StreamRemovalPolicy;
import java.util.concurrent.ExecutionException;
import javax.net.ssl.SSLEngine;
/**
* A Netty-based {@link ClientTransport} implementation.
*/
@ -30,31 +32,41 @@ class NettyClientTransport extends AbstractClientTransport {
private final String host;
private final int port;
private final EventLoopGroup eventGroup;
private final ChannelInitializer<SocketChannel> channelInitializer;
private final Http2Negotiator.Negotiation negotiation;
private Channel channel;
NettyClientTransport(String host, int port, boolean ssl) {
this(host, port, ssl, new NioEventLoopGroup());
NettyClientTransport(String host, int port, NegotiationType negotiationType) {
this(host, port, negotiationType, new NioEventLoopGroup());
}
NettyClientTransport(String host, int port, boolean ssl, EventLoopGroup eventGroup) {
NettyClientTransport(String host, int port, NegotiationType negotiationType,
EventLoopGroup eventGroup) {
Preconditions.checkNotNull(host, "host");
Preconditions.checkArgument(port >= 0, "port must be positive");
Preconditions.checkNotNull(eventGroup, "eventGroup");
Preconditions.checkNotNull(negotiationType, "negotiationType");
this.host = host;
this.port = port;
this.eventGroup = eventGroup;
final DefaultHttp2StreamRemovalPolicy streamRemovalPolicy =
new DefaultHttp2StreamRemovalPolicy();
final NettyClientHandler handler = new NettyClientHandler(host, ssl, streamRemovalPolicy);
// TODO(user): handle SSL.
channelInitializer = new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
ch.pipeline().addLast(streamRemovalPolicy);
ch.pipeline().addLast(handler);
}
};
final NettyClientHandler handler =
new NettyClientHandler(host, negotiationType == NegotiationType.TLS, streamRemovalPolicy);
switch (negotiationType) {
case PLAINTEXT:
negotiation = Http2Negotiator.plaintext(handler);
break;
case PLAINTEXT_UPGRADE:
negotiation = Http2Negotiator.plaintextUpgrade(handler);
break;
case TLS:
SSLEngine sslEngine = SslContextFactory.getClientContext().createSSLEngine();
sslEngine.setUseClientMode(true);
negotiation = Http2Negotiator.tls(handler, sslEngine);
break;
default:
throw new IllegalArgumentException("Unsupported negotiationType: " + negotiationType);
}
}
@Override
@ -84,7 +96,7 @@ class NettyClientTransport extends AbstractClientTransport {
b.group(eventGroup);
b.channel(NioSocketChannel.class);
b.option(SO_KEEPALIVE, true);
b.handler(channelInitializer);
b.handler(negotiation.initializer());
// Start the connection operation to the server.
b.connect(host, port).addListener(new ChannelFutureListener() {

View File

@ -10,21 +10,43 @@ import io.netty.channel.EventLoopGroup;
*/
public class NettyClientTransportFactory implements ClientTransportFactory {
/**
* Identifies the negotiation used for starting up HTTP/2.
*/
public enum NegotiationType {
/**
* Uses TLS ALPN/NPN negotiation, assumes an SSL connection.
*/
TLS,
/**
* Use the HTTP UPGRADE protocol for a plaintext (non-SSL) upgrade from HTTP/1.1 to HTTP/2.
*/
PLAINTEXT_UPGRADE,
/**
* Just assume the connection is plaintext (non-SSL) and the remote endpoint supports HTTP/2
* directly without an upgrade.
*/
PLAINTEXT
}
private final String host;
private final int port;
private final boolean ssl;
private final NegotiationType negotiationType;
private final EventLoopGroup group;
public NettyClientTransportFactory(String host, int port, boolean ssl, EventLoopGroup group) {
public NettyClientTransportFactory(String host, int port, NegotiationType negotiationType,
EventLoopGroup group) {
this.group = Preconditions.checkNotNull(group, "group");
Preconditions.checkArgument(port > 0, "Port must be positive");
this.host = Preconditions.checkNotNull(host, "host");
this.negotiationType = Preconditions.checkNotNull(negotiationType, "negotiationType");
this.port = port;
this.ssl = ssl;
}
@Override
public NettyClientTransport newClientTransport() {
return new NettyClientTransport(host, port, ssl, group);
return new NettyClientTransport(host, port, negotiationType, group);
}
}

View File

@ -18,6 +18,7 @@ import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableMap;
import com.google.net.stubby.MethodDescriptor;
import com.google.net.stubby.Status;
import com.google.net.stubby.newtransport.HttpUtil;
import com.google.net.stubby.newtransport.StreamState;
import com.google.net.stubby.transport.Transport;
@ -33,6 +34,7 @@ import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import io.netty.handler.codec.http2.DefaultHttp2Headers;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.buffer.UnpooledByteBufAllocator;
@ -176,6 +178,18 @@ public class NettyClientHandlerTest {
verify(promise).setFailure(any(Throwable.class));
}
@Test
public void inboundHeadersShouldForwardToStream() throws Exception {
createStream();
// Read a headers frame first.
Http2Headers headers = DefaultHttp2Headers.newBuilder().status("200")
.set(HttpUtil.CONTENT_TYPE_HEADER, HttpUtil.CONTENT_TYPE_PROTORPC).build();
ByteBuf headersFrame = headersFrame(3, headers);
handler.channelRead(this.ctx, headersFrame);
verify(stream).inboundHeadersRecieved(headers);
}
@Test
public void inboundDataShouldForwardToStream() throws Exception {
createStream();
@ -235,6 +249,12 @@ public class NettyClientHandlerTest {
mockContext();
}
private ByteBuf headersFrame(int streamId, Http2Headers headers) {
ChannelHandlerContext ctx = newContext();
frameWriter.writeHeaders(ctx, promise, streamId, headers, 0, false, false);
return captureWrite(ctx);
}
private ByteBuf dataFrame(int streamId, boolean endStream) {
// Need to retain the content since the frameWriter releases it.
content.retain();

View File

@ -14,6 +14,7 @@ import static org.mockito.Mockito.when;
import com.google.common.io.ByteStreams;
import com.google.net.stubby.Status;
import com.google.net.stubby.newtransport.HttpUtil;
import com.google.net.stubby.newtransport.StreamListener;
import com.google.net.stubby.newtransport.StreamState;
import com.google.net.stubby.transport.Transport;
@ -30,6 +31,8 @@ import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import io.netty.handler.codec.http2.DefaultHttp2Headers;
import io.netty.handler.codec.http2.Http2Headers;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.buffer.UnpooledByteBufAllocator;
@ -167,6 +170,9 @@ public class NettyClientStreamTest {
@Test
public void inboundContextShouldCallListener() throws Exception {
// Receive headers first so that it's a valid GRPC response.
stream.inboundHeadersRecieved(grpcResponseHeaders());
stream.inboundDataReceived(contextFrame(), false, promise);
ArgumentCaptor<InputStream> captor = ArgumentCaptor.forClass(InputStream.class);
verify(listener).contextRead(eq(CONTEXT_KEY), captor.capture(), eq(MESSAGE.length()));
@ -176,6 +182,9 @@ public class NettyClientStreamTest {
@Test
public void inboundMessageShouldCallListener() throws Exception {
// Receive headers first so that it's a valid GRPC response.
stream.inboundHeadersRecieved(grpcResponseHeaders());
stream.inboundDataReceived(messageFrame(), false, promise);
ArgumentCaptor<InputStream> captor = ArgumentCaptor.forClass(InputStream.class);
verify(listener).messageRead(captor.capture(), eq(MESSAGE.length()));
@ -186,6 +195,10 @@ public class NettyClientStreamTest {
@Test
public void inboundStatusShouldSetStatus() throws Exception {
stream.id(1);
// Receive headers first so that it's a valid GRPC response.
stream.inboundHeadersRecieved(grpcResponseHeaders());
stream.inboundDataReceived(statusFrame(), false, promise);
ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
verify(listener).closed(captor.capture());
@ -194,6 +207,14 @@ public class NettyClientStreamTest {
assertEquals(StreamState.CLOSED, stream.state());
}
@Test
public void nonGrpcResponseShouldSetStatus() throws Exception {
stream.inboundDataReceived(Unpooled.copiedBuffer(MESSAGE, UTF_8), true, promise);
ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
verify(listener).closed(captor.capture());
assertEquals(MESSAGE, captor.getValue().getDescription());
}
private String toString(InputStream in) throws Exception {
byte[] bytes = new byte[in.available()];
ByteStreams.readFully(in, bytes);
@ -201,8 +222,12 @@ public class NettyClientStreamTest {
}
private ByteBuf contextFrame() throws Exception {
byte[] body = ContextValue.newBuilder().setKey(CONTEXT_KEY)
.setValue(ByteString.copyFromUtf8(MESSAGE)).build().toByteArray();
byte[] body = ContextValue
.newBuilder()
.setKey(CONTEXT_KEY)
.setValue(ByteString.copyFromUtf8(MESSAGE))
.build()
.toByteArray();
ByteArrayOutputStream os = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(os);
dos.write(CONTEXT_VALUE_FRAME);
@ -246,6 +271,11 @@ public class NettyClientStreamTest {
return buf;
}
private Http2Headers grpcResponseHeaders() {
return DefaultHttp2Headers.newBuilder().status("200")
.set(HttpUtil.CONTENT_TYPE_HEADER, HttpUtil.CONTENT_TYPE_PROTORPC).build();
}
private void mockChannelFuture(boolean succeeded) {
when(future.isDone()).thenReturn(true);
when(future.isCancelled()).thenReturn(false);