netty: migrate Server protocol negotiation to new style

* Revert "Revert "netty: change server to new protocol negotiator model" (#5798)"

This reverts commit 4e5e19f6fd.
This commit is contained in:
Carl Mastrangelo 2019-08-14 13:00:42 -07:00 committed by GitHub
parent ce53d0eac7
commit 458f4533db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 98 additions and 103 deletions

View File

@ -538,6 +538,8 @@ class NettyServerHandler extends AbstractNettyHandler {
Attributes attrs, InternalChannelz.Security securityInfo) { Attributes attrs, InternalChannelz.Security securityInfo) {
negotiationAttributes = attrs; negotiationAttributes = attrs;
this.securityInfo = securityInfo; this.securityInfo = securityInfo;
super.handleProtocolNegotiationCompleted(attrs, securityInfo);
NettyClientHandler.writeBufferingAndRemove(ctx().channel());
} }
InternalChannelz.Security getSecurityInfo() { InternalChannelz.Security getSecurityInfo() {

View File

@ -137,12 +137,14 @@ class NettyServerTransport implements ServerTransport {
} }
} }
ChannelHandler negotiationHandler = protocolNegotiator.newHandler(grpcHandler);
ChannelHandler bufferingHandler = new WriteBufferingAndExceptionHandler(negotiationHandler);
ChannelFutureListener terminationNotifier = new TerminationNotifier(); ChannelFutureListener terminationNotifier = new TerminationNotifier();
channelUnused.addListener(terminationNotifier); channelUnused.addListener(terminationNotifier);
channel.closeFuture().addListener(terminationNotifier); channel.closeFuture().addListener(terminationNotifier);
ChannelHandler negotiationHandler = protocolNegotiator.newHandler(grpcHandler); channel.pipeline().addLast(bufferingHandler);
channel.pipeline().addLast(negotiationHandler);
} }
@Override @Override

View File

