diff --git a/alts/src/main/java/io/grpc/alts/internal/TsiHandshakeHandler.java b/alts/src/main/java/io/grpc/alts/internal/TsiHandshakeHandler.java index 76d9214053..886a8dc4c0 100644 --- a/alts/src/main/java/io/grpc/alts/internal/TsiHandshakeHandler.java +++ b/alts/src/main/java/io/grpc/alts/internal/TsiHandshakeHandler.java @@ -17,6 +17,7 @@ package io.grpc.alts.internal; import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; import static io.grpc.alts.internal.AltsProtocolNegotiator.AUTH_CONTEXT_KEY; import static io.grpc.alts.internal.AltsProtocolNegotiator.TSI_PEER_KEY; @@ -84,7 +85,7 @@ public final class TsiHandshakeHandler extends ByteToMessageDecoder { private final HandshakeValidator handshakeValidator; private final ChannelHandler next; - private ProtocolNegotiationEvent pne = InternalProtocolNegotiationEvent.getDefault(); + private ProtocolNegotiationEvent pne; /** * Constructs a TsiHandshakeHandler. @@ -148,6 +149,7 @@ public final class TsiHandshakeHandler extends ByteToMessageDecoder { @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { if (evt instanceof ProtocolNegotiationEvent) { + checkState(pne == null, "negotiation already started"); pne = (ProtocolNegotiationEvent) evt; } else { super.userEventTriggered(ctx, evt); @@ -156,6 +158,7 @@ public final class TsiHandshakeHandler extends ByteToMessageDecoder { private void fireProtocolNegotiationEvent( ChannelHandlerContext ctx, TsiPeer peer, Object authContext, SecurityDetails details) { + checkState(pne != null, "negotiation not yet complete"); InternalProtocolNegotiators.negotiationLogger(ctx) .log(ChannelLogLevel.INFO, "TsiHandshake finished"); ProtocolNegotiationEvent localPne = pne; diff --git a/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java b/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java index 76b4267445..e470d114b2 100644 --- a/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java +++ b/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java @@ -36,6 +36,7 @@ import io.grpc.internal.FixedObjectPool; import io.grpc.internal.GrpcAttributes; import io.grpc.internal.ObjectPool; import io.grpc.netty.GrpcHttp2ConnectionHandler; +import io.grpc.netty.InternalProtocolNegotiationEvent; import io.grpc.netty.NettyChannelBuilder; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; @@ -149,6 +150,7 @@ public class AltsProtocolNegotiatorTest { new AltsProtocolNegotiator.ServerAltsProtocolNegotiator(handshakerFactory, lazyFakeChannel) .newHandler(grpcHandler); channel = new EmbeddedChannel(uncaughtExceptionHandler, handler); + channel.pipeline().fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); } @After diff --git a/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java b/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java index a9784b2a21..15c3ae4cc8 100644 --- a/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java +++ b/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java @@ -29,6 +29,7 @@ import io.grpc.internal.GrpcAttributes; import io.grpc.internal.ObjectPool; import io.grpc.netty.GrpcHttp2ConnectionHandler; import io.grpc.netty.GrpcSslContexts; +import io.grpc.netty.InternalProtocolNegotiationEvent; import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; @@ -96,6 +97,7 @@ public final class GoogleDefaultProtocolNegotiatorTest { // Add the negotiator handler last, but to the front. Putting this in ctor above would make it // throw early. chan.pipeline().addFirst(h); + chan.pipeline().fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); // Check that the message complained about the ALTS code, rather than SSL. ALTS throws on // being added, so it's hard to catch it at the right time to make this assertion. @@ -111,6 +113,7 @@ public final class GoogleDefaultProtocolNegotiatorTest { ChannelHandler h = googleProtocolNegotiator.newHandler(mockHandler); EmbeddedChannel chan = new EmbeddedChannel(h); + chan.pipeline().fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); assertThat(chan.pipeline().first().getClass().getSimpleName()).isEqualTo("SslHandler"); } diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java index 49904c7fb8..f70d90153d 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java @@ -17,6 +17,7 @@ package io.grpc.netty; import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; import static io.grpc.netty.GrpcSslContexts.NEXT_PROTOCOL_VERSIONS; import com.google.common.annotations.VisibleForTesting; @@ -327,7 +328,7 @@ final class ProtocolNegotiators { private final String host; private final int port; - private ProtocolNegotiationEvent pne = ProtocolNegotiationEvent.DEFAULT; + private ProtocolNegotiationEvent pne; ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority) { this.next = checkNotNull(next, "next"); @@ -351,6 +352,7 @@ final class ProtocolNegotiators { @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { if (evt instanceof ProtocolNegotiationEvent) { + checkState(pne == null, "negotiation already started"); pne = (ProtocolNegotiationEvent) evt; } else if (evt instanceof SslHandshakeCompletionEvent) { SslHandshakeCompletionEvent handshakeEvent = (SslHandshakeCompletionEvent) evt; @@ -376,6 +378,7 @@ final class ProtocolNegotiators { } private void fireProtocolNegotiationEvent(ChannelHandlerContext ctx, SSLSession session) { + checkState(pne != null, "negotiation not yet complete"); negotiationLogger(ctx).log(ChannelLogLevel.INFO, "ClientTls finished"); Security security = new Security(new Tls(session)); Attributes attrs = pne.getAttributes().toBuilder() @@ -466,7 +469,7 @@ final class ProtocolNegotiators { private final String authority; private final GrpcHttp2ConnectionHandler next; - private ProtocolNegotiationEvent pne = ProtocolNegotiationEvent.DEFAULT; + private ProtocolNegotiationEvent pne; Http2UpgradeAndGrpcHandler(String authority, GrpcHttp2ConnectionHandler next) { this.authority = checkNotNull(authority, "authority"); @@ -497,8 +500,10 @@ final class ProtocolNegotiators { @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { if (evt instanceof ProtocolNegotiationEvent) { + checkState(pne == null, "negotiation already started"); pne = (ProtocolNegotiationEvent) evt; } else if (evt == HttpClientUpgradeHandler.UpgradeEvent.UPGRADE_SUCCESSFUL) { + checkState(pne != null, "negotiation not yet complete"); negotiationLogger(ctx).log(ChannelLogLevel.INFO, "Http2Upgrade finished"); ctx.pipeline().remove(ctx.name()); next.handleProtocolNegotiationCompleted(pne.getAttributes(), pne.getSecurity()); @@ -848,7 +853,7 @@ final class ProtocolNegotiators { */ static final class WaitUntilActiveHandler extends ChannelInboundHandlerAdapter { private final ChannelHandler next; - private ProtocolNegotiationEvent pne = ProtocolNegotiationEvent.DEFAULT; + private ProtocolNegotiationEvent pne; public WaitUntilActiveHandler(ChannelHandler next) { this.next = checkNotNull(next, "next"); @@ -859,31 +864,34 @@ final class ProtocolNegotiators { negotiationLogger(ctx).log(ChannelLogLevel.INFO, "WaitUntilActive started"); // This should be a noop, but just in case... super.handlerAdded(ctx); - if (ctx.channel().isActive()) { - ctx.pipeline().replace(ctx.name(), null, next); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + // Still propagate channelActive to the new handler. + super.channelActive(ctx); + if (pne != null) { fireProtocolNegotiationEvent(ctx); } } - @Override - public void channelActive(ChannelHandlerContext ctx) throws Exception { - ctx.pipeline().replace(ctx.name(), null, next); - // Still propagate channelActive to the new handler. - super.channelActive(ctx); - fireProtocolNegotiationEvent(ctx); - } - @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { if (evt instanceof ProtocolNegotiationEvent) { + checkState(pne == null, "negotiation already started"); pne = (ProtocolNegotiationEvent) evt; + if (ctx.channel().isActive()) { + fireProtocolNegotiationEvent(ctx); + } } else { super.userEventTriggered(ctx, evt); } } private void fireProtocolNegotiationEvent(ChannelHandlerContext ctx) { + checkState(pne != null, "negotiation not yet complete"); negotiationLogger(ctx).log(ChannelLogLevel.INFO, "WaitUntilActive finished"); + ctx.pipeline().replace(ctx.name(), /* newName= */ null, next); Attributes attrs = pne.getAttributes().toBuilder() .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, ctx.channel().localAddress()) .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, ctx.channel().remoteAddress()) diff --git a/netty/src/main/java/io/grpc/netty/WriteBufferingAndExceptionHandler.java b/netty/src/main/java/io/grpc/netty/WriteBufferingAndExceptionHandler.java index a47bcc7f29..103d300b53 100644 --- a/netty/src/main/java/io/grpc/netty/WriteBufferingAndExceptionHandler.java +++ b/netty/src/main/java/io/grpc/netty/WriteBufferingAndExceptionHandler.java @@ -56,6 +56,8 @@ final class WriteBufferingAndExceptionHandler extends ChannelDuplexHandler { public void handlerAdded(ChannelHandlerContext ctx) throws Exception { ctx.pipeline().addBefore(ctx.name(), null, next); super.handlerAdded(ctx); + // kick off protocol negotiation. + ctx.pipeline().fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT); } @Override diff --git a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java index 60f35ce0cb..bf18e3a876 100644 --- a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java +++ b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java @@ -169,6 +169,7 @@ public class ProtocolNegotiatorsTest { @Override public void channelActive(ChannelHandlerContext ctx) throws Exception { ctx.pipeline().addLast(handler); + ctx.pipeline().fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT); // do not propagate channelActive(). } }; @@ -226,6 +227,7 @@ public class ProtocolNegotiatorsTest { assertEquals(1, latch.getCount()); chan.connect(addr).sync(); + chan.pipeline().fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT); assertTrue(latch.await(TIMEOUT_SECONDS, TimeUnit.SECONDS)); assertNull(chan.pipeline().context(WaitUntilActiveHandler.class)); } @@ -571,6 +573,7 @@ public class ProtocolNegotiatorsTest { .connect(addr) .sync() .channel(); + c.pipeline().fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT); SocketAddress localAddr = c.localAddress(); ProtocolNegotiationEvent expectedEvent = ProtocolNegotiationEvent.DEFAULT .withAttributes(