netty: upstream ProtocolNegotiatiorHandler, and swap the appropriate classes to it

ALTS is not switched yet, since it is shared between client and server. Once the server is changed to use WBAEH, it can be moved too.
This commit is contained in:
Carl Mastrangelo 2019-06-26 18:23:12 -07:00 committed by GitHub
parent d5e1a4bb5d
commit 9e5f60b86a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 116 additions and 52 deletions

View File

@ -19,6 +19,7 @@ package io.grpc.netty;
import io.grpc.ChannelLogger;
import io.grpc.netty.ProtocolNegotiators.ClientTlsHandler;
import io.grpc.netty.ProtocolNegotiators.GrpcNegotiationHandler;
import io.grpc.netty.ProtocolNegotiators.ProtocolNegotiationHandler;
import io.grpc.netty.ProtocolNegotiators.WaitUntilActiveHandler;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
@ -99,4 +100,16 @@ public final class InternalProtocolNegotiators {
ChannelHandler next, SslContext sslContext, String authority) {
return new ClientTlsHandler(next, sslContext, authority);
}
public static class ProtocolNegotiationHandler
extends ProtocolNegotiators.ProtocolNegotiationHandler {
protected ProtocolNegotiationHandler(ChannelHandler next, String negotiatorName) {
super(next, negotiatorName);
}
protected ProtocolNegotiationHandler(ChannelHandler next) {
super(next);
}
}
}

View File

