diff --git a/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java b/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java index 53749f3e1d..603af2d514 100644 --- a/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java +++ b/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java @@ -34,9 +34,12 @@ import io.grpc.internal.ObjectPool; import io.grpc.netty.GrpcHttp2ConnectionHandler; import io.grpc.netty.InternalNettyChannelBuilder; import io.grpc.netty.InternalNettyChannelBuilder.ProtocolNegotiatorFactory; +import io.grpc.netty.InternalProtocolNegotiationEvent; import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; import io.grpc.netty.InternalProtocolNegotiators; import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.handler.ssl.SslContext; import io.netty.util.AsciiString; import java.security.GeneralSecurityException; @@ -163,7 +166,26 @@ public final class AltsProtocolNegotiator { ChannelHandler thh = new TsiHandshakeHandler(gnh, nettyHandshaker, new AltsHandshakeValidator()); ChannelHandler wuah = InternalProtocolNegotiators.waitUntilActiveHandler(thh); - return wuah; + ChannelHandler knh = new KickNegotiationHandler(wuah); + return knh; + } + + /** Kicks off negotiation of the server. This is a hack workaround until server uses WBAEH.*/ + // TODO(carl-mastrangelo): remove this once NettyServerTransport uses WBAEH. + private static final class KickNegotiationHandler extends ChannelInboundHandlerAdapter { + + private final ChannelHandler next; + + KickNegotiationHandler(ChannelHandler next) { + this.next = checkNotNull(next, "next"); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + super.handlerAdded(ctx); + ctx.pipeline().replace(ctx.name(), /*newName= */ null, next); + ctx.pipeline().fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); + } } @Override diff --git a/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java b/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java index e470d114b2..76b4267445 100644 --- a/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java +++ b/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java @@ -36,7 +36,6 @@ import io.grpc.internal.FixedObjectPool; import io.grpc.internal.GrpcAttributes; import io.grpc.internal.ObjectPool; import io.grpc.netty.GrpcHttp2ConnectionHandler; -import io.grpc.netty.InternalProtocolNegotiationEvent; import io.grpc.netty.NettyChannelBuilder; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; @@ -150,7 +149,6 @@ public class AltsProtocolNegotiatorTest { new AltsProtocolNegotiator.ServerAltsProtocolNegotiator(handshakerFactory, lazyFakeChannel) .newHandler(grpcHandler); channel = new EmbeddedChannel(uncaughtExceptionHandler, handler); - channel.pipeline().fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); } @After