From f94b77c87fc3a559880928dd650e15d43bc00885 Mon Sep 17 00:00:00 2001 From: Carl Mastrangelo Date: Thu, 16 May 2019 18:15:47 -0700 Subject: [PATCH] netty: change server to new protocol negotiator model Changes: * PlaintextProtocolNegotiator is the same between client and server * ServerTlsHandler is rewritten to not handle errors * Also, it now sets the security level attribute, which I don't think it did previously * NettyServerTransport now uses WBAEH, similar to the client. I don't think the buffer is needed, but it does correctly handle errors during the startup --- .../io/grpc/netty/NettyServerHandler.java | 2 + .../io/grpc/netty/NettyServerTransport.java | 6 +- .../io/grpc/netty/ProtocolNegotiators.java | 112 ++++++------------ .../grpc/netty/ProtocolNegotiatorsTest.java | 53 +++++---- 4 files changed, 74 insertions(+), 99 deletions(-) diff --git a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java index 41a630eac4..de31e5149f 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java @@ -512,6 +512,8 @@ class NettyServerHandler extends AbstractNettyHandler { Attributes attrs, InternalChannelz.Security securityInfo) { negotiationAttributes = attrs; this.securityInfo = securityInfo; + super.handleProtocolNegotiationCompleted(attrs, securityInfo); + NettyClientHandler.writeBufferingAndRemove(ctx().channel()); } InternalChannelz.Security getSecurityInfo() { diff --git a/netty/src/main/java/io/grpc/netty/NettyServerTransport.java b/netty/src/main/java/io/grpc/netty/NettyServerTransport.java index d48fbe8975..cd26fe28b4 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerTransport.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerTransport.java @@ -137,12 +137,14 @@ class NettyServerTransport implements ServerTransport { } } + ChannelHandler negotiationHandler = protocolNegotiator.newHandler(grpcHandler); + ChannelHandler bufferingHandler = new WriteBufferingAndExceptionHandler(negotiationHandler); + ChannelFutureListener terminationNotifier = new TerminationNotifier(); channelUnused.addListener(terminationNotifier); channel.closeFuture().addListener(terminationNotifier); - ChannelHandler negotiationHandler = protocolNegotiator.newHandler(grpcHandler); - channel.pipeline().addLast(negotiationHandler); + channel.pipeline().addLast(bufferingHandler); } @Override diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java index 49904c7fb8..f2032887fc 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java @@ -25,7 +25,6 @@ import io.grpc.Attributes; import io.grpc.ChannelLogger; import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.Grpc; -import io.grpc.InternalChannelz; import io.grpc.InternalChannelz.Security; import io.grpc.InternalChannelz.Tls; import io.grpc.SecurityLevel; @@ -36,11 +35,9 @@ import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandler; -import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandler; import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPromise; import io.netty.handler.codec.http.DefaultHttpRequest; import io.netty.handler.codec.http.HttpClientCodec; @@ -109,34 +106,7 @@ final class ProtocolNegotiators { * Create a server plaintext handler for gRPC. */ public static ProtocolNegotiator serverPlaintext() { - return new ProtocolNegotiator() { - @Override - public ChannelHandler newHandler(final GrpcHttp2ConnectionHandler handler) { - class PlaintextHandler extends ChannelHandlerAdapter { - @Override - public void handlerAdded(ChannelHandlerContext ctx) throws Exception { - // Set sttributes before replace to be sure we pass it before accepting any requests. - handler.handleProtocolNegotiationCompleted(Attributes.newBuilder() - .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, ctx.channel().remoteAddress()) - .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, ctx.channel().localAddress()) - .build(), - /*securityInfo=*/ null); - // Just replace this handler with the gRPC handler. - ctx.pipeline().replace(this, null, handler); - } - } - - return new PlaintextHandler(); - } - - @Override - public void close() {} - - @Override - public AsciiString scheme() { - return Utils.HTTP; - } - }; + return new PlaintextProtocolNegotiator(); } /** @@ -147,7 +117,10 @@ final class ProtocolNegotiators { return new ProtocolNegotiator() { @Override public ChannelHandler newHandler(GrpcHttp2ConnectionHandler handler) { - return new ServerTlsHandler(sslContext, handler); + ChannelHandler gnh = new GrpcNegotiationHandler(handler); + ChannelHandler sth = new ServerTlsHandler(gnh, sslContext); + ChannelHandler wauh = new WaitUntilActiveHandler(sth); + return wauh; } @Override @@ -161,67 +134,56 @@ final class ProtocolNegotiators { }; } - @VisibleForTesting static final class ServerTlsHandler extends ChannelInboundHandlerAdapter { - private final GrpcHttp2ConnectionHandler grpcHandler; + private final ChannelHandler next; private final SslContext sslContext; - ServerTlsHandler(SslContext sslContext, GrpcHttp2ConnectionHandler grpcHandler) { - this.sslContext = sslContext; - this.grpcHandler = grpcHandler; + private ProtocolNegotiationEvent pne = ProtocolNegotiationEvent.DEFAULT; + + ServerTlsHandler(ChannelHandler next, SslContext sslContext) { + this.sslContext = checkNotNull(sslContext, "sslContext"); + this.next = checkNotNull(next, "next"); } @Override public void handlerAdded(ChannelHandlerContext ctx) throws Exception { super.handlerAdded(ctx); - SSLEngine sslEngine = sslContext.newEngine(ctx.alloc()); - ctx.pipeline().addFirst(new SslHandler(sslEngine, false)); - } - - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - fail(ctx, cause); + ctx.pipeline().addBefore(ctx.name(), null, new SslHandler(sslEngine, false)); } @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { - if (evt instanceof SslHandshakeCompletionEvent) { + if (evt instanceof ProtocolNegotiationEvent) { + pne = (ProtocolNegotiationEvent) evt; + } else if (evt instanceof SslHandshakeCompletionEvent) { SslHandshakeCompletionEvent handshakeEvent = (SslHandshakeCompletionEvent) evt; - if (handshakeEvent.isSuccess()) { - if (NEXT_PROTOCOL_VERSIONS.contains(sslHandler(ctx.pipeline()).applicationProtocol())) { - SSLSession session = sslHandler(ctx.pipeline()).engine().getSession(); - // Successfully negotiated the protocol. - // Notify about completion and pass down SSLSession in attributes. - grpcHandler.handleProtocolNegotiationCompleted( - Attributes.newBuilder() - .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, session) - .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, ctx.channel().remoteAddress()) - .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, ctx.channel().localAddress()) - .build(), - new InternalChannelz.Security(new InternalChannelz.Tls(session))); - // Replace this handler with the GRPC handler. - ctx.pipeline().replace(this, null, grpcHandler); - } else { - fail(ctx, - unavailableException( - "Failed protocol negotiation: Unable to find compatible protocol")); - } - } else { - fail(ctx, handshakeEvent.cause()); + if (!handshakeEvent.isSuccess()) { + logSslEngineDetails(Level.FINE, ctx, "TLS negotiation failed for new client.", null); + ctx.fireExceptionCaught(handshakeEvent.cause()); + return; } + SslHandler sslHandler = ctx.pipeline().get(SslHandler.class); + if (!NEXT_PROTOCOL_VERSIONS.contains(sslHandler.applicationProtocol())) { + logSslEngineDetails(Level.FINE, ctx, "TLS negotiation failed for new client.", null); + ctx.fireExceptionCaught(unavailableException( + "Failed protocol negotiation: Unable to find compatible protocol")); + return; + } + ctx.pipeline().replace(ctx.name(), null, next); + fireProtocolNegotiationEvent(ctx, sslHandler.engine().getSession()); + } else { + super.userEventTriggered(ctx, evt); } - super.userEventTriggered(ctx, evt); } - private SslHandler sslHandler(ChannelPipeline pipeline) { - return pipeline.get(SslHandler.class); - } - - @SuppressWarnings("FutureReturnValueIgnored") - private void fail(ChannelHandlerContext ctx, Throwable exception) { - logSslEngineDetails(Level.FINE, ctx, "TLS negotiation failed for new client.", exception); - ctx.close(); + private void fireProtocolNegotiationEvent(ChannelHandlerContext ctx, SSLSession session) { + Security security = new Security(new Tls(session)); + Attributes attrs = pne.getAttributes().toBuilder() + .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.PRIVACY_AND_INTEGRITY) + .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, session) + .build(); + ctx.fireUserEventTriggered(pne.withAttributes(attrs).withSecurity(security)); } } diff --git a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java index 60f35ce0cb..638105d019 100644 --- a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java +++ b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java @@ -238,21 +238,10 @@ public class ProtocolNegotiatorsTest { Object unused = ProtocolNegotiators.serverTls(null); } - @Test - public void tlsAdapter_exceptionClosesChannel() throws Exception { - ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler); - - // Use addFirst due to the funny error handling in EmbeddedChannel. - pipeline.addFirst(handler); - - pipeline.fireExceptionCaught(new Exception("bad")); - - assertFalse(channel.isOpen()); - } @Test public void tlsHandler_handlerAddedAddsSslHandler() throws Exception { - ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler); + ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext); pipeline.addLast(handler); @@ -261,7 +250,7 @@ public class ProtocolNegotiatorsTest { @Test public void tlsHandler_userEventTriggeredNonSslEvent() throws Exception { - ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler); + ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext); pipeline.addLast(handler); channelHandlerCtx = pipeline.context(handler); Object nonSslEvent = new Object(); @@ -282,32 +271,52 @@ public class ProtocolNegotiatorsTest { } }; - ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler); + ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext); pipeline.addLast(handler); + final AtomicReference error = new AtomicReference<>(); + ChannelHandler errorCapture = new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + error.set(cause); + } + }; + + pipeline.addLast(errorCapture); + pipeline.replace(SslHandler.class, null, badSslHandler); channelHandlerCtx = pipeline.context(handler); Object sslEvent = SslHandshakeCompletionEvent.SUCCESS; pipeline.fireUserEventTriggered(sslEvent); - // No h2 protocol was specified, so this should be closed. - assertFalse(channel.isOpen()); + // No h2 protocol was specified, so there should be an error, (normally handled by WBAEH) + assertThat(error.get()).hasMessageThat().contains("Unable to find compatible protocol"); ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler); assertNull(grpcHandlerCtx); } @Test public void tlsHandler_userEventTriggeredSslEvent_handshakeFailure() throws Exception { - ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler); + ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext); pipeline.addLast(handler); channelHandlerCtx = pipeline.context(handler); Object sslEvent = new SslHandshakeCompletionEvent(new RuntimeException("bad")); + final AtomicReference error = new AtomicReference<>(); + ChannelHandler errorCapture = new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + error.set(cause); + } + }; + + pipeline.addLast(errorCapture); + pipeline.fireUserEventTriggered(sslEvent); - // No h2 protocol was specified, so this should be closed. - assertFalse(channel.isOpen()); + // No h2 protocol was specified, so there should be an error, (normally handled by WBAEH) + assertThat(error.get()).hasMessageThat().contains("bad"); ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler); assertNull(grpcHandlerCtx); } @@ -321,7 +330,7 @@ public class ProtocolNegotiatorsTest { } }; - ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler); + ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext); pipeline.addLast(handler); pipeline.replace(SslHandler.class, null, goodSslHandler); @@ -344,7 +353,7 @@ public class ProtocolNegotiatorsTest { } }; - ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler); + ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext); pipeline.addLast(handler); pipeline.replace(SslHandler.class, null, goodSslHandler); @@ -360,7 +369,7 @@ public class ProtocolNegotiatorsTest { @Test public void engineLog() { - ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler); + ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext); pipeline.addLast(handler); channelHandlerCtx = pipeline.context(handler);