@ -27,7 +27,6 @@ import io.grpc.Attributes;
import io.grpc.ChannelLogger; import io.grpc.ChannelLogger;
import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.ChannelLogger.ChannelLogLevel;
import io.grpc.Grpc; import io.grpc.Grpc;
import io.grpc.InternalChannelz;
import io.grpc.InternalChannelz.Security; import io.grpc.InternalChannelz.Security;
import io.grpc.InternalChannelz.Tls; import io.grpc.InternalChannelz.Tls;
import io.grpc.SecurityLevel; import io.grpc.SecurityLevel;
@ -38,11 +37,9 @@ import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandler; import io.netty.channel.ChannelInboundHandler;
import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.DefaultHttpRequest; import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.HttpClientCodec; import io.netty.handler.codec.http.HttpClientCodec;
@ -111,34 +108,7 @@ final class ProtocolNegotiators {
* Create a server plaintext handler for gRPC. * Create a server plaintext handler for gRPC.
*/ */
public static ProtocolNegotiator serverPlaintext() { public static ProtocolNegotiator serverPlaintext() {
return new ProtocolNegotiator() { return new PlaintextProtocolNegotiator();
@Override
public ChannelHandler newHandler(final GrpcHttp2ConnectionHandler handler) {
class PlaintextHandler extends ChannelHandlerAdapter {
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
// Set sttributes before replace to be sure we pass it before accepting any requests.
handler.handleProtocolNegotiationCompleted(Attributes.newBuilder()
.set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, ctx.channel().remoteAddress())
.set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, ctx.channel().localAddress())
.build(),
/*securityInfo=*/ null);
// Just replace this handler with the gRPC handler.
ctx.pipeline().replace(this, null, handler);
}
}
return new PlaintextHandler();
}
@Override
public void close() {}
@Override
public AsciiString scheme() {
return Utils.HTTP;
}
};
} }
/** /**
@ -149,7 +119,10 @@ final class ProtocolNegotiators {
return new ProtocolNegotiator() { return new ProtocolNegotiator() {
@Override @Override
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler handler) { public ChannelHandler newHandler(GrpcHttp2ConnectionHandler handler) {
return new ServerTlsHandler(sslContext, handler); ChannelHandler gnh = new GrpcNegotiationHandler(handler);
ChannelHandler sth = new ServerTlsHandler(gnh, sslContext);
ChannelHandler wauh = new WaitUntilActiveHandler(sth);
return wauh;
} }
@Override @Override
@ -163,67 +136,56 @@ final class ProtocolNegotiators {
}; };
} }
@VisibleForTesting
static final class ServerTlsHandler extends ChannelInboundHandlerAdapter { static final class ServerTlsHandler extends ChannelInboundHandlerAdapter {
private final GrpcHttp2ConnectionHandler grpcHandler; private final ChannelHandler next;
private final SslContext sslContext; private final SslContext sslContext;
ServerTlsHandler(SslContext sslContext, GrpcHttp2ConnectionHandler grpcHandler) { private ProtocolNegotiationEvent pne = ProtocolNegotiationEvent.DEFAULT;
this.sslContext = sslContext;
this.grpcHandler = grpcHandler; ServerTlsHandler(ChannelHandler next, SslContext sslContext) {
this.sslContext = checkNotNull(sslContext, "sslContext");
this.next = checkNotNull(next, "next");
} }
@Override @Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception { public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
super.handlerAdded(ctx); super.handlerAdded(ctx);
SSLEngine sslEngine = sslContext.newEngine(ctx.alloc()); SSLEngine sslEngine = sslContext.newEngine(ctx.alloc());
ctx.pipeline().addFirst(new SslHandler(sslEngine, false)); ctx.pipeline().addBefore(ctx.name(), null, new SslHandler(sslEngine, false));
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
fail(ctx, cause);
} }
@Override @Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof SslHandshakeCompletionEvent) { if (evt instanceof ProtocolNegotiationEvent) {
pne = (ProtocolNegotiationEvent) evt;
} else if (evt instanceof SslHandshakeCompletionEvent) {
SslHandshakeCompletionEvent handshakeEvent = (SslHandshakeCompletionEvent) evt; SslHandshakeCompletionEvent handshakeEvent = (SslHandshakeCompletionEvent) evt;
if (handshakeEvent.isSuccess()) { if (!handshakeEvent.isSuccess()) {
if (NEXT_PROTOCOL_VERSIONS.contains(sslHandler(ctx.pipeline()).applicationProtocol())) { logSslEngineDetails(Level.FINE, ctx, "TLS negotiation failed for new client.", null);
SSLSession session = sslHandler(ctx.pipeline()).engine().getSession(); ctx.fireExceptionCaught(handshakeEvent.cause());
// Successfully negotiated the protocol. return;
// Notify about completion and pass down SSLSession in attributes.
grpcHandler.handleProtocolNegotiationCompleted(
Attributes.newBuilder()
.set(Grpc.TRANSPORT_ATTR_SSL_SESSION, session)
.set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, ctx.channel().remoteAddress())
.set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, ctx.channel().localAddress())
.build(),
new InternalChannelz.Security(new InternalChannelz.Tls(session)));
// Replace this handler with the GRPC handler.
ctx.pipeline().replace(this, null, grpcHandler);
} else {
fail(ctx,
unavailableException(
"Failed protocol negotiation: Unable to find compatible protocol"));
}
} else {
fail(ctx, handshakeEvent.cause());
} }
SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
if (!NEXT_PROTOCOL_VERSIONS.contains(sslHandler.applicationProtocol())) {
logSslEngineDetails(Level.FINE, ctx, "TLS negotiation failed for new client.", null);
ctx.fireExceptionCaught(unavailableException(
"Failed protocol negotiation: Unable to find compatible protocol"));
return;
}
ctx.pipeline().replace(ctx.name(), null, next);
fireProtocolNegotiationEvent(ctx, sslHandler.engine().getSession());
} else {
super.userEventTriggered(ctx, evt);
} }
super.userEventTriggered(ctx, evt);
} }
private SslHandler sslHandler(ChannelPipeline pipeline) { private void fireProtocolNegotiationEvent(ChannelHandlerContext ctx, SSLSession session) {
return pipeline.get(SslHandler.class); Security security = new Security(new Tls(session));
} Attributes attrs = pne.getAttributes().toBuilder()
.set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.PRIVACY_AND_INTEGRITY)
@SuppressWarnings("FutureReturnValueIgnored") .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, session)
private void fail(ChannelHandlerContext ctx, Throwable exception) { .build();
logSslEngineDetails(Level.FINE, ctx, "TLS negotiation failed for new client.", exception); ctx.fireUserEventTriggered(pne.withAttributes(attrs).withSecurity(security));
ctx.close();
} }
} }

View File

@ -17,6 +17,8 @@
package io.grpc.netty; package io.grpc.netty;
import static com.google.common.base.Charsets.UTF_8; import static com.google.common.base.Charsets.UTF_8;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
@ -50,6 +52,7 @@ import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandler; import io.netty.channel.ChannelInboundHandler;
import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultEventLoop; import io.netty.channel.DefaultEventLoop;
@ -240,21 +243,10 @@ public class ProtocolNegotiatorsTest {
Object unused = ProtocolNegotiators.serverTls(null); Object unused = ProtocolNegotiators.serverTls(null);
} }
@Test
public void tlsAdapter_exceptionClosesChannel() throws Exception {
ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler);
// Use addFirst due to the funny error handling in EmbeddedChannel.
pipeline.addFirst(handler);
pipeline.fireExceptionCaught(new Exception("bad"));
assertFalse(channel.isOpen());
}
@Test @Test
public void tlsHandler_handlerAddedAddsSslHandler() throws Exception { public void tlsHandler_handlerAddedAddsSslHandler() throws Exception {
ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler); ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext);
pipeline.addLast(handler); pipeline.addLast(handler);
@ -263,7 +255,7 @@ public class ProtocolNegotiatorsTest {
@Test @Test
public void tlsHandler_userEventTriggeredNonSslEvent() throws Exception { public void tlsHandler_userEventTriggeredNonSslEvent() throws Exception {
ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler); ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext);
pipeline.addLast(handler); pipeline.addLast(handler);
channelHandlerCtx = pipeline.context(handler); channelHandlerCtx = pipeline.context(handler);
Object nonSslEvent = new Object(); Object nonSslEvent = new Object();
@ -284,32 +276,52 @@ public class ProtocolNegotiatorsTest {
} }
}; };
ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler); ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext);
pipeline.addLast(handler); pipeline.addLast(handler);
final AtomicReference<Throwable> error = new AtomicReference<>();
ChannelHandler errorCapture = new ChannelInboundHandlerAdapter() {
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
error.set(cause);
}
};
pipeline.addLast(errorCapture);
pipeline.replace(SslHandler.class, null, badSslHandler); pipeline.replace(SslHandler.class, null, badSslHandler);
channelHandlerCtx = pipeline.context(handler); channelHandlerCtx = pipeline.context(handler);
Object sslEvent = SslHandshakeCompletionEvent.SUCCESS; Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
pipeline.fireUserEventTriggered(sslEvent); pipeline.fireUserEventTriggered(sslEvent);
// No h2 protocol was specified, so this should be closed. // No h2 protocol was specified, so there should be an error, (normally handled by WBAEH)
assertFalse(channel.isOpen()); assertThat(error.get()).hasMessageThat().contains("Unable to find compatible protocol");
ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler); ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
assertNull(grpcHandlerCtx); assertNull(grpcHandlerCtx);
} }
@Test @Test
public void tlsHandler_userEventTriggeredSslEvent_handshakeFailure() throws Exception { public void tlsHandler_userEventTriggeredSslEvent_handshakeFailure() throws Exception {
ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler); ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext);
pipeline.addLast(handler); pipeline.addLast(handler);
channelHandlerCtx = pipeline.context(handler); channelHandlerCtx = pipeline.context(handler);
Object sslEvent = new SslHandshakeCompletionEvent(new RuntimeException("bad")); Object sslEvent = new SslHandshakeCompletionEvent(new RuntimeException("bad"));
final AtomicReference<Throwable> error = new AtomicReference<>();
ChannelHandler errorCapture = new ChannelInboundHandlerAdapter() {
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
error.set(cause);
}
};
pipeline.addLast(errorCapture);
pipeline.fireUserEventTriggered(sslEvent); pipeline.fireUserEventTriggered(sslEvent);
// No h2 protocol was specified, so this should be closed. // No h2 protocol was specified, so there should be an error, (normally handled by WBAEH)
assertFalse(channel.isOpen()); assertThat(error.get()).hasMessageThat().contains("bad");
ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler); ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
assertNull(grpcHandlerCtx); assertNull(grpcHandlerCtx);
} }
@ -323,7 +335,7 @@ public class ProtocolNegotiatorsTest {
} }
}; };
ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler); ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext);
pipeline.addLast(handler); pipeline.addLast(handler);
pipeline.replace(SslHandler.class, null, goodSslHandler); pipeline.replace(SslHandler.class, null, goodSslHandler);
@ -346,7 +358,7 @@ public class ProtocolNegotiatorsTest {
} }
}; };
ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler); ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext);
pipeline.addLast(handler); pipeline.addLast(handler);
pipeline.replace(SslHandler.class, null, goodSslHandler); pipeline.replace(SslHandler.class, null, goodSslHandler);
@ -362,7 +374,7 @@ public class ProtocolNegotiatorsTest {
@Test @Test
public void engineLog() { public void engineLog() {
ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler); ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext);
pipeline.addLast(handler); pipeline.addLast(handler);
channelHandlerCtx = pipeline.context(handler); channelHandlerCtx = pipeline.context(handler);
@ -620,7 +632,7 @@ public class ProtocolNegotiatorsTest {
FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler(); FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler();
ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext); ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext);
WriteBufferingAndExceptionHandler wbaeh = WriteBufferingAndExceptionHandler clientWbaeh =
new WriteBufferingAndExceptionHandler(pn.newHandler(gh)); new WriteBufferingAndExceptionHandler(pn.newHandler(gh));
SocketAddress addr = new LocalAddress("addr"); SocketAddress addr = new LocalAddress("addr");
@ -628,22 +640,24 @@ public class ProtocolNegotiatorsTest {
ChannelHandler sh = ChannelHandler sh =
ProtocolNegotiators.serverTls(serverSslContext) ProtocolNegotiators.serverTls(serverSslContext)
.newHandler(FakeGrpcHttp2ConnectionHandler.noopHandler()); .newHandler(FakeGrpcHttp2ConnectionHandler.noopHandler());
WriteBufferingAndExceptionHandler serverWbaeh = new WriteBufferingAndExceptionHandler(sh);
Channel s = new ServerBootstrap() Channel s = new ServerBootstrap()
.childHandler(sh) .childHandler(serverWbaeh)
.group(group) .group(group)
.channel(LocalServerChannel.class) .channel(LocalServerChannel.class)
.bind(addr) .bind(addr)
.sync() .sync()
.channel(); .channel();
Channel c = new Bootstrap() Channel c = new Bootstrap()
.handler(wbaeh) .handler(clientWbaeh)
.channel(LocalChannel.class) .channel(LocalChannel.class)
.group(group) .group(group)
.register() .register()
.sync() .sync()
.channel(); .channel();
ChannelFuture write = c.writeAndFlush(NettyClientHandler.NOOP_MESSAGE); ChannelFuture write = c.writeAndFlush(NettyClientHandler.NOOP_MESSAGE);
c.connect(addr); c.connect(addr).sync();
write.sync();
boolean completed = gh.negotiated.await(TIMEOUT_SECONDS, TimeUnit.SECONDS); boolean completed = gh.negotiated.await(TIMEOUT_SECONDS, TimeUnit.SECONDS);
if (!completed) { if (!completed) {
@ -749,6 +763,7 @@ public class ProtocolNegotiatorsTest {
private Attributes attrs; private Attributes attrs;
private Security securityInfo; private Security securityInfo;
private final CountDownLatch negotiated = new CountDownLatch(1); private final CountDownLatch negotiated = new CountDownLatch(1);
private ChannelHandlerContext ctx;
FakeGrpcHttp2ConnectionHandler(ChannelPromise channelUnused, FakeGrpcHttp2ConnectionHandler(ChannelPromise channelUnused,
Http2ConnectionDecoder decoder, Http2ConnectionDecoder decoder,
@ -761,9 +776,22 @@ public class ProtocolNegotiatorsTest {
@Override @Override
public void handleProtocolNegotiationCompleted(Attributes attrs, Security securityInfo) { public void handleProtocolNegotiationCompleted(Attributes attrs, Security securityInfo) {
checkNotNull(ctx, "handleProtocolNegotiationCompleted cannot be called before handlerAdded");
super.handleProtocolNegotiationCompleted(attrs, securityInfo); super.handleProtocolNegotiationCompleted(attrs, securityInfo);
this.attrs = attrs; this.attrs = attrs;
this.securityInfo = securityInfo; this.securityInfo = securityInfo;
// Add a temp handler that verifies first message is a NOOP_MESSAGE
ctx.pipeline().addBefore(ctx.name(), null, new ChannelOutboundHandlerAdapter() {
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
throws Exception {
checkState(
msg == NettyClientHandler.NOOP_MESSAGE, "First message should be NOOP_MESSAGE");
promise.trySuccess();
ctx.pipeline().remove(this);
}
});
NettyClientHandler.writeBufferingAndRemove(ctx.channel());
negotiated.countDown(); negotiated.countDown();
} }
@ -774,6 +802,7 @@ public class ProtocolNegotiatorsTest {
} else { } else {
super.handlerAdded(ctx); super.handlerAdded(ctx);
} }
this.ctx = ctx;
} }
@Override @Override