diff --git a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java index 60f2a3d49f..a889a040da 100644 --- a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java @@ -19,6 +19,7 @@ package io.grpc.netty; import io.grpc.ChannelLogger; import io.grpc.netty.ProtocolNegotiators.ClientTlsHandler; import io.grpc.netty.ProtocolNegotiators.GrpcNegotiationHandler; +import io.grpc.netty.ProtocolNegotiators.ProtocolNegotiationHandler; import io.grpc.netty.ProtocolNegotiators.WaitUntilActiveHandler; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; @@ -99,4 +100,16 @@ public final class InternalProtocolNegotiators { ChannelHandler next, SslContext sslContext, String authority) { return new ClientTlsHandler(next, sslContext, authority); } + + public static class ProtocolNegotiationHandler + extends ProtocolNegotiators.ProtocolNegotiationHandler { + + protected ProtocolNegotiationHandler(ChannelHandler next, String negotiatorName) { + super(next, negotiatorName); + } + + protected ProtocolNegotiationHandler(ChannelHandler next) { + super(next); + } + } } diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java index f70d90153d..03a4ffe10e 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java @@ -22,6 +22,7 @@ import static io.grpc.netty.GrpcSslContexts.NEXT_PROTOCOL_VERSIONS; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.ForOverride; import io.grpc.Attributes; import io.grpc.ChannelLogger; import io.grpc.ChannelLogger.ChannelLogLevel; @@ -321,17 +322,14 @@ final class ProtocolNegotiators { public void close() {} } - static final class ClientTlsHandler extends ChannelDuplexHandler { + static final class ClientTlsHandler extends ProtocolNegotiationHandler { - private final ChannelHandler next; private final SslContext sslContext; private final String host; private final int port; - private ProtocolNegotiationEvent pne; - ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority) { - this.next = checkNotNull(next, "next"); + super(next); this.sslContext = checkNotNull(sslContext, "sslContext"); HostPort hostPort = parseAuthority(authority); this.host = hostPort.host; @@ -339,30 +337,24 @@ final class ProtocolNegotiators { } @Override - public void handlerAdded(ChannelHandlerContext ctx) throws Exception { - negotiationLogger(ctx).log(ChannelLogLevel.INFO, "ClientTls started"); + protected void handlerAdded0(ChannelHandlerContext ctx) { SSLEngine sslEngine = sslContext.newEngine(ctx.alloc(), host, port); SSLParameters sslParams = sslEngine.getSSLParameters(); sslParams.setEndpointIdentificationAlgorithm("HTTPS"); sslEngine.setSSLParameters(sslParams); ctx.pipeline().addBefore(ctx.name(), /* name= */ null, new SslHandler(sslEngine, false)); - super.handlerAdded(ctx); } @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { - if (evt instanceof ProtocolNegotiationEvent) { - checkState(pne == null, "negotiation already started"); - pne = (ProtocolNegotiationEvent) evt; - } else if (evt instanceof SslHandshakeCompletionEvent) { + protected void userEventTriggered0(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof SslHandshakeCompletionEvent) { SslHandshakeCompletionEvent handshakeEvent = (SslHandshakeCompletionEvent) evt; if (handshakeEvent.isSuccess()) { SslHandler handler = ctx.pipeline().get(SslHandler.class); if (NEXT_PROTOCOL_VERSIONS.contains(handler.applicationProtocol())) { // Successfully negotiated the protocol. logSslEngineDetails(Level.FINER, ctx, "TLS negotiation succeeded.", null); - ctx.pipeline().replace(ctx.name(), null, next); - fireProtocolNegotiationEvent(ctx, handler.engine().getSession()); + propagateTlsComplete(ctx, handler.engine().getSession()); } else { Exception ex = unavailableException("Failed ALPN negotiation: Unable to find compatible protocol"); @@ -373,19 +365,19 @@ final class ProtocolNegotiators { ctx.fireExceptionCaught(handshakeEvent.cause()); } } else { - super.userEventTriggered(ctx, evt); + super.userEventTriggered0(ctx, evt); } } - private void fireProtocolNegotiationEvent(ChannelHandlerContext ctx, SSLSession session) { - checkState(pne != null, "negotiation not yet complete"); - negotiationLogger(ctx).log(ChannelLogLevel.INFO, "ClientTls finished"); + private void propagateTlsComplete(ChannelHandlerContext ctx, SSLSession session) { Security security = new Security(new Tls(session)); - Attributes attrs = pne.getAttributes().toBuilder() + ProtocolNegotiationEvent existingPne = getProtocolNegotiationEvent(); + Attributes attrs = existingPne.getAttributes().toBuilder() .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.PRIVACY_AND_INTEGRITY) .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, session) .build(); - ctx.fireUserEventTriggered(pne.withAttributes(attrs).withSecurity(security)); + replaceProtocolNegotiationEvent(existingPne.withAttributes(attrs).withSecurity(security)); + fireProtocolNegotiationEvent(ctx); } } @@ -851,54 +843,113 @@ final class ProtocolNegotiators { * subsequent handlers to assume the channel is active and ready to send. Additionally, this a * {@link ProtocolNegotiationEvent}, with the connection addresses. */ - static final class WaitUntilActiveHandler extends ChannelInboundHandlerAdapter { - private final ChannelHandler next; - private ProtocolNegotiationEvent pne; + static final class WaitUntilActiveHandler extends ProtocolNegotiationHandler { - public WaitUntilActiveHandler(ChannelHandler next) { - this.next = checkNotNull(next, "next"); - } + boolean protocolNegotiationEventReceived; - @Override - public void handlerAdded(ChannelHandlerContext ctx) throws Exception { - negotiationLogger(ctx).log(ChannelLogLevel.INFO, "WaitUntilActive started"); - // This should be a noop, but just in case... - super.handlerAdded(ctx); + WaitUntilActiveHandler(ChannelHandler next) { + super(next); } @Override public void channelActive(ChannelHandlerContext ctx) throws Exception { + if (protocolNegotiationEventReceived) { + replaceOnActive(ctx); + fireProtocolNegotiationEvent(ctx); + } // Still propagate channelActive to the new handler. super.channelActive(ctx); - if (pne != null) { + } + + @Override + protected void protocolNegotiationEventTriggered(ChannelHandlerContext ctx) { + protocolNegotiationEventReceived = true; + if (ctx.channel().isActive()) { + replaceOnActive(ctx); fireProtocolNegotiationEvent(ctx); } } - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { - if (evt instanceof ProtocolNegotiationEvent) { - checkState(pne == null, "negotiation already started"); - pne = (ProtocolNegotiationEvent) evt; - if (ctx.channel().isActive()) { - fireProtocolNegotiationEvent(ctx); - } - } else { - super.userEventTriggered(ctx, evt); - } - } - - private void fireProtocolNegotiationEvent(ChannelHandlerContext ctx) { - checkState(pne != null, "negotiation not yet complete"); - negotiationLogger(ctx).log(ChannelLogLevel.INFO, "WaitUntilActive finished"); - ctx.pipeline().replace(ctx.name(), /* newName= */ null, next); - Attributes attrs = pne.getAttributes().toBuilder() + private void replaceOnActive(ChannelHandlerContext ctx) { + ProtocolNegotiationEvent existingPne = getProtocolNegotiationEvent(); + Attributes attrs = existingPne.getAttributes().toBuilder() .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, ctx.channel().localAddress()) .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, ctx.channel().remoteAddress()) // Later handlers are expected to overwrite this. .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.NONE) .build(); - ctx.fireUserEventTriggered(pne.withAttributes(attrs)); + replaceProtocolNegotiationEvent(existingPne.withAttributes(attrs)); + } + } + + /** + * ProtocolNegotiationHandler is a convenience handler that makes it easy to follow the rules for + * protocol negotiation. Handlers should strongly consider extending this handler. + */ + static class ProtocolNegotiationHandler extends ChannelDuplexHandler { + + private final ChannelHandler next; + private final String negotiatorName; + private ProtocolNegotiationEvent pne; + + protected ProtocolNegotiationHandler(ChannelHandler next, String negotiatorName) { + this.next = checkNotNull(next, "next"); + this.negotiatorName = negotiatorName; + } + + protected ProtocolNegotiationHandler(ChannelHandler next) { + this.next = checkNotNull(next, "next"); + this.negotiatorName = getClass().getSimpleName().replace("Handler", ""); + } + + @Override + public final void handlerAdded(ChannelHandlerContext ctx) throws Exception { + InternalProtocolNegotiators.negotiationLogger(ctx) + .log(ChannelLogLevel.DEBUG, negotiatorName + " started"); + handlerAdded0(ctx); + } + + @ForOverride + protected void handlerAdded0(ChannelHandlerContext ctx) throws Exception { + super.handlerAdded(ctx); + } + + @Override + public final void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof ProtocolNegotiationEvent) { + checkState(pne == null, "pre-existing negotiation: %s < %s", pne, evt); + pne = (ProtocolNegotiationEvent) evt; + protocolNegotiationEventTriggered(ctx); + } else { + userEventTriggered0(ctx, evt); + } + } + + protected void userEventTriggered0(ChannelHandlerContext ctx, Object evt) throws Exception { + super.userEventTriggered(ctx, evt); + } + + @ForOverride + protected void protocolNegotiationEventTriggered(ChannelHandlerContext ctx) { + // no-op + } + + protected final ProtocolNegotiationEvent getProtocolNegotiationEvent() { + checkState(pne != null, "previous protocol negotiation event hasn't triggered"); + return pne; + } + + protected final void replaceProtocolNegotiationEvent(ProtocolNegotiationEvent pne) { + checkState(this.pne != null, "previous protocol negotiation event hasn't triggered"); + this.pne = checkNotNull(pne); + } + + protected final void fireProtocolNegotiationEvent(ChannelHandlerContext ctx) { + checkState(pne != null, "previous protocol negotiation event hasn't triggered"); + InternalProtocolNegotiators.negotiationLogger(ctx) + .log(ChannelLogLevel.INFO, negotiatorName + " completed"); + ctx.pipeline().replace(ctx.name(), /* newName= */ null, next); + ctx.fireUserEventTriggered(pne); } } }