@ -22,6 +22,7 @@ import static io.grpc.netty.GrpcSslContexts.NEXT_PROTOCOL_VERSIONS;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.errorprone.annotations.ForOverride;
import io.grpc.Attributes;
import io.grpc.ChannelLogger;
import io.grpc.ChannelLogger.ChannelLogLevel;
@ -321,17 +322,14 @@ final class ProtocolNegotiators {
public void close() {}
}
static final class ClientTlsHandler extends ChannelDuplexHandler {
static final class ClientTlsHandler extends ProtocolNegotiationHandler {
private final ChannelHandler next;
private final SslContext sslContext;
private final String host;
private final int port;
private ProtocolNegotiationEvent pne;
ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority) {
this.next = checkNotNull(next, "next");
super(next);
this.sslContext = checkNotNull(sslContext, "sslContext");
HostPort hostPort = parseAuthority(authority);
this.host = hostPort.host;
@ -339,30 +337,24 @@ final class ProtocolNegotiators {
}
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
negotiationLogger(ctx).log(ChannelLogLevel.INFO, "ClientTls started");
protected void handlerAdded0(ChannelHandlerContext ctx) {
SSLEngine sslEngine = sslContext.newEngine(ctx.alloc(), host, port);
SSLParameters sslParams = sslEngine.getSSLParameters();
sslParams.setEndpointIdentificationAlgorithm("HTTPS");
sslEngine.setSSLParameters(sslParams);
ctx.pipeline().addBefore(ctx.name(), /* name= */ null, new SslHandler(sslEngine, false));
super.handlerAdded(ctx);
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof ProtocolNegotiationEvent) {
checkState(pne == null, "negotiation already started");
pne = (ProtocolNegotiationEvent) evt;
} else if (evt instanceof SslHandshakeCompletionEvent) {
protected void userEventTriggered0(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof SslHandshakeCompletionEvent) {
SslHandshakeCompletionEvent handshakeEvent = (SslHandshakeCompletionEvent) evt;
if (handshakeEvent.isSuccess()) {
SslHandler handler = ctx.pipeline().get(SslHandler.class);
if (NEXT_PROTOCOL_VERSIONS.contains(handler.applicationProtocol())) {
// Successfully negotiated the protocol.
logSslEngineDetails(Level.FINER, ctx, "TLS negotiation succeeded.", null);
ctx.pipeline().replace(ctx.name(), null, next);
fireProtocolNegotiationEvent(ctx, handler.engine().getSession());
propagateTlsComplete(ctx, handler.engine().getSession());
} else {
Exception ex =
unavailableException("Failed ALPN negotiation: Unable to find compatible protocol");
@ -373,19 +365,19 @@ final class ProtocolNegotiators {
ctx.fireExceptionCaught(handshakeEvent.cause());
}
} else {
super.userEventTriggered(ctx, evt);
super.userEventTriggered0(ctx, evt);
}
}
private void fireProtocolNegotiationEvent(ChannelHandlerContext ctx, SSLSession session) {
checkState(pne != null, "negotiation not yet complete");
negotiationLogger(ctx).log(ChannelLogLevel.INFO, "ClientTls finished");
private void propagateTlsComplete(ChannelHandlerContext ctx, SSLSession session) {
Security security = new Security(new Tls(session));
Attributes attrs = pne.getAttributes().toBuilder()
ProtocolNegotiationEvent existingPne = getProtocolNegotiationEvent();
Attributes attrs = existingPne.getAttributes().toBuilder()
.set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.PRIVACY_AND_INTEGRITY)
.set(Grpc.TRANSPORT_ATTR_SSL_SESSION, session)
.build();
ctx.fireUserEventTriggered(pne.withAttributes(attrs).withSecurity(security));
replaceProtocolNegotiationEvent(existingPne.withAttributes(attrs).withSecurity(security));
fireProtocolNegotiationEvent(ctx);
}
}
@ -851,54 +843,113 @@ final class ProtocolNegotiators {
* subsequent handlers to assume the channel is active and ready to send. Additionally, this a
* {@link ProtocolNegotiationEvent}, with the connection addresses.
*/
static final class WaitUntilActiveHandler extends ChannelInboundHandlerAdapter {
private final ChannelHandler next;
private ProtocolNegotiationEvent pne;
static final class WaitUntilActiveHandler extends ProtocolNegotiationHandler {
public WaitUntilActiveHandler(ChannelHandler next) {
this.next = checkNotNull(next, "next");
}
boolean protocolNegotiationEventReceived;
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
negotiationLogger(ctx).log(ChannelLogLevel.INFO, "WaitUntilActive started");
// This should be a noop, but just in case...
super.handlerAdded(ctx);
WaitUntilActiveHandler(ChannelHandler next) {
super(next);
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
if (protocolNegotiationEventReceived) {
replaceOnActive(ctx);
fireProtocolNegotiationEvent(ctx);
}
// Still propagate channelActive to the new handler.
super.channelActive(ctx);
if (pne != null) {
}
@Override
protected void protocolNegotiationEventTriggered(ChannelHandlerContext ctx) {
protocolNegotiationEventReceived = true;
if (ctx.channel().isActive()) {
replaceOnActive(ctx);
fireProtocolNegotiationEvent(ctx);
}
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof ProtocolNegotiationEvent) {
checkState(pne == null, "negotiation already started");
pne = (ProtocolNegotiationEvent) evt;
if (ctx.channel().isActive()) {
fireProtocolNegotiationEvent(ctx);
}
} else {
super.userEventTriggered(ctx, evt);
}
}
private void fireProtocolNegotiationEvent(ChannelHandlerContext ctx) {
checkState(pne != null, "negotiation not yet complete");
negotiationLogger(ctx).log(ChannelLogLevel.INFO, "WaitUntilActive finished");
ctx.pipeline().replace(ctx.name(), /* newName= */ null, next);
Attributes attrs = pne.getAttributes().toBuilder()
private void replaceOnActive(ChannelHandlerContext ctx) {
ProtocolNegotiationEvent existingPne = getProtocolNegotiationEvent();
Attributes attrs = existingPne.getAttributes().toBuilder()
.set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, ctx.channel().localAddress())
.set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, ctx.channel().remoteAddress())
// Later handlers are expected to overwrite this.
.set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.NONE)
.build();
ctx.fireUserEventTriggered(pne.withAttributes(attrs));
replaceProtocolNegotiationEvent(existingPne.withAttributes(attrs));
}
}
/**
* ProtocolNegotiationHandler is a convenience handler that makes it easy to follow the rules for
* protocol negotiation. Handlers should strongly consider extending this handler.
*/
static class ProtocolNegotiationHandler extends ChannelDuplexHandler {
private final ChannelHandler next;
private final String negotiatorName;
private ProtocolNegotiationEvent pne;
protected ProtocolNegotiationHandler(ChannelHandler next, String negotiatorName) {
this.next = checkNotNull(next, "next");
this.negotiatorName = negotiatorName;
}
protected ProtocolNegotiationHandler(ChannelHandler next) {
this.next = checkNotNull(next, "next");
this.negotiatorName = getClass().getSimpleName().replace("Handler", "");
}
@Override
public final void handlerAdded(ChannelHandlerContext ctx) throws Exception {
InternalProtocolNegotiators.negotiationLogger(ctx)
.log(ChannelLogLevel.DEBUG, negotiatorName + " started");
handlerAdded0(ctx);
}
@ForOverride
protected void handlerAdded0(ChannelHandlerContext ctx) throws Exception {
super.handlerAdded(ctx);
}
@Override
public final void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof ProtocolNegotiationEvent) {
checkState(pne == null, "pre-existing negotiation: %s < %s", pne, evt);
pne = (ProtocolNegotiationEvent) evt;
protocolNegotiationEventTriggered(ctx);
} else {
userEventTriggered0(ctx, evt);
}
}
protected void userEventTriggered0(ChannelHandlerContext ctx, Object evt) throws Exception {
super.userEventTriggered(ctx, evt);
}
@ForOverride
protected void protocolNegotiationEventTriggered(ChannelHandlerContext ctx) {
// no-op
}
protected final ProtocolNegotiationEvent getProtocolNegotiationEvent() {
checkState(pne != null, "previous protocol negotiation event hasn't triggered");
return pne;
}
protected final void replaceProtocolNegotiationEvent(ProtocolNegotiationEvent pne) {
checkState(this.pne != null, "previous protocol negotiation event hasn't triggered");
this.pne = checkNotNull(pne);
}
protected final void fireProtocolNegotiationEvent(ChannelHandlerContext ctx) {
checkState(pne != null, "previous protocol negotiation event hasn't triggered");
InternalProtocolNegotiators.negotiationLogger(ctx)
.log(ChannelLogLevel.INFO, negotiatorName + " completed");
ctx.pipeline().replace(ctx.name(), /* newName= */ null, next);
ctx.fireUserEventTriggered(pne);
}
}
}