Add server support for TLS.

Note that we don't yet have plumbing to use a particular certificate for tests, so it isn't integration-test worthy yet.
-------------
Created by MOE: http://code.google.com/p/moe-java
MOE_MIGRATED_REVID=78369023
This commit is contained in:
ejona 2014-10-23 12:05:19 -07:00 committed by Eric Anderson
parent 988f219b04
commit 46118bb195
5 changed files with 77 additions and 24 deletions

View File

@ -44,7 +44,7 @@ abstract class AbstractServiceBuilder<ProductT extends Service,
* <p>The returned service has not been started at this point. You will need to start it by * <p>The returned service has not been started at this point. You will need to start it by
* yourself or use {@link #buildAndStart()}. * yourself or use {@link #buildAndStart()}.
*/ */
private ProductT build() { public ProductT build() {
final ExecutorService executor = (userExecutor == null) final ExecutorService executor = (userExecutor == null)
? Executors.newCachedThreadPool() : userExecutor; ? Executors.newCachedThreadPool() : userExecutor;
ProductT service = buildImpl(executor); ProductT service = buildImpl(executor);

View File

@ -64,6 +64,17 @@ public class Http2Negotiator {
ListenableFuture<Void> completeFuture(); ListenableFuture<Void> completeFuture();
} }
/**
* Create a TLS handler for HTTP/2 capable of using ALPN/NPN.
*/
public static ChannelHandler serverTls(SSLEngine sslEngine) {
Preconditions.checkNotNull(sslEngine, "sslEngine");
if (!installJettyTLSProtocolSelection(sslEngine, SettableFuture.<Void>create(), true)) {
throw new IllegalStateException("NPN/ALPN extensions not installed");
}
return new SslHandler(sslEngine, false);
}
/** /**
* Creates an TLS negotiation for HTTP/2 using ALPN/NPN. * Creates an TLS negotiation for HTTP/2 using ALPN/NPN.
*/ */
@ -72,7 +83,7 @@ public class Http2Negotiator {
Preconditions.checkNotNull(sslEngine, "sslEngine"); Preconditions.checkNotNull(sslEngine, "sslEngine");
final SettableFuture<Void> completeFuture = SettableFuture.create(); final SettableFuture<Void> completeFuture = SettableFuture.create();
if (!installJettyTLSProtocolSelection(sslEngine, completeFuture)) { if (!installJettyTLSProtocolSelection(sslEngine, completeFuture, false)) {
throw new IllegalStateException("NPN/ALPN extensions not installed"); throw new IllegalStateException("NPN/ALPN extensions not installed");
} }
final ChannelInitializer<SocketChannel> initializer = new ChannelInitializer<SocketChannel>() { final ChannelInitializer<SocketChannel> initializer = new ChannelInitializer<SocketChannel>() {
@ -236,7 +247,7 @@ public class Http2Negotiator {
* @return true if NPN/ALPN support is available. * @return true if NPN/ALPN support is available.
*/ */
private static boolean installJettyTLSProtocolSelection(final SSLEngine engine, private static boolean installJettyTLSProtocolSelection(final SSLEngine engine,
final SettableFuture<Void> protocolNegotiated) { final SettableFuture<Void> protocolNegotiated, boolean server) {
for (String protocolNegoClassName : JETTY_TLS_NEGOTIATION_IMPL) { for (String protocolNegoClassName : JETTY_TLS_NEGOTIATION_IMPL) {
try { try {
Class<?> negoClass; Class<?> negoClass;
@ -249,38 +260,53 @@ public class Http2Negotiator {
} }
Class<?> providerClass = Class.forName(protocolNegoClassName + "$Provider"); Class<?> providerClass = Class.forName(protocolNegoClassName + "$Provider");
Class<?> clientProviderClass = Class.forName(protocolNegoClassName + "$ClientProvider"); Class<?> clientProviderClass = Class.forName(protocolNegoClassName + "$ClientProvider");
Class<?> serverProviderClass = Class.forName(protocolNegoClassName + "$ServerProvider");
Method putMethod = negoClass.getMethod("put", SSLEngine.class, providerClass); Method putMethod = negoClass.getMethod("put", SSLEngine.class, providerClass);
final Method removeMethod = negoClass.getMethod("remove", SSLEngine.class); final Method removeMethod = negoClass.getMethod("remove", SSLEngine.class);
putMethod.invoke(null, engine, Proxy.newProxyInstance( putMethod.invoke(null, engine, Proxy.newProxyInstance(
Http2Negotiator.class.getClassLoader(), new Class[] {clientProviderClass}, Http2Negotiator.class.getClassLoader(),
new Class[] {server ? serverProviderClass : clientProviderClass},
new InvocationHandler() { new InvocationHandler() {
@Override @Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
String methodName = method.getName(); String methodName = method.getName();
if ("supports".equals(methodName)) { if ("supports".equals(methodName)) {
// both // NPN client
return true; return true;
} }
if ("unsupported".equals(methodName)) { if ("unsupported".equals(methodName)) {
// both // all
removeMethod.invoke(null, engine); removeMethod.invoke(null, engine);
protocolNegotiated.setException(new IllegalStateException( protocolNegotiated.setException(new RuntimeException(
"ALPN/NPN protocol " + HTTP_VERSION_NAME + " not supported by server")); "ALPN/NPN protocol " + HTTP_VERSION_NAME + " not supported by endpoint"));
return null; return null;
} }
if ("protocols".equals(methodName)) { if ("protocols".equals(methodName)) {
// ALPN only // ALPN client, NPN server
return ImmutableList.of(HTTP_VERSION_NAME); return ImmutableList.of(HTTP_VERSION_NAME);
} }
if ("selected".equals(methodName)) { if ("selected".equals(methodName) || "protocolSelected".equals(methodName)) {
// ALPN only // ALPN client, NPN server
// Only 'supports' one protocol so we know what was selected.
removeMethod.invoke(null, engine); removeMethod.invoke(null, engine);
String protocol = (String) args[0];
if (!HTTP_VERSION_NAME.equals(protocol)) {
RuntimeException e = new RuntimeException(
"Unsupported protocol selected via ALPN/NPN: " + protocol);
protocolNegotiated.setException(e);
if ("selected".equals(methodName)) {
// ALPN client
// Throwing exception causes TLS alert.
throw e;
} else {
return null;
}
}
protocolNegotiated.set(null); protocolNegotiated.set(null);
return null; return null;
} }
if ("selectProtocol".equals(methodName)) { if ("select".equals(methodName) || "selectProtocol".equals(methodName)) {
// NPN only // ALPN server, NPN client
removeMethod.invoke(null, engine);
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
List<String> names = (List<String>) args[0]; List<String> names = (List<String>) args[0];
for (String name : names) { for (String name : names) {
@ -289,9 +315,13 @@ public class Http2Negotiator {
return name; return name;
} }
} }
protocolNegotiated.setException( RuntimeException e =
new IllegalStateException("Protocol not available via ALPN/NPN: " + names)); new RuntimeException("Protocol not available via ALPN/NPN: " + names);
removeMethod.invoke(null, engine); protocolNegotiated.setException(e);
if ("select".equals(methodName)) {
// ALPN server
throw e; // Throwing exception causes TLS alert.
}
return null; return null;
} }
throw new IllegalStateException("Unknown method " + methodName); throw new IllegalStateException("Unknown method " + methodName);

View File

@ -13,9 +13,11 @@ import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelInitializer;
import io.netty.channel.EventLoopGroup; import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.ssl.SslContext;
import javax.annotation.Nullable;
/** /**
* Implementation of the {@link com.google.common.util.concurrent.Service} interface for a * Implementation of the {@link com.google.common.util.concurrent.Service} interface for a
@ -28,12 +30,13 @@ public class NettyServer extends AbstractService {
private final EventLoopGroup workerGroup; private final EventLoopGroup workerGroup;
private Channel channel; private Channel channel;
public NettyServer(ServerListener serverListener, int port) { public NettyServer(ServerListener serverListener, int port, EventLoopGroup bossGroup,
this(serverListener, port, new NioEventLoopGroup(), new NioEventLoopGroup()); EventLoopGroup workerGroup) {
this(serverListener, port, bossGroup, workerGroup, null);
} }
public NettyServer(final ServerListener serverListener, int port, EventLoopGroup bossGroup, public NettyServer(final ServerListener serverListener, int port, EventLoopGroup bossGroup,
EventLoopGroup workerGroup) { EventLoopGroup workerGroup, @Nullable final SslContext sslContext) {
Preconditions.checkNotNull(bossGroup, "bossGroup"); Preconditions.checkNotNull(bossGroup, "bossGroup");
Preconditions.checkNotNull(workerGroup, "workerGroup"); Preconditions.checkNotNull(workerGroup, "workerGroup");
Preconditions.checkArgument(port >= 0, "port must be positive"); Preconditions.checkArgument(port >= 0, "port must be positive");
@ -41,7 +44,7 @@ public class NettyServer extends AbstractService {
this.channelInitializer = new ChannelInitializer<SocketChannel>() { this.channelInitializer = new ChannelInitializer<SocketChannel>() {
@Override @Override
public void initChannel(SocketChannel ch) throws Exception { public void initChannel(SocketChannel ch) throws Exception {
NettyServerTransport transport = new NettyServerTransport(ch, serverListener); NettyServerTransport transport = new NettyServerTransport(ch, serverListener, sslContext);
transport.startAsync(); transport.startAsync();
// TODO(user): Should we wait for transport shutdown before shutting down server? // TODO(user): Should we wait for transport shutdown before shutting down server?
} }

View File

@ -8,6 +8,7 @@ import com.google.net.stubby.newtransport.ServerListener;
import io.netty.channel.EventLoopGroup; import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.handler.ssl.SslContext;
/** /**
* The convenient builder for a netty-based GRPC server. * The convenient builder for a netty-based GRPC server.
@ -18,6 +19,7 @@ public final class NettyServerBuilder extends AbstractServerBuilder<NettyServerB
private EventLoopGroup userBossEventLoopGroup; private EventLoopGroup userBossEventLoopGroup;
private EventLoopGroup userWorkerEventLoopGroup; private EventLoopGroup userWorkerEventLoopGroup;
private SslContext sslContext;
public static NettyServerBuilder forPort(int port) { public static NettyServerBuilder forPort(int port) {
return new NettyServerBuilder(port); return new NettyServerBuilder(port);
@ -64,13 +66,22 @@ public final class NettyServerBuilder extends AbstractServerBuilder<NettyServerB
return this; return this;
} }
/**
* Sets the TLS context to use for encryption. Providing a context enables encryption.
*/
public NettyServerBuilder sslContext(SslContext sslContext) {
this.sslContext = sslContext;
return this;
}
@Override @Override
protected Service buildTransportServer(ServerListener serverListener) { protected Service buildTransportServer(ServerListener serverListener) {
final EventLoopGroup bossEventLoopGroup = (userBossEventLoopGroup == null) final EventLoopGroup bossEventLoopGroup = (userBossEventLoopGroup == null)
? new NioEventLoopGroup() : userBossEventLoopGroup; ? new NioEventLoopGroup() : userBossEventLoopGroup;
final EventLoopGroup workerEventLoopGroup = (userWorkerEventLoopGroup == null) final EventLoopGroup workerEventLoopGroup = (userWorkerEventLoopGroup == null)
? new NioEventLoopGroup() : userWorkerEventLoopGroup; ? new NioEventLoopGroup() : userWorkerEventLoopGroup;
NettyServer server = new NettyServer(serverListener, port); NettyServer server =
new NettyServer(serverListener, port, bossEventLoopGroup, workerEventLoopGroup, sslContext);
if (userBossEventLoopGroup == null) { if (userBossEventLoopGroup == null) {
server.addListener(new ClosureHook() { server.addListener(new ClosureHook() {
@Override @Override

View File

@ -21,8 +21,11 @@ import io.netty.handler.codec.http2.Http2FrameWriter;
import io.netty.handler.codec.http2.Http2InboundFrameLogger; import io.netty.handler.codec.http2.Http2InboundFrameLogger;
import io.netty.handler.codec.http2.Http2OutboundFlowController; import io.netty.handler.codec.http2.Http2OutboundFlowController;
import io.netty.handler.codec.http2.Http2OutboundFrameLogger; import io.netty.handler.codec.http2.Http2OutboundFrameLogger;
import io.netty.handler.ssl.SslContext;
import io.netty.util.internal.logging.InternalLogLevel; import io.netty.util.internal.logging.InternalLogLevel;
import javax.annotation.Nullable;
/** /**
* The Netty-based server transport. * The Netty-based server transport.
*/ */
@ -30,11 +33,14 @@ class NettyServerTransport extends AbstractService {
private static final Http2FrameLogger frameLogger = new Http2FrameLogger(InternalLogLevel.DEBUG); private static final Http2FrameLogger frameLogger = new Http2FrameLogger(InternalLogLevel.DEBUG);
private final SocketChannel channel; private final SocketChannel channel;
private final ServerListener serverListener; private final ServerListener serverListener;
private final SslContext sslContext;
private NettyServerHandler handler; private NettyServerHandler handler;
NettyServerTransport(SocketChannel channel, ServerListener serverListener) { NettyServerTransport(SocketChannel channel, ServerListener serverListener,
@Nullable SslContext sslContext) {
this.channel = Preconditions.checkNotNull(channel, "channel"); this.channel = Preconditions.checkNotNull(channel, "channel");
this.serverListener = Preconditions.checkNotNull(serverListener, "serverListener"); this.serverListener = Preconditions.checkNotNull(serverListener, "serverListener");
this.sslContext = sslContext;
} }
@Override @Override
@ -64,6 +70,9 @@ class NettyServerTransport extends AbstractService {
} }
}); });
if (sslContext != null) {
channel.pipeline().addLast(Http2Negotiator.serverTls(sslContext.newEngine(channel.alloc())));
}
channel.pipeline().addLast(handler); channel.pipeline().addLast(handler);
notifyStarted(); notifyStarted();